QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2755|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    . R. x: [5 t4 X0 s+ q, s* D
    + ]5 |$ s3 L, c% J' z: _4 r4 Y文章目录
      B# a; P6 Q9 \4 e3 t3 |卷积网络实战 对花进行分类" x3 \/ m# b. H7 N
    数据预处理部分
    " B1 o1 V! H# J# \% u* A网络模块设置: h+ t  W/ f- i! I# {8 P  ^
    网络模型的保存与测试
    ; ~" Y5 R1 ~" ~* i数据下载:: N0 f; ]! p0 T  x, C
    1. 导入工具包
    7 H/ S" p' d! ~$ N! x3 t# H2. 数据预处理与操作9 Y8 G. ]' h- [$ j4 n
    3. 制作好数据源1 I( V2 |- P4 A
    读取标签对应的实际名字- g6 z( N' V. ]! n! \& a8 a
    4.展示一下数据
    ( u3 t5 K& o' J( \5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    . [9 I. Y' [# Y" P8 ^4 q! |6.初始化模型架构
      q9 Z' i4 \" K" ^- ~) S7. 设置需要训练的参数
      z; R7 z4 d8 E, U" c7. 训练与预测
    $ }3 Y3 j) H& X3 }# l$ ^5 [% X; ?7.1 优化器设置
    4 a; ^% _+ k1 m) w. f) @1 i7.2 开始训练模型! ]1 C/ J& ^7 j7 w9 M1 I8 e
    7.3 训练所有层9 W' B5 T( i+ f: B% A% u
    开始训练
    3 a! _: z% g& T1 v9 k8. 加载已经训练的模型
    / v( w; c' n+ q9. 推理6 ~; L% C! y6 u: Z* k4 w! R
    9.1 计算得到最大概率2 t7 a* W1 r6 q) d+ z( Y4 y
    9.2 展示预测结果
    # r/ U9 {7 k% p1 w0 P* k" x$ Y写在最后
    $ [2 y) Q0 s* e# g卷积网络实战 对花进行分类
    * R/ d/ ~/ x; I) a本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
      ]; _' x* |% s0 h  M  `5 @7 u4 H9 n1 c6 s0 t
    在文件夹中有102种花,我们主要要对这些花进行分类任务- P/ O; p( k8 K5 I5 a  o
    文件夹结构8 \# V6 z  B5 s3 N, B* S2 L+ L

    , I, T$ i+ _9 }) T( T# dflower_data8 {4 m$ e/ [$ P
    - E0 e  W& E+ @3 a
    train3 L! ]. k! ?8 l" x
      o3 H7 C7 P7 m5 u# s6 U7 t6 a9 k
    1(类别)
    & m0 `' Z, r8 \- [7 h6 k9 K7 q2. p; y' c9 G8 h! r
    xxx.png / xxx.jpg
    % p. a. m. ?* Q. `" ?valid
    7 g. _' O8 Q1 C8 I- F/ u' `. {0 U# m7 y; R- Y' ?# o
    主要分为以下几个大模块
    . I- {6 v9 n/ ~  i) @) k4 x. e  [
      [. ~7 k, d& `" ~数据预处理部分
    ! U5 t; ?7 h9 D" k, ?2 W数据增强
    * }( n) e4 b, o* K! t* r7 g) ^数据预处理
    3 a( o# g* I5 C网络模块设置6 b5 L, z8 w9 @+ i1 U. f1 M
    加载预训练模型,直接调用torchVision的经典网络架构
    ! F5 C- ^9 B2 Z因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务! U8 ?+ N  S, z
    网络模型的保存与测试
    7 A" G/ _6 B0 D5 M7 q模型保存可以带有选择性6 T! b8 X/ C3 G" w- Q
    数据下载:; U6 E4 I7 `1 Y* M
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    * B1 O5 @; z; \- M, J+ }) G! V9 b3 \6 o( s% Z1 ~1 d: `
    改一下文件名,然后将它放到同一根目录就可以了
    ! N$ c- v0 V& t: ]  V* k
    1 F3 S. q. x3 M/ `7 U' t下面是我的数据根目录/ P& @$ F/ G; l/ C
    . i$ C0 t5 Y9 [' p- u
    " |/ K& M+ w3 `+ R3 G
    1. 导入工具包
    ( d8 ]$ y8 x$ H  B7 ?+ |9 Yimport os
    4 f6 Y3 \# x4 x6 F5 oimport matplotlib.pyplot as plt
    0 B* a+ J/ ]4 W' v* m) @# 内嵌入绘图简去show的句柄
    6 ?; _) j- u! K- n- v  Q%matplotlib inline
    ( I, u  B0 U: m  G( X, y7 C  Oimport numpy as np
    $ r9 E  l* x# W7 N" ]- |import torch" W* \! w6 u9 m% {; n1 G7 R
    from torch import nn) P0 i4 c( B" X) B/ D

    # Y) z7 s$ }+ l. q- Nimport torch.optim as optim" ?' W: @. }9 F% B
    import torchvision
    6 j1 w& x) N. a$ k, C0 j5 y2 _( t9 yfrom torchvision import transforms, models, datasets8 r- x7 s; i+ W
    " N9 U6 w. W) k- @) |, P8 f
    import imageio  H% S2 l, [0 {$ i4 k
    import time. X5 z$ k1 N" B) G% ]- A
    import warnings
    $ b# ~' ^. v% g: N# Dimport random
    1 L* P, |7 @& v9 iimport sys
    - o: y* w! t/ A" r/ Bimport copy4 h: M( ]6 C) `( ?! l- Y
    import json
    : u" n$ f. U. O' }( R5 z$ zfrom PIL import Image
    ) X- T- Q5 {) Y3 I. k  o& A
    & L/ I- N+ e( O5 M5 K' c" L  D
    8 h$ x% s7 g7 l. i1
    - `: r7 E$ y7 N+ o- q6 }2
    6 w0 q; z& k. A% L4 P7 c+ u5 Q, D# j3
    0 V6 i/ n2 X2 ^- ^' Q4$ A4 N9 K. N. o  e# f% [3 y8 ]
    5
    2 U+ t# F! L6 F8 x- h  O63 p4 b2 d- Y7 k
    7
    ) B+ o; k- v4 r8 F8
    ! w9 `  }1 v- |: n- S! B# f9* x0 x' [" V) Y" [, f
    10; n+ ]+ L) \7 g) ]  f. L
    11
    3 |2 [" L: J" [* J12: S' p( @% ^& ?0 }3 K& X) p* ~
    13
    ; F7 o; X% n+ T# r& k+ y! K145 E3 r0 w* b3 ]1 n* C( U
    15
    8 c1 S: Q6 z) C  ?" l16. r! K+ q" Z" }
    17
    % S8 u$ I- H6 X* t! o6 _18# L( l6 m# w# ^3 s5 b2 l
    19
    2 F, J2 m. V3 [7 D20
    2 l: _4 A. J$ }0 T8 X+ B21
    # w  Z# z/ E8 g5 s; f" I3 G$ ~3 g2. 数据预处理与操作& w5 r8 u) s( l7 `
    #路径设置
    3 c& |6 X& c; S5 _0 gdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    ' E" v& u' J, ?1 z: ltrain_dir = data_dir + '/train': F; x8 F- d# \  q; s; Q
    valid_dir = data_dir + '/valid'! o3 ^" N/ ~% b9 @6 R
    1
    % x- Y! s+ l2 F3 N( P$ E2
    ' g! T- e; P! T7 E, s' u3 M3
    ( h& _3 ]( E& g5 |0 C3 p$ O4) A% A1 r3 k7 ~
    python目录点杠的组合与区别, \, _* {( ~- K; k
    注: 里面注明了点杠和斜杠的操作* j% B+ ~( [8 @
    + N3 P# p( E6 a( F, ~% j4 _
    3. 制作好数据源# S2 Z$ J  T( e2 W! S
    data_transforms中制定了所有图像预处理的操作
    + n. {, Y4 b; E' e# Z6 D' [ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片+ y+ C5 u6 S6 _: x* S
    data_transforms = {+ B* o6 L9 D/ d2 U+ k5 M/ T& `
        # 分成两部分,一部分是训练
    9 ^3 J  ~2 b' m3 h    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间9 ]0 }( C& W' H+ t" |* W8 X0 f
                                     transforms.CenterCrop(224), # 从中心处开始裁剪9 U8 u& K- y8 S4 \! M2 t
                                     # 以某个随机的概率决定是否翻转 55开
    " [3 g0 p/ }7 P                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转# U4 \& M1 j9 r  t7 R
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    / N& D$ z7 U  e! V                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相9 A2 \+ g1 K- b+ J: `, B& U. u
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),! L( \8 i, ~; l% N" k; M0 y
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB$ j  `( C. H( G
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的6 s& m, c" m* Y
                                     transforms.ToTensor(),
    & f) O$ _. J2 i2 f5 H" R                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    6 j& @' j# |% \3 V/ V$ M/ b$ m0 `4 o                                ]),1 v* u3 o( J: z3 r" H: G) C. ]
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    5 t7 ?5 \0 @& e9 ]  B! M" {    'valid': transforms.Compose([transforms.Resize(256),
    " T) b" B. ~% C) \9 y: S                                 transforms.CenterCrop(224),) e# [3 Z: o7 e) O8 s5 ^' r
                                     transforms.ToTensor(),% U* _8 d% R; r, [. W8 @7 C
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    * K+ ?, U4 Q+ T. D% ?1 u                                ]),( E4 U5 h0 o: h. s; r
    }
    " b9 H! s+ X: G' _* j
    ( J' v# k$ t/ s2 f9 z& C2 Z* z& G6 h1& Y" o( M& @% Q2 T$ I
    2( w: `  K9 z6 q1 H* l# K
    3% k: @0 g8 A% ^' a  D" r
    4
    - D9 p7 E* Q% Z, i, A2 x1 B0 s" s56 U, p: \$ E6 Z# j* S' @3 N, s" C
    6
    8 Q1 S: `% |1 C% k2 S5 \7
    - U, k( N2 d: e% M! T. m8; @. P: L3 j8 T; U( z% _: Y
    9
    " J* ]" d: Y6 q/ k( s( n' h( ?10* M  C, b( `5 j
    11
    9 E' N( `, Q( ~/ D1 z, J/ K9 y12
    3 R( ~; L' Z+ R, ^( Q13
    . m0 z1 F* {7 s. `! ~14
    # i& T# t1 f. v15
    ( @. c* j, y/ G, {. L. C16
    ' v/ s1 w; A3 f175 V+ J# b/ |4 \  c2 V
    18
    + `! h2 f6 K5 z8 U+ n19. |# Z' y( H/ u/ @  n9 @* t. D8 s
    20( c* M' @; }6 x1 I. }/ p
    21- W$ p+ E7 _( E( p, x* o
    batch_size = 8
    ; w7 T' t) s2 [" z5 qimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}: `* s: W; z% j% u: D
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}; O1 L$ G; w2 q
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} $ n" \# U4 J+ B
    class_names = image_datasets['train'].classes  Y9 o4 \+ Y. W7 a: \7 T2 v7 E. n
      ~. J4 K: L8 Q/ u
    #查看数据集合
    % b0 c) m: ]; {) P5 \* Oimage_datasets
    , d4 A  Q0 w3 z7 L- F' C! x( R4 m7 P) m* D; C! Y/ _
    1
    + C' `" Z  n8 a( @- r2. [3 T; w) J( m/ p( K  F" m7 T
    3
    " H2 d6 e- ^7 @, b4# A4 V+ }  A- y# ^% k6 R# L  |
    5
    * e; M9 k8 w( k, C0 R2 \  }6: a9 r. z! J2 e. V" A
    75 E* H' g; d9 H* T" U
    8
    ) Q$ w3 ~  j8 u% a8 f9# y3 z7 r' Q' _# p
    {'train': Dataset ImageFolder
    " @$ u& y4 v: l     Number of datapoints: 6552# X0 t$ O, N% J/ N* L: v
         Root location: ./flower_data/train
    % s- m# l( |+ ~1 [/ I& V     StandardTransform* l' I$ G* q* }6 r
    Transform: Compose(
    ( t& _) N2 [! t' t1 c; D2 b                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)$ x( W0 @& q1 {) ^& Y6 S/ K- o
                    CenterCrop(size=(224, 224))
    1 C% M4 a" i2 \/ }7 Z) R* Q) z                RandomHorizontalFlip(p=0.5)
      {2 s3 F: L" L                RandomVerticalFlip(p=0.5)# F' x& r# y6 |3 }$ c% D! x
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])7 r% {5 g. J+ G4 c" Z4 u
                    RandomGrayscale(p=0.025)+ m2 f  D' f% U/ q( l
                    ToTensor()
    7 \6 |3 I+ F5 D( }8 j                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      D* V  `; d! U            ),
    . S& H5 \% [- M7 w3 c0 H 'valid': Dataset ImageFolder: n5 k6 y' T& \2 s' ^. A
         Number of datapoints: 818
    8 b+ Q$ Y) S0 B6 h  I+ u" s     Root location: ./flower_data/valid
    7 ?+ f2 @& _' P& j     StandardTransform8 i1 w7 _7 j( A4 m# g; j/ |
    Transform: Compose(+ C6 A' u4 R) ]$ u: C( @7 Z
                    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)5 x* k- v6 F0 y% v/ C3 L; f
                    CenterCrop(size=(224, 224))- A& u( K0 u. c/ U' M
                    ToTensor(); }, P$ T  Y4 `$ ?+ R8 G
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]); W6 N: P5 Q5 ^0 b7 s% _' e9 X
                )}
    ! R7 e" d2 T1 A3 _4 A- Z- S
    7 b+ T+ D; g: d1& U% p) O' Z) b* H# M- O( l
    2
    2 w7 h" C5 M4 y3 F% u. {3/ u: P1 q8 m6 ]
    4( u) |3 w8 a+ h. B% W, P! I
    5
    . k% p- x2 d+ D' N1 w2 r' U, F' W6
    0 n3 Z( l" `# x' ?7/ m5 W5 l6 i* C; P
    8! C( k( [* C0 c" S
    97 W: K: v, T0 D; q, r7 ^6 F
    10
    ' S# I% {- I& Z2 y11! ^9 K$ ^& e4 d# [
    12
    7 `( k$ W$ ^2 O3 I3 h2 Z13
    , Q  L! L- W1 _! L0 q8 ^! {* v143 A4 B9 G8 L  M3 X6 F
    15
    $ `' X; s: @* W) n16
    6 A/ n' T4 z5 G; v17" X+ Z3 G: {% c% @% f  B
    18; q4 h# u6 l' n) B! _, x8 i% E3 `
    19
    # j; R: f/ |$ E9 C% h. x3 R20
    5 V* r6 Y$ e9 J0 K2 v5 a& x21  k, E. i9 m" ~3 q% t% X
    224 [  w; k$ r; E) I9 x
    23
    0 ~4 c9 Q. ~" M. j. R: y241 d( W: l. Q+ B
    # 验证一下数据是否已经被处理完毕
    - A$ L+ A) _- w9 ?; ydataloaders* Q" C4 z. S! t& b" z
    1
    " m5 q5 b# J, k' N1 b2 x. ~1 Q2* z! x% T1 L$ I  M: p0 O
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
      c; A/ k* k/ @9 { 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    7 i) c# H: q* q  z4 p1
    # L# B  n) ]7 f' H: L2
    5 i8 ^  Q8 G  V  D4 \/ J2 p; pdataset_sizes
    ( l, @% J- X) O' P4 T2 O/ T1+ M. ~# V( [) I- s2 C
    {'train': 6552, 'valid': 818}2 x9 c: q& J$ y# v" R1 S
    1
    6 L5 q  L7 \' @读取标签对应的实际名字
    ! x1 u, c- |0 R/ L使用同一目录下的json文件,反向映射出花对应的名字
    ; h( `7 R6 r% D7 w5 T+ i% U- K0 X, g" `/ c- ^# D$ H8 h1 |, H
    with open('./flower_data/cat_to_name.json', 'r') as f:2 X- }7 A- v, v$ ^6 Z3 Q, O, X
        cat_to_name = json.load(f)) R* W1 N! x9 ~2 C
    1! x$ }' M* D' [( y; l- u& ^) J/ }! B$ m
    2
    * i2 f. s8 [* `cat_to_name
    ; n- B* r% U0 d' y, R1- R' e6 Q. h8 i+ D. E$ D
    {'21': 'fire lily',/ a* }+ ?+ @5 Q+ V& T  F( b
    '3': 'canterbury bells',7 u* V7 t' I4 A$ X, P! e$ ^
    '45': 'bolero deep blue',3 E! ~: l/ _2 N. J3 Q4 q, q) z5 ~
    '1': 'pink primrose',
    - Y0 ?5 E) f' [4 q1 A( a+ ^: J: P '34': 'mexican aster',1 L: `5 L0 ]: O8 a* m
    '27': 'prince of wales feathers',! q2 b  S4 X* I9 s" e! }3 \
    '7': 'moon orchid',$ Z& t# k' d8 e0 ?# v
    '16': 'globe-flower',( @" Z' V7 j0 z; d  o4 R# l- \
    '25': 'grape hyacinth',3 X: h2 D/ p( b& K  S
    '26': 'corn poppy',
    3 P- Z3 T' S0 v/ D1 o '79': 'toad lily',3 s/ x$ f. f" m3 Y" Y! B. L, W9 U
    '39': 'siam tulip',9 }5 u2 l+ ^# u. K
    '24': 'red ginger',
    5 d" i5 Z- _3 @, X) Q+ M1 R$ C '67': 'spring crocus',
    8 w2 P9 Y1 B' x5 h0 W2 k2 O  X: N. Z '35': 'alpine sea holly',6 K" ?; a. _# N$ Z
    '32': 'garden phlox',( k8 J4 Y- G, k8 P9 _
    '10': 'globe thistle',
    5 R' R  Y* ?/ `$ o0 Q& z '6': 'tiger lily',
    7 k3 v6 I+ n& M. x: e '93': 'ball moss',
    4 B1 _; c4 t; N1 H6 P0 \ '33': 'love in the mist',( d2 |1 T5 M5 ^# ^, N$ T8 d; s
    '9': 'monkshood',
    ' f; d* F9 d. n  o4 X# [' m/ O: U '102': 'blackberry lily',
    ' D- Z5 Q. V6 f: e '14': 'spear thistle',
      i( T3 D& D7 i! a$ x0 N& O6 S '19': 'balloon flower',
    + Q" q  P" a( h3 K0 C '100': 'blanket flower',6 @& j8 E& s+ |/ ]# |) p
    '13': 'king protea',
    4 e: h) l8 l+ O '49': 'oxeye daisy',
    7 g" ]% c! C$ F6 E& S( N '15': 'yellow iris',* T8 ]1 n, c: \
    '61': 'cautleya spicata',
    1 l: n$ s9 [2 w8 i3 H: B9 f '31': 'carnation',
    ( a7 G; V+ f4 R4 }& T '64': 'silverbush',
    , g" t( h4 _6 }+ G '68': 'bearded iris',! y- ^- R9 R0 B& V# d# d
    '63': 'black-eyed susan',! s& [5 C  P% W2 S5 x/ K' {3 |: @. j
    '69': 'windflower',4 [. J0 M" q$ A2 I+ T/ i
    '62': 'japanese anemone',  {% v1 z2 p, I
    '20': 'giant white arum lily',
    4 s: p, j, [9 @7 o '38': 'great masterwort',: M, _2 ~/ l  p% X& B/ P
    '4': 'sweet pea',6 [  f5 F3 i' s! g" m8 t- O
    '86': 'tree mallow',
    - ?9 ^. q4 H& b. M+ Y  C '101': 'trumpet creeper',) \4 _" }  W! f, |/ [% S! i
    '42': 'daffodil',( W$ |* x! A6 o5 N& ~% H# Q
    '22': 'pincushion flower',
    $ \, ~" J9 B; h6 S) t; D, p5 \$ ^- @3 O  t '2': 'hard-leaved pocket orchid'," c  d, S, C" e& u/ S) y
    '54': 'sunflower',/ W7 l7 j2 \/ `
    '66': 'osteospermum',
    0 W. x  P/ I4 O; i' i '70': 'tree poppy',
    ! o" c3 @  f7 D; Z# U) I7 ^+ e '85': 'desert-rose',
    4 z: l/ G! [8 G/ T- a5 k  D$ G3 Q '99': 'bromelia',
    * b% @* ^, M2 X9 b3 j '87': 'magnolia',, q& d2 B& b9 @  Q/ o4 Y$ _- a$ S
    '5': 'english marigold',
    ) q3 K8 E1 ]2 f, P# m6 t- Z/ ?! v '92': 'bee balm',) D4 e, q5 A/ U  \8 Y
    '28': 'stemless gentian',# ^2 w% ~7 R6 \
    '97': 'mallow',8 c" `3 f& g* G+ K2 A
    '57': 'gaura',
    " I& [3 v" ~( |; C2 F '40': 'lenten rose',' ], z: {2 U  |* L/ F
    '47': 'marigold',9 `% B$ a. Z. }) L/ q$ B
    '59': 'orange dahlia',: ], v" ~3 U) V, [* I% ]6 `6 K) \
    '48': 'buttercup',
    # j, M0 K' E$ ~( w '55': 'pelargonium',; K8 F2 \3 y  C" J& f' {9 P( U& v/ x
    '36': 'ruby-lipped cattleya',
    2 _5 K; x3 Y+ v6 B2 w' n# [ '91': 'hippeastrum',
    ; |$ C/ D" w8 w+ @1 T '29': 'artichoke',. U. l* {6 n$ W+ z5 W) d
    '71': 'gazania',6 a7 F5 i. S4 [' g8 y  w4 g
    '90': 'canna lily',
    % G; @, H/ ~: E$ R, j# ~7 s '18': 'peruvian lily',
    1 ^: z  k5 i* g- U% S '98': 'mexican petunia',
    % h2 ]8 T5 Z( E# A '8': 'bird of paradise',
    % o7 Q; Q$ D6 C, @6 W! t '30': 'sweet william',9 B1 g8 ?8 D7 y- F) I% Q% m
    '17': 'purple coneflower',/ l5 \  f$ j, w4 `! V. L
    '52': 'wild pansy',
    ) T/ D. M) X! ]; L; P0 H$ {% B! ? '84': 'columbine',9 @7 q( Q/ B: d; P" a# [
    '12': "colt's foot",
    ' V2 ~' k2 n4 T) d' v! ] '11': 'snapdragon',9 }: i& }9 G$ O* I( l
    '96': 'camellia',
    1 C1 Q+ {+ |/ F9 e! I, { '23': 'fritillary',. }1 k' O6 |5 D. {! b
    '50': 'common dandelion',
    ; _* I: i6 D5 z2 H) H' {+ ~; ] '44': 'poinsettia',
    : |8 t$ C' k# K" _ '53': 'primula',; s# U% @5 J+ V) `& y
    '72': 'azalea',
    5 r! L" W8 o# D8 Y0 L% {1 P- G '65': 'californian poppy',. ]% q: j3 R1 O3 X( F4 w7 T9 L
    '80': 'anthurium',9 l% R+ [- ~7 D/ X
    '76': 'morning glory',! s! t8 ^% w8 N( W0 z
    '37': 'cape flower',: {, @2 x2 t& [3 J) J
    '56': 'bishop of llandaff',$ c& o: Z6 {% _4 Z) `7 i
    '60': 'pink-yellow dahlia',
    ' L9 V6 I* s1 m7 Y5 d/ Z '82': 'clematis',
    ' B; ]* L: B$ g9 u& u '58': 'geranium',
    " g, K0 V9 F5 T3 q& z8 o '75': 'thorn apple',
    4 h" k! l. V; m9 A) @ '41': 'barbeton daisy',
    3 f1 _$ L; q$ Y. t '95': 'bougainvillea',) P! j2 \: w/ ?5 g0 c
    '43': 'sword lily',
    4 M) Q8 ~/ T/ e( O) i3 ? '83': 'hibiscus',9 ?: @6 q5 V  \+ T, C$ W
    '78': 'lotus lotus',
    1 G' F- P1 }6 B9 b( g '88': 'cyclamen',2 t. P3 p3 ?& u) b9 n
    '94': 'foxglove',
    / D; _/ ^. z" W, D- U! A '81': 'frangipani',
    & h3 y% {% w, {3 I '74': 'rose',
    6 M* v0 J& A' g- y '89': 'watercress',
    1 C. R2 @4 {$ t. p& c/ ~( t '73': 'water lily',) I0 t$ I+ p8 h, x. H4 y# o& Y
    '46': 'wallflower',
    % v7 s( o5 ]9 X9 L0 S+ ` '77': 'passion flower',4 V9 G/ \, \6 m1 N2 |5 H9 \; \
    '51': 'petunia'}
    ) d4 [/ |2 I" B) e# E0 ~4 X! J" B) W  L/ I' Z& C
    1  O+ Z; S+ z; y- h, |" u: h
    2
    / U) S% E$ P1 R: Q3
    ' |) U8 A+ _0 Y8 E44 f3 U1 ^$ P- M/ k- t
    5
    ! W. A* ^7 Z4 j3 X; o6 k1 ]6$ N- Z" s# z0 b0 `" e# }
    7
    ) ~% a( z( q* e0 J8
    + s( W$ m0 t' Y; g9# C) Q* w% d$ ^
    10  m. x8 P* [, w1 a: Z' `
    11+ d' G% m9 C2 y" T: S6 t+ d8 I8 O
    12
    * n3 j6 J- ?# B# D* U" `7 J, w13' C; @  e/ Q2 t. c8 P! D
    14
    8 Z  D- i7 H" S: n0 R0 }; Z6 d15
    - O3 _( ?8 k5 k4 \16
    $ K/ [3 g/ F9 j7 a8 O( Y17
    . a+ Z& {  G. F  V189 b6 ?% n3 x8 L$ k
    191 Q! V0 K, D2 [5 `; n
    20
    ( @, Z- I- u  Y# q, w21
      }2 v8 I  y% {4 x9 M/ P$ t22
    + i: s) f8 v  x- p  @5 L23
    + Q; y4 c! X! L% y245 b9 o, W2 a+ O5 d
    25' R' R( Y& r; G8 e$ H' ?
    26! O: E& g; T: F4 n! y2 U
    27
      O# X, A: Z6 e& X/ \28
    7 S: v6 P' G$ M: r8 C* z29
    7 B: w) R4 I9 j& E( o30
    4 `+ l& ^( e% F8 `) P+ f7 }31/ G; v3 ]5 _0 v; [
    32
    ( P5 y( |1 K' w  S! \8 e33: K* J4 R9 P+ M' j+ W
    34- K) S7 j8 J; z& i
    358 @9 F/ A& ?& t" i- s' d  }3 Y
    36. S7 I9 ~6 a2 c8 V7 @
    37: ]- o3 Q4 ~/ T, s* {
    38
    ( [7 B6 D: W  ^+ I; I+ N, d397 [5 E% T0 E- s  P" q
    40
    / a1 G4 @2 W, k5 F4 ?- f  I41
    : Q. n" ?. }, Z, [% a42
    3 h# F3 h! _0 N; p. C43
    % E4 g2 o. A' t  n) _44
    ; w. n$ Z  G8 A( X. |455 Y* V2 U" j. v' ]" y1 ~, @3 r
    462 _& H7 G5 [8 d" q! C$ B4 _
    47
    7 O( |( A! v, b/ O48
    ; x  j8 m5 Z* P# w4 q: q491 z- j: C" N* U) Z3 @" |/ k& T; B9 B
    50
    0 ~* \& z, t* J  n/ f( @512 A+ c) h5 y0 q& [& q- Q9 [$ x1 r
    52
    0 O" _; V7 I/ j" c/ Q53
    ( K( P$ V, b9 p' M54
    5 n9 M. [8 U0 @55# \2 q5 d; m  \- M& _
    566 g. _* W# h3 |0 a3 ~+ @
    578 j+ c+ m8 G1 _2 ]
    58* a" o" T8 n9 ^' Y1 ^0 Q( v8 O
    59
    ! K# N+ ]5 p! B9 R606 s3 D& W( R! B# p0 p0 M2 c
    614 w" |9 m* ~9 |0 a3 |
    62# Y- ]! O9 i# {, U
    63! Y9 C( C. n8 h; Y& ]2 i0 y
    64; [6 z: X$ @4 }
    65" W  H$ P4 n0 [' K' n: H/ w, b0 V
    66) E: [. Y+ c) |! |( E8 g  `6 _, P1 J) f; R
    676 `- R3 j- r/ x$ f7 u; A2 n
    68
    4 F( w6 t# V( e$ |6 i$ S69# X$ t; A+ F) w/ M7 _
    708 y8 `# D$ ]' c" f* |. N
    71% K' p2 a1 W& d8 _) m* I+ l9 h8 i
    72; [+ q2 F3 N! X7 e' N( d  ]- p, Q
    73
    7 \0 `' ]* {. B% g5 r" q5 M0 i749 S* j  g, x3 Y0 F, |, g; Q
    75
    - f5 w  o8 L% G% u0 ?" ?# D* O4 h766 i- x) e  W+ A7 ?3 }2 m# _
    77. O6 X, w! V4 U' [
    78
    6 t2 T7 i- H# K3 r# F/ `5 O& [6 c79
    , n3 ]6 O/ N  I& M1 `80" O2 Y  h" A; D; V8 P: g8 k
    81" {& Y) _: X5 S: G5 Q8 Y. y& ~% n* L" {
    82
    $ S* K! k8 L, k4 S8 j2 R& x( H7 m  D83
    - `* |+ G% N2 _" V84: W% ?# z. \* p! Y+ o
    854 Y' ?6 _4 f' c1 b/ r$ _0 ?2 O
    86
    6 u9 b# C: o3 S! s; e87
    1 ^9 |3 ]' Z7 J2 W( `5 ^7 e88
    * `& \5 }3 Z: Z2 }! e$ n% ?89
    . U  J) r% m5 T8 g+ a7 t/ l; f903 P: \7 Q1 G4 Q( m5 t* g) A
    919 j+ u$ A$ @6 S2 H" B
    92
    . ~* x- u# `7 c6 n( _93& D  G$ T5 I7 S4 A& i  b; z/ a
    94
      v( i1 a; z2 V* {95
    , c+ i2 @) K) |0 l3 i4 f/ ?96
    4 L/ p9 E% u+ S, S$ p/ l( M1 v97
    9 V8 q. [7 U' O3 S3 O% C8 Z98" d# X! w% O  T* k! w; P) D
    992 s6 E- r7 X1 i, q7 P4 [
    100
    3 L$ J4 t0 U  ?& d# \/ E6 q+ a101
    ( ?1 S1 D2 X  r1023 P6 `* [  a" c3 q& [1 N# M7 |* t" o
    4.展示一下数据0 E  p7 g( E6 L5 b/ B# h& \, v
    def im_convert(tensor):; v7 \5 n  W; ?
        """数据展示""". `% h, w! v( N" W; J2 n, Q; Y
        image = tensor.to("cpu").clone().detach(): r) Y1 i# r7 i" Z
        image = image.numpy().squeeze()  S5 N! T" c. Y% L" o8 [
        # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    3 M5 [! a  p; C9 f    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)7 |  l0 i) ~, C+ t, _6 x; C
        image = image.transpose(1, 2, 0)
    9 I0 g2 _$ K  W) b* w    # 反正则化(反标准化)
    & ]! U& n( G8 k# w; N) E    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    5 B4 I3 X5 [' Z& L  `3 u  P2 t; o8 g1 i0 M" k3 N
        # 将图像中小于0 的都换成0,大于的都变成1: k& N; F6 f& _4 q4 T, X
        image = image.clip(0, 1)+ h  n1 g1 h. T+ K
    ; |  x$ ~; @6 n! ^& w% t, {. j( [
        return image
    1 _" A" R- T7 u+ M8 M$ N1: S" Z) A% k5 k) t
    2) [  L7 i+ `! S
    3
    % P8 d/ q% b% ]4% i/ }$ D- Y( P% K0 b
    5
    5 N& o" `$ c  Y! J  v6
    % _9 z+ j- f3 {: V& U7) i' N7 _/ }5 B
    8! j5 a$ i- y/ B0 H: T
    9
    ; Y6 l. p/ \- A3 [& C10
    . y  p' M/ O, I- ?& C2 G118 P1 s) V, G# K
    127 A$ l& [$ @0 R# v; k
    13
    : b. C1 e/ Z- o: G$ f# @5 \14
    , ^8 I$ m5 x, F. m, Y- k* m# 使用上面定义好的类进行画图
    8 |8 K# d+ e% t1 j( j. H$ B0 D: Lfig = plt.figure(figsize = (20, 12))
    5 F* o" U: o4 K+ R' s: Y+ zcolumns = 4
    # B. `" D" h! w5 q8 G6 L9 L5 ]7 a0 v5 Grows = 2  G# ^2 p0 y1 \$ h

    0 V: y& ^1 z1 Y3 D# iter迭代器0 e1 r" ]) v6 C1 a
    # 随便找一个Batch数据进行展示( V6 B, `' t' T% k/ e' M9 `0 F- L4 U
    dataiter = iter(dataloaders['valid'])
    + \: \. k! Y7 l. ~2 k, Y" R% Vinputs, classes = dataiter.next()
    ( Z7 |& A# c" L( R3 }! q8 ?7 y6 B) v, F; _/ }$ ]
    for idx in range(columns * rows):
    + R& U( Y& g( H+ F    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    . J% J% X, Z. a5 I    # 利用json文件将其对应花的类型打印在图片中# Z, X+ B  G, ]2 C; ]
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    8 H" _" m4 J$ f) k+ c7 F8 v    plt.imshow(im_convert(inputs[idx]))) Q/ |! b, o: Z9 G/ Y# d4 q
    plt.show()9 i+ s: r: M8 p$ l/ y: q: e
    . T5 b1 E' ]% L; [5 a; ^
    1
    / o1 ]7 [: n0 V9 q* T9 M( x2
    - m# |  ]( R/ k. z* B! n, d- u3
    # t2 H8 H/ b) }$ s7 k1 J: _6 d4- n6 L9 m- m  Z( A% q% P$ Z
    54 S( e4 c3 {7 \7 q( A9 ~
    6: s8 J! I' b! A6 m8 M1 g# h
    7
    0 S( e' ?9 Q. l- W, l2 K) o  _$ N! t2 k8
    ' Q4 [& ]* @! S' G' u* K4 {9
    ; B$ S* f7 E" W10
    # m  \6 O* s' p, u- F# C11' q4 t0 K" r; j$ d4 V, V' Y+ [/ }7 C
    12
    5 b) W6 ?" ]# ?13% d$ ^1 t; n" G3 g( t
    14, X) h: c- @: f# I! \2 V
    15' y. u- y/ z. M
    169 ]" }4 ^1 {3 t! U3 y8 v
    & H; h$ i% O2 @$ u2 }; D% \

    0 X" x& A3 S6 t( \/ f* B5. 加载models提供的模型,并直接用训练好的权重做初始化参数( R, j5 ?" @; f
    model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']! w' A: B$ t1 g6 [# v% B1 v2 o# j
    # 主要的图像识别用resnet来做: |" D4 l& N  x- |, r. X0 q0 s
    # 是否用人家训练好的特征
    7 M3 S; l3 u; b3 E* E6 q0 J2 xfeature_extract = True
    ; p: [7 G5 C7 A% w1
    ! }: @' @7 ^/ m0 X7 p2/ Q: A2 h) r0 l
    3, K5 ^  f8 g; I' Q) \' z
    4
    7 {" _) n& W* o/ f( ?7 j# 是否用GPU进行训练& ^/ ?# O0 H0 \+ e% m7 G5 m
    train_on_gpu = torch.cuda.is_available()( O2 L) N( u6 _* u  @, I

    * i- x# r1 z: p# @7 `if not train_on_gpu:
    : ^+ J- t9 d  j: h, Z- Q    print('CUDA is not available.   Training on CPU ...')
    ( y0 E. d* r* Lelse:2 o  h$ p/ t( o8 t+ ^% e! t
        print('CUDA is available! Training on GPU ...')
      |% S1 ?" ^) g' Y+ B  }
    : E& h$ n/ ?6 [) z8 Ddevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    : J  ]4 w9 m2 b' B+ p' o1
    / Y* u2 e3 N+ J) E2
    ( ~  O6 M& L/ e6 e  P. G; I. `) a5 j3
    - p$ p, m; f5 Q4
    3 Q+ B3 _) g# Y9 [# a$ d. C4 l2 E+ Q5* h5 ~7 Z# i) f5 [
    6. o" s# N" H$ ]8 J" L8 R3 q
    7
    + ^+ U, ^( O8 g: l$ a1 U5 O) C8* w/ b$ @% M  e9 |+ `
    9% @9 g# M3 ^; m( w5 k+ \! R( F5 d4 j! u
    CUDA is not available.   Training on CPU ...+ L3 c5 z8 p. L2 a
    1  z  _, u# O3 T: n0 h
    # 将一些层定义为false,使其不自动更新
    : I+ i2 T* \# R+ P" K5 @def set_parameter_requires_grad(model, feature_extracting):' n7 h7 v8 x! m0 [6 k
        if feature_extracting:5 g7 Y. a0 ]# P( O& r
            for param in model.parameters():( ?, N- |7 a& [/ T
                param.requires_grad = False& W7 t0 w& r  v9 ^( o, [
    1# O' J, M. f# G% m
    27 e, \3 _# l9 N3 k2 Z( S
    3
    : L) R8 L  ~; Y$ }( V4; }4 c2 ^4 I& {; B! j
    5
    " H) u- [; P# ?. q  e* s" O# @1 R# 打印模型架构告知是怎么一步一步去完成的
    3 G% J6 m( C+ Y3 b; ~# 主要是为我们提取特征的( Y+ `- L% Y$ T, N% E9 P0 ^

    1 w6 J! e& i; R5 L( \model_ft = models.resnet152()
    & `2 P& O8 S; Vmodel_ft2 M; y, c9 T! |5 c! y+ @
    1
    0 ?+ D5 ^, b( Q3 Y2 G& o2- r5 q* V" g; c7 X
    3. s+ S* T6 O8 x: H
    4
    . W5 ?$ q" Z  }; X! f5
    * @8 e/ V$ ]& c$ i3 jResNet(
    5 B* @( j8 `! W2 y1 J& o  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)5 G1 \. _+ H/ W. D* |# [: w
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    9 ?) @  r9 s- _$ P/ h  (relu): ReLU(inplace=True); ?" D1 f4 L) p3 l
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)5 t8 \( ~7 u% J, _5 n9 z% X7 y+ m
      (layer1): Sequential(: \/ h4 o, _: X8 R* i/ u, n. y0 z
        (0): Bottleneck(4 S1 h' D! M# O5 k- [
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)- V! N" \& o, j. Z% F7 N4 u
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)8 m7 d6 T; y0 Q% A5 j
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    * i& \1 N8 [) S      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    / p" P  s. H8 f& z/ u      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    # x7 Q' J) q2 l7 {      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)* [  m$ Q! ~! ?' p$ B( I/ m; f9 h
          (relu): ReLU(inplace=True)' m* p1 Y( Y7 H+ ]1 q1 {4 m; a7 W
          (downsample): Sequential(% H3 F, F$ k  K# c0 F$ o# I1 v8 o# X. i
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    + ?" r/ A/ L; N7 X, }        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) m- D! n( x4 L# s- f0 z4 c! b, ^
          )
    & P0 L3 y0 K- v9 T    )
    2 [: v# [/ P6 R2 q* ?中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。. @9 G' M+ i& Y4 a0 p. N$ a3 |
        (2): Bottleneck(
    ) C* K7 c7 E# S) P* B      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)3 ?' `) H3 t+ H0 E
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): e4 s5 }* @9 }4 \$ Q$ z0 j
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    , j% Z, i8 y" a& H      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    * Q' K  O: ~8 J7 V3 [) O8 v' l      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)8 o5 D0 Q1 p: L& @
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 @) Z2 A8 G5 f
          (relu): ReLU(inplace=True)* A" j3 y) H$ u$ P
        )" G. b) }' v$ m
      )! R- e4 D5 O5 s
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))* G' @: ~% R- |- P( Z7 a1 u! {$ w, L# N
      (fc): Linear(in_features=2048, out_features=1000, bias=True)! |# W* g2 j- h' |
    )
    * {+ n1 \# x& X" A3 y; z5 [" u& O
    6 C3 Z7 T/ z( u. i9 m! \4 F+ s0 L1- a% b2 |/ ~# S) L: B: b- E* b
    2; A3 r! _# f) N" U- ]3 q  Y
    3& \2 U9 o$ k- D  A# C0 Z' v& ?
    48 d5 r& j" G- [4 L; j/ `8 i; g* w
    5; U$ V) }' O. P% d+ a7 q
    6( N6 F! K# I/ [, N' M
    7: r, f( X. V! I& L4 q) N
    8
    7 p- q7 b$ K9 W. a0 c92 y+ V+ I6 E) U
    10
    5 n1 R  w. D2 ]4 O11
    # C0 j; H: ?: r1 p, Q/ ^12
    0 `# l: N$ z! y, G5 y- Y13. n0 L# ^/ z% r- c$ w
    14
    ! e! Z2 Q5 s+ }) Z156 m4 Q5 I* y0 i
    16
    4 z$ O  n3 g5 Q% @) f- y8 P1 l17
    + P1 m" t6 Z: u% Y3 K( t18
    / ?9 i( P, z/ L* m; \# ^' K% D19! R* w  s2 W& b! @/ \% {* t
    208 d* n) e: I) K  o. \# f% q' m
    21* f9 u& h+ w1 |; C7 P1 a
    22
    1 {. N& d9 _1 U) B' }23
    $ V8 g# h7 Y/ V: S7 S8 b1 q24
    % m2 L0 W: L4 n25) p; o+ Y7 V1 A  k* k
    26) Q% N7 k$ c% j6 n
    279 s% G  G) T" \3 @, s
    28
    $ w3 e( C% J& S$ u1 V29
    5 L: `6 r& [: L6 Q8 Q+ T# B30
    % S2 i2 f" r- C- D31
    . W3 r( r* B; L3 _! z32
    8 u) S! G( L' o. [  u337 B% a" C( E' D+ k" {6 N2 p$ p% }
    最后是1000分类,2048输入,分为1000个分类! O6 W3 p' T/ {) J8 N8 ?
    而我们需要将我们的任务进行调整,将1000分类改为102输出1 ~3 c5 c, N4 F9 i/ f2 i) `: a

    3 m" \; }. m/ e  J4 G; E  N6.初始化模型架构
    1 w+ r$ t2 \& T* I步骤如下:& |* f/ B4 a  Q+ w6 ?7 f7 |( p
    - }' Q- Q( r4 l: g) b! R
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数2 \+ \, M; M6 H  c0 h
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)* ^5 N, O2 |" a  a! y% @
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数" i3 d0 B9 u- B
    官方文档链接; ]3 b8 n1 Y$ D% _9 `
    https://pytorch.org/vision/stable/models.html
    % T2 j7 z6 j) i' K5 o. H
    , J  }5 u! j5 n! m3 i4 Z# 将他人的模型加载进来+ S. z. J0 |. S
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):" K3 \; p/ Q: r" M1 `3 J3 h4 S
        # 选择适合的模型,不同的模型初始化参数不同
    " i  L% q4 B1 e/ P$ z    model_ft = None
    ' Q' ^$ P' {! Y& ]    input_size = 0
    + M, ?+ s' h- g/ B% f/ Q, C! a( Z: L! f9 w$ d
        if model_name == "resnet":; e; [3 ~# k4 h* [
            """
    2 s7 e, M) E/ g9 o4 Z% v+ x: Y        Resnet152
    2 \# H% r( P$ T5 m        """
    # X. ~; d6 n8 ^) c5 X( ]( @" e# u9 c1 W7 o5 K1 U
            # 1. 加载与训练网络
    5 l  J  M& b' F# t; t) l        model_ft = models.resnet152(pretrained = use_pretrained)! v7 }4 t0 a4 S, Y; k% V6 o7 |
            # 2. 是否将提取特征的模块冻住,只训练FC层
    ' P. J& I9 V) E+ f' k$ P5 o        set_parameter_requires_grad(model_ft, feature_extract)
    0 p/ n* z" @" L2 c2 ^- b7 A        # 3. 获得全连接层输入特征
    " g6 x6 J# a% r" }8 p! q* m" x0 K        num_frts = model_ft.fc.in_features
    ) P1 h5 {- _0 [9 u. J- V1 L# ~  `; G        # 4. 重新加载全连接层,设置输出102
    1 D% Q+ O: R7 D        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),2 l8 k& s. x1 X6 Y/ F: ?; u
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1, e$ b5 X, a4 \! k; O
            input_size = 224
    3 M2 S% w+ a3 y- \* N$ y" ?; Q1 B
    5 e. t4 K. h4 N, w/ z+ j    elif model_name == "alexnet":
    ( d( z8 p* ]4 j. [7 J        """, d7 c% s  y6 [
            Alexnet
    , q5 u8 u7 a' R) x; V% ~" L        """
    + l( X# q9 N5 C/ B& p2 u* V$ j! {        model_ft = models.alexnet(pretrained = use_pretrained)
    . A9 a  T' \* w% e        set_parameter_requires_grad(model_ft, feature_extract)
    ) l( C- V3 w$ |  A- h1 d* }. Y. _) H" e
            # 将最后一个特征输出替换 序号为【6】的分类器
    8 J2 m' y7 b1 X        num_frts = model_ft.classifier[6].in_features # 获得FC层输入, b& }; X: L0 S% p' y1 \
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes): H! S( W9 g& ~5 ]! H
            input_size = 2244 x9 d" U: D  g2 V7 X+ E; U% a+ x
    $ ^) f( e* X" A5 M# J
        elif model_name == "vgg":5 \2 Y1 o# W0 X( k; x  o9 T
            """
    8 s0 e2 F+ K7 w' n  |+ Y        VGG11_bn3 T3 W: U; z( S1 T- F1 m7 L
            """% F  b5 T* z8 ]: U
            model_ft = models.vgg16(pretrained = use_pretrained)1 g8 Y+ P  z: q% Z, M  a" j
            set_parameter_requires_grad(model_ft, feature_extract)
    ) B( Q, Y) Z% x6 v        num_frts = model_ft.classifier[6].in_features6 j; S1 v2 e" q$ a& Z
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    # G$ J( ]6 P$ z0 F) |        input_size = 224# d" x/ f5 c7 S7 Z( r/ }; i. n

    0 ?* g$ y" I: A$ J    elif model_name == "squeezenet":$ Z" ^, H) g8 w: u
            """
    2 c  t, r1 E0 Z! N        Squeezenet
    8 I' U' b' c* _        """, U, i9 d8 O  B$ J3 a
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)- x) p( {8 U" ?/ Q6 w- R% J& R
            set_parameter_requires_grad(model_ft, feature_extract)
    ; H+ s, S$ u& O2 c        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))% s" N% ^- q$ o" l2 ^
            model_ft.num_classes = num_classes" e$ _7 P( Y2 F% u
            input_size = 224
    / U6 K3 {5 ?6 Q$ P! Y( w9 k/ I5 q, S0 o& V
        elif model_name == "densenet":
    . [) Z' V. G7 e# J, `9 ~        """
    & e3 E: u2 o$ {0 e: Q$ _        Densenet
    / h' a! O$ g; C! p, P9 y" d/ g8 z  @        """* Z3 T& ?( t# c2 [: x  z  j# S/ l
            model_ft = models.desenet121(pretrained = use_pretrained)
    . ]1 r  @: [& q( z! _7 f        set_parameter_requires_grad(model_ft, feature_extract)$ |* ]. H! @3 y) h6 U. g/ x8 i/ S" E
            num_frts = model_ft.classifier.in_features
    - l; O' m# [# Z9 o6 n3 Z1 M        model_ft.classifier = nn.Linear(num_frts, num_classes)
    % O# m3 S. u$ L- Y7 h" H6 u4 {        input_size = 2248 g) i0 J/ y3 X

    - p  @: a, I8 j2 R    elif model_name == "inception":
    & ~2 F6 A* Y. o* h' X        """) j( b: J& {. ]: Z" W, j
            Inception V3& O6 }; b2 m- \' n0 h' n
            """
    $ }! |! k* D! j2 D: `        model_ft = models.inception_V(pretrained = use_pretrained)6 E2 x* L2 c* B5 Q" |
            set_parameter_requires_grad(model_ft, feature_extract)5 B0 l3 J" f, i- w. l5 n, \- ~

    , A" b) q. x- w+ v* Q        num_frts = model_ft.AuxLogits.fc.in_features
    1 X4 P% b: u/ I8 F        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes). p$ U; Y2 R& J2 }5 M$ ]

    2 ~# X& L, B5 _$ W; {( h6 `9 y        num_frts = model_ft.fc.in_features
    $ X4 ]( ?8 [0 `- z        model_ft.fc = nn.Linear(num_frts, num_classes)  o: J. V( Z* w3 j# t+ z7 z8 M
            input_size = 299
    * V; v' a4 n! n7 F( i; |  W) [. M! Q
        else:
    8 v: A0 G, Z) v: A        print("Invalid model name, exiting...")
    ' C: ?* u" n- Y# P2 \! ?" e3 j. f        exit()
    $ m. ^$ u5 C; G3 s1 D7 e; c+ f# Z% I$ O+ h6 r4 W4 @) G. a& b. |1 M+ A
        return model_ft, input_size& F9 B# i. i/ g

    - E+ d2 `  }5 S2 D* x3 N1 B1
    ) P. C- a& V/ f2 u2
    9 `, i" ?8 Z/ U( K7 v3: q% L% v4 ~& c6 l: d4 v" i& U
    4
    ' K! S. l$ E, q2 r% I5
    - ]- @5 f! i6 L- L6 _+ a6
    * d7 d# E6 Q+ {' Z7+ J/ `4 ^, z3 o4 F' j8 d
    85 f) i2 K. s6 o) a; r7 B
    9
    ; u6 S( x: g: v10) h" |& c4 E' ^  `& z% R  H
    11
    ' Z; o8 t/ I, `9 `) S4 r" f# O12: p0 T1 U9 c' e
    13
    : ~2 |* Y" i! w- y14
    5 c/ z& R5 B' j: O- ^, q# Y15& ?" a- z0 z# a7 L
    16
    + `8 T3 |. r/ ?  A7 K# S2 d% E6 I3 E$ ~17% c, y& f: {- o5 c/ K) v2 R
    18
    + p3 j: n# g" b5 S, h199 a" x* n% n! X9 N/ D( [. N* }2 ~
    20
    # n+ D6 i$ I+ t/ j2 L21
    , _$ h1 g: m7 c22, g. ^: k% `" Y7 X! ]
    23( @" Q0 e# M) R0 p! T4 k* \
    24
    / i. d$ t/ w1 b7 i25
    5 k/ Q7 t# \- m2 j: e262 O; V$ `5 S0 E5 [2 c" h, B: ]! M
    27' c2 i, V  H2 r( o% O
    28. j. J/ A. a5 ?- Q9 ]8 R" l
    29
    - y  r* g, ?) }3 c3 t" U, k30
    8 J  M* Y4 `7 w1 D- C31- T: S& D/ E: D8 ?. s, k( T
    32
    2 m" z& t/ L; [+ t4 m( z1 H2 |333 F. Z4 |/ y% W# ~# @9 e& n7 @
    34
    $ f3 ~9 M: m" Q& v- k' k2 ]* ~) a35
    1 t& r- R6 D9 \4 S1 y9 f% W36  ~/ B, r' e3 k7 h$ B# }( Q/ L
    37  L6 \& q1 T; m. Q, a! x
    384 p" n3 T" }% y
    39& I; E' y  w7 H0 [: f4 K2 X
    40
    * D' L0 q9 s/ m41+ y, E1 @1 P, ]/ O9 R! [
    42
    / [- U. P4 j# l; M5 b% j43( S& {' c4 o$ v/ q
    440 m5 n. F1 Q; B( w# S' s' b
    45. F, u( e) c& Y. |
    46
    8 K, d. k* y; b$ K; f+ M7 R' C5 [47, W! Y, I  u( Q8 O0 R
    481 P% i$ C5 J0 a9 q
    49
    6 C2 L7 L% n1 C7 }50/ D, B8 d$ D' |0 ?& E
    51
      H/ [/ z( g6 H3 z7 v8 S! Z52
    : O; y2 K, R8 [4 X+ ~53, }! ~& T' W2 H3 Z
    54
    : |- U% }0 W& ?# h  A6 W+ ~55  E4 n( K( u. P; f  ?2 d' }  l6 e
    564 F* @4 n; g6 W$ K* V0 u: s' z+ W
    57
    / B" n0 ]* `! B58: N* H6 d* Y; `& Q0 e  e
    59
    : y/ K( {5 u* T60
    6 J( k- `) b( |* |$ C4 N  o61& u, \) z5 C  \: ]8 N$ i
    626 Z7 }6 [% |) r% v% q$ P- Y# a
    63- N4 v; x( D  }; Y1 w% X
    64
    - Z; C5 @, K$ \0 X/ W- g2 }65
    / c; |1 ~/ B' e. [5 V6 {  j66
    ( @5 s5 T1 h6 Z/ T67
    8 i4 `2 j/ x$ f& ~$ C7 O6 H68. H. B6 E* e+ s  l
    69- \. E! }! \2 t
    70. b: W2 r% ^, a9 b
    714 r! X9 A3 w5 a* {, @$ X: y
    72
    8 u9 K' l6 w4 k& U7 A73" h$ a; ?, l! a+ F0 M/ z( |5 y
    74, ]4 y- x! d9 U( G5 ^% ^4 W. F
    759 C% O6 X1 e2 z9 P9 y
    764 G& U3 Q- e* R0 Q, W# j  e
    77
    $ U8 F' U5 D5 Y% h; N78% E  x& b& q' Q7 k9 F
    79
      Q9 m! N& s' M( W80, o5 i* `1 h4 J- B  _
    81
    ' i( p4 @6 @9 l9 d% r82" X# p( [( f& ~3 ?( i+ o( S/ p
    83
    2 b' Y+ ?% ]8 i: i2 l0 f% M7. 设置需要训练的参数
    * U3 n/ J$ f* }+ K9 b0 h# 设置模型名字、输出分类数
    ' q0 E3 j" R' Q; V6 Gmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)( y- l3 `# w; t! k: Y9 O5 M

    : g, J/ {& @. z) G! C$ ?6 j# GPU 计算
    5 L$ d0 Z- O3 b' z) N7 Ymodel_ft = model_ft.to(device)
    , K. ]  {) f+ ^7 v; T" q. T
    - \& t* n# s0 K8 G5 ?( j# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取8 @, k- A2 m% B- z, S: y! C, _2 }
    filename = 'checkpoint.pth'
    8 D- Q! b; ~3 J4 u. Y7 }$ A" _5 F7 D$ D
    # 是否训练所有层5 W  X2 {. d( G! Z3 G% y4 M
    params_to_update = model_ft.parameters()- I/ ]/ O% r6 ]
    # 打印出需要训练的层
    / P7 x" q" ~1 q" G& f6 Cprint("Params to learn:")
    : G0 o4 U2 u0 dif feature_extract:
    9 ^+ Z1 B0 O4 ]; U8 V    params_to_update = []
    " s5 t7 O) o0 [8 i1 \    for name, param in model_ft.named_parameters():0 r0 q2 X9 @; p6 b. e8 O
            if param.requires_grad == True:! \1 W4 R5 |  A  m! g+ n6 [
                params_to_update.append(param)
    ; H: l# V# a6 c            print("\t", name)
    : f% J- K' F- g- b* celse:. G, l" g, Z* i/ w) S  G1 N
        for name, param in model_ft.named_parameters():
    ! F; j$ N1 n  o; T+ U. J" S        if param.requires_grad ==True:
    # @" w. W% k2 K3 J% e6 r            print("\t", name)- T1 `1 o- W4 W$ E. V) p

    5 p  q0 t# O, @0 O; h- k1; t' W. Q& C: y% ?# ^6 G3 X
    2
    2 H% z% B3 O1 O6 G9 u+ d9 L3
    9 f, R  |( p; h! @# d6 v8 J40 Q- k& E: w5 e1 q
    5
    / A6 p* i& I2 {: P: X6( S8 J) b/ C; V& f/ T/ e5 _" @
    74 e  e+ \4 z1 W5 _& u" `7 |% v
    8! x/ N+ A' [) s9 B, ~
    9
    ; e; g0 ?0 V3 u/ I10
    # F! d4 q. |9 I  s11, D+ t$ r5 r1 G
    12
    ' Y2 \; q' o/ R13
    - W1 ?% C/ n0 |& ?9 M+ `14: p5 p; o4 N2 k3 r. a
    152 F/ {! c) Y/ n3 n  b7 n9 x/ D% y
    16
    7 d1 ^9 y; O1 I; A178 w+ z* [( D  A" {
    180 Q5 X) z# Q; f
    195 H7 Q  i1 B/ ]' K8 p: q5 ~) m
    20
    # ]9 R# p& C! P2 B; q* l. W% L& c21, z. o# |; V5 g" {
    225 A6 V) K, `& m) c: d
    23: k4 j1 U& D! a; z- s
    Params to learn:
    9 N( [7 g2 o* G% l. p, z6 g         fc.0.weight% I" L' Z9 p8 Z" F& s* r) c
             fc.0.bias
    . a( M8 V% t8 Z! _17 |; I, @4 k1 y5 @- U! ^7 L
    2' r% c* z$ l" F( e7 f
    33 {' B' F4 [, P( ]5 _' _/ j
    7. 训练与预测9 I/ g8 K  ?+ w+ N6 S
    7.1 优化器设置
    2 H% ~# N: N" R' n- G0 z9 c# 优化器设置5 y9 ~2 ^, u& j. l
    optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)/ Y1 B. a6 ]- M5 S2 _- w# G3 E
    # 学习率衰减策略
    4 m$ M! _" y! j2 U+ k* [3 i7 Tscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    4 E( ]" m2 k2 ^$ N3 |" c% r6 V7 R4 t# 学习率每7个epoch衰减为原来的1/10- K; o( U4 K# n! ~" h
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算+ }8 j  t$ M! b0 e
    8 C- }$ v# \% ?3 _8 q) M  |
    criterion = nn.NLLLoss()* M  L9 s! N" D& @  ~
    12 S+ I- e1 y2 I3 K" l- a9 n
    2( |3 E5 A6 G# }% h  N# M
    3
      H+ `: z5 a, u" o+ l# F40 L# {' V) Z5 {8 X5 E3 R* O, w
    5) V( P. y- S7 j9 R! Z1 A. n
    6
    * W+ \& W- l7 w4 Q7 ^9 o9 ]78 P/ K. h& B* |+ N7 D: J# [: J
    8$ m. [6 d+ }' `) s! j/ M  U/ I
    # 定义训练函数1 J, B; a3 S$ |
    #is_inception:要不要用其他的网络
    ; i. E+ C" ^2 Hdef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    2 g* W7 i" J: Y$ {5 S/ w% I( I# O( p& }' K    since = time.time()
    & B% }6 u" }' L# i    #保存最好的准确率
    4 C1 j/ K* S  Z) Q6 K( U$ B    best_acc = 07 d6 l+ ]! w6 S& Z: y
        """
    % g! y( @' Y3 i3 e+ l! Z! u0 c    checkpoint = torch.load(filename)
    : L- A; q) J4 {; s    best_acc = checkpoint['best_acc']
    0 a9 c$ |$ M2 `0 l; m" J  h    model.load_state_dict(checkpoint['state_dict'])9 t" m! y2 x7 e5 q; T1 D
        optimizer.load_state_dict(checkpoint['optimizer'])
    5 P( |: y/ P! S0 F! e    model.class_to_idx = checkpoint['mapping']; r5 y3 q; i7 ~* T0 C" N1 R6 P8 c
        """# j4 `: n0 h/ l2 i) N
        #指定用GPU还是CPU
    " F4 k9 w# K; o2 M; I& |    model.to(device)' r1 s& l7 h  z0 g' o4 d
        #下面是为展示做的
    . j: P4 ]1 X) D% X    val_acc_history = []( G- w% B8 U3 n
        train_acc_history = []1 g- o! |5 n& Y$ O; Y4 H+ D, ^
        train_losses = []
    2 p: \9 y( E6 ?2 a8 V    valid_losses = []* X# @4 G" t4 k  ^6 ^! Z
        LRs = [optimizer.param_groups[0]['lr']]5 G/ ^& z8 r2 p' n5 p
        #最好的一次存下来0 c. t5 P6 t+ {+ L0 K/ ]& ?
        best_model_wts = copy.deepcopy(model.state_dict())
    ) \4 E3 A" ^/ [5 ?7 |! k$ }5 _
    - N" b1 T5 U0 B    for epoch in range(num_epochs):
    ( X/ S  f( z# J- E  y0 L        print('Epoch {}/{}'.format(epoch, num_epochs - 1))8 X6 D. g4 H0 V
            print('-' * 10)) e- a9 F; M: {: w! I
    : p0 w4 `7 H5 S6 ?% A
            # 训练和验证; v+ U6 P7 S, g- W& f' k" w
            for phase in ['train', 'valid']:8 B' M  }2 t; E* ^
                if phase == 'train':  P* {0 y9 s+ h: {$ M
                    model.train()  # 训练
    4 k2 Q; D; [  P# _% k* ^- O            else:. Z: _+ }. p/ |* B
                    model.eval()   # 验证: K! {. N1 g1 f7 d  ?+ b
    % Z; q* @# n  f+ p) i
                running_loss = 0.0
    / c2 O: Q8 V# t" Y; F" I/ l+ L            running_corrects = 0
    ; }: a9 _, g2 U9 p7 R" C/ F9 D9 P* \" g
                # 把数据都取个遍
    - v' u* C% [6 S            for inputs, labels in dataloaders[phase]:
    * ]. S$ b/ D" a' Q8 ~" @+ w                #下面是将inputs,labels传到GPU8 U: A5 i: U  }5 a4 y
                    inputs = inputs.to(device)
    ( R3 h& Z8 h2 l" O7 d                labels = labels.to(device)* z2 T! o) x) T( ^3 P# m% [

    : c3 H1 H/ I9 c- j7 s                # 清零
    & x# G' K9 r& l0 H& w& G* y+ W4 N/ s8 F                optimizer.zero_grad()
    ! @9 v$ U# u4 G2 h                # 只有训练的时候计算和更新梯度+ |) z* l& A8 }/ P
                    with torch.set_grad_enabled(phase == 'train'):( @, r8 B4 V' Y4 C  K3 x5 n$ E
                        #if这面不需要计算,可忽略3 N* y, V0 y4 D) |. Q, b& W; |7 J
                        if is_inception and phase == 'train':2 `$ ]2 N- z* Y) g0 x3 q# H
                            outputs, aux_outputs = model(inputs)5 K; }& A$ L+ ^( u6 h+ q0 J$ E1 o
                            loss1 = criterion(outputs, labels)
    " H' {9 n  f" R' T0 O1 K                        loss2 = criterion(aux_outputs, labels)+ A0 n6 q; s! s* p1 q
                            loss = loss1 + 0.4*loss2
    " d4 I& q, y  E" a1 {# R1 _2 e                    else:#resnet执行的是这里1 k' r! I) o! ]  o8 F4 S
                            outputs = model(inputs)
    5 e; C. |' o& Z3 @                        loss = criterion(outputs, labels)& z0 `/ @  d  ~  V- m
    8 n7 o" X* N* N: q+ ^
                            #概率最大的返回preds
    6 E* f- W+ N) _4 ^                    _, preds = torch.max(outputs, 1)
      u8 C1 F: ^, [/ w* K, W7 f  l, b
                        # 训练阶段更新权重
    : o/ e  v' T+ H8 f& N3 e. z: K                    if phase == 'train':
    ' w9 n6 R$ n1 S# j                        loss.backward()$ ]8 U7 j# @$ o3 P( F4 T0 v4 f
                            optimizer.step()
    / ]) ?6 \6 V7 m+ {$ i3 W8 d
    : }$ r6 R. }# `" k* r4 U$ n: }% O                # 计算损失
    & U4 `* p' F8 o9 @  T6 c                running_loss += loss.item() * inputs.size(0)" {) G+ G( v" V
                    running_corrects += torch.sum(preds == labels.data)7 N# B1 [% A* ~0 o( `0 A

    * \9 V8 a! @+ F- R2 D" Q3 H            #打印操作
    : @3 E4 A. w4 Z            epoch_loss = running_loss / len(dataloaders[phase].dataset)
    7 t! G5 M/ A" c2 d9 Q; ^            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    * b( m) j( [, a9 Y  z- _3 l
    8 B8 r! H5 j' @( @( z! X! B1 y
                time_elapsed = time.time() - since# y# H$ P% ?, V1 A& D
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))! B% E* g1 @  L7 C
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))/ B5 p. u7 y3 l" @

    % w0 y. g/ N: A$ W
    6 c* Q- x$ U1 H5 [4 @4 p+ a            # 得到最好那次的模型
    ( l+ u/ N8 l0 v' `: v0 A            if phase == 'valid' and epoch_acc > best_acc:
    ; x& x- ]0 k' }1 a                best_acc = epoch_acc' h* _: i! I/ n, M+ A( Q
                    #模型保存! f# w' d! Q  ~6 Z6 m5 j0 e
                    best_model_wts = copy.deepcopy(model.state_dict())
    ( V4 [- y# }' }5 }  U                state = {
    * A1 h- |" K2 b  Y) V6 t                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数  x0 |4 C- f( D5 q+ C" V
                      'state_dict': model.state_dict(),; E/ s5 s" Y0 y7 o( Y; l
                      'best_acc': best_acc,
    1 S, B/ F+ [+ g' {0 m& C8 z3 X                  'optimizer' : optimizer.state_dict(),# }3 T* J! Z3 x* d0 B5 B( Z
                    }
    3 M0 e; E# G+ k8 V9 Q" T                torch.save(state, filename), S* Q# x# O) @
                if phase == 'valid':
    % ^* {, c6 Z* w7 s; b) I* N. K- ^5 ]+ s                val_acc_history.append(epoch_acc)
    . a1 y- C  w+ ?& R! G                valid_losses.append(epoch_loss)
    6 H9 m  f$ b. G0 r. \* r3 J2 w                scheduler.step(epoch_loss)
    . B0 D" }4 L  R3 F( n! e% _            if phase == 'train':, C1 v+ M3 [* g0 w* R9 W
                    train_acc_history.append(epoch_acc)+ f% u4 @! @( |- e7 J- W6 H7 ?
                    train_losses.append(epoch_loss)
    1 t5 y' |$ @1 C0 C0 o4 |
    & R- }: n/ d3 j6 n) ^        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    # Q$ h! i" Q. n: Y' l        LRs.append(optimizer.param_groups[0]['lr'])0 x, q4 d3 C$ ^& A
            print()
    ; }! S9 W3 p/ ^- w/ I: N2 A% U
    5 c) ^2 I# w4 [6 H' {4 T: Q    time_elapsed = time.time() - since
    . J( x/ x1 o+ m( }* z    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    ( x: Y& g* Y2 j9 L& |6 p    print('Best val Acc: {:4f}'.format(best_acc))6 Z) [- _8 I7 W' ]! w- Y& N! K
    8 K  g" u( f: _$ h0 b6 \
        # 保存训练完后用最好的一次当做模型最终的结果! ]( S; ]5 `+ D! A2 \1 _
        model.load_state_dict(best_model_wts)
    ) d! r% u2 s* {# l1 Q    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    5 c* d% R. g2 [; W- o3 @* R2 N5 Z9 i
    , z* g- Q; [, [
    . K9 G7 p/ X1 X: d16 S5 J. y/ {+ _, U
    27 d3 ?; P) v; T7 t0 \
    3  J/ f  t9 _; l5 o6 P
    4
    5 m0 V/ K* U8 }2 J$ f1 w$ O5  u1 E- Q. I! x2 C/ h" g6 P
    6
    * U5 Y) S6 e- ^# f/ {! z7
    / c/ ^) Z' ]7 p! y85 y9 l) v' z1 x* ^! V
    9
    5 v$ V7 y; {7 b3 Q10* ?; z8 P7 O4 H
    11
    ; j6 k8 \3 D5 T, H+ c! @128 a7 G' N# b) P9 Y6 r
    137 P; j4 ]/ I7 M+ a$ ]" ~
    14
      R4 C1 ^7 ?4 E7 [8 b4 {: {15
    / i# X+ y$ @7 D# D16) X) A. a+ \! \+ _2 e
    17
    9 W; J5 I2 \' g* p7 b# {/ W/ r18
    3 l# Q0 k( f  _$ m. Q0 Q4 i' t19
    8 k/ ^8 `3 b, z& v& o207 I- l& A8 E8 V4 }7 u" o
    21" z( ]5 |' n6 p8 l
    22) h% _$ C4 g( ~6 |# o6 F
    23
    ; {$ @0 S/ R! f+ b# {24% R# L4 g% l1 a$ w. E' l
    25
      U: T3 ~+ I0 a- C+ I* o26
    5 H/ b5 U7 w. ~3 h27
    " l/ g; `; G1 D' ]# Y3 E$ ~2 i28( d$ P/ w6 M/ E$ Z; G
    29* h2 ]7 y# p8 \" C# {& [" u
    30
    . W- b% O7 F3 F; F, S. F31
    4 A8 d* D! W( d, {! J% K32  @# S; `3 k6 ~/ g1 O. s+ U2 C2 W
    33
    6 t4 [1 D; O* q7 X4 r/ R34/ Z; \0 @1 @6 B. \; Q% E% h. k
    35# r: c! }+ _$ x6 r# B% g+ t7 V9 U
    362 @2 p( n: A# |4 x9 h9 k1 h
    37; v& d5 F- A/ |) t
    38
    # d/ D2 m7 a& X$ f39
    , W# ^. g. E7 z4 s/ h0 g/ X8 u) q, x5 S40" Z, v* L, X/ Q7 b) M( r
    41
    4 X0 }& t* G* L2 u429 |; V! @+ W. e, l6 l; P& P) c
    43
    # y2 x) E4 Q  Y+ r5 V44$ D* [2 w  X2 G9 I/ a; m# {
    45
      y( V7 x% W+ }3 \4 E0 ^2 `46
    % W, ~+ b1 h5 E1 S- B47- b$ z5 x$ A+ k
    48
    0 Y$ n2 u/ D2 G% H% q. [49  X, |( `+ }: {1 B! _. C
    50
    + K: a* `  @0 \' C: `. w51
    + r- c7 a4 `2 P9 R2 k52
    % ]  F# }; B% o8 s53
    : \; m/ V0 H0 A5 N  d; F54
    - q4 R" A5 n0 R& X3 {% K55
    " X( D% a5 m0 S( t! B" ^56; N# O' l0 R6 x5 Y; g$ v% v4 _
    570 A! n5 e2 I# ?5 C
    58% L8 h% x' y, U$ D! g, {: ]
    59& Q* A( D0 b( p/ e
    60. n4 F" B; I1 [* y4 O7 Q
    61
    7 v9 B* J' D# B# O7 ~62
    " @+ ?) p' w, E+ k4 }0 K9 W" ^5 q63
    6 h  I% k9 }# d" Y( V% T1 `/ c64
    6 C$ c) i7 z2 d8 _9 Q5 \, G" e* \65- b- c2 w6 z+ v" _; N9 Q
    66
    $ w" @! T1 |8 H% X- ]; E67" y5 V) _7 n5 P  e" ~
    68
    / y- L) D) _+ _4 \69' R! ^- o! \1 }
    70" S* i+ M6 I$ `' g' i9 {' W
    71
    3 R, z  i. R6 S8 N9 @) Y; j) p72# M+ i+ T8 F3 U1 _% h3 {0 q0 \
    73/ J7 s1 B$ r5 D% E
    74
    * k6 N9 E& c/ q75
    % y3 Y8 H& ~# B" G6 d2 x763 y$ o; V/ |7 e8 `: w6 b
    77
    " }6 l9 e9 M+ [4 X78% ?$ B# {( ]4 T7 f8 L' |
    79) W- P% O. F. P. y7 D. {0 h# j
    80. ~# I- K; ]& i
    81: R8 P1 r. P$ k# t7 e9 ~# m
    82
    1 \7 `2 h0 |  F$ ?4 f; @; ?3 d- k839 K0 V% V+ C) I- O1 b, [  K& \
    844 i# W  P- m  m8 J! T' J  P( x% Q) a
    85
    ' ^+ M: Z/ I# E. B# F86
    , M4 m4 ~& P+ w9 C873 y1 k1 R# ]9 B
    88
    # E: K' G" L! ^3 E6 q4 c89
    + L- _: f8 V/ ~) @: N904 l8 F* c7 x+ N& B
    91% S2 t& d4 h+ P' l( K  c
    92: r$ l$ Z- X* e% u" |4 |1 j
    93
    , E9 l1 U" ~3 m& {8 Y, D94
    - D' q& H) x* J8 S4 Z- ~' D# ~95
    ( O4 S; [' `9 F7 B96
    % B& M. l( `. v97
    ! i% k% m4 e+ J* A' f98
    ; U) K6 M6 J- u1 g2 E& o5 f993 U+ ?& _. C+ X# X
    100
    9 P5 b: \  Z6 j7 @! ^101: i! A* A' w' i7 L# ]
    102
    : ~5 z1 I- K. p0 W, @: [6 o2 H: a1031 k0 W8 c) q# C8 T
    1045 v) X/ x) ^* c0 r4 D
    105
      t* ~5 c3 E& T* m. n106
    4 q7 j8 d) b2 Z( V" l6 s. i9 a107" u' ~9 r8 t/ N3 U
    108% a9 G( T8 q5 x! H6 U2 x
    109
    % i/ C* q0 ~) b/ _; y! C! S, ]7 }110
      {: H; E+ H9 p5 x) O% h2 z  }111
      v/ v3 V' L) w* f8 j* Q112! E- S& I8 `% S( M7 F2 B% ]
    7.2 开始训练模型
    & P3 E- i1 F& d% @2 T我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    + ]6 D5 W( t1 u  w9 [3 @
    5 b# I6 ^* r: U3 L5 K#若太慢,把epoch调低,迭代50次可能好些2 D9 X6 t- P" ]; w% V, `
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合+ v* A6 t. r" d$ T" {# f
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    ' P$ G* c* o+ I7 V( I
    5 L* F6 |8 G; H) t" ^7 M6 l1
    7 h. w! p; I1 a, P/ y2
    2 v9 j4 T/ ~/ D* `% t5 G( r3, U7 N6 n. p; V2 }) G' E% a2 c
    4
    2 T$ b% t1 n/ X  m! @5 a7 |Epoch 0/4
    ; a7 Q5 Y- k  J: U5 x2 X! Z" p! d+ Q0 {----------
    0 U* ]0 c* S8 c, C; a. eTime elapsed 29m 41s8 X+ t! K, U; A7 \6 s
    train Loss: 10.4774 Acc: 0.3147, a: ]# l9 p0 {
    Time elapsed 32m 54s2 S7 o' F. `7 a5 U  C$ X3 ]) \
    valid Loss: 8.2902 Acc: 0.47196 N0 _, X1 i+ B
    Optimizer learning rate : 0.0010000
    6 @1 R6 _9 I3 o9 p+ {. |) H. ]0 ~7 h
    Epoch 1/4" v! G) S7 B7 ^: `; O
    ----------
    1 f: P* Z6 p+ E, o* C: S8 VTime elapsed 60m 11s( b$ o8 w& v& M/ z/ F9 }
    train Loss: 2.3126 Acc: 0.7053" F- u4 ]: |3 C! Q! n' W* g
    Time elapsed 63m 16s' `# I* M  S$ R: r7 f2 J) I: u
    valid Loss: 3.2325 Acc: 0.6626$ ~1 i( D  q. L8 B* q
    Optimizer learning rate : 0.0100000
    4 A3 M. M2 r; w; v* u% o+ f% R1 a" k
    Epoch 2/4
    & t! z3 P3 _4 B; w& h8 }4 g6 }----------# c8 x" [, Z' S3 G
    Time elapsed 90m 58s: S7 [# \" T5 {1 Y( `
    train Loss: 9.9720 Acc: 0.4734
    / K. P# \3 _: O* O2 ?Time elapsed 94m 4s
    ) z+ ~" D. P6 P- \: u+ E5 yvalid Loss: 14.0426 Acc: 0.4413
    3 c" B' L7 A) X$ X0 ]2 T8 D' sOptimizer learning rate : 0.0001000
    ( E1 ?2 m* d$ x. U
    4 Z2 v9 w- T; E5 G& L; [Epoch 3/41 o$ k+ R* P& {, w/ [- ^7 M0 F
    ----------4 ^) K! b# ?! a" Q; o" I* g2 k, J$ M' f
    Time elapsed 132m 49s
    ; ]: m$ c: I. R: M! l; Q0 O) rtrain Loss: 5.4290 Acc: 0.6548" s, T# X+ {$ b3 t2 c0 |
    Time elapsed 138m 49s( C! m& t7 H1 X+ m# A9 \) b7 s
    valid Loss: 6.4208 Acc: 0.6027
    ; S5 C/ N2 M4 i- r; XOptimizer learning rate : 0.0100000% ]! q+ }7 j$ P3 r; f

    6 l0 r1 c8 x8 C% ZEpoch 4/4% k- U/ S6 S2 D% p; W
    ----------) n0 G% _- J3 A) X
    Time elapsed 195m 56s( ]; w; B. V# A
    train Loss: 8.8911 Acc: 0.5519) y0 z/ U; v# p, G
    Time elapsed 199m 16s) ]3 T2 l/ i1 N: N- C6 t$ t$ G, u( Y
    valid Loss: 13.2221 Acc: 0.4914
    # q! \7 K+ H1 g7 f! FOptimizer learning rate : 0.0010000
    , Y/ ^' A7 h" S
    8 I% T9 D: _8 d- J2 I9 JTraining complete in 199m 16s/ ~1 |0 j9 [8 i( \- a/ c
    Best val Acc: 0.6625920 F9 K" T! ]# @0 y
    7 g9 L" ?+ @  ^9 f
    1
    4 T' m" j* d& y% O) G3 X7 t8 [6 H2  c2 L9 A! Z7 ?. g! r& D
    3( I4 o/ m5 P0 l
    4- X- V' T6 b! p. S8 s- P" O9 n
    5
    6 `' X. Z5 J; ~8 s67 [+ X/ D9 \: y" S# D. h
    7
    ! \; A; u5 n1 p* l, Q: V86 E+ l' M0 Y  N9 ?7 o4 i6 E- ^
    9
    7 e; K: ?3 @+ i10
    # F9 x# y9 H/ H, q" U115 \4 B# |, L5 k& h
    12- F( W+ I' ]0 P* v! f' R6 _
    13$ I. S( g2 [4 s8 t
    14
    ( ?% m- b) C( `. }+ C8 n9 ?15
    8 Y$ Z  |& P) n16
    $ Q7 O& P% k. A& y- S* {# h17
    9 d* y9 _2 v9 w4 A" ?. h. Q1 P' h; X9 q18
    % u# U: g: ]5 [& P' k( g8 m3 K19
    * h, r5 e9 T1 o1 q" e5 s% K2 i: [& y20
    7 K+ J2 ^2 c4 D, f' a0 p21
    & r- t( w& R( h. v0 u3 N22
    " h* f5 K0 f5 `. s" a5 {* V23
    & z# [( {1 O# M  b24% Q1 N- T, A8 s
    25
    ) i. \- R! g9 M) M2 }  A, W262 c& \, j5 Z% G5 c! P
    27
    ) T2 Q& S8 n6 p" R" e285 P! e5 w! x9 i
    29) y$ R8 x3 d( o6 T8 y* k
    30
    7 w: p6 `8 X7 ?4 m: u5 o31
    ! \4 z. v: |' x5 r. e' P! c' j  v$ J8 Z32
    5 a1 S0 }( f5 h* J33( M' B/ F7 u) x  K; i
    34/ z$ K9 G. `$ p0 {! ~' }
    35
    & O  o6 _  C8 r* e7 t/ D36
    2 i$ F2 O4 ]/ x3 g' }. x, W37
      {6 ^% u& u6 V& ?: @' X38
    * h( }( @2 H0 L* g3 p/ k" s2 L39. g& S$ j" j" z# u' s* H9 @
    405 a* G8 ]( U8 Q9 w; A. w7 V
    418 }$ U6 {' m$ f$ s! t- @0 L  O$ m) I
    42
    , \& R# y0 F6 s# Z# X* u7.3 训练所有层8 ]* r( B9 e$ f) Z
    # 将全部网络解锁进行训练
    : I/ y+ R1 ]7 Lfor param in model_ft.parameters():
    7 f( m! K" j2 I, Y. K1 I    param.requires_grad = True
    ) O% g7 a6 x* t2 l" o0 Q5 D, z+ e+ l7 s3 c# |
    # 再继续训练所有的参数,学习率调小一点\5 u  x- L/ x0 F( w/ b
    optimizer = optim.Adam(params_to_update, lr = 1e-4)
    3 g; Q) ], a2 f3 w7 O# tscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    % y# H! r" y  \% Y4 L9 Z0 U# e
    ) |+ ]( }' h* S4 ~, C$ b+ y# 损失函数
    + P* X( n( F, P# v. f8 l9 b/ r, jcriterion = nn.NLLLoss()
    9 _/ |' F; v% _1
    % L7 D; |4 m8 _2 w4 j* U9 [9 E3 ~2
    8 w" \$ P8 `! X9 B4 Q" P% ^38 C$ k& ~) Y. n& P+ v9 B# a0 e- @
    4) K3 U5 x5 T3 Y+ H5 W! I. ~; C
    5
    5 @- h/ m8 B; f* j; U6
    ; D! G( c9 w6 M: J7
    7 V' y$ E+ ?& E; q! q; w7 A0 ?$ |8
    6 `! P& G$ H0 K1 v1 ]9
    2 C  B- Z$ Z# w( x* T1 e3 Y/ D10
    ( f. V: X7 @2 l3 u1 Q4 ^# T# 加载保存的参数
    9 S. u6 I- i9 m. u3 M9 r, e# 并在原有的模型基础上继续训练' S  @" k' R; l, R6 {* T$ B7 }
    # 下面保存的是刚刚训练效果较好的路径
    % k2 b9 C0 W# {1 Pcheckpoint = torch.load(filename)- U# h$ C9 J% r0 w' D. P
    best_acc = checkpoint['best_acc']
    7 P- C+ b$ [+ q" o- c4 u) Cmodel_ft.load_state_dict(checkpoint['state_dict'])
    5 U# Y. V7 B/ ^/ hoptimizer.load_state_dict(checkpoint['optimizer'])
    2 a( r& v* F4 M% k7 [$ X1
    + c6 c/ R- E5 v. Z/ z7 m2
    1 x# R3 i/ s$ v6 G3
    " Y: k8 a$ n  b0 U: _/ [47 c8 x8 `/ N( _* k7 l2 g6 M( J
    5
    + n( T; A6 ^. w6
    3 r" F% {& I1 t% f7
    2 O/ t! J' R! V开始训练. M0 P2 g8 g9 F9 D
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    . P) r2 A( t) \1 I9 r2 R4 N/ _5 i5 W3 X! ~9 I, R" t$ W
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    . W' A6 E6 H9 p( [1 A1
    5 H4 l7 u6 n' Z" g: `" xEpoch 0/1
    3 r7 s- {" i, U5 B- k----------8 L& v- ~7 z; ^9 Y, q, C
    Time elapsed 35m 22s
    1 x: l3 n2 e; p, e4 C' `0 e' B0 i8 ttrain Loss: 1.7636 Acc: 0.73460 Z! j+ h$ K( e+ i" m
    Time elapsed 38m 42s& ~, k" o1 p- `) O! `
    valid Loss: 3.6377 Acc: 0.6455
    ' M* k& N1 S2 z9 j4 w  T) b* gOptimizer learning rate : 0.0010000& N4 X' g, C; z' Y
    & a9 Q( H6 |+ L+ F# Y. K9 I
    Epoch 1/1
    & K1 \) K' e* C6 U----------2 |9 {5 J# K. b" n+ M2 j
    Time elapsed 82m 59s: r9 h9 Q8 @2 o0 o
    train Loss: 1.7543 Acc: 0.7340
    5 \2 s3 j! n' B% K- ~Time elapsed 86m 11s
    8 v# h) |! x7 Y. M2 F+ x8 Xvalid Loss: 3.8275 Acc: 0.6137% X3 k3 }& l0 I7 A) l6 ?" u8 o$ \
    Optimizer learning rate : 0.00100004 l" Z9 A- ]2 I0 T4 t

    $ J" v1 l0 c* B4 a  P, y- lTraining complete in 86m 11s% {) X; p& W' H7 q7 G
    Best val Acc: 0.645477% Z! [  V# h7 B, Z5 @

    3 O( W- ~: J4 ?0 _- h1
    % P' d; r; Q3 g. c; R2) i! e, m) g3 j  e
    3
      P4 t; ^2 Y+ O8 G4
    , _9 d- f8 s3 f0 l1 X55 r* G1 E6 A5 g6 f% f5 j  A% {0 K) s
    6
    ! E: ]+ U: m& r! t6 S, C. \! m7* J* x; t9 M* n+ I: F
    8
    3 i8 B# I' r6 [4 L; M* u# D7 f7 R9% |. I/ K' A- x$ E5 F/ t
    10
      c7 Q3 Q, p" T& `! X/ U- v" A11
    / Z) m$ Z! @2 C6 Y+ y12
    / E6 D5 u% f* p, p13$ Q8 D, C6 a- X) m/ O
    14
    7 o- O/ x( g- x15
    6 U6 @! o# Q9 G$ ?' h3 G7 s2 h$ q$ k2 V16, s6 y. c: x# ?' G
    175 K8 n0 m; _2 ?; }
    18
    % a* X% `( D2 Z! X+ E' A8. 加载已经训练的模型
    ( s3 J. `  P' L* g) d# V  h相当于做一次简单的前向传播(逻辑推理),不用更新参数
    % I0 Y% e  |( m! Q; M9 Y# s' ^
    . O1 `! w" |& D5 h! N! w3 r7 `model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)2 h# R( P+ Q6 ]* G- I5 P

    3 S, Q1 ~2 p/ k1 m8 d0 c# GPU 模式% @. L( {# q& Y3 L
    model_ft = model_ft.to(device) # 扔到GPU中
    1 ^/ Y0 r& n! x5 i( M6 R6 }
    7 [8 v; j# _8 v* T! j4 \' ^# R# 保存文件的名字
    + @2 q: _! a+ q  l' u$ Lfilename='checkpoint.pth'# r/ \4 X6 V0 p
    $ r# G5 z: H" _7 I- i8 D
    # 加载模型+ y- ~3 i; n: ^1 ?1 b! g4 ^% f' c( b1 }
    checkpoint = torch.load(filename)8 H7 \/ {% ^/ ~4 |5 W! o8 V
    best_acc = checkpoint['best_acc']& N; T+ g. b6 k1 e) n: y( z
    model_ft.load_state_dict(checkpoint['state_dict'])
    " \- d+ U) S- G' [* [, M' J1( u' Y6 c8 b9 w( \4 t$ P
    2% `# {8 o% e. R; M. p
    3
    , z( J2 ^* I9 v. Y' k9 D6 f; ], t46 t2 Z/ w0 E! e3 ^% d
    5+ D: }: O. q; {* _7 C
    6  L8 k7 c% @! W$ Z! S( @6 u# u
    7% h5 ~6 Z  ?. _
    80 V' {$ H8 a) M- k# N7 {, i  A
    9
    ( |' Q3 E' D, [8 |; \7 B7 x10
    : g, z+ y$ l: i9 N+ J11
    1 z8 M. t% I% {4 k' L12+ ^5 K0 }( s+ ~/ \# a% q
    <All keys matched successfully>
    4 r6 S, {0 Z. ]" v3 I5 V+ v18 ?, O) ~+ n0 e" I6 j3 ]
    def process_image(image_path):( H% [) H7 @! l6 M- ~: N) L
        # 读取测试集数据
    " B9 W+ v6 i7 ], V7 m    img = Image.open(image_path)
    % k+ y5 a& a4 y4 G4 U/ }) A  u( G    # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    - q8 Y- v. T' u* x    # 与Resize不同
    . n5 J, @3 D; m    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    0 V' I/ D" J: q6 \: u3 j    # 而且对象调用方法会直接改变其大小,返回None. D7 N- [# i" w" u% I
        if img.size[0] > img.size[1]:; ^, O! g* b3 r7 n
            img.thumbnail((10000, 256))
    % i7 h6 z; x* H8 y    else:
    & R5 O* |5 E( ^        img.thumbnail((256, 10000))
    . L; e! z5 @2 l5 n3 ]- I6 i5 [7 r; [2 O
    , M' A2 h4 [8 ]$ w7 M) M    # crop操作, 将图像再次裁剪为 224 * 224
    ( g, R. m/ S/ l1 ~    left_margin = (img.width - 224) / 2 # 取中间的部分
    5 K, m+ q9 k; ^- `    bottom_margin = (img.height - 224) / 2
    2 ]" A6 @& s' h0 A- ]    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度8 u0 N- N8 d1 P
        top_margin = bottom_margin + 224
    ) C9 l( Q6 Q8 S- J. z* u2 n" A8 n( ], @4 A
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    0 \! M. {# ?( i8 w: l* {! u* k" T/ b( m) u/ D1 S
        # 相同预处理的方法
    , @+ i% j% T' O+ k( Z8 J    # 归一化
    / ]7 }5 L- o3 O$ N    img = np.array(img) / 2550 s+ c( c) |$ m; V/ v
        mean = np.array([0.485, 0.456, 0.406])
    # q! J  @, @0 r5 s, E2 e5 o9 M8 H    std = np.array([0.229, 0.224, 0.225])
    & S1 |  y# x9 l* L! J3 w    img = (img - mean) / std
    & o' H& y! P4 s# W: ~2 x( i1 L
    1 W9 V5 C6 S9 E  p8 i9 \    # 注意颜色通道和位置
    - F+ b1 d$ H6 d    img = img.transpose((2, 0, 1))' u; Q5 Q: U+ B
    3 _- T, w# _3 H8 W; E9 ^2 V, H* @
        return img- V  S1 y. d$ \& t( \) ^8 i

    ' k- M  I* g+ T0 }# ^def imshow(image, ax = None, title = None):
    ! b' f9 q) e* k  N- ^  [    """展示数据"""( R5 V6 Y6 ?; d* I+ I3 E, [3 r' i- k# ~
        if ax is None:
    1 E# Z7 h5 }8 L! {. z% g" _        fig, ax = plt.subplots()6 s! z# ~, O) D" Z

    ( t; x0 w$ J" B( ?* t6 m1 }# Q    # 颜色通道进行还原- z( ?3 Z" Q- _0 H
        image = np.array(image).transpose((1, 2, 0))7 T6 W, I( H( C8 E
    8 }# g5 e! _9 e
        # 预处理还原
    2 M4 n" {% x/ Y& P: |% ~( G    mean = np.array([0.485, 0.456, 0.406])
    1 k, c1 [1 N; U' y" Q6 `" u2 g    std = np.array([0.229, 0.224, 0.225])
    ! ~: C, ~- D9 V/ Z2 u7 y9 R- T    image = std * image + mean% b7 `0 V. R0 |
        image = np.clip(image, 0, 1)2 E& h+ q( o5 h4 u3 b6 H8 o( K

    0 x( a  \9 w' y    ax.imshow(image)
    + D" [3 M' S: d4 Q. e    ax.set_title(title)
    ) N8 W% ^4 @# D: Q0 ]" Y$ ?9 ?$ g; t; h4 M+ G
        return ax
    ! v3 K# o$ G, F3 t
    % O4 i  b. a0 @/ G1 I) Simage_path = r'./flower_data/valid/3/image_06621.jpg'
    " p, N, X7 {2 zimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    6 U3 r+ F8 V# A0 P! z4 wimshow(img)& J- U& Y  K7 R4 \

    % X& k8 s& @0 |* r& [0 s0 t1" T9 p) i0 R" a) c/ Q
    2
    % Y1 o0 {6 ?5 }1 P0 ^4 t* c3. ^8 G. T' Y8 |3 n0 |/ }5 A3 q
    4
    - u7 F8 Q$ a  O+ }- D# V5
    * E) {5 _- g3 D. V2 i6
    8 O8 e; O, c0 g4 D9 ^70 H3 ^/ L5 H8 t/ `( y& L
    89 @, e. u+ t9 B* m( l
    9
    # B0 @7 H6 f. \* Y, K% H* W5 D102 I9 h; P9 n5 F( p  X: x
    11" e' J4 W5 m* y' H* G2 T
    12
    & O& F5 t5 \) r& B: e13
    / G5 M  {/ \9 s14; i* G' L: t8 ^% E' Y/ K  X+ d
    15  ]" `2 ]: j! F4 @5 h0 J
    16
    7 \! R) w! L9 t17
    3 d& _2 |/ p# ~7 A" k; H# X9 T18
    + x* A; J; {  ^2 B19, z; @* {- _' `/ w
    20+ x. y6 U1 Y* O" L5 O
    21( n+ r  A# o: f& M3 p: S* X* R
    226 s1 U2 ]# y2 u( A
    23, R- }; o2 v- M' d; d$ b
    24
    3 e' T1 [( g2 I2 L25' ^. v: E( L3 ?0 C% a
    26
    7 q3 `2 k4 O" T# Q# v/ l. Q27
      R2 N$ w, T/ e8 K, H5 z7 e28
    " i+ u. ^( U8 R8 ^294 s1 P6 r: Y$ U1 A2 T0 r
    30
    * _/ S  ~/ R% \# V. p7 p31* b9 `$ {4 T- W& a4 O. N3 T
    32
    % s2 y) T  F( E6 @33: a( E# U2 j1 ?0 W
    34
    0 b$ p8 A/ H7 P9 h0 Q35
    % W. h$ c9 g* f36* H% I( M. k& w5 |: G+ `, J
    378 n3 k6 @. g5 x& O' D% c
    385 p# \% I) l0 {
    39) S7 M/ S. l+ K/ G9 M% c
    407 ]: M" {0 P5 o5 Z/ T& y
    41+ m% U5 w4 v% i8 ^% ]3 b
    42
    ' r1 @; j- _) W" W- b- w- Q43  b2 r8 U! E3 |. t
    44/ |) f$ B# q# e+ |
    45  f" L% ^$ D5 `& |
    469 G1 e, V, g+ `6 F, Y2 K
    47
    , e4 F; `# Z) w/ j7 o+ Z48& J* m* @3 v3 s& I* S
    49
    % C- [- W& j, B% w# a. H50  K0 e. j/ d- q. {, j
    51% Q" C3 Y1 [: j9 j
    52, w3 ]5 C1 S# t
    535 g$ D, ~/ k. B2 D
    54
    : a5 ~* `$ S9 B9 d3 }<AxesSubplot:>
    6 [- x1 U9 Y" u2 V5 @: R1
    - r! T: ~7 x9 z. Z3 V' ~, O, e3 v" S2 E7 ~! ]
    上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    $ _# r) Y1 _/ y& m  r
    ' _( w. O5 h' S, w. wimg.shape5 V6 P) V- Q" l4 v# a7 p
    19 [. }8 S- L/ h9 R9 J
    (3, 224, 224)
    ; b; q  Z9 V! n; [9 w7 }' ^1( l4 k; h. N3 y8 A/ z
    证明了通道提前了,而且大小没改变/ W# G, U" i2 c  l6 |9 N" c
    0 A+ h2 Z% k4 v. b8 A$ B! f
    9. 推理
      C5 w( q. I% C2 @* B0 Kimg.shape
    : K9 {! x! `3 }* d7 k; V% J) a) X: r8 U7 }( ?+ e7 M
    # 得到一个batch的测试数据7 J& ?9 v* U5 H0 J. T2 o. I
    dataiter = iter(dataloaders['valid'])
    ; |: G0 I; J* j1 z2 p/ Zimages, labels = dataiter.next()
    : A" {7 i0 @3 {; K% N, F9 }& p# f+ f- r
    model_ft.eval()$ u9 W! u# e  y$ |

    - K2 p/ N) p3 l  K7 t$ I4 Dif train_on_gpu:# n  b, Y, Z/ k8 H6 \
        # 前向传播跑一次会得到output
    * A. V5 p9 D- n1 f* J  t    output = model_ft(images.cuda())
    - ?  H" J# T) q7 r6 xelse:
    / I9 }- c% @& t    output = model_ft(images)1 T, Y  `; d* z6 l
      ~* i$ e# F8 g0 r1 ~$ j7 M; X
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值' ]1 Q* }7 h9 m- Q: W; j* S1 k
    output.shape
    4 T/ Z7 y% F$ I5 t4 f
    $ D" N8 r9 M5 j) m2 s& K1
    . T  Y! ^. z7 h# l) T3 b( E* I2
    / ^' w+ F( }1 ]! R$ q8 I) P' W3 J. k3
    ' G1 O% C; f' {4# _3 n  E( x4 ~
    5
      V" _: y* d+ [9 w2 T! w2 S" Y6
    " i& u( E: }0 q  j- j9 u! o7
    9 z: V! S$ s) u2 k8 K5 a8' z4 o3 c& F7 I1 n
    9
    # ?" H2 K. F! c8 F& V( B10+ f2 e, X" c* L3 [( @
    11- @; W" R1 n* T
    12
      a+ E) C% Y3 M- m4 r13
    6 W6 V" A$ j* M+ A14
    " k' d7 p7 `5 W. W) ^+ t15
    8 m# p5 L. d5 v161 s( F' P6 H* o6 _" t
    torch.Size([8, 102])
    5 t& l7 X% }7 B# A( _. a1, l7 q4 e$ K$ g
    9.1 计算得到最大概率  z8 G+ @/ n# o
    _, preds_tensor = torch.max(output, 1)' @' C1 u6 I! f* K/ H

    ; S. j& C" p; w& Zpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    * l$ K  P7 T/ F; S* x: i1+ d! k6 L9 e1 v9 M
    2
    " `6 F& }4 `/ \2 W7 t3) \0 w. \7 |0 t$ a& C+ k
    9.2 展示预测结果$ E, [, s" o  M9 `
    fig = plt.figure(figsize = (20, 20))8 B  g2 H% k7 e8 T6 m0 o# V
    columns = 4
    3 Z8 b1 u# i9 Brows = 2
    - w# c  N' N/ x9 x% G( t: s
    $ U; Z$ }: B$ Zfor idx in range(columns * rows):
    1 D( p5 O  s# y/ b$ {- H    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    & P# ?4 C3 V/ W( S3 G4 b    plt.imshow(im_convert(images[idx]))/ L3 ]7 \7 ?& A
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
      X" u( q, `0 e1 P$ ^9 [                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    # ~: S0 Q- c: ?- o% Hplt.show()
    + }7 a, ?! x% _( q* w# 绿色的表示预测是对的,红色表示预测错了
    1 N1 |  y# i2 `; J+ \1 R1
    ' }4 S1 r3 A0 i# E% r" [9 l) |: t$ a3 k26 X$ R& D* q- T1 C6 U
    30 U$ _' S; z6 G8 b. x% j4 \
    4# K" t# u+ v+ O8 I1 g/ `& D
    5
    8 M1 e* G/ M1 u) Q7 A5 D, L6
    0 p0 f9 T+ o% M1 z, H: q7- w0 T$ [- G% N) p  j. J
    8
    8 c4 O1 w) q5 y$ a# o# M9
    4 u/ C  \4 h- }4 n; v1 M% ]. R10+ m7 m* w( r+ Z3 D3 |
    11
    0 [  c9 V* [* c% T: l& K& E! V+ u! N+ x9 ~, t

    $ M0 J1 w* D/ o* m; N& {$ l' s  q
    / b/ v; S" I# j  T. _* {————————————————
    * c7 a- ?: F, p版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。- T. P4 _5 K3 Q2 U7 @+ Q$ i
    原文链接:https://blog.csdn.net/LeungSr/article/details/1267479400 J( V5 A5 J5 H! H* i! B+ O

    & M3 m. ~; e. n$ M  m! t+ q1 x4 i6 u9 T( g" K
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-6-14 22:39 , Processed in 0.495702 second(s), 50 queries .

    回顶部