QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2722|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例5 I' F0 z: h4 J0 ]+ m; x" i
    * a3 F4 W; X! w  f3 J" R
    文章目录- T1 X" b( E1 ]( W& }- `2 W
    卷积网络实战 对花进行分类" `; }' S" \) ?: B/ y4 a% _# H1 c
    数据预处理部分& Z' \: T; w% w" V) T
    网络模块设置
    . B, x4 J: N- G3 [) D( G; [0 T% l网络模型的保存与测试
    : \% O1 m8 i" \3 Y. z数据下载:. P* ~1 _/ @8 c- }" K
    1. 导入工具包
    & I" t- A4 d  L7 a' z1 ?7 _2. 数据预处理与操作
    + A$ I4 B6 }2 B# X3. 制作好数据源
    0 f9 f3 V' s5 N读取标签对应的实际名字- j  O) W0 J' ^. h4 ]( C4 e
    4.展示一下数据
    8 I+ A9 u1 ~) |( e% F. k5. 加载models提供的模型,并直接用训练好的权重做初始化参数1 j) _4 a: k, b8 j2 e/ [
    6.初始化模型架构
    / @; u' }# G, _* P# P* U7. 设置需要训练的参数
    7 C8 s# q/ j( N! _* C% z4 t7. 训练与预测
    # X/ V& [. n  K8 ~# _0 r+ [9 h7.1 优化器设置
    $ t# b# K; o2 n7 n8 o+ K7.2 开始训练模型) O2 m* {, g  c: X% z
    7.3 训练所有层
    & d2 R- H1 g5 Z# ~  f0 K6 o开始训练4 ^& ~2 v  X7 z+ [8 f
    8. 加载已经训练的模型
    * k, C7 T, @& T5 Z( M/ a, U9. 推理: h4 J( y' X8 B9 o! u# y
    9.1 计算得到最大概率- T# f6 `& u* d. l9 n8 a' S0 U
    9.2 展示预测结果
    2 W* j/ q4 D1 o% H, U& K7 p写在最后* F4 o% j8 [* G' g. i* R  E% Y) b
    卷积网络实战 对花进行分类7 z) F2 `5 c( L. [% m
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    4 ]% d5 K' @6 N# V5 O  @3 ]5 g* x
    - k8 c- f8 g3 m7 _  l) Q& i在文件夹中有102种花,我们主要要对这些花进行分类任务
    1 y- f7 `0 ~0 |6 h文件夹结构
    7 V. [9 t, l4 I4 }, I4 I$ z
    2 n5 P) _$ [6 m# ?" U' k$ {/ mflower_data
    , ?. W$ p4 @* n8 q/ u. ]6 W8 P) V* o) j0 @6 ]/ i/ P7 e# P3 {
    train
    : ^# F5 _+ V" k; [& k) f. p" Q) M, u& ~
    1(类别)+ t! F" x! }# o3 M7 I
    2
    5 W# m% @& \8 X1 Oxxx.png / xxx.jpg
    % c$ v3 F# l: @5 M( tvalid$ H0 j( h" z* i5 ^

    % o  a" t+ G( K5 f4 x主要分为以下几个大模块  \+ S& p; W; L* m

    8 i" L/ A' a2 ]6 b- G; J$ q2 v$ }数据预处理部分
    9 p7 X7 X; d( }  K* I: l数据增强) x# |3 k& D+ y( p; n
    数据预处理
    7 d2 s3 v/ B4 ~# Z6 R" }网络模块设置
    6 Q' X4 J$ O$ p7 D" v加载预训练模型,直接调用torchVision的经典网络架构3 v' G/ E7 v0 ?; w& `
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务7 \2 ?' w8 d: e' ?
    网络模型的保存与测试9 }$ F: y% L) l5 {8 R
    模型保存可以带有选择性
      h. P% O" ]3 \3 S8 J; w数据下载:
    2 T- e1 G6 g1 B! b. t& {https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    / o/ o5 V2 B) [0 D6 R
    % M  T' O2 m( k1 y改一下文件名,然后将它放到同一根目录就可以了) T" V, u/ J% l4 b5 T8 d

    * d- S% O7 I, C2 }: j下面是我的数据根目录
    0 P! J: R5 }7 H' ~& U. y% ~/ }& Z! c, q) |2 N* o1 _) W& F) A: @, X

    " k. B* d$ B& m  [. p1. 导入工具包! v& i5 T8 f% Q# k8 K
    import os
    ! e3 D3 Z, f" L/ H& Mimport matplotlib.pyplot as plt5 K7 p! [1 G5 m% b; ~  C, g
    # 内嵌入绘图简去show的句柄
    5 I: s, i9 G8 n%matplotlib inline 4 w% |- q) v6 c0 ]. O
    import numpy as np
    % h* u/ c1 {5 n0 R; simport torch* y! i2 I  B/ r% m( @2 W
    from torch import nn. p7 J5 n) W) x* {4 U) E4 ~$ O9 G! V
    6 h; w$ Q' N, r
    import torch.optim as optim9 l7 M' G7 L6 b$ r9 r- F
    import torchvision* W* u; w4 y7 Q
    from torchvision import transforms, models, datasets- X/ Z: ?8 k, A: T7 T
    $ e2 C! _* H" l7 X% J, W
    import imageio
    8 W- `, N. {% B# k+ dimport time; u$ l; f$ M: ]! ?' `8 x8 z
    import warnings
    ! G# c* H9 a; J. i8 q$ timport random* Y1 T5 o6 h. F* x
    import sys
    . n9 O0 ^1 H0 ]" K1 Simport copy" k+ o/ b6 {2 n  y# f( [. x8 }
    import json& ~- }0 P" T. [6 a9 {6 Y; q
    from PIL import Image
    0 s4 r: P2 p3 q5 b/ s  W
    * `8 y/ a* c- g+ j2 o: T* b# }! C( o, Y+ T/ I/ i& X, a: m
    15 r( @% o  B7 O+ s* u
    2
    ! w4 a( f* s5 e) P3# c3 a7 [. |8 p4 l4 i
    40 Q$ J& A0 G+ s, R
    5" F* h' u1 I/ I+ F1 s& f3 C2 k
    64 T+ y- B6 O( b& C
    74 L0 l! i, k7 s3 Z
    8) l" g9 t1 ?! R5 {7 |4 e
    9* t1 I# g" Y7 Q2 y$ Y2 e+ q7 f3 H
    10
    ! T, C$ A$ m; z+ \4 D& ~. n117 t+ W& l. F. M* y
    12
    6 K& p# x; W- ^" }* i0 e13
    - F- D+ B- D# }1 B, w14
    8 P2 N5 O0 G0 T( _" R6 p( S. ~15
    5 ^+ Q0 s! s$ D% F2 ^5 t# e16  o& J, e8 Z5 q4 f5 V
    17
    / g% e$ O- ]+ `3 y: }5 Q2 K6 a) `18
    0 v4 q$ T4 ^( U' ^2 J) V19
    / `' q" q' r8 S& P) p$ w) ]20# p7 Z  z  u, A2 j' m; }8 m  }. _
    21
    2 z/ S6 I% v7 n% w6 ?2. 数据预处理与操作5 S$ [% i4 ]# H9 b8 _- R6 ]# Y
    #路径设置7 w0 G9 K/ b) b) |
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录$ j0 P' T4 b# o  U8 {8 E" r
    train_dir = data_dir + '/train'7 F/ f6 _1 D' l/ t* w8 \+ S
    valid_dir = data_dir + '/valid'7 Y+ l* p- |; ?
    1
      N* m: N$ B) E8 |2
    $ e. w* ~; U9 ~. S) _5 w3 l) D3
    , j- ^! u9 ]/ _4
    " n4 }. F1 }+ p( I( Spython目录点杠的组合与区别1 k* p. v; I8 z+ s
    注: 里面注明了点杠和斜杠的操作8 @# d+ `& f% p  _3 ^0 T# ~

    - P) I8 `  }# K/ J, U. A" A3. 制作好数据源, P# ^1 ?. ^; D7 ]7 _
    data_transforms中制定了所有图像预处理的操作
    6 T, a) y# x; PImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片% j- N: f  I! E- k: ~
    data_transforms = {8 n& v+ w! X$ Z1 C, G8 ?, Q. l6 x' s
        # 分成两部分,一部分是训练
    & S2 c# d6 W. ]  d7 q9 V- A    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    2 ^2 H! J: r* I3 N6 g. ?) Y7 u                                 transforms.CenterCrop(224), # 从中心处开始裁剪$ q3 \2 T& j( K7 N* U1 q" f, ?
                                     # 以某个随机的概率决定是否翻转 55开' Z  V' ^9 s+ H4 `* ^& o, y6 r
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    ; h! s4 _  d9 b9 ]( F1 L4 \) d  U                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转- m# ?4 k$ m8 _1 T& b
                                     # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相* Y- C$ d; t, \7 L9 ^
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
    - Z3 l: D+ Q; d  `* X3 T                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB2 G9 O2 u/ o+ ]
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的5 q' i4 D+ R1 V5 V
                                     transforms.ToTensor(),
    / y- }/ M6 P' c7 x  Q8 c                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差* V8 J2 \8 n5 Z5 y0 o& K2 b9 G. f' t
                                    ]),/ s  G4 ?' G0 ~# \. \$ E( K9 g
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    0 W1 u6 ^, i  S) E0 q    'valid': transforms.Compose([transforms.Resize(256),
    $ L& x3 L& ]4 W, `9 |  B/ `8 K7 O                                 transforms.CenterCrop(224),* y8 p$ t8 G8 M# c' l" _
                                     transforms.ToTensor(),
    : i+ p, D$ e, [& s& F5 r                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    . {7 Z7 h1 V- v9 \                                ]),
    " G/ d3 y$ E4 D9 l# e5 ?. ]( M! N7 I}
    $ e; ^# |  \  m, E# x/ K0 H$ h" ~9 t8 [3 W6 M  U2 H$ g
    1# }0 H5 m9 }, i. H# C4 h
    27 u$ ~8 K; |% \4 n. z* W, ~6 M  T/ E$ E- d
    3
    / U7 T! y5 ]8 Z! v4# x7 D- Q- W- }9 T7 Q2 O
    5
      ^7 o, O1 G4 C6 |; ]; n8 J6
    6 X& i4 b% p% C' H, P) p7
    5 O4 z6 V1 d2 K- [8: t+ f0 n% a7 A# T
    9
    % m. `7 o( G* i* R* E+ C" E1 k10
    2 w' T; i8 j3 p! }. ]; Y! O11. e; g7 b( S6 A$ |1 z( l1 C
    12) _: b' [4 q" H) ]
    13* P. ]" ^3 X# V6 [0 h
    14
    ; G) W7 z# p8 q( h15
    * S7 L$ [9 u4 N2 k16
    # S+ _0 q% C) x. O$ s17
    - g: q5 f0 F" Y3 M  D: O18; o$ o- l1 ?( D! ]. m
    19) `. {, N) ~; l9 H. M
    20
      i( C; t7 Y* A6 O21
    5 ]2 K0 m, ^/ \, Vbatch_size = 8
    - N* K- z' e+ G. [! Himage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    - Z) `% f; H) g' Pdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}5 F6 [  ?+ h1 P3 g0 z0 M
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} + _5 a" Y: {6 {8 f
    class_names = image_datasets['train'].classes7 _4 [1 ^5 u! B  h1 E; O
    . J4 m: `) L# p
    #查看数据集合. K$ l- v4 v* }
    image_datasets
    9 n7 E0 ~9 D: i( x- W# t9 S# s7 q1 {/ s& \. x
    15 A1 Q1 b; G; G4 ]$ }2 e! b5 ]6 v( [
    24 a% q& b/ N) R6 V* J  R
    3
    3 T8 ]+ X% e/ [# ^) [/ @2 p3 {4- x2 J% k$ v3 d$ f( s* Z
    5" d5 _& o: _6 y; I) U: y
    6
    ' q' W, ^- w5 R$ l8 q7' s# u8 |; u' g( v% [# r. g
    8
    ' F9 c! L$ |9 E# p8 m  f* z# i( o9/ k; L5 I) I& K0 I
    {'train': Dataset ImageFolder& f) Q( q* s/ J# _7 o+ Z" U" c
         Number of datapoints: 6552- q" ^- E9 D) Y& f1 ]0 Z: {. M
         Root location: ./flower_data/train
    2 D3 K# T; m. i     StandardTransform/ I9 h, n, |7 X- W( V
    Transform: Compose(/ `9 V$ n' k! d  d9 \
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)  N5 r0 Q) V* Y5 H3 u
                    CenterCrop(size=(224, 224))% s+ I$ A( S* n
                    RandomHorizontalFlip(p=0.5)
    ) G3 c% }8 g7 o                RandomVerticalFlip(p=0.5)7 w# M5 V7 u' x# l2 V
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    / z; W+ a- G* H" ~: V5 I                RandomGrayscale(p=0.025)
    * w, n) ^9 r  ]9 o, p: o                ToTensor()5 o7 h8 w5 l3 c: P. j* }) C, ^6 @; E
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    8 a$ f7 c9 _9 Y) [. b" D            ),
    2 j" }' ?$ X' D 'valid': Dataset ImageFolder
    - \! x5 P0 F8 q5 G  U* Q, Y6 C, S     Number of datapoints: 818
    4 w9 F4 M8 J  @0 t     Root location: ./flower_data/valid5 @$ G  o) r' I' F7 {0 D6 R
         StandardTransform5 e* y. [  y, _7 `% q# D
    Transform: Compose(
    4 U* W) M' W: X& r/ V& u                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    " s/ Y" e! m3 J* |; E) b( |                CenterCrop(size=(224, 224))
    0 M" R8 c" G" t' a. ^5 ]5 F                ToTensor()
    . e" C, J: d" a& {% L: N! O                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ( r$ u2 L, m. W& B; I            )}9 A3 N8 _" e1 s! T% ^4 O
    * N1 [# u7 H: J3 T' ~: N- k* N
    1# D6 p, V: P6 ~. p9 s5 p' C
    2- ]  a8 Y+ C! N/ S- C
    3
    5 Y) V- \- Y; @$ O. u+ q3 H4# d$ U0 Y. ?( T% z% r# }
    5
    ; ?9 n' b* [4 C( }' |5 E% H# a6# g3 ~4 }7 {1 P9 X6 A- }4 S
    7
    & B9 M* |  x3 u/ ^+ I8( A. M7 U1 D$ t( }5 c0 ^& y
    9# B. O' U3 d  Y: ^8 z
    10
    * R( G% ]* b5 `/ y% V( z* Z11# |6 a! [/ l9 h- r: O4 p. _4 ^
    12
    + b9 d) f5 V! o$ Q0 v5 v13
    ' H$ E! |& f" _% G  s/ z- |6 k14
    7 f/ ^! [3 M- L* N3 r' C157 g( k: g& ]7 I+ M
    16
    + T6 E$ U# c2 R1 |17
    $ n5 z4 U7 q! L( @187 |: b" W& X. h; |( D) o
    19/ e) \8 |! C$ d2 e+ Z
    20
    0 g7 K/ a! Q$ v211 R) ?4 X8 N* e
    22& s: Z* _3 X: G, G& a# T
    237 G$ s. @, A5 E3 p) D& V7 |
    24
    4 G1 |2 ^% B; Q7 N# 验证一下数据是否已经被处理完毕) U4 N! C+ |0 L) h  e+ |8 c- s2 v. |  g
    dataloaders* w) J$ m8 i5 d0 x& Y' h3 Q
    1; f" l9 @. ?- J) H( D( d3 M
    2
    # `$ D0 R7 N" u8 y( r/ Z{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,0 v+ g( |1 ?9 {" b- `
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    ; h7 m' k* F& l, |" s: Y19 r2 U5 W+ c; N8 l
    2
    ( O( f1 Y& l# y2 w: hdataset_sizes' ^  C7 C' h' J) _  ]
    1
    % ]6 M0 x/ _3 R, r. q! I{'train': 6552, 'valid': 818}/ d2 I$ M1 \" d- f; g- C
    1
    : I$ Q9 d  \6 M* d读取标签对应的实际名字7 ?, X& K, S2 v% ^# K- c
    使用同一目录下的json文件,反向映射出花对应的名字* [% R( I8 G* z. p1 b4 L! ~# x

    % F8 U+ Z1 }% L* Nwith open('./flower_data/cat_to_name.json', 'r') as f:
    & y, X8 N2 p: E, V) S$ p    cat_to_name = json.load(f)- n, n5 i2 P8 s7 F7 T/ v+ W( m. c; U6 z
    1) Q' V. D# F/ f  M# s* x+ g7 X
    2* A! K$ h* {9 G" y3 ^8 S  j- W
    cat_to_name5 @, r" {7 S1 E  E
    1
    ! e+ m: j1 G$ o% v, a9 P; b{'21': 'fire lily',  F3 r; o2 W- g' `2 ^6 T- e
    '3': 'canterbury bells',3 O; d: o1 a: f+ S
    '45': 'bolero deep blue',
    9 M  l! s) J3 D2 ]- v '1': 'pink primrose',1 B2 n2 p4 r4 a4 `0 H% @, I6 ^$ ]
    '34': 'mexican aster',) O. w9 v/ R' ]) w& M$ I
    '27': 'prince of wales feathers',
    8 Y& i1 e2 U- J '7': 'moon orchid',
    - M: f" a! k! C6 w' } '16': 'globe-flower',
      T+ e/ }, J  h! a) B# p& c' R '25': 'grape hyacinth',, [- ^5 N# s  t6 w  n3 y
    '26': 'corn poppy',
    * m: W! o3 @8 ^) Q% l# K" _ '79': 'toad lily',6 L! I& f! l0 C
    '39': 'siam tulip',
    , R% P5 ?$ \& o7 o: M& b '24': 'red ginger',
    % S5 f. d7 G) X" N8 ? '67': 'spring crocus',$ w# h1 c5 ^5 y1 r5 a! @
    '35': 'alpine sea holly',
    7 U$ c3 G7 V+ i' \0 m9 `  R8 L' b '32': 'garden phlox',3 m) X6 _1 u; G2 \; J
    '10': 'globe thistle',
    $ y/ s; q: L* M' F. i '6': 'tiger lily',
    , G0 n' O" H& C0 K '93': 'ball moss',
    ' y9 X. E& d4 [4 _$ P '33': 'love in the mist',
    8 r- L  @5 b% \8 L8 p! L# W '9': 'monkshood',. W3 ^1 a- M3 x; J4 T6 h% h3 n
    '102': 'blackberry lily',  c9 b, f0 k+ x+ K
    '14': 'spear thistle'," B1 Q0 ?5 r" ?2 C
    '19': 'balloon flower',
    3 h2 I' U3 m9 ?- J '100': 'blanket flower',. `2 ^( j+ X$ N/ E- R' q
    '13': 'king protea',. r; o. [' [6 q! O* E0 c4 P! A6 D& x
    '49': 'oxeye daisy',# p  Q! I. Z* ]( [
    '15': 'yellow iris',1 F0 K' T/ B7 o, T. H0 s* K
    '61': 'cautleya spicata',
    & g/ ]* e$ {4 c/ h$ @ '31': 'carnation',
    % n! i9 q  t6 @% c '64': 'silverbush',$ h1 T9 I* q7 }7 k$ e# k
    '68': 'bearded iris'," R# f9 \) Y* a3 h5 [+ L0 }
    '63': 'black-eyed susan',
    2 ?4 p' A/ k* L# l- M. A% J: U '69': 'windflower',' D5 [! \8 o# o9 {
    '62': 'japanese anemone',$ I  b$ X0 v" ^* k
    '20': 'giant white arum lily',* C$ c0 Z) R2 R+ x' A/ R
    '38': 'great masterwort',% m7 W! w4 z3 S
    '4': 'sweet pea',
    ' w0 C7 d) U7 O- \5 {, t7 _  [6 e '86': 'tree mallow',
    % f3 S, x$ v: ?# t) O '101': 'trumpet creeper',1 G4 L9 P- b& R6 P. y
    '42': 'daffodil',
    5 G' I0 o. Y' l: X( I  q '22': 'pincushion flower',
    ; D* @- ~. y- w' g '2': 'hard-leaved pocket orchid',6 q! \7 G; }7 ?2 P/ a$ L0 ^. W
    '54': 'sunflower',( t0 O) y; S- S& ~! O) B
    '66': 'osteospermum',
      Z" l% |! V& P5 f9 S. S( l '70': 'tree poppy',
    . q+ [4 C& ~' U- M: S '85': 'desert-rose',  c: f* d" O2 J1 ]& S+ \
    '99': 'bromelia',
    2 z5 x, C: m" H '87': 'magnolia',
    ) l7 y# i- p8 Q. N) U) G5 Q" B" I7 Z '5': 'english marigold',
    2 _) G( @$ u0 }7 V+ R0 M8 q: A '92': 'bee balm',/ _- l- I- E7 H2 B
    '28': 'stemless gentian',0 D8 x1 [( K( j$ t9 [
    '97': 'mallow',
    7 z+ x, v. J+ L; x/ P0 f '57': 'gaura',* v5 E1 @- ?2 r0 z# I' v+ J$ f
    '40': 'lenten rose',9 K% P( P3 n# d. B) T
    '47': 'marigold',
    8 E. S) W" ]0 p6 J' ?' Q3 B% U+ R '59': 'orange dahlia',$ q! m( w: Q0 G! p
    '48': 'buttercup',
      C0 |7 T2 ~/ B0 ]# L: U! B% o '55': 'pelargonium',
    : f1 R9 q; M. d7 P/ C9 E8 } '36': 'ruby-lipped cattleya',$ t7 y2 p8 ]/ X
    '91': 'hippeastrum',+ k( e7 o! f. s- L' W
    '29': 'artichoke',7 P( D8 u4 e1 K! o1 ?! [( A
    '71': 'gazania',
    " a1 c1 t" Q; i& n4 g0 ~ '90': 'canna lily',& g  q, M  a! U- [
    '18': 'peruvian lily',5 s5 S: I9 ~. _- C: p* L" |
    '98': 'mexican petunia',9 C5 S8 e$ u+ \4 w$ V3 f7 j
    '8': 'bird of paradise',
    - s- y# W1 P! y* V2 f1 ?( v '30': 'sweet william',* P1 b, {5 H# ^% \3 s, a
    '17': 'purple coneflower',
    : e" M! r7 @7 x+ ]$ Q '52': 'wild pansy',
      }, r' V) M4 W4 @ '84': 'columbine',: ~: x' X" r7 q. J0 R  q9 z
    '12': "colt's foot",
    9 j) ^# N" U6 q/ K# C/ I '11': 'snapdragon',
    0 V* _  Q+ i9 N '96': 'camellia',- j6 }' b0 D8 p& D7 c
    '23': 'fritillary',1 O( ]! \7 r0 h  i8 F
    '50': 'common dandelion',
    / A# y' T& L8 L6 _  r+ \% u '44': 'poinsettia',1 p5 t0 H8 {4 S! M4 n% r
    '53': 'primula',: L: L( e/ ^5 e; M% W/ V
    '72': 'azalea'," C' Q; V; x; D9 U8 T
    '65': 'californian poppy',7 X1 ]: w) |1 A0 P' {; G# R. u
    '80': 'anthurium',
    ' g* E2 x0 @. ^9 X; ] '76': 'morning glory',
    . P, s( l! k. j$ t, I" d4 ?; D '37': 'cape flower',# `  h# u* c% g$ s; A. o% C
    '56': 'bishop of llandaff',
    - b5 }8 D% X+ \( p' K( d '60': 'pink-yellow dahlia',
    - x/ b) H7 ^' j! g5 I0 F '82': 'clematis',
    ! M& w  ]0 w+ X; c6 z- A$ v '58': 'geranium',
    ) J2 Z; N" R8 s '75': 'thorn apple',/ g3 ], M' ~3 I% {
    '41': 'barbeton daisy',: X3 r$ Q. P; ?) z
    '95': 'bougainvillea',1 v; v# ^" h% J4 A: J/ j
    '43': 'sword lily',7 Q( Q8 _  t3 Y1 e
    '83': 'hibiscus',6 X$ _) d+ E( ]2 q
    '78': 'lotus lotus',% P0 E% x* y7 q& e( J9 F/ \' j
    '88': 'cyclamen',
    & K+ ?1 y" h+ L '94': 'foxglove',+ m' B/ ?7 L* I$ W
    '81': 'frangipani',
    7 Q8 l5 ~9 \) f '74': 'rose',* M$ T6 g" _9 a  s/ ^
    '89': 'watercress',
    " A+ g/ h! G. G. [& x# E" { '73': 'water lily',, Z: S6 g, Q( p' V
    '46': 'wallflower',8 H$ h! {9 {) c
    '77': 'passion flower',6 N( ^( z+ R" g# T% t# ^! g
    '51': 'petunia'}0 E$ K! q, A/ H/ q: c

    ) j/ Z7 e! S6 N9 I1
    , s9 Z; `4 ~: j  P% x2% o+ J* N, ]4 ?; O% d
    3
    . `8 a4 j9 C& c1 F$ s  V6 D4
    0 L: l4 _1 |. ]3 A+ t- ~5, w0 H* p2 j: k8 `
    6
    $ U8 A1 i1 y- r' \. x% k& Q& J7( T7 J, s$ \7 I4 W5 r
    8
    3 m) L! r) M0 O% @8 l95 ~, V" a7 a4 ]: N: q6 n5 P
    10
    3 m- v) R/ K; w$ V% C0 l114 B7 T! A' T3 h1 i8 O" J9 c: |9 f/ a
    12
    # w1 B5 B) R8 o; @/ Z; Y0 w13
    - Z- ]$ L/ l6 [& Z14! o7 U5 {8 d! q2 K9 ~; c: a
    15
    : B0 x8 v+ P( C; p% I16
    : p" t' @7 H9 o) @, ~" m6 T17) B& x% @2 a! J
    18
    5 _7 b8 P$ e! h! u9 ~) j19" M% b% m( X2 t1 G, d; r
    20: ~- o* n$ n9 T6 Z8 B5 F0 f
    21
    " C; W" L; ?9 U# Y. {225 n1 m. G) Q  r" R& s
    23" ^/ G. ^5 y" h! S
    24
    5 Z: }# o" n- o6 W, E' s252 p5 F8 m# j* C
    26
    / l8 G3 i/ A) Z( @* E& ~+ C27  o2 `6 l/ n: N
    288 R& y+ ^4 V: i& o; n/ W5 l  L
    29
    : \- P2 e+ Q8 k- i- v30
    # A  a( H% A" j" b; Y9 L) `' _31/ K( b& P& ]/ N0 N/ n3 Y  @9 P, m
    32' p5 {1 X& a4 \& ^- A; _( O8 g
    33
    ' C, M4 M' Z% w" }# H34
    ! B, M$ Y9 e' f5 f/ }; e35+ w9 i% w8 w- X  j$ c4 r; n  c
    36
    4 D5 Z# u5 Y5 Q37
    % Z9 {: }2 H( S4 E, L38& ]' |/ Q% g1 M; d$ p9 A9 u# S
    39
    : M/ ?( i% _& `8 i. ~. B4 Y, J40
    2 e# e) ]; n5 \+ b! E41" \4 E0 o5 d; {: S' R. X
    42
    ' K: ~" s. F* c9 t4 X/ L2 {6 k43
    - V. v. v* `7 p& A9 y; n! s: ^! N446 s4 m6 Y& H$ c! |. a
    45. d  g4 d2 c( X& ?, n" U
    469 s7 _! {' S. @1 D1 E" ~" J4 F6 I
    47. x0 \  _5 c: F
    48# F3 v8 E/ T5 w7 Q8 [; P6 b/ A0 W
    49
    ! S# A0 L7 n& i2 |9 W0 G: ^502 `1 S0 l& A1 ?  i7 _" V
    51* j+ t: }! k* k  m3 w- _! U) a
    52
    2 Z: S, @3 n3 f53
    ) |5 @, c5 ]' w54
    0 d( K* m6 L5 Y& d. T0 S55$ W3 m" t# b1 W. m8 k7 z$ b, i! o
    56
    , F- ^% y6 }: x1 ^57
    8 n, d% A, a, R; G58' U# y. i0 ~) S8 w- n7 A. i5 v
    596 |) A: d9 C  u0 o
    607 B2 n: a3 @  a3 l
    61
    # t- D: K0 O6 ?2 ~2 j/ _62
    ( |# m4 w2 F, d, u( t63
    - P% W8 n' g4 l0 N  m+ M8 Q% H641 r3 S- X- v" V( h
    65
    3 ?- B$ w& v- U% E, ~664 Q! j0 C' A5 B; s
    67
    - E* i3 b9 U2 j8 K7 x& B# e- f682 x" c( c/ I& N: m2 V, d
    693 M6 o- `- F7 r% M- D& I
    70
    5 F2 ^8 f+ u% o/ V71/ D0 n6 A1 q- x( e7 @# O+ z
    72
    $ E! [. W5 U$ f) E. v2 L" \, |73/ f3 W  A$ d9 {( H  p, U2 c( M
    74
    . m# s) K; v9 J: x1 w; H% d& U* P75
    + U0 t; T! W# @+ \& `. f766 t( \- ~. @5 q) W; l4 W
    77
    1 T, Y$ @( ~* M7 e$ J78
    ' W/ ~* w: U$ p- C$ E. N793 c: V, V7 }0 I( K+ t' t5 d
    801 [  ?' y5 [# V0 |. i5 l. p7 N" b! x
    81
    4 Z: b1 }; i3 J/ V: C8 n82
    2 G4 f* ]' F% ]2 b833 Y, `5 _# P- d) B( C& ^2 y' V. J
    84
    - ]$ e* d1 `5 G- c$ R85
    * ], G4 p& G$ E( v4 `* n0 g! C5 y86
    ' Z4 k/ j, F6 [0 y- |3 P87
    # C  m; z3 ^6 t4 F88
      r% m' F0 F: f! ?3 `# z  ^) j89
    : y! a/ Q# t4 s7 F& K$ W4 Y( @6 H90' f, c. E& x; Y7 f
    91( W3 w( Y% [# m( N0 C
    92
    ( s* z) v/ n& a$ L8 M4 Z- I8 `. B934 a4 _( c: o3 g& `$ |
    94
    1 X5 n) Q6 {/ x95
    5 W$ V+ ~5 d" `, u96% y+ |! n: r9 z* I
    97
    7 Q: H& H% f2 w* i) l98
    ! G0 O3 _6 o: q99
      g+ Y4 d+ @% V$ g' x+ a" x, M100
    . I. [0 Z/ ^* y. L101
    4 p: R& P: [& o) A$ o& A1 i102
    ! y2 E8 J) \( ?7 u/ t6 _' g4.展示一下数据
    : g& r; G0 T8 y+ Pdef im_convert(tensor):$ c/ C6 h) O- U  I9 b: [3 f
        """数据展示"""
    * t( j6 t' {  c2 U' y6 l6 d    image = tensor.to("cpu").clone().detach()
    ( f7 d/ M0 }8 ?7 {9 O; l% c7 Q    image = image.numpy().squeeze()! ]" h1 R) h# d; S+ T+ t" C
        # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    0 j% j+ G9 O8 c* A! c    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    : P& ]) n4 \2 N- Z    image = image.transpose(1, 2, 0)5 _2 _7 z" A( ?  a* ^6 _
        # 反正则化(反标准化)
      M$ ]$ Z9 z0 B, p( p% G8 O. R    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))+ N/ n9 B2 I1 x" \7 X
    % u( _6 h2 a4 m2 R
        # 将图像中小于0 的都换成0,大于的都变成1
    ( t  M- e( D6 z% j: W    image = image.clip(0, 1)! M" f' y: W  z& [0 d: `
    1 [6 n% Y( m8 h0 P+ j
        return image& a- ^1 @$ z; l
    1
    3 B' T4 Z* D, C9 j2% r* f6 j6 k; A$ G) R& D, f
    33 _6 d4 X& N7 W9 f; Q
    4( p8 g, x4 }2 n8 t! U
    54 }2 {! n6 {/ B) G1 \
    62 \' V( a  X# K
    7
    3 }+ N' Z6 z4 B* {4 z83 H' L7 k5 D4 E
    96 ^8 g. V+ A; [& l+ B, t2 S
    10
    $ U$ c: f# z) R3 j" T* {& d11% r: D- K! u7 B# C" x
    12
    ) a" k- L8 {9 b! M. J. L9 c131 y, R' X( s3 Y" B$ i0 n- G% Z
    14. V4 H2 m3 z# F5 h
    # 使用上面定义好的类进行画图) h; f8 a1 u: h6 l9 ]0 W; p' R
    fig = plt.figure(figsize = (20, 12))' {0 j, W& L" K4 X
    columns = 4
    ; h& h6 A* z3 j7 q& ]rows = 21 n6 K3 W9 g5 n5 n$ g" g9 V) b

      |) r0 T: ?- z. A5 _# iter迭代器
    ) h) @, B9 t! d/ o, z0 L; w# 随便找一个Batch数据进行展示3 y1 s  h- v% L7 Y3 U  T' c5 p4 Z8 Q
    dataiter = iter(dataloaders['valid'])* s& o" ^6 y6 q
    inputs, classes = dataiter.next()' b4 p5 C7 s! z: z" }( E+ q! m: q
    / q- E5 E" e# P" _2 i% `+ j' l
    for idx in range(columns * rows):
    9 w+ C: N$ H1 y4 W    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])9 n: h1 y! E0 ~2 V; j
        # 利用json文件将其对应花的类型打印在图片中/ M4 B  J# j* p. v3 G5 H7 \2 c
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))]), k4 D& r3 u% J- H
        plt.imshow(im_convert(inputs[idx]))& U( T8 y/ Q: |5 ^4 k4 ?  ~
    plt.show()9 e& W5 n- b% @
    ! H! [* K' K0 ^: o# H6 X
    1
    ; ?6 U! e$ E) j: `1 C* B6 w2, h4 T! _) w. T7 F' V+ O
    3
    . C3 |, f7 t* d# c5 b5 U" x47 e# r0 I6 ~- j3 O; D
    5  F$ h) Y8 B7 y+ L
    6
    6 s" b8 N* g2 r9 G7
    6 h0 j' R7 [4 e# x9 w# U, D8( Y) L1 E. v0 X' v
    9+ A3 s, J4 D* E" W7 ]/ [
    10: q; m5 ?9 o- U3 l7 s2 u
    119 F: g/ D8 |0 a" c, s& @
    12
    ; f1 J7 H3 Q( A( h133 {" [1 b5 k! w+ F0 G
    14" x1 Z; c) ?6 v
    15
    : \( C: J5 L' Z- w& {" D- k3 y16, o6 E& n3 B: l$ @" {! P; A/ ~& I2 s. V

    2 M7 n: k) f1 }9 @0 e+ ]0 p" y# c  K" V" g2 i8 X1 ?
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    # I7 i+ T4 i3 _" P/ Mmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']) o' P, Y  X( m- J# I. Q) r
    # 主要的图像识别用resnet来做6 a1 @: {( K8 l$ x& H6 `
    # 是否用人家训练好的特征
    5 x" M+ p/ l- @- [8 tfeature_extract = True+ K; p0 z* z8 `
    1
    : ~6 U  j+ D0 {21 X7 b& Y  [$ {3 c6 G1 {' a
    3, a$ k, k' m. h( _. Z+ J
    4% u& c; K# C: p' k
    # 是否用GPU进行训练* H1 r7 }% Y% A
    train_on_gpu = torch.cuda.is_available()
    9 W- n1 W# a& s  p' N  T
    3 H! P+ U5 h* x: Zif not train_on_gpu:# z$ _7 F. A$ _& w  G4 J
        print('CUDA is not available.   Training on CPU ...')% V1 }! w3 [  ^6 j" G
    else:
    # {6 p3 t" b$ s9 X; e. K    print('CUDA is available! Training on GPU ...')
    + D/ F3 D+ [8 d- }& g5 Z* u$ z
    " V3 W3 F+ s! T3 o( o4 s) |& N. o& pdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    % R# |: K. H  S: T  S/ ?7 \1
      L! X3 X# c- U6 H" }7 O3 r# y) u2
    5 K9 k4 S5 G1 I3 R5 \3" U- N; h$ H. o+ k. [
    47 q8 D3 i5 L* `+ y8 g) x3 b9 q, l7 v
    5
    * o9 F' g. s( c1 S$ G# W( T& P* ?6
    # H: o( e3 ]" U3 f: L7
    ( s5 m6 V4 Z5 O3 i, u3 @8
    ! o% p* g7 q# R$ H/ M* E  z/ D9
    0 U, i' e1 v6 s" C) ZCUDA is not available.   Training on CPU ...
    ( x" Q- X0 B) t6 b" B1# p; M- {8 U( [/ m
    # 将一些层定义为false,使其不自动更新( c2 b. G; H! z+ L( B
    def set_parameter_requires_grad(model, feature_extracting):
    2 b/ j, \2 h+ Z    if feature_extracting:8 Q; |  K  E1 k; A7 O
            for param in model.parameters():' q6 C0 e3 z9 m! e7 s$ X
                param.requires_grad = False
    # Q: J- b# X, ], }1$ Y6 T' I* M; R. F' B- @. }9 Z
    2- o# ~( y3 J' ~2 K$ i
    3! ?$ h, ]. ^4 \
    4% y% A1 E# L( n, M9 n
    5( j$ X- H( \) E2 P4 t5 ^9 h
    # 打印模型架构告知是怎么一步一步去完成的8 K( }9 v! `9 l% D" f
    # 主要是为我们提取特征的
    7 h' n2 B9 ~; {3 |9 \$ f) j+ U
      o$ r# m, ^$ ymodel_ft = models.resnet152()
    , Z! m- n+ R4 D  h% B' t' Bmodel_ft4 I6 u1 Q$ P  g' ]. \- ~4 o
    1! @; ^5 ]0 u/ |1 O9 K( \% u% `
    2( |- o9 P3 [9 V/ I
    3
    + J! R% ~$ f' F) \8 V- h/ X* k4
    2 J; v5 e+ X2 {6 _" W, t+ C- ^; K5" s0 v& H* a; \8 o; J
    ResNet(
    - \6 |$ Y7 M8 k# q6 H  p  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)" s! D; }# x% A) r- T9 u8 `1 N" t9 m
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)4 m/ A7 U# K5 r6 n
      (relu): ReLU(inplace=True), i& n1 o$ w5 F" l1 H* w, q* S
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    2 ?4 V0 o0 d4 I  (layer1): Sequential(9 `+ y; A, g: k1 S7 Z  L/ R
        (0): Bottleneck(4 x0 P+ B  J/ I" f; S# A+ Q
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)5 N. x: I% }# g! y
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- O$ d6 s. ~: E; \6 [$ F/ Y
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    . y  A( \, `8 T- T3 a8 g) t' ~8 M0 l      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): k6 D0 h: j1 _5 s1 ~  i; _
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    + o  l- A6 \/ b2 p% F0 B      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)4 u- w  X, t$ ]) S, }+ x2 h
          (relu): ReLU(inplace=True)* l3 k4 ?0 C1 V
          (downsample): Sequential(9 x4 Z5 f' X/ S* V/ V
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)4 D/ Y$ Z9 g  {7 T7 [8 d- E# b
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)! w8 Y: ^+ K! u2 D, u6 L
          )) n0 C0 B1 j1 W6 n& T9 _
        )
    2 D; x1 w& |% ]! a6 t3 \中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。2 Q+ I+ ~1 ^3 G5 O  N
        (2): Bottleneck(3 e2 D2 _& `. T, }* l* Z4 j6 a7 J
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      e1 P  j4 T8 L9 k/ ?/ m$ @8 T      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    . `& O! W  W8 S6 o( @& |      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)3 u  U& O4 p; D: L% U/ r
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    1 {3 W1 i/ r+ V  q      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False): F! d/ I3 r1 G' Z& L, ~8 a2 _# V
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ! K. {; @4 G1 o4 j) v, H4 e      (relu): ReLU(inplace=True)2 V6 E1 O/ r1 _1 K
        )/ J( @1 u+ k# P2 F
      )
    7 F/ v( Y. j  \3 {% B4 c' w/ x  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))6 @: K; @0 q7 R; ~7 u6 f3 Z1 Y% g" _
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    $ O6 w  h& e$ U- S3 Q)$ D! p# a' h+ q5 J* P( g
    ) }; o# x$ F4 p6 t' I
    18 V8 k: s% q* N, G
    22 v: d7 ?# M0 Q3 `! j
    3
    - |" `4 r2 D; x; R& |- r4
    # r7 L: e. D( M9 r/ O5
    ' S  w5 o! G0 [; k. u2 y9 G- b% `69 W* ]3 P- A9 _" p
    7" d- a& Z+ ]8 u3 y( }. K
    8; d  F; t7 D; o
    9
    : A. }6 \& ?3 N4 }1 u10. O, ^$ B$ w  k+ A8 H  v2 O
    11
    ; V9 P- ?# k. A6 `$ H( P. N12
    # j/ L8 _# ?3 e8 Z13/ a& ^: o0 h" t- @( k
    14! J, _$ H$ |; p! t. j  l
    15( ^& S- r; Q  R  b% R" |
    16
    ; ~; K; W& D2 N: o2 k17
    $ B' W3 T) z2 z- w8 [. |18& t; k/ U: q& O% I" s: P: X$ I  C3 N
    19
    ; f. g/ |5 a, U5 t5 ~20$ P. [. ]7 f  U
    21
    ; p* e% h, b9 ~22
    7 Z5 R. F$ v# Z$ a* H7 W23
    % k- j- |: x7 k1 X8 x/ h: |24
      |( B& e- `; r: V' ^: e/ @, S0 [1 ?259 S% B$ N6 R5 h9 T. f: r' D# H, M
    26! M+ y1 h% b- a/ [- f* {3 J
    278 b" V3 D4 |) T/ J; m
    280 `' }& w: K/ @1 Z, c' @
    296 H% F+ ~/ p  A$ ~! q$ q+ y. a
    30, n* k3 n' o% v6 I) t
    31/ r3 g- K* e5 X5 o# i! ^2 I
    32
    , f& D* q. k, _2 u: X7 {/ C9 V33
    " l2 v1 n- Y) ]  `最后是1000分类,2048输入,分为1000个分类1 u2 |( s9 k& S1 {
    而我们需要将我们的任务进行调整,将1000分类改为102输出
    ( H3 X1 j2 [5 r: k6 B$ `8 \% D: k5 k% W
    6.初始化模型架构# P! Z8 B( v0 ?4 n
    步骤如下:
    * Z, _3 w* @2 Z1 s+ ^0 ^+ ^6 X9 H% @" i2 W& H6 e$ }
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
    ) x. U) ~; Y2 y0 i可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)0 }6 w! l' l4 Z" E' `' V
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
    , a" i5 a! |) a3 b官方文档链接( B4 ?7 c' ?4 @" B' Y4 x8 J6 Q
    https://pytorch.org/vision/stable/models.html3 ~! i8 k0 B- [# n0 m
    % N, i" r. V7 d, U; w" s% `- Y5 o3 Z
    # 将他人的模型加载进来+ l' [9 H6 b. ]1 C7 t% I6 t9 q
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):) V/ h) |" }9 J% G6 E$ h
        # 选择适合的模型,不同的模型初始化参数不同
    9 Y7 v& Y; ~6 N/ g    model_ft = None2 w! ]7 U+ C( P- s
        input_size = 0( f8 f$ M2 a+ w5 \& C. V
    5 C) f9 \- r% m! p' q
        if model_name == "resnet":5 J' N+ S( B- @) O/ z4 a$ H
            """1 t$ R: x% ~5 Y7 b5 E4 W
            Resnet152
    ! M1 ]9 `2 L! c        """. W5 n# k( B" d& X2 R2 E
    / B3 j& P% Y" a: N6 q) b: x
            # 1. 加载与训练网络& y( t/ x2 @7 g" h
            model_ft = models.resnet152(pretrained = use_pretrained)( R% E: H7 z7 a( D1 r% ^
            # 2. 是否将提取特征的模块冻住,只训练FC层
    " Y) j9 G( K0 d8 S7 Y        set_parameter_requires_grad(model_ft, feature_extract)
      ~% H) ]+ L& ~4 k        # 3. 获得全连接层输入特征
    6 x1 L3 p3 ]5 W. q9 @        num_frts = model_ft.fc.in_features" B8 F, ~0 e2 r6 n3 j- m4 r) P
            # 4. 重新加载全连接层,设置输出102
    - f" e. a. j9 }        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),# ?/ N( Z# g- m1 X5 l' w& H
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    6 t1 f( o: i$ P1 Y( V/ y! n        input_size = 224, g* i' Y0 V+ w. ~1 U

    9 u- r" l; n4 w7 h  w    elif model_name == "alexnet":
    . S3 F2 _! Z0 r, K. O        """
    % i/ }* F) `7 z! T) V        Alexnet
    9 A, ]4 C" a7 b        """$ m% r# Q- s9 I4 Z+ A8 J: F1 r
            model_ft = models.alexnet(pretrained = use_pretrained)
    - v. Q8 j: O" y& y        set_parameter_requires_grad(model_ft, feature_extract)% o! e8 e* u- a2 _6 F- _
    : [: p( Q0 i7 n# {" {
            # 将最后一个特征输出替换 序号为【6】的分类器
    ! t. O6 ]1 _% u% E        num_frts = model_ft.classifier[6].in_features # 获得FC层输入7 X8 j0 n+ L' e7 M4 \
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)1 e8 G. R' E! t+ |! V/ t; R
            input_size = 224; A  |+ S) [1 c3 C$ P" P

    ! r) [5 m- Y$ i; K2 N% i* `3 P    elif model_name == "vgg":
    ' \: v/ A% j5 W' e  _        """
      @* r; B( @+ D) g$ l. s        VGG11_bn
    6 a+ P5 k8 x  b2 U        """
    * E8 J( r1 h& `1 m/ e        model_ft = models.vgg16(pretrained = use_pretrained)1 J7 |6 _  \+ X/ O3 z
            set_parameter_requires_grad(model_ft, feature_extract)
    0 ^1 H3 S1 h4 a        num_frts = model_ft.classifier[6].in_features
    7 x. ]: j2 B5 b        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    * X' B( D( R+ ~8 S- K        input_size = 224
    3 P& F- g- J! h* m, h) D3 O
    ! ^1 x7 Q. L7 h9 n; K  H3 D/ B6 t    elif model_name == "squeezenet":) x) P, h. r. W- s" Y( c
            """
    & u! k( u# V! `) }7 v+ E) I        Squeezenet
    / G0 X# _& m7 z, {, s, w        """- I4 T0 v5 H8 a  C. A6 V
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)- v8 d' p( Q( l1 n  a! e
            set_parameter_requires_grad(model_ft, feature_extract)
    6 |' |  f5 {, l& e( r        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
    1 V' O8 I/ f) E3 U. n' ]        model_ft.num_classes = num_classes
    ( K" t" [& [4 z, E5 y$ M        input_size = 224
    + w; [) w8 n; `( r" [3 U
    7 G# R  t& O3 O" l    elif model_name == "densenet":6 d; ]8 ~& j7 j3 w. r1 o
            """
    " q+ X( k( f# f+ M# t; O* e        Densenet' K* U% J- c: [* A% X7 w7 _
            """
    7 K) e2 i4 Y4 l" ~  l        model_ft = models.desenet121(pretrained = use_pretrained); G+ j: @7 L4 e, G0 m5 U( ]* z
            set_parameter_requires_grad(model_ft, feature_extract)/ y1 o" H1 T% n2 M4 Y2 h
            num_frts = model_ft.classifier.in_features
    % Z% x9 l/ o; c# ]2 ?+ ^        model_ft.classifier = nn.Linear(num_frts, num_classes)9 Y2 f' ?8 E+ D1 t. j- l* k
            input_size = 224; F) ]+ G# c$ x7 d' ^; V  w

    ' u/ ^) I4 m, u) `    elif model_name == "inception":
    + \. W) @% C. O: }/ ]$ L* u        """
    9 l5 V9 X1 W* R/ R& v/ l7 Y        Inception V3  b5 G, ?; Y  E3 j
            """7 t5 y" P6 V( @% U! S
            model_ft = models.inception_V(pretrained = use_pretrained)
    ) m5 ], O. v/ H/ M        set_parameter_requires_grad(model_ft, feature_extract)* |. K2 A$ q! n% c. h# l$ d. J6 @

    - `0 a4 _! r/ [" c  E& [        num_frts = model_ft.AuxLogits.fc.in_features& x9 H5 [3 \$ G. V
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)3 l( p. X( G0 G/ [/ p: B7 x4 h
    : B$ A4 X" e1 m3 J6 C
            num_frts = model_ft.fc.in_features
    / m: P2 W7 B% d( W        model_ft.fc = nn.Linear(num_frts, num_classes)4 a8 e$ Q* U. \2 B/ m
            input_size = 299
    , N0 u) P# l4 G$ I8 u. z  {+ |( A7 b  N" p7 Z% `
        else:- _. t8 w* z, f2 D# h2 `
            print("Invalid model name, exiting..."), L4 @. c, v3 q- n. l. T
            exit()
    7 g4 E: ]8 h! J4 |  F; W+ S5 w) ]* l2 q8 ^
        return model_ft, input_size
    9 A5 Y, J+ K9 J+ t! t# v
    9 f% l6 E6 y- b4 w% }& |5 Z1& Y- L: C" @5 |- J) J1 F1 B
    2! N8 M( C- o6 y5 f  C1 t$ R1 f
    3
    1 _: ]. ~3 F2 T2 d: a4
      J, d* U! c/ f57 j% V+ Z6 l" V5 k/ f5 s
    6
    4 B( ?8 Z. k2 b2 L) g& `7
    3 G6 W% o% @7 @" }: s: c8 C. d8
    3 w& N( d/ S# g8 T4 V$ p4 z9/ E0 @0 a0 z. {! _( @. j3 ?
    10
    ; c% z2 @$ _5 ~* ~& x) x11: z0 f) d8 z  B8 N8 v
    12
    ) T; B% v+ @, l" x( z13/ ]* K/ K( i; j7 F
    14. A% _. v( b% C
    15
    # T* J) H: w4 F4 ]* |16% ?0 O  }* z% q5 ~
    17& l6 [8 B# w2 P; R( K; [: ~$ y2 j
    187 ?1 `2 d2 u; O$ y
    190 }4 a, b0 w2 [: w; e9 y5 q+ v9 }
    20
    ; \  U9 _# d) X2 E1 q219 j* Y9 c* e" o6 `  }& \$ p5 T' ?
    22$ @- V! S& y, A! e: k9 T
    23) r* r' p! n. j, p
    24
    # E, @. ]; ]. z/ D5 X25
    - A. |0 G4 R/ o+ p% ~26; k, a2 }# _7 y) x
    27
    ! i. _+ E) i8 j28
    3 g* k$ D/ b  T- y& L* u* R294 a, q- @) F+ U. s% E2 k* ~+ @
    30' w$ [' V7 n4 A! `' M; B! K
    31
    2 B( W1 k7 c( z. V/ C: G0 b32/ |& `8 e, C, c0 ?  V$ ]# z
    33
    ) K! D( f! l' a. H; [& c34. I4 v( Q0 z+ h  O( B5 y
    35
    , Z3 u; y. Z6 K# T. k363 e- L8 ^- x2 h- x. l5 o3 R2 P
    37
    % p+ N1 f  _; {4 \1 e) k0 q- x38
    % v, M0 @( w  U2 r' k. \8 s39& f1 `: M, }- s0 W5 B( t5 |7 ^
    40  [6 ~0 m; y0 g) S- t
    41
    - [) g6 `( b5 r+ d! g/ ]& _42
    * u# J5 A! E+ O* p3 d" q" `43+ f* m. F; A( H) D
    44
    5 c( @2 L' y* f% b45
    , h2 H9 e4 D# O8 q5 `46. N0 N4 Q1 w7 `. I* Y# O8 U: ]
    47
    $ D) b7 X& ?1 ~! F" ~48
    $ x5 O: y% ~) T- y0 C  Y+ l49+ H! u7 O+ u7 Q* t
    50
    8 I1 \* c7 \( g& x+ C51
    8 u0 o& G; t# V- m& I) S) e52
    & Y1 q2 h  q" D" f$ r8 w1 O53
    1 ~" u* B: }# X5 r" A! D54/ ~- x6 ^6 W3 Q" w
    55
    8 h6 E! p, \- d0 y9 A$ M56% b9 R: u3 l6 Z3 X/ W4 k- D/ w' [) D
    57
    * A: Q' k" U' x- k+ Q, `7 T8 P58+ ?% f2 ]5 m; x' l4 U
    594 Q; S0 \- U- |. `; p
    600 \4 t& {% E. J  x
    61- A7 q/ \0 F. ?) t/ }
    624 b$ ]/ g. O  t3 ?8 C
    63
    7 ^! e+ S* K; _  ?64
      `6 `, D# v# p0 a% d65
    1 Y! {9 u2 i# U0 t- U3 i66; P& P6 `9 n4 L  y0 `/ b
    67
    8 T2 G( K( x7 e3 G# I2 O68
    $ N% @( M. @9 c6 [% ~69
    $ ^$ d( r1 d/ V* m. H70, ^$ r; G( \, r, a2 o
    71
    4 K, M# _& q$ |8 m' T72
    7 w% X5 a; d0 G5 s9 \5 L, Q73
    : W& u$ a) b' b" V3 Y* s744 X# ]& G, C! S6 Y1 f
    75& q$ F$ A) E' F) K/ J7 b4 v& W
    76
    8 S( h! n+ P4 ~77, h% A1 i. m; {7 \+ c. c6 y
    782 z, Y; p  S* g. r
    79& x- W8 n! X6 C2 T/ w0 v# L2 q6 K
    80
    " D7 o# C- r: z) a81) Q9 m7 R: J3 i2 u' K; q' R: K
    820 w, d% R3 e* \  }! L
    832 y* Y0 T' q4 z. ~
    7. 设置需要训练的参数
    , `5 h( G" l% c' u1 H, x- ~# 设置模型名字、输出分类数
    . T. e" Z2 g1 v8 H1 G# j% H# c  F5 i* Pmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True), Y8 d; K5 o. L1 J# \
    2 l0 j8 w  s$ W0 N) l# r5 m4 `
    # GPU 计算
    % y& e! q( @1 |7 t) z* Z8 W7 r. ymodel_ft = model_ft.to(device)
    8 Y$ A( R4 J5 v2 j) C
    4 `/ X0 J" y7 x4 {0 V# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取4 t+ D$ H0 h$ `3 M' {* M$ f
    filename = 'checkpoint.pth'
    8 N8 P7 q: P. \  g' k! e4 R- p6 T4 ^5 O- n. C. e; c+ Q- e3 [8 e0 W
    # 是否训练所有层
    . k; H% s1 c3 lparams_to_update = model_ft.parameters()
    5 p$ P$ d0 x- M8 g( e# 打印出需要训练的层' j8 f5 C. n3 |0 v5 c0 D
    print("Params to learn:")
    . C3 [  S* ]; R+ ~6 vif feature_extract:. l8 Q, N) q6 M# J4 l" r) i
        params_to_update = []0 i' ~! b+ r2 w( h7 R. e% I
        for name, param in model_ft.named_parameters():6 T, W0 V7 G" H6 b4 x' m
            if param.requires_grad == True:0 @3 T; T+ f( Q0 j( }9 u  K- u
                params_to_update.append(param)! h; d) k# P4 j0 {2 M
                print("\t", name)" Q0 E8 f! _' s$ O! N! d
    else:
    % j4 x5 z4 F7 m: z7 Y    for name, param in model_ft.named_parameters():3 g. M9 c( T" P' c: W- i" U0 {
            if param.requires_grad ==True:
    ( O7 N- `/ n. m            print("\t", name): c& Y/ p8 [- }3 C
    ; U7 j8 k; X2 T$ @" H+ M0 `' b- L
    1
    0 Y6 N  e$ s: \1 R$ s' n) U& q8 a7 |2% Y% ]8 \" x4 [% x: b" Q
    3
    ) A7 D) {( |0 p8 T. }47 o# K& W. W; J3 m
    5
    + m9 |/ l3 U% c3 X: C8 t8 }! |62 c& L3 ], m' f) `! a
    7) L/ U7 B$ S! `+ L: Q" \
    8/ C+ L6 z+ p# `  z( {4 P
    9
      S- E! a- V1 R) Y2 F! z. l" t10
    * Y" j+ `) @" f( _" D1 ^$ q% E. y11
    1 u8 k7 _& G1 d/ m12
    / W6 v$ w' c4 s  @. E. o13* l1 N3 A* }/ |: _* a# d
    14
    6 n7 N  ^- U5 g& D* R' c7 C; a15, f7 n  X+ A3 L( v, [9 D
    16
    + J! t: }* Y5 e3 }4 Z  m$ b17& G9 @9 J0 X6 e! W
    18: B) m1 d. c8 t7 z' m! K, \
    19: o" X+ Q/ W2 G0 b0 n1 y+ ^& C
    20
    : w, p1 J: }+ n- L21
    # e9 x+ s, m1 V9 S1 b: d22# K; H1 X, V0 [! B9 E. _
    23( f2 M4 w0 m7 k
    Params to learn:2 T* V' e- k9 X0 C) b
             fc.0.weight5 N: R- _/ K( A! t' [  U/ g& E$ @
             fc.0.bias
    " g+ D# |. g2 R- k1
    - h) v) N; e5 ~" K" B9 b7 Z# R2
    9 a) z6 F, r1 p1 J( `7 s4 H31 }; O9 l: @' t( |2 M1 @- `
    7. 训练与预测1 j' c' z; s) U! S/ m
    7.1 优化器设置5 M$ S4 m6 w1 ]* g/ [7 R- S/ c
    # 优化器设置0 H+ v. t$ V6 H9 I% }
    optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    ' D$ b$ \1 G0 d7 m. _# 学习率衰减策略
    4 b8 o( r6 g  B/ F  J+ d, q5 c! W8 Mscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)! {# N% E. v% Z& W1 T# s  p  K4 I
    # 学习率每7个epoch衰减为原来的1/10
    1 {) t- p/ g! @! e0 {0 X2 o# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算3 s- V5 [. b* _4 L- X

    * d8 P. _9 J9 T! k' X0 E; P/ ncriterion = nn.NLLLoss(), l+ ~1 Z5 ~; z$ P8 U- ]0 i
    11 t* j1 {; q2 @( d
    2/ e/ y& `% `( |1 T) W* ]$ W% I
    3
    . y! m1 u0 Q. Q& |- j8 c6 k( X+ [7 |8 b4! Y. h) N. Z  n% w2 b/ ]6 G
    5
    / Y6 m/ F& x% b8 M/ ]6$ \2 m  z' F# `! h2 l" R
    7; E6 Y. Q0 g" J8 j0 W. N; W
    8
    + F7 `0 L6 m6 e% u' _* G# 定义训练函数6 A6 x3 B& T2 T( Z% K  S
    #is_inception:要不要用其他的网络- O0 b; S% i1 ]! j7 }
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):! S. {6 D- w+ \' `, l
        since = time.time()
    - k9 u! G6 g5 {$ k# j* K    #保存最好的准确率2 x' d2 x( L* G5 k9 |$ }% y
        best_acc = 0) Z7 |. z  |0 }+ p( y
        """, z( u# T" R2 B! a; U, V
        checkpoint = torch.load(filename)3 N# e# |6 J! a" y! u
        best_acc = checkpoint['best_acc']' {# |& P5 |, B! a5 Q
        model.load_state_dict(checkpoint['state_dict'])1 x; |) i0 E( n8 K% @/ p
        optimizer.load_state_dict(checkpoint['optimizer'])* s* t) H# _( I6 N+ b+ Q
        model.class_to_idx = checkpoint['mapping']
    3 X) F' z4 n- Y' Z    """
    1 _2 c/ \; Q+ H& c, U    #指定用GPU还是CPU9 H1 E+ o" j2 _0 X8 }2 F' M
        model.to(device)& _' G# }) n) i- H8 k" c; s  v; q3 K+ _
        #下面是为展示做的+ k( L% b6 T$ N' T+ ~. K3 P
        val_acc_history = []. K/ ~! p# C0 @5 A8 X. _2 {
        train_acc_history = []
    3 ~; C2 G5 J; d2 }" K    train_losses = []
    4 v9 q! g! A; x6 @+ q0 R- k4 O    valid_losses = []: C- U; |& {6 K
        LRs = [optimizer.param_groups[0]['lr']]0 q. f, {7 {, U+ S; x' g1 q6 B/ Q  w
        #最好的一次存下来
    . R! g) d( m  a9 O    best_model_wts = copy.deepcopy(model.state_dict())- J" z5 a( @  q* i
    ( c) i% Q0 v+ [6 u+ t& Q
        for epoch in range(num_epochs):
    0 `6 y/ s+ G% X, |! o) z' L        print('Epoch {}/{}'.format(epoch, num_epochs - 1))8 V' f" x" [: W2 ~
            print('-' * 10)/ e/ u( U9 z) W, R9 G: y
    ; b. a. f  z+ t2 d
            # 训练和验证8 p; y" T- C' ]$ n' Y: n
            for phase in ['train', 'valid']:$ ^1 {# I5 q; J2 H5 e
                if phase == 'train':/ q9 B6 e& f3 m( a
                    model.train()  # 训练5 j( J, a0 [& U0 ~8 z, I
                else:
    ) I$ A$ t. C/ ~! Q' [# M                model.eval()   # 验证, Q. I! D- f1 c( T8 E0 \- q, p4 j

      G1 N+ O  c) w* ^: e3 Q' L            running_loss = 0.0, p- g& t! T3 j8 a. ?1 X
                running_corrects = 0
    ! S+ [" p& V' p7 Y3 I& E! i! {
    8 }8 F* m+ ]) `" P* S- A- _            # 把数据都取个遍
    4 k8 N! f/ N. ]& |  i# W5 G1 [            for inputs, labels in dataloaders[phase]:  z5 _0 s3 _/ E; U9 i* f! j
                    #下面是将inputs,labels传到GPU
    % Z: m7 R6 O$ d2 R6 m6 G2 s& q  R! J                inputs = inputs.to(device)
    " W7 _2 G' k8 I* s                labels = labels.to(device)6 ]4 E( M0 C% Z. m; n+ P9 W
    5 \& p! `5 B* V1 @* I. I
                    # 清零- V8 x; R# s' d
                    optimizer.zero_grad()- v( b" r' S: ]9 x  C
                    # 只有训练的时候计算和更新梯度- J, l6 x+ H' g( o8 t3 q1 J1 `
                    with torch.set_grad_enabled(phase == 'train'):3 @5 P( U% L% C# ?* S/ Z* O
                        #if这面不需要计算,可忽略
    - R- x4 t0 ~8 |+ n% {3 P                    if is_inception and phase == 'train':
    2 H4 @& X# R8 ^" s0 W                        outputs, aux_outputs = model(inputs)0 e7 f8 T3 C2 Z* O7 i
                            loss1 = criterion(outputs, labels)) }) C- r5 b' N1 o/ m; R8 C
                            loss2 = criterion(aux_outputs, labels)0 m; e. V9 ~% y
                            loss = loss1 + 0.4*loss2
    ) c* z3 s% C0 D$ a0 G                    else:#resnet执行的是这里
    ! g6 \+ P9 Y2 Y                        outputs = model(inputs)
    / L5 p; n( p; }/ g+ x" G5 ?. T3 E; l$ @                        loss = criterion(outputs, labels)' H) m3 z9 A4 l# p1 `, m

    . m* x- m, `5 \& }- y- g, r" k                        #概率最大的返回preds
    ) P: ^4 b9 l- p  l4 E+ O0 `/ F                    _, preds = torch.max(outputs, 1)! r9 X6 N1 {6 w, \- Y# u0 I
    % v5 l$ O' K% H  P# ~
                        # 训练阶段更新权重
    & z  M+ d) b' L$ G9 M                    if phase == 'train':
    " ?8 l0 A7 i1 n; Z4 u' v                        loss.backward()
    , d  f/ p- E5 t9 h                        optimizer.step()" O6 c/ r* h/ a* H/ k

    2 f3 h1 j1 x2 n+ L: g0 D                # 计算损失
    ! |6 ]# H/ j* o2 @8 z5 x7 ]                running_loss += loss.item() * inputs.size(0)) ?$ p4 H" F, g7 d! \
                    running_corrects += torch.sum(preds == labels.data): {. _+ w6 K9 K1 Q& @7 ^1 ?; S7 c/ m
    1 ]3 q9 `  v' a
                #打印操作5 {9 n6 b" ~0 G' C  ^' }8 E8 c
                epoch_loss = running_loss / len(dataloaders[phase].dataset)
    ; i% [2 m' }* T% W% H- U4 Q            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)# r' W& p& O: S3 A
    6 w0 k$ q' u# t6 b$ W2 k' R$ F( a

    + e& Y) F: c6 `2 x            time_elapsed = time.time() - since
    ' |2 F4 ?5 {2 [: k            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))# s) W) L! G! a# f$ t
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    , o7 g% x0 t: g9 p2 Z3 |& O4 [
    1 P- {0 M# ^# c( T6 X4 [$ h- T5 |2 I0 d4 g6 Z
                # 得到最好那次的模型
    ( Z, y  D# d2 Q            if phase == 'valid' and epoch_acc > best_acc:5 f1 X, ?# X7 E9 t
                    best_acc = epoch_acc8 p  t9 O" k1 ]% l# v
                    #模型保存' ~+ i' y$ s6 n
                    best_model_wts = copy.deepcopy(model.state_dict())7 L0 V% j& D& E+ l' P5 l4 x( }, l
                    state = {# V8 E: `" }5 ?$ P  g, k. {  w
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数& ^; F9 c/ {/ C# v( K; j
                      'state_dict': model.state_dict(),- Y9 v3 ~; @( ~& l( k0 m+ O
                      'best_acc': best_acc,
    * g  V+ [$ R8 f& E3 F4 D" _                  'optimizer' : optimizer.state_dict(),; W, P& m6 @/ a- ^+ c
                    }
    6 c- M3 T* N( [0 ]                torch.save(state, filename)
    + L+ f, s) i0 X6 g, t            if phase == 'valid':
    & F$ o: J. k# c$ ~- o$ r                val_acc_history.append(epoch_acc)/ {  s# g: m8 f& a/ c( T# ~
                    valid_losses.append(epoch_loss)4 \( g; K7 _/ F& B4 E$ f7 \
                    scheduler.step(epoch_loss)
    : k9 ~+ D7 f7 r+ b1 R, [. X            if phase == 'train':
      {3 I. b9 X' I% U                train_acc_history.append(epoch_acc)4 g  r  F& s9 d' I8 A: k7 o
                    train_losses.append(epoch_loss)0 N) A8 V7 j) T$ Q' f% b9 b# A) ~) M* c) [

    7 l$ t6 D( |1 `4 d0 G, i: O        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))3 k. E& X# e2 X# T& }$ d0 @/ I
            LRs.append(optimizer.param_groups[0]['lr'])9 c$ s7 M; W1 O- Y: @0 H/ b
            print()# A0 J0 W- w% x( G" z8 v8 e

    0 _1 q1 j% j# c3 f    time_elapsed = time.time() - since
    . _9 _% \9 w- }* b( ^    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))0 [, Q5 r# Z5 u
        print('Best val Acc: {:4f}'.format(best_acc))
    2 A* U. g0 ?0 {" [$ N- O
    # N1 w( ?7 U# U  W    # 保存训练完后用最好的一次当做模型最终的结果
    7 F( h4 m, S# S; m4 W    model.load_state_dict(best_model_wts)
    ! [9 d$ i" J1 b    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    8 Q  S+ o! G, J- e: @0 @( C
    $ B6 [4 n8 o/ ?- p  Z' M% J
    ( e: R5 M: O) }  [! z: a1
    ) w) L. K  p. I& q  K/ u9 x  H# f2
    / k2 j& r& I) ?+ j) U8 \' q% a6 @34 S0 s# [$ x- h9 R! r4 Y2 }
    44 \  @& N) j) B6 v, V8 ~3 e
    5
    / \# c( X4 a9 H: P' Z6
    : `; j& S. D5 F) M+ w7 n$ @7
    ; Y* K+ Z4 p& P( W86 A( N! z: ]/ A
    9: m% G9 x+ s/ n. G% v- r9 ^" `
    10
    , q( Y5 Y0 R( Z& k$ v118 X( B" j4 K$ q1 l
    12: W& U* H3 Q4 F* e: n3 D; ]8 o0 h2 l
    13- K% Z+ i, V2 N. j$ t
    14! Z$ J0 L+ }" v9 I6 u
    15, d# r$ _7 p  _
    160 ~" m- y: S4 H& ~5 n1 N
    17$ P  g8 h" w" K- t7 ]4 m
    181 y4 J* I) d6 J. A
    19# ]( T0 A$ y# e7 n& E
    20
    : j: P) r( I# N3 J4 A7 V2 i21& O* n5 K% ?* J- Y
    22
    3 M- N  V; O( Z+ l/ H4 x! K2 i23" M* y4 L8 k& o9 F; q: |
    24
    4 C+ E  z$ W4 O  R8 n! e7 A256 w1 @  n! y- z) r
    26- }+ ?  q9 v, a) _
    27
    ( j1 c+ d# F+ X' u9 R3 H289 c! G  P$ b( |: t
    29
    ( B) A0 c; U$ t" F- v" c30
    , m# N; ]7 K2 `3 w. j+ v31* t' K) p3 z$ p3 M- b* J
    32; j6 e% \4 v' h5 g9 `! P
    33
    2 }& E2 X: H7 p. h34
    7 ^2 _. ~  M! E$ A! _35
    * Y6 k6 _. l% a# C368 `3 p2 @+ [( n, [! l  P2 d
    37% I$ e* Y6 }/ R5 W# ?; B
    38% X7 i: r4 l. ~" q( r$ g
    39
    6 s- M+ ~% I: i- d7 P: ^3 b3 s40
    ) v+ ]9 Q6 k8 F  r  G41* {/ L0 B* O8 S8 A# E" J5 _) A
    42) k7 p  E9 L5 f5 @5 H
    43
    + a( E3 ^  C5 i  _! N44; _7 B1 a4 K0 u- n% x) y# R
    45/ t% K( Z0 D1 K3 w. C2 }
    467 e- {0 e2 h7 q) W" U! _; \* F' Z
    47
      a; ~& b& I5 u48
    . x3 I9 R8 b8 }# F# n49: f+ y/ t7 p+ k* w1 d2 w
    50
    / h# A$ s* p7 `. w3 E, N" `: F51% V( z* i* @( r( Q) Z1 b& j
    52
    9 B  y# z+ b/ y: I) s" v  \) b53
      P# g3 K8 d# g* |540 L3 P0 R' u% H$ L. J
    55' F- Z3 W) R* b
    56" y. c' X6 ~0 y! O0 ?- [6 k
    575 D. k. d6 A2 G- T/ u
    58) j. R3 _, `; k0 q! \4 S
    59
    # r0 {3 A2 V+ X+ x60
    1 l8 m7 b1 ~% C61
    ( E( E& N5 ]* o  b! U0 Y62
    . F7 w: {. G. {0 [) O- d63: b3 O/ M! v/ b  x3 T' L
    649 I5 m: E9 [& r
    65
    1 t* @# \' F0 h  |# j- Y( U66  i% g3 R: z2 {2 N- `6 C+ H% H: a
    67! ?0 Q' F" |" @  W( m7 M1 w# j/ I3 V: Y
    680 z, B% m, g9 |3 a* z6 M; E1 y
    69: m; v7 y. q% p! O8 o! [* u* w6 `
    70+ T& Y. V5 z) B4 A+ X* r* d# D
    711 t4 ^! J8 g0 T+ w( y: X1 ^
    72
    + X7 v/ t' {2 n73
    2 T" m2 C# u6 d. A74% G8 }1 b1 \1 \! O: }% e9 p9 D
    75/ U# N# J9 T5 W/ w3 b& C
    76# ?0 J( j: W6 u/ r1 o- L$ r5 m& C4 P
    77! ]" w- X1 C: B
    78
    2 q2 B) R" v0 \: l: O2 ^8 c79$ @  i0 {: z9 e. m- H
    805 g: U$ R7 o5 E7 g! M+ y& Y
    815 ^! i8 w1 a1 D8 Z" L
    82% Y3 V) n4 q/ j! w$ I
    831 O- b- _) h4 g/ }) r' C
    84* B1 P4 T0 [2 @' B( G
    85
    % ]4 S, H* R# @  V86
    3 ~4 i6 O- g4 r+ i) V87+ L1 v0 U; d; h8 i
    88  d3 Q5 Z4 O4 Z; a5 c/ F- {
    891 T! _9 M/ q& D5 o5 c
    90
    # ?7 Q; N* v6 J' d$ d/ ]91$ c$ s5 F$ }8 ]! D0 d
    92
    ! P! x0 x$ \2 E7 X93, ^! Z; a" V7 l# A9 e
    94
    3 b8 T7 l2 @, o; v8 ?" q5 z5 x95
    ) k! s" n' ?3 [& C) n( D96! H! z( i" m4 `& q2 m+ H# n
    97
    - E# J/ [* I* V: ^4 ?98
    ; G2 Y# x5 y2 i% P# B; n3 _: c- J99
    + }( |, `( p3 X: H& [100
    % c  w( ]0 l* f8 a, T$ |6 r101! ~+ q4 H' T4 O! O
    102( o0 o0 B) N; `: A1 ~+ }+ i
    1030 Q9 N' V! L" a: F* D1 {6 Z6 ?
    104! }8 u1 }# o$ a' }) R
    105
    . l  W* ]: Y0 {- k3 a6 q106
    ; M$ r- y7 _! W7 X% m! u+ X% ]107
    ' P$ `! X, K1 w: W0 I8 i108
    5 C9 F4 y' v2 V1 E8 r109
    ' A, z" i1 i, n9 I110
    9 ?6 s! z# X1 x* |+ i  R8 @111
    5 |% m: c8 B! J$ Z112  u* i6 l% W  b3 n! U5 b1 ^5 u" E
    7.2 开始训练模型
    # c+ @$ ]0 Y, D* ?我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    4 I& K$ O, `3 {* e8 z" X, d9 c% w4 g2 z: x
    #若太慢,把epoch调低,迭代50次可能好些& b& x, g$ K# o. S
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    4 m5 S2 |7 B; t2 i6 R  t+ Pmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    6 l2 @9 P8 j+ a
    $ t/ e. V- P% O7 J9 {& g1
    1 N) U" \# c- x: R% ^8 X6 r' s2 B2
    8 c9 A* g( t( k1 G3
    2 T5 i8 ?! Q' w1 r3 o4
    / D+ J  |0 W9 N* e: Q4 Q% d. YEpoch 0/43 D; X7 e' {3 C
    ----------
    $ v) \) e+ @8 d$ `8 [6 ^Time elapsed 29m 41s
    5 W7 R! z5 A0 m7 g1 j2 etrain Loss: 10.4774 Acc: 0.3147
    0 P0 `' V2 m9 K0 t( R8 cTime elapsed 32m 54s
    6 M) k' X' Y4 r5 y+ avalid Loss: 8.2902 Acc: 0.4719
    ) _' Q* o* e! B1 y$ q$ u) xOptimizer learning rate : 0.0010000
    8 |/ T0 f8 c6 Y% J0 f" S) `$ C
    + j' Y* Z+ k: E1 z9 T' {Epoch 1/4
    * P3 k' }3 ~: f, d. ~! I" _) u----------
    8 o  z/ u& O; j3 H1 ~9 vTime elapsed 60m 11s
    - E4 D1 c5 h2 F4 J% j) s% ?: g0 Ltrain Loss: 2.3126 Acc: 0.7053" ]6 _) ~1 t/ ?' ~
    Time elapsed 63m 16s
    # u  }; Q, T( \9 p/ @7 |valid Loss: 3.2325 Acc: 0.6626
    ) A4 Y% g% e! ?- y& F/ D* I, IOptimizer learning rate : 0.01000000 l  G/ M- S: \
    5 n, X0 @2 H+ Y# J$ L
    Epoch 2/4+ f  O  D1 U" q! L' M
    ----------9 r4 r  e8 W$ m" N6 ?) A0 O
    Time elapsed 90m 58s' e  c: n$ F+ i' \7 m
    train Loss: 9.9720 Acc: 0.4734, ?( ?0 B7 @. J& r3 p7 g0 g
    Time elapsed 94m 4s, a% I% u& E* u# `5 O
    valid Loss: 14.0426 Acc: 0.4413; M2 f7 T. X+ _
    Optimizer learning rate : 0.0001000
    - A' @# {: b( \" L1 y. I9 @/ \( Q8 N, b; {: K9 s
    Epoch 3/4
    ) R! N& `8 g( K1 }# j2 [----------3 M8 m' q& L, A" T" ?* M3 K
    Time elapsed 132m 49s
    4 Q$ e1 e. T/ t, vtrain Loss: 5.4290 Acc: 0.65487 T* l; R7 B# B4 ]( S! ~
    Time elapsed 138m 49s
    6 N* g. K0 U( U' Fvalid Loss: 6.4208 Acc: 0.6027
    ' g, N' B& B, |0 ]Optimizer learning rate : 0.0100000+ c' @* E, ?& r- E

    ' h% `0 n# [; X. K  aEpoch 4/4- Y2 a7 B# V. N8 P* c
    ----------
    ) q  V$ [2 _  V) c+ S/ [8 tTime elapsed 195m 56s0 r2 r+ y7 c+ X
    train Loss: 8.8911 Acc: 0.5519# z) W6 [5 W$ z8 g
    Time elapsed 199m 16s
      f/ c0 Q# L; |$ C( @. {! K- Ovalid Loss: 13.2221 Acc: 0.4914) Z' m2 j4 F  f7 E+ \
    Optimizer learning rate : 0.0010000( g5 `2 Q( g) \2 }0 K3 K
    ; A1 Y0 f4 Z4 g
    Training complete in 199m 16s" L0 M, a1 F3 F6 R+ r6 T2 C
    Best val Acc: 0.662592# x1 Y/ B7 h. t, l& L4 i( ^/ ~
    9 @0 Q0 W" o) r: v# r' i+ @
    1
    $ i& x, C. ?# p" {$ j21 ^8 P& Q! I& ]9 l3 U
    3
    9 X/ R$ ^9 Z5 u( i6 z# u  p4
    5 P4 d0 Z) b' b* @8 k$ p- g$ G5  }3 W, f- b( r+ ~" [9 v+ O0 e
    61 a% o7 J* \0 t3 E9 Q
    7
      M8 J( G- t& t: E, b* \8
    0 F  E( p) U" `( H9) ^# J: h$ J+ B4 }6 ]! B0 f
    10
    . J0 V2 ^# V: S; z11, _: V* r$ H7 j3 f5 x* m+ _
    12
    % z1 v1 U3 I7 W& K3 k13
    1 s* r  d4 |6 \14
    / I3 R" i" R  b15$ W3 `5 D6 N# B0 e7 q7 m# O( O
    16# N% I! [) h; u" r0 x' @- @; Q
    17
    7 G/ X/ Y0 M! X" G2 O. ^. B18
    $ W0 u6 X+ V- U4 s. ~; u19# \9 q1 f5 G% A) g
    20
    * T+ C; ]# `3 k! ?+ O! P217 b4 P5 f# v" u3 e% Y
    22
    / M+ N' r6 o% `' o/ o0 F23  g9 b% A2 W+ z9 Z) k2 s
    24
    : h# d" N% g6 ^$ ]1 z! D+ U25
    : ~% v' N5 |$ c; O6 a& U26
    ) L- j0 E) Y' c# \5 v27
    2 x5 }* u2 J  r1 i9 ^+ |+ b28
    % W- L1 |- f& C! u$ w; m& ^29
    3 h% i. T+ ~5 S30( ]3 p* ]2 N( b2 K0 |/ N
    31. r) b; Z$ G7 b6 E1 A7 ?* F
    32
    5 `6 V$ S2 |: m1 M33
    - p" X+ G, O/ H6 r34* Y5 f6 ~# H4 a2 {: m9 q1 v
    35
    + Q5 k2 T' t8 g8 U2 R3 P, o+ |36
    * b: {5 x+ q6 G' a2 h: N7 A37, l1 ~/ \  N+ B
    38% H+ R- n6 F3 h7 S
    39
    % s- @" T0 @( I: O" g' E' z40
    9 A. w: G. X9 s5 D$ K3 g7 ?9 D( b. c41& [: l( H; ?1 {7 T( n# @0 W
    428 {. n! W6 m; S+ j
    7.3 训练所有层
    & l8 q9 n3 b5 L& ]8 F, R$ O& [# 将全部网络解锁进行训练2 ?. f# c8 v7 v& }- G. Y
    for param in model_ft.parameters():
    8 }( }7 e+ @: E    param.requires_grad = True
    + o- v" s) E: D3 Q3 E+ u* {: ^9 a0 A
    # 再继续训练所有的参数,学习率调小一点\& w/ i: M* r! |4 h/ \( j2 [% j
    optimizer = optim.Adam(params_to_update, lr = 1e-4)
    " R# }# p  q4 z8 {) C7 u  u6 Sscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)$ b9 b' i7 M; R& N/ b
    9 V4 c$ q( p+ L
    # 损失函数
    : d5 S% s! p, o  C( Jcriterion = nn.NLLLoss()5 x. H3 [$ k( g5 F3 r( x$ C
    1) @+ k6 G" `3 P) B. m9 X8 O2 T7 J
    2! x$ c2 D3 X* X- f
    35 h) N# Y' P; o, U! q) j( ~7 @
    45 o% o9 u2 {3 o. D$ p  x# ]
    5
    % f, l0 r9 c" y: ~: ^: y/ M, Q6
    7 c) q( ^5 M; w' z7
    2 h4 [! R: x  N2 ~. y2 V8& B2 S8 W& W  v7 O
    9  R$ ^2 V- W6 {3 }6 v% t8 S9 C7 R
    10
    ) y" a  d% V; J( @$ S6 ^1 `7 }9 u# 加载保存的参数- c2 N5 o/ ?% i) ]
    # 并在原有的模型基础上继续训练
    ! d# }; s0 T0 [6 |# 下面保存的是刚刚训练效果较好的路径
    " k) A, j9 D7 m% acheckpoint = torch.load(filename)3 G2 |3 A. n3 i# `
    best_acc = checkpoint['best_acc']
    ) d5 _# I, q9 Lmodel_ft.load_state_dict(checkpoint['state_dict']), Z( {) V4 V3 g: g
    optimizer.load_state_dict(checkpoint['optimizer'])
    , ]( h  y+ Y+ H1
    * S2 O' o$ O' A* S20 }2 j/ v, a# [& u8 I2 ?
    3
    / t0 d& b  ], A! w! R9 K+ X4
    , J- f9 h  u2 X: Y' ]! {, i5
    2 t, s; h/ ?/ f6
    6 u. V9 N1 T% w# H7
    ' `; {% K  B' a# i7 W6 H3 `开始训练
    6 i/ ?# d* i4 Y. e8 B注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    6 I# h9 ~' @: y% E3 u# ^' k2 S. r0 ~4 G2 c
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    * r: y$ i- e  `9 J7 X+ B1. L3 H2 R& G% B2 ]/ L+ o
    Epoch 0/1
    ( J1 R- K0 J- X+ [----------; p8 n5 _7 x2 O7 [; P' |( x
    Time elapsed 35m 22s9 D5 N5 N6 c" x' o
    train Loss: 1.7636 Acc: 0.7346) ~% b3 ~; e2 U
    Time elapsed 38m 42s
    0 I2 p( H7 o9 hvalid Loss: 3.6377 Acc: 0.6455# q' m0 L& y, n) y. ^2 x/ y
    Optimizer learning rate : 0.0010000$ Q& m6 j! Q: S9 l' y

    , e6 ^/ ]' A% Q  {( O2 X# jEpoch 1/1
    $ Q6 m9 G0 h" Y6 O% G3 Q. A- a----------
    * m+ Z8 h8 N0 O9 p: `Time elapsed 82m 59s; M1 I* L2 x4 N+ S4 z
    train Loss: 1.7543 Acc: 0.7340
    / o( w5 t% x# V- C7 \4 `Time elapsed 86m 11s& D8 @- e! V* T5 J: n
    valid Loss: 3.8275 Acc: 0.6137
    0 D' ?1 \# `$ N/ Y3 t, \Optimizer learning rate : 0.0010000# |( Y7 u0 B3 m( h! ~
    ( B/ n" m8 O, g1 c3 f- e0 _$ A) ^- u$ M6 e" Q
    Training complete in 86m 11s# F  g" f- n& O* U
    Best val Acc: 0.645477- \, _' c7 j! t) g+ n3 {- a

    $ w1 J& w* u0 `+ Q4 m1% P/ Q( }' x* H7 m. n
    2
    $ Z; C; @& H- F, `: T% r& Z3
    7 N/ P; p$ J) l; p  K4
    & d& x3 K: k' `5 G2 m" i6 A52 n8 d. c5 B' d: c, s5 r" X
    6
    ! w2 F) X' B# |3 {9 E7+ H% K( [" Q* \0 T9 z, d( g7 J
    82 Z9 q9 t0 @) D' j# ^- B
    9
    . R1 h! ?. h6 O: x9 ~. J10; Y. w8 R3 h  t+ x' L  _; d3 `
    11; F6 N, G- o4 F7 ^& `7 |, o
    12- P/ i+ s! N  ?$ b( U  ]. a: S' a
    13; D" ~5 B. C4 I  k  U% R( m
    144 e+ `2 @& w! T* h0 K7 Q6 G2 _/ j
    15
    4 R! M. S" p9 d9 Y# j% d16! g- j% M8 ~. c6 T! x" g3 w9 x2 ~
    177 \8 x" D; H; ?# ]3 u- w4 H
    18
    " X8 f3 R2 h0 y/ O/ p" C8. 加载已经训练的模型
    5 X7 }0 o% u1 \4 V相当于做一次简单的前向传播(逻辑推理),不用更新参数
    ) ?7 u. }1 `* ?- x
    9 ^  s/ e% q4 ?+ s. gmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True); t1 S! E5 |& b( [! O- V
    : b! Y! e4 T% d* p2 e
    # GPU 模式& A5 J. H$ l3 L" q! ?, Q* s3 ~& O
    model_ft = model_ft.to(device) # 扔到GPU中* k4 A) f. Y! g& r" ^  E

    4 W# [9 W2 u6 e/ O& A# 保存文件的名字! e1 {: r( D7 q. K) B# b
    filename='checkpoint.pth'
    $ W6 l& d3 f# [; E: O, e1 h4 Q3 V% t7 V, ?+ Q8 A: b2 f- D8 j
    # 加载模型
    ' Y1 l0 d; J# P' s. n  o1 K9 D0 jcheckpoint = torch.load(filename)
    8 \. G% d6 k7 v! L6 M1 U8 m, x. Ubest_acc = checkpoint['best_acc']+ z, ?! y; Y9 e5 @* y; x
    model_ft.load_state_dict(checkpoint['state_dict'])
    , D9 D0 v6 n. |( f1
    " d8 e8 W( a8 f; D% Q4 P2. E5 R* `; n+ g2 ^  s/ R
    32 M, g; d% ?: H9 C& P
    4
      r; R: P: g  u$ V56 s5 o8 `8 q% W+ f2 D
    64 n# r* Y4 D! H8 H
    7
    , O# J, ?9 V2 V5 J! T( Y( N  `8
    0 ]+ K3 a. H! x9( ~5 Z, J, e. U; p# N6 I
    100 o, [, b7 U4 C# G4 c  Y- G. I4 \
    11
    3 A4 s% p! ~' a12
    # u# m1 v. B( I7 U! L<All keys matched successfully>
    & [% |, l+ D0 j  m1
    8 G, I) C; v3 tdef process_image(image_path):
    / [0 I  y* s! G3 F3 A    # 读取测试集数据- F( `/ j/ _: j" B" }
        img = Image.open(image_path)1 K8 B( D" G* h/ B) l# a. N
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断1 u4 W' F  x" `4 n. T0 Q4 t6 ^4 P
        # 与Resize不同
    - z! [* ^0 T' Q' s7 [  A9 Q    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    ' C0 |; X! B9 C8 t0 E% u" P# b    # 而且对象调用方法会直接改变其大小,返回None
    3 x1 l+ b; }( i* Z    if img.size[0] > img.size[1]:5 X8 y) L2 L; Z
            img.thumbnail((10000, 256))$ e5 ?4 q* L) x; i% I$ b% F! b
        else:
    # i/ @+ R9 M  u# V* K        img.thumbnail((256, 10000))2 Z8 w; h# H! i7 d
    , b9 N. J: W9 Z( d% h1 n+ L8 B1 A8 ?
        # crop操作, 将图像再次裁剪为 224 * 224
    1 u) e5 D7 v/ l7 ]$ S! d    left_margin = (img.width - 224) / 2 # 取中间的部分
    0 _' |& b; S: S    bottom_margin = (img.height - 224) / 2
    0 [! L2 \: V: H  L/ {4 U    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    5 t- s& x/ d# t6 M. q    top_margin = bottom_margin + 224
    ) [: F9 ]) u7 t  D) s, k8 K2 g; c4 }. `- `7 w0 e% m
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))6 c' k. {- h3 P3 O. U% |

    : X8 R7 n9 Q- g, G  `! p    # 相同预处理的方法
    + a( \4 ~+ e. ]# Y/ J    # 归一化
    7 D( d$ c$ ?- }& X' B' U    img = np.array(img) / 255# K: p; q- N8 U% o# g
        mean = np.array([0.485, 0.456, 0.406])- p0 L' t3 l1 R$ v- \8 P( z$ p
        std = np.array([0.229, 0.224, 0.225])
    . Q  J$ b# U$ }3 z2 {7 O0 L' f    img = (img - mean) / std
    ' Y: [! J, e2 a, [
    ) S! _+ @% k/ D! F: j% O    # 注意颜色通道和位置
    ! _7 ]! t, v. |5 T0 }    img = img.transpose((2, 0, 1))
    ; W0 ^! V# L+ p# {8 y( |* ~3 }
        return img
    5 P8 w8 M8 I3 T4 `/ p' l$ I
    3 k' a) E# b& K5 hdef imshow(image, ax = None, title = None):
    9 y5 @4 j: ?9 l6 I! \7 y# m    """展示数据"""' M; l- z! E7 C. q
        if ax is None:
    . `) V  m6 `2 c* ^, [  Y& Z        fig, ax = plt.subplots()( t6 r% R- h1 V+ ^" r
    8 E: x7 p+ {4 ~! \, Q
        # 颜色通道进行还原6 k+ w; p0 \1 U+ _
        image = np.array(image).transpose((1, 2, 0))$ D5 ]3 I) O; \, ~" Y9 _3 u8 Q* o
    * o6 Y0 X  C$ j4 r& |( F# P' A% w) @
        # 预处理还原
    / u) \" i; n1 U) P+ x( c/ d    mean = np.array([0.485, 0.456, 0.406])  P; p) M/ p/ b7 h3 A
        std = np.array([0.229, 0.224, 0.225])
    ) J; C0 ~1 [( T4 q$ w. t. a    image = std * image + mean
    * u: v7 v" \1 v    image = np.clip(image, 0, 1)* f4 T+ i3 E: b3 V. ~
    , z% w( `- i; O! p
        ax.imshow(image)
    . b& Q' u+ H% f, Y3 v    ax.set_title(title)2 `3 U5 ]5 |; n0 a2 V

    ' I7 I0 c0 Q( a8 J+ G    return ax
    * o3 L) w3 M( [. |
    3 r7 n5 [. M7 Q7 rimage_path = r'./flower_data/valid/3/image_06621.jpg'/ u, N# E3 Q! C
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. `" K, p$ [9 B6 ]. l
    imshow(img)  N9 U% A0 J: q: B0 q# t0 w

    9 I: u8 x% d2 s7 l# J10 h. w! y' J) A2 h0 ?" t( r% s
    2. a5 y# J4 K: o' n& ~& K& Q
    3) n/ r) T4 ^( L
    4  p: Q/ k& O# i+ J0 b
    5
    . }9 Z5 R+ {2 ]+ ]; d9 @& _6
    " Q% W  C; c0 y7  ^( _1 M2 G3 |  Y
    8, V8 X/ c" G+ M7 H- w4 ~/ C% v
    9: r, g7 [" V3 h' g
    10$ M: t; G- j  z2 S
    11% ]3 P. O& U# _0 z9 M. C3 c
    12. @+ C/ j( i* C2 N" @: S
    13
    ( Q+ ^. l4 x" K% B# }14  }4 ^9 K2 V7 ]  I
    151 }6 G! k6 x6 X$ ~9 r9 \$ i3 [
    16# ~! G5 t' j  N% Y# I
    17
    ! _. D( T5 C* h; C0 l7 {: X18
    6 A, X, v& a$ D5 _19. a/ e) F7 d( s; X% l$ [! a* L* c
    20
    , u" J% [7 E% x! ]( ~+ h21
    # d) L+ ^# b+ m0 y2 e22
    ! N- ^( [& w+ b$ D% Y2 o# v6 [( q3 a3 v23  J# j$ C. p; \# u* H
    24
    6 T( T: i" G  o8 ?8 d, r% I25! @. y) _, i9 ^3 x7 u+ C
    26
    0 P. m% Q9 f/ \6 _+ c27* @. ^$ G$ b. O5 c. ]0 V
    28
    4 J4 ?( H' G, E$ j) d4 i- A# _/ g, b29  a' T9 L) O- Q- Z
    30
    # R) e* b7 H6 ]8 R" m310 e3 y% K0 i7 z* D) y- i, A
    32
    0 d5 K3 }! m1 |! K4 m1 T  N. F" [331 L2 ]1 L( d3 u. K$ ~5 m1 X
    349 _# m/ W, y2 o4 l7 B
    35
    " ]# ^$ ^3 o  r0 ^/ E366 p* G$ y; s: r
    37  Z6 c2 `6 {4 }% Y, p4 m( {  X
    38( e, _1 d* Z4 `4 R, v
    39
    / _) `: f5 p8 b" f( _7 N8 E40
    # W5 z3 n) U/ W* V, W% B/ H- a# G410 t* w8 g( U" {
    42+ Y% r7 \5 _; m/ G- A
    43( J+ `0 D( j+ _
    44
    5 H9 G4 x- ?& V& s6 W# D2 K45
    0 b. X0 n+ T0 n1 S1 g7 x46
    4 L$ Q- T; v2 J9 r472 n9 \% i, _0 ~6 r8 c
    48
    # f( \! t! Q* X( m$ L# E49
    : |: X* v. X# f0 Z% X50
    * {1 _4 @( ]) f! y- v9 v51
    & o' n9 M" p7 H# ^2 n527 }+ Y, l# k) L& q% D; ]9 s( {! S8 N3 s
    53
    7 y1 A$ w- d7 m$ z3 w549 l3 Q" W) A1 W# r  ~& m4 x. e9 `
    <AxesSubplot:>3 a+ Y- b  _7 P9 i8 [
    1& @5 @3 n; g. B1 X! |8 ~" ~6 q) f

    * C8 }. @9 U$ }3 v* a6 ^1 ^0 g上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确' M$ n9 r/ l) r
    7 {! K9 r  q/ b6 ]" S* ]- q
    img.shape* l+ M0 x# s; q5 j
    1* L( z+ O9 i8 T6 ~- U
    (3, 224, 224)
    1 _- O1 a8 }7 u$ f- _0 ?) ^1
    " A- }- N5 V) L5 D8 x证明了通道提前了,而且大小没改变
    ( j; ?! M- m- n7 D9 d0 y4 P) X2 {# D9 d
    9. 推理' V: o  S' E. S( m5 A& l& _* j; q4 Q
    img.shape
    / l5 y4 T3 _' @: o& Q7 s5 v& _- R0 \
    6 p, V$ a7 i, g/ h. Y# 得到一个batch的测试数据  E; O) m1 D# ^! q. w: r+ ?
    dataiter = iter(dataloaders['valid'])
    ) P' ]% F9 o) Q, `) ^* T; e( yimages, labels = dataiter.next()5 m$ \. Z0 j1 y- d
    4 j* C9 \  J0 k  \' n
    model_ft.eval()
    8 b4 C8 H, ~* i; o3 X5 z. Q( T" t
    1 |0 H$ b' k! U2 a; eif train_on_gpu:
    6 A* P& ~' a# {    # 前向传播跑一次会得到output* P1 M( |; F3 s5 \) E: Y7 J! f" K
        output = model_ft(images.cuda())) K& S- i/ r, }) V
    else:
    ! ?4 b8 h. e! T    output = model_ft(images)
    . A- g0 q2 M/ k5 c: V, R  N  Y, b% W
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
    + O, ^8 g( {' V/ [5 g1 moutput.shape, B6 H8 u( K0 C8 W8 A8 k
    ! t  f+ X: D" ?% d0 G- i
    1
    # z0 H  Y. E7 a6 i$ _. v; u. H& c2
    : A/ F0 G2 u# u* R9 N! N$ l3
    - |0 L0 x2 h- b, e* g2 C' L40 }/ t$ e4 i  E' |1 r
    5
    ; b5 @( D. z5 F69 L! q: T' k# `# g- _. q
    7) Z* B" g* M0 m. _& s
    8
    3 U" M% [8 P9 k) W( |9
    9 I! B& y7 ~8 l$ Q9 \10
    ) J# k$ C/ u1 D8 y" M11
    , H: I$ U* g3 V. v& A, D2 s12
    0 b- P  f1 ]. T1 ?- j13
    0 G& ~: _! u( p( m0 v  s14
    " C! B9 F( `4 O' n/ i# _* ^& H, t15
    5 N& |& z7 g' L. T16
    : |0 v' a2 j+ v) n1 }. M) O9 ]torch.Size([8, 102])
      U  J, n! P: y- c7 f, P1) }9 c1 a# \  q! C
    9.1 计算得到最大概率# Y: s: |/ D0 T7 C, h
    _, preds_tensor = torch.max(output, 1)
    % b! c5 e, Y! o' A9 H
    . a' t8 e$ Y& v% i9 O# a. S: zpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    6 N, N" f! K7 K- a1+ J' `, }  ~& ?
    2
    7 {$ Z' Q6 o, I! O37 i9 D2 S1 H" U9 K7 @- D3 i
    9.2 展示预测结果
    8 A3 |; N3 }5 ?+ r! w" r  Hfig = plt.figure(figsize = (20, 20))0 y2 E+ C& Q  m# `6 F
    columns = 47 V9 [4 F! p( _; T3 G0 o! O' c& d
    rows = 2
    4 U6 n0 }3 e' _* A$ \
    4 n" W) v7 X: ~' H: |+ K6 _; Wfor idx in range(columns * rows):0 {! @2 N* _( m' Z
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[]). B+ ^0 d5 L* ~8 G: F- [
        plt.imshow(im_convert(images[idx]))
    7 i; K& @' n7 K* T9 E    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), - m* m5 ~5 X$ }
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red")), s7 }! N* t5 Z! T" q0 K
    plt.show()
    2 z8 r0 G* ~" S' ~, f* ~5 ^# 绿色的表示预测是对的,红色表示预测错了' Z0 [- n" M$ K& ]2 ^
    14 E8 y/ q5 x) n: K4 c$ x0 ?
    2
    8 X9 F2 [/ L: u0 b6 n3
    - w+ U) a0 |) ~4
    6 C* H( l7 O% `, S3 V+ t+ i5
    7 H5 u7 [# L: n, s0 S6
    - {5 c# ?. J- |& S7
    . W" w" a7 N5 Q- p$ T, J; ]0 g: R89 Z5 K! b0 e4 [! x8 M
    9
    " _1 B$ |# Y  i# j+ m10
    / V4 a( g0 \- m# X116 ]2 z) D% n! U( g3 @+ u0 _
    1 X% j: \) e. X- s
    . L$ B- e9 I& P3 o8 q7 a3 |( f( E

    9 I' \" b8 _: v' L————————————————8 Z- a* F9 C2 i
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。% }1 {+ f% O* ?/ ?8 E9 @6 @7 Q
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    ; a$ Q( U# c' A- Q
    ( \/ f( r" p% q- ?& W
    * E; [( u$ i2 W" O* T# N
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-17 02:34 , Processed in 0.447333 second(s), 51 queries .

    回顶部