QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2721|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    % I9 P) A6 i# L7 e0 w
    7 g: ]: z  c8 p: I" E文章目录
    ) N) N8 ?5 P, |; X& r卷积网络实战 对花进行分类
    * l, i1 C/ h- V- H数据预处理部分$ o/ O1 N, S$ h4 h$ }; R
    网络模块设置- r/ H/ z( ?, k
    网络模型的保存与测试
    0 M. I; U; q9 l, z7 c# W! d" H数据下载:0 r' q6 U+ ^7 n
    1. 导入工具包
    ; e" r% \! n) @& J* R' f3 f2. 数据预处理与操作
    ' Y& H7 B% t: D( Q4 w+ f- _3. 制作好数据源1 q1 R3 F9 V+ |. |1 }( r
    读取标签对应的实际名字
    ( E! `0 E" Z# m4.展示一下数据
    4 u( Z! `" |9 f8 H/ \, z3 g5 b1 C5. 加载models提供的模型,并直接用训练好的权重做初始化参数  }/ g* m- [& b; @, ^
    6.初始化模型架构
    # p! t8 ?  T& E+ ^* L% y, ]  K  o3 D9 |) `7. 设置需要训练的参数0 e. M3 u4 S4 Y5 R, P
    7. 训练与预测  r# _2 g( w: d- d* S) o4 R3 }9 s
    7.1 优化器设置4 d$ t7 O& g1 O8 l1 g+ T
    7.2 开始训练模型: q: \& O( o: f
    7.3 训练所有层
    7 T/ k, ~/ s. s1 O开始训练: b3 w+ F. e3 H0 f( \# z7 {+ {
    8. 加载已经训练的模型* h0 P0 ^* Q9 \
    9. 推理
    $ [9 }. ]$ Z  T3 S1 n+ a9 G9.1 计算得到最大概率- T6 ~; x" H1 m# F' e' [2 e
    9.2 展示预测结果
    2 Q" H4 ^4 O/ k1 P写在最后
    3 y" h& Y& q  `  b4 W卷积网络实战 对花进行分类
    2 L$ j. g; Q0 C7 o) R本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    " P" e2 p1 e0 p# C* j7 [
    $ m! q+ M% j* C+ B在文件夹中有102种花,我们主要要对这些花进行分类任务
    / A5 _4 x1 Z; S0 u6 e4 N文件夹结构
      u( C, N" c- `
    5 M2 `' d  z' `; ~/ nflower_data  `) F3 w( |3 i" v

    5 Y8 D2 a$ Z% ^8 itrain) D, h7 a1 W2 a) x

    / c) F/ K/ s4 ^/ ^) d! v5 K2 z1(类别)6 s; Q$ A9 c( H$ ^' O% \
    2* t. r3 B" m+ p9 [' b
    xxx.png / xxx.jpg2 i% [7 t1 V6 ?
    valid
    ) g/ w/ r) }* g. [" P5 m! j% D7 J2 m' ~4 _7 Y4 l$ Y! z# c1 d
    主要分为以下几个大模块7 {" s% M0 J5 e" p: V; ], |& q
    - j: y7 W8 h  p' A* x
    数据预处理部分
    ! a# i, h2 O* v$ K) J( s数据增强1 d5 u/ {% z0 v( r
    数据预处理
    6 m" Z# n: @! `, T# w5 b网络模块设置
    5 P0 D5 r1 ]2 j5 k) k: \" h7 h0 g加载预训练模型,直接调用torchVision的经典网络架构( s% P9 z+ O2 K' A1 p1 h$ q7 m
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务( v3 f% _" C1 f- [9 u( N1 m
    网络模型的保存与测试1 O$ F; [4 X" h0 @
    模型保存可以带有选择性
    % F; @4 }+ i  ?5 m" W: t数据下载:' O1 Q, X# H# @: X1 e! M
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset) u, H8 ]: q2 Y3 z5 e- U6 {

    - y" |! p8 A# x, J; D改一下文件名,然后将它放到同一根目录就可以了" ], o' ~) E- e" R$ L  J
    2 a  Y$ i2 q/ E' ]: O
    下面是我的数据根目录' x0 b3 `: i+ D

    - {- y- Z1 k7 f/ F8 q7 E4 S5 m9 I0 |4 e! d
    1. 导入工具包
    " i& K* K5 L' k$ S5 ]8 ^  Simport os
    ; B0 g/ x' {& b; B* h4 nimport matplotlib.pyplot as plt/ s# ~9 S! H- t: R- k1 P+ }: R" y
    # 内嵌入绘图简去show的句柄
    : c$ j4 ]9 C0 x2 K%matplotlib inline
    8 L: W  r' j- d7 ?( W* ~import numpy as np
    6 g; y  z# R5 I+ }import torch8 t- @. E; ~* E. g1 n' H8 q
    from torch import nn
    , Q: p" {9 }; C# y0 X+ I3 L. o, @+ d+ D; q
    import torch.optim as optim6 B- O8 {5 i% W: \" n9 m
    import torchvision
    ) e4 j% U: L1 W* ~$ X* |from torchvision import transforms, models, datasets
    % }4 S9 R! L8 l/ t: {* Q4 o- I3 b' P" r
    import imageio/ {+ D9 v) m: ^6 j
    import time, ]' v; e+ a4 o9 d6 N; @
    import warnings
    6 h, A  v4 ^1 ^  M& z0 ]import random# a" b( @' |# J' ~9 Y& v; Q
    import sys$ f* N; I6 ?; [& Z3 N
    import copy4 j5 Y3 h3 g+ h* j$ i
    import json6 F% r. r, n  y; O
    from PIL import Image& s7 {. h! k* X3 N+ }% I& S- l
    : U0 J' _# B9 W! e/ _

    ( P  b* Y- g  l$ M$ Q$ H1$ V! ?* m0 I* b: G7 t: J5 R
    2
    # K2 K. W9 c; l0 Z36 o3 q  T% l- w! u/ D: y' C
    4
    / m- A% x$ U* q! [. y  B5
    ! |0 w7 `/ x$ v& U1 s7 [7 V  R$ s" o" ?6' t7 ?9 v" }  w6 f* A. x+ A
    7
      [7 b$ i6 c# D0 f8
    . @" m  q# z1 j: L+ A, r$ E# j' _9- c1 m; N: v. v' S" M
    10
    2 |- e8 z: k* b- P  \2 e; u; O11: M& Z! T; e# h' {
    12
    , h) r' z" x& m8 e3 Y13
    + X  O6 h, M; \142 l% m3 s5 S6 Y& C8 f, q9 l% j
    15
    ' Z* m1 u& D4 I162 ?& Z; h/ f8 I" b
    175 o8 E& f) G$ i  w* q7 |4 @
    18
    7 w0 j6 D8 G! g2 u+ `4 g/ A# h$ Y0 H19
    . W5 m3 j2 J# V) s20; H, }" @+ L& A5 s/ o
    21
      V5 C$ G! E' R& R9 |3 c2. 数据预处理与操作3 D/ F1 O+ h9 @( x: u- W" _2 g9 K
    #路径设置
      \2 C: b6 q8 z: bdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    9 m& r; |$ A# |3 }train_dir = data_dir + '/train'6 C2 K7 t9 q% {4 ]8 W1 B
    valid_dir = data_dir + '/valid'
    ! m4 W& n$ z8 N( K! O1
    ! ^- |: J) f( Y) X. B( S2
    % V$ g0 j. }" l9 \2 E3( h- x8 a! X  f$ L
    4
    : Q' m) r1 ?8 A+ m/ a; Spython目录点杠的组合与区别
    ' z& f/ Z5 Y1 z* @1 D注: 里面注明了点杠和斜杠的操作( E$ ?( M: J6 f) m! Q0 L

    # }  [; r) x1 w% @3 N1 A( {' l4 Y' o3. 制作好数据源
    8 ~$ }; E" _2 D' Sdata_transforms中制定了所有图像预处理的操作3 V. v) f* o; n/ g' c
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片- @+ O: D( V" \1 f
    data_transforms = {
    2 c( Y  }  j+ T: y! n    # 分成两部分,一部分是训练, \. N+ c: |, V8 D0 I: l+ u
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    7 I# A2 c2 C- W5 f* g                                 transforms.CenterCrop(224), # 从中心处开始裁剪, ?& j5 V  v; B' [
                                     # 以某个随机的概率决定是否翻转 55开
    - i8 U- C4 ~, e, W( H- R                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转" c. a% d4 N6 p1 e1 X7 o' M1 }& R
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    + p/ I2 M( k" m  T, }2 j                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相( Z- U( h6 h1 ?
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
    ! E" U8 M  c' _/ \8 U                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    1 _) W: z2 |* Q                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    3 r' h# N; T% t4 [$ L* V  `                                 transforms.ToTensor(),
    7 p* S* _4 ~& b0 x                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    # f# C+ D, `0 P! q; Q% h                                ]),
    ( n* Z) i/ l) C    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    : v4 E' S" E, }7 H    'valid': transforms.Compose([transforms.Resize(256),2 u: g) I9 U2 Y) V
                                     transforms.CenterCrop(224),+ ?; @) i# c' W
                                     transforms.ToTensor(),
    3 q# U4 y; U) c7 c6 T: |+ \                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    5 B" q" T0 Z6 x' E$ {7 e                                ]),: k7 F; w6 G/ z: b+ i; E
    }, B2 t0 [) v7 Q: v2 I, K4 m/ o
    ' @- E- B5 M7 d2 k* O- i# e( @
    1
    ) s' N0 I8 j' I; }  _8 [! Q2" u+ h* f/ L1 v0 O& _1 U1 ], \
    3
    # M( w9 O* o2 x9 c4
    % N  X3 @+ Z6 n% E5
    ; q' k5 J7 o& b" J  c' v6/ x) J/ b+ M0 t2 b7 r* v
    7
    7 z( z" B# [+ u8 A: x8
    ( O  r; s% ^- n" f9( B" [* ~0 W  \! w
    10
    + X  ], l3 w0 c11+ P7 _; ^. n- _* q) Y
    12
    6 s% R5 x$ i$ n% ^- J13- D- d9 G4 Z6 V3 R
    146 ~+ b% w! u, h7 L/ R+ x
    15
    5 H' X, i2 F: ]1 _163 l# S8 \' `; a/ n7 r
    17
      O+ o4 g' ^% j18
    1 l2 W) R, Z  U1 n0 d% N/ v; J19! A# ~, S2 D7 h! i7 e0 l6 `3 V
    20# z' c+ B5 I9 F; G/ ~, W0 W& G  b
    21( n+ @( w! a8 s  D: V' g6 E
    batch_size = 8
    ) B% |' L5 c6 M' k7 e2 O/ ^9 dimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    ; s3 \( N- \2 k8 [  adataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    4 N" p& |$ ]3 W0 ?8 t# ^7 \, Xdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} 5 U: B  e$ ]& H7 V
    class_names = image_datasets['train'].classes
    1 X6 S2 z: ?2 Y( o7 \0 h
    ! n9 z4 q6 Y  G" y#查看数据集合% W3 w. h- J7 F' H) p- J1 [: L% U
    image_datasets
    , G8 n1 Q. ^# u5 `$ o: s
    # d8 `8 ?" x0 k9 Z0 h# Q+ O' r+ W16 q* v4 h. s, K* b5 g0 N: y
    2
    3 N+ L: e9 P$ @% p; ?3
    ' Y* c4 R$ x! k- t4. i! l9 Y0 i; p- \( B2 F9 h$ M$ Y
    5
    $ ?& F, z7 C# |: D6
    3 T; x) Q1 g* A7 S: \$ t: X, R8 L7
    8 x. g% W/ K4 B" ^( m8
    2 V# E# T6 i3 x4 l, }9
    ! X/ k5 U4 k% G( X( A" I& ~{'train': Dataset ImageFolder
    5 @- ?$ M! [9 e3 z) K) }: [% f     Number of datapoints: 6552$ o$ j; n' ]" v. N$ ]/ `
         Root location: ./flower_data/train1 |! ?) F( [6 L4 f7 |
         StandardTransform$ c( X- }, C+ {1 Z: K/ U
    Transform: Compose(1 F5 z' n; N% `$ x
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)( G0 e# [8 }- |# S8 Z+ B* ~
                    CenterCrop(size=(224, 224))
    5 W/ {- L: u4 S" V3 b! Q2 R/ t                RandomHorizontalFlip(p=0.5)
    . ?9 b5 s% a/ `5 [                RandomVerticalFlip(p=0.5)5 u- E% T& E6 W' P
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])4 e' H0 l, E* o+ _# I% d
                    RandomGrayscale(p=0.025)
    & ]6 u( e( \4 e6 R                ToTensor()
    , x, i* c- L* }) w/ U                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    2 e9 t/ c( O5 Q3 L3 o# ^& X            ),
    8 f# r4 Z3 |7 S 'valid': Dataset ImageFolder0 V7 C& s+ g6 D( f
         Number of datapoints: 818
    + D& i4 q5 x: j% c1 \     Root location: ./flower_data/valid
    % |  ?- \$ m' q- b6 O     StandardTransform
    5 w- @5 C1 X6 k5 \0 o9 [; i# ?4 D Transform: Compose(
    $ D8 e0 D" t* ~6 z( k. ?                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)9 ?  ~% U$ T3 t. Q; b& S' I
                    CenterCrop(size=(224, 224))
    & {7 c6 u7 b0 K& y2 G6 H$ }                ToTensor(). H/ H; m7 ~6 ]8 }9 u& [" \0 v
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]); C3 s3 e! s7 n8 y" V
                )}
    8 o  _. j, h! \5 Q9 h/ u: \$ E" [% p8 o% q6 r, G
    1
    2 r" v# y# j, t4 F( Y2
    , l8 r. N$ r2 T1 _3
    & y( ]. Q1 p" S+ C5 y4
    ' t- E/ l" `8 U. _# r5 O; _5! Y  b6 T3 Y; u+ t% K1 n
    6
    5 ^( F. r  n6 `5 s4 i6 k' P7* f' k1 W- P. n& |9 T
    8
    1 ]8 ?5 t: \! e7 A, e9 T9
    / I  K% S" g; N  w& v# h10
    & _, I+ S& I! b111 F& y# g6 h  r+ _
    12
    , V3 @; x# n  K" E3 s13$ [9 J( c2 g( e2 v0 p' S8 I/ e
    14
    . z9 P8 C2 i5 E% D( _/ @15/ B# t7 L. ], M5 o( j- `. J4 m" N
    16* u! B5 h4 q% u+ u3 P9 ?2 j
    17
    # g: U* S5 \4 P& g& |  Q4 s! Y189 D4 z. q0 R* [' \# D$ G- [- Q
    19  \+ r* r: @5 g0 w3 B6 A5 m9 i
    20, w' B) A, N) L8 L( M/ l
    21! |! s& G2 E  J) x2 O5 M
    22
    ' w- i4 u4 J2 }: l6 T23% w! N1 }! B1 w9 u5 Q8 d  @/ ^
    24
    $ S# O  Q) H" V# g; j) c& V# 验证一下数据是否已经被处理完毕
    ( v9 k" l" b' \' Y3 s6 J% [" b7 sdataloaders7 P4 P3 t! }: {8 [- E) ?
    1$ I, J% N" N) @7 c
    2! k! g" k6 ?/ h( F7 o
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,) R9 m# |, X4 x* J
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    $ o! F1 n( w9 a1
    % e8 B7 k& K5 u) {, {( A$ p2 S2
    & v  C, [9 i% a6 c& ^dataset_sizes- C( P4 i/ W6 ~0 O
    1
    % f9 V9 w' G3 K* K! U{'train': 6552, 'valid': 818}* t4 `/ ]" F2 O5 h
    1
      k# v, m3 @9 _$ Y读取标签对应的实际名字
      h8 @6 ?$ p+ \. ]! x使用同一目录下的json文件,反向映射出花对应的名字
    ; J- Q6 t$ C2 _7 l/ \. F; [1 `; i6 C! `' z3 T
    with open('./flower_data/cat_to_name.json', 'r') as f:
    - [) w% `2 p7 d0 [7 _    cat_to_name = json.load(f)7 H1 B; s9 A* ^- Z) b7 _6 c
    1
    # A' {' w2 t" H: [3 E1 m/ ~2
    ( k3 y6 _3 A& e4 }6 Q3 Y, z4 l. icat_to_name% b& e: p" f5 M1 M  I, C# @
    1
      Y; A9 @. @# K{'21': 'fire lily',% f  v0 l! p9 ]' Q# c2 H
    '3': 'canterbury bells',
    6 I# C) X. d5 z- F. A( o '45': 'bolero deep blue',
    1 U) C) n3 `; @: C6 C8 ` '1': 'pink primrose',
    . b5 V+ z! K3 u, X9 F. M '34': 'mexican aster',' Z$ |1 o1 ?+ R# E6 [' ^6 j" q! Z1 H
    '27': 'prince of wales feathers',2 w- ~, s: r6 l7 u/ ], y
    '7': 'moon orchid',
    ) x, N+ Z, B& z4 T* G3 J! z5 N '16': 'globe-flower',4 V6 |8 |4 ~: v
    '25': 'grape hyacinth',
    " _/ Q3 e( _0 o3 F( t '26': 'corn poppy',( F& d4 D  B" Q+ [3 D" [3 P0 l
    '79': 'toad lily',6 e# g" \4 S  k% y+ s
    '39': 'siam tulip',
    % J3 v, C9 p* W1 e0 H '24': 'red ginger',
    ! T- L' Y0 R+ d+ D- G2 z '67': 'spring crocus',( u3 t2 J6 S4 c! F
    '35': 'alpine sea holly',
    7 P  a* x: p( G1 H; D '32': 'garden phlox',
    : x, G8 R) C" i& t8 t. \ '10': 'globe thistle',
    8 R' y3 Z0 b3 A. @" p7 }) ]  E% \ '6': 'tiger lily',
    6 {/ A+ _) C) s/ |2 q5 U '93': 'ball moss',* e# k" F% h) |. r7 K# ?
    '33': 'love in the mist',
    ; L" j) ]2 A" ]1 K4 B+ \ '9': 'monkshood',
    . Z5 Q) e# w4 R6 W '102': 'blackberry lily',0 n7 F( p, u3 ~( v
    '14': 'spear thistle',
    6 ^4 G2 L# M# T! M+ D '19': 'balloon flower',
    8 d: u+ k3 n' h '100': 'blanket flower',
      v' v2 u" a" v) p5 o, s+ g* X- J '13': 'king protea',& S& I, d9 o9 N% u3 o9 ~* L( `9 m
    '49': 'oxeye daisy',
    5 \! I  c9 ~, m% ~% R" Q9 s4 [ '15': 'yellow iris',
    " ?. f  N3 T+ G8 |) D' Y, z '61': 'cautleya spicata'," A9 q. x, O7 \6 a( e
    '31': 'carnation',
      i" j7 x3 O. |. C$ b9 l+ R1 d '64': 'silverbush',/ Z- U2 v( @" x5 H( m$ m* j2 M
    '68': 'bearded iris',
    7 C# E' {  V0 H) B% h '63': 'black-eyed susan',
    ' Y* U. |2 B# e9 e '69': 'windflower',* q" C1 ?, M9 S' A" v
    '62': 'japanese anemone',
    2 D9 h8 j2 Z" e; M1 E '20': 'giant white arum lily',; `( _, Y, R5 j6 R" }8 J  ~
    '38': 'great masterwort',
    # f3 y" ]: f4 ] '4': 'sweet pea',# k4 d) o2 P5 k* O6 H2 u
    '86': 'tree mallow',- n/ X! T) i& Y1 X) p1 S0 W
    '101': 'trumpet creeper',/ S* I; `( D; J3 C& }
    '42': 'daffodil',
    4 |- p  w, w) h. C2 s '22': 'pincushion flower',
    . ?8 b9 F8 I9 d  @1 p '2': 'hard-leaved pocket orchid',
    ; I! l' r8 n- i/ G '54': 'sunflower',
    $ Y7 C+ `1 t. {1 r- X9 r0 w' D) e '66': 'osteospermum',0 Q1 [6 L3 N" m2 m& j9 [- ^
    '70': 'tree poppy',
    - Z. N, r/ v$ S4 R+ U# I '85': 'desert-rose',1 p7 i; x4 w( |& {+ Z
    '99': 'bromelia',. g8 y0 }" ^# |8 l, w, A/ }: x
    '87': 'magnolia',) W- P+ S& }) Y& m/ v0 p- x
    '5': 'english marigold',
    " ?" O/ h6 z( j, V1 \ '92': 'bee balm',
    5 c7 e5 M+ j8 d$ ~* u '28': 'stemless gentian',4 g; ~( m3 P$ A# r* a& n: r$ H
    '97': 'mallow',% V- q+ a" m- {/ ?- A3 L6 ^1 L
    '57': 'gaura',
    ( Y' B( `5 B# d; L '40': 'lenten rose',2 C2 ~( d- M! P
    '47': 'marigold',
    + X$ b5 M, ?1 M% i, Z. r) s '59': 'orange dahlia',
    1 \/ q' l* t1 o  v '48': 'buttercup',
    . u$ c6 I" D( u+ Q3 L '55': 'pelargonium',, f: u- z4 P- ]
    '36': 'ruby-lipped cattleya',
    7 y; P0 J7 \, C# w- y" b  l* K '91': 'hippeastrum',
    / i  h+ W' C9 Z' I '29': 'artichoke',
    ; Z0 I: B0 J, y" M* W '71': 'gazania',+ b: T- Z; x6 C6 j( S2 k
    '90': 'canna lily',
    ( p8 S% G2 |& G, J '18': 'peruvian lily',4 I# K( a1 L% z0 n3 S! b' q
    '98': 'mexican petunia',
      Q# m' B  H* w! V) `8 ?5 m '8': 'bird of paradise',
    1 p" D$ T" {' D! W3 W '30': 'sweet william',
    ( N( \3 [$ |( N" Y+ [ '17': 'purple coneflower',: G# z1 i& [7 s' V0 D' l  q
    '52': 'wild pansy',
    4 B; Q+ l6 R6 L# |$ V '84': 'columbine',* u7 v4 N+ z+ S$ E/ g" o
    '12': "colt's foot",( O0 L0 z, l+ a  X$ D- Y* |
    '11': 'snapdragon',: Z9 N4 P. t8 N
    '96': 'camellia',
    * C3 i3 k* Q* b/ }; V '23': 'fritillary',
    ! y% i3 P% d% { '50': 'common dandelion',9 m6 U4 l: c( A$ V) {- Z
    '44': 'poinsettia',
    0 p5 a" d3 f5 d( T- c$ ]/ e9 C '53': 'primula',; V/ M* w+ f& ~% H. P3 r9 R
    '72': 'azalea',3 e2 a7 G7 Q. U6 f* B! x1 ]
    '65': 'californian poppy',& ~) B, p; i( |
    '80': 'anthurium',3 l( h' W  [5 Z3 i* p
    '76': 'morning glory',3 ~  _: z2 P3 l* ^% g+ x
    '37': 'cape flower',
    ( D4 U+ M7 a* [+ u& v7 K '56': 'bishop of llandaff',/ V5 h- ^: O+ @4 l
    '60': 'pink-yellow dahlia',& L; m. V' |2 x( _& j& d
    '82': 'clematis',
    5 f' a  K# D, Z0 w9 g# ~ '58': 'geranium',+ ~; A  L: o+ f0 B
    '75': 'thorn apple',
    ( u# S" ]  S$ p. t, G '41': 'barbeton daisy',
    ; Y- X% G7 p0 M '95': 'bougainvillea',- w6 i( G' i( f: N$ O, w
    '43': 'sword lily',
    & B0 P: B# p0 s3 I, \+ q# R! [ '83': 'hibiscus',
    5 b+ y! U* F2 g# q! {' @( C8 Q '78': 'lotus lotus'," E. N- T  G4 `1 D3 D
    '88': 'cyclamen',9 U$ T+ _6 W+ }! T
    '94': 'foxglove',
    4 ?( f5 ?  B7 r' S '81': 'frangipani',
    3 t% |4 G0 _# V# r3 F9 a/ b& a '74': 'rose',% H, P1 l9 M7 G' z6 M* A+ h
    '89': 'watercress',* i2 S; R1 Y7 z( o
    '73': 'water lily',
    4 @+ s2 f) j; e" G '46': 'wallflower',, m& b3 p4 g: W( M3 e) n
    '77': 'passion flower',
    5 p9 s: B  R: H  E3 i5 ^ '51': 'petunia'}4 \5 P/ ]" X+ y  b6 e' M
    ! e. v: |5 _2 b" N, ]# c: J1 h. @1 k
    1$ C  e8 j8 |/ ~2 K
    24 h/ H; O! X& q" f  I
    3. L) I, @% {% V. v
    4' o% [# Y8 C, R0 A0 {6 W+ ^( X
    5* s0 }* A' o$ y% Q0 H
    68 z* k6 R0 ~  Y) N
    71 l" _9 Q: T' m& S0 G8 ^
    82 h( I" w. c! q: b
    90 l. j* M1 m4 o; h
    10
    : |7 o. w/ D: i11
    1 b# M/ I6 X3 J! j9 `12
    ' S( Y+ ?7 `9 H8 i: R1 \" u13
    7 z: E7 h$ R( q7 t8 U# R2 `14
    ; }: A1 A9 \+ U9 A15
    * p3 i3 M$ W$ j# J3 X16
    0 H3 o3 G+ P; L% d. [17
    ( h: S9 A8 `4 J# R) C$ g  V18
    , g; A0 s9 J( B6 s19
    5 m8 ]# A: T% N4 ?, d- B& e20) Z$ v/ R$ z1 L; Q4 U
    21" D( l; F- p. Z; _, i! a. N1 i2 i
    22* ~% L! w" V) q, [
    233 y' t  g+ {8 j9 y* ?
    247 B+ V: Z  J- r/ g2 H
    25- q, f& N: D- x. X
    26
    4 m& V3 L  @# w4 }# S' f0 o- K27) t) z( f, J# y1 S& {! [# o
    28: n. j7 n7 C7 x6 e
    29
      G3 ^2 p& N/ g6 C30
    8 @7 }& d/ _6 B; f31' W0 f6 b0 F+ h9 _
    32
    - j% Z( P- `  }2 d% i* V' T33- v! e8 R% ^  o8 O. P
    34
    ( z4 Y5 f; K5 Q- h2 M358 G+ x" _  z3 z6 s; f- y, E- X
    361 Z7 O: k! r# c
    37" c+ r' M: Y2 n4 t3 f/ g8 ^
    386 W5 [. L' E$ {9 @, b% [, `# ?
    39
    / p# e2 d# D8 F" I40
    ! b# K7 f" g  E41
    " [% a4 E+ S8 C2 i* j425 g' P& g% o4 X3 w4 J/ U
    43! {9 a( C/ z  v3 l4 A' A
    449 K' A) j4 q& Z( k% Y
    451 h" l8 y1 A) z
    46
    . v  ?( U, J) i47
    , r" D" r- U( N( I' b" z" @488 p7 M" K- M8 h) `; k5 s) Q# ^7 _4 D
    49* B4 n, j3 Z2 d6 h) @
    50% p" ?8 {0 S# i/ l+ v: Q
    51; _1 J  a. L3 P, T
    52: o# c; K) u2 U* j% a7 {& K
    53. S' g2 E9 G# r0 t
    54. A% {; ~- q( f, B; W
    55
    ' u. @' G$ p; \. @# |$ ^# |561 B6 f0 }$ ], U2 `
    579 p% [  j- C; J" G- }0 t
    58  `# Q4 q# V* E! @: ]8 V
    59
    # q( v2 |0 x, U2 l/ J* d60
    0 b  p- n3 u' D2 M3 k3 B61
    , n7 B7 A' @6 ?  S  ]62, O5 h5 m% k8 b. [
    63
    3 m' [& r2 S3 r645 o0 S3 t- p& b1 s1 d  a
    65
    % ?5 R; }& G3 @6 p% Y664 U* B- Z3 |/ `
    67/ n1 P4 q4 s0 }+ |
    68' D4 J5 ^# ]/ d. ]2 Q- a. v7 V3 y+ X
    69
    , C# [* j$ q/ ~2 W0 y) j70& B# b, |$ B; m3 u3 h3 _
    71
    / E) \, t& @' I2 K72
    1 |: M3 K2 G0 k4 L0 h  G3 Y  C73
    + F2 C: b/ l% h74! M* J. z- F% y
    75! ?- u" P6 u) \, ^+ [/ l
    76
    ! {* M! U7 a" A+ t- Z2 c, C" K77( L9 ?: u- _9 E( I
    78
    # h0 X: O1 _% p  L/ a79
    7 o/ d( _% G: y1 T, x/ A- I# z7 V80* l! P! T4 D6 N0 |: @7 e8 d! A  q
    819 _, d- e* r3 X2 u; p
    82" n% v/ P! W3 p% c# u5 P  \/ z" f- ~
    83
    + N6 t2 B; f- S84
    ' Q8 p( y- P% u& M& L5 g85
    ; ^4 @+ v9 l5 g* M0 w86
    : |' w* N" U( z% P. v7 w1 \5 {87
    1 R4 b& q2 e7 U+ Y88
    4 L" ^1 k7 P& v0 L- `7 ]89; s. T: ]6 {0 b2 u( r0 V
    90! D. h0 l+ j: r6 O7 _, Y1 p2 G
    914 b6 ~% D8 x, r+ w; u! i: f0 f; j
    92
    7 p# w1 }3 V( ~% e5 i7 v& t: O- o93; f8 I& k/ `( y: D5 U  [3 I
    942 y9 z9 x* k" K, ]
    95
    ' {1 r2 F  r  @0 ~2 p  y( Y9 A96
    ' T" e6 ~0 o5 ]) o97
    ; ]9 k) l+ p  Z98. o# o( R% q3 ]
    99
    7 V' t% C! {! I- P: f  m" f100$ N- t" F* T" f
    101+ N  X! S6 N' r8 D/ c# ~
    102! }$ r% \: x: N# g
    4.展示一下数据
    9 S- G4 K5 w  z" Z' ?6 b0 q6 Xdef im_convert(tensor):( W( C4 c5 L( @" c: s+ E3 Q- w
        """数据展示"""
    * m; |8 ]3 S  Z% |    image = tensor.to("cpu").clone().detach()
    8 Q+ j0 Z4 u& P1 C5 _) f    image = image.numpy().squeeze()
    # L2 [3 X+ h" o/ ~/ M( m    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图4 G& r+ v0 w5 A5 a, x' e
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    - c  V4 T3 X8 G& i! ]    image = image.transpose(1, 2, 0)
    $ E- W1 m4 s: S8 S. G2 o    # 反正则化(反标准化)
    ' a. T- E3 B' A    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    ) O: V2 C, _" V6 U& D! }$ _
    6 M/ @+ K# k1 z+ |    # 将图像中小于0 的都换成0,大于的都变成1( j$ ]9 `6 `; T9 I# E5 P
        image = image.clip(0, 1): ]& d# q+ n8 [4 F+ O& |

      R% Q: Z% D. e0 {- g" H    return image
    % q5 m0 i; C5 L5 f5 w1
    1 K* {1 n; C3 p. H2. B# y3 r$ U9 H" C: [
    3
    - I$ ]& E; P1 Z8 u4
    6 L6 y$ u6 w7 V: ]. l; Q5) ]9 d; |" `/ S4 i
    68 O3 `. Q, k0 W# h$ y4 ^9 S
    7
    - Q9 i. B( H3 w3 m+ e8) D/ o+ T) Z$ @% B
    9. @( @6 Z$ N- P/ i& x
    105 D; x1 \9 g0 a1 n* P4 O, c
    115 [/ Q/ _: K( ^: x! |5 r  O1 Y
    12
    / ?8 I: ]" t" K- f6 l13
    ! h* h% m# B( u& r14+ v" I  h: d0 T; Q7 I( J  ?1 ]
    # 使用上面定义好的类进行画图6 B  W0 Q) G) J4 @: J* Z+ A. O
    fig = plt.figure(figsize = (20, 12))
    * g8 K0 ]" n7 }( E: f8 }& \+ scolumns = 4, t/ K4 [+ h2 z; p0 k
    rows = 2
    * k2 W* f0 |; t, X$ c6 G8 n
    ! g) I4 g1 _$ a: o% \# iter迭代器4 a8 }! b+ h" I; A$ H
    # 随便找一个Batch数据进行展示" k% i4 L8 s9 K& }
    dataiter = iter(dataloaders['valid'])
    9 }% U1 c7 S3 H. N  _inputs, classes = dataiter.next()- u' i4 q$ P8 k5 Y: E
    3 M8 m+ k" O9 N5 ]
    for idx in range(columns * rows):' Y+ X0 j0 A& S0 ^* l& d
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    2 q! R; f% F2 }; G    # 利用json文件将其对应花的类型打印在图片中
    + \  |, ?# C5 @9 R+ d/ s+ ]    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    ( o: c/ `9 T+ t  R, k/ F    plt.imshow(im_convert(inputs[idx]))3 `. |; ]5 i3 n+ v8 G
    plt.show()8 s, f  P2 {, _- k
    : g# |; T3 @- }* C
    1
    7 ~& h6 h" ^7 S" g% i- {2
    1 \0 w* ~& `  L; p! f. t3
    + u' x4 e6 G; y% c4
    ! @; x" S: S1 q& a) E9 L5/ W1 i0 b6 N0 i$ C# o
    6
    ( B  v& J9 s; b$ W# E; H, b7
    ' {) g% s- `% ?. n: l8
    4 B6 U3 p) U5 y* j9
    ! I" x# M) v+ z: u7 Q  @* J10, N/ e8 \" e) a& Q1 D  W+ e) j- v
    112 T' @% I# M$ {/ g6 y
    12- q' a4 H. o" U8 o
    130 c$ r! C% q- s2 v+ Y9 m$ _
    149 S: P4 p% E) u8 {: ~# U
    15/ u5 J8 s, t4 ^1 x
    16, M' I$ y0 ?0 t" [
    7 m7 A4 u0 w) D8 ^* l: L
    ) {6 h5 Y/ G7 i- [1 m
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数
      p  d2 L6 W# C# Tmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    9 P. T7 ]0 o9 H# 主要的图像识别用resnet来做
    0 x* l% ~% P- I1 `. [5 G6 Y* s# 是否用人家训练好的特征
    2 |0 m  e5 K) J5 ]9 F9 k' @feature_extract = True( N" A% v- o; f" }$ C6 V
    1
    8 j" J1 `; H/ s2 c21 B8 J/ D7 c% b, V5 ]& q8 V3 E
    36 S( y+ b6 b6 J  q8 ^  r( \4 }
    40 b5 b+ d( [0 H/ d6 v5 ]
    # 是否用GPU进行训练4 e  W/ N2 p) X1 g
    train_on_gpu = torch.cuda.is_available()
    + @& C' f7 K2 P% U2 S6 }+ u6 N/ n1 S6 C/ O! R( O7 j, `  j
    if not train_on_gpu:
    ! B# i& Z/ N  [8 c    print('CUDA is not available.   Training on CPU ...')0 o1 `1 ^- Y2 }% {  T+ j. G- {* q
    else:  ~0 c. v+ `2 o' ]" g( t  E# s
        print('CUDA is available! Training on GPU ...')
    " Z' _8 G% a/ R' x% X) R8 @8 F8 i/ {9 i, O. F/ m/ U" |
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu'). v- w6 |" ?& ?3 b7 f" d$ Z
    18 ^. H1 q; l: L" Q& z
    2
    0 U9 \* w0 e: n: q3: @! X+ j% [) {5 Z% I" \
    4
    8 [; o! d: c! K! F" `5
    : K  o, m# k$ H, g6! P- v0 F- j% V- H- [& z1 J
    7
    2 _* L, g. t  p: Y8
    ( i9 g4 Q( j+ c/ S2 j9
    ; ]6 |4 n7 {4 J8 M4 zCUDA is not available.   Training on CPU ...7 v& g! H2 l: W* c  J) {2 |- g9 j
    1$ c" y% k6 S, `1 K; @2 j
    # 将一些层定义为false,使其不自动更新3 i$ m4 I2 ]+ \- Q! T! s
    def set_parameter_requires_grad(model, feature_extracting):$ Y  P; N$ E" {. L
        if feature_extracting:
    $ c; ~; P0 x" n* o        for param in model.parameters():% d; H6 Y* `/ m% z
                param.requires_grad = False4 d+ d) g9 M; H8 d
    1, {9 Q, K* r0 \& K2 \. |) _# N  k' I
    24 V6 c& Z+ P; J; _8 g9 C
    37 ]; @+ s# P# h  W& x% P) t1 k* V
    4" \) Y1 a( r& Z& X. b! M
    5
    ) b) B7 F' C& n# Z# e' P# 打印模型架构告知是怎么一步一步去完成的. ^  }: k; x# c8 A. [3 r5 ]
    # 主要是为我们提取特征的
    * I' y+ [* ^9 H) f: }0 K3 v; m
    0 n# j' u2 Y1 D$ B, vmodel_ft = models.resnet152()
    0 V: M5 C$ J2 W3 \model_ft
    ; Y/ s$ ^% F, x" _1
    7 i& u. i9 J  ?: b* W' r, i+ v% U2
    9 K) T0 q2 d& M6 f3
    5 w6 q# t3 R) W3 U/ ^4
    / e9 M, L  t; R5. K5 D* D/ k! g% I0 O
    ResNet(
    : v6 R: m  ]. A  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)' C% ]: w0 U6 W* v  _$ r
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ( |& i0 p* h2 U: }; E  (relu): ReLU(inplace=True)) H+ H1 L, V& a  P5 P4 j
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)8 W" E" q0 F7 \8 Y7 U+ u) @" R6 i- `' k
      (layer1): Sequential(: X1 [/ J! O& L# d, u2 a0 _" Z
        (0): Bottleneck(! _) v$ B* I# f5 O
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    1 x# M4 A' K. `- D; X* l      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    # n8 w( E0 p' N3 N% H      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    ; s( y4 ]4 i0 h& i# ]1 h5 ~      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ; k! W& X" \% X2 x$ T" F      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    2 Z* f6 s8 o: T/ R      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ U4 X, F6 @4 Z1 y( \: r# k; e( u
          (relu): ReLU(inplace=True), H; i  a& L9 l
          (downsample): Sequential(
    6 L* o  V- C- |  p        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)9 t0 F3 }# P3 l) s7 X6 _4 C9 b5 T
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True). h; o7 W1 k0 I2 A# y
          )
    + B8 o4 t5 \! G8 D  x8 N& ~! K    )
    " ?' A: f$ ~8 z' Q0 N中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    / G5 P" z# C. W5 ?5 }- c# N% |    (2): Bottleneck(# q8 k8 n0 [5 d9 a* E4 x
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ! V8 N4 S9 U! i* T) Y3 r- z4 a      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)( Q. ?9 ?# ~$ f2 Q- F, \  h$ m, M
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)$ ^: g) J4 S1 Z& n8 Q
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    6 I  K% s: v0 G      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    & W! z2 u0 q5 v4 o  s$ d      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    8 D1 V5 v" T. o5 ~! F      (relu): ReLU(inplace=True)1 E6 Z9 O' ?8 c( W
        )* {% o, V( z' c+ x
      )
    ' C; i" ]. A( T% H8 j  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))0 p+ D7 i3 j1 l3 e! r, X
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    % n6 n% ?* l9 G9 N/ k)8 X8 A! d9 \) H  I, K2 z

    ! X) T; E) a. M1* U$ J4 ]1 V  k  E
    2
    8 I! `- ^2 ]$ ^0 Y, d) f: T1 }" y, N39 Q& M. G& Q! X" R0 `3 b" J5 I9 M
    4
    ( b' s  @6 V  H! F6 J- U( P& F5
    / a: L0 K4 q7 o5 d: }) ^1 `+ j2 x6
      d+ U6 ?# o0 b8 x) x9 H0 e* H7
    ! ]) s. Y0 `1 C- |7 W8
    ! a2 b* g3 N! ?2 F' T; K$ r9
    - W: [8 b( v& u2 N% N8 m10
    8 R$ F  f, E, P5 _  @/ O11/ j  t2 o) f5 K' u1 _
    12  J: `# o# X# c* r1 R0 _% ]% K
    13
    + c' R; s" }- p% X$ h5 H9 S14
    3 p/ g4 ~  M9 R* ?$ r4 R/ ]15
    / y# O0 G* r. B- Y) y16
    7 O, w3 G2 h" [7 L7 R4 W! S$ {17
    % X; I. e2 D  m; R6 k- |, K8 r! v18
    3 X, I2 v. [& Z; x19* W1 M' v4 g3 _6 z9 V" A% M
    20
    9 h  O$ |7 R/ J21- e2 g8 ~( Q$ _" o
    22* j# h- q, q) E/ i
    23
    ! ?! ~, e& i/ P: M24/ j: W7 F  p7 e/ h" O! \
    25
    9 u# a. @: X7 E- j26
    3 o6 X' x8 N( Y, d2 Q: k27
    ; D4 x* W, |1 D7 V6 Y: N/ ~8 N28
    % s$ [$ C* U: c29
    ( H# e/ C3 e+ L! Y6 q' V2 G30
    ! [8 {6 n) c% }& r31
    * K; k; {+ |0 b& `- `) E" S32. r1 ?# J& T; M1 t: j: `
    33
    / A4 Y) O7 u! v( A最后是1000分类,2048输入,分为1000个分类; ^4 p6 k5 n8 n6 U/ M3 U( N  O
    而我们需要将我们的任务进行调整,将1000分类改为102输出
    & N4 e/ h5 G' q) l' n1 C
    1 k( p& f8 `  h6.初始化模型架构! a7 D5 D2 m" f- j0 l3 J7 i. q& X7 n
    步骤如下:; @/ F) }$ U+ Y( P
    4 {; Q* Q* i" O. d8 Y- V5 w
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
    5 Y# R/ B  c+ X  `! O可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
    3 H% C0 f3 D! m" S* R' W: |. h# t无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数- V' X. v- F  }
    官方文档链接
    + f" ^) f' c6 Ihttps://pytorch.org/vision/stable/models.html
    " f+ _) I+ O' B
    ; @( A. c0 E0 z! U# 将他人的模型加载进来
    / K0 {& O& n; o' G% hdef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):1 F  H2 u7 a/ N" I% `5 [+ _% @  ]
        # 选择适合的模型,不同的模型初始化参数不同
    3 |+ S5 `: O' x) ?: r  n& \' R    model_ft = None' P' k3 I) \  O( i7 y
        input_size = 0
    ( B+ j$ T0 X3 h2 g* x2 i1 b
    6 r* q& t; p; c% B    if model_name == "resnet":  P; Z" A! |* l, r+ a
            """5 d4 R  t& w' W8 U$ D; z5 F
            Resnet1521 {" D6 g  d' E1 V# W$ {, ]
            """
    ; w* F. |0 Y9 I( K% @( H3 A' o! K0 n3 f. ]
            # 1. 加载与训练网络
      U* H. d  `7 r2 ~# P( m        model_ft = models.resnet152(pretrained = use_pretrained)
    - X( e' @7 z; D$ t+ _        # 2. 是否将提取特征的模块冻住,只训练FC层
    9 k# z& S; B9 |$ y: V        set_parameter_requires_grad(model_ft, feature_extract)" E6 u" p9 T! ?+ C8 }4 Q' }
            # 3. 获得全连接层输入特征1 I, Z# O' ^5 y* j5 s
            num_frts = model_ft.fc.in_features9 x  K4 K2 k; T1 y4 i% G
            # 4. 重新加载全连接层,设置输出102
    / y3 q. k: }( Y" z( x. X7 f        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),7 g& U/ q% g# h0 E3 ^  f; L
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1$ E" c9 T* l: J% X, g) I
            input_size = 224
    + q/ c* V8 f; H5 ~: T
    ! D( |0 R. y% m1 X& v$ l7 R/ b    elif model_name == "alexnet":: `( G/ @, w* N, Q
            """
      N5 S$ H0 Q$ T; h5 n4 H, i        Alexnet6 b+ a7 a& J. x( \3 h/ S2 O
            """
    $ h. f3 d0 Y0 s4 ~" G0 G1 B  M        model_ft = models.alexnet(pretrained = use_pretrained)/ w& p) E2 W8 ]" p* l, H
            set_parameter_requires_grad(model_ft, feature_extract)0 }, f* B2 w! }1 Q6 L" O

    0 I4 P3 |9 I' W1 X& d0 E& Y. s        # 将最后一个特征输出替换 序号为【6】的分类器" O( p) l- f: A# t
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入3 D  M) v7 V/ X/ o% Q
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes); Z& C5 \2 }. s/ |
            input_size = 224
    3 T. s% P- D1 q* n6 }- ]0 |& G! x. n9 o3 U  Q3 S, Y, _
        elif model_name == "vgg":" B# E! H; I& V5 e3 g- I
            """2 r1 ~: s, w" U5 A; m/ |
            VGG11_bn
    3 u9 @+ Q8 r  O        """9 V7 q  Q. {2 J& Z$ ?6 w: _& ^9 t
            model_ft = models.vgg16(pretrained = use_pretrained)0 t! v, d& w  a. k; q
            set_parameter_requires_grad(model_ft, feature_extract)
    % E; |, f! H1 _/ {; U: i2 e" B        num_frts = model_ft.classifier[6].in_features
    4 _& e# D9 V4 _' s, E/ B# b        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    8 Z* t( _3 ~/ v; V, z5 F' d        input_size = 224
      ]& Y4 l  B1 z5 C) q) W6 M- u5 ~9 J  u4 X# P9 I4 A8 W
        elif model_name == "squeezenet":% ~- |0 l( C7 E
            """! S1 J; h9 |% {4 n/ W: O3 Q0 Y! ^
            Squeezenet5 A1 r) q) F7 \* m' J
            """
    5 `' u* J8 I8 S9 N+ y6 H        model_ft = models.squeezenet1_0(pretrained = use_pretrained)5 \( H- c4 x4 r' S: U1 P2 f4 j
            set_parameter_requires_grad(model_ft, feature_extract)
    % x8 V8 x/ f) \# W5 ?. g* l        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))2 f) ^1 }2 P: e0 J
            model_ft.num_classes = num_classes: L. t! {* h9 _4 T- x3 q1 P
            input_size = 224
    $ E! z' L, v8 z7 j" ]2 p0 j: I/ P
    / d6 }' v' {( ~    elif model_name == "densenet":
    % H) l9 J: G8 P        """$ h9 Y( ~7 `* f% M/ D- s
            Densenet
    % a* e. s* c& X4 N4 C        """
    4 @% ^& B' {3 P& K+ W        model_ft = models.desenet121(pretrained = use_pretrained)# [) E0 D9 I! _: p, M# \& ~
            set_parameter_requires_grad(model_ft, feature_extract)
    ! u5 t0 @. f5 l7 r+ k( P        num_frts = model_ft.classifier.in_features
    # s8 Y5 n5 a0 t% M# n$ ^7 o, W        model_ft.classifier = nn.Linear(num_frts, num_classes)
    9 ^6 }8 f7 `$ P& U        input_size = 224- _) P; a+ c- D3 _/ N1 V' U- J: E' V+ q

    / h8 V! v, `, c' W4 d2 r4 v    elif model_name == "inception":
      b1 F# x1 Q& j        """# r% _" r9 U1 R; O3 D
            Inception V3
    / v5 L$ T: ?( i9 L& h* j) y9 q        """( Y1 B$ ~3 X; ], i4 m5 k3 b
            model_ft = models.inception_V(pretrained = use_pretrained)
    - W% t5 K+ _  z0 F  U        set_parameter_requires_grad(model_ft, feature_extract). i6 R4 o5 _9 Z. c! b
    . G2 L! V6 M+ }7 Q
            num_frts = model_ft.AuxLogits.fc.in_features
    $ C) t, J! g# s% M& p- `4 Z# \$ e        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
    ) @8 E4 C1 D3 s% A( N. e, A! T7 Z6 M+ ~" Y: J" @4 T% c
            num_frts = model_ft.fc.in_features$ o' j+ q; I" ^  H- a+ }% }8 a  ]
            model_ft.fc = nn.Linear(num_frts, num_classes)% k' w+ M" J( p1 B
            input_size = 299
    7 s- K$ V# _& N8 I* m& [4 ]' G
    " V$ |! Z. U/ F7 Q. x    else:
    . a/ G$ x9 y! l        print("Invalid model name, exiting...")
    / X9 }  m0 z9 B$ T: L/ g6 n7 i        exit()
    , P7 H  ~1 B! o8 o, v' b8 o
    3 G4 N/ t2 s! Z" j; Q/ @8 v    return model_ft, input_size
    7 W. \; w0 j+ Z  Y; k
    & G8 b. Z6 d: {$ K1
    + X2 i/ q4 J- \6 q2% {4 ^& B: q% I* r. Q, u3 F8 ?9 p) J
    3) k2 H2 e1 j8 |. Q5 k' o3 |
    4! h, C6 r2 R9 n2 m
    59 W0 t; d: b4 A4 @7 \
    6: j3 p1 p# W- @8 C
    7
    & s' j  r' W% R7 s  {& o88 d  h$ X7 n3 T! \$ V& r
    9; u4 e* ^5 t/ K3 l6 n- t
    10
    ; C9 X7 X7 o6 D4 p, x115 u" |3 j9 i. H- o. j) y/ @  \, U
    12& V9 E6 }3 K5 w+ I6 A! K
    13
    # `( D) c2 \5 {0 e14% C# P# [$ g+ X' ^8 X0 n
    15
    1 `- w$ M# f2 C  h, W16
    * E5 f9 d* P6 P! I. D2 M: h17
    ! U, E. o  ~1 E187 R5 S8 v' ~0 R' T- m4 T; h
    19
    ) [* T; e& T) F# t( J* t' [20
    % O- |: A3 ]$ A, D21
    5 ]$ Z! h! |. l7 ~+ L( J8 J( E0 z# L22
    ; J8 M5 o- S1 d4 K$ d23
    - ^0 `$ k& }1 _& T0 I9 p1 z" J24
    " R& D$ j- j. N; p* ?' D25
    3 T9 c7 ~: U  w; B; z& O' L26
    ; z" |  g/ q  g0 X$ Z27
    # f. B1 a! R( x' Q! b% R28  h6 V6 {$ ]! I; }2 H6 `" G
    29
    , ]  c/ j- C; E; D2 z$ W& g302 `* N6 a5 P: @1 C5 e
    31
    * B& D) d! S: |# v% [$ N32+ p$ g4 S2 }" `/ L: A  Y/ |- X7 P+ j
    33" l9 o3 [$ k) n8 w5 m
    34# k# L8 }' f6 d/ \9 q$ \# c" P
    35
    $ ?1 j2 I% C9 D+ `36
    1 `& Y1 V, c2 a37! V7 }2 ~( s+ V3 B+ x9 Z) H
    38  o* X9 V2 t' P9 S: E+ V  o
    39
    3 C) T, J/ P1 C( i$ w2 |  |401 N. a, }6 v! a& V/ B
    41+ p' ?, P; J/ q) j' }! {8 K1 `2 C/ Q
    42
    . Q/ F  ], I" @" _; l! v0 a43: A" g' `* K- ~; J1 T9 |
    44# S2 O) v" U' j, }6 A8 P/ {4 i# c
    45
    $ m# p' c, b# ?" F* ?+ X: ~- |46
    4 ]1 [4 @. ?) P. Z9 [470 |' h4 X4 i! V9 D6 x
    486 k, Q- S) B  c9 A+ T, t! `4 k- S
    49% y4 D1 ]8 m" ~& j
    50" _5 O# v4 w8 a4 ]
    51, H: O  M# |; f  E
    52
    , t7 @% @6 n# g! B, A0 C53' y$ C. h; Y% l" a4 F" m' b
    54
    ) p1 _" Y) ?! U) T55
    - ?2 A& E: c: |: f" a( [) c56; d: T. x: B8 T# ]( C
    578 ?# _+ I; P2 Y
    58
    4 B' y9 w' A+ G( P/ Z- D- q( p- x/ A59
    , w: A% w* I* J- a8 @60
    4 b0 G1 C* U3 Z0 j' }; `- C! t% |61* o! n; I0 s3 I, J+ y. ]% J! x. ]
    62
    ; U: e# M1 P  e$ \/ m63
    5 U* Z! n# {2 |64
    4 r  [/ L/ k6 k9 V) ]5 H) a! D65
    ! @9 C, h+ \$ V& t3 ^5 R1 O66
    & R/ B. X. `( M/ T678 n4 }5 {' S: ~7 @& D
    68
    , p/ t8 n7 _! `) t. _7 F# o69
    % m& ]# h# C# s70
    ; s" S5 m% X0 C6 l6 T. X718 Y+ |3 }7 r8 e( k# _
    72, \' ]7 [" F' w4 w! Z
    73; V+ e2 ]3 `* h- D/ B4 }7 x
    74
    " o0 k; L% c" {75
    # c5 \7 x4 l/ m- R# X+ `76# b5 ^' V. l! J4 f) t' D
    77
    2 j8 j; u+ W* y- v! J. I78
    ! r2 ?) v$ P9 |% x  z: [3 a79
    6 M) U: M. P* t# e# C80
    9 ~# v/ }3 D* [3 G81
    1 w! c! b/ Z: k" ^( r6 |% S82+ \% p- D! \; p
    833 ~1 P! R/ V2 C* B) a! E, V
    7. 设置需要训练的参数
    ! k9 u; C* l8 v$ e5 B0 j2 Z# 设置模型名字、输出分类数1 Z( Z, P" i0 z
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    , |. ?% y9 _" v% G6 |1 J  d* f! Y2 u$ T% c5 n: X
    # GPU 计算
    # ?; z; @' a/ H6 w0 s$ nmodel_ft = model_ft.to(device)8 e& P8 ], N. T8 P" e; |6 T+ R% F
    % H, D( x# b# }: K5 m5 b4 c9 h
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    1 b& f2 N. ?' I. ]$ hfilename = 'checkpoint.pth'5 A& Z1 y8 D& Y# w, Z! D! V* d1 o
    6 M- }* P* h# ?* C7 C4 [) Y# y6 r
    # 是否训练所有层% O. U; F% b' F
    params_to_update = model_ft.parameters()% V6 Q% q% N; V6 b! o2 v6 L2 ~2 q7 o
    # 打印出需要训练的层
    " D# v0 o5 l5 {6 a4 G( E1 ]print("Params to learn:")* \9 U! \; {! @8 c
    if feature_extract:) j, l) R- u9 {; `( u
        params_to_update = []0 h% x+ D4 N$ u
        for name, param in model_ft.named_parameters():& v9 q+ ^( t& r1 V+ T
            if param.requires_grad == True:7 V; d' |( D; T: ?$ }
                params_to_update.append(param)
    : U4 y0 t% \- G0 I4 G' U7 J) @; j            print("\t", name)% v: g$ ~2 t. m0 V& _  l$ }; X/ L
    else:
    7 ?& K1 Z8 _" ~7 D4 F    for name, param in model_ft.named_parameters():
    6 _1 _9 Z) x1 M9 f+ f2 B( U2 x        if param.requires_grad ==True:
    & u3 G2 G  J, e$ }            print("\t", name)+ G+ {+ P0 f9 f& b
    6 V- Q2 B) O) T
    1, B) T6 W# N+ z+ b( T, H8 Q
    2
    9 a  I1 h( w! H4 c: t* s3
    " Z$ y+ U7 N. H5 E5 U# |4, y0 o) |" n  K+ L4 L
    50 ?' ^' l  V% j8 Z% V1 ]
    6/ v* Y3 w9 d% k# X! |8 O
    7' u/ e3 e4 M' o. v' E6 E2 c6 |
    8
    0 c; X0 u. d( t7 \2 ]9
    ; _! P4 _& h4 t7 ~* o& @* k. S10
    ) K" {4 l; P* \0 q3 S9 s9 y# [3 \11- X" G9 v8 B, ~: |: ~  m  K
    12
    & I$ {1 F. O3 D( z7 h1 j: L131 {2 X7 I- S) N8 M! [& m
    146 _6 V4 R0 n  X/ m8 R0 m
    15. K& _3 b$ b1 v% i- i; w, K
    165 T; S) F) w6 J5 m9 i" Z
    17( M- [) ]4 M7 F5 s, Y9 {( T  ]
    18( u* t' ?8 L' g4 n! ?: N
    19
    ( U' R$ c3 Y" O" b20* S2 j8 {: p# F9 |
    217 c( |2 O/ V2 P3 z
    22
    ' t4 N8 W9 ?7 _& ?# `2 S" c' j23
    / p" R& o) {( Q. C0 JParams to learn:& M' {% d3 W! i% M2 s- w- V# ~
             fc.0.weight# m! G7 y$ p& y, {- n- \' N
             fc.0.bias4 f1 v% K9 Q0 H( S' i# B! v
    1
    7 B, f! B2 q# v- D& S8 H2
    - b/ _9 N# ^, h7 H/ @3
    4 ~  Y& P; F( z- g: k6 p7. 训练与预测9 o% P* t1 v* C1 n
    7.1 优化器设置
    * i7 I5 R' a3 l5 k# 优化器设置  {+ D5 u- R% X1 x( n
    optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    ! X3 U! g! S1 G# H5 o0 g5 A# 学习率衰减策略
    2 D3 u* w) ~* j) A; e: X' ascheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)+ n8 j$ @4 |6 f; Y( H7 E! M
    # 学习率每7个epoch衰减为原来的1/10
      Z/ p% B/ I  H+ e# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算& @6 Z2 I# c$ w$ b5 `

    5 ^6 b, Y8 j( qcriterion = nn.NLLLoss()/ R8 H! w5 L0 {; v0 l: s
    1* ~. b2 I' {1 A  d* `" Z9 h
    2
    8 Y7 q! T, [6 `  J6 t3
    " a4 ~# b3 ~- w  v- z$ {+ m4! y/ {5 X/ s3 A7 z/ N
    5
    ! s4 \8 ^2 G6 Z3 V/ N$ k: ]7 ^6
    + s7 c4 G! W% e' Z4 B0 {75 c5 j5 p- r6 d. b  b7 G
    8" x, ?7 e! I# z; i
    # 定义训练函数" C2 b. H0 K* Y: J) J
    #is_inception:要不要用其他的网络
    4 L2 V, S6 m9 g$ \0 O: Qdef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    8 L9 b) }$ l: V) p' _! S    since = time.time()
    / ?( D8 R1 ]( s    #保存最好的准确率
    ( k6 C) h! q$ G* l8 v8 E, N    best_acc = 0
    * N$ I' H3 V3 D/ L( z1 O8 x( b    """7 F" F, r$ y8 v1 b. l/ {
        checkpoint = torch.load(filename)
    : w4 R9 n  g4 H8 _; X    best_acc = checkpoint['best_acc']/ R% x" R8 H# S; }
        model.load_state_dict(checkpoint['state_dict'])
    . n' g4 v3 N" N  V+ q    optimizer.load_state_dict(checkpoint['optimizer'])
    . A/ d* E- r5 }7 r: }7 s% Z& `    model.class_to_idx = checkpoint['mapping']# K; }" r% k. P. u6 J5 b/ l
        """! n& h' `0 o" z, H9 a
        #指定用GPU还是CPU0 H8 z; O2 P1 T0 u3 R2 I! f+ L
        model.to(device)
      y( o# I+ s" n$ L  h2 x    #下面是为展示做的+ L" s* ?0 E9 u, X9 o' U
        val_acc_history = []- d, m; A7 r6 O
        train_acc_history = []
    9 N4 A% J2 ?+ q; u/ N    train_losses = []$ o! y6 i+ ~  R9 f- y
        valid_losses = [], f/ T" x6 {0 a' I( I. |
        LRs = [optimizer.param_groups[0]['lr']]
    # K4 r9 g/ S+ B    #最好的一次存下来* Y! @* \9 T; L+ Z8 S6 J
        best_model_wts = copy.deepcopy(model.state_dict())7 ?% E3 }. g% Z; S
      [! ~% @  [5 y% P
        for epoch in range(num_epochs):
    9 i1 i8 S+ Z; w$ D8 i        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    8 l# z! d- W5 t- {- l# `7 K9 ?2 x+ d        print('-' * 10)* [! h" F# H+ n$ V3 i0 K* n9 M7 f

    : z) v  R; B1 p* R3 Z# u! e        # 训练和验证
    + _# r$ h0 t1 i  |3 U* N        for phase in ['train', 'valid']:
    2 \3 O8 u7 ]' u! ~" y7 {            if phase == 'train':4 j9 B) x, ?2 B2 q9 Y, Y& h
                    model.train()  # 训练# `  I% ~8 @7 G
                else:
    9 |1 e/ B# q7 L. {3 K% ?' \" O9 t6 w                model.eval()   # 验证6 V, s' v' o" O8 A$ y9 g: ?' o

    1 M. R# Y- K; k* B- Z5 M" k' l* E            running_loss = 0.0' A: |% Y' d! ?" @- w
                running_corrects = 02 M, F  W/ m( W, P. |

    & q; T4 D6 Z# c8 p8 M' p            # 把数据都取个遍
    6 \. h  t/ l3 x* T            for inputs, labels in dataloaders[phase]:
    / H# Y9 Q6 h4 d8 f1 T% N" q3 a                #下面是将inputs,labels传到GPU4 |, c# u: [' ^2 r- V# j* w
                    inputs = inputs.to(device), M: i% L  Y) N+ K' v% L3 w; d
                    labels = labels.to(device)* y& J: e; E- O7 G
    # k3 @. r! m# V& u9 }8 c% y6 S( [
                    # 清零
    ( {# K7 E" c% h  B- n# p                optimizer.zero_grad()
    & ?+ W! J7 z! ?  d- R                # 只有训练的时候计算和更新梯度
    & c+ W/ F( T! N                with torch.set_grad_enabled(phase == 'train'):
    8 l. B+ ?1 }7 x# t7 o                    #if这面不需要计算,可忽略
    7 A1 C' f0 M+ _4 ~' d; t                    if is_inception and phase == 'train':  g. y/ ]. ^" k. t1 v% Z
                            outputs, aux_outputs = model(inputs)& R' b# ]/ e6 B, C" J& S
                            loss1 = criterion(outputs, labels)& x! n# i3 H! y0 D5 {; G* Z
                            loss2 = criterion(aux_outputs, labels)
      H+ {6 K# I' }& [: V/ A                        loss = loss1 + 0.4*loss2' A9 T) q/ f! p6 j7 O
                        else:#resnet执行的是这里' n6 d# T9 Q0 v3 {; c4 j
                            outputs = model(inputs)
    3 Y# D8 a0 L/ I2 J( m& q4 e( i                        loss = criterion(outputs, labels)
    5 B1 L8 g; d! H
    . A! k: c% k0 W% _                        #概率最大的返回preds* j3 @, Q5 c' }6 t
                        _, preds = torch.max(outputs, 1)
    3 k$ K7 z" R3 d1 M
    7 `- \: I3 E4 l' `- _( v                    # 训练阶段更新权重
    % k$ }* J1 u! Z  T/ H1 A3 |$ d( t8 T! V  \                    if phase == 'train':# S* i- r  T0 E
                            loss.backward()+ c6 N" A" c2 N6 `  ^2 x2 `
                            optimizer.step()
    ) G. Y  Q) {) _* \( Q' [4 R/ X$ j8 k& Z# j2 X$ x
                    # 计算损失
    + Z! t. B# I2 p- V: K2 L                running_loss += loss.item() * inputs.size(0)
    7 B) Q, |5 n. W/ A4 m, e                running_corrects += torch.sum(preds == labels.data)
    ) ]1 H, K0 v( B+ ]2 t- X8 M% ^7 o/ }8 h
                #打印操作4 m) ~/ v8 T+ ^0 q3 `% E; P& g4 W
                epoch_loss = running_loss / len(dataloaders[phase].dataset)7 [3 W# H7 l) ~7 P
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    8 I+ Q! Y' d/ |9 r1 h3 u
    6 ^( S# n9 O/ y9 _) L
    5 j1 W9 n1 `. x! t$ q            time_elapsed = time.time() - since
    4 S! T1 [0 Z" w; |. k8 O2 ~+ Z            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))) e0 C, U* d2 e6 e
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))6 l( M! d3 ^# m' M' H2 b/ D
    % y4 q" }! G2 d0 S+ l

    : R2 q8 h* c$ X: C            # 得到最好那次的模型
    / _  X  f2 ?# y8 f  P$ G) u            if phase == 'valid' and epoch_acc > best_acc:, ~3 i$ j# i/ M  `/ Z( W1 v
                    best_acc = epoch_acc5 M- x: O9 I! ?: O
                    #模型保存
    - N8 Q+ E' U9 F  h                best_model_wts = copy.deepcopy(model.state_dict())9 H, z2 O5 U4 r
                    state = {
    8 F: S$ ]4 t+ d# n, t: C                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    2 k1 a& j& l6 P3 Q6 G" {2 i                  'state_dict': model.state_dict(),
    9 s: G) Y' c2 N/ u+ e9 w0 ]                  'best_acc': best_acc,
    6 S% c! r( F5 G- ]0 s3 O5 q# C                  'optimizer' : optimizer.state_dict(),, C0 |( C: p7 V( a
                    }
    " v* H3 y  M6 K4 ~8 f* t  J                torch.save(state, filename)' C4 X9 B- s* g" O& {
                if phase == 'valid':
    9 Q  c" u1 v9 f% F' ~. b+ S; i                val_acc_history.append(epoch_acc)
    ! y0 g6 E  O5 o# u                valid_losses.append(epoch_loss)' B3 c! m& C) a$ V/ o8 I
                    scheduler.step(epoch_loss)8 j* C+ r$ \7 I4 X3 q0 U: x
                if phase == 'train':
    4 N& ?4 s3 [/ K1 Z" B                train_acc_history.append(epoch_acc)2 f6 P  l: i2 s  ?* h2 d
                    train_losses.append(epoch_loss)6 p8 Y8 N, i& [7 X

    1 n* y# I1 s7 u  c3 x        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    ( {6 d7 s5 W- L( I        LRs.append(optimizer.param_groups[0]['lr'])% q" f5 E' b2 R( P' R/ J
            print()
    : O& Z; n5 A3 M2 ]. G- Q: q% w4 h$ H; d: G+ O! S6 c1 b5 |) w; i( \
        time_elapsed = time.time() - since) q; c1 U. }' x2 P# ?
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))% u( K/ F2 m% W; ^% f& p: A
        print('Best val Acc: {:4f}'.format(best_acc))6 ^# E, s; i2 Q9 ~' V' w

    7 ]. X* m- q4 b: A( Z) o    # 保存训练完后用最好的一次当做模型最终的结果
    - L8 p7 ^4 ]! B- T$ `    model.load_state_dict(best_model_wts)
    : L0 L- n2 G2 b) U; f9 \    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs   {5 s" a6 Z9 `' s3 f8 i* |& g4 [4 z
    * H! \# I& P1 K3 B1 R5 C

    - B6 i: z6 l5 ]/ ~1
    9 b/ B& Q$ ?7 q. P' V3 ?2
    & s: i4 }6 }" e% ^35 F8 R9 y2 ^7 |/ ]2 K7 N
    48 O5 ?/ \& \0 b8 B
    5( ~) }/ e& v: A1 a+ ]
    6
    9 S5 |  s4 l8 y" f8 ^2 ?7 J7+ S5 J2 r3 T8 X  D1 @8 k; E! C9 v
    8
    3 `* N$ @  [8 d" e  A0 O! b1 _9
    - J/ l6 x8 p- y& d! V& d10
    1 L3 n+ c6 U2 }11
    5 T. y" |" K  R( j) s' a12( h$ L& ~, {3 g9 \( h
    13" c- ]/ G# W/ v* d1 Z
    14
    ) Q3 ]& C/ y4 y# b& t% R. S153 ]% Y9 a: g0 D
    16
    ! z# ~* _& T/ @+ N' M& d- y17
    - |; y* S/ C+ s, x+ m18% @% w# u$ Z+ k, D: Z: h9 j
    19
    ' r, N7 G1 E6 j1 s8 B5 N207 F7 D! p: q. A4 g$ \
    215 f( g0 k3 v, V: G9 p. }) R6 S/ ~$ L
    22( Q  F6 M  B* F$ v4 X( m
    23# ?. O' F9 M+ K' b" i3 {) e8 b; _& ?5 M
    24
    - T1 \1 z& g/ c3 ~9 o25- a3 c! D7 {' C6 r( ]0 b( f5 f6 G
    26
    9 n8 a. x2 {: K: G* e. E27
    : ^6 Q( o( z! s: A" p28
    - w9 k2 W' q- a. x$ b29
    6 o- H# s% @; o8 d" R$ b4 O, F302 ~. }3 x% T& C3 Q1 r
    31
      D5 S0 s5 B8 B# ?+ q- c32
    ' _' o4 s$ r( {, q33
    ; S% P$ ]5 w( i/ W6 y34) s! `7 A1 D  v0 Y
    359 [8 }! d3 F. }( H; t
    36: P: N- g, Z3 _; ~* |; ]0 b" W
    37+ u; B6 d* }/ e4 a5 W. j
    38
    ) r8 w# J% \: i+ a  ?39/ N/ b: p( v; \
    40
    # i3 Y! U, x: ?$ c1 O41# X. E& e6 H$ a* j' Y& ~: G
    429 G" U" |3 F7 O" ~! \6 o+ v
    43
    ( u4 w& `2 T& [0 V" A* r: v% u3 a44
    * F4 N# V2 t) u7 T8 C) u& H45
    / c1 X9 Y/ F0 A+ Z6 {460 E3 D* h7 I3 m- T( M( g
    47
    " n+ A' T9 K2 Z& O7 r8 j! m+ G48
    0 V2 N: I% D& t7 B$ b2 e2 m498 k" Y$ p2 j$ s; F
    50
    # m% a7 Y! V" T& ~1 i51; `. v- `( C: Q6 l
    52& X, b  C% j5 F7 g  g0 ^. m
    53
    4 I' g$ N) v6 `# _549 _9 l7 \% \1 L. [( @
    55
    5 \3 z: M9 f: A2 M0 P5 }56
    5 B& s' G9 Q' E* [. Z2 ^/ R57  ?! M# ]3 j( @. r
    58
    / i/ v$ d9 y' ~& T8 I59
    . K2 E+ [# q* _+ B( T60  x+ D. v1 ]3 E+ Z/ ]2 s% C
    61
    5 ]! c: |* w, u) T- E620 v& j  ~, _" J9 _/ ^4 Y
    636 M: m, t% A$ [5 Z
    64
    7 `. i& S  E8 ]% o658 E, A. e  k4 u3 G0 M3 g
    665 B4 C; ~+ ^) T9 v  ?
    67! n: U0 t% N# M# j
    68
    7 y0 D: Q1 ~$ F1 Q69$ ]/ }/ ^& Y) m7 _- P# U5 C
    70" Z( c) k3 D6 F  q
    711 s  B; G9 c! a3 q
    72
    * Y+ o! Y& z- B0 q) a3 m& q73
    8 s6 Y' S$ T5 b5 V; b4 a" M74* F& d6 y" `0 B2 w1 [
    75
    " G) b- _# j2 O/ r7 R/ G76
    2 R5 ]  O4 {% {& P9 y# s+ v77
    ) n& U( ]7 e+ P6 S5 B78
    4 T) z0 a) _( |1 X8 J79. h6 l& g3 J) U* b8 y2 o' {+ M
    80" Q7 P! f6 Z) u3 c( [% v
    81
    5 v$ D! k/ c" V9 ^: J; a# c& k7 w82: B% Q& z2 V0 o; P
    83, ^5 \4 V3 G: h  k
    84
      a1 r/ E$ J4 z6 ^856 P( c" s: z. }; M; U; A0 @  M
    86
    5 d: f9 ]0 f6 t1 i" h& R7 [87& b8 ?- w' h$ J" z2 K6 a5 g
    88
    ! l$ \! n3 B4 F/ ^" G89
    0 B* ^! c4 F2 z$ Q6 L90* _% b2 @* B  S4 w/ @
    91
    ' }$ W9 X* F. m0 O92: ~$ N( X+ C4 G" A
    93- R7 ~" S  A! o- s! ^2 S
    94
    $ f3 o& N% f2 n2 f$ E+ e' R95% A& t/ x: z( a5 S# M6 S7 Z' }
    96
    8 w* L6 k) w) S% T: \" K97
    & q6 E* w* D5 `- c98) ^2 x% x3 _) V1 Q$ e+ H& I
    99( s, H6 M6 S" i
    100& y7 D( v' b+ k5 P
    101
    + M# d5 U* {% x. F, d3 [1026 c1 \1 q' \- j" i3 M2 x& j
    103
    - Y- ?: K2 z; u/ Z+ @104
    * q) C6 h! b6 g* D3 Y105
    7 a) H8 \  {& z/ B' @+ [" \1066 N; @2 }" b2 S! H
    107
    ! `/ l7 w! H2 [6 u/ M2 j/ p3 G) j1087 E8 E5 |( W: _6 c4 ~) ~* D9 C
    1097 N; W# q% a5 N' ?, D1 Y) a+ a
    110
    0 L$ k; e# T' B$ [111. s" w7 q' j  @$ `* s
    112" a# _& [+ y+ b8 X* Y0 C
    7.2 开始训练模型6 M% Y2 H, ~+ d6 I8 |
    我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次+ T9 |9 O+ X) \" o, B2 h) R3 u

    5 p: d$ _( y) q0 b4 S#若太慢,把epoch调低,迭代50次可能好些
    , X+ A1 n# w" k/ a/ `" T8 S. @7 t! b6 M#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合3 e, W( s) w2 @1 \7 x& Q+ b
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))8 g  ?. ^. @4 n: }2 X; a' \3 D
    # `9 u4 [. x" c6 K7 m+ i; f; j
    1
    2 J: i3 E% A* H/ e* D" P8 ^2
    " `0 U  U# s8 A" t& I3# f* E' G+ u7 T+ q9 t/ V" L7 G
    4' e) G* f( I# w
    Epoch 0/4
    6 W2 n4 K3 t  w# ?* K' ~----------" V! B  |0 [; V# H: ]) ^$ p( {
    Time elapsed 29m 41s/ V0 r* J. v4 V, s/ U% z  d
    train Loss: 10.4774 Acc: 0.3147
    0 Z# S! m; p6 a: G  OTime elapsed 32m 54s
    % b9 X4 u2 F0 dvalid Loss: 8.2902 Acc: 0.4719! [0 }# O, a/ l
    Optimizer learning rate : 0.0010000- j" h( b( e' s

    ) K. i9 B8 U  J4 o0 i1 sEpoch 1/4, M1 U- {1 k% f
    ----------
    7 t# U) z7 z( O/ S4 BTime elapsed 60m 11s
    / N1 G' h+ f* _& Z: i" @6 Itrain Loss: 2.3126 Acc: 0.7053
    8 q, {/ @$ j; y0 tTime elapsed 63m 16s' R8 e5 a5 U) [- `& P* R
    valid Loss: 3.2325 Acc: 0.6626) Y7 ^% J  `4 ~' x
    Optimizer learning rate : 0.0100000% ]8 z' k9 Y0 p" Q# R, `

    1 @, |' N3 {7 ~, b6 T$ FEpoch 2/4
    ( W! K8 L& v+ B3 a9 e9 L----------2 l: d1 j7 _# N3 B) x
    Time elapsed 90m 58s
      A, r; P# W' f" {train Loss: 9.9720 Acc: 0.4734
    7 ]" [" @4 S0 j! }Time elapsed 94m 4s$ K  Q+ m7 S3 e" t1 T+ i* g, ]! C
    valid Loss: 14.0426 Acc: 0.4413, l& ^# s  P1 w) E4 C$ o' p
    Optimizer learning rate : 0.0001000
    ) s0 D/ \# ?4 [' [. _) P5 d$ _9 b- q1 O/ P- F* C" z, i
    Epoch 3/4: h. q5 k% T8 p8 h5 j" c
    ----------" o6 l. ^3 e( D+ O
    Time elapsed 132m 49s3 b+ p7 J- u$ K8 W" y1 P7 O- h
    train Loss: 5.4290 Acc: 0.6548, P) B* p7 k# l- i' P
    Time elapsed 138m 49s' w3 c0 H) [4 B, F, B( @9 M
    valid Loss: 6.4208 Acc: 0.6027  ~# P: v% s4 n0 O* B3 q
    Optimizer learning rate : 0.0100000
    , G( p! W: u% H1 E; V+ V7 L% s' H+ K$ B
    Epoch 4/4
    " ~9 @2 h" F- d& W/ n) t' o----------
    $ A' v4 q4 g7 m( Z7 qTime elapsed 195m 56s
    9 v1 A* e/ N' J9 `" h9 a5 Atrain Loss: 8.8911 Acc: 0.5519  {, d  @3 i9 V! g3 c5 r
    Time elapsed 199m 16s
    ! d2 Y, u6 ~$ c- w  Y5 J) dvalid Loss: 13.2221 Acc: 0.4914
    - h0 A, Y+ A$ x+ J4 z  x0 ]! Y1 ROptimizer learning rate : 0.00100009 L5 T  F* A' M# L. K# ]& E
    0 O5 m" |$ A7 S) p$ N
    Training complete in 199m 16s
      g6 ]) p3 ^% k& R5 yBest val Acc: 0.662592! W: ~2 K$ K1 A6 C

    4 J- C0 l' l% I3 ~  W5 n6 s1/ B; E+ Y6 x" [3 |
    29 j$ [9 q& c7 W0 \3 m  S. Z$ ]
    30 o: O6 m; a$ |
    4
    5 k; U3 x2 t' i  m, P! I# _2 m) H5
    ' a8 P: d- |3 N. R5 v6
    5 C2 K0 L1 M2 `+ w, ^7
    0 k; B6 [, H& I8
    ( _3 @$ d: s! n, P9
    + O# \4 v/ P$ u: M8 @- A$ T10
    $ s" p5 K/ N3 O1 @; F5 e11
    : l& I1 t" V  W& W+ q7 h& q9 y3 |) x: [12( N4 _: V6 r4 c: D9 c7 K
    138 {4 f. `" `' {$ t, ?0 |4 A
    14
    " @. |. e% g& u& @* `$ u7 A, i15
    % H1 F9 r* K! [  x6 [, E16
    " G$ x4 s1 A- e5 _# N17
    0 u  @5 N( b/ d% ?% H5 Y7 U18
    - ]& K4 x. k- z. h* m; S19
      a: g# L; c4 |" z20$ N( @; J7 k4 L6 Y
    21
    5 g( O' |; v% I8 j1 i. C22$ f! f* s  X2 `4 V( N
    23- [; _: [- E+ ]0 y+ h5 E
    242 ?* X" p2 v. G
    25# Q) C5 k7 [3 K0 j
    26
    2 f+ ^# n5 |+ I# a7 g' |272 }; }5 C+ \  t* m
    28
    ( h( O4 m2 H+ w29
    * ^" p5 [. u9 g30( \/ |2 K( _  S' A# i4 v' M
    31
    ( C  ^( c: ~9 k  T0 @32, A5 n; q0 t+ P9 P$ ]/ k& A9 w
    33+ q0 P- h# x! Y" ?! n6 s4 @
    34
    3 d3 N) ?5 _, I+ D6 u5 [2 @351 d! c# O. S) U
    36( v! [! r2 d7 `( j
    37% T0 l+ ?0 A6 L0 h: j: [
    38
    ; W1 w2 v' o. Z7 w" m. ^* o39
    6 U8 q* q5 I. k: o6 o( A406 k( `8 _$ m2 H! k$ v9 z: G
    41  X  s/ Q+ P, T) G
    425 U# A" H1 o3 r8 k
    7.3 训练所有层
    & T- [+ V( k, T1 Z: b# 将全部网络解锁进行训练
    3 a9 X- R0 Q5 g  B. Ifor param in model_ft.parameters():$ v4 I/ H$ i6 T5 V/ V: Q) T
        param.requires_grad = True
    ! b3 k) t( {5 B' D! Q
      M. l; C0 q' B; B1 r# 再继续训练所有的参数,学习率调小一点\, l0 m0 {, s3 ?# |" b3 v3 f
    optimizer = optim.Adam(params_to_update, lr = 1e-4)
    " ^7 g5 |' H1 o4 |scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)2 A* H" y* \  j% Y
    8 i$ g% |# g8 ?4 ^0 X
    # 损失函数; j8 m, n' `! i! i8 X) K
    criterion = nn.NLLLoss()
    3 N  s& C1 ~, v) I1
    6 W3 ~! h2 i/ f: G! I3 L! Q% m2
    # u7 Q, P# v; T" N6 Q8 {34 M4 e3 V  Y# |, g, q) X
    4$ Z1 W9 {0 A0 A7 s. ]5 g  U
    5
    2 A- y# o; r0 F6
    2 z3 t, V9 {8 x5 `- m7* p/ J: A/ Y" P* k* Y
    8) h# \0 N# X0 f8 s- r2 E5 {
    98 N4 m# p6 t# _6 C; U4 D# [# K
    10
    / ], ~: z) `3 Y) \0 D& V# 加载保存的参数7 b7 i- S2 |" s* S9 q: y
    # 并在原有的模型基础上继续训练
    9 h2 L$ N# N8 g3 G( a# L# 下面保存的是刚刚训练效果较好的路径
    $ U6 l: ^# w5 n6 z) Q0 W6 ~checkpoint = torch.load(filename)6 W9 p4 x% s' Y3 t1 ]' j- U0 p* w
    best_acc = checkpoint['best_acc']1 u: }  l, X( v6 A6 F7 k6 h% j0 a
    model_ft.load_state_dict(checkpoint['state_dict'])- F1 H1 R% b4 Z* a
    optimizer.load_state_dict(checkpoint['optimizer'])6 I- G  E" A9 y+ D; ]/ J6 u
    10 B# Q- m6 ~" ?6 D8 S- n7 V
    2! g4 o8 W6 h7 Q7 ~- T/ N% O7 Q
    3) I1 k# W1 s7 M! e0 u+ S
    48 Z9 Y+ }" S  k1 E
    5
    , F" ~/ d) y7 j+ \) X64 X0 n) d2 k( q# x$ c  Z
    7, t( \' u7 V$ q
    开始训练
    $ F2 k- t  p0 G9 A" V注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考' L) X) i4 x9 P) b
    4 L$ ~9 d  e9 Q0 e! ~2 u/ G% \
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))* Y0 u) p/ G& ?0 ~! C
    1, C# i& Y. a& E* s/ U0 v
    Epoch 0/1) [& v. C$ K3 A% |, Z
    ----------9 A4 c9 U+ i0 g
    Time elapsed 35m 22s0 M- B6 m: t7 k' i/ |
    train Loss: 1.7636 Acc: 0.7346' g; E$ i& v, c" |
    Time elapsed 38m 42s
    ) ]3 x, E+ O7 e! Nvalid Loss: 3.6377 Acc: 0.64551 T/ e. ~# L& m( i" W  D
    Optimizer learning rate : 0.0010000
    + A! c3 B* b4 K8 T' {1 y
    # j3 o9 ~, R/ B. vEpoch 1/1, [: {/ U( P% N4 E% E4 j
    ----------
    $ J5 z3 \; ?0 e, cTime elapsed 82m 59s
    1 z, A& {9 H% B  ^5 V6 {2 l4 Vtrain Loss: 1.7543 Acc: 0.7340
    ! j4 F8 m0 }; ^# pTime elapsed 86m 11s
    ; J6 J% O! m0 V' Gvalid Loss: 3.8275 Acc: 0.6137
    - {) r. b# @3 c, ?+ U/ iOptimizer learning rate : 0.0010000
    4 ~0 G- C" t7 e: W  G7 U6 }% R; n
    " G6 w% g+ s  w7 B  D9 N, [Training complete in 86m 11s5 ~: a+ ~  |- _6 n$ J2 x! x6 l9 ^
    Best val Acc: 0.645477" c! G$ I2 n0 l# e3 d! @

      a2 k2 @3 p' T; R% W+ C! d  |1# g4 X) V' |# f  h$ ]& K2 j
    2
    2 T9 y5 m3 U5 }- e7 R1 `2 U+ p0 b36 x/ w% U8 D) L, d0 w
    4
    ( |7 t0 F& J+ d1 w' w( x5
    ! y2 T: I$ d, N7 x) l0 T60 d* Q. E7 i# o) @* b3 Y1 b1 j
    7
    2 q' r, V+ E1 J* V% Q: n. {& `. q82 y: x: i- e! Y( Z8 L, `/ c6 F
    9! U: r& c7 \6 a5 U( y
    10
      o, v7 [+ _5 {" m/ q* |* O11% M8 y  u3 P- o' Q
    12
    - k5 L+ m$ v' N9 J3 [13, d5 N! r" n( n3 [' Y4 m: P0 M6 i
    143 |) E/ W0 e" n5 n2 A
    159 E0 w( n; w* c
    16( N+ S# v2 h) q. ~) Z4 L" {
    179 A; @: K: ?  e+ w+ |* s
    18. ^) Q& \3 S* b! u$ x
    8. 加载已经训练的模型
    + y' ^, \: z4 [1 F2 h相当于做一次简单的前向传播(逻辑推理),不用更新参数
    ( i# n2 y+ n' b; Y- A: q
    * M2 F" F! y6 d; Q7 A+ s1 Q9 x3 {model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)& I( E, }7 Z5 F1 k3 r) |  h

    5 h! e% I" w2 x3 N' W& r( q" E# GPU 模式
    . @  U4 s5 f* {9 U0 {- lmodel_ft = model_ft.to(device) # 扔到GPU中/ Z' Y3 g) w+ X  j

    / f" F5 c8 }/ \# S# 保存文件的名字
    ' ^# L3 `( Q5 g( Yfilename='checkpoint.pth'
    + E7 x* U0 k5 r- ?9 I
    % A4 y- `# b: I  ~* q% r9 r# 加载模型
    0 ^  f9 {. h6 I" k5 dcheckpoint = torch.load(filename)
    ) c) v$ E8 n( ^* Obest_acc = checkpoint['best_acc']  \, ?- \5 P$ M# a# _
    model_ft.load_state_dict(checkpoint['state_dict'])/ H8 m  s8 R- |) e' w+ @
    1
    " v; g, y3 p1 ~6 u0 S2' M1 s5 O6 U/ B$ e) |4 y
    3* l# D/ i0 o" b/ B5 K* Z
    4! q4 U* A- i1 W$ J; F+ Q
    5
    2 T9 s9 \' I" O6
    9 ?  `. _5 ~" y7 Y7
    1 `" f, u# v1 {- T1 Y- |8  K8 ?. C7 E# z* X
    9
    / A( a* ?# b: L- x2 M" Y& a9 n" j10
    - r9 s* W' L% O11/ X, p3 j, g4 F8 s$ X3 H
    12
    # V7 E3 m8 Q$ y( W1 x% y7 }: R<All keys matched successfully>
    - O8 D# x- L( Q  z* Q  j. M2 w1) `; N2 p! N# h
    def process_image(image_path):% @4 e6 r9 V* v2 R
        # 读取测试集数据5 v* ]6 l. ~2 s
        img = Image.open(image_path); _7 O" U/ n! F& p: s
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断- _. d% H! P- [! m& o. v3 J9 u
        # 与Resize不同
    3 _. {& `* @- }& a1 Q    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小+ z8 ~$ \/ q' S- A, K
        # 而且对象调用方法会直接改变其大小,返回None/ ^- X& n0 v4 Q0 Z6 T, Z+ \
        if img.size[0] > img.size[1]:! J5 n' u$ O# N) X5 Z/ T2 \$ R  _
            img.thumbnail((10000, 256))- V. q( t0 N" `$ d
        else:! H! G5 Q# H6 L3 w0 ^7 p: }
            img.thumbnail((256, 10000))9 Z* H( Y; C& S$ K% C! S2 q% G" x& h
    5 j9 p, M/ [7 ~! D5 L
        # crop操作, 将图像再次裁剪为 224 * 224
    ( _  P9 I$ Q' t% b! Z* G2 k$ q    left_margin = (img.width - 224) / 2 # 取中间的部分1 {, P6 z9 y' q# f/ K' N- D! K
        bottom_margin = (img.height - 224) / 2
    8 ~0 Y3 I" Z$ t0 K1 G/ p1 x    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度+ U+ e- j, [3 o  f
        top_margin = bottom_margin + 224' f/ C1 `8 j5 a/ p

    6 o7 {7 K) T4 @1 Q' l    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    5 O4 R0 n- V2 ], v
    ' ?: J$ n% v2 h    # 相同预处理的方法
    ' F2 \' p. }% P2 X9 I' C; N9 R: J. e    # 归一化; X# B( Q  C( `% f
        img = np.array(img) / 255
    ) q( g3 O- J4 T% \0 `/ q- g! }    mean = np.array([0.485, 0.456, 0.406])
    3 o  f9 \& n. e    std = np.array([0.229, 0.224, 0.225])
    4 J: q# `+ Y4 S0 u8 u7 \    img = (img - mean) / std
    " M& z. X8 M9 s
    & _. E& e# x. e8 M7 N% j    # 注意颜色通道和位置0 R5 A0 ^# \8 D( h4 f* ?
        img = img.transpose((2, 0, 1))7 U( w3 ]4 ~9 o& w" @0 Z
    ' C% ^/ X5 Z  A0 K6 a# l: R0 v& d
        return img
    5 l: {3 a3 V1 U# x4 J2 V: Q/ ]
    " {( q2 p- r1 J! W! ?. M. i! J4 odef imshow(image, ax = None, title = None):/ J3 f! \, _; ^; t
        """展示数据"""6 z3 W! ~2 V7 {" ^
        if ax is None:' S. G# O" K8 a/ M" }
            fig, ax = plt.subplots()
    # S  m* l  ^- v' T0 l1 }- z3 M( S. `4 C# ^0 `
        # 颜色通道进行还原1 d+ c# Z4 _  f
        image = np.array(image).transpose((1, 2, 0))
    - ~( n5 }4 E! Q1 h' Z) b8 ^5 n+ S+ h* ~; r5 \/ a% X5 x
        # 预处理还原2 x, Q7 V( I* _4 q0 H1 o/ B
        mean = np.array([0.485, 0.456, 0.406])
    ! i4 a) |1 u! g8 S8 R9 _    std = np.array([0.229, 0.224, 0.225])
    9 Y; j$ a( E9 E9 c, F    image = std * image + mean- y8 @8 L+ l! r- V( s+ `
        image = np.clip(image, 0, 1)
    8 T' i3 }9 V" t1 S' ^  _+ e  k& w: i9 F
        ax.imshow(image): A* j% j7 r) n# Q8 z9 k( L
        ax.set_title(title)
    - j9 Y+ H" ]# ~8 ]. d8 z! D! o9 K' e, |& N. g. Q: a: C* k- q( @
        return ax
    5 }' O' c/ W8 `* T% t& ]/ n# |4 B6 ]1 `
    4 {5 i# q( m+ [2 n& W% `# o  simage_path = r'./flower_data/valid/3/image_06621.jpg'$ @9 `" u) ?9 K" c
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. j3 W1 A( z: I
    imshow(img)' [) T0 T* S) Y
    7 a  _9 @. e/ ]* r* r. R2 G/ v
    1- @& n5 ?: ]3 J
    2- ^: |: a. v; H1 w; f
    3# ?, p/ H1 a4 W# E
    44 A/ e7 l( v% z4 v& T( p) S& l5 B
    57 E5 l7 N! T1 z+ S& g4 o
    6
    7 f; L% o; w; p5 Y% t1 [  n0 ^7
    " d* C/ A8 }# h9 \8! V5 S+ h1 s; k
    96 P# @- r) u6 L
    10
    6 P- o1 k1 J+ t$ h! e2 T6 m11
    * ^5 B+ U3 ?* V- W1 a12
    6 E4 J# K. o3 v1 M; s13
    ! S; B; b# O2 O0 Y4 r/ a14
    # |5 U8 V; J  {7 _; }) a9 @' V: }15
    7 s4 V$ w* ?, B3 D+ I7 R; m. d167 l; _+ ?3 P5 [8 K
    179 \; X4 G6 e# P: ]4 Z9 V
    18' l0 q! m. G9 s* W
    19
    # c* ]/ I& K! O) Y& k20
    % @4 U" F& r) z. \212 x3 [, l/ b3 s' X$ B/ D9 b7 T" `
    22
    + H; Z8 m% H1 ?1 d5 U+ s3 G8 f/ l23
    0 N" ], s' V1 u# E8 Z24
    , N4 H, V; o1 ~: c5 a0 F) [4 R25
    / M4 H9 O7 {, X( i  p264 s0 u  a6 M. t" e+ |5 c# I
    27
    ' |5 d6 a% p4 ^. }  i28$ C7 o5 z! s2 f
    29
    6 C! q' @& }# X6 l; m7 C301 ]  x/ m3 t7 c& P1 P# v; X
    31
    : X8 T8 Y4 O% Q  M$ r( I, J32
    % N  O7 C! V3 Q" ]( q/ S0 ^33
    ( K4 s. r# w# b0 H34! L# [9 }* Y0 O, h8 P
    35) f# R4 k' q+ i1 d* }
    36" _+ k$ _) X& o* G* N& e% F5 [3 O
    37
    3 Q% D0 K; ^6 f3 l! L1 H6 T+ c38( w+ V) ?5 ]- R" B- O+ P
    39
    " O/ I0 H* |) O0 Q40
    6 Z' i7 a/ f4 R3 ?  t; Q41; R2 O$ I( c' n. }7 i6 [
    42
    / T& U( j& C9 Z6 f43
    - H. V9 ^7 U& T0 q5 t' x445 C5 D* D9 q& }6 [
    45" P/ I# i" P7 X/ [! N
    460 c, b+ N# o4 v  P1 ?8 [# |$ w4 L
    47
    5 u( B1 f) \) a1 g3 L, D48
    ; R2 g: T. p7 g6 t+ |49# U' y+ {8 {3 u
    50
    . }8 a6 I) l5 L( x0 L51
    $ l  Z! Z( N( K, k52
    0 ?! }# X. _. d. H' L, h$ K531 _" g8 I: [4 i
    54
    . n: a7 B4 {" V+ X2 |3 g<AxesSubplot:>
    5 k+ z3 N+ L7 b5 r* Z; ?1# z0 F4 Z9 [# h

    8 j. g: w0 L  T8 m6 ?. @! F8 [5 B: ?上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    ( K3 O2 k* a3 Z* {5 K6 [+ ]6 k3 ?  j, m) N, S
    img.shape: A& h) v" M1 J: u
    1& k, r( |- b% w9 T9 b$ h- l
    (3, 224, 224)
    9 D$ L. V* }, \( U9 `1
    # u4 l2 ?0 o& Q证明了通道提前了,而且大小没改变
    # P/ j9 x' ~. v& C1 e9 v+ ]2 ?/ D. Q' \- R4 R4 z7 b% [
    9. 推理5 [" v0 }% g; O# [
    img.shape0 {8 |  o' W( B

    5 p; h1 |( m0 P% [; h1 N2 ?# 得到一个batch的测试数据
    & F( L# y: f% b/ odataiter = iter(dataloaders['valid'])  e* t& J7 q3 c1 Z9 e2 a- W
    images, labels = dataiter.next()( D: O( _/ A& A2 `' }) z1 d

    * p6 {. R/ \$ v, j) y! xmodel_ft.eval()
    - S) y! Q1 b- q( b( c) q1 T2 ?7 U) s* K0 z7 ~4 p
    if train_on_gpu:3 K3 F* z2 n' X/ K: n
        # 前向传播跑一次会得到output
      n6 o% y* n( |4 |& E6 |5 W% z2 ]) H    output = model_ft(images.cuda())  w5 n3 p1 h$ E% t3 v, V' u
    else:
    # d/ j8 \1 ?& R% q    output = model_ft(images)
    * S4 z) M4 y# f% w5 @! M. @% A& H1 U; z" p% e/ k' W
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值8 H! ], E3 t8 t- S  h8 V& X9 Y
    output.shape
    ( v  a+ s9 T4 I9 ~: C. a0 x# S/ ?: x5 A4 ?( E
    1
    * Z- s# {1 Q8 Y2' a. d5 c. ^3 o# t8 b& A
    3
    0 @8 x! U% j: @9 V4  v) K( V: F6 \4 G; x% x
    51 e% Z$ f* j) t" s7 {! [( {
    6
    3 ?. A! R8 \2 E  O! m& x  W" L7. t1 r) D' A6 A
    8
    3 M( B1 }5 Z7 \" W. Q+ @! u+ w9
    8 [9 C0 q( `- J10# t( _3 N( w- j' I5 |0 [
    11! R5 Q& o. N0 o- ?) a. [8 j0 s9 n
    12
    ) G6 p0 p  p' s% j; M/ }# k13
      B6 X2 m8 L1 @14
    , X9 S: u; G+ m. d( F  \15
    & a% o& T6 w1 t" c' C16
    , m6 K- q% X2 Ltorch.Size([8, 102])
    & D0 r$ t9 k4 P) c' w" }+ v16 i4 l/ m$ c; m
    9.1 计算得到最大概率
    1 }7 D4 T) G1 c8 r4 f_, preds_tensor = torch.max(output, 1)+ p1 j3 L# A* M& F% E3 R
    $ ^8 L' S3 r+ M
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量# d  g- O! W( q/ b( o
    1% |. [9 L1 b) B: |
    2+ }9 A& U0 K. F  U( s1 X. F( q
    34 e2 q7 O% [9 ~7 k# {3 Z+ x
    9.2 展示预测结果
    - M1 T. @# h* @2 |$ `( W. L2 s# Xfig = plt.figure(figsize = (20, 20))4 Z( x; U. P2 R% c, P4 t
    columns = 4/ _; z0 U3 o0 Y) d. \0 X( {! {
    rows = 2  ~8 k( R1 o$ W7 g! Q0 R) u
    & \0 ]: _0 O2 C1 a6 P$ `
    for idx in range(columns * rows):
    1 x7 e6 x' T4 D4 {    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])8 Q  j- G1 ?) B/ [
        plt.imshow(im_convert(images[idx]))
    , p5 J0 j% j* w! c, n    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
    2 n: H. r# v4 R$ d4 O                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    0 Z& ~1 I0 k, X/ ]plt.show(), o* j0 |3 _7 _# e" r9 w4 E% F
    # 绿色的表示预测是对的,红色表示预测错了
    2 q1 V0 B" h# a4 k4 I1
    ) g/ R7 [$ m2 Y; k+ u2
    * S4 P; ]3 D/ q; S; G( e4 ^37 ?6 q" |! i4 U# C  d  _! p
    42 B& A/ f" l8 Y# q1 l! Q
    5
    ; k" m( O6 z: z, g6
    % G; g/ X# o5 e1 Q7# H4 t4 M/ t# x; e. ^5 P
    8
    / E4 n7 ^8 U; B7 O- b( [. s9
    , r6 F( z9 [7 D, H8 o/ R- L10: _. P  n! `3 l6 g7 S1 v
    11# s- f; n4 k8 F
    , G5 c/ K% @$ k. h
    ( y: L& j* M3 Q- G  m* b

    6 ]$ g' G+ w; q+ T————————————————
    $ M" g) a5 R2 D9 K# i版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    7 |" O' t  F& k/ c8 L原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    9 Q1 D6 z# L7 `% d) Q0 [# y+ ^7 e' T; C% g$ g6 h

    6 Z2 a; Q# p& J4 S; J. k. q) u, H
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-16 20:57 , Processed in 0.552674 second(s), 51 queries .

    回顶部