QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2714|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    * J: l! S! {- t* u- m. Q! b- x- p& r  K& b6 ^' h1 J8 J* _' f
    文章目录+ }# ^( D, o% g+ W
    卷积网络实战 对花进行分类
      `9 d; }2 K. p8 g& S2 d2 h3 J数据预处理部分% F. j! v, S) T+ B2 B" Q( o) B1 M
    网络模块设置" o& d. v0 O% p- _7 r& C) \. i
    网络模型的保存与测试
    ' }' }+ d3 k( _& m& Y0 g数据下载:* e* M, I" Z: y, T% G0 w  n; c
    1. 导入工具包) {6 ~" n$ r7 j8 C" ^
    2. 数据预处理与操作
    4 ?8 M- k' b: m5 c% [& y& K( ]3. 制作好数据源
    6 p# Q8 p+ s2 `& w. e5 \: k读取标签对应的实际名字4 D# i1 z4 p% A$ w9 v" H4 D7 ~
    4.展示一下数据. w2 h& w8 f3 P
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数& t# ?# O/ `7 g& @, Y! @  E& c8 N
    6.初始化模型架构: ]6 e' A7 Z4 @4 A& z1 _
    7. 设置需要训练的参数/ t) _: O8 m) y6 o! i& s
    7. 训练与预测
    " f3 }4 L0 p* t- s7.1 优化器设置
    * V9 e/ A3 J2 r8 P3 X* X7.2 开始训练模型2 |% g5 c" o* X
    7.3 训练所有层7 w) U' C# r( D9 }
    开始训练0 P, {4 `# G0 p$ @1 k
    8. 加载已经训练的模型
    0 C9 \" U* @7 ^$ l$ u+ Y9. 推理
    % m+ n: E% q: L0 E9.1 计算得到最大概率
    6 e" r: X9 t  w5 _; C9.2 展示预测结果- h$ C7 V4 f  H4 g0 A
    写在最后6 a2 G6 V$ ]! r6 }4 r
    卷积网络实战 对花进行分类; q% Q& V+ {6 r* x
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能4 q; |3 n: q, r7 e
    ) ~2 D  w( R% g7 W, w# c, h" H
    在文件夹中有102种花,我们主要要对这些花进行分类任务) I: i' H# f, D  q2 |
    文件夹结构* i2 I3 t! x! ^6 z
    & j/ x$ W8 S/ }0 G. |
    flower_data
    5 _# i  Q2 H5 t7 K5 W* l* C5 v$ c: \
    train
    5 M! P0 @4 g! R0 x* x0 D
    8 _+ ~$ w5 w# T, i" T: m1(类别)  a4 B! C* s- D2 T8 Q- |4 Q
    2
    . z! _5 K( Q7 _* \. n+ [8 J, `- G  xxxx.png / xxx.jpg" \8 I5 L* F: \# P
    valid: m* T8 L) a+ P- D
    / Y! S: Y% z" N3 C+ l' X
    主要分为以下几个大模块
    ( {' W( g1 Y0 Q5 G8 Y! U6 T* a# y, V: L: h& _" F# R
    数据预处理部分
    # d. m" h: R" E$ `, c' L* X数据增强
    + o( P3 ]5 [7 n  N& @数据预处理7 O: |" |8 U9 v, U' {
    网络模块设置
    5 W* ?; H% L* S4 b加载预训练模型,直接调用torchVision的经典网络架构
    9 k! r5 K1 t% W5 R- Q# U- {! w, M# R因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务; }6 U2 n. r! S. L) B: J) B
    网络模型的保存与测试
    - A7 K3 U1 c. w+ t* |+ f9 u模型保存可以带有选择性
    8 g6 N6 g1 N& C/ D6 Y数据下载:; j7 f% c! |1 i0 w2 U$ u/ W. f5 L
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset" ]% L# S! p9 R- P4 G* A
    8 y' D$ |, z  V6 o
    改一下文件名,然后将它放到同一根目录就可以了
    / y. l0 r/ I% {9 s$ i% S. ^5 w" Z' K) J& u$ T" Y0 y1 m! N9 d. s2 J
    下面是我的数据根目录3 ?* Z* ]" b4 |
    : O; c8 w  D. r1 T

    4 q6 A2 B  N& V( P1. 导入工具包
    # E7 u4 r- U6 |( Fimport os' C. E  a/ v1 m+ i
    import matplotlib.pyplot as plt
    ' y/ _' m- D! {2 L# 内嵌入绘图简去show的句柄
    & q2 M; z+ \& M  d4 }& t%matplotlib inline ; _/ P/ U( k; R. j0 f" F
    import numpy as np% o8 ?( t* k1 u: q; i
    import torch/ l9 ^9 c8 K4 N2 z8 U2 g' m
    from torch import nn
    . r6 S& H* H* h
      s7 y) X& y: [9 x) n8 qimport torch.optim as optim
      l4 T- `) X' L: y6 z, t! K4 timport torchvision" _  H7 T7 E1 W' L7 a" c
    from torchvision import transforms, models, datasets, x) @  @1 Q+ J: `
    0 n; P- E% O: W5 f. m! D
    import imageio9 Q; W  V, s1 K2 x
    import time
    , y0 v6 ?6 r' S4 c9 Oimport warnings
    ) n& k* I0 D/ Nimport random
    $ t: c3 e* \" d4 Iimport sys
    6 T! e2 t$ d* r* wimport copy9 \- `5 @, `, L( W  y
    import json
    6 K# E8 X3 n5 G% V) }from PIL import Image
    $ k& ]0 b$ K3 n" B: J' r1 T8 M. E" I: }5 t$ E( @
    / s7 D* |  V  C8 V/ o5 Q) _- m
    16 |% W: n1 n7 p- y2 a# S; o! s! K
    2! p- c- v  n! ~" o2 O. T
    3/ T7 p# E0 C) l) `: ]: W
    4
    % B* _/ o5 k, @# e6 }5
    , X+ S: O3 d4 p. i6
    7 g3 e' ^+ `% ~7
    7 Y2 h! j% C3 `$ y) {2 O+ y8
    # ~1 ~8 h( E2 {. i9
    " _% D' `2 F4 P$ N( ~1 Z, E% J10! E5 d4 n) I/ ^$ W4 U& h. `; H/ L% T
    11. t2 @0 v% X2 }5 P1 A2 F0 Q
    12) k2 z, }8 ?1 i- |0 e* I7 m$ d' f
    13
    7 @  \' N# v6 a14: E* y! A0 S: p4 `' m
    15  m6 X8 b2 }! G& ~1 F4 D
    16
    : h7 \9 c6 N+ y17" m! G# B" d9 A$ G  I, _+ H4 R
    18
    , f" n3 m9 _" b19) K" F6 c, B" M. `  L0 m! o6 x8 y
    20( W$ D: |5 z! [5 ^6 D% g+ [
    21
    8 M3 m" X/ ]- c+ F' _( a2. 数据预处理与操作+ m! h; p3 \  ]) W0 R( g4 v8 C
    #路径设置
    . X4 e& ]" h7 \% Ldata_dir = './flower_data/' # 当前文件夹下的flowerdata目录" M% w4 I7 p) d6 \- g
    train_dir = data_dir + '/train'
    9 \% t6 {1 W8 a, {valid_dir = data_dir + '/valid'
      C. X8 z/ q8 v4 n# l5 S% D1
    3 l0 [( w# |4 n1 O2. \' F: K% {% k+ n
    3
    ! l6 W! ]& N- v$ f) d4
    4 m& s2 N2 G6 [9 t$ npython目录点杠的组合与区别
    . L( a5 t! r2 Q注: 里面注明了点杠和斜杠的操作
    , a# @, H! m" m
    ; B$ T& F- l1 K/ |3. 制作好数据源
    : F* ]# o& c# p0 ]& ~) l" odata_transforms中制定了所有图像预处理的操作7 q& ]  D3 [* V7 Z$ o3 }
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片" N; N; y  b  X1 [1 B
    data_transforms = {
      |- Q" R+ B7 Q% |9 h    # 分成两部分,一部分是训练; y2 X: e; o( u7 X2 D. d
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间6 k8 r4 Y* i: O3 i% I
                                     transforms.CenterCrop(224), # 从中心处开始裁剪, {& n! T. H0 D' ?4 A
                                     # 以某个随机的概率决定是否翻转 55开2 G' U: F+ U1 w  N6 p  u
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    ! e6 H3 W9 F1 w! I8 Z+ U                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    1 ^) E8 |$ |! j6 c+ z6 F                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    5 E. i4 Y& U, K" L/ k! x$ G                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),9 J  q: z! H" B! M2 i' J+ q
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB9 g, ?5 y: ~. m  j8 N4 F6 ~: e
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    : l3 v: p- Y; Z3 t, [- t                                 transforms.ToTensor(),
    - r- C- h7 d( K; \% y                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差3 b& j- ^( h, z6 v) l5 @! a+ u$ B
                                    ]),
    , I) y! g* g& M( l    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化6 n$ @* b9 v! P$ X; b
        'valid': transforms.Compose([transforms.Resize(256),$ p/ Y2 k, i& U+ M
                                     transforms.CenterCrop(224),
    * k4 t" S, E( ]                                 transforms.ToTensor(),
    : L2 a& a% x. z0 ?0 Q. [& ?                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同8 `4 [* n7 |# h7 f
                                    ]),
    ! @$ L0 l! J1 Q" c. ^}
    4 p0 x' _1 s7 R6 v
    % O& T: b7 ]8 i1
    & R. r7 {  n: _1 `/ S' j# a1 o2
    $ M: D$ D4 n4 K) W6 o5 Z3
    * v# V/ `% S, j' p4' U% W) M) O+ N$ C/ n# ?* K5 {
    5
    7 _& _* j0 P7 |  \! Q6
    / Q/ `) B+ ?/ P/ W( \7 m* Z7
      e' U/ r; z8 J! k' c& `2 q' S8, A6 K4 U. z1 A
    9& A( y* L7 n1 D0 D# q/ g6 ?; d* U
    10
    , J0 s* y) j) X11
    ( Y5 H0 a! R/ t  u' S12
    ; ~( m0 j" }" C13- F2 |& d. u# a" m' ?: z
    14
    & K; c3 w2 w2 J+ }+ S15
    . F8 ]$ ~2 [: `; q" \16& u+ |. w$ S: N  j( E+ |6 b( e  i" b
    17$ i( l* N5 U6 A( U- B
    188 E- H5 E- h. g0 y* E0 N2 n
    19" i6 h8 z" ^& `  o$ K) l. W
    200 @- K+ g: J) C
    214 [. R. A6 d. t5 v- M4 ~
    batch_size = 8
    ; r3 b4 p7 Q1 _$ l2 T* |image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}7 [; ?' l& M( _' h) [! {
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    2 H3 V6 |. S$ j3 E5 ~dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} 4 j* D4 m5 Y/ ]" l. E, e
    class_names = image_datasets['train'].classes6 R* Z7 F  z, Q5 K- ?5 m8 }/ a4 U5 h

    * |# q% {9 c; q, b1 d$ @; R#查看数据集合
    1 i+ S& s! \6 {# `/ S7 Bimage_datasets4 d. N) H& ^7 e9 c: N
    # f. G1 g  K, p' I
    16 b  u) E# S$ e' ]4 D
    2/ t' J8 |3 l( X, |) H& e, o
    3
    / Q) K" D( S7 E% j& q46 z8 B# \# B7 G7 h
    50 ?; `7 _; m$ U% t; h
    6
    - [- E9 x- r8 y4 V# V$ ^7& H$ z& B/ n& G) R7 l2 v
    8
    9 F) }) V8 G0 y9" G. |' H# F2 p3 Q/ {- {
    {'train': Dataset ImageFolder7 ~+ _! r* P9 m% `! z
         Number of datapoints: 6552
    * U* M/ B8 ]/ ~3 o: f* e     Root location: ./flower_data/train) p% C# j: m+ L% t/ L* W8 N1 r
         StandardTransform+ o$ l2 Y4 f1 \7 Q& ]' F8 N* d. j5 y0 k
    Transform: Compose(4 L: X/ x; W4 C; j+ l1 N: e+ u% {+ v
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
    6 h  R4 L# m1 K6 E% b, [                CenterCrop(size=(224, 224))
      z0 k* j1 H; @. N1 A( I. ^" I                RandomHorizontalFlip(p=0.5)
    2 ~  S# [* P! t                RandomVerticalFlip(p=0.5)9 p+ L9 M2 i* x& A
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    , g+ Z! G. u% u+ s. l+ ?                RandomGrayscale(p=0.025)
    ! u: b, m3 v" R% c, a                ToTensor()
    % {2 s7 {6 ?/ u* x                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    4 g, c4 u! j; r; c& q3 V% K' w            ),; |5 v8 {0 m  B4 Y" m
    'valid': Dataset ImageFolder- z: X; w) S& W# o. S1 n
         Number of datapoints: 818
    $ q- `& K' c6 g6 J     Root location: ./flower_data/valid
      u4 X2 F6 G6 ~* k% B/ J+ K     StandardTransform& E6 I/ I4 z" g8 ]
    Transform: Compose(
    ) E" t$ M. e5 R" M) a8 z" D  ~8 [                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    9 L. f: D% u! d                CenterCrop(size=(224, 224))
    3 W/ m: Q0 ?5 Y$ {1 c7 x) \$ T                ToTensor(), [5 `) E+ ?" M' ?. s5 X
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    & m! O. f% K' b0 g; `8 o% y            )}% P1 d$ W. c9 t, o! p5 W

    + A- z5 R( a% O1
    : V' E5 H; c8 O! e8 ^, e8 d' X2! ~1 T& m0 u/ m
    3
    7 m9 _& l# V8 q9 `( T/ C4
    9 Y6 ~+ a) S# v5
    ) E: a6 J& H2 H8 n+ \6: q% D; B. D  X7 O
    7
    * A7 m$ M, P2 ~: E) X8
    4 E. H2 Q( q& H  y+ }; O9, P% B, O) O) e! ]' Q2 o6 ]  |
    101 T% }% \- X$ ^& \3 a" h
    11: E! P  J5 X, O
    12
    5 S& M6 P4 n* a4 w# d13
    . H( e2 H; X$ @  W; o2 [$ k- L  D141 S, G; K3 T! F* `# _/ r( x3 S* D" x
    158 Z3 l5 `; ]/ d) T3 L2 Q
    168 Q5 |' t. F9 J6 Q0 S4 [
    17
    % n$ u5 }- F6 `4 y1 i5 P183 I+ _- z/ h& {( Q2 u- D2 q
    190 K* k# u( {( Y1 P
    20' f. i6 g# Q4 t' W
    21$ j8 h& I& j2 k. g; G) D
    225 v. G6 w9 Z0 ]# Q; e5 ^- s6 {
    23# D+ N2 M1 w% f4 S
    24' j( U" }# c/ H+ J, v; s# S7 t8 n( \6 }
    # 验证一下数据是否已经被处理完毕
      Y% W; |7 m; l. M& P2 ]dataloaders
    ' y& X6 L) Z3 K# J  l1
    ; |9 L; p& Y. z0 k' \# B2/ w+ {+ s" k. }! l# {4 R( j  A+ C
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,, x- C! a) Q9 x  l8 d, y# _' @$ P6 u
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    2 A; J* C; A) U1! M: Z- k* _& N" \2 p: j
    2
    ! \  P% R, y. mdataset_sizes
    # B8 I' O" t4 }4 ?, Z2 b( j9 E1
    * P( b) p7 X  u$ V/ |' J{'train': 6552, 'valid': 818}: P. v4 }1 B- D! [1 ^) L8 _
    1( s2 q7 V: Y: O- O! x& o
    读取标签对应的实际名字9 @% s5 I, y( c% @" [
    使用同一目录下的json文件,反向映射出花对应的名字% z% P# Q' r5 h# J- j2 {

    7 w% P8 u( o% L8 f, s7 O( O( awith open('./flower_data/cat_to_name.json', 'r') as f:
    8 q* Y, ]5 J! {0 V. N- [. L3 T7 \0 V    cat_to_name = json.load(f)
    ! f7 {% N* S( j: U  l1
    $ P" s* ?/ B* m' s2
    9 c$ z  P5 c, C' g, R/ Zcat_to_name
    9 X. {9 y: G' G8 m. u1
    & T1 E2 r5 D7 x{'21': 'fire lily',' T! R0 B  C: h6 l; |8 e
    '3': 'canterbury bells',
    & k% q" w1 w  ] '45': 'bolero deep blue',
    # M4 w$ w/ r1 U* I5 n# D '1': 'pink primrose',) n: A- ^5 J1 R5 \3 |
    '34': 'mexican aster',
    * E% N. y9 y9 X# R '27': 'prince of wales feathers',# P6 r- x: |& [5 v) q: p
    '7': 'moon orchid',8 W+ O1 k, p# m& g$ O
    '16': 'globe-flower',2 e7 x( n5 {/ S* y' k  _8 b
    '25': 'grape hyacinth',7 f2 t6 {5 L2 Z$ }5 v$ j5 y
    '26': 'corn poppy',
    & f% }8 j0 n- j4 L '79': 'toad lily',
    / `' Q  r/ E* _ '39': 'siam tulip',
    9 k  M1 Y9 [/ c '24': 'red ginger',
    : h1 h" V! [: p: y '67': 'spring crocus',) G  F* L- f1 x& X2 Q# i
    '35': 'alpine sea holly',
    & c. H) K, Z) ]% v; {+ h! ]/ u% F '32': 'garden phlox',
    - F0 Y8 I2 n7 F* m '10': 'globe thistle',
    / }, x1 m# M9 M' @- A6 j7 Q8 a '6': 'tiger lily',
    ! V% r. E0 {, E2 Z9 ] '93': 'ball moss',3 C/ k7 f3 i) l& x- G8 v* A
    '33': 'love in the mist',9 z) P; d: f! g+ E/ a3 o) p/ U
    '9': 'monkshood',
    3 B7 U$ D/ z- ~6 X1 \6 v '102': 'blackberry lily',. |! T7 A- I0 ^. e1 d! o1 |# u
    '14': 'spear thistle',# ^% b) C+ A; c9 f$ Q
    '19': 'balloon flower',, i" ]$ Y: n/ O0 }, }# u* _) j/ Q
    '100': 'blanket flower',
    # ?4 O- B9 I- f" e- ? '13': 'king protea',; i4 C/ i$ [0 l6 Y3 T, B/ ]8 v  \
    '49': 'oxeye daisy',; L7 T; _8 g# `% H1 Q( \
    '15': 'yellow iris',
    ' ~; e9 {) Z) v  `" P/ t. S; J '61': 'cautleya spicata',
    + f8 b% ~1 H0 U( H8 I '31': 'carnation',0 c; M, P: {" d0 g0 A
    '64': 'silverbush',+ b* |& s/ t9 d9 }. M
    '68': 'bearded iris',
    1 T' _" Z. J! M8 o4 b9 F8 T7 m '63': 'black-eyed susan',8 h. W, X1 F3 ~( E) ~( N$ r% o6 j
    '69': 'windflower',
    - y3 h& ~7 ?7 i '62': 'japanese anemone',# ?& g- H0 L2 z/ N/ Q
    '20': 'giant white arum lily',
    2 Z  g( v% z9 ~* K" p+ U% E '38': 'great masterwort',
    9 i( n% T+ D8 S+ k& o '4': 'sweet pea',9 P3 ^4 c, w( }' o, _9 a) E
    '86': 'tree mallow',
    ( T3 |" P! T! q" `  M0 Z '101': 'trumpet creeper',/ s8 ?  W8 S, k$ U, w  l- \. V
    '42': 'daffodil',# z4 g' W& t" N
    '22': 'pincushion flower',
    ' U1 s- `$ [* D  B% V: L7 _2 M '2': 'hard-leaved pocket orchid',
    * O) ^, D$ ?2 v/ o" B1 ] '54': 'sunflower',
    $ b: _7 Y. G$ E0 L) x '66': 'osteospermum',
    7 p* r0 X) k0 w8 b* b% R; | '70': 'tree poppy',+ {: ^4 k6 L$ n+ p9 w
    '85': 'desert-rose',8 a9 l( i5 ]. s+ u6 k7 Q* g! ^
    '99': 'bromelia',
    7 P( ?" e- V5 x0 X3 ^& {7 p '87': 'magnolia',
    0 b/ @5 J5 t0 S. M4 n5 f+ X '5': 'english marigold',+ ]8 y. g6 B0 ^+ B7 p  O6 i2 C( Y
    '92': 'bee balm',
    : u/ m- \% z9 k9 n$ ]% g '28': 'stemless gentian',0 b' L% _, q4 k9 V; q
    '97': 'mallow',
    # W& d% N3 h$ h4 s, w+ f '57': 'gaura',. D4 G: _3 v4 h
    '40': 'lenten rose',; @  ]7 B/ s/ B, L
    '47': 'marigold',
    % G1 t* F3 H) z4 o '59': 'orange dahlia',( h! {" K3 a  l: f0 B2 s. i
    '48': 'buttercup',
    / _- P* \9 S" |# G: a '55': 'pelargonium',3 Q) o6 k/ n% z/ t8 n
    '36': 'ruby-lipped cattleya',* r8 Z- u, h/ W- v; d6 m2 l; X
    '91': 'hippeastrum',
    0 h7 I6 R- a4 G! Q+ j8 O6 w! L8 I '29': 'artichoke',7 ^, V/ Y7 b( [) L1 p  V6 P, R  ?4 [
    '71': 'gazania',0 Z' f# Y+ c% V6 r
    '90': 'canna lily',& d0 e: |7 w: i+ H! ~
    '18': 'peruvian lily',
    * k( H8 F1 a7 i5 b8 ]7 g '98': 'mexican petunia',
    0 R! P8 r- j6 ?" f$ T' a: w& C '8': 'bird of paradise',) `" t* I+ a5 m5 Z
    '30': 'sweet william',1 v2 l! |2 o2 B' |1 r
    '17': 'purple coneflower',
    ( F& `' Z3 D7 r2 u  W! d! B. E '52': 'wild pansy',1 [% I/ k/ F6 I* B+ d- ~
    '84': 'columbine',6 o. G9 c3 t& o! Y% G
    '12': "colt's foot",* Z" m# O6 u# I, u/ l0 L7 d
    '11': 'snapdragon',, g4 Q8 p2 q' ]2 A) c
    '96': 'camellia',' J! E* Z% h  x4 l
    '23': 'fritillary',
    5 o9 U  G* p) i- V; T, C '50': 'common dandelion',0 v, l3 B" _9 J
    '44': 'poinsettia',: h( m' \1 `7 `+ s  `: ^2 r
    '53': 'primula',+ u: d# X, h- N# s9 ~
    '72': 'azalea',
    2 K, B' k7 y. o- h3 r '65': 'californian poppy',% H: n7 p. ?. ~, B% b
    '80': 'anthurium'," E2 V1 N$ \. D1 j7 n
    '76': 'morning glory',3 S6 y5 m4 G. a( s7 Z. ~  _$ \
    '37': 'cape flower',
    ( D' n4 Q8 v8 u* z- H- q '56': 'bishop of llandaff',( ]/ X" v9 P; O  {/ W; g; J' ]3 P8 u
    '60': 'pink-yellow dahlia',
    0 O3 W3 }; Z3 f3 T2 @% }: q$ z '82': 'clematis',- L2 g! J! W; {# d4 }: ]7 h
    '58': 'geranium',/ e, N$ W$ B# n: q- J% r" @: g) A
    '75': 'thorn apple',1 C* A( f5 ~- R& \4 K3 b$ [4 q
    '41': 'barbeton daisy',# x0 j5 A) e( L1 t0 J1 t
    '95': 'bougainvillea',5 w  I- S. p4 |$ C
    '43': 'sword lily',
    1 F3 F: k* {) s# ]3 p9 ] '83': 'hibiscus',7 T& k* `" p4 D( W
    '78': 'lotus lotus',; A# R& @3 L* k1 T! X- ~
    '88': 'cyclamen',  _- _9 v" L/ [: I3 N
    '94': 'foxglove',  g* e2 m% g8 h4 j
    '81': 'frangipani',+ ]' f0 k5 r+ [/ z
    '74': 'rose',
    ) X0 ?# D$ O, `, q6 D. O2 X '89': 'watercress',4 P/ M4 h) y3 Q
    '73': 'water lily',
    , ], o; S' |" c8 g2 V8 p '46': 'wallflower',
    $ `+ o( i. U3 g& s" ` '77': 'passion flower',
    . F0 u5 ^7 P0 h2 q '51': 'petunia'}& s* t9 C5 D! P: Y9 t& s* ]
    0 X/ H% s0 A  H3 K
    1- k$ l5 ^! G$ B8 S8 @/ j
    2
    6 |1 d1 ^7 V* N) a- c3
    & T0 a+ r; d9 G4
    8 P" ?# k; W$ E5+ Z+ {% F$ S7 n# k$ Z* G- J& @
    6
    8 p' _0 K: o* }" H3 V8 ]7. j4 ^% P4 ~3 [; ~* a
    8) F' v1 r* j3 @' \) R$ U7 c
    9  c+ q" u8 m1 q! t! V2 Q0 d! h9 u
    10
    + d! s) e  X* a7 U1 Z: T  j. G111 A3 E$ Q$ f6 w' T3 q) Y9 K
    12) S$ n- `5 d  B' c5 I9 C6 ^
    13. ~; c) U; t( O7 t+ |# f2 E- ^# }3 p
    14
    3 S" D7 W  i  G1 B$ |* r. z  j15( |% ]" l! [7 j* C% q
    16
    + C- K& W& ?6 p' C( O( D17# z2 V7 o8 C4 q' {# W1 N
    18
    2 w8 C5 [$ L! z+ ~+ p19
    * R1 m; v% M" `, J6 Z20) O3 l) e) w' B: z9 B
    210 ?3 Q& Q$ O0 U  F7 F' s2 i
    22' \3 N/ g4 v: S- Y: q
    23, o* z0 p8 h3 C5 t! n
    24
    4 \3 @3 l5 b9 K  }$ S% q25
    9 G5 [2 Y' H! t" }0 X* p7 \8 U; I( A26
    2 ?0 C; o( E4 y( B# Q1 _279 D. V0 k+ g: J3 v. c5 h
    28$ P+ u' `. L! @
    29
    : r* N( J/ r. b% Z( w30  m. I4 m) d7 o4 W) i1 g, \
    31# `: l8 |) D8 f/ i
    326 i: K9 T9 o6 r+ Q& j8 ]' P* H
    33
    2 l* p5 ?( @2 h! _1 F4 d% K34) f9 R* Y$ n9 v, r4 a7 A" ]
    35
    . M1 v; A% F4 s+ A7 s5 @( x1 z/ q36" a# O- A2 W. x* _6 n
    37
    . H8 `+ \0 i) R6 A6 H7 p0 |: V38  r' ~, [! e6 I; h
    392 i1 I# {; `5 D% f
    406 y# B7 X4 x5 ^5 y! M1 C
    41+ d2 {8 [5 j/ v) v
    42
    6 p; e3 B- E4 v+ |2 g434 C# y8 f5 {5 Y, m" u
    440 e: R. ?2 q9 M( H
    454 z; l# }( K! `' r
    46
    * s, W3 X0 b. m" e6 `471 o( u; u& @7 _5 G
    48" B! @& Q/ z" p
    49
    : W6 ]- n1 d& M2 p% j: ~50
    9 y& x0 N" f/ ]! S51
    $ ?( ~" H; ?% g5 m# i! Q52- E) A4 g; Q3 w7 _
    53$ y9 x) h! u1 _
    54
    & w9 y( a4 ^5 D55
    0 u# l- U; U5 K$ y56
    ( H' ~1 V7 v0 g: o1 Z6 J571 ~7 E6 W" f/ E0 _! v5 f3 W- `. [
    58" A/ [5 p5 f  D6 T& I
    59
    ' x' I" I5 f* ~( G& k602 v1 J. c' N" z
    61: B( R# ^- w* z7 T' T* j
    629 @: e3 Z) I0 y; F. p4 D# r/ I
    63% u1 f- M0 T: }; J0 }" U
    64/ E6 l- o! c: _4 z: a! |
    65! r. Y( D; n5 |
    66
    & ~( r- [  m" V9 y+ h67
    % h( x6 [" }" Z" E1 y7 i0 h68. X* L% g$ S$ n
    69# _5 ?+ e) J! c, {
    703 c4 U% G- v3 q5 v% ^( y( n
    71
    8 c6 o' V7 X; F* O722 l9 h% V2 Z7 e* U
    73/ m3 f0 G* A" r" l+ N6 b4 C
    74
    / {, s$ V, n  }+ \  k754 w& P  i3 n: r
    76
    0 l5 Y4 }( g" X5 a9 {7 H8 p, U77
    # Y, j! j5 X2 p$ q78
    1 w5 m6 K$ c$ B7 s79) S' ]$ I' X8 C* d
    80' a6 Y  ^! K/ ]# R* I; C
    81
    , ^: i1 k6 R. {3 R2 {82) Z) |# x$ e0 ^3 x7 F  \3 @4 \
    83  s' ~6 [! M, ~! l1 n, B) G8 V
    84
    ; U5 L$ X7 G& X9 R$ L! B: L. c85
    : }* P( k- t4 G9 C. h3 W86+ |( }- e: l$ C" f) i! c; q
    87
    / K) r" y( ]: ^9 w) z1 Z6 k7 s885 ^& t5 o& @4 @- F5 g3 q
    89
    0 T. _2 Y. y+ j' |) c3 h& B: Y90
    & c" V4 J. X# d: O9 e2 e- h  k2 A2 l91# t% m- u! d+ l8 g) e2 @
    927 V7 _* @2 J6 g
    93! E2 v  L2 {9 T& {$ ?9 G' A
    947 _% q. I) ~5 ?' D
    95& ]2 V2 {2 r, v2 X/ ~" ^
    96) T. w$ v' ]7 Q3 z1 j2 I
    97
    2 \+ N& t2 F/ f" s. i/ a5 Q98
    ) ^  R# \1 H$ h* b+ h999 t" i# P( q' s7 m
    100
    6 h. T0 M% ^# \101
    ) P2 m5 X# t. n" ]$ w102
    9 e$ w; _9 n8 y) I4 D0 L4.展示一下数据
    ! k- T1 E6 L& c" wdef im_convert(tensor):! |1 T; e: a( x" {% I( U
        """数据展示"""
    5 S" z% V2 h- k5 d  r# k7 G6 K    image = tensor.to("cpu").clone().detach()
    * l; e5 L8 g3 I5 L    image = image.numpy().squeeze()
    5 _4 Q. \: S/ i    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图, H/ {, X* U  s) m4 F( }7 I$ Z
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    ) B! v' b$ A+ Z. B# J* I3 f    image = image.transpose(1, 2, 0)3 o2 M+ T% P! w1 a! I
        # 反正则化(反标准化)6 o; w7 w* x0 T3 R5 B6 o
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    " I6 ~5 x% h7 l* o( P. H' X7 I% V
    ) E5 Q* Q" q6 B4 E7 h    # 将图像中小于0 的都换成0,大于的都变成14 P8 f) k) k$ C5 z9 \* M
        image = image.clip(0, 1)
    * i0 r3 v9 y% J& i: G1 r
    8 q: O! @! |  q( {8 p! U    return image
    4 E0 H) q9 I) ^5 J- h1# U# A  p+ S6 x1 c* I9 D* N) U
    2) I: Q. o+ u2 C7 G% B2 P- i0 W
    3
    : l: u5 I% l$ L4 g& t+ d4
    ' n4 M! f) h3 \$ [, X7 U  k/ e5
    * G; T% t8 v) K9 W6 y66 x4 f/ s$ Z+ @: H- U( F( ]
    7
    1 D1 N6 B, q7 s3 ?8 G4 P) \8: `) c$ d6 n; T: [" e
    90 z+ F8 D2 Y( v1 w- \( `
    10
    1 E# j7 N, \3 o" Q. ?11( O7 H) D% E0 r7 q
    12! R) n# g0 o) i) ^
    13  {) g2 }8 T& w) m: n
    14
    9 h3 H9 p3 f0 g5 L4 D5 y4 g% |# 使用上面定义好的类进行画图
    - I) }' |8 s9 U/ c4 l% a& C2 mfig = plt.figure(figsize = (20, 12))4 C9 r7 O$ g9 j3 |& _& u$ x3 H7 j; S- Z
    columns = 4* c; S# x# M  R7 {# Y5 a" T8 O: f
    rows = 2
    7 ~7 T  z& ^; F1 x
    0 t6 l, H2 U: |& R# iter迭代器) {' w& }6 G6 C: [
    # 随便找一个Batch数据进行展示& B! q$ p: c: W
    dataiter = iter(dataloaders['valid'])* q2 p/ X2 O, F: `4 o( Q- g
    inputs, classes = dataiter.next()
    ! p* H- K  [, v, ^7 V8 C. }3 J% \7 t, v, c$ {. O% _" Z3 r& l$ r
    for idx in range(columns * rows):
    / ^1 v7 T2 Y, v) o4 T    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    $ p' k6 R! A$ q/ V    # 利用json文件将其对应花的类型打印在图片中
    ; M+ j# P2 j( k0 _0 E& L    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    2 }  _+ X4 U+ j4 E    plt.imshow(im_convert(inputs[idx]))
    $ ]6 P6 M% I' w/ y3 Gplt.show()
    % k$ l  S% v: r4 s' R+ G# x1 J  |# @' B1 @
    1
    5 B1 N  w  n8 q" |# n2
    ; t& F7 f7 p( |3
    ) j2 {5 Z9 h8 U4 @4
    8 ]( A3 f5 Z$ n! f/ J: a5 p5
      ?2 C+ u! K' k* Z8 c6- T% d1 V, y3 V! x: [. R, m% h
    7: P7 m  s" Z: g8 h1 G' U
    8
    6 s  o  }5 a9 a! U& v; a9
    3 }/ \. t3 t7 P. d, Q( z10) N, ~5 _; T, v; J& B( n* |
    11+ f) A+ u6 m/ A# X- b
    122 q4 g2 h* Z# I" U' W% }
    13
    * O; i9 x. K5 k3 a143 P2 M1 z4 K3 x) d
    154 m, u2 C  Z+ n* Z- `, o
    166 W5 g/ u# U' h( @' n7 S. i9 R
    4 }0 f5 G& y; C& _9 o7 z
    4 U4 d; V5 {2 v2 ^) V
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数( @) d8 k% c4 R9 E
    model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']8 z, q( g; L- Y' H7 V8 L2 d
    # 主要的图像识别用resnet来做
    9 ~6 w: [. \/ \" W& U7 c# 是否用人家训练好的特征0 e6 O( b  W# J) g. U. r
    feature_extract = True
    1 W) g4 c6 }8 r0 P7 }1( J7 B% A/ }8 h' d) K
    2
      F+ J" }+ b" V. `) M3
    " D: |% \4 n: {9 D' s. G4
    : D: K& y) D- X8 w* J+ B$ h5 f# 是否用GPU进行训练
    ' x7 [5 ?9 a& G  Rtrain_on_gpu = torch.cuda.is_available()
    + y# O4 u- C- I3 u! o4 V$ X4 F' R7 {3 \, X! j
    if not train_on_gpu:& Q) c) z9 k8 b
        print('CUDA is not available.   Training on CPU ...')
    6 z; g5 z9 ~' s7 E4 [else:
    ; E6 J2 n, _$ }% a" m    print('CUDA is available! Training on GPU ...')
    7 g5 B9 ~1 [) I4 s, G" \
    & \  P% e6 x- _/ V' ^, t. Ndevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    0 B' P' i7 A) Y: x1& {) z- k  i6 ^8 l0 m4 i5 m- u, l! E/ |
    2
    : W/ o# Q$ O; j1 D6 h' g6 p35 P# H, Z" [- G2 f9 @) H/ o
    4* F/ g( q4 ?. e- Y
    5" g( m% \/ P, R4 F* L# ?
    6
    & T  ~" S' [: V7
    5 H9 n( g0 h3 Y+ {" X85 Y" Z8 S* P/ W1 _' M
    9
    $ }$ ?7 Y9 a  `9 mCUDA is not available.   Training on CPU ...! b$ Y" R% M' Z3 k( P
    1; [. ~( o4 e$ r7 U9 ^* x
    # 将一些层定义为false,使其不自动更新8 k: g: D, t( c# x* C
    def set_parameter_requires_grad(model, feature_extracting):4 g$ y! g  @, g3 u& j5 ]/ w. n
        if feature_extracting:. @: Z! `* L0 J- e, O! |
            for param in model.parameters():# T1 G* |; I+ o4 |" b
                param.requires_grad = False2 R' }: K  a7 L
    1' n4 ?0 z, K0 v
    2" ?' j1 Y  g# W" K
    3
    & S) T6 Q. b! D5 {' M* q4
    - p* M5 n- y1 ?  b5
    ! |' ^3 O, X3 h/ I  z- U4 y# g# 打印模型架构告知是怎么一步一步去完成的* H) J% h- f  R6 @. V9 n
    # 主要是为我们提取特征的
    / I( @( d$ y6 M+ A0 ~4 m- ~
    : z; I, W" m; dmodel_ft = models.resnet152()
    5 \' q" J; E' Z' hmodel_ft' n% }$ k& G) M7 H5 g
    1* Y! @' f2 S3 S' k" w& h& g
    27 \7 ~1 q3 c# [6 I# \9 J! f
    3( |4 C( F6 }9 J' \
    4. H; ]1 ]4 H* O
    5* O# h7 p9 h8 |6 D" A
    ResNet(( x/ T( m, L  c  t- }
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    6 j& R& k0 s& B/ z  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)$ C9 K+ k+ {$ g
      (relu): ReLU(inplace=True)
    ) y: N' H- l8 p2 A7 L8 ]  U# T2 l; h  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    ' l) f$ H( k! e4 _  (layer1): Sequential(
    3 a- k+ W9 a. K0 A    (0): Bottleneck(8 W! y) k2 f5 d
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)1 E5 C* W/ W1 f  d
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    0 e* M3 W0 ~/ K; ~3 g      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)( I6 |) W4 B; G* d" ^, ]7 f: J5 {
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    9 m1 N3 c3 @. j2 i" p2 }; ]1 A      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    0 I- Y! I, ]9 h, f      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)1 ~: R: h0 Z% C) }( O8 {6 i/ ?' D
          (relu): ReLU(inplace=True)
    7 ?9 _' h: X( g      (downsample): Sequential(
    " V, T/ Q: s" i4 l# J+ V6 n$ a        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    & w9 F% D1 V* {* g        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)! R% w4 e) N8 [0 {9 x# E
          ); p% e0 z; l2 g; R2 A, \- Z
        )+ f- [9 M% s7 A# c
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    + T% W  F( z4 @4 ]* E, e3 y    (2): Bottleneck(& @  @3 n- G$ J1 X0 i7 t
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    $ l4 a5 W! l; D" W4 Q. o& j% }' ]      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    4 T  ]: a  F6 L6 ]& |      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    . y! w9 Y/ \* C" s2 s- A      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    0 m, B: F. ~4 }- {5 P; m; e      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    & o7 A# P) i3 T' ?) E8 a      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    . U  N. ~# V' |! V2 [      (relu): ReLU(inplace=True)% M7 J( y+ F( p# c; Q& u3 M
        )* I  h( a2 }! n6 w9 q. k
      ); p& c- ?' J( H/ }1 M4 `  L
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))% V; V* X  K# N; e, m, ~4 ]6 s
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    # Z  N3 k) j* X/ b& L! z9 })* O+ e- ?0 G' O8 ^- v0 C
    $ G1 h% ~9 A# H
    1
    3 J' z4 Q  N* {. F4 O) H2
    9 |% `+ a5 |  N34 Q, R. [( U3 X/ W4 S6 s# X+ x
    4
    ; k2 G4 j  y' ?8 U9 Y3 Z' N56 N! b1 [& E  |6 [+ f, Y
    6. N. w5 K) |$ M# I8 p" z; {" g- W
    7
    " G3 K2 i9 g0 g0 }8
    & z1 r. g1 m5 L" G3 o90 Y& _) o$ K9 M- ~/ a5 i" v
    10; v. n/ f$ Y: x! P9 z6 M# w
    115 z+ p: }( M. u. |+ u1 |. j
    12+ X3 A5 x5 x9 @
    13
    3 Q. w4 q  b' X; ]5 `14
    + T- {3 I. I2 B& b15
    , m: T' o" P+ a* l% ]0 A168 K' }, Y( |7 f8 ^
    17
    6 \7 [/ e1 E3 `; W18( v0 w0 b# L0 `* g: |' Y2 A
    19
    , K4 t5 t! F1 }; F# t3 D. p* `4 f20' [4 X5 C2 e' j
    21
    . J) q1 I2 }& t% T# U/ J5 K22, I6 P9 I3 S8 z+ |8 O' I
    23/ ?# A$ s7 v5 f& Z+ ^, U
    242 o+ b! |& ?4 j2 F6 e1 i. j/ ]
    25
    3 a" m2 X7 f% ^261 K+ ]% ~( `6 C3 `1 W
    27
    , W2 z9 Z# C: M( m0 `( C284 c4 W! F& x9 _  }' x8 J5 _
    29
    / R: Y& S7 r9 c& X7 Y7 g# o306 v4 ^# P- K; i
    31
    1 u9 P$ X5 F; m- w5 T; b% X! s6 E* [% T32* ^& T$ x! ~! `, o! i2 `
    330 N0 |! V! e, ]
    最后是1000分类,2048输入,分为1000个分类
    , V& A; _+ A4 n4 [6 Q! U而我们需要将我们的任务进行调整,将1000分类改为102输出
    " ]+ K5 }/ k! b! ^6 Q+ L2 `. w9 w
    " S9 T6 d: ^: ?( J' E4 o6 R: c, r6.初始化模型架构
    0 l2 N% M% X  G5 c  \步骤如下:! o# }0 n. ]2 U* g, H
    ! j. u/ l4 y! Y7 u4 d) @
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数; O1 t$ ~( c8 [1 n, P
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
    ( e: _/ W& S" d; z+ x无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
    ) t7 p; R3 Z( `5 ]! w3 Z! E- k官方文档链接& e$ M# F, e- S& n- q8 o" N7 X; p7 q
    https://pytorch.org/vision/stable/models.html
    . q3 ~; g$ t+ v% `1 b- K: q$ J9 S9 j# D+ W+ w. Y
    # 将他人的模型加载进来
    % }+ \% t) K6 odef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    . C% q. M. ~1 A3 |    # 选择适合的模型,不同的模型初始化参数不同3 J; J9 `9 O2 d% n# R
        model_ft = None
    ; x7 {% `' v) j    input_size = 0
    : R) z; q0 a. U7 O8 W4 W, |6 R* g8 E* c2 J; j
        if model_name == "resnet":
    : M: f, \- K& A9 J/ x) q        """
    - g/ i" [. q" _4 x" A- O        Resnet152( z- A# c" g% D$ P* R. s
            """2 E0 p  E5 e/ ~* [
    & M8 [2 ~, R9 T9 E) q, Y3 X# @9 l
            # 1. 加载与训练网络( w  u. ]) R4 E1 ~: z# ?5 W$ w
            model_ft = models.resnet152(pretrained = use_pretrained)
    ! I' K6 _: T7 x) t$ Q: V        # 2. 是否将提取特征的模块冻住,只训练FC层
    / W% y, G+ \3 T7 S0 w7 I        set_parameter_requires_grad(model_ft, feature_extract)  t- ^- z" l1 ], W+ m
            # 3. 获得全连接层输入特征/ u/ N2 a) ?1 F: B- v
            num_frts = model_ft.fc.in_features
    * n% H1 l( O8 }2 i% Y! }) q        # 4. 重新加载全连接层,设置输出102
    9 Y! E3 ]" d0 y8 @        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
    . J  v% z1 P; l  d" U. w                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    # D" s  J7 ~# D$ x        input_size = 224
    6 E7 v( U( k0 ~" |: I9 g3 n9 ^" n; e
    3 j( H5 f  @$ W) |/ l- D6 c( L    elif model_name == "alexnet":
    0 ^, }4 a, @- Q2 v6 O. T* t        """+ }9 n8 C5 p  _, b6 o
            Alexnet
    6 e* Y/ F, H7 M! A! B4 A$ P/ J        """
    9 M. x, ^# {8 }. G4 F. D6 ^        model_ft = models.alexnet(pretrained = use_pretrained)) F; R  f6 h( ~7 l" H" W' A! F
            set_parameter_requires_grad(model_ft, feature_extract)
    ! v0 N' U( ~% X
    6 T: O0 P" {& Q( A! S& ^        # 将最后一个特征输出替换 序号为【6】的分类器
    / m0 x2 b5 Y# A5 `& ~        num_frts = model_ft.classifier[6].in_features # 获得FC层输入, \  j8 I" j0 a2 j3 o( j* x  Y
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    0 T, Q. @$ d3 [        input_size = 224  v( \0 G3 l. y. m1 x

    # _: {% A( g. |; s) G5 H    elif model_name == "vgg":
    9 ]% G3 b  C% u+ t8 d8 F        """
    4 _! J9 \# }% v        VGG11_bn+ b+ V5 j, `8 g( a  n0 J
            """
    & v3 |$ C9 a3 ~        model_ft = models.vgg16(pretrained = use_pretrained): s1 A9 F! R" V6 x, \
            set_parameter_requires_grad(model_ft, feature_extract)$ G* w, I9 t- [
            num_frts = model_ft.classifier[6].in_features
    ( b; e2 K  ^( l3 Z8 f. y( g2 v        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)* y, a3 K  S# e6 }
            input_size = 224
    0 m; L4 E7 G" n$ p5 a! Y& d7 u$ G4 Y* @" U8 L/ [3 i1 r7 ~& H
        elif model_name == "squeezenet":
    # B7 \+ n4 z6 n  ?+ O        """) I. y# |' H+ L8 D
            Squeezenet
    1 L. @$ c+ k# q5 s3 B        """
    + L/ s3 M  h( H; W9 C        model_ft = models.squeezenet1_0(pretrained = use_pretrained); y9 S& R2 |; c
            set_parameter_requires_grad(model_ft, feature_extract)( D7 W# f* Y+ O
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))' K3 M+ z) c- T/ m* T$ j
            model_ft.num_classes = num_classes# D, C8 j6 J9 c2 K
            input_size = 224
    * J) o! a2 \0 e
    6 x( z; {1 ]: g    elif model_name == "densenet":
    , ]" A- T5 z" k        """" @* A3 N) }5 a+ s  V) o, p
            Densenet6 ]) V4 L  I: H* G4 a7 }- j" h
            """+ E8 E7 f) G1 e; A
            model_ft = models.desenet121(pretrained = use_pretrained)
    5 l9 r7 z. o+ z- E" X3 _        set_parameter_requires_grad(model_ft, feature_extract)' g. B# o+ ~0 H
            num_frts = model_ft.classifier.in_features
    . L2 C! r7 m5 G8 e. i/ F        model_ft.classifier = nn.Linear(num_frts, num_classes)7 i: D( _! F' G4 `, O" n1 R# P
            input_size = 224: J6 j" N0 q) z" V# }# i# L

    0 Y1 E+ z: Y6 P    elif model_name == "inception":* T' B0 U+ h  x2 W% [$ k: L3 P
            """- F  X0 z& d) S3 o4 C
            Inception V3
    # X/ m8 `& f$ H        """
    $ ^  i" _& u  f0 ~  U. k9 p! Y        model_ft = models.inception_V(pretrained = use_pretrained)) r$ C0 W# x8 Q4 c: j" X
            set_parameter_requires_grad(model_ft, feature_extract)- h6 x/ x* l8 _; {/ q, x* [+ k

    + t" W% c" G* b        num_frts = model_ft.AuxLogits.fc.in_features, X- d3 R$ R, ?- D
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
    ' B  m( d; j* o8 W, w7 |) M" i& D' ^: L4 Q
            num_frts = model_ft.fc.in_features
    3 }- x. m) V% [# t. z6 g6 U        model_ft.fc = nn.Linear(num_frts, num_classes)$ l5 c3 q; i8 s* ~" c7 c
            input_size = 299
    ( `5 n4 E, B& a! o4 x" [4 ?& f/ B5 u; W* n
        else:
    1 C* S' h2 L. F. I, K6 m4 H        print("Invalid model name, exiting...")1 S. U0 q, h' s2 `" L: ~
            exit()
    : j% n5 N4 E7 Y  C5 x+ z1 g5 }0 S0 K2 ^- ], S/ _! l$ f7 A+ i
        return model_ft, input_size7 u6 X* X1 C; \4 A

    , Q5 u$ H$ f) C. P3 z2 e+ Q1
    " T1 ^9 r) d* m3 ]; a# d- i2
    " @% u9 {2 T$ P9 `- ]& Z/ m; @/ M" _3
    . o" n) I* g) U* _+ O4
    5 N; X6 F8 @+ n- u/ V( }) P5' r/ _. b; @6 y
    6
    ' q" z* S. y3 P0 ^7. q* b' ^# L% k
    8) _: m& z" F* x  c- W
    92 \$ h' A: a6 {! Z1 q. P
    100 G# |( i0 c/ X$ X
    11" `  |9 l( O! G& ~  m! L
    12
    ' A. R$ }2 y! m4 a$ R/ L7 r$ q13
    ! M4 y2 M; {4 K( C14
    0 F( U+ {7 E) U3 Q7 {15' e3 f/ c5 _7 I! \- r2 L# ~
    16; u( P3 d* {2 N
    17$ i) R, x9 W" l# M( I  W
    18
      f& W, H( V% ], y: R1 c* c5 y19; v! p6 e" g$ z+ H
    20
    ( B/ h, J) b  C& s217 K. c: L2 [3 v" y* d& K5 u
    22, f$ z! a0 O3 y% P: l5 m
    23" d4 Y) o4 q" V" G! J
    24$ {; u$ g+ @: ]
    25
    - @9 V! `/ b" U26
    0 ?# h; n3 }" O0 X/ I8 T% [27& i6 d7 a( E! I7 [9 l
    28
    6 G- [, u1 q1 A  U. H* a# H; U( N29
    ; H/ u9 H) E2 H( r% n& v30
    5 T6 n) h. ^. O& q  U31& f  i" B6 `: O  a% h5 q
    32% z9 i0 J5 m" ]& D6 W
    33
    " T- Z  t, @+ y34
    - j+ a0 n( X% R6 G8 n) b353 S- L: _$ O! `; ^4 b8 ~' P
    366 ^6 {9 z7 d1 J1 o3 _
    37
    " Y2 L  i) I! ?' D1 j38! G6 D; I9 l; |" o9 r2 ?4 p8 O
    39/ }( P  V6 _' m$ w" d7 T/ A
    40
    8 @' T; C  N( P7 @6 L& E41
    ) G' _6 R, }" I7 h42
    , ?% E* `. D* z3 o434 y; _5 ]0 @! z& v; `
    44
    3 ^* L9 F" a- i" B8 ]45
    3 f, B6 e1 `( o1 [46
    $ t& P/ m( ^8 n47- K9 ?6 @; [* B. i; {6 L, B
    48
    3 ?: W# n: c6 [; N) ^49
    $ Z0 _2 v: O, E, d* {, y50) r" k% s/ y& {  W- D( _
    51
    % ?7 ]& [3 M5 H( j% a9 p' s; \2 H/ t52* n' r0 ~; E7 ?# ]" q
    53# X- g2 `9 x' R. j) u% E) f2 y
    54
    # h6 y) N0 b  E: o* G" M551 b5 v" @3 W9 s- [* W3 U8 K3 v
    565 e" s9 V1 x2 f1 W+ G8 w, z
    57
    9 {0 k+ h7 k3 B: S* i: _587 R. y$ k( F; T* B/ l/ b
    59
    1 g) ?" F+ F) ^+ k5 v- L6 b; U, O60
    ! b: |. p; Q7 r  h61
    0 y; t8 w6 h$ ~7 g1 {/ e6 E' M62
    9 k& |( [9 Y& U* n63
    4 r! `7 C1 m0 S6 x8 n/ |9 ]64/ T. k7 B* j6 W0 j" o8 U; F
    65" ?- w% X7 [' N7 R
    66' m) Z8 m5 t, \+ _
    678 t, j9 A+ v7 A, u: s
    68
    ) T0 l! l, b0 B. n/ K6 ?$ O+ q69
    # T' S) M( E( l" P& l4 i+ T70& X2 G$ }: ]. ^# ?; _9 ]% p
    71
    , f+ T3 A8 D3 P9 ~1 G724 B* h, B% `9 E
    73
    5 b& M/ D5 j6 [, v74
    9 q6 L5 M* {3 o: G+ I  G7 s75
    / h' Q7 e0 @) H: a76, b* a' _% ?: `2 a' ?/ i( S
    77
    " t* b2 f  e8 Q  {  y+ A1 L4 \78
    ' }3 a! [4 i1 T; e797 u+ N9 W8 ^( x8 J9 d
    80
    + H2 \% d. C! i5 x3 l4 L5 T81
    ! `' o5 `+ N' T82
    7 h4 y( I' X6 S& V83! G: X6 K7 R- N8 ^* V' K
    7. 设置需要训练的参数+ t# u/ |0 e6 D
    # 设置模型名字、输出分类数1 e* b' Y- q7 I
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    # v7 t2 \( ], k( ?# {2 T/ I& T0 c% N+ |, }
    # GPU 计算
    , S) f4 w. m7 }) L/ N6 k# c: gmodel_ft = model_ft.to(device)% I+ o: W+ d, C
    4 ?! p8 K7 Q7 U% k+ t
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取" ~) E  P" e2 I6 _0 Y; [' b
    filename = 'checkpoint.pth'
    ( o, q0 K" O5 @; I8 E$ T* r1 }
    $ o7 |* [' z' s: k' [% j# 是否训练所有层) [4 T& g: b6 [# l. e# l
    params_to_update = model_ft.parameters()
    % {- _) u# O% d4 S  K# 打印出需要训练的层( W3 [! h! S9 k: V( m) ]- n$ J
    print("Params to learn:")
    ) p- }0 v% O9 O7 z. r0 Y, Y) tif feature_extract:2 Z" r; @' d- O; I1 O
        params_to_update = []9 E- g& d$ |. G" c/ i# O; i9 G
        for name, param in model_ft.named_parameters():
    ' H% N# c* h5 c, F  V1 g        if param.requires_grad == True:( x$ q3 s3 j8 w1 T  ~
                params_to_update.append(param)
    + ]6 D' d: J! F! u8 H1 t" s            print("\t", name)
    8 M6 v3 \# k. e8 belse:9 j1 Z7 }2 @5 |# ^: k' d
        for name, param in model_ft.named_parameters():" t  k, {3 n( x3 U* a4 _% }+ X
            if param.requires_grad ==True:6 C1 n; B3 t- j% r6 @2 I
                print("\t", name)- f. Z/ j2 z5 |5 L
    ( R/ @! s% W1 \% \
    1
    & y  D  ]! V8 Z! V$ u: C' n( r8 h2  C8 m; v4 _* p+ q8 v
    34 o' o. p3 \2 n: ?& h$ w9 ^
    4
    3 V3 y2 A- e4 i- l5- \; {0 h  b' H  P5 w' O1 ?
    6
    5 D1 u) Z' n" E% f5 r70 B+ x' S; d2 {$ P% r
    8
    " y/ B& D2 R3 N' i9$ L  w" d( f5 K& [, m" H# W0 Y
    10
    , q  H8 `1 Q8 D: p8 B! ]11% \! u5 W% M$ g$ Y! O
    126 c3 p. g1 u. O) `
    13
    3 P  v2 g! L: G* c14
    0 E1 R6 @6 x' I- q* q15$ i1 X3 W& L! C3 a
    16
    2 G9 h" m2 \" M% i174 C! g9 ?: q+ l# a; T9 E9 U
    18! L7 z( V/ w. g* Y4 F* Y- C8 G3 E
    19
    0 ^3 N( U# J1 v' b0 |$ _20
    - s' G& ]5 K, `/ b7 _; E21/ n2 c, v7 t5 z' X6 C8 s& t
    22, q" o/ \/ K% I$ g: R3 P
    23
    + k( V5 @; d2 F, R" G  H! [Params to learn:
    9 C7 N  L7 e6 `! N% }. B% T2 E  d         fc.0.weight
    ; s; C0 t: C- Z- T/ U' z         fc.0.bias
    ! V9 i1 S( v9 r/ U' P) H. L1; u! R# @+ U- N
    2
    . g, x* N/ e( ?+ F32 w" X7 F9 t# Q8 P  g: G# c
    7. 训练与预测3 Y- Z6 E- {) H) b. ]* ?0 I
    7.1 优化器设置9 @8 h2 k! W# r0 \( F% f0 i
    # 优化器设置
    9 t4 J& ^1 M8 b+ {optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    ( x. c, Z$ }! _6 j. f4 z& z% w5 V8 V# 学习率衰减策略. a4 ?# P6 y& }2 V; D; H
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)/ o7 w. G. z6 ?* l: Z- P8 M
    # 学习率每7个epoch衰减为原来的1/105 j) j( C1 H/ K3 Q
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算! \- N2 w- ?: N2 P; G( I

    4 _) @1 T  n" m1 W) }5 u1 }7 p/ ?. _criterion = nn.NLLLoss()! ?4 C1 V: D8 Q  s/ m  T8 W
    11 g+ f+ ?( d# i! b
    2
    ! `: |! b+ e9 T4 I% |" Q3
    3 U6 M- |$ N1 R& d' ?7 ]4) @6 m7 E$ T. `- i  A5 @5 r! F
    5
    2 W" n. O$ t' k9 i* `/ s6
    ! ^1 u5 U2 M, ]3 H2 L" r7
    ! O  M- L2 x, W& f% x8
    : K# Y/ F3 W% Z4 G, M# 定义训练函数
    0 D  ^. d2 G( D& H/ m/ d: o& ^8 \% V#is_inception:要不要用其他的网络2 E4 j+ q3 \8 ^  f  f/ j% j
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    9 @, j5 L8 t! s6 a5 w; b9 t% Y+ y. T3 J    since = time.time()
    ' W# \3 Z. M8 A, `/ n    #保存最好的准确率
    7 o" W2 Z# _5 _6 ?    best_acc = 0/ x. a! }9 B4 |- F! b, G  L; l1 ~0 c
        """- r2 |0 C& R( v) ^  n4 H
        checkpoint = torch.load(filename)
    $ M5 C* V' R* d/ a( p    best_acc = checkpoint['best_acc']
    " J) l$ P6 ?' Z2 v( e3 [    model.load_state_dict(checkpoint['state_dict'])# s7 y/ i/ O$ O) D8 S/ }- e" g1 a
        optimizer.load_state_dict(checkpoint['optimizer'])
    2 v$ ]7 K% }% l4 U" T# U0 I% |; K    model.class_to_idx = checkpoint['mapping']  \  a5 c3 b7 _9 k$ b; m8 v
        """
    % A( E  R- J& t5 j" L8 B) u    #指定用GPU还是CPU
    % K! }& k4 M1 C5 {4 u% a    model.to(device)) e/ ?3 @3 f, y: R  Q) l7 j
        #下面是为展示做的
    $ X9 `) e  X! F& m" v5 k% u9 q    val_acc_history = []
    0 s4 N& D  d5 o9 ~! K& n: X) @* F; h    train_acc_history = []' Y  m: V; d! f+ ?+ V7 I! Z3 P& p
        train_losses = []9 @1 M4 x; j8 Q- y/ a
        valid_losses = []
    $ ~. r) W: G8 I$ I, {' X& o    LRs = [optimizer.param_groups[0]['lr']]( \; A8 M: s9 V* |8 _& q# N' w8 ^
        #最好的一次存下来9 h0 |; w! U" _7 d0 m+ j
        best_model_wts = copy.deepcopy(model.state_dict())
    4 X- x' {! A1 X3 D4 b
    : f# D" x( C5 Y1 n3 ^7 y    for epoch in range(num_epochs):
    % t# m/ d. ^1 a6 E3 }  x. a        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    6 U& Y8 L2 b. u' h# D: `        print('-' * 10)- A9 c4 E" Z$ f  c. ~2 p* Z

    ) y9 \6 e( Y' G" |1 H0 k        # 训练和验证& {: l2 _3 y$ g" N3 U4 u
            for phase in ['train', 'valid']:
      y% m- |9 v! q$ P  P            if phase == 'train':
    : X' @$ e; d1 x                model.train()  # 训练$ `% {7 k4 N7 o
                else:
    , J" y% R8 U( p2 a$ I0 N9 H                model.eval()   # 验证# i$ m; {' s& [
    & ^+ h, S6 i1 ?- G0 @# N6 r
                running_loss = 0.0
    2 q( S- k8 A' G$ h            running_corrects = 0! A8 P! l& M9 {' T( ~  p

    " V% w% h, _; o# f* _2 [8 l            # 把数据都取个遍
    $ b7 @' T( T0 @2 P: f            for inputs, labels in dataloaders[phase]:) U' o5 B% ~! t: D1 k# b
                    #下面是将inputs,labels传到GPU
    8 i7 ~( p" P( @0 k4 l! Y                inputs = inputs.to(device)
      X1 b5 W# |% ]                labels = labels.to(device). Y; P* ]  O8 N; g6 U2 [8 h

    9 X+ f/ [9 r8 {3 t                # 清零
    1 C' z# f8 {& j1 C                optimizer.zero_grad()0 {+ p! t8 |" r( W
                    # 只有训练的时候计算和更新梯度
    7 z- l4 ]+ h$ `: O6 M: K                with torch.set_grad_enabled(phase == 'train'):/ m4 J7 d) v6 e
                        #if这面不需要计算,可忽略, X$ T1 i+ E. W: H; l& t6 I4 q
                        if is_inception and phase == 'train':+ @; q8 v- M) M
                            outputs, aux_outputs = model(inputs)
    8 h% J/ h! r. D5 `) o) M                        loss1 = criterion(outputs, labels). I& X% B5 n$ }* N: o+ r
                            loss2 = criterion(aux_outputs, labels)% g5 Y3 ?# n, N% L! z, q+ x! b/ ]
                            loss = loss1 + 0.4*loss2
    % c! x4 ^2 H* Y6 g) H  t4 h. x                    else:#resnet执行的是这里( @+ c- @" r& ]0 M6 G% Z
                            outputs = model(inputs)
    # \( L6 H# K1 r                        loss = criterion(outputs, labels)! H# k; y/ }- L9 M8 D
    $ q: `) U* l7 K+ [2 v; O
                            #概率最大的返回preds, l, \2 m1 s( E) F
                        _, preds = torch.max(outputs, 1)2 X* ]. {, `7 W7 O; O
    # [/ {+ R6 k$ W3 ]# y' W( [$ h8 f
                        # 训练阶段更新权重
    $ o+ S' q; Q$ M# w, Q. n4 H                    if phase == 'train':
    ; A1 O# A( ~# T+ k. s4 ]: K                        loss.backward()& p0 h0 h3 g5 T! P* A
                            optimizer.step()2 T9 k( @5 }8 L+ b0 X0 B" v

    6 a3 c& Q% Z& t                # 计算损失" n( l! {$ b' |5 U7 G, U
                    running_loss += loss.item() * inputs.size(0)
    & W3 K8 Y$ a$ `' d$ z9 J0 E                running_corrects += torch.sum(preds == labels.data)
    . w. K; R9 K8 H1 ~! @
    ' x6 R. G' q/ P4 C) n" Y' Y4 j, z            #打印操作
    8 @8 `3 ^% I! f7 \! y# j            epoch_loss = running_loss / len(dataloaders[phase].dataset)
    ( O* o! g6 c0 I$ \1 r            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset); k! q, V" _  b- g! F: t
    6 M* H3 u1 h- x: i  c# f7 n( z4 K
    / m/ Y1 x! f2 k/ i/ I5 _
                time_elapsed = time.time() - since0 @- U9 s! G: d- e) R3 z6 l3 t
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    # _; p- e/ }, p. ]  E, X            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))9 x- N  L( r) y6 e' m& I
    3 }) V  |5 t8 L$ U3 E' o$ \
    4 a3 J1 L% g: z; o/ O
                # 得到最好那次的模型: ^, H6 {1 X, c: t  k
                if phase == 'valid' and epoch_acc > best_acc:% D4 u( B+ y2 X& V
                    best_acc = epoch_acc
    & `2 O/ n. K* Z7 I  F1 f                #模型保存& ?3 F! p8 O6 x* ]" V% ]4 \5 R
                    best_model_wts = copy.deepcopy(model.state_dict())
    ! o2 k  H9 X$ h9 r                state = {/ {) i+ l- }! N- b
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数$ t. O) |- }6 P
                      'state_dict': model.state_dict(),9 @( F3 m. [6 Z- g3 f  q
                      'best_acc': best_acc,1 R* V; e( J, C% P2 D9 ?, ^5 Q" G
                      'optimizer' : optimizer.state_dict(),
    6 I6 F7 @) \. Z% b+ A4 x                }
    # K5 z: |4 A' U& x  x) M0 m                torch.save(state, filename); g$ _. v, U( R
                if phase == 'valid':: k! K  Z) D4 u/ }  A# O2 S
                    val_acc_history.append(epoch_acc)' I9 @5 m& e/ r' i2 p4 A7 z
                    valid_losses.append(epoch_loss)
      {3 _3 u; n" A; I; R                scheduler.step(epoch_loss)
      q9 r9 h  o* t2 L% ~4 v* ?: S5 Z# ^            if phase == 'train':- a7 A5 c+ Q: a0 B' s
                    train_acc_history.append(epoch_acc)
    ) ^# e6 K( Y) Z6 }3 V                train_losses.append(epoch_loss)
    / U' J% b7 U5 D
    $ f9 y. a  Q0 U  Q( y        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))! T1 y# ?9 S$ K* c) D6 V6 A
            LRs.append(optimizer.param_groups[0]['lr'])8 Z/ n: b  V- n# ~
            print()9 b0 f+ p" L9 w) o

    # U3 x) J0 [/ {' q    time_elapsed = time.time() - since
    ( @: h; W6 V  H8 q  R/ j6 V3 ?/ B/ f    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    * L+ N$ n3 p7 O9 E2 Y    print('Best val Acc: {:4f}'.format(best_acc))
    $ I) m6 [( _5 C* `. n$ g& X. w8 S5 Q2 n* `
        # 保存训练完后用最好的一次当做模型最终的结果
    $ R% t6 E5 w4 E2 \    model.load_state_dict(best_model_wts)6 y$ o! a: H5 i. z2 j( ?( j8 E1 Q
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs , N7 X- h" t7 W) R8 L. p; ~$ y

    ; R- A/ r* l" e9 v+ K1 s1 ^, b7 Q4 o, {+ H* k/ r
    1
    $ t' q/ h- x' S# Y8 N* u4 X2 V26 `  K0 U# R$ y$ Q! Y
    3! N& q: M$ D6 y  n- A/ e
    4$ z% \8 S% D2 v& V/ c) L
    5
    9 S0 u: |& ~) ?9 n6
    1 K" V& k4 Y: H0 n3 U7
    % e' o! B* L' m84 ?% ~0 C) T2 {2 P# g
    9' W- O/ w" ?3 ?+ l2 V' Y
    10/ B4 z+ K3 @( h0 p0 V# d
    11
    + ]( ^3 `( e: s. j12
    7 q# v+ Z! F3 Q5 o3 i; N13- D! M( Y* @5 S4 Y$ K$ F
    14
    / M( F5 m+ e+ _1 q, `4 _15
    & s1 O& f3 i' n  r: J5 Z) F: b16
    * @! p' D! i" X5 b- M( |; L17+ \$ E+ _5 D7 n
    18& F  A( T& w$ B6 \2 g: H
    19! C, {& j; Y- n) B: T0 Q$ Y# Q
    20
    9 t! v  t! T, k# q7 G- n7 l) R213 h$ [, `) L6 d& U* X
    228 t! K/ _  ^8 ~- S9 w
    23
    1 G$ z1 n* a8 X$ |+ C+ _245 J! o4 t5 b, @7 K+ b
    254 Q( R# ^+ L& U0 ^- P1 G/ E1 Q0 g
    26
    " q' f+ q1 ]( |# _278 P8 S- @/ o9 Q6 p$ ^' R& h
    28, b# i8 Z8 k( L# w) |% R
    29
      \: ~9 m! z" Y/ o: |302 B( s* u- |, W7 D4 h6 w0 ]. G
    318 o5 B' F' Z/ H; U: a
    320 O1 a" ^) X' H( @
    33
    ! H( W5 a* B2 j5 I0 |7 A) ]. R34
    % x# r! g  E) I6 x35/ b( D2 _$ k3 ?; b$ k% n
    36
    3 J9 q0 ]. n/ L/ V37
    + c; [0 V5 H, O6 c38& l' e% s$ X" F6 F
    39  a2 z' u& P0 |1 m! k. n8 `
    40& [1 [0 j) z# P5 E9 g- j
    41
    / q& L9 p8 w  X3 W2 E* f' E2 q9 n42+ r1 C5 y; {8 n0 w
    43
    ( V2 C2 E$ @2 T, U  o( Z4 C4 Z, M; q44/ b7 n  i7 @; S* Y$ m
    453 {/ m$ S4 `7 j
    46! u6 [' n. X9 j$ v) @4 W$ ?; C& P
    47; Y$ V) }( ^5 c1 v
    48: ~/ P5 Z2 S9 `5 l+ d
    49+ r6 `: C5 o8 a. `* J3 a
    50. R% |- }: ?+ x5 S- i, @$ f- s/ x
    51
    : O# r3 E( W# E9 s3 a- R3 D; z1 c; I$ r52* H' `8 J! k3 _  Z2 G8 B
    532 |* C& n1 J: Z* x, r" Y
    54
    & e* _* ~! }3 F2 _0 `2 J* E55" \* ^: V' F& M
    568 c8 K2 U, p9 @4 C1 C, R  x
    57. j2 N& x+ \$ {8 W% T& n7 U5 E
    58
    6 l/ |3 e( M' m3 T9 C) ?9 X59
    5 U! ]$ K/ a; y. M7 i* e60' Y: C" G: i' V( k
    61
    / \/ U) H* T# |620 d8 }( W/ [, `! ?: b
    63
    7 \& e- U) ^, ~7 I3 I% {64& Z# z. J& W: _5 ]" H* G
    65
    ' m, l) u- A0 b  Y+ _7 W* _$ C66
    2 ]1 h3 u# p' ?- o0 x. _! d7 [7 [67
    - ?, n" b3 g! g/ J' Q68
    ) h2 z. }* G! T% X697 P5 E/ h( _; n; N) W  c, r
    70
    ! r# S. C$ \, e; P713 O! V8 }' w9 `" g* }7 w& t5 R
    72* ]  M3 [/ e/ \& t6 r7 q
    73
    " f# J4 E9 P; E* O- P7 H74: \" _' X  B0 b
    75
    0 ~. E5 q. g$ Q5 m5 z2 |( V- I76) L$ P# w* @0 X3 \
    77+ T" `0 g9 D) d1 C, G
    78
    2 J9 s4 [' Z; I9 n( e79
    - L* e+ @  l$ W8 Q3 {. k9 z$ l* u80. |. n7 N2 f0 N0 F2 i( |
    810 x* L! n! x( _* O( j9 b
    82
    : O9 z  |$ ]4 q2 ]2 T83* e, @# u/ M5 c' v
    84% c: P" N) |0 @+ O8 X
    85
    # a; p; X5 a& s7 ^7 {* ~, t8 f1 q  `86" i7 D/ _4 S9 x  n
    872 L2 p; H- Y9 N6 L! @* i6 J% v, I
    88- B2 d$ a$ |4 T7 ]* z
    89: ^: A1 m+ T  @1 n0 \  ^0 N
    90/ K* g1 F  M! p$ i6 P; ]* L
    91
    5 ]7 J2 m3 N# l1 Y3 {0 _3 @- R92' D9 S) N' r7 y( `; C7 P' [; f6 J
    93
    ! G# ~3 Y) s& m/ t: E94
    & ~: x# h& L7 o5 u. E. n; ?/ k7 R95( M; ]1 U+ d. a; j- {1 E' r
    96
    5 Z6 [" C+ c& ^97
    6 T4 @9 K, ?9 D" Y1 o% I1 D98* e3 b- n0 V" d+ G" Q
    99( _; {; ]" R  {5 X1 ?8 x
    1008 p$ x3 y: k; h2 m4 l8 z) A! s
    1016 k* V' `- [" `
    102  h. l/ [* {+ L. M% j& R
    1039 {2 ~* `9 o- _1 w2 ~  i
    104
    , m; f' N% E1 u$ x. g% |# i; B3 @105
    ) [) e/ b2 }+ B$ Y% p2 J, G106' k# A+ H9 D# u8 e
    107
    % S3 W/ f. P2 ~108
    # L# t* ^$ S  U$ p: G# q* z+ n) ]109
    , z3 ]+ |& \, H: `$ p8 X1105 `+ l: h% t+ l3 G" Z9 U
    111
    8 M" L  h4 X' p4 z$ s2 J# |4 }8 V8 H112  _8 ]# r1 O! ^8 |& I
    7.2 开始训练模型  c7 V" L- W# @5 R
    我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次5 j; t  ?: K7 L4 G+ W- n9 H
    8 `; l, ?9 [9 K5 t, Z9 J/ Z' J
    #若太慢,把epoch调低,迭代50次可能好些
    . z  H+ U+ n8 y9 X1 \& I#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合& R. c% G: e+ L6 O- H% A# Y# \
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    6 H& ]' J+ e$ ]9 O. s% e( I! v5 ~% e9 [  s$ x7 B1 ?8 Z
    1
    9 W3 d# C) D+ P3 i. H( i/ R2
      s( W- b4 B% ~  U5 w3
    % w( r# ~; C% S1 e6 z; s4, w5 w  Z: e7 b/ k5 x8 P6 s0 I
    Epoch 0/4
    # F0 m& G8 a2 k5 z: g/ W3 C----------
    1 {' M5 P0 h8 L& B( u; @" MTime elapsed 29m 41s8 `# D0 t( \' J9 U% {
    train Loss: 10.4774 Acc: 0.3147
    6 ]# o/ C. q( T- W4 f/ F4 c. X5 STime elapsed 32m 54s
    ; }9 \9 K. O) E# D, x. B5 O/ Bvalid Loss: 8.2902 Acc: 0.4719
    ) W; J3 ]7 |6 \3 C' z0 BOptimizer learning rate : 0.0010000
    ' B# Q. {/ I+ h3 K; M1 a# ~) `5 W( A8 ?4 `2 @0 x& z: n; o
    Epoch 1/4
    2 d( F5 X7 K' Y: v' C. W----------
    7 b4 E7 R1 q8 ~2 M. @5 {1 r* G5 @Time elapsed 60m 11s
    2 r) X7 [, J/ D' h# t" strain Loss: 2.3126 Acc: 0.7053
    0 @+ [8 ^  [: P9 V6 R0 PTime elapsed 63m 16s- S; v/ J) M. B" s& T+ y$ B+ w% g7 e9 L
    valid Loss: 3.2325 Acc: 0.6626
    1 e9 x! @3 G, vOptimizer learning rate : 0.0100000
    3 c5 `" `% o: m9 B: p
    : @) x" Z! _6 bEpoch 2/4
    * s- d* ]/ ?$ |) |----------
      `. ~1 T- Q. n) h( Q6 t* e1 STime elapsed 90m 58s
    " i$ {( q  ~4 v; Z. ~train Loss: 9.9720 Acc: 0.4734$ e: t. [& r( u0 h0 m, p
    Time elapsed 94m 4s
    8 K% G/ l1 Z" svalid Loss: 14.0426 Acc: 0.4413
    * O) s. r, L% z; M$ BOptimizer learning rate : 0.0001000
    2 n  M7 Q' V2 U( S* q/ v3 C  g* A8 |! K8 a% Z8 w) D
    Epoch 3/4
    5 s6 G# ~0 k  z----------
    ! s+ k3 d) j4 y$ LTime elapsed 132m 49s
    . G! S  M5 L3 P, a2 U" ~train Loss: 5.4290 Acc: 0.6548
    & X; R5 M0 w  q  s; }6 R6 n. JTime elapsed 138m 49s
    0 c4 _4 M' w! K2 q) I, p# s& ]valid Loss: 6.4208 Acc: 0.6027
    7 D3 @. ?5 J1 q4 wOptimizer learning rate : 0.0100000% x) O" f5 G9 I: M' Q; q

    6 ^, R$ {- S( l4 `: dEpoch 4/4
    : ]9 V1 u2 A8 \) ^----------
    % e0 Q$ D' A6 n  A1 |5 p6 uTime elapsed 195m 56s
    0 K7 P  d8 A- V; ttrain Loss: 8.8911 Acc: 0.5519
    ! }5 \, ]6 ^$ ?2 l0 z7 [Time elapsed 199m 16s% b# J! k' z7 M0 B4 f
    valid Loss: 13.2221 Acc: 0.4914  b2 }( r" B0 S' M1 t6 N4 N
    Optimizer learning rate : 0.0010000  z, z- s, ?: S  p3 {# x3 v

    0 {' d8 _9 Z/ ^; m0 wTraining complete in 199m 16s1 c( j6 k0 a  p1 F
    Best val Acc: 0.662592
    ) r2 L! Q. y5 h. q# B- ?/ j
    8 x- N  `5 X2 P( p) D1
    % V" M, f3 {" @9 _. f) {3 m2 R20 {6 U8 K7 M3 o
    3
    6 A* D  e: _+ O4
    # |* l4 H* r. C% _  |- ~5! O) K9 W. F2 i2 M( _
    6
    : g; H1 I0 Q1 S78 n; ^. A) i# x
    82 I* W5 w$ b8 O  ~
    9
    ' L+ o! i* @3 L5 e5 h4 i10
    $ }  _# e- m# x( J+ Q11
    $ m. f$ w% l- S. R# f3 p; f, k( [( c124 x; S! _, @! ~) q$ Z: q
    13
    / X. U8 K5 k. z+ I# E( _14
    9 w3 {* R! m: Q* I0 e3 Q  D% r( _* k15  ~  k. b: w- ]" P9 y2 ?; C* `
    16
    $ R, |( r- I$ l8 s7 O17% Z& \5 G8 Y$ J4 B; n/ s! m2 {
    18
    / Y0 k; h7 y, v# j2 N" L( I/ U193 q0 K# `$ T9 c% _8 M0 q0 K: ?
    20
    7 @3 ]7 P6 z4 k/ K9 T" y$ P21
    ' J% l# i/ o' v; c22+ L; ^1 I. p! Q8 O* q
    237 O1 Y0 V! P8 w  ~8 V2 ?4 I% H# U
    24
    $ z  P) ?. {6 Y25+ H3 `' B' s* Z! z8 t) b% G. g; }
    26
    & d1 O) y! x1 g& }( |27; p( d$ O+ p2 b- i* d/ m% E' p
    28
    - O4 h7 o5 x- T! T29
    , I8 G; U* R' H30- g' Z* f  E# r. T7 z4 v- |) T
    31! x1 v0 o/ H  C" @6 z* E
    324 |' G) b+ f* r3 @* r7 R9 `
    33) }6 d; n) f- \1 y4 V  L
    34, k, G+ x9 P  x2 t
    35
    ( _, S$ S! [; P+ H36
    # O- g) Z% a0 g; A; j9 I379 \6 f" `1 J5 ?
    388 r' L/ l% w0 J
    39
    " p/ a2 ^6 N* x( q40
    : d, \. D7 h# B3 L% J41- S0 e' b; h1 H/ l+ b6 u/ k$ b. m
    42
    ' G6 }/ C! M1 h. B9 O5 {7.3 训练所有层
    " C  Y$ U  g4 ]+ b# 将全部网络解锁进行训练: A4 z' J! ?+ C$ V5 T
    for param in model_ft.parameters():- \! |5 i( d# }7 J
        param.requires_grad = True' Q7 w6 h: b* s

    8 u! W' t9 l2 V& b/ o2 C+ Q$ R0 `: s# 再继续训练所有的参数,学习率调小一点\; S+ A. J( L2 k2 K
    optimizer = optim.Adam(params_to_update, lr = 1e-4)) T# Y9 [7 Z4 Y" d$ E- t4 c; v
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)' q3 ]) T1 f  t0 Y, {/ R# p

    1 K4 B6 [7 r4 y# 损失函数
    ) |0 `) B/ U0 z0 k; Y: Ocriterion = nn.NLLLoss()
    & x9 y0 R/ n* v) E) G  W7 K3 B$ c1 E/ L1
    ; j3 r" J! p# T2 A4 R" V% b" n2
    # U5 |; d& ~" w) w. w" Q30 J, Z& a: \  b1 Z, l
    4
    2 ?( Y7 @8 i9 _9 E* t  |5
    . k% D7 s6 p  \; t" Q6# u2 }% w0 K# [' c
    71 m) p2 n, `& @7 ]
    8
    + B3 P6 I& v* q2 T) B1 |* x9! c7 V: G+ N6 A0 A) G
    10' S% a: |/ M9 Z3 P: K0 }" Y
    # 加载保存的参数
    ) t, H2 T4 U% X9 v8 X# 并在原有的模型基础上继续训练
    " r$ e; q% X0 r2 q3 }; C" L* C# 下面保存的是刚刚训练效果较好的路径' K3 o' }* e4 f. Z2 f
    checkpoint = torch.load(filename)
    " {; o7 E% d+ r$ X  }) `2 Ibest_acc = checkpoint['best_acc']+ g+ M2 [. S% T1 k5 m
    model_ft.load_state_dict(checkpoint['state_dict'])
    5 V, L, F" W+ boptimizer.load_state_dict(checkpoint['optimizer'])' f/ S1 A& A' A& E1 ]' E# @
    1( d4 V' y( \3 h- J% M
    2& m% G5 A1 M2 X+ O9 U1 u
    30 s* O- D/ X- R$ {( a3 f
    4. L7 B% |' s4 w+ F& B) w
    5
    1 ?& Q& o" x* [# o; G6 ?1 k9 s$ `6
    3 v. ^3 ]: _, `/ m$ }% ^73 {, o9 H6 c' y  f' V1 Z0 Q+ q
    开始训练
    $ Z2 V) l" P( K1 G注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考! K# s: ?( S/ |, C
    ; q$ p  r- T  I4 }0 h) @( A
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    ! G* P7 s  n8 w1
    : j* ~9 o$ v; FEpoch 0/1/ h! F- t/ U3 B0 h. w
    ----------  c% F- r/ W% g! N
    Time elapsed 35m 22s
    ; J% {$ ^: c9 |/ p/ z: ptrain Loss: 1.7636 Acc: 0.7346
    $ O7 b* K% K# g. Y5 |5 q/ r3 |Time elapsed 38m 42s2 G: u9 b  J) g
    valid Loss: 3.6377 Acc: 0.6455
    " @) G% s9 p  l( n. vOptimizer learning rate : 0.00100006 Q) _) d! F* r3 \6 f
      x7 u( y2 H' q/ K' N9 p
    Epoch 1/1
    5 T" A3 B1 s/ B' E( j2 m+ y/ s----------
    ) m' C7 `$ u4 q  LTime elapsed 82m 59s
    ) h5 v% ^- ?' o( z/ F& N; Ltrain Loss: 1.7543 Acc: 0.7340
    ' O4 s+ T* s* _! \Time elapsed 86m 11s
    3 `8 }3 E+ ~4 h8 i! nvalid Loss: 3.8275 Acc: 0.6137
    . R" c& }9 |  q/ `  H/ dOptimizer learning rate : 0.0010000
    # V7 ?% T& E: I0 L) X& @# s7 y  n2 D1 L( j  a" _- r; L  ^; {
    Training complete in 86m 11s% c. _: R' ?! B8 m- ]) O
    Best val Acc: 0.645477) g% G/ g% F. h; ]

    / {% N7 y- {: V. J! _1 v9 f4 T1
    # d5 N7 M  t6 R1 ^) ?2% G6 ]& _+ E+ e
    3: K+ h, F, N3 r9 k
    4
    , y/ V! D9 n1 h: R55 I" n9 A3 e! K  ]5 F' O$ @+ }
    6$ C2 q* u* `) F. D8 O
    7
    # D6 K' E5 ~# b3 R7 T# `8$ G% X" j* g  G; N
    9
    " |# S- K# T9 K4 s5 j& l$ R10- c9 x$ s* X4 q  x( G
    11
    - d, x2 x9 V2 L4 c2 H12. R6 N& T! R7 H( \
    13" B; [+ Q( s+ U" P2 _6 y
    14
    0 R( C5 N9 X' D' Q: J15
    7 E7 x2 |) |* q8 z1 i3 J4 m16) o# l5 E" B0 ^
    17
    4 b' g; x! }% M" @18/ G1 o: S7 k. y& ?& |; }" J3 Z
    8. 加载已经训练的模型$ [( _" M4 R6 t6 w. d
    相当于做一次简单的前向传播(逻辑推理),不用更新参数
    . h& U0 H/ f. ?! t! P
    : G' v! j' N- J4 o6 {* s# fmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True); A9 [& H* `# w5 P, A

    * L6 X1 Y, U# O" f; f# GPU 模式7 f7 T2 ^$ [% O$ o
    model_ft = model_ft.to(device) # 扔到GPU中
    6 |) t& a! @( f  Y. R  U* h: I+ M2 z7 d6 x5 s9 U
    # 保存文件的名字
    7 ?, m/ q8 N; x% r8 yfilename='checkpoint.pth'
    - W3 I7 k: a5 }, S5 ]! O
    2 H; @0 C0 @' \- F" L' ~# 加载模型# J2 ]9 a( `/ w
    checkpoint = torch.load(filename)
    4 S4 y% g- @, T$ j5 pbest_acc = checkpoint['best_acc']; K6 n! |/ a7 Z# x( _+ a" H
    model_ft.load_state_dict(checkpoint['state_dict'])/ X# f2 g$ c4 t& g/ ^; q. y
    1
    0 v; J# F5 S( g8 d  S2
    % h% H5 Q, U" M2 p# P8 ?, ]36 e# b& g$ B4 ?" a
    4
    1 K: h# j! h1 {' G, X% k5/ H( n2 J5 W# x6 e4 R
    6; u: a7 F* r' j9 k+ W0 B
    7% F1 Y7 T. l* M
    8  d$ ?+ s1 k5 G
    9
    2 B* V# y( f2 @1 `5 l9 [& V) w/ A10
    9 x7 f, D% Q8 H+ _: \. |2 {; S4 s11
    ' V6 W$ B$ |& G123 |) h! ~) m& q0 Z- ?* O
    <All keys matched successfully>$ Q8 |( m7 W! j/ `' r
    1; Q, M8 \; \2 }& E! r9 k
    def process_image(image_path):( ?& S7 h" M7 e% j6 T  w& X3 V: M9 x
        # 读取测试集数据
    8 p  e" a" s  p    img = Image.open(image_path); U1 ^; f4 [, o' b# F& M! O6 F/ p
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    ' d  l1 I- ^2 a' g6 U8 T# i    # 与Resize不同
    5 L0 O0 @* V  q$ x    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小; Z" K: ^  {9 }0 r; J) c
        # 而且对象调用方法会直接改变其大小,返回None
    ; _, C# v* c8 ^% ~$ q    if img.size[0] > img.size[1]:: B* m' _0 W3 Y: R& E9 u$ ?
            img.thumbnail((10000, 256))
    * k  K6 g4 p7 d+ Y1 W    else:
    6 v; P" X' ?# p# r        img.thumbnail((256, 10000))$ A' g* o1 ?9 J! y2 s: b2 `2 x7 ^& B

    8 j) B9 K0 H) J  W* J4 j# P% Q. R    # crop操作, 将图像再次裁剪为 224 * 2244 a, p3 O3 z: [; m: g
        left_margin = (img.width - 224) / 2 # 取中间的部分
    ' J6 n! C: E7 i0 y' H; R    bottom_margin = (img.height - 224) / 2 ( w4 d5 N! b3 x0 p  N; Y
        right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度$ Q. X, j- _$ W
        top_margin = bottom_margin + 2240 H" b5 P2 |, g

    , B4 J: |' v1 f2 K    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))% S4 z( }* S& L, g3 V" q1 P  E! y

    1 W7 S1 Z2 H" c    # 相同预处理的方法, {! B- Y& f- n/ t
        # 归一化
    0 V0 _$ D/ C2 E0 @7 i    img = np.array(img) / 255) @% R' ~3 Q6 y5 j: w
        mean = np.array([0.485, 0.456, 0.406])3 B8 Q* _0 b! D) s
        std = np.array([0.229, 0.224, 0.225])
    9 t2 z3 O% d2 x4 \    img = (img - mean) / std
    1 P' Y, W* f$ N# U/ X: p
    ) C1 T1 ]5 y1 I! @. u    # 注意颜色通道和位置, ~: }) H, q  R
        img = img.transpose((2, 0, 1))" o  Z& ~  j. [8 b
    % f* f: ~* N" e0 \+ K" }! X2 P, C( V
        return img, Z0 h( s  R$ U. J

    ( V  i7 W) P( o4 adef imshow(image, ax = None, title = None):
    " }: R1 D2 |5 q  h# v4 c7 H    """展示数据"""1 e# z: [" u  z' n  M
        if ax is None:9 Z7 s1 I2 P. U2 r1 S
            fig, ax = plt.subplots()
    6 U) i. P" P3 c# _. B+ W% s9 A  i9 I8 Y( d3 S  U
        # 颜色通道进行还原; x: k2 L0 X. l# s/ p% J) Q- b* q6 Q% M0 {
        image = np.array(image).transpose((1, 2, 0))4 @8 t9 o; g% @( ]+ K0 G  w/ a
    0 f/ R2 ^1 H) ?" Y% r. e$ v
        # 预处理还原3 f5 f4 t( f: G, a
        mean = np.array([0.485, 0.456, 0.406])
    6 z: ^4 l5 m5 ?( Y& w3 \    std = np.array([0.229, 0.224, 0.225])) F4 P3 R# A; Y0 s: B2 v
        image = std * image + mean
    ' l8 O9 L% ^! r7 k2 Z' d    image = np.clip(image, 0, 1)8 J/ a9 w" N# }8 b
    4 }9 a$ Z' n0 c  o$ F2 c
        ax.imshow(image): d* ]. Y5 X1 D, E* q% |6 |8 a
        ax.set_title(title)
    ! E( |! R) n; q6 X4 D
      z, }1 s  p  G3 f    return ax
    & Y/ H) R7 d5 L5 A8 g0 c
    + s+ Q; h  N3 C; |9 R! Mimage_path = r'./flower_data/valid/3/image_06621.jpg'6 J& X' j9 B# N; a  z# |5 f# g$ t) V" S
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理; A1 Q* |( E2 z1 w5 s. ~. r
    imshow(img)
    3 u( [/ Y% D" ?  U2 }1 `7 d8 i, t! N, y9 c
    1$ ?6 [' t% m6 f, j0 }3 Y1 R& _9 l
    2
    6 I2 n/ X- r% }3
    7 Y* K+ t9 k! a4+ M% {& m! R. D
    5
    * @8 K6 p+ `& L5 O+ n& A6
    4 {, F( X0 O2 f& b7 t# R6 P7 F7
    / |5 o; l: B* k: y8 G, T! J6 ?8; Z. b) y7 N7 _; t7 p
    9' T, d: Y+ T+ W' {# Z2 G
    10' r; o; i: X0 p. _  b. P( o
    115 y; v" a/ b# v$ b  L$ W: z
    12
    9 l  W2 S! c; w7 `6 L; I+ k6 ^13  j, s7 Z( h9 U) ?0 @2 M
    14
    9 Q" e4 \7 B( r7 F) s) f& R159 }7 A" m/ j& Z  G+ e4 M( g- m7 G% w
    16" E& a) |8 U( ?$ k
    17! g+ r! c& N( [7 `, L1 G
    18) ^4 q3 P0 S5 r" ^$ N9 g* M, b
    19
    2 r: P% a# ]1 Z3 S20
    , i4 x; `; Q7 s' _21
    0 _6 p( A/ M# u# M( z9 S+ _22
    " o) o+ Z3 O1 d+ j23
    6 B7 b6 F. h( ]$ H- o1 S; B0 a24
    3 N. l: _4 Q  n) y25
    - M/ z  Z- V' c" ~: ^3 B' J1 U) t26
    ( Q" @. j: o  k& Q. T. b" U27
    ! N) X9 C( ~) k; b/ }+ ^- s28
    8 N, Z& |( j9 d7 W4 ?% D1 {4 `29
    9 [8 R  {: a7 y/ E% C30/ A2 ~# e9 j( ?! j
    31& B' q1 ?- z" T. T
    320 F6 [3 A4 v, q( J3 u. f, h7 e
    33* t" O. @) k7 i6 E9 P: E
    34
    6 i3 V. c" E5 s$ K( r35
    9 a1 F- V2 x9 B; [$ h3 Y6 O36. V# B6 e( a( T5 F" U% l8 p4 o
    37
      o2 T: Z9 S+ q7 l38/ j1 K; S! n4 M/ {) x8 m
    396 u; d) z" b. D) f
    40
    : ~' S! Q' x( t1 X/ D41
    1 }  r/ F2 R+ L' h" F42( m3 ?- c3 h8 ]
    436 R. l' H2 R3 E% H2 `0 i* ~0 o: O
    44  r( s% Q( A; H4 o1 R3 L
    459 X# H. O" e6 {$ P
    46
    , i' }  R# z( S. F6 B47. j. w' [5 K- c5 D$ B3 G1 n
    483 U5 }' J+ w, u' O/ L7 |
    49: o3 ]6 |! ?# \" {( o; G( [. D
    50' t; Q& e' ^, \$ x6 p3 j2 Y+ R9 h
    51
    4 b$ x* q& A7 A2 J& L8 E528 @) ]  o; C$ l& @% ^" _
    53
    , r% N# F" [5 g5 ~0 q* C542 B/ F$ t" Z, B. |/ M
    <AxesSubplot:>0 p1 k5 N1 e# u) O
    1
    4 i( e3 B* g! u" w# @+ c
    ) q0 \% P! }% H1 l上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    6 w; k7 O! l7 |, ]. R' r8 b2 D* b* R" R; ?6 f
    img.shape/ m0 a: a6 T4 N
    16 t1 Z: j( k7 Y1 t- u1 D
    (3, 224, 224). C2 a: I( U. o5 c$ C8 l+ @$ z
    1* c7 z. u  H9 r5 k2 b- W8 G
    证明了通道提前了,而且大小没改变
    " Z6 L& P2 b2 Y3 A0 G! q
    ; ?9 B) U& p+ N9. 推理
    7 L- T* ?' {+ }& ?' S" l7 f% Oimg.shape
    1 m7 s( R* O4 o! j% _( A% i/ U6 K+ M% {0 B' z3 ^
    # 得到一个batch的测试数据
    2 N6 x- p( v+ S) {1 X8 pdataiter = iter(dataloaders['valid'])' f) X& _$ Z4 o$ \3 V+ x$ i
    images, labels = dataiter.next()
    1 `9 n# x+ n: j; a2 l/ y! a/ F$ [" \' ~  E
    model_ft.eval()" i5 B9 A8 T" m  L6 [* h
    ( @9 x; D5 h: W7 H4 u7 V
    if train_on_gpu:
    9 i7 r) r3 i9 C7 d* }    # 前向传播跑一次会得到output5 d2 {- v* Z7 _7 X% c, A. Q  X
        output = model_ft(images.cuda())+ V, A( _1 x: g* e* W- a
    else:
    1 J4 S( p0 q# P  M9 K    output = model_ft(images)9 x5 `3 m& i8 p* l

    5 J; r& q0 o. }9 u& P: @# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值, s- M! o; c) I" w6 l. G
    output.shape
    ( B& l9 J; x# j0 O2 X+ b8 B; y9 r6 l3 m$ R% Z4 z) [, w' x4 a
    1
    " b. z+ O% Z. q6 C& r- S# M8 @2+ D: K' C) g+ _; \+ C, G$ u
    3
    4 Y+ K3 p" x0 U9 ]49 X. R/ K9 \8 U" a# {* e: {
    57 s& w0 t5 N. I% C
    6+ V8 S8 f0 s0 c9 ]) Y! `
    7
    5 ~1 O5 r2 q8 `% u" z$ ], @8
    9 m6 j. ]$ U- y$ M9/ [- r+ g. M' x, j* t& ]
    10* ~: E8 k; o; d' B
    11
    % J+ a5 S) s' @8 J0 Z12
    . {4 R( ~8 l4 i% m136 k& y1 q* [2 i5 U+ _
    14
    9 _( {, E6 X8 o  f150 i( }1 d# s' A# A; Y' P* x  D% |$ y
    16
      N' Q% N* w4 s& v5 ^torch.Size([8, 102])
    # T- h2 ]3 n) F7 o7 P6 m+ ~2 v1/ l6 ]* ?4 W" h; b
    9.1 计算得到最大概率3 b4 f' e7 r1 h" c$ [
    _, preds_tensor = torch.max(output, 1). N# m) o+ N, _/ J/ z- \, \
    1 P" a6 {! ~1 v2 e* ~5 `
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量, U& f, L: S. t  Q! B  W& p  L: a
    1
    3 c9 _  q; l+ o- I7 C2" d( q) Q4 J" U2 h4 F9 I% y: q
    3& L6 R- P  `% u7 n
    9.2 展示预测结果7 Q) u! f9 o/ n" t8 I4 {6 E
    fig = plt.figure(figsize = (20, 20))
    ) A. }  M1 R* Ccolumns = 4- k  Q) a* `7 L+ _( o$ N
    rows = 2
    3 N) `: a1 V, }# e: `% J2 ^# O! z7 |
    - S/ j0 C& z+ rfor idx in range(columns * rows):3 k$ n( l8 g3 A4 e/ y; b7 ]
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])2 w4 h0 r. K) n$ K1 r
        plt.imshow(im_convert(images[idx]))3 d  L  z) }3 I0 M0 j1 U9 Q
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 5 o7 h# e" ^/ Z
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    & |8 G% a" ]- a) W, E' Pplt.show()
    5 X7 R* v# |5 Y+ e# 绿色的表示预测是对的,红色表示预测错了
    5 \! ~( W; y3 u/ y5 G7 ?1
    + C7 ]2 F: Q0 ]* J2
    $ `- G" G1 z+ F, x2 h3$ A" i' r8 s+ y- m
    4; e! m! q% s# m4 f; ]' O# S/ ]
    5
    % W" y6 ^9 s9 {! I/ T& U62 D, |  A( ~- o' L3 {: a$ {
    78 G+ A& }, S& s9 l( t/ \
    87 h# q7 z5 b& s  {7 D
    9
    , _6 Q# ~4 [4 S9 h; Q; z101 \) c  w  m" W* x. h
    11
    ) k# R1 u4 l) k8 o' ]) V- N( {: ]' n, m  y/ U& ^) H( L: N; |( E  M$ v
    0 Q% ?' }7 N4 G$ x* X+ P
    . M* O9 R/ {+ W$ i
    ————————————————
    " }9 V8 M3 s. n) t版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    2 x) e. {+ F3 s7 s; _) }, f原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    : n/ u( E2 ]% h, x6 P* q$ q4 [* Q. c4 u8 M- k

    ) s3 L* y( @/ X! ]
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-11 02:05 , Processed in 0.312658 second(s), 51 queries .

    回顶部