QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2713|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    0 @+ ?* r% [( v  R6 [0 x/ ]1 S4 C/ }+ q  w, T+ v: ]
    文章目录6 v0 p* q5 f8 t: y2 i9 q! f4 ]2 q
    卷积网络实战 对花进行分类
    - {& e3 r  `- Y+ z" g- M数据预处理部分
    . t2 R( Y" v; O/ ?0 \网络模块设置4 o0 M, o0 ~& D' d
    网络模型的保存与测试
    0 u1 r; T2 ~# `$ [+ Y2 D  M数据下载:7 w& z* o- h( @& u. @
    1. 导入工具包# t7 D( ?0 g# a. N+ p  ^
    2. 数据预处理与操作
    2 d9 V! D) y" P9 l. M3. 制作好数据源
    : a2 T( G* `: [& @读取标签对应的实际名字5 Q# n8 S1 i4 V7 ?, q; f( l* [/ z5 r
    4.展示一下数据; R. k' R0 T. o5 e1 T, {
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数$ v9 {0 Q: s* K" e1 i
    6.初始化模型架构
    7 M% s) Y! H* S. U/ O& I4 n0 p7. 设置需要训练的参数0 l6 w9 Z& V1 t  p4 Y) E5 l( E+ `
    7. 训练与预测
    % I: ?' W- j" E1 m! h* e) w. t7.1 优化器设置$ @3 p: M4 M* z: N; z
    7.2 开始训练模型, m/ X) c/ C" u" K2 y" ?' F, R& o0 ]
    7.3 训练所有层
      Q0 f7 [- N' I/ l* B开始训练
    ) \6 x. ]8 m3 }8. 加载已经训练的模型5 ]: ~+ u' p* ]- `4 V. t5 x" l$ [
    9. 推理
    * y' \  p* c9 F) o9.1 计算得到最大概率& U4 f% O4 L# Y5 C- T
    9.2 展示预测结果: |5 A" C" x' @8 R9 g
    写在最后% `8 x/ U6 X9 R6 T. @  L$ l5 r
    卷积网络实战 对花进行分类0 a$ Z% r. ~7 l$ g, V- F
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    % U4 y, B6 A% r, v' O* @/ h5 u; c: D* @) z& j
    在文件夹中有102种花,我们主要要对这些花进行分类任务# U0 F5 \( W& q9 Q0 I# p/ @; s
    文件夹结构
    # a) d" Z3 Y1 |/ I
    , B/ V. a0 g. g# A  Zflower_data! h. l0 C! g/ G1 Y+ L

    . X8 Q" e9 x" r$ x, C3 wtrain
    % z1 c. z3 j+ M% K# y
    . n& ?$ k: V) f1(类别)
    / p$ g$ `5 l2 v+ h+ `9 a2$ H: C% ~" D( Z0 b
    xxx.png / xxx.jpg' k" ]7 }8 M4 l% W
    valid
    7 ?/ w: q7 B( P; x
    + o: K; u% ]  C, N! b& ~8 i主要分为以下几个大模块8 ]' e' G- l) \. ^" ?+ T' U7 X/ j

    ! [9 ^4 N8 \( [& F7 n& c) s数据预处理部分
    4 \9 \4 k6 b6 k* J( P1 I7 u) m5 b数据增强' j; d- T. A, i. |9 f7 t9 h( q/ ~1 e
    数据预处理) V8 J( x/ s. a: I) ~
    网络模块设置2 K$ W6 J. `# {1 V" H3 C, E8 E
    加载预训练模型,直接调用torchVision的经典网络架构3 q1 x7 N4 K* C* k, ~
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务: _% O, i# ^& Z
    网络模型的保存与测试
    8 m2 g0 j  O5 n+ W, ]7 O. Z模型保存可以带有选择性
    ; M8 @7 ]( V. d8 E数据下载:" p+ |6 j, h2 `; G  L; N/ [
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    ! A( p) L; ^( F6 \4 v. ~# B2 A5 `; v( o  w+ q. V$ S9 C# f- n5 U
    改一下文件名,然后将它放到同一根目录就可以了
    2 j/ P! C: V* K/ _
    ) l+ _! d0 z2 v( ^+ k下面是我的数据根目录
    3 `' L3 C0 ?  f9 Z% Z1 y9 m8 k7 N7 r% @% T5 X5 B# ]9 l
    " `! V0 G5 f  ^) g" `" S
    1. 导入工具包
    : c- p9 p$ W3 |! i" Eimport os
    3 q" \2 H. s' h$ b5 g# G. p5 |( p5 jimport matplotlib.pyplot as plt
    2 X* G# V8 O$ T* V* n" o# 内嵌入绘图简去show的句柄
    ( H6 t4 M6 R0 _# L7 l% ~%matplotlib inline
    . Y5 s3 F- C6 Z5 K9 pimport numpy as np
    * ]1 d& F( [0 W% H" g: aimport torch
    2 O5 z+ v3 o$ Mfrom torch import nn& f8 L3 e$ i9 r( R$ Z% l: o1 n7 r7 j0 b

    3 l" q8 Z* f( e' }% o1 `import torch.optim as optim
    % _& `8 U0 n  m  qimport torchvision: f8 G! Z5 f; M
    from torchvision import transforms, models, datasets
    8 v! \# M. W) M# L5 c/ [! V7 g" u2 ~! s3 h
    ( L8 ^) ?2 p% o& ], ]  aimport imageio
      n1 k1 v# F0 h7 \import time5 z1 K* [; F" X& z
    import warnings4 {  c: R. s% V1 x
    import random8 j" ~1 y! A/ a: X" e) c
    import sys$ P, b; M& D, w; K: z& q% p
    import copy% ]+ `1 h) ]: P7 L
    import json
    5 |- W' `& b6 u- o' U- g) zfrom PIL import Image
    4 Q8 k" X+ C( ]' g9 L  T5 ]! C" \: w

    & H6 j) P" l. [. L0 h3 q1
    ' @. S! j2 [! f& q% s7 T  K( A3 C" Q2! P! G* Y% ]  i0 G
    3
    , ?3 h) K# l- c( v# x4+ d7 w, V; F  B+ W- C: U+ G
    5/ o0 I3 O4 u: r/ F6 R
    6  W' \' ]9 k: A4 a! @' h
    7! x1 c, m2 ^0 Z4 g
    8$ A* A% e# A4 p5 H, ~
    9
    & L. I+ N5 ]# w# S10
    - n* o7 ^4 ]' U, Q5 g11
    ) k3 @$ w$ I+ |12: z9 o- x; C" W8 G% A& a* z% b
    13
    0 [$ ~! _4 ?  E2 `14$ A% ]; E) V# t% y, I- ^: V
    15
    7 _7 Z4 o+ S, A2 q+ M& r16
    4 F, O$ u: L. {" g17
    1 v  u. @: c/ _/ Z7 _7 O  [18
    4 b2 ~& F! Y" a* f+ p& l19. D8 a+ C" e/ W3 M# u
    20
    " h; \9 m9 W9 g2 v21
    , c  \* F' j  m# L: T2. 数据预处理与操作
    5 v, W. R3 p) W# k#路径设置/ ~2 E8 ^, H0 F) `" X" ^; Z
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
      l6 O2 Z( F4 L$ z; ~& u  Q2 f1 Ttrain_dir = data_dir + '/train'% K0 o  l7 S4 N1 o
    valid_dir = data_dir + '/valid'
    ) T2 U9 O7 @. p* \5 ^1
    & x7 M2 j+ L; S) V  B3 `" \" _) `2! i. I. b' T4 G7 D6 ~
    39 w# @; \7 a* d8 ]$ f- u1 }- o- F
    4& v) |- o1 Z. n1 G4 M6 ^# r7 I
    python目录点杠的组合与区别' i+ }& J, i7 F  v/ T' C
    注: 里面注明了点杠和斜杠的操作
    , D# c6 s8 e  r/ Y; {' q: B/ s+ f6 j! `0 h+ n
    3. 制作好数据源; a) K( P8 A+ W1 b; p9 r& R9 j1 j2 k% l
    data_transforms中制定了所有图像预处理的操作2 D- ?( q" {% S
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片$ r+ h' z( B, }6 V  v6 Z7 ]. w
    data_transforms = {1 Y8 A% Z6 G5 L7 j! U# F1 \% \, V
        # 分成两部分,一部分是训练
    9 R$ r6 ^/ x- a+ J    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    7 X0 c( j( ~0 G0 L" h: p                                 transforms.CenterCrop(224), # 从中心处开始裁剪4 C7 O$ ~0 k* w5 |
                                     # 以某个随机的概率决定是否翻转 55开' V9 X( V9 |' x0 E* d* J, g
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    7 Q% }4 i) p4 y6 H4 L                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    6 P# C5 {+ O" k1 Y+ W, W- ?1 @                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相! t) p* p. C- g, Y
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),5 e$ g9 F) t/ f, m; B* W3 [/ f/ @- q. x0 u
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB+ ]3 Q5 M2 {% b. w3 X" s/ P
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    # W. M( }$ i! y+ K( Y                                 transforms.ToTensor(),
      v' o! h0 [! C4 r, J( G. S                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差- m8 X+ W8 v* L/ O& g: m* u0 O
                                    ]),
    3 u0 {* e6 U* I+ G! P* I    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化, O5 t0 `7 y; l8 m! e9 o
        'valid': transforms.Compose([transforms.Resize(256),7 x$ N% q6 W- z5 L  k# l% Q
                                     transforms.CenterCrop(224),5 X0 N' B+ D! H3 O
                                     transforms.ToTensor(),
    2 H) }$ G1 S# O                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同& n+ m8 `% C7 m$ w2 A
                                    ]),
    7 r1 z* P6 T- G) l  }8 Y" y}
    * _. q0 [6 p0 I! {
    3 s# x/ l5 x/ E, r% k  C$ P9 Q1
    5 v$ o9 h. h' X$ t. e2/ V+ U) o0 _  X( _9 B7 ~1 M
    3+ E# A$ C- l7 Q1 R4 G6 v. {
    48 @! Z+ L) z' t
    5
    8 K. c0 G; J' R- s/ n) i5 x* x6# p' K$ L9 e" g& x6 @! J# }
    7# x! D* s2 T9 {+ C
    8  e2 Z7 C" G  b0 u2 H( T
    96 y" U& @# W' e& w4 K, w
    107 f3 S7 G( |4 T$ y9 T6 `
    11& T  }$ ~: |" {* R4 T1 {
    120 l$ d* T; Z/ d/ m8 f8 J
    13
    ' y1 k  R% Q2 C5 }7 `  ^140 @/ C. g; U8 T% G8 U! x
    15" E3 [  J8 B9 H. _% D. K
    168 i- u) ^: ~1 m! |- w( \
    17
    4 o5 H& Z9 s/ u5 h( d7 T* m4 W18, @( S1 g  N+ V" H% A- s, T
    19
    ; r3 h( d6 r" N20" [9 }+ i/ d. G2 [6 N
    21: ^" L: U& W2 Q& J: V$ E
    batch_size = 84 D" B1 U" b  I6 T0 G! v
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    . S: `0 v8 Q. X, N/ _) fdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    9 \! N& {% m, tdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} 3 X3 w( p0 t9 t! @
    class_names = image_datasets['train'].classes
    . e8 H* S# E1 ~4 U$ T: \, E: {
    0 Y3 O% {* U+ C: Y3 T  M#查看数据集合
    + A3 H5 B1 a6 s+ v- Q# {image_datasets) u) u" n$ H0 w0 F. w, @

    . G; i# O6 K) @  u' s1
    & z9 }& ?3 S/ C3 i7 C" u. B2
    ( d0 J: t7 x3 ]3
    2 L; F5 f0 j3 ]( {4; G1 P9 n9 y! ?- X
    5  i. B5 `1 z) M9 |* ?9 k7 q9 e3 F! g* z2 u1 `
    6
    5 k' m0 [) o7 p* R74 n8 G8 z4 n; R2 @
    8
    8 ?3 W& ?2 i! J8 L9
    3 q/ j8 M* E4 D. Z5 k' _" m{'train': Dataset ImageFolder
    % U1 y# W. Q, m     Number of datapoints: 6552" x# w, y4 c& V, C4 l9 a
         Root location: ./flower_data/train# L/ t2 p! S: K5 X2 o
         StandardTransform
    # J7 @; q: a* w  e% N7 R) |# Y$ I Transform: Compose($ c5 ^3 |3 W. ]2 l2 U  @: b$ ], b- P" C
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
    2 y- x- z; z7 ]  c: Z                CenterCrop(size=(224, 224))
    , n$ _# Y, @7 O/ l/ S8 k                RandomHorizontalFlip(p=0.5)" M# s( ?. s$ F1 n
                    RandomVerticalFlip(p=0.5)+ n" I. h. R# K# U% P
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])& m" t5 a: P/ Y5 R
                    RandomGrayscale(p=0.025)
    ' i; c( n+ _$ v8 v# F. G                ToTensor()
    1 d' n( m, {$ H; D* i0 v$ l# Q                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])6 }# q  _( v1 W# ]% N
                ),' `1 o3 l# ]3 n. h0 O3 E
    'valid': Dataset ImageFolder
    ; r4 t8 x( H( O4 `$ q2 `     Number of datapoints: 8188 F- Z5 d* f' K% Y1 e0 N# N
         Root location: ./flower_data/valid+ j# k8 v1 K  A, z$ Q% A  U6 s
         StandardTransform9 J* ]" P' p1 u
    Transform: Compose(! x5 k/ T& N0 U% S% \
                    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    * `6 v* d( V9 d- F% s: n                CenterCrop(size=(224, 224))/ e2 r! ^" U9 j# g
                    ToTensor()
    * }1 D! ?% R/ c* e5 ?, j                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) R0 _5 ~1 P$ M3 p  b; C: H
                )}
    . B6 P+ {7 z) _9 n+ x, e
    ) }% j: t4 N9 u' G1
    % J7 L% Y; ^9 [- p" ^2% N- V* `8 R: _( n6 ^2 s) t
    3% u- m1 b, w( x1 J
    4
    6 N. _/ `. D- q$ c+ r5
    ; V" V: i; a* t' l6
    ! c7 f% E2 u# u5 g; K. m8 L7' f) [1 }- L- w0 Q2 N
    83 X. W6 p: \& i# P
    93 {# d+ s( j* o, X) X
    10
    # F+ Z4 k5 D1 k# L2 l- Q9 R4 o# K$ H11
    4 g( c! N4 t5 A( b, _  n6 j+ I12
    - j" B- S  C" y/ N2 {, g1 t13
    1 f  C' ]0 w2 O& N; R+ O14
      D# Z8 h  D$ p4 ]* h15
    - a- B9 V  H/ e! Z# p16
    * t/ w7 z' g2 g, N4 u9 A2 F. ^: _& W17$ u2 Q% z' V6 i( u$ D$ U
    18
    ! e/ b+ o. w% \/ G; u: I19
    6 P* e3 P( n+ f20; C/ K: h. h- s
    21
    % k3 K; x' n* r* M# ~: X) p22
    / E( e5 U4 n! i+ T; [: y23
    5 B  B0 S9 n2 d! Z24
    2 e& A* [' ?  V: Y# 验证一下数据是否已经被处理完毕! z( Z% J! C& J& |$ M: i  ]
    dataloaders
    ! J9 \0 U7 e- y1
    0 w& @1 d; ^& x4 T- q1 `4 t4 b2+ O& j9 r* C4 s6 u- G2 E. h
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
    ' S" ^; H! }, I; D- ^4 z9 Y 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}. D- Z1 \* |1 L5 i3 k/ _" y0 W' w
    12 f( ]6 [' Q7 [6 }8 B
    2
    ( n2 |- T6 c; b# d+ v/ ?dataset_sizes5 ]: ~, L- [/ s8 A5 `
    1
    " P4 _( }% Q% r# y{'train': 6552, 'valid': 818}
    + y3 U/ s2 l% w6 {1+ b# v' |# \3 X3 {- O; @
    读取标签对应的实际名字$ ?4 C0 Q# Y9 D0 t$ F
    使用同一目录下的json文件,反向映射出花对应的名字6 y7 ?& [8 f* K, Q. s: W

    8 H1 P2 A4 X. pwith open('./flower_data/cat_to_name.json', 'r') as f:
    $ w7 E  Q! J+ a6 y$ B    cat_to_name = json.load(f)% U1 U) l* U. H! n
    1
    ' V3 ~' b, I5 R2
    6 ~$ {$ g; z2 M! Bcat_to_name( F1 w4 w9 t9 ^* I- L$ X, r
    1' k* O8 B: \  X0 i
    {'21': 'fire lily',9 p! a9 c3 f0 F  |/ x/ x9 i- u. K, b. e3 Z
    '3': 'canterbury bells',  ]& j' [2 s- P1 h. I
    '45': 'bolero deep blue',1 l6 B2 M" x* g7 T: L9 d  g, W
    '1': 'pink primrose',
    7 ?" p; O. @7 ?  g/ | '34': 'mexican aster',
    * o5 |, w: r# y  l; b" v3 R '27': 'prince of wales feathers',& L1 g5 k* v3 j) f: a% k
    '7': 'moon orchid',
    ! X2 r$ w  U2 s8 f( E* h; o '16': 'globe-flower',
    + E! Q2 |' k; O4 G6 ]7 R '25': 'grape hyacinth',7 U& r6 S; g' P9 }
    '26': 'corn poppy',
    " X$ x( z! i; e: m! U  X '79': 'toad lily',
    6 a- f: H  w, Z '39': 'siam tulip'," Z2 Z2 A. P. J4 U  O7 F; k7 `
    '24': 'red ginger',
    / b1 z0 ~  r7 e; U& p+ v4 }0 {/ w '67': 'spring crocus',
    9 M, J$ @+ C& Z1 C0 E '35': 'alpine sea holly',2 k+ `  r2 L' u
    '32': 'garden phlox',
    ) P% O8 d# K% _; O% {, n '10': 'globe thistle',
    6 w+ S: R' A. {' k  c '6': 'tiger lily',
    $ ~: |" ?/ Z7 D5 ~/ Z '93': 'ball moss',
    + ^7 _( D$ O2 `& r '33': 'love in the mist',
    ( G8 T$ B& V: q, ?& z- p '9': 'monkshood',
    . {; C% w5 X6 E '102': 'blackberry lily',/ z7 s. q, ]4 v) F
    '14': 'spear thistle',. c# r) i- f, e/ ~* v
    '19': 'balloon flower',
    * r) @% Y" u. u '100': 'blanket flower',
    % j4 U6 o; l; Z# Q1 M '13': 'king protea',
      a. i' \( z5 g; \ '49': 'oxeye daisy',* W  f: U  c; c7 p, ?0 g6 ]
    '15': 'yellow iris',: \4 I  W; b# x8 ]- Z! C  F: {
    '61': 'cautleya spicata',
    / q: I  b2 |/ |, z& L" E" f* Y+ t '31': 'carnation',
      e/ v: g# r3 y- D$ u( V* D3 H '64': 'silverbush',
    1 G+ T. d/ }: }& B+ E& S& o% | '68': 'bearded iris',5 u$ G: U! O3 m3 n+ a
    '63': 'black-eyed susan',
    : r% i6 `& j# O6 n6 K '69': 'windflower',- l6 k* {- P' ~$ j2 N
    '62': 'japanese anemone',- @' @+ |! R0 C, E3 e
    '20': 'giant white arum lily'," f3 y* N3 C2 H, ~3 q; i
    '38': 'great masterwort',+ R8 m, S  R4 b/ H
    '4': 'sweet pea',
    : q: u  ^" G1 d# g& l6 ? '86': 'tree mallow',
    6 ~- v/ z: u5 p' _6 G* X0 w '101': 'trumpet creeper',: O+ J6 g* E: H$ p/ H$ X
    '42': 'daffodil',
    ! R, G9 V1 N* [: X9 U '22': 'pincushion flower',
    3 k; K: d% _3 s1 ?4 { '2': 'hard-leaved pocket orchid',! R6 |( p" J3 Z
    '54': 'sunflower',
    8 H) _  `+ q: d  E3 h8 ` '66': 'osteospermum',# D+ G' r0 ^) \' {0 g9 k
    '70': 'tree poppy',; E0 t5 K1 P. h6 `
    '85': 'desert-rose',$ W' g1 E% H9 u; e5 J( e
    '99': 'bromelia',# S! g; @+ Y2 A. B9 W
    '87': 'magnolia',
    : Y# }5 s0 L) I '5': 'english marigold',
      V2 p4 J+ i8 j6 W% O2 ]+ S '92': 'bee balm',
    , ]4 S2 f$ e4 I$ H7 Z '28': 'stemless gentian',
    . A- [+ A+ l( i% F$ | '97': 'mallow',
    / d0 _5 W/ I3 t6 E' a- \ '57': 'gaura',
      R& v5 K7 `( g '40': 'lenten rose',& q' C) {) Y2 o+ z. O, k. v0 d
    '47': 'marigold',
    ; ^) b4 e0 W4 P4 k2 g+ u4 R '59': 'orange dahlia',0 ?) N& K) x) T, A4 X
    '48': 'buttercup',$ o6 X& E5 v* K7 i0 D' I# e1 H
    '55': 'pelargonium',
    ; \% ]: I' T: s4 f. F' {9 s '36': 'ruby-lipped cattleya',
    ) C" m7 p" z$ R9 ^' c9 Z '91': 'hippeastrum',6 g3 _2 d1 \" `" u) d
    '29': 'artichoke',
    ; j! d! p' }$ L& q '71': 'gazania',+ e- l8 Z# h) u+ V/ U2 y
    '90': 'canna lily',
    2 \* }9 x. q% T- ^+ r '18': 'peruvian lily',
    0 H: U& g, ?  B2 | '98': 'mexican petunia',
    7 ^# l5 H& `; ~ '8': 'bird of paradise',
    - e, U" C5 V! w/ K+ ^) g '30': 'sweet william',4 I) q1 H) K6 K/ l
    '17': 'purple coneflower',
    " V/ Y2 ]( f. Y/ g) p1 i) e6 d7 ` '52': 'wild pansy',
    1 L: A9 v( O" ~# H8 p '84': 'columbine',1 O! }* g! d! U. Z$ W
    '12': "colt's foot",4 t7 s' h6 [9 o
    '11': 'snapdragon',
    % b: E' U8 Y) t; b0 o( a& r; U '96': 'camellia',
    . v/ v! E( J/ ^, x0 k& x '23': 'fritillary',
    5 W7 W! l1 J7 p8 Y+ a '50': 'common dandelion',( d, k: _# s7 a. y3 @+ d0 K
    '44': 'poinsettia',
    $ s  y0 }$ J0 X0 V- W& J3 J '53': 'primula',, @6 z$ M3 \# K- t
    '72': 'azalea',
    % W! y9 f4 N: |- g; ` '65': 'californian poppy',+ N4 I# G) [; {4 s) \
    '80': 'anthurium',) Z6 y/ r* V- ]6 e' Y
    '76': 'morning glory',6 X- s1 N" f; a
    '37': 'cape flower',: J; s6 }) {4 v, z$ K# H0 H% f
    '56': 'bishop of llandaff',3 o) v" F( x% l9 {, L0 @6 _
    '60': 'pink-yellow dahlia',5 J- r8 a; Q7 |9 t0 a% M1 c
    '82': 'clematis',) e( j3 }5 f) q
    '58': 'geranium',( q3 s9 Y0 y( O# S3 F; y/ I
    '75': 'thorn apple',$ B. ~( t; z/ `6 ~& y) Q" A
    '41': 'barbeton daisy',
    1 w2 s2 e9 @3 T& M8 o: A) \: P '95': 'bougainvillea',5 n. `5 X+ I' \$ W
    '43': 'sword lily',( w/ Z3 [3 k  f7 a/ Z1 N& c/ K
    '83': 'hibiscus',( v6 d$ b$ F! {. V" W4 G
    '78': 'lotus lotus',
    8 W, r# C) r# I4 L% H" d  d) x '88': 'cyclamen',# q6 [% m3 ?5 v9 K- L
    '94': 'foxglove',
    ! ^' m( Y; a9 M5 Q0 n$ o: q '81': 'frangipani',
    . T7 D/ x! S6 h '74': 'rose',
    & ]1 ?6 c0 k) E- [3 U" [ '89': 'watercress',
    2 N0 x, \6 c- t! P '73': 'water lily',, Z+ {' {' q% u6 J* B" M& [1 l
    '46': 'wallflower',
    ! {3 ]+ G% b# i0 [4 V '77': 'passion flower',
    0 I9 |  H% x; o! a8 {! J '51': 'petunia'}3 C) F1 d- {0 c" ^: V
    & U% l5 F5 ]" v
    1
    2 G% c) _* s2 s. f6 L# U24 ~: `9 X$ h: x. k* t& Y$ W1 P
    34 N8 E) E% E5 c8 h
    4
    % }- j# G/ p4 Y7 {& O4 _& z  U, v8 q( M5! z) T1 o; x, y
    6& R! D6 ?. O. I, C% C' o
    7# W; S7 E" W* F* S# D. _
    87 U  a: A, N- t
    9
    : m. Y* a2 u# C. _# a: J( S10. Z" ]; C' E" B# U8 ~
    11
    $ @& C& j9 P# e  M* t0 G% S6 S+ r125 u; h: d* ?8 G1 I: {  d, |
    139 o: m+ u/ D. t" ?" V. ~2 [
    14
    # m+ l& G7 j" }3 L3 a2 ?) k15
    % h5 P; a* K3 h1 q6 h0 z9 B4 i16" D# L* D' _( q, F2 K( W: R
    17) \. U2 a! m! @( F% I
    189 P$ w0 S$ E8 a$ L2 D. i/ O/ H2 h: L
    19# W8 k7 |2 l1 L% x8 {
    20' b/ k, V: ?- Z
    21/ e8 N" y- L( h4 N1 _% ~
    22
    ' I/ `4 \: t6 _$ H233 c" s; b/ H8 h  U8 l
    24
    ! r7 F. u2 x* Z4 o% P" b  h1 A6 W( Q25
    # H! R- Y, I* |26
    ) R" P0 v5 [4 _& u8 d2 Y27
    . F% J/ O, \( s0 w, B2 F28. W* L; W4 H- j. R
    29
    3 r0 w1 L) b, U6 `: m) E30  F9 w& A/ V) c5 ]9 F
    31
    5 V5 `  f! \4 ^$ V, Y8 X326 l' l& O! ^. n" Q3 @
    33
    7 h# S( S$ Q7 {$ X5 ]8 n34. C) U5 B  ~" O
    35
    ! m' n: S1 H7 X3 w% s* [( S$ I36; j& m7 y7 g# }1 ~1 e! x2 s1 Z
    37% U9 r6 ?4 K" `  j. s
    385 c! p; I4 F+ W: p/ j; K
    39
    ; M2 v& [% E) _6 z& G40
    0 t% ~( I1 ?- \* T+ |1 j1 S41
    1 A- J. K: f4 _9 \& |42
    2 V- W3 Z, g, V9 i8 ^' V" M43
    % h5 S3 J: e+ \  `" T44- ], E; k/ b8 G/ M1 Q% j
    45
    1 n% q$ e9 ~( k* U46+ k( i9 e" j: V
    47
    + f$ b3 T% e) u+ L8 [48
    , q( _9 A) m2 n2 e; \- a492 a( g. \, v- Z0 i& w
    50+ l9 y+ X1 Q" i* U- ]' L; L7 w
    51
    3 A* L! i4 Q4 y52( t) v! `, w+ d5 O
    53& ]* y) L! J" m* k( R
    54
    7 L0 C! j" e4 {7 a- o( t55+ b& n. ]/ L  O7 g: Y' s/ Y
    56* y- n) T7 x! X! v2 ?* P
    57; B& U. {& k5 D3 S
    58  D4 Y6 ~1 v, R! j. i
    599 y* \) @" K4 s( O- x4 x' y6 u7 l
    60* C0 E. F0 _+ t$ ]: V1 d
    617 C9 U/ L* Q' w, g2 }9 a3 P
    62
    . Y, F7 f* U) Z63
    - R2 I# S3 W& D" i0 d& A64  Q0 D0 {4 M6 t. ~4 V
    65
    9 t6 }- z# Z4 [& c4 ?8 C' }66
    , W- O- |( h9 ^8 z67
    & A# Q9 h4 Y! V& l68
    , _4 l& N% i' H! F69
    4 Q/ i2 r7 x0 `2 r0 A" c  S7 b# [2 D70
    6 ^9 G5 }3 ?4 z4 f: v71
    9 T! X, o7 ]: ^. y: [9 o72
    & e9 M" B: z6 N' ~73
    8 c0 W$ _6 ~/ J! x74
    * R3 u$ R0 v$ [9 ~759 i. h% T7 R/ G6 T
    76
    9 E- O& K! e9 \7 L0 _77
    % l+ q5 V7 T" ]5 D! C, o; _783 I: g1 `2 o* H. S1 ?
    79
    4 I6 @1 K% U  p, H. I80
    ' @! t3 X3 K( S# N818 h  z! S9 ~" O7 Y8 q8 k* H
    82
    $ A5 x3 X& t6 b3 V& y! z83- w* W: p  v* }( a& E4 @
    84
    7 l0 u2 x. s$ r% l$ ?85
    3 `) O* h- k/ o# u, k. P86$ s9 }. v( q% W1 T0 _( r( f
    872 c2 G& K, I2 G, @. \: ^5 N. _1 [, g
    887 y* a7 w" N) J$ E
    89
      J3 R8 d0 y) _8 z903 Q) o( W, K# l! E; ]5 `
    91
    7 P3 |( c% c4 _" j929 [) f5 V) d/ }+ s+ @* t4 O9 D
    93% S2 t& n* K( F7 M
    94! s& {5 Q2 d% m& K. x! g+ X
    95
    8 H0 c) ]( v9 n% l960 H5 v5 [2 \7 y' h
    972 K6 |$ E6 W  S! v. y" S/ j
    98/ S8 \2 h; x( Q" o  F; j
    99& J: w: d' A* @) Y3 z; v) W/ L' a
    1004 P! V* T. \, P9 S& M2 ?
    1011 P. ~# Z* H9 E8 u1 H/ W/ C' w
    102
    ( a/ k, `8 j& ?. Z4.展示一下数据
    " k1 C$ H3 ^9 _def im_convert(tensor):
    6 g6 m4 ^7 U, ^6 U( e5 u! P5 J    """数据展示"""
    . G# O( f! i  T7 \# |    image = tensor.to("cpu").clone().detach()
    ! A4 T0 H/ x+ v$ `    image = image.numpy().squeeze()
    3 \" a, u9 ^" I: H    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
      s2 j! t% m' L" P# p    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    ; a* J( r" A* h! |+ S    image = image.transpose(1, 2, 0)5 W0 c" |- o3 p
        # 反正则化(反标准化)% I# H; H4 h8 b1 F* u
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))! T% V$ ?: X' B' V- e/ l: u6 n
      |; s1 N, A6 c
        # 将图像中小于0 的都换成0,大于的都变成1- M  F" H4 l$ r1 t
        image = image.clip(0, 1)/ ~2 ?* m$ e3 a% U3 @
    / ~% T! O8 L+ `1 X; N
        return image! u0 \7 j* G0 c# z- e% X, T
    1
    ) p6 M8 M* V6 Y26 ^# d  w/ A: D
    3! [+ w" B; y: J3 q8 ^) R% c6 A# ~
    4
    6 g* z  j; X6 S5
    ; l+ `: s5 \  k1 N- T6 V" T) U$ ~7 \( j* Z6
    7 T- j( C# r. j7
    * v$ _, E' P+ f' M8
    6 Q' w0 ~# a, ]. X( O5 [" B95 p+ P. T0 L2 h2 G+ U. ^9 ]9 i
    10+ N# p; q1 y0 S% b$ m/ `
    11
    / S2 C$ R6 w: x+ X' @; E; [+ m12! m, E* ~6 c2 C' n2 j
    13
    * v& o3 x2 K$ v) V3 D, m( F* f14
    6 X: z7 k- }$ G8 O: u# 使用上面定义好的类进行画图4 ~& B. h7 N2 s
    fig = plt.figure(figsize = (20, 12))
    0 h) Y* ]* }0 U3 ncolumns = 4
    9 }& E- ^& z9 srows = 2# q" p" _! R( p
    : X% L) {' w! t6 {: M6 h; N
    # iter迭代器
    2 k, |7 w" F& E1 P1 Z5 H3 s3 j# 随便找一个Batch数据进行展示* P; u4 n: r6 v9 |
    dataiter = iter(dataloaders['valid'])
    1 @# W* a5 e: ^+ |( pinputs, classes = dataiter.next()
    . t. y7 U: o+ ~' ]8 g" S3 Z8 K' c" Y( \* d4 P
    for idx in range(columns * rows):' ?8 l& V& z2 ~8 l$ S
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
      y6 K3 t! D% x    # 利用json文件将其对应花的类型打印在图片中! X3 v3 d/ _, @! ]$ c' r- E
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    ' v4 z( o/ T4 N! {( K    plt.imshow(im_convert(inputs[idx]))
      M) p% z% f3 I2 B. u* R; Q8 y& ?$ rplt.show()) b9 ?% E! M4 ~, J9 k- I
    9 R; v  Y4 x, j9 M' P& |
    1- g- b  D1 o) u$ @
    29 G1 p0 r3 F" w  E/ l! ^/ v) G) J5 }
    38 @" z/ f8 I3 _) Z3 ~
    4
    ( B6 W8 r) z# l# l, D* y5% E7 U2 i6 @; B- e8 ~9 F5 e
    6. J' s/ r5 U6 D. x0 b& `. D, \) ^
    7
    + E# p! |$ U. B: ?$ ^9 t9 Z8
    % G+ _& [9 h9 k  j: x& \4 U$ u0 ^9
    # S& C# Z8 ]/ B7 S/ H10. S. f3 ?* r# Z6 |
    119 n: M6 P: q3 b! n6 q, q# `6 b
    124 Q) \5 f" u$ {
    136 Z* W4 s* P2 P" W; L
    14  e/ o+ q2 ?/ t8 |
    15/ r, \- t& N% G3 ~
    16
    " C. K& ~; L7 V7 X* @* J" y0 s* ~& i4 u: ~0 u% n. E( c
    ( l# u/ d# Y0 l: |
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数5 q6 ^2 L5 D  Y  Z( f  w$ M
    model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    5 ?  p$ k% O1 h+ X9 q# 主要的图像识别用resnet来做4 [% j. T) C  [- A  j% W( z  `. }
    # 是否用人家训练好的特征
    : J  V; L5 Z$ c- a3 Lfeature_extract = True6 G( o( k( U+ \9 y$ t% ]
    1
      S3 Y' u9 a, k3 T& y  q2+ ^$ _8 E( A3 U: W! J( P9 v
    3( V3 d3 _/ H0 [
    4( a/ {; W7 X4 `; ~* l
    # 是否用GPU进行训练
    3 n3 a. Y( h/ F: O/ `1 C$ Mtrain_on_gpu = torch.cuda.is_available()- u. u9 L% \, O- S% r& b/ Y
    . ?3 c1 o7 [5 E
    if not train_on_gpu:- O0 O( I& j6 w9 J! o9 I
        print('CUDA is not available.   Training on CPU ...')1 }& [" e3 z" x+ v6 M5 K
    else:
    0 \$ f8 ^; q' K" b0 ^1 u5 Z    print('CUDA is available! Training on GPU ...'). {3 R' @3 x% H* \+ U5 _
    ; H" a3 q, x& }; l8 o7 X2 L: f# A
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu'); n9 Y, \2 V6 s5 }4 J- t
    1
    1 r- M) O! @2 f4 k( h2
    * p2 X# |" ?2 U$ G3
    ' }8 j& A# ^) a/ b4
    * x& Z5 @5 O! N$ N; ~" p, {( ]5
    $ V% B0 L, z- G5 c9 o9 G; q6
    / Q5 E5 u8 H" W1 o  x7
    % N9 Z. I6 ~# s# _8* T' Y3 F! S6 ~0 @0 z
    9
    ( d5 A9 ]! O5 h! o0 ^9 w' ZCUDA is not available.   Training on CPU ...) Z! Z" o! n" ~4 x; ]* E7 `
    1
    * O/ w8 B' m+ J2 D" U8 U" v# 将一些层定义为false,使其不自动更新' G# U0 h5 h, H7 r& d) t
    def set_parameter_requires_grad(model, feature_extracting):! b( w$ l, V+ Y
        if feature_extracting:
    # B/ p/ ?0 j) Y) {/ p* M( ]        for param in model.parameters():
    $ a- ?& }7 P0 y  `" L3 ^  `0 j            param.requires_grad = False3 B# R5 u3 x" ?) `0 `0 P. o1 }0 f
    1, r8 `) ~( ~/ c3 M0 T
    2
    7 c6 g  N/ b, Y3 S8 K9 M3' H5 G$ i! m1 u! V7 ^  t( O
    4
    4 e- k6 g2 x& W$ Y0 u# x/ r2 T5& e  ?% L4 P$ {% m3 p7 i" d
    # 打印模型架构告知是怎么一步一步去完成的
    & a  |+ H/ r) @# 主要是为我们提取特征的1 g0 f2 ?$ J3 S. c* x- e+ E1 V

    $ E' L$ f. z2 c" t7 \model_ft = models.resnet152()
      R5 \9 K& T; J4 w% w! ?4 g/ rmodel_ft  O+ ~) z( [# `; a
    1
    6 _" g% ~0 K, [9 L$ E2 c; N6 t& R2
    . Y' D' T* b5 M3
    8 o( T3 k& I. C$ S# o9 Z- g48 }, A! s$ H; y6 J; ], d
    5
    , O& ?- `( @; f+ K+ ~) e6 G& iResNet(
    2 y- P# }) A6 [0 x! k) w/ N  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    ) m9 O& `; C6 n3 V6 T  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    / L5 [( V. o7 v$ _) b9 p  (relu): ReLU(inplace=True)
    : r: f. F7 R9 M6 k$ `( `0 P% x  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    2 n% ]: ^  A" o% Z3 w: l  (layer1): Sequential(
    , ?3 O0 b0 x! s9 a6 R! Q& N. \8 ~    (0): Bottleneck(/ a6 y/ o9 U& G+ @" N
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      R# w* V1 v  w      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)% B# q% S' ^2 R; H4 f" `
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    + u5 ^& B. K8 r' [2 G" m  ?      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)4 J" r1 O& A  }! ?. ^- K
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    & C2 h6 k! Z, G5 Z0 t      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    $ ?4 v; i# M" s+ f# e: c$ D      (relu): ReLU(inplace=True)( `3 A. Z) p# j0 @# w0 o
          (downsample): Sequential(
    7 Q% Z# ^0 G" r) p- i3 k        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    * P6 _& x4 _4 H( C* C! N3 E        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)' u! A5 d) u( P% \& P5 A
          )
    9 w8 ^; j5 L4 {7 R  ?' t* N* Y    )
    5 j$ ^6 G5 S. y" }4 Q) I" K中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    ) g/ V" H- }3 r9 D- E    (2): Bottleneck(! |9 w* F. v; j$ c$ S
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)1 o. R" Z; W# R6 K. x8 J
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  `/ p. @; s5 U1 `
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False). c' X. J' Y1 t; b1 q
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)7 F* O$ M# {* b; I, k) r
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False). Q4 N3 S% L. c2 f: l8 Y
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ; ~( K+ e3 h+ b% S( m) |8 ?      (relu): ReLU(inplace=True)" U. Y* h5 m& T; m; v: t: v& f
        )
    # X+ ?8 X# [% Z* g; O& R" m  )
    8 s4 W% m+ D, }! x  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    ! u9 ?8 N* V% N# D  (fc): Linear(in_features=2048, out_features=1000, bias=True)1 w9 U0 |. `+ z) `- J! T$ l
    )
    ; l, a. U6 D' Y- J9 O, s3 }
    8 I, g' [! o3 B1 o  b! F* E1 n/ o17 j7 e8 h( x  g/ f9 M' u% L3 E
    2
    4 J# A* |# v# X$ {3 E, V" t3
    : m( g$ F0 h# D- G8 P4 r4 t4
    & F: _% c5 O3 s0 J( x& ]5 y5
    ( T: d' Q! `+ d" s61 N* _! N; R3 {, L% {" g3 A/ W9 ]9 f
    73 a7 d& N1 k* B7 ~" A( i
    8; M9 n- s" Z" R& N. u
    9
    2 }$ g; E" A: W$ X9 k10
    % N7 t5 @' H9 S9 g, c- R/ Q6 {5 V4 R1 ]11
    5 [) r, e7 B5 y+ Y: _) N3 N12: ^$ a6 n2 j/ B& Y" k( m) m
    136 b- I7 R0 [) f% t9 c0 M# ?- g
    14- F1 K& N8 V7 f  p8 F5 p
    15
      n! M* C& N6 e  e16$ K1 _4 j! |" s, d
    17  o6 f7 B, N0 s0 }+ }$ @
    18
    8 r7 P/ ?6 R& w- ]19" n; B/ X( d; [! n
    20
    5 Y( A  L7 j" E) m2 v; |2 w219 K; _5 R3 E" ]1 `- l
    22
    7 H, w8 t3 a+ o1 n5 c23/ I; B& ^8 [% S3 ?& a5 d
    24. _/ g' b& V  U2 b
    25, P% b- m1 X, u7 {4 Y7 @" \5 D- g
    26
    / l# R6 Y( ^/ f8 d2 V- F+ |270 g8 L$ B' }! Y! ^
    28
    - R6 d. D+ f! T' c# w9 ~8 }) k29
    2 ]- T* I3 Q" |' O# W2 I30# |2 T8 l  m& k6 w3 S$ ]3 m
    31
    9 S; @9 B$ Z' S+ a32
    & q1 D9 d# d) a6 \8 ]0 m5 @33
    : b/ {) V& e4 S9 Z( k最后是1000分类,2048输入,分为1000个分类
      @; o- {) d, e6 x3 V; k而我们需要将我们的任务进行调整,将1000分类改为102输出
    * |+ n2 G- W& T; [2 h5 x5 o
    ( x) ~2 b" T5 x) O' T. ^9 R7 d% N6.初始化模型架构
    " ~; e6 V7 D8 @) [0 \! z7 H步骤如下:# X8 d4 y  ]& ]& K7 f

    . c# l* o" f) {7 D' a+ \+ D! w* ~将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
    $ Q/ U: i6 K5 X, i可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)6 B1 e% m9 \# B+ q" v2 T" v
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
    2 z( @9 t; f9 u5 A# g: T官方文档链接
    / d( Z( f9 n% p  yhttps://pytorch.org/vision/stable/models.html
    2 ]. L- f. W. j7 ?. D
    $ r5 w2 ]5 H8 a  M# 将他人的模型加载进来
    / O( |( t! o* a+ `, A8 T. k. d1 tdef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    ' z0 H' q& M; a1 p    # 选择适合的模型,不同的模型初始化参数不同# q8 X/ t6 g- d; H% C' l+ p( W# z3 [
        model_ft = None" C/ v, {: `, Z1 u. h+ l
        input_size = 0/ R0 S' Z& R. [" u7 A) f4 g
    1 ?# a. x* X" P' ]1 M
        if model_name == "resnet":) Z" b8 B8 j/ n( I
            """( n. h" T/ M" T; b+ R% u7 T
            Resnet152& i" W1 f5 H0 D! ?8 j
            """
    8 @7 \. y' w; t
    , l4 I; h+ @7 A9 x( T: X        # 1. 加载与训练网络
    1 ]( e, e, e) h        model_ft = models.resnet152(pretrained = use_pretrained); c  n% H9 S, ]/ {8 B: E- |
            # 2. 是否将提取特征的模块冻住,只训练FC层# Y7 Y: g* F7 p, H. p* F
            set_parameter_requires_grad(model_ft, feature_extract)7 z) J. G) u& c
            # 3. 获得全连接层输入特征5 m+ L. K3 u7 I+ @# f% d/ k
            num_frts = model_ft.fc.in_features
    ' O+ r# G5 P1 x, D5 A" c        # 4. 重新加载全连接层,设置输出1021 \0 q8 E) M  N$ E  ~& _8 M9 }
            model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),: D( r, J( [/ v% i
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    " `! }- t* D0 r4 {        input_size = 224' _" C2 Z( Y* i( v0 n+ k' \
    ; U7 H$ P( d  a: |! R8 J+ d4 d. A7 d3 r
        elif model_name == "alexnet":
    ! C, t% k9 ~# O. `; f- W$ n        """
    & V8 p9 D0 q6 v9 O# p3 r        Alexnet
    * A, ?3 t% `; p        """
    5 \( G- x7 t. |7 s3 C4 }' W: a- b        model_ft = models.alexnet(pretrained = use_pretrained). P; l9 C$ ?" L
            set_parameter_requires_grad(model_ft, feature_extract)! r  r) ~( g' k$ ~# l! u
    " i: o5 t* \& s' z& Y0 _
            # 将最后一个特征输出替换 序号为【6】的分类器
    : f% e" E) o( E8 y0 V        num_frts = model_ft.classifier[6].in_features # 获得FC层输入' I+ B+ ^/ b: y6 j0 {0 T, R
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    1 S) l2 R! X$ |8 }        input_size = 2240 _0 E' f5 c; y4 z. J
    3 ^8 |3 s: W) T
        elif model_name == "vgg":7 b7 n7 \$ w' g/ f
            """
    8 w8 _! F3 n2 x0 j* I        VGG11_bn3 s7 m9 b  J4 K/ T2 ]- D8 w5 b8 D
            """) R  G% |- C  g5 r( S
            model_ft = models.vgg16(pretrained = use_pretrained)
    . t3 \! P; |# ~( t$ J+ \/ m7 `        set_parameter_requires_grad(model_ft, feature_extract)
    ! c: {* `6 S2 x* E        num_frts = model_ft.classifier[6].in_features
    - t- s% M! u% g- H3 e        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)) `9 a6 Q  o+ |8 z% D% v  s5 S
            input_size = 224( m/ M% y5 \9 y( q) [( b
    4 T: Z) `( Y* Z
        elif model_name == "squeezenet":
    8 B6 B- m9 K8 y        """
    0 U; ~  f) J1 S* y7 m/ E        Squeezenet) g( i' n9 a+ U2 M( y
            """$ O( _$ R( [( W( b
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)9 x8 C; Q- y% ]6 i1 |
            set_parameter_requires_grad(model_ft, feature_extract)
    , ?2 [+ y! [5 Q6 \1 d7 N        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
    9 B+ I+ m6 t: m" {        model_ft.num_classes = num_classes
    6 ]8 r+ j  M- W5 E; Q4 S5 E0 b        input_size = 224; P' Y' g' c3 z* N0 I
    2 d! p+ f# ^& C$ |
        elif model_name == "densenet":$ c+ ~* f; L3 R
            """+ N5 w6 M; }7 o
            Densenet, s2 K1 F+ }3 I6 N
            """" H  p. Q1 @# k  U$ `, k5 `
            model_ft = models.desenet121(pretrained = use_pretrained)
    , x9 V; a" b, I. K3 U' I% G        set_parameter_requires_grad(model_ft, feature_extract)& D' Q- v& `. g  J4 j, h  y
            num_frts = model_ft.classifier.in_features; g2 {; a, q6 S
            model_ft.classifier = nn.Linear(num_frts, num_classes)
    1 G" L# t& R& O; P) n        input_size = 2248 P( e! z. |" i9 u! V" Q

    ' i" s& i2 j% L9 I; o, Q    elif model_name == "inception":' _+ ~  s; }7 |% n. ~( h4 v1 T
            """
    + o/ [; j6 D0 \! r. ?  Z        Inception V3
    % s$ {3 H4 F' o) y# i        """: F: O  b  u5 n/ K. x
            model_ft = models.inception_V(pretrained = use_pretrained)
    6 x$ J/ D$ @( ~# j        set_parameter_requires_grad(model_ft, feature_extract)5 g6 M( E. m. p& r! j
    ! F9 z. `& U. G) ?% k+ A) ~& ^
            num_frts = model_ft.AuxLogits.fc.in_features( d/ }7 Z  [/ K: _) m: ~$ A) W
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
    ( C" I  k  A- `1 p6 y0 T8 F8 ~8 R" G2 V/ G; a+ T+ @. A
            num_frts = model_ft.fc.in_features: ]& j+ u; e6 c; ~0 |
            model_ft.fc = nn.Linear(num_frts, num_classes)
    2 |! u6 s4 B0 g0 A4 X0 Q% r        input_size = 299
    - [) P4 E+ `8 Y  e! a4 b
    4 u& n5 v3 a; m5 f, i5 D, r    else:9 a! W4 f- d9 k' F) C
            print("Invalid model name, exiting...")
    2 w( ]/ h' G( {, I5 ^        exit()4 `) N& B  W7 t- W) b; ~" s
    " ~0 D& C) C: @+ F
        return model_ft, input_size& d; |0 [3 K% I% v7 T: T: l

    3 C7 j5 F6 D# J1
    ! o1 ]; w' v" e5 [& F8 i7 z28 V) s5 a2 t, A# r+ b7 i
    37 J. v( \* `: V# p! {
    4
    , l8 Z/ h# f0 r7 L8 S* {- [57 T) y* `8 g0 e6 v8 @! ^9 s
    6
    . z5 ?8 Y( \# @% _% R# [7
    4 C$ k  j1 I+ L( M2 \81 t7 ~/ o3 B1 C# [
    91 k2 H. C' o7 O
    10# Z( [3 x% q* n/ n  d% F
    11
    # e: ?: F) |! w5 J' t12% D4 e' X% D+ q, d6 L1 ^
    13: A2 r, x& z4 J+ X: w
    145 E) ~: D3 |. c1 @/ T% @+ {
    152 b8 F4 Q3 o7 Y0 t' k3 I' h" {" a
    16( E' N: R8 U# N
    176 o% [4 h9 i1 J. d% X, ]
    18" s6 w# m: V2 ^& f4 X) g
    19& I1 F) q. b" g% d" Q: ?. t
    20( U4 }; ^1 b  L8 b9 A* B
    21
    - P5 n) q) B+ M! p22
    7 _, e' g% y2 T6 k! q' n/ u232 ~6 q! a6 E8 o- Q/ d
    24. D! J3 ?5 u! ?1 P6 h
    25
    3 C0 E: r: h: r" x) p/ K262 }$ e' F& k7 H( J- M
    27
    4 h7 l/ D1 J0 ~/ ]. @- T0 g28
    - Y6 f" ?; Q, Y: ]" D% n+ L! b294 ^/ B2 Q  n3 E$ r6 l# V) s
    30  W5 Y0 e9 L' f' j2 j; Z6 X' w# ~- C
    31
    - K. Y. n4 }" R2 ?32
    ! ]) g9 A% P: e2 j4 _8 g337 }5 X* E" W+ _5 N, H' V
    34
    / I3 w8 l: R2 }- H/ i) _) Q35
    - X9 R9 ]2 X# M7 I+ z5 Z* w36" \( a( m: h( h
    37, j. |) G% l- y* h: c
    38% x* P6 l3 K* W' @3 s
    39
    $ ~# ~0 ?% a" Y  s# J+ J40' T6 g0 o5 i1 m6 {
    41
    : u  P1 d/ p$ W. p3 r: p& i4 R; T42
    , _6 _$ `9 }+ E8 f! |/ b43
    7 n" C0 h4 ^# b" u) Q0 X444 S  s# F1 N7 r
    45
    2 N, r3 e; {' }5 D/ E: [  A46) \1 r: V% [( u) f
    47
    - A- }# A7 l! A4 y* k48
    ! b8 q3 W* o2 q4 G8 [6 x) C% t49
    & x: l8 \, V& \4 |+ e' P/ z4 I504 t' Z1 s  W/ W, e9 U; Q8 h! q% Z
    51: `: a6 o* ]" A  ?7 G
    52" g/ B) a6 a" B( l2 h; H( ]
    53
    . R6 R/ z  U! `0 B% u* X9 N54
    " {, R* z) r7 q) o55
    7 {! O! J' z" F! ~! r# r: ?56+ Q- q5 t, M& B" W- F3 r! [
    57- P, r* m/ R9 ^# B
    58
    $ S1 n8 _- ^- b" ~+ u599 t, }2 B9 O* }# ?
    60
    ) O0 L( g% `' U7 Y) ?1 P61
    ) I9 ^# o; ^2 v! s% B" u$ Y) j1 f62
    . L- U! S& x! W* l* T+ O4 L63
    1 m/ s" U4 o7 s: j4 R$ s8 _: U3 X64
    1 A+ ]/ P: o: R% B8 ~$ [8 @2 }65# n5 P/ r8 z6 x& C
    666 d, Z9 ~  Q, e/ s% R& y
    67/ y3 h( Q+ E5 n: O* X3 }& ^% y
    683 i' p0 U- Y1 a5 r  g; i/ y3 O7 m
    69( B+ R' K! E8 a$ j- {
    70
    * b2 g# c  H# g" H71) N; T$ f- R( r- C" ^$ a3 X4 c
    72
    4 d. s6 s5 z' b- B9 W* M73! x5 u5 X5 F# i- I& l& o! Q( E& ?
    74
    4 y6 t7 j; ?& N75. m7 Y) n, ]  c! u. n
    76
    . ^2 y8 m; f7 A2 |& D772 z" i9 }* B  X+ Q
    78- r2 Z; C2 _# _* w9 {0 V7 T0 R- b
    79' }5 a) W( X6 K. B- v" Y
    80
    5 b7 Q* G0 b  Z7 t" K: c- E" ?8 c81
    0 S$ q& `9 z( j( C, Z! N8 a8 @82
    5 Z) E6 F- J9 `, [8 U83+ z9 {4 x9 [% X8 N
    7. 设置需要训练的参数* c5 \5 M6 q  v# |( R2 [
    # 设置模型名字、输出分类数
    ' T5 Y, ?( c. \) _0 cmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    0 u. \) s% s6 ?) O& P7 j
    " R, s* d8 B5 `+ l2 T( j" {; k# GPU 计算
    ' U0 K' [! m) e! n' ^model_ft = model_ft.to(device)
    1 Q; A) B( {7 |' Q' I; d
    - n0 r/ Q% e; H/ i) E2 e) L2 }# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    8 }" l( A1 [- w! P0 c+ ]  qfilename = 'checkpoint.pth'8 B, v5 f* `; c

    % m/ ~& \9 z; \7 s, e' Q# 是否训练所有层3 m7 Y# ]* D; h( `
    params_to_update = model_ft.parameters()
    ( R% }- g% R' A- ^! f- w; z# 打印出需要训练的层8 F# k% }1 a9 n
    print("Params to learn:")
    # N% N, J& A5 [4 r& ]. i" aif feature_extract:
    & Y+ x+ N2 e! l  j! S3 B    params_to_update = []3 G) W  T& R' S
        for name, param in model_ft.named_parameters():
    0 y$ F$ j; j6 O# ]6 z# g        if param.requires_grad == True:
    ! f- v1 ?; V( K9 p6 ?$ S, T2 Q            params_to_update.append(param)
    1 `! ]- v3 P) C; a% z            print("\t", name)
    ( I/ h0 y7 _+ Y; l- p- q, q5 Melse:
    , w( s( q, p5 s* U0 d    for name, param in model_ft.named_parameters():) [8 K) ]$ o7 K9 k2 g3 n3 p" A
            if param.requires_grad ==True:2 q1 }: g( l6 E) m8 x. _
                print("\t", name)
    ( K, w% F& Q* q( i$ D# k2 r7 T+ [  b0 {- w- |; [% N
    1
    0 ^5 `. t5 {$ Z$ r0 ]# V2" U" y0 n  j' v. D
    3
    ' r" @; P% S  ]49 A, j' o+ E  c9 Z3 j' B; g5 c  p
    5
    ; P' ?6 j' r6 K$ F6/ j' d# ]6 D0 B7 @
    76 M; A- c  j) m) R  z
    87 \$ Q$ r( n, t" H
    9- Q3 Q3 g3 s  l7 F& V+ E! M
    10- ^* `  f+ Y1 T. o% q
    11
    3 r) C( m  Q- R- E8 \2 G9 Z) s5 W1 ^12
    5 X# Q' O' n& ?, j9 P0 d! m; I1 f( E; X13
    3 t# v/ y7 G# R14% @4 P# i7 R$ \* ?4 ~3 H3 w
    15
    8 E# Z$ A$ C  p* o4 q7 B* k  d16: B$ u: j  B5 @6 b
    17
    7 U% x4 p: u7 c( ~) }189 k/ I* U8 P# Y& ~& I% m& a
    19  I; r( d) N* O  C- b  X% P( ~
    202 H- w, t6 F6 h  ^: q
    21
    5 R0 x* ?9 ^- B- b0 _8 x$ n22
    2 v/ a8 [# E' `5 B23
    8 n& ?/ \( T6 e% kParams to learn:# g$ O# ~" P, z/ x) B9 C) D/ s6 j
             fc.0.weight, O6 K9 O4 G1 l" G: H" N5 \- a$ a
             fc.0.bias( G, R! h4 G; _8 e" ]* u
    1  L* U. h0 x6 m+ J
    2
    ! R7 P: V* _* O" ^8 P+ u  t3" h7 j; F. |: L
    7. 训练与预测! J) }4 J! v: f. t8 ?; Z4 m
    7.1 优化器设置
    * \* @3 A! }# H# H/ k; a/ I9 `# 优化器设置
    ( B: w# f& N* Foptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2): [: ?, M" ^9 p) D
    # 学习率衰减策略& J6 M7 E: d! ^* g: j  \
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1); }# d1 W- }$ C2 ]2 o
    # 学习率每7个epoch衰减为原来的1/10% Q" k& W% h4 i5 j
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    / q) ?2 g& d7 t
    $ V* W; g& E- @( |9 c  e; ~criterion = nn.NLLLoss()
    + {6 b# Y! K9 n  ]2 E5 `; `1
      n# \  c" o0 G* y( t2- J0 _% l4 J9 o
    3
    0 Z2 a# A& F4 s5 e6 E4: O$ v, j; F* E: g0 a
    5
      c" t$ m+ v+ ^63 ~% F) @: q) d$ X$ s4 Z* h
    7. W. Q% Y! X; r  K9 C
    8
    & x5 i, P+ Q5 L  ^( q/ K# 定义训练函数
      M' S8 [/ j0 e7 N#is_inception:要不要用其他的网络- z5 D5 t* T- A# {) `8 i
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    , z$ d6 h) A; A/ R- }3 t5 F# C    since = time.time()
    3 g) G  Y7 V* M( V    #保存最好的准确率1 Y+ B: k! @/ i
        best_acc = 0/ l, ?5 K( v* i2 y" R) L0 h4 L
        """+ t& A7 d2 U' w3 F4 M, T" H' P
        checkpoint = torch.load(filename)( P% M6 g/ g/ ~9 i
        best_acc = checkpoint['best_acc']
    0 @# E- Y5 D' l; S  r    model.load_state_dict(checkpoint['state_dict'])9 v) J9 J. C' S$ A& g  _
        optimizer.load_state_dict(checkpoint['optimizer'])
    ) |. S9 t" z. h9 b; e    model.class_to_idx = checkpoint['mapping']
    8 x- r4 g. ^; _6 B5 N    """
    5 C+ A  Z% j: y9 O( p* i) R    #指定用GPU还是CPU
    5 z9 m3 D0 Q9 j# h6 \7 E    model.to(device)
    * Y% n' m" E7 n    #下面是为展示做的. L& [1 M5 X1 z3 ]3 C% ^, ^, F( u
        val_acc_history = []  ~$ \7 P: `! r4 E) _
        train_acc_history = []
    ) G1 j: N! |. M8 C* O) s( I    train_losses = []5 K) X. w" {) }7 g  {# H
        valid_losses = []; w8 W1 ?* q9 X6 ^% p4 I( X
        LRs = [optimizer.param_groups[0]['lr']]
    . w5 k( y# Q( E    #最好的一次存下来
    ' ^; S0 d+ @0 v% v+ E  F    best_model_wts = copy.deepcopy(model.state_dict())& F" J% F0 l) y$ R3 r3 A# y
    * g0 ~- Q; a* W$ }
        for epoch in range(num_epochs):( |/ ?4 K% R+ T9 |; ~. x7 l; [
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))# o. A) Q6 g% k! j8 m. g  P
            print('-' * 10)( p; V( k8 N% ~- o) J: ^# e, ^
    $ l0 k" f% ?9 P
            # 训练和验证
    1 _. ?+ G6 ~7 y$ F7 Y7 l7 q' [$ R7 z        for phase in ['train', 'valid']:
    5 H0 S$ ]! J$ y. `* \. |! j            if phase == 'train':
    5 v0 D! {- P' z% N                model.train()  # 训练
    9 u9 V( T. m( c0 S            else:
    0 u- y! _7 _7 l                model.eval()   # 验证
    & Q2 [$ B1 @) M. G  Y
    ' [6 S$ v& W# D3 l8 |8 |            running_loss = 0.0& f% _/ M# F7 T3 u
                running_corrects = 0  E9 ^0 N  B" W( G# ~1 |3 s5 \

    2 L, h# h. p" ^# p; ^' u6 E  T            # 把数据都取个遍
    9 e& B1 d9 p7 N9 e& e            for inputs, labels in dataloaders[phase]:
    8 |$ @  A+ ^$ D. F' c4 a                #下面是将inputs,labels传到GPU
    - s# v* i$ t, P" V3 c                inputs = inputs.to(device), F! |! E) G5 y( P& K$ H
                    labels = labels.to(device)( c% ]1 r. n5 R

    ) T1 y* h+ G# D+ V; W+ n+ Z9 S8 l                # 清零
    - m, Q: }6 k$ S% }# U4 `% g; W                optimizer.zero_grad()" R/ K# f( {  u0 `4 x- \
                    # 只有训练的时候计算和更新梯度
    % X( N7 ?9 p% G' C1 I1 ^                with torch.set_grad_enabled(phase == 'train'):$ g% f% t, k! Y& [  P1 Z, x" Y
                        #if这面不需要计算,可忽略
    1 W- ^1 q5 }# c, y5 X: Y                    if is_inception and phase == 'train':7 _+ N7 i9 w4 S: S3 V( z
                            outputs, aux_outputs = model(inputs)/ a+ Z' Y0 h/ B' w  h: f, P
                            loss1 = criterion(outputs, labels)0 h9 U% _* d3 U. r0 b
                            loss2 = criterion(aux_outputs, labels)
    * i8 x$ Q3 p& R6 g                        loss = loss1 + 0.4*loss2! ?  W# ^/ s# s% z1 W! T2 f
                        else:#resnet执行的是这里
    & R2 Y2 h, ]& X2 d  |4 F! |/ k# ~                        outputs = model(inputs)
    4 E# ?& [' e7 p7 s0 G3 E                        loss = criterion(outputs, labels)
    5 \! p" x& I, \4 X( k. N/ w; b1 t2 U: U& Z1 O2 x- ~
                            #概率最大的返回preds, d; q& H9 O2 o8 O+ m
                        _, preds = torch.max(outputs, 1)
    * [7 @2 i; Y! w7 t2 ^
    ' p( H8 ~$ L, r! }                    # 训练阶段更新权重* y7 Q7 g5 j* s/ r9 x0 P0 J
                        if phase == 'train':
    2 }7 ^* X- ~% T+ L                        loss.backward()
    ! `1 y6 H) u  s9 s4 ]1 ?                        optimizer.step()7 }2 i" U. P( S# K) A) c
    . X( P6 V& w/ h- B! H$ D; k/ K$ P
                    # 计算损失, Z7 b9 n: M, s: p1 Z! B
                    running_loss += loss.item() * inputs.size(0)% b9 ?% }" B) ~' |5 U
                    running_corrects += torch.sum(preds == labels.data)
    " f$ }( \# w( C! d) U  ^& O8 a+ T0 ?) W4 N5 F7 X3 q5 _
                #打印操作
    2 j9 Z5 K- `, Q; G# \3 h  m% q3 ]            epoch_loss = running_loss / len(dataloaders[phase].dataset)0 T5 W( g- w- W  [& W8 Z8 T& }8 d3 h
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)+ y7 S. C. u1 {3 W# p) }8 N& Y/ J( l* v
    % b2 J6 o# C% C! @

    8 v( J2 I- q# {# O- N            time_elapsed = time.time() - since6 J/ y- x& D+ b' V8 s$ _" F" v
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    0 G' U  c, p) _5 b# X, T9 R* t3 v            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    3 b6 v4 H/ h# i5 C- I/ j% ]0 W" }# y" j+ A
      s) n. l4 h; e7 j) U
                # 得到最好那次的模型
    + i2 g, o' N- _6 C2 J' H            if phase == 'valid' and epoch_acc > best_acc:
    8 h- P; M" ~* e  Z                best_acc = epoch_acc' f2 p% ?$ o2 [, @+ {
                    #模型保存
    - s8 V- r  ~0 s                best_model_wts = copy.deepcopy(model.state_dict())/ c( Z# g" S2 F$ r5 {. c1 \
                    state = {$ `) W  `8 `- ]9 G. ~1 R
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    . S7 e8 ~5 c0 _9 Y! m                  'state_dict': model.state_dict(),7 ~$ `" x7 Q: f8 Y. A
                      'best_acc': best_acc,: ?) O/ Y1 h2 P- T5 q1 ]
                      'optimizer' : optimizer.state_dict(),5 H6 ]9 z3 a9 `
                    }
    / M0 B4 B; N9 _5 {* |9 G- r                torch.save(state, filename)
    + A0 O& v, A4 B* I/ }' T            if phase == 'valid':7 ^+ G3 A; a' @5 l9 g4 z" d! m
                    val_acc_history.append(epoch_acc)! f* R& n6 c/ ~8 C7 s; b
                    valid_losses.append(epoch_loss)9 R1 @) ]( Y6 {/ A+ ^% ^
                    scheduler.step(epoch_loss)
    ; b/ ?, c# F& ~5 M' J7 J4 N( G4 c4 c            if phase == 'train':
    " j' {+ H- o* E. h                train_acc_history.append(epoch_acc)
    / B" m$ ^( G- d& z. A" p                train_losses.append(epoch_loss)
    6 ^. Z; R4 q! B& A* P
    / R* K% S  V3 I, S# Q        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))/ P/ c$ n  c7 r- m' K1 g! d, ^
            LRs.append(optimizer.param_groups[0]['lr'])+ j0 F+ e/ D  ^0 h) J3 a
            print()4 G, Q6 C. q/ @' l8 R
    ' V& a5 H0 Q- Y; R
        time_elapsed = time.time() - since. H* f: D$ e1 P' Q, Q. s7 M
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))6 B! o, y% o: n
        print('Best val Acc: {:4f}'.format(best_acc))
    ; t/ z+ Q# f: b7 O
    $ Z2 c: W4 h: ^7 W& z# y- Q    # 保存训练完后用最好的一次当做模型最终的结果% S5 r+ q1 F' }' M; [
        model.load_state_dict(best_model_wts)
    $ l, t3 T$ i$ F: y/ D9 R    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 1 s, _( k% B2 z, E% _& i5 m
    ; @  E; h/ f: x: ?7 {) e
    3 y$ O, O* O2 b% C- L2 R
    1
    $ ^- l) z0 z# v2
    6 w2 R* b9 k( O+ I1 b  v9 ]3
    ) B% D0 T$ n  o4
    - [, S: ^" t! D# t5# E+ v$ w' E. P1 c# x7 Z
    6
    * e# \1 v- _  S/ v, F71 `  o* D! O- \
    8$ o/ ~6 y6 x+ w' a/ H( }2 R, F
    9
    2 B- D3 u+ E* v' w% v10
    ' G/ j! }- D' Y" J5 f" F11
    5 r+ D2 d9 b3 q' S* O# T9 R1 c3 {12. P9 }4 v. s# o+ z8 ]. Q5 L
    13
    % l' Q7 U+ m7 C$ P) e4 _# T: J: [) v. e14
    ; c4 ~# t% Y, P# L- ~) G& [15
    0 r8 V( c' q1 _# K* n! j9 z16
    3 n( p7 h8 l( x$ ]  J173 x: D/ _$ u+ A3 \: ?- @
    18
    " }4 R! Y% k; J  j  E. r19
    ( D% G- Q! _# l# y# R9 d20+ X2 r( h- K0 ]
    21; P% {4 t- ]" w
    22
    ' Q: o4 o8 T  W23- f  c9 i' @6 ^$ f9 d" p, r9 j
    24
      ~$ Z: i. E7 ?0 @0 g/ E255 e$ `. X6 {0 j  U
    26
    6 l* }% X2 X9 y4 i7 @1 k2 M/ [27( H% `' ^# A$ }; _' s* g9 ?  Q% C
    28
    6 w2 W2 ^0 ^) C29  _2 y+ E" o/ G/ X9 Y6 ~6 G
    30( M9 O& k1 U- `/ n" R( y9 ?
    316 p7 w% ?& U. E4 o
    327 n6 `7 a* t8 ?* k) u- s( r2 @) `3 ~
    33
    ; I$ ^4 J" y8 n  f345 x7 }0 u+ H9 u2 }
    35# R$ g6 u( D+ |* K' O# J$ ^
    36
    ; Q& t5 z8 l( `9 n% |37
    0 F% s8 Q4 y3 H: K, ^/ [38) Y" l4 `! S5 P9 v. s1 _
    39
    3 [+ s, Z% t5 S. Y7 D4 B' Q- d40. J' T) \  H. K+ M# M% l, I1 j) c
    41
    & j6 n4 p' V/ C0 e/ a0 f' k  W: L42- d! {* \; |, |3 e2 ]# K
    43* {+ E$ C1 Q/ X
    44! d  n* x0 y6 S6 P7 ]$ |
    45
    3 T) e' O3 H: ?4 C, M46
    & z1 ?9 }3 N% A; N47" t! V9 C) [8 Q4 O" H5 K
    48
    8 W3 w; v2 i; _+ m& f8 U49
    + I5 ?6 G# M2 h- K% Z50& `; R. ~( ]" Z& i) w
    51
    ! f' y' a' q2 |; I0 C52
    3 E' ~/ n. o( a4 C9 a" j53
    . q" @7 E6 a  Q) A( k% D54
    : o" t' k6 j" \- t' \6 }55" ^% u) t: M# D1 \* z( D
    56
    # V6 m; ~  W1 N" I57
    3 U9 ]) f; ?; m- ]' U( F- T58
    , v; D2 w( \2 D2 `1 ^: _59
    * `+ `% I* R) S3 u. v" ]60
    ; b* R( ~0 K& v" @" G61! _# x5 Y' \0 I- {
    62+ j  K( h2 N; F) |' u
    63$ M- y. `' |9 A' X/ W& I- e1 G: o) U
    64& R* i1 B4 i+ m# \' y, [- i
    65
    : v* w8 |# l9 h6 U( ~) o/ N66
    , }9 }# U7 P( U3 |% W  T67
    / c2 z; K# x3 i" v9 x) i68
    ' \# z% \& W+ P+ e9 G69
    + T  \% J0 s  k% o4 U0 ?6 k& y70
    0 C8 h4 h. ^( ]( _0 `/ r71
    . l# V+ W1 E" I, x, }. A5 X: H72
    # l: l; J' D: |" \73) W2 t+ u# R, U- j
    74$ Q) \! u$ L% p
    75  c- \- b/ V9 q; a. \+ E5 m
    76
    ( `/ p8 ]. N+ K2 H77
    8 p! V$ I, f" d  C) B! Q0 B78
    ' x1 }- `+ i5 q) h$ A+ R& S* H79
    1 s( {- k4 _6 n% B80
      \6 g- N6 C' h8 T) _1 C1 v4 ?81* F8 V8 l4 p) x; s. N4 G
    829 x1 X4 R; u) r1 r$ [4 w
    83
    ) j* {% j) Z* T84
    8 Z! F- C( F  X" W85) f: q, H/ ^4 s
    86. n: i; M" q2 \; e8 t. h. H
    87. |+ c" o0 N# N* D; n: K1 y
    88& ~4 l8 f. |$ c1 }! }, v
    89
    + a0 G+ u2 `9 J90
    ! H" i% l/ g: _) f* ?. i, e91
    # w! L" F" v& U- i92) A9 g/ p* J9 }5 X! p/ Z$ J% c8 a
    93) m3 j, s7 |- p/ d
    94# ~( j7 Y. l" G' ~4 I/ H% c
    95$ F+ ^& b$ t& V" D# J
    96
    $ B4 P- ~) l8 }9 P7 I' f: L6 n97
    8 d0 K/ f4 T2 H5 v% h" C1 i98: u5 r- z" l+ @( y
    99# Y# Z$ o0 ]7 M9 q  z
    100
    , y" K7 r- H( r) J& f* V101
    , Y2 K( m8 d& ?2 e9 R: E1027 B8 C! J8 y& g( q
    103. ?3 L2 k) C( e9 f+ t: G
    104
    ' a" L* c# A0 ?. f105; l+ X% ^7 A" o* e' U4 u
    1069 t3 b. }* I# w' {
    107
    ; V6 W9 V, _# r! ~/ O+ Y. P108
    # p; L, j; d% j1090 S" Y5 h9 C; k( d5 N
    110% l  S8 ]; M$ L- p% ~2 w! l$ B2 R
    1112 L% d9 z( l& e, h
    112/ S4 `  D( K# I7 y" }
    7.2 开始训练模型
    0 {' C( l- {; g) d+ o3 [+ v我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次4 r2 B/ g' T5 q8 k7 Q
    ( ]; L7 U: V7 ]0 B
    #若太慢,把epoch调低,迭代50次可能好些$ z# V- Z1 R: ~& W/ o3 {" F) A8 R
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合( p  x) L' {% x8 U% I) _4 T
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
      R) p4 x. E2 H9 P8 ~
    - [% X! n7 g5 I1& Q" ?$ O+ X6 a) M8 X, r, u- H4 u
    2! X% O" X+ J7 |) f' y1 p& j
    36 u/ @  W3 v# F; C
    44 R' Y2 s" o: `0 B( x- T
    Epoch 0/4
    : q: y" A3 Z- Q) z----------
    ( N) V6 v4 g$ |  v" \5 P' CTime elapsed 29m 41s
    3 P) i! A1 P( R& z3 K" g/ U& h4 Ktrain Loss: 10.4774 Acc: 0.3147  `* J) q( i% x
    Time elapsed 32m 54s
    # S" ]) ~$ C- j* Ovalid Loss: 8.2902 Acc: 0.47194 Q5 k# m& w' Y
    Optimizer learning rate : 0.0010000
    & t; H, }3 T% A8 Y% i
    & @5 W' p) Q! O/ B$ A' gEpoch 1/4
    + `. t3 K$ o2 j5 [: r8 R  s----------0 V& s& u4 a4 m1 b9 i- M
    Time elapsed 60m 11s
    ! r7 n- w3 y6 D, Wtrain Loss: 2.3126 Acc: 0.7053
    5 l4 `+ L% k" d4 ~/ F( O5 VTime elapsed 63m 16s
    $ X, \9 Z4 O1 r  a) f1 J8 Z8 jvalid Loss: 3.2325 Acc: 0.6626
    3 y$ d) S3 m4 \5 p" ?; v! IOptimizer learning rate : 0.0100000" f9 \, a5 f! G& k7 J, U8 }
    * ~, Y: f4 c8 R
    Epoch 2/4' i7 E' y% Z5 ~. V# e
    ----------
    - Z# J  Z& \" T4 k2 K9 YTime elapsed 90m 58s! u$ u( w( ]8 X
    train Loss: 9.9720 Acc: 0.4734
    6 Y) j5 X! |! p0 v1 R- F% K/ LTime elapsed 94m 4s8 d/ s# D6 X' a7 o
    valid Loss: 14.0426 Acc: 0.4413
    & |* U0 D4 @) KOptimizer learning rate : 0.0001000
    % t! x! c; k' S% \5 d, k. o6 F
    % V$ p, B. m4 G: @( K* eEpoch 3/4
    " o' w9 f( w: e& a9 S----------1 y# t3 ~1 x# R/ P7 p- @+ r
    Time elapsed 132m 49s
    ) @% a- ^! }' O4 `+ {$ ^5 q- @( Xtrain Loss: 5.4290 Acc: 0.6548
    & W- w) E: j! p( TTime elapsed 138m 49s
    " |/ U/ G# o* d* l6 t' S( pvalid Loss: 6.4208 Acc: 0.6027) T$ ^# _/ n" D, ]4 X
    Optimizer learning rate : 0.0100000/ d0 h, a0 L, G4 v1 s. i
    . r! m- c& N$ ^( U
    Epoch 4/42 z- }& @+ O6 t
    ----------. e4 S- h' t0 P6 L
    Time elapsed 195m 56s
    " m0 n" m& x: x' gtrain Loss: 8.8911 Acc: 0.55192 s/ N: S0 F( f$ X. L; B
    Time elapsed 199m 16s! a2 y4 {' d; m1 Z2 L- z* _
    valid Loss: 13.2221 Acc: 0.4914) m9 c. Z$ N- y/ U/ {
    Optimizer learning rate : 0.0010000
    " r) d3 r( p. B; \# i5 a, R# {5 x3 j' W0 H8 s) A& ]
    Training complete in 199m 16s
    5 z  l+ p7 G/ o6 U  pBest val Acc: 0.662592! W3 E$ Z6 K; r+ v0 A% i5 n
    ! [/ y; ^7 s$ v$ u$ {6 z/ t
    1  p9 D6 P  z: V
    2
    3 W/ }1 H* u' L# S3) X  Z/ t+ m+ W5 @/ d3 s0 D
    4
    - J6 W! y2 D1 ]4 ^3 o4 z( p50 j) f3 A: d% ]: b2 K
    6  D0 m3 ?% g; @, k% i( l
    7
    - B2 t/ d5 p5 R* W2 \$ d& [8
    . r! o' O/ ?; a0 j9. f; B  S4 ]2 N
    10
      g5 P0 h/ G/ e11
    . M. V* ?% s9 w, J4 |6 G12# t# H# B" ]& ]- r4 ]
    13
    . H5 a6 S* S) O3 q0 o14
    4 O8 M6 P# L( I% }( @% _15
    6 h) d/ I2 a& S# y, E1 i( Z0 s$ w16
    8 z6 ~/ s) x- B; l, @$ X. @$ K17: v9 c6 x, n& n% F; J0 R4 ^0 x
    18
    . @$ b' I( q8 {190 d/ \4 }8 i, l) `% E0 P
    20
    8 R( D% r+ {3 J5 f6 {# g  ?21
    . J0 h6 T0 }1 Q& O- V8 f22
    ( q6 V1 M7 P0 p. R6 m4 ^, B23
    " b4 L. W0 s0 ]" P! V( I3 ?! e24. c& I$ C8 t# b9 C7 f$ I+ ?0 t
    25
    # Q1 y1 N1 l5 J' k3 O. z26
    * M& x7 a- ~6 a) H6 k3 m27
    % W) m. S" K; H28
    ' q7 I. Q$ ]0 Z; K$ K29" R* a# q0 o6 c4 W% q
    30
    5 J0 f, a3 l& {& {: L0 B31
    ! F! ~' u& w2 T% n1 Y  t% Z+ r32
    1 h1 h7 F2 q& q8 A. ]33
    * Q7 G! k; J7 p* \- W2 f* H" f34
    : M  J; N0 P) q35: U. r; @  |+ T
    36
    ) q4 M" H8 I7 I2 X) x+ V0 x8 R4 h3 M37
      g- Z" M: m1 ?7 |8 o2 y% L/ r$ }38) T! t2 B% \; e6 m8 |3 o
    39
    ) q. B! g  \. y, G, X4 m0 c40" N! l& L1 K3 w( f! C
    41
    : S, m1 K! ]2 w( I, ?42# O( A+ q; t. \" {7 r) B3 c
    7.3 训练所有层: o% H: P6 X! {; e# S
    # 将全部网络解锁进行训练
    # s  C! H5 |2 _/ T4 E2 Hfor param in model_ft.parameters():
    " x8 _1 I8 v0 F. K* \. @    param.requires_grad = True
    1 K4 S7 J) \1 ~5 e2 l" n3 L
    6 R/ G) ]6 _, F$ Z% J+ V2 S# 再继续训练所有的参数,学习率调小一点\
    6 ?' ]4 C. u6 hoptimizer = optim.Adam(params_to_update, lr = 1e-4)2 N. Y6 M6 H; p  D
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    6 H4 O9 i' C8 J8 ~: G7 Q
    2 n1 @* R: X/ \. F# 损失函数
    / ]& }) x$ j  r2 K$ m2 v. Pcriterion = nn.NLLLoss()' S7 e9 T& x9 m5 W- p& C8 h( S
    12 Y; [7 G  P4 n7 W  b4 T8 {5 k3 _
    2. _2 p9 P! \0 c
    3
    % ~9 t) q6 c) h1 ~# K4
    ' c: y  O5 g( H. i$ F2 L) x51 t# D8 B2 t9 X( X
    6
    ; {% X5 Z: m. |& ~2 b78 z2 o. J0 f9 R* R. ]/ v! _
    8
    # U/ a$ u5 I* A  K4 u* F0 Q) n9' z8 O$ G! H* i5 G
    10& @2 R/ T4 R/ Q8 A. F" d: L
    # 加载保存的参数
    , X) M) C. o2 \2 b/ v. t# 并在原有的模型基础上继续训练3 m: B& c+ R5 p7 E1 f0 O8 E7 Y
    # 下面保存的是刚刚训练效果较好的路径4 I. B( B2 C! j# M/ [* e' k  M
    checkpoint = torch.load(filename)/ q5 q7 z$ h! V# h5 ?
    best_acc = checkpoint['best_acc']$ U9 `3 ]; n& A( C" _; E, I
    model_ft.load_state_dict(checkpoint['state_dict'])& E# h, @9 Y. h& F( c& Q
    optimizer.load_state_dict(checkpoint['optimizer'])
    4 ?" H: n0 ?! q11 r1 J1 A6 k  I) J9 m# ^0 x
    2" T' N$ K" v: _+ ^2 h; ]$ w7 H
    3: E+ m, o( o% g+ N
    42 B8 C7 F& ]' }, m( u& }7 \* f7 H' u. P
    56 B! b! p; ?0 r5 s. s/ e) }: R
    6( {6 P& d9 c$ q$ K; u3 X8 C$ ^
    7
    # d$ r$ }) s; S. D7 s开始训练. [, v% I8 _( r0 p
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考2 d' r/ [+ K! U8 K3 p6 y, w7 {: B) l$ y
    9 z& d& O+ T1 q- U: B' \
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception")), S% H; Q: ]* c4 ]+ T1 g
    14 |6 g- F/ ^7 V
    Epoch 0/1/ W+ A# O- F, f% G- r7 w6 ]/ N
    ----------
    5 p3 O: g- r% w7 }- y- s& PTime elapsed 35m 22s
    , h- C3 {1 r" T/ x) a( z( ?0 p5 T* R, `train Loss: 1.7636 Acc: 0.7346" O+ V! j, G' ~9 j& |! l
    Time elapsed 38m 42s
    * z! K* W  k7 O/ l" c! vvalid Loss: 3.6377 Acc: 0.6455" \4 i! {2 _# k4 A
    Optimizer learning rate : 0.00100004 g  f' X0 m; C3 z% P4 [

      y0 s+ E5 v9 N) {9 j5 u; KEpoch 1/1& C& a: g( y+ M: K; `0 |+ d2 m
    ----------& S' o+ A! N6 b1 W! m( d( i( f
    Time elapsed 82m 59s/ D0 K7 s( f/ g
    train Loss: 1.7543 Acc: 0.73408 O5 }# }8 I4 z6 U8 b
    Time elapsed 86m 11s
    ! w# ?4 r8 V: ~1 F  y! Hvalid Loss: 3.8275 Acc: 0.6137( d0 F% ?: i7 }  E
    Optimizer learning rate : 0.0010000
    5 Q( p" ^1 F) q# E1 ?2 O7 |7 B8 K
    2 Y$ f. d0 j: p* S, u% LTraining complete in 86m 11s
    * B( a( B# G" p. Y1 o9 tBest val Acc: 0.645477
    ! {5 Z. Q& @9 }/ k* J1 `
    : a8 R; g3 F* o. c- \% r8 ~1
    7 I1 D7 u; z7 x6 P3 B" v/ |2
    ( k' A/ H$ m2 i/ P35 j/ _0 S6 J' Z3 H4 n
    4
    " [7 ^/ m! s2 S* l  L$ F" q% Q5, |+ t( S* X6 Z/ q% ]
    6+ ?5 Q7 `+ j/ [3 |' v
    7* |$ o' ]) B8 `, h8 `8 T: K; g1 f$ z! |
    8
    ! u3 q, t$ y; Q% O7 B' ]9
    2 U5 K5 g/ \) O: b3 \* ]100 Z6 V$ `. }, Z- O2 `
    11
    / s- t% n/ M; m3 ^12+ u6 z7 Y; o& _! @+ f  u/ `
    13
    9 K- C/ D  }5 ]# }& g6 [% ^14
    2 A# t8 n2 o" r% G8 q15
    8 S- K0 t4 P* e* r7 u16+ m8 a$ W9 q3 D
    17
    1 D8 x7 ?$ l( {0 Q4 K/ D( H( d18
    & i' C2 ?" R8 P- j! Y; p; f3 Z8. 加载已经训练的模型
    ! G9 Y' t* G- w+ j相当于做一次简单的前向传播(逻辑推理),不用更新参数4 ~0 y' Y8 M. d; O! G

    1 R7 k9 H+ W, Z& Y  Tmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)3 b$ h2 e, a& ~- y: ?$ q

    * e# |4 q8 s' z5 V: t2 `# GPU 模式
    + ?  J) N- J& b' Cmodel_ft = model_ft.to(device) # 扔到GPU中" h. a6 w6 Z( Y$ C' T- L  H
    % ~3 \2 K% M3 l
    # 保存文件的名字6 A9 ]6 w8 n1 y' H+ e0 G
    filename='checkpoint.pth'
    2 A# J) n$ W& h- Z
    ' c* L) `! r  i) K/ \# 加载模型% R" M4 ]3 Z8 H
    checkpoint = torch.load(filename)- A2 I9 Q0 F1 R8 t- H9 |
    best_acc = checkpoint['best_acc']- {$ z* @: d7 D! D
    model_ft.load_state_dict(checkpoint['state_dict'])- N% ?4 t4 I1 ]
    16 ]$ u+ s, v, Z$ P4 L# ^5 f+ h& I
    2, Y9 V7 U% c! o1 X5 D
    3$ i3 i( q/ d: n; G
    45 m& e9 T6 a! T# `/ o8 W, p
    5
    ! L0 J3 ?4 ]* ~, T* I- }* J6 }6* d! E' L$ ]8 D( V, K$ V5 Q: s( N
    7
    6 g9 v( d8 P6 M* f# j8. n0 `: N9 }$ j- q
    9" v4 M, ~& }# _) b4 x4 K  A
    10
    ; f0 x# V0 m1 ^' W+ j110 s7 e8 j% \9 k! v# L2 e
    12  U- a6 a0 @7 T! u
    <All keys matched successfully>
    $ b, Q# A6 \% I0 u: Z: q1
    + Q; M) W" v- a4 e% Wdef process_image(image_path):
    ' I5 o  I. J# I. {    # 读取测试集数据
    # d; @8 ]( F! a% L* N6 q    img = Image.open(image_path)
    * U; D6 k" q5 x. W& T    # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    2 o! [& q# i6 [" R, ]) n- v    # 与Resize不同+ ~: W  @6 w2 _! `$ V8 \9 ^
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    7 \# s' Y/ @( X7 t+ L    # 而且对象调用方法会直接改变其大小,返回None, r& K4 U& u, R
        if img.size[0] > img.size[1]:( N. Q$ }! @+ d" ~: s
            img.thumbnail((10000, 256))' a% }& s7 k# b! h. E% `. z
        else:$ V1 J1 ]4 t4 e* Y! W+ L* C
            img.thumbnail((256, 10000))
    4 h% G& K" ]  o  V: S, ]( ]* W3 [' v' r0 j4 D- @, C
        # crop操作, 将图像再次裁剪为 224 * 224, H  D8 Y1 h7 }& }" x0 {" @% N
        left_margin = (img.width - 224) / 2 # 取中间的部分
    % q! }0 N* O' G* f8 g' |  I( ?    bottom_margin = (img.height - 224) / 2
    2 [6 }% D* z- M$ g  s    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    ( `, j2 E& m) e/ G1 S) r  S5 `    top_margin = bottom_margin + 224' F9 V) S8 j; l7 y; v0 P- r

    : c1 |! u8 H* G: h- t    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    % O5 a9 Q- k( {+ G  E
    * r0 t& J7 f% t  z5 O- R    # 相同预处理的方法
    + _/ j2 O2 c% P    # 归一化
    / r- ^$ ^- S. z' O. s# `! c    img = np.array(img) / 255& L  P/ ?0 ~3 p2 f) R- o
        mean = np.array([0.485, 0.456, 0.406])
      K  x+ ?5 c/ M6 |    std = np.array([0.229, 0.224, 0.225])& L' t$ _, Q. q) e& b
        img = (img - mean) / std
    6 a4 \# w9 O$ y3 J( h& J$ E
    , X/ ]# U8 `" U: z" E    # 注意颜色通道和位置
    1 E1 T5 d% e. h* f# @7 h    img = img.transpose((2, 0, 1))
      C; p4 X- |* a1 M: z7 y
    1 m: H# N6 g- ?0 ]% e1 @1 h( R    return img4 g2 z6 Z% ]& C% t4 m$ q7 R' ~# s. C& n
    5 C( g" L/ a& Y# T- v2 ^6 z$ g
    def imshow(image, ax = None, title = None):: ?: i1 ?9 R5 ]! }1 v
        """展示数据"""
    5 W% s! ]+ @% z    if ax is None:
    3 o& S8 z' g( t6 O/ s        fig, ax = plt.subplots(); e9 m9 h2 V2 b4 a) }4 n+ |$ C

    : Y, v. h1 ?4 l9 @6 H, f    # 颜色通道进行还原$ [  y- F! A8 U. o9 }
        image = np.array(image).transpose((1, 2, 0))# x, `* R1 k* c5 V! p
    1 n( s) ~' e- S- y" c1 E0 f
        # 预处理还原
    8 R& p3 q  L8 D' b% U7 O    mean = np.array([0.485, 0.456, 0.406])
    ( o+ t) |' K* N    std = np.array([0.229, 0.224, 0.225])! p: K1 N& w9 O
        image = std * image + mean5 C* J& G! d% f
        image = np.clip(image, 0, 1)' b5 h& t+ C. u

    ; }& `' }4 G+ U- H, X8 D    ax.imshow(image)# R3 @$ }' [+ _& G$ Z) a
        ax.set_title(title)) M6 U; i  W( v& [
    8 t+ D* d! J/ R5 |. I
        return ax: T) L2 k" i& M/ d. j" \- B6 A

    # R- u: S# k& b. s9 {( t; ]% Qimage_path = r'./flower_data/valid/3/image_06621.jpg'
    * Y* \, v; r: j7 m, L% K" Cimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理0 z% e% G* z& |. ?3 y3 |$ ~2 Q
    imshow(img)
    8 z& S# x, X7 k) T; ]0 j# T3 K% C/ `9 B- h) Q8 Z% a' u% A
    1
    , \% j1 F5 x- M3 Z7 G5 ~2
    , k6 M; w& }: f3
    " L7 l0 m  |6 t: S4
    + ^: H. N* _/ U+ S  ^5! r7 U1 D" F* }! C
    6
    ' t7 {6 z* i; w6 \& @& z7
    , h/ h1 t6 k0 M& l! c) x1 H- X8
    1 @7 `! @- T: b6 ]' K9" M2 [$ l' S4 z: \* U/ X
    10
    4 S4 q! m9 N5 y. w. e3 w3 M% j+ g117 R0 O! H. M; H9 x+ z8 i* ]
    12
    $ D4 s7 Y0 @2 w+ ?6 r4 ~13
    : V$ X! B% r& L* }+ {+ ?# n/ w14! r0 s2 x( L: ~& u; ~- s; i
    15
    & _7 V5 z! f3 u5 W. x16( v3 y9 D$ d7 D/ y
    17: ^1 ?1 ?9 k/ B5 `# k( U
    183 I/ m, `7 x# _& v/ g* `
    19
    1 R" v' f) x; h20
    * X' I+ u9 O: |5 h& J* e21
    ! F/ N2 Q* o. ~) C! e0 s22
    ' o2 _9 `" e1 j! o' \23
    , f1 [( c; ~% D! i/ P! O# P24& k- A/ U% H" \1 Y: T! [/ e8 m/ Z8 N
    25
    - j# y, h7 {, |4 N& M* }6 z  n26! N  B0 e3 R' |; i7 H1 v; t# k" x: Z
    27' M  w" p) [% G# s3 X
    28& c/ n2 k. P; E) l: C$ x2 q
    29
    : @( q* m; `/ s! k7 W30
    : U% b/ K3 d% t  Q' e31" b2 ~& k& R, ^# g/ y1 u
    32
    1 y9 }( Z. a# j' }& w33
    * l! u$ W0 _( J9 m8 ~& d34
    , I7 f5 z  Z' A; v, P5 [35
    ; l, @7 n- P' x" I, `36$ E( c2 W4 n9 h9 Z. d0 U9 Y
    37
    0 d& f$ o" c- T6 `( s& Z5 I0 |" I2 I38! e# X: _5 y/ \% a# b; g! X4 C  p
    39
    ! d* x$ k$ z0 w/ {401 T0 ]9 A+ K: R2 H" R
    415 ^9 d. Z1 O. Z2 [2 G9 u& k
    42: w2 J: A8 }2 t: ~$ r2 v6 G
    43
    & M$ N. @( j) o6 p  G0 f3 x44; L9 n( I6 q: W' t& H2 k  T' g
    457 Z- G( ?9 a7 S3 _2 {$ _
    46
    - N8 n- ~0 P8 t3 R47) J; q, C8 }6 e8 e, h" e* q7 H
    48
    1 {7 Q; v3 J- M& {! r8 u49
    ; v4 O: d( [- a1 ]5 K6 v7 n50! O6 Z3 W+ a2 K- s
    51
    % o$ z/ i3 B! ]# E7 \/ d6 a52
    6 f* e2 f$ R8 o) L# V53
    - f% S! r) z' o* d/ a) p6 u54
    2 O% v1 s8 d' d: Z2 w; I$ _<AxesSubplot:>! i  j* E$ {) c# d
    1
    5 v5 s2 y# m$ w( D
    2 `# v% m2 s9 A( ?2 _3 }上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确% m# _/ Y  [2 r- C/ V
      }+ b  l1 e- a5 w) M
    img.shape# M# h/ E# J) J; K3 }& u1 B/ N
    1
    : w) n" x6 M6 j9 ^7 o+ h(3, 224, 224)% L) Y* y; e5 e! }8 W  ^/ I( {
    1
    1 `- _; ]6 ~3 C$ e证明了通道提前了,而且大小没改变
    / V3 }+ P& Q8 Y* W- A- I, d; z. }6 o1 y% z+ j! t* n6 |& z9 S
    9. 推理+ }: S+ T$ }4 O$ p" j
    img.shape
    + ?& R1 c; m  K" a+ z- x& r; N; c& N+ m
    # 得到一个batch的测试数据. T9 `7 }! ]1 I# E" ^
    dataiter = iter(dataloaders['valid'])6 Z, h% f# X5 t4 W" N4 D/ M' R
    images, labels = dataiter.next()
    4 l9 g. F2 I) }
    5 U' W2 x: ]+ K' n( j/ mmodel_ft.eval()) s* [- \7 i/ d1 ?( g' \0 Q

    & P  W* g! n0 R' S1 X% ^8 ?; eif train_on_gpu:
    ) y  x* ~: b# {( T0 T0 e    # 前向传播跑一次会得到output( W$ E2 [) U5 d, h
        output = model_ft(images.cuda())8 F9 q( k2 K, w; M. C/ M0 M+ t
    else:
    ) B; v# ?" x4 T0 K4 w( q+ }& F    output = model_ft(images)1 B' t2 N# o4 f/ e7 _5 G

    # s3 B0 C9 L! d( g9 Z# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值: }$ {8 d% ^2 }6 W3 q- \! L
    output.shape7 a) E) p$ @3 W* L, K5 R

    0 @4 H$ c+ P/ Q) Q* l3 {% X# o13 p/ Q9 ^; b$ R$ ]! A
    2
    , C' o, u$ A+ O9 J2 y3
    8 S: O  [- S- l4
    : B7 ~3 N- M* K' s; h5
    + n- R2 z8 G0 _- c2 E3 s6) p7 u3 R' ]9 k( n
    7
    6 {* U0 G3 |& ?. L8 t% n# D8- d2 h7 v5 e" ~  |
    9
    * A! F9 j0 }3 \4 L8 o" O10- E" ^. n. v. _! C- U
    11
    " P8 R. }4 k9 W12% N# j  X+ h  U3 `( Z; f
    13
    ' W8 Q6 `4 ~2 w/ Y2 u; A9 m14
    & B, Q1 y( V7 y  T  ^15" W3 {$ I) v4 y9 ^2 a3 a
    16$ n. X5 f" R* j0 f
    torch.Size([8, 102])& H* h0 f; {3 F
    1
    ; ^+ v* E# H7 G- K4 z9.1 计算得到最大概率7 f" G- g6 j) X; t
    _, preds_tensor = torch.max(output, 1)& f# u( e& g  G# q* r. Y

    : G+ g9 v: v# v* [" K, p1 kpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量' C) r: Z/ u5 N. h
    1
    ) J/ y' |3 F  n& p9 ~27 a3 D& {* Z2 T& ^: A
    3
    . S2 D8 U6 u, c7 z& P! M3 q0 I& I9.2 展示预测结果
    ! y  M+ P: h5 U& M1 Rfig = plt.figure(figsize = (20, 20))+ V6 p, q  N: f9 }9 R( N
    columns = 4
    + Z- `, b% ~2 T( Trows = 2
    # t& u( a: e" n' z" Z/ @( X
    ' f  S. b1 D5 q9 Cfor idx in range(columns * rows):
    , F! J% u2 p' @3 J    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])3 \* E: o/ a3 l& F6 |2 b/ i6 L
        plt.imshow(im_convert(images[idx]))7 _. s8 L4 y3 c& d4 U4 v$ i
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 5 s0 G4 |, _; n& O3 F, K
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    5 v/ ^. N3 ?3 y( M6 h2 @8 i) O1 vplt.show()
    $ k: S3 u0 z9 e( }# 绿色的表示预测是对的,红色表示预测错了
    ; j/ o* Z: h3 F8 N1
    ) A# p5 @9 T& \4 z7 y: S3 U% L* @" \27 I+ l* P  S$ D2 _" K% l4 P* }3 O
    3* L- H- s/ O& t: b
    47 i9 O# @4 [" [" _# Z
    5
    3 o- g& j! L) M# E# ?63 d/ n2 e/ j$ H) K2 }0 F& A
    7
    ; ], ^7 B$ a% _& n9 b8: |/ A/ x( Y0 s4 L; x6 y! @( _
    9) v+ X( W1 l7 k6 k9 g5 L" j1 o3 s
    10* V  I' _& @$ x- R$ H
    11
    3 n3 v7 a( @0 G! k& V" `0 m3 P( R3 z( h) @6 {8 @

      U* {2 M3 m; L7 R: H+ H
    & a' l  @9 W( z# _————————————————" ?2 c8 \2 i  G3 n6 _
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。! N. t: A6 P" c+ m( f" p
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    ) S; E  R: ?2 O( w* d) J" r: E( d- j" @4 X8 V% J5 Y

    2 u6 u/ s% e  Y+ ~9 d
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-10 17:09 , Processed in 0.351011 second(s), 51 queries .

    回顶部