QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2757|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例; K! l. R7 U$ f8 d
    & y+ I7 S) E+ R; p
    文章目录
    * N( T3 m  ]" R卷积网络实战 对花进行分类- V; I0 @0 W, T' t" i' D( h3 q
    数据预处理部分
    1 U) O% h* g( f+ k网络模块设置) Q  s$ U* E$ ?2 |0 n
    网络模型的保存与测试$ D* r7 v& G" \* z; s8 X
    数据下载:$ G; [+ S$ l* S3 Y! P9 ~. K8 l4 `
    1. 导入工具包2 |% z7 ^3 n/ V' R
    2. 数据预处理与操作( s' m( t$ q8 h* b, w
    3. 制作好数据源
    8 ^* Z5 j! O1 `0 u3 [( b0 \8 K读取标签对应的实际名字
    , ^8 M# |. ^) [+ q- k2 p8 f4.展示一下数据
    * J! T  d+ _4 Y3 x7 l) L; G5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    , W: e9 H8 j; [6.初始化模型架构5 Z( j/ a& q+ L; m% V  p
    7. 设置需要训练的参数
    5 I2 j3 G4 K+ h0 }4 w7. 训练与预测" \( X: K" A- a0 I: ]7 m' U3 k6 u
    7.1 优化器设置
    # X' ^9 `( R3 n" |" ~& n7.2 开始训练模型
    ) H& H; |% t9 I7.3 训练所有层
    * h$ P. k. U  R9 U$ A; Y开始训练
    * a& h- u/ s8 X. n3 a2 f8. 加载已经训练的模型
    # N# e( l8 U; S. b9. 推理
    ; m  s5 z5 d) s' |+ k9.1 计算得到最大概率8 f9 O, H* N3 h: Z" K0 o1 T
    9.2 展示预测结果
    1 O6 {: a& [! R( }# Y# O" v写在最后
    1 `' I/ k! Y; T2 C卷积网络实战 对花进行分类
    * U2 l3 y# ]* _% h* ~# P本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    ) e& {  ]. y6 B2 x4 r' \' P
    3 b0 a! r0 D9 `' M! N3 I在文件夹中有102种花,我们主要要对这些花进行分类任务1 Q+ a$ M" k/ M- B9 ~. W
    文件夹结构
    : [( }( a7 {7 F5 {4 S2 j4 t
    % ~. n* q" A1 d. R$ \( G9 b1 lflower_data
    6 z' p3 s3 x, c0 F2 \
    + z+ X& z5 [* @5 z$ I/ ?8 ]train
    $ w% x& ^5 J; ^4 d4 j' ?
    $ Z; R* R8 e, L1(类别)
    ( `2 J+ [* z( p+ K2
    1 C- ?3 A/ [$ uxxx.png / xxx.jpg
    6 J! _) S  E; `4 n7 Q: d, nvalid9 _* R& R* {! x

    : h% u5 n0 ^6 B主要分为以下几个大模块
    : ^. W; Z+ l1 y% e# X1 ]9 ]
    3 I6 d8 Q; A$ H  V/ |数据预处理部分
    , n2 ^8 \% `9 [# C1 s3 w数据增强' o/ Q4 q! ^1 }8 i
    数据预处理2 u7 h1 R6 W6 ?7 r9 ~! _
    网络模块设置
    ' u% ^& I4 ]& T1 s7 P' E加载预训练模型,直接调用torchVision的经典网络架构
    & F7 I9 H4 l: M/ e因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务8 B) l0 @" l) }$ N+ S% ^) K
    网络模型的保存与测试
    % S0 w7 Z3 I5 n7 A模型保存可以带有选择性
    / H  X9 j4 R$ W: {3 J" y6 K# j数据下载:7 e) N0 }+ y% B; R. u6 I
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    . O5 V7 ?. x/ C+ i% c8 J0 q4 D: P1 R/ }9 |2 r% }) ?9 ]
    改一下文件名,然后将它放到同一根目录就可以了; H. [0 x# T5 g, ~
    0 q; y. f/ w/ s+ G6 z# F% Z& @) W
    下面是我的数据根目录
    # M8 A: [3 e9 ^8 g& C) _( T1 r) l9 }9 X0 J- b0 l
    0 r8 g! f3 A; f+ u& v: m
    1. 导入工具包
    & r* k: m6 j6 ^, B' r' v6 k* q2 simport os
    # w" l; r$ @& Q0 V  b! s) `import matplotlib.pyplot as plt; p) X; s; a: ]2 q5 l
    # 内嵌入绘图简去show的句柄
    0 i  q- r7 a1 k* `/ w$ A! F9 o( O1 \%matplotlib inline
    ) D  ~& o8 w) t) X" Uimport numpy as np
    / ~/ x- q5 H  B$ A2 Aimport torch
    ; E+ F& J  v9 o/ d" ~9 O& mfrom torch import nn
    5 U7 n6 ~0 V5 d$ g! k( N
    , y9 v- s8 x; Z% j% n9 H" Qimport torch.optim as optim
    7 K9 [2 P  l8 U# n. _import torchvision
    * r* C+ e" q  Z! Z/ X& Mfrom torchvision import transforms, models, datasets
    / P" _( W9 `' K# A! E, }* s7 Y; @3 w3 ?+ J
    import imageio( m0 Z6 l1 W" j/ x: l
    import time
      C% T8 [* k1 V; g! f9 oimport warnings
    : o& f5 y3 i/ D  x( \* G# ]& Qimport random2 U. Y$ E0 Z9 l& Q8 k. _
    import sys1 Z: h/ B& i/ l
    import copy  i! {4 v) _7 |$ z; e$ y9 ^0 k
    import json( Q3 K3 _! y% v
    from PIL import Image' m" [+ }4 N( e$ x; f

    # V: M8 W8 b- h( I6 x
    / ?  Y( t. B, b) M4 v+ X# P4 W1
    7 C1 i6 @5 [! l+ d" o) P# Z2
    : u8 O$ B( c0 [, A% t" M8 u" q) l" y3, w2 Z( z8 H8 {- b- o6 b8 {
    46 l/ h: a5 s$ Q7 \
    5: O& N1 F- V# o: C
    6
    4 M0 z% U; w8 L: v2 ]7 T# B$ T! B7* i- ?3 N- X( B2 D; W6 q. E
    81 N) ^/ ^9 O, b4 v& P( Y
    94 l/ m$ A1 x! @- \8 p3 |* G
    10
    ( G0 Z2 i* g6 ^: c; X11
    ( T: \7 l& V5 e6 m4 t" R12" g& z$ L- Z# S* S+ _
    13% t7 u( j1 G2 b( J  j
    14' L4 i8 `/ T9 V( t, u; o
    15
    $ s0 p! ^  J, H16: a: o+ o3 M% d. f
    17
    6 j9 I8 d4 E- M# V$ L7 A18: h" n' \1 K$ o$ R* j
    19! i/ [" r  N( b. e. |
    200 \8 Z1 m7 V1 o% P7 G
    21
    ' L4 i3 |: _/ E* _  Y2. 数据预处理与操作. h* _# q- f1 ~9 T
    #路径设置
    2 J- k& Z& ~+ Y" y# j( m1 ndata_dir = './flower_data/' # 当前文件夹下的flowerdata目录' y1 j" M; }7 \5 u
    train_dir = data_dir + '/train'
    + \' n$ L  g1 ?2 |1 v! Bvalid_dir = data_dir + '/valid'2 q' J2 }7 v4 l( i* D( U2 [
    17 n) a( I; k( `/ [
    2
    * k  n- ]3 ?& ]3% G: k) Y% V* q' \( x0 U
    4
    3 }. W+ U! j7 ypython目录点杠的组合与区别
    5 D4 E2 f6 B! w" v) @) y  X$ g注: 里面注明了点杠和斜杠的操作
    # {+ a9 O3 {  U' N. d7 ~
    , v9 R4 g, m$ v: j0 L3. 制作好数据源
    $ u2 a2 U; a& Edata_transforms中制定了所有图像预处理的操作6 p9 n3 J! z: L: G
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    1 r6 x+ l& \0 @& U  O& \. idata_transforms = {
    0 m: D/ Y4 u! d  a7 d7 f6 K: P    # 分成两部分,一部分是训练$ C8 A6 g: _7 b' V- D" g2 l( p# f
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间/ D& H' i  z+ O( B' S$ ]
                                     transforms.CenterCrop(224), # 从中心处开始裁剪7 |: y. j/ s; T- M% a% [0 |* a
                                     # 以某个随机的概率决定是否翻转 55开1 O& t: A0 u" ?% D  v3 A% m: c" A( @& V" k8 w
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    ; w% P# w2 P6 d( }( Q2 ^                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转3 u: }  {/ z! V( y
                                     # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    # E: _8 N# b0 k                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
    ; k, K  _. Q+ j5 J$ i7 C$ C                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    6 l  M5 A+ \4 j2 ^                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的+ o+ M" D3 v5 L6 e4 N; ^# V0 c
                                     transforms.ToTensor(),
      @) |5 q# _* a7 _3 ?4 r/ O                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    : n% k; b! z- M/ F0 Y7 P" _6 D                                ]),
    . n1 F0 B$ O1 j3 I( v. X% ]1 L    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    7 p+ m8 X, h" l. l5 {4 j    'valid': transforms.Compose([transforms.Resize(256),9 k6 j- q- k! t( A
                                     transforms.CenterCrop(224),
    ; v- ^6 `5 {; u0 u& g9 k; _" Z                                 transforms.ToTensor(),
    4 O3 j& [  q+ a                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    ; o8 I% ~; @3 t* \                                ]),
    . T0 H! O" T  A- f% N/ s}
    - Z- T7 g# N: m* E. g7 J( i
    2 y. n: ~! x  O" L1
    3 W' ]) x6 r  w9 i& Z3 v2
    ( d+ g( e! t0 D3
    ' V& X0 n5 z, b' [9 E( C4 m5 S4
      O! u; }9 g" F* N# ]4 U. j5$ ^6 h* C8 W: n0 R7 H: q6 m* B+ G; O& X
    6
    - d0 f* D2 P1 a' o' O7 ]7% {( m5 n1 I9 z$ v# \. [
    8
    ' C4 T) L4 Z) i) ^% F, b9
    / l% i/ P8 F; k# i, |10, j6 v/ y: h2 t$ u: K
    11
    * j# y# Y5 M- ^( |* m% Z1 `6 @9 }12
    # w7 i/ c3 h: F% F, C13& G: O+ D: i+ @( k( D. @/ l
    14
    # a$ v# p+ Z" z0 M1 k2 l/ j  F- ?15
    1 t: P$ `4 a( g4 A' [7 t16
    ) e3 A2 u  E9 Y4 ^( T+ }17
    6 t& @9 \* }; e& g. E180 [' i2 P2 B" B
    19
    , B# O& w; F6 r& j; a8 \& k0 V! X20
    & H+ s0 X+ ?# X6 `21, x  t$ t3 K% r, R/ k9 D( Z
    batch_size = 8
    ! w: W+ O: V6 Kimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    : H" f  z# o# [& [2 e5 u' idataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    2 q' M. s$ F, E5 m6 Idataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} ( B! y, G* ^$ d; |
    class_names = image_datasets['train'].classes0 x5 W  ^% S4 a( g4 ?# b, O0 r5 H

    7 s2 f% N8 ?" C% s9 S#查看数据集合' V" n: o, k& Y  t% m. M! r! g
    image_datasets
    . U6 n$ g. Y* s  b* `. F  q# I; f3 V  {) @% c. G
    18 W+ {/ n4 `% _: }
    2
    8 p. _( M) D7 o1 w% `2 k0 I8 J3
    0 }! m* V0 V* }* ?" U4
    " j, d0 v( [, |5
    . U9 M, n/ A! A5 h# O8 ]6! X$ g/ d+ I/ A7 Q
    7# j; r/ g! [, E' ?: ]" C6 B
    84 V: M! B! c" i! z. b" Y# l1 x; [
    9
    * f* B. s/ q2 e" D+ R' O" C8 r' M{'train': Dataset ImageFolder5 O9 L6 M9 n4 U  G7 ^$ R
         Number of datapoints: 65520 I/ ]1 v  T! W  Q! L
         Root location: ./flower_data/train+ V0 k( d5 o( g, |" m  N; v1 p
         StandardTransform, K, r0 M2 e* e  K  n) P0 X
    Transform: Compose(+ `$ A3 T, A! M) _0 ^
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)3 b- u, B$ _! A" j4 N" G# k
                    CenterCrop(size=(224, 224))) W; R. Q+ ]  d' Z0 @
                    RandomHorizontalFlip(p=0.5)
    ! P7 m& o7 _. P3 E- L0 b7 W1 f                RandomVerticalFlip(p=0.5)0 K% m6 A% [7 f/ o# |3 h" [" n7 r0 {
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    $ ?4 x% B" g+ x- e5 n                RandomGrayscale(p=0.025)4 x7 \% m9 m+ h5 \
                    ToTensor()
    ! Z& O; b/ |! s4 U- S. }  C                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) s  s( E9 M% K' T2 _9 z- h! p
                ),; m. E$ R! U  X" ^
    'valid': Dataset ImageFolder
      o8 r+ W( W, _# d2 @& ]& }     Number of datapoints: 818  U1 k9 s( f" p8 g/ ~3 W, {( u
         Root location: ./flower_data/valid/ m# U2 F; a" w; F- x$ r
         StandardTransform' x; l/ \: m( F, e
    Transform: Compose() N4 ?- O' z) a# x  |0 X) f
                    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)0 Y  k$ d3 u" d* F6 R
                    CenterCrop(size=(224, 224))5 Z( F7 V; c8 J* Q5 e
                    ToTensor()
    ; z4 }& I; [3 c0 G, V                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    2 B+ Y. f* L6 g- W, T' c, _            )}
    5 L3 Y+ V) h0 e4 I  @! z; W
      v# a7 F! A: M; l; M+ p# T6 G1; ?! k9 f  B/ w/ y' S3 e6 F( Z
    2- l4 a7 I1 C6 N
    3
    ( o8 p8 |, d5 g$ ^- S1 W4
    ( R! W  W# T0 r6 {$ N5  B( c; \. f! z
    6* x3 M7 L% B) @! X
    7
    - P' J  r' |5 }: I& q8
    4 }6 B: ?' }6 P/ D9
    1 X6 o4 Y, Z4 B/ ^2 u3 ?10
    $ m. T. w+ m& Q+ t11
    + y; K7 b+ t$ ^! \2 @: _12, w% R1 ^4 d5 v
    13
    9 x/ J1 v! C/ @3 d14
    . l$ B' r0 A4 O& `( n" n5 ~; b7 Y15
    1 n6 }" }6 _2 |4 \7 M16
    ) n/ b  w- |: e2 b7 |& S17; R; Z( ]& ?+ [. _* c
    18
    % S; f( z2 d# M19
    ' ^9 J% Q# t" X. B! d20
    . [: [# X: N5 @4 u+ m* B  z21
    ) n7 g* d  C0 k6 M9 \  @22
    ' L- W' n( L. C: c! I8 W" g23
    % `2 d" I$ Q5 y( C24! M0 `8 p2 x& Z0 h+ Y
    # 验证一下数据是否已经被处理完毕
    . Y3 N5 N% j' I; hdataloaders
    + |! Y2 \% z$ A7 x. i6 x1
    ) T9 h8 t8 ^9 d* x& p$ X3 }/ q2% T. u0 T) {( v/ N4 X% f. m# a; k
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
    8 W, {6 ]0 {. E- m% @( j 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    . G- U$ p) k, B( N* j1( i4 a6 E. N* N0 K8 x: z, `7 _- r
    28 s8 W/ Q9 e* k% b% u0 X' ~
    dataset_sizes
    0 a  F: p& K/ g9 w, U3 E1
    - Y2 W- N8 n, a: O- z{'train': 6552, 'valid': 818}' @: U% w5 I) K5 B8 i( h" V; W
    1. M0 `5 p" s: M+ e: |/ H- F
    读取标签对应的实际名字
    5 j! }& p9 D2 p使用同一目录下的json文件,反向映射出花对应的名字: l1 ?% P& Z* Q2 e# \
    ( s- M, d( Z5 r9 A9 \# [
    with open('./flower_data/cat_to_name.json', 'r') as f:  E8 K; O' i* [: L5 D) t
        cat_to_name = json.load(f)
    6 [2 J. A# @9 Z; l. r6 s1
    $ H, q( M2 s6 n! u! u9 ]2
    ; ^9 i. h! N4 B; W. Ocat_to_name, @' M  ]# T! b2 D" u. g; X
    1
    / C5 B$ W* Z9 S  ?{'21': 'fire lily',) Y8 n5 e7 z7 n0 P2 }
    '3': 'canterbury bells',7 j, @3 N2 W6 F/ `: p* B/ k, H
    '45': 'bolero deep blue',
    9 N/ h( \1 t  S' G% e '1': 'pink primrose',1 l: n; h( {. H3 d: g7 S9 A2 }. o
    '34': 'mexican aster',
    $ a% S& h4 l. u1 Z% L# M; n; I '27': 'prince of wales feathers',0 q! U: I: l) S2 L/ N
    '7': 'moon orchid',
    # W/ i" {7 m- o. \, ^9 q '16': 'globe-flower',
    ( A- a) B) q2 e  ~) T1 ?" E0 B '25': 'grape hyacinth',$ B) L4 v/ k4 C5 |! T2 V; Q( _
    '26': 'corn poppy',: Y" y' x/ D( H0 V; [8 X; c0 d* r
    '79': 'toad lily',7 `% L& ?6 Q& L' _( {, P
    '39': 'siam tulip',
    $ W( n* ?8 R; G& z7 i3 q6 t4 H- ? '24': 'red ginger',
    ) l  C  |3 Q' q3 _- Y5 t '67': 'spring crocus',
    4 R: `* V; F8 B8 h+ y '35': 'alpine sea holly',5 g7 x% K% H1 m6 P: `
    '32': 'garden phlox',
    - l/ @7 O5 d: n+ Q9 N# G$ E; J '10': 'globe thistle',; o! s; \9 b) y
    '6': 'tiger lily',! v7 {6 H( j! ^8 p
    '93': 'ball moss',
    0 A/ K: N. k& F. s; }( Q '33': 'love in the mist'," h1 w; C' {/ }, J7 \: d0 N
    '9': 'monkshood',
    & A- x4 i- z+ h. e6 g '102': 'blackberry lily',
    , J: R, n3 X8 p6 P8 n' s. K/ B '14': 'spear thistle',
    0 r  z0 k, m' t9 J( ~3 e+ E '19': 'balloon flower',: ?' K! B. l; n6 s; D6 o
    '100': 'blanket flower',
    3 V6 [+ \/ m9 e$ _3 x '13': 'king protea',
    + Y8 H1 \' z' G% p- b) | '49': 'oxeye daisy',( G) ]5 j: n) r. q/ t
    '15': 'yellow iris',% {! S( e4 A( f, J. {
    '61': 'cautleya spicata',
    ) |8 j  {" f8 b '31': 'carnation'," g2 g9 j) t8 C
    '64': 'silverbush',' G* h& r4 o1 k
    '68': 'bearded iris',- T% L" m; \# I+ y: X7 `7 y
    '63': 'black-eyed susan',& T2 o+ c3 g7 c. p& k2 s
    '69': 'windflower',
    ( z: l& R! [- y' z- P7 f& | '62': 'japanese anemone',
    5 l: g1 Z# I- W5 {# _9 y '20': 'giant white arum lily',
    4 r- P( S* l& X' B% J) V  ?6 {+ Y '38': 'great masterwort',4 d! B4 ?- T* _; l. r  I+ P
    '4': 'sweet pea',9 c6 T" J& r0 o) t0 |
    '86': 'tree mallow',
    2 O. S7 ~& M1 z  K. m' @ '101': 'trumpet creeper',
    2 g# g$ B' O0 r2 ~- P0 @; A. C3 }, z '42': 'daffodil',
    ! w* X) K7 s" V( g% L0 M '22': 'pincushion flower',
    . x$ T  i/ ~; t* F '2': 'hard-leaved pocket orchid',
    # R2 M0 L+ s3 P; R6 d '54': 'sunflower',3 t2 u' M( U- d. y0 E9 z
    '66': 'osteospermum',
    ! D; s8 T6 P7 C" P/ k! R '70': 'tree poppy',5 b3 j2 g% u" j0 s/ N6 C( w. |
    '85': 'desert-rose',
    4 b6 _2 ^0 J, A% A. b  Z- R '99': 'bromelia',8 P  u  R$ E0 Q( ~+ F) y
    '87': 'magnolia',
    3 v8 q; {" p9 {, Q6 u6 y  u; u6 t '5': 'english marigold',
    / ?2 B# N" b$ j '92': 'bee balm',
    5 l  {$ [* @, U* [7 {+ H '28': 'stemless gentian',
    " z6 H! S! L8 `2 ^% Q) P: M# l '97': 'mallow',9 j- x+ ~9 Q; X' m; G
    '57': 'gaura',
    ! o- n$ v; ~- x- R '40': 'lenten rose',3 h! z) q1 b' U; `+ ]- Y7 |$ T4 g
    '47': 'marigold',+ N: d5 ?; C3 e9 c& P0 P
    '59': 'orange dahlia',
    - b' g$ W3 F0 O$ T& ?; N) ~: x: Z '48': 'buttercup',
    8 v' \7 `% @% M6 r '55': 'pelargonium',
    ( n5 {0 K- M/ f* f$ W. I '36': 'ruby-lipped cattleya',5 B) [1 u8 P0 _+ ~+ e5 G
    '91': 'hippeastrum',
      s, L1 o) |" b( `' B! G- l1 H '29': 'artichoke',2 t3 h$ }0 \! |6 h0 e
    '71': 'gazania',
    ( K% k# E+ q3 F! o3 L# R '90': 'canna lily',0 T" V+ N# O' h# A4 o* O
    '18': 'peruvian lily',5 d: d* ], \6 E2 g$ \" }- k! u5 M( F
    '98': 'mexican petunia',
    ' ^$ a; f2 k* \' M' ]" S* ~* F) N- ^ '8': 'bird of paradise',% h1 p7 [# a0 F5 o) [
    '30': 'sweet william',
    - G9 x) O6 c" Q% K6 i+ g/ ] '17': 'purple coneflower',
    : _! M* a6 L3 n* e( Y, `( l! @ '52': 'wild pansy',* e9 y3 Z, a. R& j: a, _
    '84': 'columbine',
    + q" V/ e1 A2 e* n '12': "colt's foot",; q2 ^3 g+ j$ n
    '11': 'snapdragon',3 |4 s; l7 w/ o3 z' b- f
    '96': 'camellia',, m+ b; x, v" z+ h2 k
    '23': 'fritillary',% v# U6 m- Z9 M' x/ e
    '50': 'common dandelion',1 o0 J0 L: y6 o. I
    '44': 'poinsettia',
    & V, K/ r8 g( [( [$ @' F '53': 'primula',
    4 y5 {( u: I4 u '72': 'azalea',
    $ x* X# u. L& t3 h1 G '65': 'californian poppy',% V4 P' l- d6 M% E0 V
    '80': 'anthurium',
    3 l* _* Z: c+ W6 H' T '76': 'morning glory',# \/ C# F! U2 A
    '37': 'cape flower',8 n( {: o4 Y: L7 |
    '56': 'bishop of llandaff',
    ! f# R, y, l0 I6 {) M '60': 'pink-yellow dahlia',
    % z' d/ N5 [5 b1 C# ^ '82': 'clematis',
    - V6 D! Y3 l, U& b '58': 'geranium'," ]/ P9 x$ N, V
    '75': 'thorn apple',# U5 M' ]( S% w% |! R% M( E
    '41': 'barbeton daisy',. j- y$ I- R% f+ u
    '95': 'bougainvillea',
    ) i# N2 _( l1 G$ o+ m0 L/ }2 J" z8 J '43': 'sword lily',
    ' O; B' m8 r0 d* |; J '83': 'hibiscus',
    3 }5 V- v/ A5 I& [* \. L) B! x7 @# x '78': 'lotus lotus',5 W- l: a* o6 x) N0 Q) S: Q3 J. X- r
    '88': 'cyclamen',
    - T0 G5 v3 p( s8 I '94': 'foxglove',
    , e; |/ A6 u/ T" ^5 ~7 P& O '81': 'frangipani',
    ) S# F- e  W0 |* G  V2 n% u' Z+ g '74': 'rose',: {! H7 b1 O7 g5 O; ^0 f( E$ ^- k; |: C1 R
    '89': 'watercress',0 N/ m# \8 q3 u% C. u* m. l$ n7 H
    '73': 'water lily',, X! F( N) i2 U
    '46': 'wallflower',
    - J7 i/ s! o6 ~$ `) {0 S '77': 'passion flower',4 \6 p& r- s1 H5 m& |! m8 U2 ~
    '51': 'petunia'}
      O# M% r1 G& H/ w% r$ z- w2 N0 Q+ s  T1 {; v: d6 |# |
    1
    4 Q) c  m& e1 b5 i5 ]2 g23 K: P$ ?) x1 H: A4 s* B( J( ]
    3
    7 s$ h3 @6 |, p9 ]4" k- [0 G& X4 R$ K# Y! `
    55 w/ F% u2 L$ u' t, U' X: P
    6
    ; n) h9 o: F5 O& Y- ^73 ^$ Y7 D3 M- J
    8
    * b. k( d6 A0 a7 u: |  ?8 a9# V- p: Q3 U' ?' |% ]* f2 G8 A
    10
    % t: k5 @. o0 s, L9 e' ?3 R11
    2 ~, |4 R5 M3 }0 ~4 K12
    $ g, J' C) N3 T7 `% s8 m/ c1 c( t13) ]1 V% p9 X$ K( ~2 ]: s
    142 b8 v& Y. i6 T: J( W  x6 ~
    15
    - Z3 u( c' {2 l/ W) m% {& c168 q* h$ p2 l+ e1 K
    17
    , J+ N" j  B6 W. ~! O2 m& h  s18
    ; b1 E+ \# u* e9 S" `19
    $ R4 [" w! t, e203 P4 |: M' c( t' v. q/ I/ t- S
    21
    ; @, k- z, f. I) U22
    6 y( \% O- _+ r- L& W23
    / G) t5 Z1 a. U; ~, C# u7 [7 M24! O9 d7 D$ R5 I% J9 B; Q, S
    25: b, h; S" V, f+ N. y8 H$ n& i
    26
    , E$ c; R+ h* ^7 t' B271 k! e1 n, G, J( s+ I
    28
    & }" v( Z; R, O8 i5 o290 p7 V8 j$ l3 Z7 B. ?5 j0 e& i+ a
    30& y. J3 U; Q$ `/ z; h; I/ `
    31
    5 G0 v. [0 }7 w7 I0 u1 t; l/ l324 |% V$ M- n- ~- m3 y- R
    33& s/ F- N& p; p% r6 H1 |& a
    340 |& Q" y  o0 d
    35
    . x8 \1 Q+ Z. h3 F) }$ S36
    - F1 t9 i$ J: V& W37
    ( B- [6 l/ d3 I. C+ j+ H' r8 Q385 I$ R$ y3 L9 s+ @" ?* N; l
    39
    / a# B4 e' Q+ w" j& I408 G& s% x) J) f# E+ r' E6 y
    41' V% L6 j* X" G- B. x
    42
    " k; t! n0 I! E& v6 z43
    . H' I( k# B* G  S! O44
    - O- D2 n  R6 n$ F" s  x/ _+ G45- l/ L& V; _- L8 s1 J
    46
    8 J- U1 ?- S" t47
    & B$ u/ J( t3 R0 c, z48! P1 _! X7 `* `5 l/ w
    494 s' C) V: u8 b" X+ B/ z3 q
    50
    8 C! F; N5 c4 k( n4 e51( O% |, H5 L: e& K$ l6 c, N
    52
    ' B' U/ C; }  N: e0 E53
    . y% h7 n- _, Z. [) W54
    + z6 z% w( g3 n- J1 F1 a( u! m* r( p8 G55, X* V( b+ `$ c  _+ J
    56
    1 f7 t, j# l9 C, k57" f- i2 X, i' }
    58
    & x: n2 y( b! O' e& V+ y59
    6 {0 ?+ N1 v  ^' A6 V60
    9 q1 X4 D$ H: n9 z6 ^612 v0 J' c# a1 x+ V8 y8 d4 ^5 C
    62
    " |( A$ C$ M4 Z, T7 O) s63
    6 S  J3 H1 q" k64
    / R# N( S4 j  j5 b- z$ P65
    . D4 e5 B  s. N664 _% j3 V0 X6 k
    67
    / }! D& K& d- ^% [8 v% M68# k4 p) y+ c: U% t6 W7 r
    69+ K8 b8 S7 A7 c5 F
    70
    " D. r; E. A. }- e8 T+ u: d* j71
    6 Q9 }6 {: A( O* m- T2 C72
    7 ]! F5 G; w4 i% ?8 o9 z$ c+ S73
    ( i5 J  C( ?4 x* D2 ]# J6 l% a740 c3 @7 v; ?; j( I
    750 D! |' N' R+ \
    76
    * u8 j0 ?6 m( `% h# v& y8 [77
    3 s# C( J2 u3 Z78# a& s  [+ ^7 n
    79- K: ?0 f4 q* J
    80
    ! }; ]8 H9 [5 G1 b' o6 \819 |! B# _, I9 b: T9 _
    824 ^. Z1 g$ f- [/ n- Z  {; B, e* B
    83* E' T7 `# q) N
    84
    2 H+ F& K: I; b/ A0 p9 r& ?+ m9 a: D85
    . {( ?0 ^4 g$ n/ o8 A  M3 Y! Y86+ ?7 g( }% G8 k8 Z  _0 [
    87) T+ E8 a9 n. t
    88
    * N- E3 j1 h1 ^# S+ @! V$ j89& w! S4 Q2 j6 s
    901 p. u' g. E+ b- e6 S
    91
    ( ~6 ^( {+ ~1 g2 `! u3 g! ]1 p* [92
    * Q4 o1 e% Y3 Z93
    0 @) h. J. D6 g+ `7 y4 z: l94
    ) O4 e/ h& q" d3 H0 S) `95
    $ b& S7 E8 ^# x96: p7 D* U* l) c' m0 \  |& _( D
    97
    * o, u) M3 l2 j. |5 b98
    1 r  T0 o* \8 [! T1 P- K, T99% z/ `( U* p8 p* O; b: A
    100+ ]9 Q* w# c' r5 l( j& t
    101. z  Q( |2 B4 u7 ?5 @
    102
    9 m6 ^2 G" c- i9 ?4.展示一下数据
      R% _# c8 h  B8 {def im_convert(tensor):
    ! W! ?1 k3 R% }    """数据展示"""
    & i/ L9 Q% {: q    image = tensor.to("cpu").clone().detach()
    % U0 ~. n# s; o+ x* n) N2 D    image = image.numpy().squeeze()
    8 h$ ^! ^7 P! s$ m$ v, i    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    * f6 n. v  f% J3 }2 O1 |    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)' _2 H9 R5 e- j/ j, }. C
        image = image.transpose(1, 2, 0)* G2 u0 X7 B( b1 E! K* ?
        # 反正则化(反标准化)
    8 d" F5 \4 N" {    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))- H/ `' j# s( \
    2 {. n) H: v5 _& C6 g
        # 将图像中小于0 的都换成0,大于的都变成19 Z0 Z; b" n" c
        image = image.clip(0, 1)
    ! O1 K) x% R' b3 z. C/ X5 i% Y9 o  h. \8 B  S9 r
        return image
    # K- w5 W0 J& l5 X& @: }. s& o1 G1. e; ~; }. j$ m. t  g% C
    2
    2 e2 O1 |6 E' d3 [/ r  R$ j3
    : v" z% v; G4 [/ Y4) q, F9 ~/ |9 d7 s6 N0 x
    5' |: x; Y' m' X, k9 l0 X6 c/ ?; O
    6
    - r- l7 K# d# u' W( }7* J+ z0 e2 a; f! Q" q
    8
    * t# T9 O" Y- O94 {7 n' d/ z" \3 P( d# n4 |- q. N
    10. C/ ~: p$ Q' ~# I
    11) G# p) @9 A/ x0 A. y! w+ m$ m- T
    12
    4 T# M+ X6 g3 t/ q2 F" }- f$ D% e: [13' w  p; `3 u4 H
    14
    / F; w8 o! r% X& J& V0 w# 使用上面定义好的类进行画图, |7 u/ j; t6 e$ f8 T4 A( v
    fig = plt.figure(figsize = (20, 12))
    / ~. x2 x2 y8 G3 F' c4 _# ^5 L8 Mcolumns = 4( O0 k- U' p! _8 g/ P
    rows = 22 u4 l6 g' Z+ _5 V

    8 |7 |- t* F2 P  }& @1 O3 |# iter迭代器3 f: p5 {! V5 b- Q+ Z3 Q
    # 随便找一个Batch数据进行展示
    ) |# ~; X: c; L* m0 F- m# mdataiter = iter(dataloaders['valid'])
    2 ~0 s4 F. z+ N  g: O+ C* @; Uinputs, classes = dataiter.next()
    $ f  s0 `" R1 n4 w$ k1 N* P: Q
    ! T& r* l; }% h7 S! H8 f" b5 ifor idx in range(columns * rows):
    0 o5 K/ ~- A7 }5 }$ e  |0 U    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    ( B% Q" M# W& o8 ^; N    # 利用json文件将其对应花的类型打印在图片中
    : X5 v" t/ Q! s- h( q: q+ ~    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])# f  W+ I3 l$ h
        plt.imshow(im_convert(inputs[idx]))
    / i' A, N  ~* rplt.show(), }3 y0 Y$ Y: u0 N( S. d

    4 {( r  W$ |+ v1) H. O3 N5 F5 F! n9 S: d
    21 w6 I# h0 F8 L, E7 H5 a% E4 D  O
    3$ K, D4 H( V, J8 \% m; o
    4
    : _- Q: s) C- y5
    + U, V! |: Y& T7 U- M6
    ' @# W  k) M1 U! ^6 K7
    ! H; z" k9 Q. D; G8
    ; c2 L- b# ~. T: B9
    ; s% v8 x3 h+ O6 O; }10
    , r( [; o; a. H1 }# `# ]% ?11
    5 v8 q* F# G; [! n2 O/ N! O128 o7 s8 S: C9 y2 |" i+ M; j2 e
    13* B$ B  n+ f( `) d% V
    143 b2 [7 t! w6 N; O5 J% l$ q
    15- ?: M2 G& i( v& E  K
    16
    % Z2 l' m1 _. Y4 t* U, ]) U* Y/ W8 B9 W+ b
    8 c* Z7 d  W; o
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    ) R7 i8 ^+ y: umodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    7 H! s5 X. q/ `7 K! g, B# 主要的图像识别用resnet来做- y5 N$ Z3 [8 f' F' k  d* o( _
    # 是否用人家训练好的特征; B3 x. ?# M# |0 s- E
    feature_extract = True
    5 q' ]5 T7 k$ U7 _" c1( o' ^8 y+ P" {% P8 t' u
    2
    ( _; Y5 y: k: l2 J3
    / q" S( C! d  a/ b4
    ) u  N) p4 e0 E3 l* X4 m" M0 f# 是否用GPU进行训练
      ~8 x  a3 l/ C7 o3 qtrain_on_gpu = torch.cuda.is_available()
    9 o" J% D& F5 c1 o6 P5 \  S& S
    if not train_on_gpu:6 R7 T; |& I- W6 a/ H
        print('CUDA is not available.   Training on CPU ...')# C5 f- t) I7 L% X
    else:* {! m9 K' a2 A9 B
        print('CUDA is available! Training on GPU ...')
    ) @' h! u" ^. O+ |3 R4 |
    / m* b1 V$ I, _: U/ \& p) p2 vdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')7 C" r/ |. \/ G' q3 z
    14 n* t1 b5 ]! G/ ^! x& L" I
    2( I9 Q$ j, [. ~2 B3 ]
    3
    + s6 }" e5 }, t: B4$ C3 y0 N1 d" \2 s
    5- {: l% B( V* W8 v
    6/ G9 ^- v/ I5 [0 _
    7! L# j0 J& w+ m3 h
    8; X4 e/ G+ `/ l# r9 I9 {$ I
    9: N" v  T; p: y  ~6 s
    CUDA is not available.   Training on CPU ...
    4 a# r- l  Y4 k9 ]' {* Y5 J0 T, \; w1
    ; j  Y; V/ A8 U4 k7 ^+ S! R# 将一些层定义为false,使其不自动更新, N0 J0 X( t8 n, R" a
    def set_parameter_requires_grad(model, feature_extracting):8 s0 V9 ~  B$ I
        if feature_extracting:) m' `" }8 W! X0 |  b
            for param in model.parameters():
    % @4 m1 \- [3 R( x" n" d            param.requires_grad = False
    0 X4 H6 H: E- w1( C4 M' u4 h0 s& X4 E' N6 D
    2* c7 B8 I1 h  v9 u2 B$ n
    3
    - G( S( N. B/ [, H4
    1 H' _) `4 h$ {, a5 m56 N9 m. c9 u+ f3 A& {) s* W$ V
    # 打印模型架构告知是怎么一步一步去完成的
    : ~% m3 S/ ?4 o" |7 [% `% t# 主要是为我们提取特征的
    0 ?7 d$ l8 x7 a! m- X9 w0 Z. m' |8 ]6 d8 c4 W
    model_ft = models.resnet152()( G8 h+ f, z- n. a& `1 |+ p
    model_ft
    8 \2 O; g# `4 ]' b% ?18 w( P8 D! T, }& o4 G
    2# ~$ u5 z, H2 A" o6 m! {5 S
    3
      t. J2 z9 F+ w, y; ^4& ^" {/ D* X/ q% O
    5, h/ I$ _  A- u& V) n7 b- p
    ResNet(( E& T/ l0 I. o
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)* f5 i8 X1 ^3 t8 B( A
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ) W4 I- ~  L0 ]- H& G  (relu): ReLU(inplace=True), z3 x  e+ x( U) }+ J3 G
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    5 T; S* h8 M1 \) I( g; \  (layer1): Sequential(
    , T( @& m+ i$ f( J# g1 \3 ^    (0): Bottleneck(
    : I9 a, i1 Y7 z- \' U/ i  Y      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ! x, U0 h6 S# z6 W) N" U) p      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    + f5 R2 b! r" X3 A; V7 i      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)% J, `$ \& u5 o3 G& [, _7 x
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ' M  Q6 r" A. O9 s4 f( ~2 O  X      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    , W) ?7 G( q* @) I" i; d9 s      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 [' N& T2 R1 ]3 q% Y# b
          (relu): ReLU(inplace=True): d* B# }: ~' F( x0 r8 }
          (downsample): Sequential(, y) Z  h3 u7 a% t- I1 ^
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)' Q5 T) ?3 A3 b
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)# E4 y4 q4 O; x1 D& j5 m
          )
    8 W. c1 ]7 J: p/ B! J2 @6 |1 [, ?    )
    - r, O, u2 G6 i0 s7 ]- J, ~6 z; R$ G5 O中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
      y' n2 T& |4 I: }/ ?" V6 z    (2): Bottleneck(# {4 h0 s) F% W, `) E
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    $ P$ S8 m8 ]7 B+ J: m6 V/ R      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    8 y) b0 P; I7 u# g, N) {      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    4 ~7 ^9 h5 c2 o: X7 i0 I      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      U! H4 K2 r: Y. g  s0 b      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    1 e3 ]' j: o6 R4 O8 U      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    . W: Q3 K' C1 X) p  U: X      (relu): ReLU(inplace=True)
    ' u& X# N- `! o' c$ k: ^, O( ~6 m    )
    9 u8 X* x) K( l. Z# M) D/ @& {$ e  )- X; q9 N& w( O1 Z8 M8 G+ W  c
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))  T  f- ~' ^: W5 \
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    5 W- N7 a& k4 f2 A* m)- w3 M1 c- z0 h7 t: m" {' m- t4 E

    . y% y9 m2 |1 F2 g) U1* N' C) p' `& D+ d
    2
    ! m4 n5 D( ~+ o/ B  K( q2 _3
    4 L+ [+ P6 a2 b2 j5 i* S/ b/ \2 l40 D& G. @8 _* F* r& N, N
    57 f' O8 X# T8 o) Y. L( _) t# X
    6
    - b9 j, q& b5 T7
    " L) X: l; X4 ]% {/ l" z8
    0 _  [  t. b. o5 H* [1 l9+ W% x4 D5 u) I0 c/ T
    10
    8 }+ n) X7 s: c4 U6 F11
    ; Q* |) G6 Z& }12
    / ?0 a( G2 N9 y. f& \3 H13- {# D+ T6 p  ^, x" G5 L; m; L
    14
    9 {4 N7 j9 ^- f" P! ], Z& A# x15
    9 W' A- i& x$ P) b4 Q# t16! ?% k5 P7 L9 Y5 d) {$ j4 U+ [
    17
    / A% `6 S9 T4 _9 ^189 F# B) b! K. h. R
    19& Q1 M$ B& U% o2 G! Z
    20
    + R7 R2 v* a5 s. v  u212 D3 Z1 _; }# k* G7 T+ O
    227 w: c/ `3 |- O+ b) n
    23
      ^: a6 V2 {- n24
    0 ]% @" _  I9 Z& b8 X1 S25- X& g' ?: m% ]% Q4 d
    26; Q5 Y3 `7 p2 r- C
    27
    ; k' C3 ~5 z3 f" j) T" j$ ~2 l( e28
    9 z+ j7 `8 X2 s, C6 A9 Z293 \: b9 u1 @& {/ E! u" i
    305 H1 c: ?' G8 K* T. `# f
    31
    : I' X4 O1 _* ]7 \) b$ g32
    0 U6 B0 O" q" p33! [; n. t8 W' I, `% E7 z
    最后是1000分类,2048输入,分为1000个分类- C  c. t" i! ]& G+ c! U) s
    而我们需要将我们的任务进行调整,将1000分类改为102输出' d8 [5 i# m$ ^5 a( t( s9 u

    3 Z$ I) a& i* n6.初始化模型架构
    1 v; ^& ~9 {% Z5 g3 O: @# f1 b步骤如下:
    % t$ a5 j* [) M- n& {# V7 M8 d9 Q
    8 s; z5 U" i3 v0 a, [- U/ T0 R. C1 I' O将训练好的模型拿过来,并pre_train = True 得到他人的权重参数3 Y- F- j# X; B$ m9 W1 t
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False): y, O7 K- X! C2 K  O8 B$ [
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
    - L5 h( ~& c. C- n, K  Z% ]官方文档链接% F1 o" F& o( }$ w7 [
    https://pytorch.org/vision/stable/models.html/ \2 @4 b. [' E/ r& L

    / p# N( O$ U. V# 将他人的模型加载进来6 ~8 J+ F! I! ?0 \9 @
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):2 U* X& A! x- _3 r6 o
        # 选择适合的模型,不同的模型初始化参数不同" ^0 c1 |( x& \+ ~) |: v) ^0 n* |! n
        model_ft = None
    % z$ m1 {( t* Q9 |    input_size = 0/ b+ T; o& ^- Y$ X+ B' Y: V, O- ~: h7 W

    * [9 `$ D- A% ~% N3 i    if model_name == "resnet":3 L' X, S# }8 S4 Y1 L! P" s+ B
            """( C; a, S9 L, \, W5 a, B6 E7 Q. d
            Resnet152- u- Z0 O! t8 t. T
            """$ t- T5 ?& O5 E' X" ^6 a

    4 P+ Q2 {' d+ M5 K* ]* ]& h* J        # 1. 加载与训练网络: S5 O. i- K8 d& ]
            model_ft = models.resnet152(pretrained = use_pretrained)* {. s" P$ N) Z
            # 2. 是否将提取特征的模块冻住,只训练FC层
    5 T& h/ |$ m! |: I7 ~8 y3 d: T        set_parameter_requires_grad(model_ft, feature_extract)
    ( n4 F8 g( O  ~2 `' ^! z        # 3. 获得全连接层输入特征5 C% l5 p1 i  d2 {" i0 }# ]' ^. T- G
            num_frts = model_ft.fc.in_features
    $ Q  a# |2 Q4 F        # 4. 重新加载全连接层,设置输出102
    % Z1 R' b9 x( p7 M. Y, z$ k5 w# v        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),7 u+ X  `8 M" C$ J
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    & Y# b3 T0 l! G& w6 x* @1 l" B# k        input_size = 224, W8 [$ C. Y% S- Y. _

    6 T# d0 k* [0 {- Z  C    elif model_name == "alexnet":" E1 x! K; T! s) C
            """4 f* y2 G1 E7 c1 J! I; g+ n
            Alexnet6 h8 c; U9 @6 i# j; i' X3 Y
            """! l! `) r0 w& N4 ?" U% h: S* ~0 \' }
            model_ft = models.alexnet(pretrained = use_pretrained)% L. ~% Q- q% E# E
            set_parameter_requires_grad(model_ft, feature_extract)
    ) L  K$ @9 i( f4 ?, j) V+ |7 W  e7 Z" q: x4 @
            # 将最后一个特征输出替换 序号为【6】的分类器
    ( Z! k7 E8 Q. O9 y5 Z  H. V        num_frts = model_ft.classifier[6].in_features # 获得FC层输入
    5 z9 T5 [# c7 @; o/ m0 V: R6 J        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    ! ?# d2 d0 g2 C, n        input_size = 224
    # S; v3 D# [; K& y6 r# y) W" t# U' [+ I
        elif model_name == "vgg":
    5 }( h5 L2 A" l5 ~7 c        """
    ! a& f8 i( g: V# H2 n" X: t        VGG11_bn8 ~! u. z9 e% h/ o2 g; z7 {1 z! s) q
            """! e# O" `6 @& U% I5 ^/ U  J
            model_ft = models.vgg16(pretrained = use_pretrained)! Y# c" J/ v$ P% E5 S4 N% y3 ]' ~
            set_parameter_requires_grad(model_ft, feature_extract)
    # y0 I; S  s0 r! [5 Y- X. d        num_frts = model_ft.classifier[6].in_features6 s) Z- b5 `* I! A9 Q* q
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    / r* B/ o0 D3 i        input_size = 224* S2 S5 Z% P8 _4 G; H: Q6 _, N
    * U- Y; x7 `8 t( W
        elif model_name == "squeezenet":4 ~# C8 j+ G4 Y# H' @# o
            """) T8 y8 E+ g& G
            Squeezenet3 ?  N; I: a8 A! M( l6 y
            """
    $ ^  |1 j  R" i5 `, J        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
    ' \) W7 J8 A! z7 j4 `, \        set_parameter_requires_grad(model_ft, feature_extract)4 A5 V2 O; s: `8 |
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
    " L4 L5 V5 N' x, H- z. X# o9 V        model_ft.num_classes = num_classes# E+ v5 O( B8 s
            input_size = 2244 \& H8 Q9 i, ?- [9 e! h8 e/ K

    5 |3 b/ P. J' I    elif model_name == "densenet":
    $ T1 |3 {2 \. |3 ]        """
    9 D( N- O$ _& ^8 d        Densenet
    ' L- ^0 i2 t5 u5 C; M        """
    8 `% Q7 [9 z% ?        model_ft = models.desenet121(pretrained = use_pretrained)
    2 E) v: o, a/ n5 C        set_parameter_requires_grad(model_ft, feature_extract)9 F, w% `! m, K( h" b; c. Q% s
            num_frts = model_ft.classifier.in_features
      I; N( z) r) `2 \  Y) c' [2 w        model_ft.classifier = nn.Linear(num_frts, num_classes)+ ^* x3 D, ?0 j% e
            input_size = 224
    - O9 o' U. z0 ]( }( P$ i  M( }
    ' m, |3 t: g0 B, D9 }9 J    elif model_name == "inception":
    . U4 z8 }( `/ p  v/ y; h1 q- |        """
    ' Z' ?0 s" D' }+ s4 o7 L% L8 N        Inception V3& k& W6 T+ o$ x; k) O
            """# N8 l3 d) t4 r( y2 W% L: n
            model_ft = models.inception_V(pretrained = use_pretrained)
    7 m" J' k( `9 i! Q) [) e        set_parameter_requires_grad(model_ft, feature_extract)) j2 F  e" o/ h/ {0 ~7 A5 h: N
    - a* g' {! K! C( k, o3 k
            num_frts = model_ft.AuxLogits.fc.in_features6 o( }6 }6 D. {# _+ |3 q
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes): U; i3 H/ ^, a6 }/ O% g

    ! A8 L# D2 }4 C" _3 U) C        num_frts = model_ft.fc.in_features
    / R4 ~- W; K, u        model_ft.fc = nn.Linear(num_frts, num_classes)
    " G% w' C# T6 q) a1 A        input_size = 299
    ! ^7 L/ F, K, l
    $ I3 t( r& q$ f; ^) I2 u    else:# C+ V4 `: G* U* {: j& x0 r
            print("Invalid model name, exiting...")6 j2 P6 v* H5 c$ N& ?% n2 C: `
            exit()1 Y+ C6 E! B# I# R3 s& E' L; T1 f
    8 c& {. N: q; f7 w( g4 b* a" m
        return model_ft, input_size
    , O7 ^2 W4 F+ c: R* s+ d' h' o8 B! f! K; U
    1
    3 |' @9 e6 ?1 {) ?  G4 w2) H1 w) R: [. j5 N
    3* z$ `4 _* H8 V2 c
    4
    & V3 R, c' y' ~$ _5( R/ k! n0 ~" j& v1 @
    6! _* a1 E: _) J; W
    7
    # \+ F; m  I# `, U8
    9 O9 O! G; `  {5 x$ b9
    ! J& L4 i9 A) [/ n: v0 ~10
    9 D$ i" U- Z1 [: V, {/ |117 \1 T4 c& d' F% C
    12
    2 B- V# U. i7 g* m13; r+ [# {' m" M5 Z* P& ~
    14- U- J  O1 T- [9 C! P# p% m
    15
    7 @& K9 A% ^1 g5 O, B16; M; Z% R( V4 S$ _& g- ?
    17
    7 m) A& Q2 S1 I+ |( F+ E; g: L. E18; A- Y4 {. C  W, g+ F0 X: J
    19
    6 M2 _0 s6 y; A5 U; p20( \) n9 o& q5 c, G  p
    21
    7 L0 a% N" C$ i222 I$ \/ d; K& h$ _
    23
    : T8 c3 S5 ^5 t) |0 K5 U# {243 Y: ~" D4 N4 S) ^0 B/ z
    25
    & P- l+ K8 \: I- ^3 B, n26( g# _8 Z! U: y: b2 A
    27( r/ P1 W9 `, @' j& F5 ?! }- S( w
    284 }/ k1 @8 R, D" p/ o
    29: K: s& X3 [; _' U+ {+ P$ z
    30
    ) v2 t. |' y& m  ~# t3 g318 n" q. Y9 A5 Z9 ~$ F0 J: J
    32/ s6 }2 V' B( {; L- p. }! J6 s
    33
    " i; |" H6 V$ I: U' n34& m1 u& p( G3 Y/ ^' _; ^
    35
    . g2 S3 z( |& N  M+ R  O36% s; m" o( s# ~
    37) M; F! W0 `# X) G: Q  ^/ T
    38' e/ v- e3 K* A. d+ e3 h/ _
    39! g$ N4 q+ x. Q( c1 q
    40
    ) R  R) Y% |% c41/ u1 {% F$ [' _0 O
    42; o3 G3 G. }3 o/ l, b. H
    43- s' A" L/ R# h6 b
    440 W: [, `9 ]; I/ I. }7 j5 K
    45+ e2 Z' p, e( y- Y- f0 x5 b
    46
    & K1 f: v" W6 e6 n6 D  d47
    6 Z0 l+ f; L3 f0 e48
    ; ?+ e* S$ G4 I( S49* F( J* g! U/ K' L/ l7 R) ?/ Z1 m
    50
    * X; D% {9 C+ f4 k511 H4 |  p! S1 J2 _7 E1 H
    52
    ; M' I! T5 j3 d$ v6 p53
    7 m" h4 i! B5 p$ r! z/ y54$ ?* }8 D/ L+ H5 h/ n. a$ w3 r
    55
    ) C7 y( s8 Z7 D% L$ d5 a56" `% \* _9 S: Z) O" |" ^" s
    57
    . j+ K% C# a, A58
    9 ^2 w6 F% m8 x) p  [6 `591 S; o; s# o" ^5 I
    60
    % j$ A' J1 T) S/ Y6 H& }61) A/ X/ C$ t/ v- M
    62
    2 ~& t0 o8 |7 k633 w1 B0 q5 _! ~
    64
    1 O, {# H! n, k/ @6 b7 r" d652 G/ v0 G4 K( W+ ]0 J* u7 g
    66, V% y. r0 f* {: ^2 T
    67: p. J7 t' E' Y  n/ ^
    68
    ' }7 t* t9 U# |( R5 R- d) x695 u. l  a/ }1 Z( c
    70
    ' X* l5 R  _  u9 G) F71
    6 k+ Z& x! R+ O6 ?# P0 n72
    9 e' U# l+ ^6 v73
      N+ _$ {7 t8 Z/ V: G74
    + s2 T( N+ a5 ^! l) ]) M+ L75
    & {6 n1 n5 o0 V% Y5 l. Y. \76
    1 L: ]1 `5 T. ~77
    + ~* N9 I3 V- a& ~78( G8 T0 l5 q+ S; U+ I  o
    79
    2 v( |+ F0 ^, F9 ]; e80
    8 |5 c) I; |7 T# \% Z81
    3 E1 v3 q7 N" O/ C- c+ s825 r  L/ V% f) K4 J" E
    83& [* X3 U9 e- f, Y
    7. 设置需要训练的参数# H# Q7 i6 K/ L& s/ U' y
    # 设置模型名字、输出分类数$ F' j$ j: R) G: ~8 v! j. m
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)8 V$ k6 E+ g/ b& g# z
    9 B1 E4 x- Z9 W! P! e
    # GPU 计算  w4 ~% V# P+ J" `6 O7 W# s
    model_ft = model_ft.to(device)
    1 X# h4 o9 D- y2 e$ X4 e/ i- g/ E8 g: V
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    0 `9 q1 Q& h. P0 z3 Vfilename = 'checkpoint.pth': _2 J( ^# ^) Z  _% Q
    : x6 B' V  k: ]) y4 y
    # 是否训练所有层0 C8 y5 g2 @) J2 V! C$ p
    params_to_update = model_ft.parameters()8 @- ]/ J. r( ~  N
    # 打印出需要训练的层
    9 K- E* b+ q0 xprint("Params to learn:")8 D: O3 S$ }) b& C& }4 i
    if feature_extract:$ ?+ b/ I) G: U/ V: L
        params_to_update = []; i7 r1 V( N" V. o; c2 |0 @& j
        for name, param in model_ft.named_parameters():5 O; p" _2 q- I7 h9 m, Q( B
            if param.requires_grad == True:
    ' C* I/ M3 `; E1 K! x# O            params_to_update.append(param)9 C$ M! O- j( v* `4 M, }/ y3 m- b
                print("\t", name)
    6 X& n+ Z# w( j' c" [/ }else:+ ^- b# W" b( o8 Z, ?
        for name, param in model_ft.named_parameters():
    9 L' D1 v+ @7 [" U  I1 {8 b  X        if param.requires_grad ==True:; a- H4 ]& \$ f5 ~
                print("\t", name), v& j7 L: [5 i& \' D
      }& a, y+ i. e2 ]
    1; b9 i3 v6 N: H8 U# ^( D
    2
    ; W  ^) r* e8 w5 ~# e* t" r3
    1 n/ }" z& V" X: A4( Z- h7 V& v  [9 a
    5& T( M' \6 j. a# N6 j/ x1 H  O( ~
    6
    0 C/ s% i) r7 I2 w1 _/ u8 P7
    5 n$ G( N" v/ r) D7 e8
    + E4 J, L1 q* E4 k2 {/ L9
    5 Z) Q. O' s+ \; w+ D7 p, n10
    ! G' O4 C/ e8 K11" W9 J4 X; u9 }, x
    12
    # D6 R1 ^6 D7 L& D8 {  ^6 U+ U6 _8 o13! U, x) y& P8 D
    14
    " B+ U& O* |) V- t' l$ E) z8 {15" U- x6 J7 ^2 D9 ]  A- n  F% @
    16. V2 M/ n, ~5 d" u' _1 ]
    17
    / j/ U' x" L" _4 s5 I18
    9 v+ X$ J+ d  T) n5 ]/ X$ h19
    , _( r# j9 K$ H20
    1 I! z/ E8 C& I+ U  J$ q21
    5 {- a2 e7 N+ h$ U22
    / g+ W1 J4 O% q" |9 k* s, S. E/ n23
    ) a! |5 I  k+ l8 ^Params to learn:
    ! B9 l/ j" y# J+ Y1 N         fc.0.weight
    $ h- s1 L# t8 g# X         fc.0.bias
    ( y0 _* c; }- }5 r' o7 J1" y( C  c2 u3 C% n  ~
    2! n. g& K) ^/ S. z: |/ t6 B: V
    3; J3 J6 t: a  S( q
    7. 训练与预测
    4 G+ q  Z  |/ l( H( i4 n  A7.1 优化器设置
    - g9 L2 U/ ?% k  @) X( l  m# g4 w5 A# 优化器设置
    4 s0 b/ d; z, q5 @optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    7 f: I  e# R, C. v# P# 学习率衰减策略% M% t1 y( ~# Y! L1 H8 A
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    : H' u# L8 ^: T# 学习率每7个epoch衰减为原来的1/10
    , p- f+ Y1 m2 _8 O* ^# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    % @; w7 q) B/ g- E8 G. T
    ' Y' V4 @$ j% [; W- ?criterion = nn.NLLLoss()
    ( E8 Z1 {; D: s3 v" i; [. o" ?12 ]9 e& A  n5 w! L$ }
    2
    5 X+ ?8 O& \( Q4 [, o. ~; n3/ d2 L+ C1 }/ [* w! O
    49 [, Z% F. I; T( V
    5, Y# j% g% E  Y4 F% t  @5 n& A8 r& {
    6
    % z: Z  C# ^2 u7 G7% X# g1 e& E5 v. m( Y3 Q; V& K
    8/ K8 P  k: r; s
    # 定义训练函数/ m) |3 h+ S5 e3 r* N2 I0 `
    #is_inception:要不要用其他的网络0 z; L9 _' X! G2 l! v8 n' u
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    3 }' D# L6 R, h8 `+ z6 ]% z1 _    since = time.time()
    $ d+ r' a4 K. B* K% d! M    #保存最好的准确率! l' O: G; O% ~& |2 u0 r
        best_acc = 0
    0 V( h( p2 u  b6 f" s( ~# v    """
      k( z7 `) Z+ K3 S9 k, r$ L    checkpoint = torch.load(filename)9 ~& k: L4 v; ?: W. `; V
        best_acc = checkpoint['best_acc']
    3 N+ [+ n) c5 g: \7 T8 H    model.load_state_dict(checkpoint['state_dict'])6 Y' d) B% t+ h, R: l. c% J8 q& u
        optimizer.load_state_dict(checkpoint['optimizer'])7 [0 u* d  z; t  F  i
        model.class_to_idx = checkpoint['mapping']; ]$ v, M3 ?1 l$ y1 S1 i/ h( ^
        """' E( l. q7 ?$ y" ^8 @
        #指定用GPU还是CPU6 A1 t1 A* w% [, v- P3 ]9 Z4 {
        model.to(device)
    " \' q! M/ c9 p$ A% j- g) n: {    #下面是为展示做的
    / X' C9 b0 p7 z  i    val_acc_history = []8 w; t6 v" j9 O9 I7 S# w+ ~- j* Q& u1 c; w
        train_acc_history = []
    ) a: W, [% d/ U7 h& s0 G9 n+ C. I    train_losses = []
    4 u  g- D6 O# C. F    valid_losses = []
    0 v5 H) y: k& _, p5 `+ p6 n! f    LRs = [optimizer.param_groups[0]['lr']]
    . a! J" K% Q0 r' g4 H    #最好的一次存下来: Q! S# k% |' g0 m" j9 q
        best_model_wts = copy.deepcopy(model.state_dict())
    + w6 m( w8 e# A9 ?# A5 [
    + A. b4 o, ]0 m5 j8 k    for epoch in range(num_epochs):. `0 f, j/ W3 @1 \+ p% A
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    4 w5 @; E# }! `  J        print('-' * 10)
    0 L7 o3 F7 X0 M+ q5 _$ e/ Y# n  F* e& @. }0 h
            # 训练和验证5 C+ O7 g. ~1 R3 P
            for phase in ['train', 'valid']:' @7 k1 x+ ^, b7 h
                if phase == 'train':3 [0 F% Y6 s8 a7 g3 ^- m
                    model.train()  # 训练
    8 @* _% Z* ?! T3 P            else:
    ; i* y! @& A- K                model.eval()   # 验证/ j5 }) h9 k4 e7 k  x- x5 i

    " W) J/ {$ [! t: O1 v6 J( D            running_loss = 0.0
    - a/ n7 {+ l# ~: Q6 s' K! i: {# r            running_corrects = 0
    - c* H" l6 s* U3 g! o
    - b- U8 W$ Y" ^; N1 f$ q, F            # 把数据都取个遍
    1 S* N1 L- `# S2 w8 p            for inputs, labels in dataloaders[phase]:
    " s5 O- M/ P" S  j                #下面是将inputs,labels传到GPU! w$ j- v# d* h  z
                    inputs = inputs.to(device)' u. {; J6 @+ f! w
                    labels = labels.to(device)
    6 e; {: [, l/ P$ z5 n, p
    ( A  c1 g! C+ U0 _9 k: J# x                # 清零
    ' b! ]6 ^; ?+ B' T$ f  O8 w                optimizer.zero_grad()
    ( }. G1 Q1 ]# S. J7 u3 y7 _                # 只有训练的时候计算和更新梯度
    7 F' m8 v" m& u# F+ M- J6 s& B! f                with torch.set_grad_enabled(phase == 'train'):( x  D5 V, w% [7 N/ N6 z+ |: P7 ]
                        #if这面不需要计算,可忽略: e, F; I, T# M" s% u
                        if is_inception and phase == 'train':
    2 l1 F/ }5 l9 h7 `$ Y; V                        outputs, aux_outputs = model(inputs)
    ) B& }1 E1 B; _9 x3 u% Y  D" ]                        loss1 = criterion(outputs, labels)
    4 q* ?; y. q# Z7 C. o* C- S                        loss2 = criterion(aux_outputs, labels)
    9 l. [7 t& ]! T# G; X                        loss = loss1 + 0.4*loss2
    : K' g) G' }% d- a$ ?                    else:#resnet执行的是这里
    2 f4 K7 j* Q- A/ G                        outputs = model(inputs)- y! [4 p6 ^1 F  {0 \) E( X1 V
                            loss = criterion(outputs, labels)" \, {; F/ a( P7 C+ h
    ; _: u+ ?7 i  c7 E
                            #概率最大的返回preds4 J) P5 m, l/ b- X- ^
                        _, preds = torch.max(outputs, 1)4 J2 ]9 b# F5 `0 S' |; T4 X9 p, K
    ; ]7 }* c+ i* k  o9 O: F. l4 ~
                        # 训练阶段更新权重
    2 z  g2 Q; R5 o* }6 O                    if phase == 'train':
    # k4 Y0 x' l/ z. }6 Q                        loss.backward()" ?0 b( u/ x- E  d
                            optimizer.step()2 w  u9 C) U8 ~" ?$ U
    + c3 ]' D7 v1 \7 l6 c& D
                    # 计算损失
    4 I: x; X; u; y                running_loss += loss.item() * inputs.size(0)
      Z2 f9 y6 U1 x0 E1 e                running_corrects += torch.sum(preds == labels.data)
    ' k' L  W* Q9 j5 C) r. c- M% X3 L: q) S3 Z  v" v
                #打印操作
    : X( d7 f. F) J# p+ Z2 F            epoch_loss = running_loss / len(dataloaders[phase].dataset), E) V4 i: o% X- U3 V1 _
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    1 Z: q1 g. d1 ^1 C6 T" T% p3 H6 K5 }" G- l6 \& v/ f  I+ g
    2 b. k2 }$ }$ o/ R% k' L
                time_elapsed = time.time() - since) J4 M0 m( U; P4 V
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))( j5 V# D9 _# L9 n: ?8 c9 s/ y
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))( l; p5 b1 X, E0 }6 S7 }

    ( o5 ~- l1 J" j/ @; C2 p& d8 L" J3 j6 j, j/ D# [) O8 H4 h" B" r' ]: k  S
                # 得到最好那次的模型
      b' K9 q$ h1 e+ f            if phase == 'valid' and epoch_acc > best_acc:
    3 h1 x6 s. F" Q- e9 o                best_acc = epoch_acc
    ' N* u- \: i9 p$ c                #模型保存
    7 M" V0 i2 o( y/ i                best_model_wts = copy.deepcopy(model.state_dict())
    , P* B' ~) M, m2 m                state = {
    * d$ K3 ]. ~5 F* T  G                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数, @+ R0 R/ l! e8 s0 m/ ]
                      'state_dict': model.state_dict(),
    , X0 u. }5 x4 T0 H% `8 }+ `                  'best_acc': best_acc,
    # ~0 v8 A; B/ o% ^0 T# Y* W                  'optimizer' : optimizer.state_dict(),; K9 B# r- \  P8 d2 ^; ~) o  t
                    }$ x, `. H, ~+ c* Z( N* u4 g" q
                    torch.save(state, filename)
    % P9 r! x3 B% K% M6 Z            if phase == 'valid':
    9 \1 x$ {  H) O6 E" a3 P2 P                val_acc_history.append(epoch_acc): i5 ^: @/ Y" J3 m5 Q
                    valid_losses.append(epoch_loss)
    4 v8 J& w# |# v0 ]/ m+ Q1 k                scheduler.step(epoch_loss)6 W* x$ @# e- |
                if phase == 'train':+ A$ u. @0 i% l$ Y
                    train_acc_history.append(epoch_acc)
    8 k! {9 r- _6 I1 |2 h+ f+ g" w0 Q& u/ [                train_losses.append(epoch_loss)
    - {' }/ a% S9 g
      o4 g5 c& k& f8 D        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    4 h8 R. T- P% a3 @        LRs.append(optimizer.param_groups[0]['lr'])
    0 S- C& D2 c: z2 w        print()
    7 T7 Z* @) T# ^  t* u2 B& G
    ( |) q: ~( N  s; {    time_elapsed = time.time() - since$ T; e& x; Y6 }
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
      S$ v- @- G* x( a    print('Best val Acc: {:4f}'.format(best_acc))
      n1 T- m* U- ^: v. a3 ~2 ]+ L) F
        # 保存训练完后用最好的一次当做模型最终的结果
    7 F3 |0 y  u- [1 Y    model.load_state_dict(best_model_wts)  i, O  O# {7 n" H; _3 `
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    + S1 ]2 M  ?" a  m5 o. h8 o: [9 L" i

    ( J' U, g& ~* T% Q/ a3 P( y1
    3 i* Z5 t9 \! Q4 S2
    / n: ^$ p4 L6 m. r1 A: |" {3  H; q2 n! u/ n" I. ~
    4
    0 Y& k2 ]: Z' a& G0 B  F2 M5
    & J" ^* q3 y+ t( d5 j2 `! |6' @* y3 T4 H3 z" M$ a3 D% f8 Z
    7
    1 P: R% P& v; u8
    , |5 }  i+ V; p, x* I/ b( r9% A5 X" _( L) g& e
    10
    ) X8 Y) i' c. V+ u' i11
      Q5 A3 T7 A- r: P12
    & L/ J8 W" W# c. v, N13# O% q# J% ]; F: }& i
    14
    ! x9 n7 p! q4 ^15
    , ^( Z; i2 e  m( A169 t9 ^" O5 G# L; A# u
    17
    ) Y2 Z: p4 w3 F8 f0 r$ Y! B, ?18
    * u! n) J' |% T1 H$ n) \2 s+ }196 C. E, k1 _1 G3 |9 H0 U7 Y" x' J
    20
    % z9 i* P/ d* }5 s/ ]4 J21
    % A7 u7 [- a& T22
    ) z/ V1 s4 ?4 B  ?; T23
    0 }. K! \. i; B9 ^& X0 d$ Q! z24
    ( W7 h2 {  f$ P/ I. H25& D! A7 P5 Z6 u/ C# U+ ?/ Q6 z7 M5 ^
    26
    ( g. [# [% n' G8 E8 ^27( g. e+ T* x! \5 w8 P$ C2 ^
    28" T! f! G+ N* w8 y6 d0 {
    29
    ; c+ D* \! `  q2 ~! i7 m: k30
    ( t/ k% |' k  w; H8 w5 T31
    . L! h9 n6 P0 d' S324 P9 Y$ L$ b) C8 Y
    33
    0 l: K6 J# n; f, ^; ^: B34
    : p. a: Y4 m: N9 L$ }35( [& V5 |' E" d- T
    36
    ! X& K- @( q# x+ n) ~2 |9 [# d373 r0 @/ S7 k* e0 C, N1 y2 j0 F
    38
    & |/ N6 D! K8 M% h; H+ U3 a2 w  o393 z2 r9 N9 T( \
    40. t: i; B0 B6 @, T5 M" W0 C
    41
    2 u( C  L1 Z& d' j: E( L* a42
    ; y. Z# R5 m9 @" T. U43
    8 L+ i5 o! B& a+ ^8 S44( p# Y1 r) Q0 ^% Z5 k3 R. R9 d
    45. u9 `% r, p% W% W" s/ g
    46
    . e8 G7 [! d0 |* T47
    & l/ b; t8 q( ?. E9 J! M$ ]* N, ?48
    ) u2 B; J7 c. J. s; d1 p49, ~: s. o7 ?& Z  C" a
    50
    # {# s" `" o0 l/ f. H51
    5 ~9 ^, ~/ f; u7 w7 d52
    5 n% [* g1 K9 N, p# ]53# K& t" T" O  X
    546 {, c" P0 @0 w
    55
    + h: [* c! A- ~/ Y# {56
    , S; W# ~6 |! |/ z# l- ?57/ N$ W  R4 [( G8 W, P2 R9 B
    58# v. r0 I/ y1 m! {3 o, d, M
    59, V2 v6 U, h8 H6 o7 t
    60
    1 @; X: q& a- g: r61
    : R& F6 K7 r0 O  I7 T; j* i' U  F62" \! G  t+ }) U3 i
    63: k/ N, L, l! S0 \: o6 E, d7 G' j# X
    64( r6 A6 r' C- L# n) ~0 _8 o
    65* p* P" A; K) J+ t
    665 z; J( ]& t; X. p
    67, O% u: V7 v& O2 V' I
    68
    + X# w9 {3 a# N% ]692 `0 `0 R8 u: Q- F
    70/ W& _5 O! h+ H
    71
    $ e# @* p7 k6 X72
    0 @: A  ^2 m) {' ]7 j' O5 L73, _1 V' D2 ^1 i' _/ c/ a# {
    74
    9 ?9 K! Q+ d* H( R+ `+ g75
    ! b% W3 Y# T. w# O( D% X763 @  x5 S  }5 [9 H  Y. O5 z
    77: o6 a3 X1 W1 L/ r' v
    78
    " Z5 _/ x: R+ E2 u% y9 p4 S" n79, @6 [1 m6 G* h5 L8 A3 C4 t+ v9 c+ K
    80
    " O9 e8 b. s) c* [5 ?2 c0 V) L81
    2 V1 h0 l0 D% S824 c+ J+ [/ B) V5 }2 t7 Z7 r
    83
    ( b9 o4 ^; s& g$ f; Z84
      m* R! V) N- J+ o6 b1 X, ], v85
    + F4 T' Z! h$ A- e( L/ I- q86
    1 D' F; }# }) O) ~3 N* h9 x87( s$ L( z/ q* M' |# v
    88
    . Y) L7 R- R2 V+ @, e* O( P896 @9 ?' I% Z* c2 F, x
    90
    8 X+ _# q$ U  m: c0 I, m) N) F# E91: L  o4 S3 @% c# x
    92/ k8 ^, [/ Z# [3 o, H/ M; Z
    93
    9 B1 u3 U& }0 o0 I6 }" V1 k- d94
    " C+ v" j* m/ j8 \# d3 P953 _7 }) `5 c6 D- ?) |0 O& K
    96
    3 q) h  _, E$ Y& z972 ^6 O1 i: Q2 Q# p# ~
    98
    / V( R0 Z; ^% O9 T, Z99
    7 q' S* F6 r% J2 u( h- I7 z1007 ]9 F) E& N1 U& q
    101' y8 [2 ^- F; Y3 A5 L( y
    102
    # w5 y, ^2 `8 g0 q1034 b9 t. k& u$ n1 \
    104) c! q: S. l0 B) X
    105: @, u- `) Q" n
    106
    % I' c# _8 `9 p  a; U' q107/ [1 f5 a) t9 P
    108
    ( r9 J! K9 U6 I; X# {; W109
    ; Z2 ]" D5 u) I% n110! _, a) Q* {  r
    111
    ' b) f; r' I$ |& \112& f- S' i/ x$ ?1 i/ t+ O# g
    7.2 开始训练模型3 P/ t( z2 L9 r( z+ y
    我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    1 K+ o0 f! W2 t! b* \0 C  ]
    % t7 q) J' Z( \1 ?#若太慢,把epoch调低,迭代50次可能好些& E& `" p6 H4 N6 V+ Q
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    ! X3 K5 N7 A1 ~& G. lmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))! @; A! Q" q9 X: X1 ?7 @

    7 e3 `" g, J& s( m" C+ E* n1$ Y/ p; V$ ]# J* b" I% J& J
    2
    ) _- K" S  p" J% n/ C* L6 X3
    5 D1 q% F6 a" n3 \- n- b5 L/ @. T4
    : r+ ?. U5 b+ K' v3 N2 f  cEpoch 0/4! v9 t0 j4 {) j# n$ n# |4 m$ H
    ----------, b' ^3 b* ?9 x; e( ?9 q" R' C
    Time elapsed 29m 41s4 o! ~5 s& ?  h! g) S+ V, ~
    train Loss: 10.4774 Acc: 0.3147  n: f- j. B! L1 Y* Y: ?
    Time elapsed 32m 54s
    / D5 A3 z, k) N: V: ^valid Loss: 8.2902 Acc: 0.4719
    7 ]# H, x9 X0 S0 e9 ^. ROptimizer learning rate : 0.0010000, T/ O$ }4 a& |+ b$ y% E: A
    ' s3 _9 Y7 \* Y7 A; u0 m
    Epoch 1/4
    % V) {, B% t$ a# B; L2 w% V----------
    7 [/ y; t3 n; P6 ]# VTime elapsed 60m 11s
    2 J2 v1 T. ^; v: Y1 Strain Loss: 2.3126 Acc: 0.70530 x/ @) `/ M6 F' }
    Time elapsed 63m 16s# l5 W1 H! @" j$ F
    valid Loss: 3.2325 Acc: 0.6626
    ( V0 n0 d0 x9 `; P) OOptimizer learning rate : 0.0100000& p+ l6 m7 [. f5 d& D
    4 M& C, B+ T# _
    Epoch 2/4! U, c( j9 J" B5 M9 \7 ~' W
    ----------
    # k! n$ t; u2 K) g, S# F8 tTime elapsed 90m 58s3 U6 X5 n: Q1 i  P: Z  E$ Z, p
    train Loss: 9.9720 Acc: 0.4734
      {* x; G- P0 X: x' xTime elapsed 94m 4s( A, `5 E: D4 x) K
    valid Loss: 14.0426 Acc: 0.4413
    ! r, M; r& }* aOptimizer learning rate : 0.0001000) K( U! R" }- @3 {# J6 i+ _
    ! \3 a- A% l. H) T% ~
    Epoch 3/42 u- {  M3 p# L& j9 _: P, u
    ----------; C$ a9 o1 f9 ~5 {) H$ a1 K0 `
    Time elapsed 132m 49s. i, o+ d" s5 z* P, s) _' B
    train Loss: 5.4290 Acc: 0.6548
    ; G. v( N5 x% {, P8 \1 [: UTime elapsed 138m 49s+ }2 @! n1 F1 I4 h
    valid Loss: 6.4208 Acc: 0.6027
    ) W0 h) D$ m5 T% hOptimizer learning rate : 0.0100000
    " Y( s  R6 Y! ~0 @3 S( t- w! f
    , U$ v' n6 B& w6 B# l  p/ SEpoch 4/42 a# k% v0 H* c* F; B
    ----------
    - j9 T$ O2 r! H/ {) r8 u0 DTime elapsed 195m 56s: W& E6 r8 t8 W
    train Loss: 8.8911 Acc: 0.5519% V( P9 k% ^* ~, U) x2 S
    Time elapsed 199m 16s0 u/ }) h  L. s, q( L( b
    valid Loss: 13.2221 Acc: 0.4914
    0 J0 o7 ~3 U- N& A3 |0 x( qOptimizer learning rate : 0.0010000
    / ~* q/ t" Y1 X" j/ _* R
    5 @' J7 v8 s9 m! h. l" e7 ZTraining complete in 199m 16s
    " T+ A) g# V* QBest val Acc: 0.662592
    & }+ B1 ^1 G% j$ S9 d/ N5 d* K, P; |$ J0 l6 r1 Z5 U4 L# {
    1& w4 Q: o* u3 G* L8 {' U
    2$ J( O# R$ Z* e4 o
    3
    0 b4 _$ X0 `/ Q4
    / P7 P5 I* \3 H/ {% a5
    3 M/ W" k$ i! N/ w6
      g' G* s. m6 x2 R/ Y7
    9 h6 u2 a6 X# S# [$ W. A8  U* T4 d1 {0 ^5 ]; O5 h9 v+ [
    9& j9 Y1 P. t7 m2 S$ o2 `* d. z
    10
    / |6 n8 W: r" j  s11
    ( ^3 O% a; r; O$ U2 d12$ d  Y4 H6 N6 d6 [
    13
    $ q6 B  N: _) p9 I0 n14
    * ?5 b" K/ ?( p) f5 C) B: K6 v156 |" l* u0 I: r% I% u
    167 O7 c1 n* `8 n% T  N1 ]
    17
    ! K- W. m( S# P# M6 |) [18
    & j0 S( M1 u' C7 O8 V19
    # V6 A; K% u& k  X% n20
    , i9 t. r* f+ F- u$ I0 a; `215 D) D5 i, B$ @6 B* u
    22  s- O+ E1 A2 j( F8 Y
    23
    . ~* M- \/ F: V) p) Y7 _24! S# n, ~, h8 U
    25
    : ]) e, f. i0 x" Z2 V% l4 r26
    8 o% U5 v* X" b% ^8 j7 I27( j( J- E% }5 O1 M5 O
    28% `2 }( z3 Q" @1 C) S
    29
    6 M) H! |# A6 Z- E2 n1 t30
    6 F$ P  i  u. t6 C31
    2 E& Q! M+ E6 w326 o1 R. L0 _! a# s: \3 T
    33
    8 n; N0 p  p; ?* Q34
    / T- z1 U) y+ O7 \$ h! g35+ s1 _2 k: Z4 u1 w( f& L
    369 l0 q" U5 `- {
    37
    * f4 }8 K$ d' a. T2 t. B38: d( l" W  O+ I$ l0 O
    39
    3 [# u8 k* Q$ {$ U# W) Q; b407 e2 G$ z: @" I/ X2 T( S  b; H
    41' m8 g' M+ M; y: e3 q0 J& @
    42$ h& I- s* I0 o; L% L" x
    7.3 训练所有层. R7 X$ a) r( n1 x* E
    # 将全部网络解锁进行训练
    # U- N4 Y0 v+ D7 n* `for param in model_ft.parameters():
    % S6 |4 y* N2 {$ ?; W$ E    param.requires_grad = True
      Y  E; }8 h$ Q0 i/ {
    5 Z: X. W$ u6 c* f8 e$ c5 Q6 d+ Q# 再继续训练所有的参数,学习率调小一点\
    1 ?4 N; l7 ^; l& T" `optimizer = optim.Adam(params_to_update, lr = 1e-4)" w5 X1 G/ C' M! L0 I
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    " r  l! z2 o- z! J% R
    * J- h, L6 ^4 q1 e/ X* e# 损失函数
    ! C% w. R% l* q1 k4 v& h& ]criterion = nn.NLLLoss()
    * L5 }/ N  u( z, l1
    % S' E) z6 J: ]! s2
    , r, m6 {  v6 ^* E' G( S9 a32 M! u5 S( x( f1 G  _" u' d
    4# s9 G4 A. _. }8 q0 G2 ?7 [
    5& k/ x) W. Q! O! r1 T8 H
    6
    8 _1 B) s2 C: J! p' z5 p7
    # t4 I* R( }: B3 S- C# ?" z+ Z84 {1 ~, Q4 {* R, b+ X
    9
      h1 X2 ?  M! K3 H9 N( l9 y105 m( V' S8 H2 v
    # 加载保存的参数
    # e1 n" {3 X9 G( k0 N: B# 并在原有的模型基础上继续训练4 p# S8 v5 P. h+ ]! m
    # 下面保存的是刚刚训练效果较好的路径% z( \7 ?2 U2 k& G; x
    checkpoint = torch.load(filename)
    . |' O0 Z  `! F) D# Wbest_acc = checkpoint['best_acc']
    " H: O, [  |  W) F; r+ r- h* Hmodel_ft.load_state_dict(checkpoint['state_dict'])
    - J7 C8 `/ W4 k: X8 M2 boptimizer.load_state_dict(checkpoint['optimizer'])( m( ~! w1 q) @
    1
    . I. u1 ^! Y! G# `. u  o2  [  R! P. K" I1 K; i
    3" U6 ^) g8 c1 I( v" [# ]
    4, i$ U: _- U0 P/ t
    5
    - T% Y" ~! z9 ?, J- H6
    # z) @/ l7 }) d9 `; g$ [7
      k. G/ H8 K3 l& X) |开始训练2 T5 b& M' q) Y( x! W
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    ; `# H- W  y0 }$ T0 k# h$ l: K) Q6 i% P/ ?/ X7 T# R
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception")). A  _( Y* C* p& i
    1- s7 M1 ~: }" g1 B( }
    Epoch 0/1
      F/ p, n0 o4 B* G2 M----------
    " Z% m9 O) _: e$ ?% E$ x/ |( ETime elapsed 35m 22s* t$ u7 w! o& q- P: o
    train Loss: 1.7636 Acc: 0.73461 v+ f. }, d3 j  |. w
    Time elapsed 38m 42s
    2 t6 a% [4 R2 F3 V: L5 I6 Lvalid Loss: 3.6377 Acc: 0.6455: Q) C4 q6 c' t# M" z
    Optimizer learning rate : 0.00100003 h- C; j2 S  ^' A+ b+ L

      i6 z( D- B% b* pEpoch 1/1
    $ L8 F5 c" D: i, m* G9 W1 B8 |----------' u- C8 Y5 l& B: F, n, q
    Time elapsed 82m 59s
    ( ^+ Y6 v' q/ ztrain Loss: 1.7543 Acc: 0.73400 g& r7 n  w" x: H! G1 b
    Time elapsed 86m 11s
    2 I9 ^! Y) `8 ?( y' S9 qvalid Loss: 3.8275 Acc: 0.6137& A$ q6 l9 L  ~
    Optimizer learning rate : 0.0010000
    " h0 I8 G9 w2 b% H7 y( K* d1 Y; D6 i) r# i) G, G* V0 m" S& l! H
    Training complete in 86m 11s$ p$ G8 i. D" [: z. m' q
    Best val Acc: 0.645477; ^4 C! P# c  x) u5 q9 _6 e

    7 P1 U9 Z+ T- j2 h1
    * x; y/ W2 A( W! t2
    & A# M: }: ^2 r3* s  _+ |4 P1 H( S! n! N/ p9 A
    4
    ' r5 c7 N! V$ z7 k6 h& Z5
    1 F, c  L8 ~4 \0 t6
    3 W5 B/ Y8 @: q4 M$ |) G/ R7( V2 F: l) V$ d; p; G) Q
    8  Y! i  J$ V( o2 ~: p
    93 v# W2 I" G4 o
    10
    $ L" s" W5 u# `" e3 o9 z# m1 ?; S11
    . _7 c+ J# G' r$ i; D& d. h7 N0 r125 q' w7 a8 `9 v& r) t6 S
    13
    ) v0 f# i9 K* m7 ?' Q14
      u( J3 |( Z9 x. o/ f15
    4 M; r; E/ `3 S- `' q4 ~3 k" A16
    . _) I; D0 Q1 m- v* \17
    6 W) ~/ }2 p; {! k- K" ?5 e# U181 r1 @% j/ A  G: T
    8. 加载已经训练的模型
    8 i3 h% I7 W4 G/ ~1 F! s4 v相当于做一次简单的前向传播(逻辑推理),不用更新参数
    - l6 ?; f9 a" Q" I3 }2 A  I! ]- n/ d* O; b4 G' J* F# i7 r+ y# V
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)* a& d) o7 {+ ~  O  C$ Y
    ) A- M: E; z9 p" w& o' ]
    # GPU 模式
    # }  b- a/ M- @0 y4 H* T3 l/ F% jmodel_ft = model_ft.to(device) # 扔到GPU中$ Y/ p' b4 _$ e: u" i" a+ z
    - k" @1 B  Q( Y1 K( O2 S3 A% u
    # 保存文件的名字1 r' }1 W2 X& L$ I3 r' ]
    filename='checkpoint.pth'/ b' U# f( G( r
    + i3 T* M5 L) _  k* w1 c
    # 加载模型
    ; X* R- Y. D0 V6 i+ Gcheckpoint = torch.load(filename)
    " X, d7 l5 {% K1 h# n. Wbest_acc = checkpoint['best_acc']
    ' l. |" P5 V: e/ M; y" k3 gmodel_ft.load_state_dict(checkpoint['state_dict'])
    & m% C9 M* x- \1
    8 _2 Y  d! j% \1 m1 _7 c9 J2! d: a6 G2 n" K# y4 v' k6 S1 v4 L
    3
    8 l% P  t7 P  C/ l6 |, O8 L4 F4
    # N5 P& I4 {+ N2 ?5
    % J4 \# r; L4 p6 F( ?4 l63 B5 A( T) m  ~6 D- \
    7" l; u3 q/ n) [* Y1 r/ e! u  u
    8
    ) h$ M4 @% A# l% Z7 e; x2 b9
    0 |0 A: Y2 M; E+ l4 }$ c10/ G7 A$ S; J8 L: I' u
    11( F% v' _2 M* v6 I  c
    12! K( A7 F7 E" k
    <All keys matched successfully>& h: D8 f2 g' o* T
    1
    " X. R6 N: t6 Idef process_image(image_path):
    6 e- ~. j! G9 {    # 读取测试集数据
    / P: l5 p) Z( h# C! M7 U/ e    img = Image.open(image_path); ^7 J5 e1 I8 A- p( K7 P/ c7 f3 \
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    ( V  @. R# q  G9 [    # 与Resize不同  |: L$ ]# v" N) i6 p
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小6 e% E+ V& X3 M3 Z( A2 G7 B9 K# b0 m
        # 而且对象调用方法会直接改变其大小,返回None
      ~# j3 \! O- @4 o    if img.size[0] > img.size[1]:7 J* }1 g& h0 P# R6 E
            img.thumbnail((10000, 256))
    $ u: T% f  O/ _" \/ o# x    else:& l3 A' z' }! [! K9 x
            img.thumbnail((256, 10000))
    . k% G  s: ?4 `* B  {# H4 G/ m) V8 }8 \. W& C" m+ T" J( Q# J
        # crop操作, 将图像再次裁剪为 224 * 224- b5 L. j4 a, l0 c, B
        left_margin = (img.width - 224) / 2 # 取中间的部分
    ' q0 i! D$ p4 W8 k+ ?: J    bottom_margin = (img.height - 224) / 2
    8 h; }1 M; n1 T* j$ _    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    8 [+ l$ R7 x8 {  f- e    top_margin = bottom_margin + 224
    0 Z6 s* h/ K& Y0 B- [( }( p7 k( F  |+ X! v' H6 K. G
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    % j) ^1 s8 J2 J; p* U0 q" z# _) i4 A! t
        # 相同预处理的方法( \" o  ?# j! q/ |8 h- o9 z
        # 归一化
    + n/ v$ \5 A7 w. Q6 p    img = np.array(img) / 255' H4 k2 v: m0 J: e% S
        mean = np.array([0.485, 0.456, 0.406])
    . T# s% _8 f& ^5 A- H8 j9 N    std = np.array([0.229, 0.224, 0.225])3 Z' }( L2 n- \1 @7 @7 {& A0 V
        img = (img - mean) / std
    9 m1 d* i6 F: }! z
    ; {# V8 H% U! B  e1 E- e3 W4 ?    # 注意颜色通道和位置
    $ Y0 m8 i* J5 R0 o1 k5 Q1 Y! B1 n    img = img.transpose((2, 0, 1))
    : b% \# J% L4 L# H( e/ I) u  T4 S7 v$ V6 U& h5 f8 ?7 X
        return img
    : P7 T+ g3 l  u% W. t/ b, ~; R  _# i7 i7 I  j0 I4 \( X
    def imshow(image, ax = None, title = None):1 l# S: d& |1 B+ e1 Z" P
        """展示数据""", p  h0 ?+ e1 u, Z' O$ C
        if ax is None:
    6 ?- k4 ~4 w1 g, |) Q        fig, ax = plt.subplots()
    . }- e+ d/ L( r8 M* _: S, H
    % ]$ E( L% K. B7 v- }    # 颜色通道进行还原
    ( d% F& ~6 X) u1 W    image = np.array(image).transpose((1, 2, 0)). `( |. q% p. Y, H4 B2 v* E" w% |: ~
    9 O: t, H1 m  ~3 h% j) P$ o: |: Z
        # 预处理还原1 W, U2 X# I8 Y7 i6 F( `
        mean = np.array([0.485, 0.456, 0.406])$ N% a: A: L3 B- L4 r% B
        std = np.array([0.229, 0.224, 0.225])) _& P9 g  ^+ t( j* [1 n, g
        image = std * image + mean
    ) H& C$ X& ]6 a; I3 O; v    image = np.clip(image, 0, 1)- S2 N4 C% r8 ^8 c8 b8 ~* b) x
    + A' n& Y1 i) L9 O! [1 T3 y
        ax.imshow(image)
    + ]" ]% j' p  O, M7 F    ax.set_title(title)
    $ l; r( O0 r% C! c1 ]  Y, q  d
    * ~, I, v0 @4 F% E6 A" U    return ax
    1 m$ _, F( k1 f2 [) J- H% }2 h4 p7 S2 M* r
    + J! G9 s9 u1 {. n" g8 j3 pimage_path = r'./flower_data/valid/3/image_06621.jpg'6 a1 ~5 _# b5 {! A% K
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. x; ^9 u/ o! D# H1 B8 G2 g
    imshow(img)2 H% B8 }! m7 s1 F- p$ ]
    ! R. g2 p" g" c5 ]" Q' a' I! ^; t$ B
    19 O! B* _0 f, d; ]
    2& K* g5 o& A1 ]3 ]# @3 w- `! m* l5 O' M
    3
      B2 U8 w& c- H* \% r8 ^( [0 m49 v/ F) V/ i% u3 S: u) {5 ^( Z
    57 C1 k1 P9 w* P$ C9 g+ N4 W6 m
    6- |. I7 O0 c9 {* R' I+ A
    7
    " x- A* u/ t3 O% O8/ v6 i/ j  I3 g) C
    9" G! x8 D5 W2 c1 P! R! F, [
    10
    ! u: g: \' A# `4 K) Q% K11
    ' b' {# G. p6 J$ \( i% B126 o! w; a5 V: j: W. O3 N
    13
    / y6 ^, E8 }5 ^3 Q! i14# K( j5 u% L+ M) W" ]- L5 l( `
    15
    2 x  b8 d) m. O; q5 w% p0 W16+ Q( V0 h3 q3 `, q
    17; `& o2 U! m9 ^7 R
    18
    3 q- f; u1 x! Z6 {19% g% {& T  y6 O
    20
    + a& w! K+ ^% g4 B; g21
    * i  I9 q. T$ ~+ G3 N. X; D& E224 ~) {5 `/ C$ d" e
    23  O8 V& c! ^% G/ p- Y
    24
    $ p6 A4 d. i& [+ V6 b7 u$ }25( X: i6 R8 X2 j3 q
    26/ r1 p+ b% w$ {. e
    276 D0 i; z! u% }) a/ a5 R
    284 R/ F  R- C* U" n/ d+ ~
    29
    , y+ a# ~# x) h% w' G6 L! R309 a* c! P. E% S
    311 L8 c4 w0 y$ a. }# L# k
    32/ K' c4 U! ~7 {" b$ Y/ j
    33
    3 s  |; @2 L: H1 B8 U* c: z34. p9 Y1 N3 L0 D* `9 a5 d$ D
    350 s; x1 J( t$ B# G
    36
    - }9 X- f% J/ r5 b! x7 U  E2 Q37) h$ R7 A. o0 y& X/ f
    38
    ' M9 }2 u  }: i39$ k* i8 C0 V- i0 R& m4 I
    40
    * M% T5 y# I6 Y419 m5 ^* k1 q  ?; ^
    42
    # |# K9 [3 ^) H8 ^9 N+ h  {& v/ L43. X# ~( t$ _* I3 V3 |, v8 V
    44: m! o9 A4 _+ R  K' A5 a5 @
    45$ [% w4 N4 _8 O$ @& B
    46
    ; ], B9 v2 G6 T- i  r474 I2 N, P' a  E$ F/ u( ?1 p( k
    48/ f* C: p5 ?- x% C& C( T7 c
    49
    9 [6 m* V- ^" Q; I: K50
    2 Z0 n+ u8 ^8 Y& w! B516 M. p& x- V4 g8 V& [
    526 s& X; W% y. D, J5 V4 S- E
    53
    & Q7 h) |" V, }1 V  d54
    2 R; E$ D7 Y+ \3 D% c8 s5 Q9 y<AxesSubplot:>5 f! w$ |- _0 U% u, k
    1
      G3 J# e! K! {4 |+ u
    7 o: F1 |1 P+ U# R上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确& ]! q' c5 |" f# k
    + N. s$ w, S" p) P
    img.shape
    5 ?( E# }0 b1 h7 b& w( m) z7 P* o1* H9 E* ^# M9 e% {4 {% ^' y
    (3, 224, 224); l3 w4 n# \/ W% y- F
    1; k8 V$ s7 }2 d1 S) [! i$ e
    证明了通道提前了,而且大小没改变
    ' N/ G! d+ {& s( r( Q3 d% I" @
    0 J" V2 K' Z% t* \- O  d9 ?( D9. 推理. s3 @* y8 t0 Y5 g
    img.shape+ j3 O$ |" k7 E+ `

    / d7 o9 L  T2 ?7 g# 得到一个batch的测试数据, y, F1 C$ A3 B, Y
    dataiter = iter(dataloaders['valid'])
    " u' k2 _/ N; w' f' t! u3 uimages, labels = dataiter.next()8 X& V8 P% `4 B- ]6 i

    ' \' T8 Q3 d. `9 q: |model_ft.eval()# q$ @( Q, ?& K" ]+ N
    8 S' D) r6 B+ h, A% p7 K3 a+ c1 d
    if train_on_gpu:4 ]! Q4 l5 X) l# s* k
        # 前向传播跑一次会得到output
    . ~' B' [9 A7 s; l+ ~( q    output = model_ft(images.cuda())8 T( ^# f1 C, E3 W
    else:
    . Y$ W/ k  _& K( I: K+ F: [8 j    output = model_ft(images)% @- {( m% O9 x9 b2 p) L
    * X" |0 D% g- J( L) @! g
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值  j6 s$ M! p. Z3 E0 g2 A
    output.shape
    5 |  ?: d/ A: E
    , M0 }5 Z4 c" ^9 o5 @16 O/ L- a, W; r: Q6 K$ c
    22 F  j8 n6 Z: e$ E
    3
    1 l5 d/ C* m& M5 J9 t9 R7 }; c49 B, Y! Z8 s6 z" f0 J  m
    5) c2 a# @9 X: ?7 p/ X
    6( x- k: \) C9 J
    7
    # s$ @3 x2 P! t83 d& @4 w8 M& X% a0 s3 R
    9! E2 B/ N4 y& X* j7 w
    102 A0 f6 V9 p- m7 j7 l) M9 H3 I6 v  S
    11
      X0 J) m/ p* p3 f6 @120 u: g8 \6 K. s# f( \* h
    13
    1 {( a" l/ ^1 T+ B- V6 |0 j7 l6 f3 a14
      e7 U: g+ I$ o. s# x155 Y4 [! T# u6 n7 L8 E3 v; l' U
    16
    0 X( `/ ~7 H7 r0 N6 ?torch.Size([8, 102])
    : ?8 ]9 |5 y1 i6 f& M1
    0 h3 L: A7 Q, A8 H7 D2 K) c# ]* ^6 d9.1 计算得到最大概率# Z. @3 w! q% E" x$ |+ N% X
    _, preds_tensor = torch.max(output, 1)
    " D: e0 U, W; E2 F8 `' H- _. {& ~6 m2 u+ b& D& M2 W3 a
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量$ R2 ?. p& @  b6 B. U
    1
    , [6 K3 v2 _8 P) a! w: l1 u2
    6 x' d% _3 a6 P4 e) z( G$ A+ X: U7 _7 _38 |. ]+ q; k) `6 y  ^' Z
    9.2 展示预测结果% [; l; f. Y& c  ^0 A. j8 E
    fig = plt.figure(figsize = (20, 20))* w6 R3 z) L/ X) b& p) r
    columns = 4
    5 p1 r$ |6 S8 b" Urows = 2
    5 }3 C; L3 V; s& ^; l! v! d: `- Y! u# i, a0 G
    for idx in range(columns * rows):
    9 y1 W% ]# X$ l4 F6 o) B" C, g    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    4 _  W; E% I0 L3 @' L    plt.imshow(im_convert(images[idx]))9 P) G& x  F  P( ~( M: x- r5 ^- S
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),   p9 w& h& ~; d5 C; W
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    & s! \2 ~5 l$ v. Eplt.show()
    ; g% ^. s0 |% o: B4 F* t# 绿色的表示预测是对的,红色表示预测错了
    3 s) w" Z$ C& \# u% d1& V9 B, F. J* X' z  P
    2! y% i( Y, W/ ?* \  _
    3
    ' R2 S" Q% N/ ]$ Z+ C5 X  F: ?4( V. l: Z1 d: P8 |) u- c
    5
    : ?3 u2 `& i( j: K. ~5 w1 \0 X6' }0 q1 m2 w* D1 X6 M$ `
    7* u) \" y. X0 X
    8, v) ]5 D5 p2 K( k
    9: V. J: r7 {) e; \
    10, N* F( j* \0 C+ K8 f
    11
    ( G( ~1 z/ Q1 [% A/ G1 ^$ p% Q9 U9 ~( t$ N

    * C: X2 o  w* z
    : n1 \& l4 d7 P) n  c( ^————————————————# x6 v3 u8 J+ X
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    7 T1 `) t7 a4 n1 K* ?原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    0 ~+ ]2 c& N# s6 P1 ?, f) \; r8 p3 C2 S) h$ B, E' x& c

    $ L/ U+ B& ]" E. f0 I
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-6-15 14:48 , Processed in 0.370017 second(s), 51 queries .

    回顶部