QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2742|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例0 A: b8 _: E) w
    & U- d$ x# n$ i
    文章目录
    , {/ K! W9 m$ e卷积网络实战 对花进行分类
      L2 s- z) q: C数据预处理部分, t9 t# s" |; A% X7 U1 m
    网络模块设置
    8 \6 {$ V; i2 p" U* m# K% N网络模型的保存与测试2 k' {/ y1 x8 T) v. {2 |
    数据下载:) z% Z; X, z/ A( R6 S. q( n
    1. 导入工具包2 ~) X  L2 S- B( _
    2. 数据预处理与操作
    9 K8 E8 [4 @0 H3. 制作好数据源
    , h! n) X! A9 I0 c读取标签对应的实际名字
    $ u' Q2 y7 K* X9 u' ?2 I/ S  Y& Q4.展示一下数据
    ( s4 d) l; ^& D6 c; y% ^$ {& d+ f5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    & _0 m7 t0 X* T" a7 i( s1 {3 \6.初始化模型架构# J2 f, I7 `7 V5 \0 |, d! G
    7. 设置需要训练的参数; B+ w# {* I9 E1 G  r
    7. 训练与预测
    # Z1 I0 E* q! p- i: y" r; O7.1 优化器设置+ C' E6 C$ A+ W" E: V9 T
    7.2 开始训练模型( E4 M( m7 ~9 p4 h: _6 r
    7.3 训练所有层! N9 `" s0 q# P1 P
    开始训练
    8 x2 {0 K0 d1 }: _; P! [. L% _8. 加载已经训练的模型1 ^" r' i0 i/ B  B
    9. 推理
    / Y1 o* Z& `, G4 I9.1 计算得到最大概率- X- _% \; i/ A
    9.2 展示预测结果
    # d( j) |, c9 D1 m  E& L& P8 c写在最后. ]) g! q' m, i
    卷积网络实战 对花进行分类' \& A8 E8 n6 B8 R9 a0 e
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    * ^, t( o' s3 F+ ^* h; n7 ?
    8 J" t5 M- Z) d, r/ z在文件夹中有102种花,我们主要要对这些花进行分类任务
    3 _  E# O2 r2 h) ^2 L文件夹结构8 A  V4 e/ L3 \5 o/ y

    # M5 Z; u, s4 }! C- R7 h" A5 r/ tflower_data( V: r0 i# J# T4 R9 |* p
    ; p  v" ^1 K; P# J
    train' v7 q6 G9 e$ L0 r7 P0 w
    : W1 Q5 i6 \* g2 y8 {& e
    1(类别)
    ) a9 w) P. ?8 G% @, v2
    # _, u, `& ~, [2 O3 Axxx.png / xxx.jpg
    : R9 u  Y% i; d, ovalid
    7 k# o; s. `; T  G
    4 E& R: q5 E& v0 i  W/ p/ z( B主要分为以下几个大模块0 t; w4 @8 B- I. r$ n7 m  c3 l/ F% J7 m
    3 r4 ^9 g  f3 T( k) ?7 V
    数据预处理部分
    1 k4 U' u4 S9 }数据增强% h# P' }( ?% y% G1 }+ K) ]* `
    数据预处理
    ( V- h4 ^3 K& S0 m网络模块设置
    - q* z5 o4 w1 v; I! J! Z加载预训练模型,直接调用torchVision的经典网络架构" U6 o; @' x0 ?- E# g
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
    ' @4 ?$ M, ~: f9 D) X9 J* c网络模型的保存与测试, k; |8 `# t' E* m) T; b
    模型保存可以带有选择性" ?3 {- s: m/ L- F( _! U+ h
    数据下载:
    . i, b) ?; R/ U3 X' f" Phttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset3 F) T) j2 P* O: u
    # E0 J. H0 p! T7 y. i  s
    改一下文件名,然后将它放到同一根目录就可以了6 \5 O5 Q# \7 i' o( ~8 R
    + f, ~1 W/ u0 D* p& C9 T
    下面是我的数据根目录
      L0 @, j/ J3 w; e4 \& z$ x/ |, n: j- p) c/ w) }
    # @" W$ Q( [/ m
    1. 导入工具包
    7 |$ K; W" T( s9 b# C+ l, b# Eimport os; q. \, A" t2 v! ^; A& R/ i
    import matplotlib.pyplot as plt
    2 y' o4 S% ~1 _# 内嵌入绘图简去show的句柄
    8 i0 z+ M! W# t: G( x%matplotlib inline . Z# A  h! U; \6 C  }  g. l
    import numpy as np8 T! q* Z- u+ f( I! b  C8 B
    import torch
    - ]5 L) k! f" Z! X! Mfrom torch import nn
    : ^4 ^1 m# U: I: W! Q' `. I9 V3 f/ H
    import torch.optim as optim
    9 [0 u& A2 E5 t, z! n' pimport torchvision
    ) v6 |( v. D8 ^. yfrom torchvision import transforms, models, datasets
    + R1 D& B0 {% x- E$ x2 g0 a
    , e$ c; ^$ o# o5 E0 Iimport imageio
    6 K( w5 J/ @: J; {: X8 g' v1 Z+ pimport time, I. H2 l1 l9 y9 {: o. N
    import warnings
    3 T$ `: G! `5 }: q6 e6 o/ iimport random0 ]) }4 H0 y( `1 [. t( z
    import sys
    : |& t# ~! K6 ?7 e, ^import copy& d3 p* j- g% _9 J9 z3 X. l. d
    import json5 B' {- `8 G- A+ a) y& W
    from PIL import Image
    2 _. Q4 }" x( `% Z
    " X/ n% L' G* ~) q8 o; i- C) u# ^8 |$ k5 [( y! a9 q
    1
    8 j/ W+ E. {! s  s3 u2
    0 w( r/ G- k$ H. H0 d  q: T$ t3
    3 b, ^* [/ O3 ^* r, @6 g, L( I41 I9 C+ Q" C, Y8 R+ |! h( r+ O* z5 J
    5
    ! A9 C" B4 |  r( d' c" }6
    2 a, y9 @" V8 L! L+ g7) m9 ^* F% y  P, d- v) N  U
    8  R! M* |2 n& |/ {& X. C; Q' z3 p& |
    9
      ~( L; _/ X3 s0 ^9 i  c- Y10' i, E8 Q+ N( ?* l; V9 w" Z
    11
    & U2 }- a7 d* l  G) C) H' G12
    6 U& i0 ]7 U* I3 M5 `13* w/ P8 S: q5 j/ N/ e9 X
    14
    5 V9 c3 d8 M0 l8 V/ D* `15
    + \# {8 ]; W7 u16
    , h( G6 p" D+ H2 T8 [! i* i& G$ I17
    6 y) X# H$ d) e% T/ I; B& h18: }) i% I1 A5 t8 j% J( o9 M+ l  |
    196 W2 K- l# {6 |5 L) q" m/ a: u
    206 ?8 Y, C4 u; @. ?1 l- A
    21$ |5 U/ u7 @4 ?
    2. 数据预处理与操作5 s% _+ N+ L; H9 V6 p2 q$ u0 K
    #路径设置5 O/ H$ T3 v0 N; @9 q+ D7 e: L
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录/ q' l+ W+ i: P5 R/ `
    train_dir = data_dir + '/train'" X. s* L2 s3 p2 Y: N
    valid_dir = data_dir + '/valid'
    ) L* p3 C6 }0 Y, a6 e- s+ v1
    + N2 C$ z( Y1 h& u* ]$ ?+ Z2+ i4 G. ^8 n8 B
    37 e/ j* F* ~3 m/ [: Q3 b
    40 d6 V4 D+ Z# j5 A! ?7 Q) {
    python目录点杠的组合与区别3 I8 B% g2 U8 x. b4 ^/ b$ Y
    注: 里面注明了点杠和斜杠的操作
    ; o# f. `# \! p* D. p
    , I* @" s6 i$ e6 R& G* s3. 制作好数据源
      t1 L# Y* M  h( S; Y1 p4 Edata_transforms中制定了所有图像预处理的操作. w% x$ o9 o" h$ s$ q, T
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片& b& e0 s0 c9 ?1 P" p' C, }
    data_transforms = {  C6 _0 F: i2 L  |
        # 分成两部分,一部分是训练5 |+ m% Y4 g4 ]5 i, w# I% S
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间5 `! V3 Y# v; V, D& l
                                     transforms.CenterCrop(224), # 从中心处开始裁剪
    + _" o; E, P3 H' k  `6 e                                 # 以某个随机的概率决定是否翻转 55开
    ) l! Z9 Z( t2 _% C- e9 y                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转6 R& K8 b& Q; Z% e2 U7 l& H
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转. X$ Q/ G% S! G( k. Q
                                     # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
      z  o$ u& S3 L2 V                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),2 `; ?# F+ Z- ?5 F6 ?
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB( t! U9 u, B4 B- u  _. V
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    8 k0 d$ u5 D+ U) O6 h; F; d                                 transforms.ToTensor(),
    / _4 M- `$ w2 U. R& w: o8 h1 ]                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差3 ~" i: R2 G/ `; d
                                    ]),
    * Q( u( K5 |: R5 \    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化  z2 B' o- Z! p5 C& {
        'valid': transforms.Compose([transforms.Resize(256),( m/ h& @2 P0 ^4 t' `4 ?$ s9 G
                                     transforms.CenterCrop(224),' v/ a* p4 V6 I( G; C, u) U
                                     transforms.ToTensor(),
    1 F6 n0 l5 a; Y! i! G5 b# F                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同- O; e" g# x1 g! `/ u6 \0 n3 E
                                    ]),
    ( S' U8 c8 v6 u6 E}2 h1 y3 N7 T0 l' R5 L8 e) a3 E
    7 H4 J. A2 @5 a$ r. C1 Q+ ?, `
    1
    * h* v% k3 A! f, r: r/ k# w2
    " [% M% r# m8 r: [/ ]: s  f3' n, H1 e2 A9 a7 o& {2 m
    4
    ) j' S5 w6 L$ v; Z# b3 E5- E' w5 l' ~9 O) b' H1 p" o
    6
    7 K2 v/ C! T: x3 O. z7; q( \9 a2 D, x
    87 e$ s# t8 I3 P$ v8 N& c8 r
    9# D+ s. \/ V) p) C4 K' t) I
    10; T. S* \' q7 {' {% N: k+ t
    11) ^5 H! C  x3 I8 ]2 g, B; n1 t
    12' e! d7 l5 }( x. e/ ^& J$ R$ u+ n4 [
    13
    2 b4 \9 n1 v( w* U0 t( x+ j14
    3 j& N3 K) R9 X% G: N- `, ~) d3 z  Q15
    - G! K; Z& S! L16# W- L; g2 O4 Z/ t; U' U
    17
    ; K5 v6 ~+ @7 m8 b18
    4 t/ i$ i# ?1 P( g& p& U% O19
    $ A+ p' e8 ^4 [" I0 z20
    # J4 O. Q; q" V8 u( f$ p21' o2 Z! ^' O* w2 O2 m6 S2 V& Q6 t
    batch_size = 8
    & P; g) X4 ]$ f7 X1 Dimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    8 m. B% |, L! qdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    * l2 S4 W( E) ]  Hdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    3 w  z6 H. u( N! X$ o9 ?class_names = image_datasets['train'].classes7 n$ Q' S5 `5 Q4 n

    ' V4 q9 @( v9 B7 y; Y+ s#查看数据集合" z. H4 x- s9 x) f0 e( u
    image_datasets
    9 K) N  e9 ]2 I# h
    1 o7 G; B5 z2 s  @. |  Q% m1; u1 ^' f" B1 J) K2 M$ F9 D) Y/ W
    2; g( B9 X- L4 M2 \4 v" \
    3. L& b2 B; c( {; ?
    4( u# Y3 a6 U2 I3 P
    5
    ' G7 D; r- y2 {6
    ) X  Q& R6 s7 o7 G8 W9 m7
    / t" f" f  ~. M3 l# n' a8
    5 N. ?. _$ d2 g: Q& s9
    1 c4 c0 N' Q7 j  c- }% a, t9 C{'train': Dataset ImageFolder; n3 S" X, {3 _0 D* l
         Number of datapoints: 6552) C- Y  Q' a0 Y! i: {  s3 S
         Root location: ./flower_data/train5 D3 E5 W9 u+ n" v$ [
         StandardTransform8 P; {4 U  n: q, y# A  f+ s5 h' b
    Transform: Compose(
    $ \% t- T0 `/ b! i                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)" q1 Y7 M: {! W7 X, |, s
                    CenterCrop(size=(224, 224))
    ) {6 P2 ^! }) X8 [5 V  G7 H                RandomHorizontalFlip(p=0.5)1 V% f2 x: |, ^: B
                    RandomVerticalFlip(p=0.5)
    ( p4 f4 L, B" ]8 e1 u                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])* }; [2 x+ A; {' |4 }0 R
                    RandomGrayscale(p=0.025)
    9 b6 ]+ ^; P. E1 n( P- z                ToTensor()3 M, c1 K$ b! S. I- e5 `
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ! U, q2 H, X* {0 l8 M            ),
    ! f9 H* J* c: P 'valid': Dataset ImageFolder5 U$ P/ @3 f' u) M3 _
         Number of datapoints: 818! Q8 Q+ H2 v7 T5 R* m+ k+ k
         Root location: ./flower_data/valid
    , [) P6 w3 d# V2 p  m3 f, R     StandardTransform
    3 h5 b" }- p: A9 f4 F# c Transform: Compose(
    0 q1 R( e9 J; e) ?+ m; K" l5 R                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    / V) e; _. x2 W# N. X+ p% K: k                CenterCrop(size=(224, 224))
    ) m5 d0 Y, q# X0 {9 ~0 |                ToTensor()
    % @. Q) v* L/ z2 S  d                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) o! P  H& ]. U- Q- a7 k$ z
                )}
    6 {" u5 B; ^: `  s2 G" |2 V+ \
    % }" [8 p0 U' z1
    * c% \! D+ x5 x- S2
    7 N3 N7 ^# P1 {32 B2 O$ g: R# Q0 N# V2 o% Z
    4
    & X- F7 c' B7 O# [5 h+ B5% I0 j! G5 H, S/ b" {9 H2 R7 ^4 S
    6
    " n( F& e9 P( b3 m: B& k8 B# a7+ @5 v& O9 I9 }* @: z3 k* o. G* d3 G
    8
    - m$ X, @8 E% t, g5 v9/ T3 S8 {  p/ _: j# j! M7 j  ~  P% a
    10: y' i1 ]2 Q5 k3 H7 \
    11( F+ e& J' W* L7 ^3 _0 t+ E
    12
    7 f" U+ N" r! v' z# v13" W; W' ~4 |, ]3 ?
    146 n& I8 E. k! Q5 V6 `# E
    15
    # g7 u; y9 _* M+ R, `- ~/ n167 f9 ~& N2 Z8 U# ?7 R9 e
    17/ G' u" {9 H# p( }9 o
    18* I* L# N8 V% s5 P/ B% ]
    19
    3 T# b% h$ C; B+ I20
    ; t: {% Q" n0 K5 C3 W6 w  Q" C5 x21* r  w8 k! R3 o$ T  p* [
    22
    - \. I' u6 \1 e6 S7 C- Q1 y23' o1 A; M- E9 O% G4 v2 Z" J) B
    24+ O4 K/ ~" O/ t* D; c& L
    # 验证一下数据是否已经被处理完毕
    ) o/ o$ K7 H; j1 j. X4 s1 G" qdataloaders
    " }6 W: F, \  z14 ?2 K8 H- V- @9 ~9 W3 P
    2) y2 L; M* J' ~) A" s5 {  q
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
    , Q6 c! I4 x6 P: v- i: x7 L# E 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}1 n1 A2 q( [2 M6 e3 {. ~7 G$ D
    1; @1 g  ~0 Y, ~0 _( y$ P0 ~
    2
    , m' l8 _3 l- Q* h: Xdataset_sizes
    9 x! m+ ?. M8 d% {4 V3 M6 v: k$ F4 ~3 p1/ j; L/ R, ~# i8 c2 F: C" Y
    {'train': 6552, 'valid': 818}/ M" D) d& l  ~8 X
    1
    7 r  z2 M1 i8 T% z4 Z9 Z读取标签对应的实际名字: D5 C  w9 }! `: M3 ?2 r
    使用同一目录下的json文件,反向映射出花对应的名字
    0 @  S2 z9 K! q7 B0 y2 S
    ' c. X1 R! P+ r" ?$ B1 cwith open('./flower_data/cat_to_name.json', 'r') as f:& Y, \! x9 r+ I: v
        cat_to_name = json.load(f)3 r+ I) o: ~7 b2 r+ G9 r' l
    1( v( f* T" g9 l" t  r' X5 \
    2
    # D: _! @! v( N9 {8 h  Ecat_to_name
    # i% [9 e% ]; `. c9 t1
    / S& C$ R5 C" F{'21': 'fire lily',
    3 r4 i. M" f' x8 U" f1 `0 }) F '3': 'canterbury bells',
    $ \8 b; E3 J/ r; u* H '45': 'bolero deep blue',
      l, e- v  y: G '1': 'pink primrose',0 i( L& V# J% e$ {2 f
    '34': 'mexican aster',
    ; r9 [) @( e/ K '27': 'prince of wales feathers',
    ! @+ S% n6 Y) U- B% [ '7': 'moon orchid',: E1 o$ L4 O3 q. r# [
    '16': 'globe-flower',) U$ x  m$ L- m- m: E: _
    '25': 'grape hyacinth',
    % i% K# a; {- y, l4 L '26': 'corn poppy',' q4 M$ x9 b' ~0 N3 P$ o
    '79': 'toad lily',
    - F, K. \7 p+ W: e. C '39': 'siam tulip',
    " ]* G2 k4 n# e0 D '24': 'red ginger',
    5 ~4 N/ U- b5 m( x, b '67': 'spring crocus',
    , L5 z% N/ [6 ?8 M '35': 'alpine sea holly',
    2 J0 y, o' _0 S8 G '32': 'garden phlox',
    # s3 N$ p+ b2 M '10': 'globe thistle',
    3 Q* L; x' e0 K; i+ t) J$ u '6': 'tiger lily',
      B7 r# _+ F5 i# Z+ o '93': 'ball moss',8 C0 F2 D* @8 F( p1 p! N
    '33': 'love in the mist'," w( K8 K# V) S1 |
    '9': 'monkshood',
    4 k7 \+ n4 ~: O9 p '102': 'blackberry lily',) Q4 r8 v1 O( w8 V" X; i, p
    '14': 'spear thistle',
    * C% C9 A' m1 c '19': 'balloon flower',$ i8 |& k: M, K, p) @# t( F6 @, O. r' K
    '100': 'blanket flower',
    + \+ I1 h* L! z: ~# z5 }7 S '13': 'king protea',; m* t) F* N0 l
    '49': 'oxeye daisy',
    2 \2 K' _5 p$ m9 D4 f; ^ '15': 'yellow iris',: O' ~) T% D# A" p
    '61': 'cautleya spicata',
    - \! W; ^4 m* p$ V5 O  u& U( V '31': 'carnation',
    % t4 [2 P& ~. J+ h '64': 'silverbush',( t% z& ]6 G1 l0 {& @% Z5 k: u
    '68': 'bearded iris',/ d/ f8 C& J+ P7 E
    '63': 'black-eyed susan',
    ) K# y7 Z" C& h! k; o '69': 'windflower',' z: M9 l; U0 }* r: N1 P* Q
    '62': 'japanese anemone',
    & v6 m& j+ R4 E0 Q '20': 'giant white arum lily',
    & I* Z7 P! P; S( ]1 R% G '38': 'great masterwort',3 ]3 o1 ?! k1 e- B# O5 n/ |% Z
    '4': 'sweet pea',
      }, |; e. G! E: c3 E$ F7 M '86': 'tree mallow',) z* W2 M( j1 a; \
    '101': 'trumpet creeper',
    1 K: ^; t  [7 f' S4 a '42': 'daffodil',* J, g5 o& a3 C" b- B& g9 {
    '22': 'pincushion flower',
    + i+ Q2 j4 a% k8 o  I* P '2': 'hard-leaved pocket orchid',
    6 n, |, ]4 J$ _- {9 f% D '54': 'sunflower',
    ' O1 ~9 Q' |: ?2 g1 M '66': 'osteospermum',# w  ?4 J) \" L5 A9 G
    '70': 'tree poppy'," s8 r% ~& F9 `* k, a8 `" _
    '85': 'desert-rose',
    ! h* G# [% v7 j0 ] '99': 'bromelia',
    1 U# T# g8 f; M" ], C) P! z, v '87': 'magnolia',
    0 }8 V) l/ s0 p' i7 U# c# E '5': 'english marigold',
    / |4 P; j2 w: p) y# ] '92': 'bee balm',
    " m3 [5 }6 m1 C) }  J7 j- p '28': 'stemless gentian',
    # x$ U& ~0 r5 t! Z) e5 ?7 G) ^ '97': 'mallow',
    ) L# P) u* f. P7 f; l( R9 A '57': 'gaura',
    7 {; g3 J9 c$ R% P+ } '40': 'lenten rose',
    1 Y  W' L) A" r' ?- }  l '47': 'marigold',
    3 d5 q) {& m7 B, C6 ^/ z! ~6 | '59': 'orange dahlia',6 o. x0 w  I  ]" s% W( l5 g
    '48': 'buttercup',
    1 ^* c$ O- ?* x. n" n) z  b4 X '55': 'pelargonium',' c$ N( k; F& |3 H  k0 r
    '36': 'ruby-lipped cattleya',- Y0 E! B. o% M% W; z2 A$ B
    '91': 'hippeastrum',
    ! }% B' G. {1 I* n- ~! T '29': 'artichoke',
    7 f6 ?! B. ^, G '71': 'gazania',! |4 s  B) R) ^. q- I4 i9 x9 z( A
    '90': 'canna lily',9 m+ S4 i0 t8 j
    '18': 'peruvian lily',
    . Q/ {  o3 S4 ]# c/ ^; v '98': 'mexican petunia',' ]) b9 A0 M5 a* w
    '8': 'bird of paradise',
    $ q* \/ U, D0 l. G '30': 'sweet william',! a+ ~  w' m: `4 G1 O, }3 C
    '17': 'purple coneflower',9 Z; p* k# ]9 I+ `
    '52': 'wild pansy',/ T& b  n7 _) c7 c
    '84': 'columbine',
    7 y* P& m# {6 P' d '12': "colt's foot",! M2 v5 }. ~& `6 s& ?! y; S
    '11': 'snapdragon',/ e5 \5 n8 c0 p% Y$ j4 Z7 e
    '96': 'camellia',* D7 `! m% @( h) j( Z" Q' L  R
    '23': 'fritillary',0 E* R. L7 ?! _+ x- T2 [
    '50': 'common dandelion',
    : S* v3 @& s" y8 t& x6 R8 D '44': 'poinsettia',) E4 z# r# E, ^- {5 ~1 S
    '53': 'primula',/ f( H, c( w$ g1 y, U
    '72': 'azalea',7 a% O4 f* d, v
    '65': 'californian poppy',8 p9 h( K5 Z: V0 I
    '80': 'anthurium',6 i  p, V; m; W1 ?4 y
    '76': 'morning glory',
    0 R' T2 E5 l* u' g, i2 j; w '37': 'cape flower',2 h2 [' q, Z3 a# q4 i
    '56': 'bishop of llandaff',- y9 c. b) B8 H+ D5 D8 @( W
    '60': 'pink-yellow dahlia',
    7 F; \1 V; f. L  P' M4 g '82': 'clematis',3 V) a, m( I1 p: ~
    '58': 'geranium',
    ; T) q) k7 `) g '75': 'thorn apple',. ?2 J6 B) U8 Y; t
    '41': 'barbeton daisy',
    ; k9 S0 J: e1 o& [. g( X '95': 'bougainvillea',! z! i1 l/ K9 W4 ?
    '43': 'sword lily',
    3 f  T' b$ F7 I6 Z '83': 'hibiscus',) X4 \. B# d9 t8 n
    '78': 'lotus lotus',
    / v9 W; h5 D& y '88': 'cyclamen',
    5 s- q) w  w2 _, s; u '94': 'foxglove',
    8 ^4 C6 A+ z# ?+ R) @ '81': 'frangipani',
    + U' t7 W! N( x- m '74': 'rose',* g& S+ ]0 U& |& g  O
    '89': 'watercress',% p/ S2 o! J- y8 j
    '73': 'water lily',, v. t0 m, C/ n  B3 r) c; d7 n- a. z
    '46': 'wallflower',
    ; D' P% c# M0 A6 H; ]# b1 S '77': 'passion flower',' Z& R- ]& X# c' U/ ]
    '51': 'petunia'}
    ' s" Z) }9 ]% O  u( q
      W: t. O# |& G13 p7 R; a' @4 K: ?3 a
    2# X1 y' A4 W  A( @
    32 a0 u4 x$ P) k6 F6 k! }
    43 E: J* u4 C; j, e; P- Z6 ~' F
    5
    : k* {9 d1 V5 L: j* \8 ?6
    1 S: X5 F2 I* P% D" c( R79 z' X0 g! l; ~: @" i
    8
    . Q' N3 q, L" E2 ^% t9- e9 u8 f1 \. y. C" x5 O7 ]
    10
      o) Y# `" f/ O# {( l/ d/ ^11' i2 @3 q. P0 s
    12( j# A- Z5 \2 G' A. P- O# t; B7 |
    137 C: Q: c! H( c) A8 j3 \
    145 Z& ^; @6 C$ C6 a3 g6 X( I$ S
    15
    & ?( N* T6 K! }& ?16- r; |3 \# r( j$ D  e. {  `
    173 e2 |+ P& ~5 t( |
    18
    3 i) N8 Y# w2 H. t19. q3 U2 U  X. j8 o2 X0 W$ j
    20  J( n+ ?9 ~: R, ~
    21
    # V. S8 v6 n) i1 e+ M9 ]8 z. L22$ V' ?- s+ s) D& n- u: b  m
    23
    ) |1 f+ r7 x* \24* U. R* t% s1 W* W
    25
    6 C, Q8 \+ b9 y, F* v26
      _+ ~& L# S, O. N8 `; ]27
    8 o3 i' [7 T4 d/ Q282 P6 `; x" Y8 G" F9 W
    295 Y. q7 e8 v0 n) t# `- _
    30
    7 L" f4 ]3 ?7 [31; {) m+ E/ O1 v" h* ?, h5 b# F
    32
    3 g, K# }# X  k3 z9 O33/ w  c' S( w2 I( I
    345 B1 T9 ^: c1 d2 x5 \! `
    35
    ! M- L5 c, a; r) R" ]4 E" B36& h+ i. h4 f7 q
    37
    ! U' i5 [0 x8 \. y" v- x2 r& c0 l38% I7 |- P5 m# X* i
    39
    5 J  e, v; b1 @+ J# Z' ~1 Z) |40
    2 A( Q1 n  ~4 g1 m8 ~8 t41/ d6 i' `- ~6 _$ N; n+ D! e( T3 i/ }
    423 D# m* s. o" K1 D: d  X
    430 g& a% J1 j# e5 q; ]6 b! M0 c
    44
    2 k' P0 X) N; v45: a+ X- Y/ k9 ~8 m3 Z; N2 V  Z+ s' A
    46
    % X; m+ [/ V' d" Q5 I+ G: J47% X( `  H$ d# Y& \
    48
      V3 u. A7 m: c/ _+ Z49. M3 B1 h- p. |( z
    50% S* {: y  p: W% B3 h
    514 I2 a; I2 E6 `  J
    524 e0 p  C$ P1 z) L! T/ r5 A
    53
    % z( W6 Y$ R) l5 C) j. ?8 Y54" p% i: ]2 _+ t  S
    55
    : L( u9 Q8 @8 G3 f" R562 a9 {0 b; h+ z9 k; _3 C
    57" w2 V( A" Q. n, L
    580 ?. t  _3 h' C2 [+ Y3 \3 W4 A
    59
    1 o- b. M) f/ u- d! }& j60
    5 P5 A- P1 P4 p  y7 I61
    # C% ]. Z0 r, c8 x4 n% m62
    + d3 K" |0 M/ k' C4 G63
    9 \% J3 r1 F/ E64
      B- P: a# a& d6 I  r2 S- c$ S  Z: O65/ z/ v/ R: n9 o* n
    66; h+ D) Q9 A& I- j* B. P' Q& @; v
    67
    6 }* R$ K4 ^+ q8 C2 Y" J680 C! `* X' a4 L: w# \: b' L
    69
    " x+ x: e" y/ g5 x9 M( {5 F70
    $ }& o3 Y' W9 V  t2 t: b71
    , J* f9 w" {4 l3 b724 L5 R) n  `1 E8 j
    73
    ! }3 F5 T  o7 F* b9 G74
    9 r# W: [3 ]+ G4 _75
    % a7 Y! v7 X" B- ~; H2 f/ f+ T76' b* [8 D+ S, L% m( u. _3 o
    77
    7 i, l) S4 r/ K! M78
    : R/ [: f8 j; A. F79
    3 |6 A- b* P; q) Q3 Y) `7 \  S6 @80
    4 p# u7 h5 K1 s3 k1 W( C6 l81) S' F( p' l  h/ W2 ^; `
    82: X% Z) _" y$ ]: W4 [0 J
    83
    3 W; [9 T# j# z! Z2 u( n4 f84& r! k# V( O0 h# T. k$ {
    85/ K, f- b4 f1 ^9 u# E! l
    86
    0 H% K9 M2 h. f# T87
    5 Z8 ~! o  i% K) d88" E5 u9 T+ @2 V. D
    89" V7 x' R% \5 M% G1 I6 n
    90
    1 |5 Q. _; L- U91& s0 F- H" R" c" L% t
    928 D. Z  N% U4 C- m* Z
    93
    0 A. g- W  j7 y4 y4 u# p94+ B7 V+ @0 L( L0 ?( ^
    95
    $ K4 w% f' Z2 A3 B96
    0 J3 n% b1 y6 _4 C& A' {97
    - b8 M, ], }( ~2 ^" t4 i98
    ! b8 e, F: C! Y99
    . n8 U" z; e- v. A100
    7 M! s& M& H' G+ v101$ n; V. j3 f3 K2 |
    1029 ?, X- i' E, [- q
    4.展示一下数据! W1 W2 d4 \; Z, i$ h0 K( I0 F
    def im_convert(tensor):
    ; y7 l% }; S' A* q0 k    """数据展示"""2 F9 p7 t8 M. |: z
        image = tensor.to("cpu").clone().detach()
    $ U$ g1 L1 w" j  i. y1 u    image = image.numpy().squeeze()5 z, a$ ?7 D  x
        # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    : i% K+ L! j# t2 F) H- S    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    ' |9 {( B0 q9 h* h7 {7 M    image = image.transpose(1, 2, 0)
    ( V, g) x- c8 \  H    # 反正则化(反标准化)6 B& P: _/ E9 v
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))7 C# u; t- F4 \% T- V5 ]5 P

      o3 U; q3 {8 {8 ?, D    # 将图像中小于0 的都换成0,大于的都变成1$ \2 p* }- `9 C7 W  B, [# B
        image = image.clip(0, 1)' {2 n$ s" T' l" L
    ( w% V; I* b2 d1 W$ k. c- S: F5 m
        return image2 |+ m3 C4 J+ D5 Z. g- f" n$ n7 o8 n
    16 M( o1 p$ B' n# Z: p( G
    2+ a% c! E, \1 q6 P5 y( o
    3
    ' y3 T+ ~4 O/ u6 E, s4
    ( I6 s8 H6 {+ T% I; N! h5
    * P* K7 L+ {+ ~5 g6
    & a- R; ^. M' Y$ M4 y! A) A3 l7 s7
    0 ^/ y. O7 Q* f5 b8. m: N: i% T" M. P$ ]' n( L
    9
    . x7 c4 x3 x% @# y" k" i2 Y( h5 E10
    # P( |) M0 t6 H& v( |6 \11
    1 O+ ^$ y3 K6 `: K" F12  L/ |  b9 z0 ^- x; w9 I, `
    13! ^8 F% J+ n( I' \6 ?" V) r
    14* B4 Q0 @3 y8 S: _2 N" R0 A
    # 使用上面定义好的类进行画图
    0 g  D- N2 M) T6 H8 pfig = plt.figure(figsize = (20, 12))% y/ J5 p6 B/ R5 b
    columns = 4  d, E$ m5 G8 T  g* M$ ?) H0 V
    rows = 2$ O+ T: l6 f0 G  h" k$ ~
    ) \. w1 g0 ?' n- u& Z
    # iter迭代器( w3 U. W8 o; {
    # 随便找一个Batch数据进行展示8 j7 P+ I1 {# q$ W
    dataiter = iter(dataloaders['valid']). C" K6 g( T2 i# y, @
    inputs, classes = dataiter.next()% [; }0 @2 r' S9 \
    4 e1 U% r9 v* B3 \  y
    for idx in range(columns * rows):3 F, k8 G8 n9 D( l( Q; h
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    1 q% O' a8 z3 O    # 利用json文件将其对应花的类型打印在图片中" z5 P' _9 r0 D5 O
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
      \: F" F0 |, C: E; M( G    plt.imshow(im_convert(inputs[idx]))
    9 k4 h3 [7 l+ C; L& oplt.show()
    & O& E* x5 U- n6 x7 s; I3 C/ W" Z  Q. Y+ W* \
    1
    , ~6 m9 w" W2 a4 K! ?  Q2
      O6 a8 H- _. e+ f4 A3
    , q5 P9 t' k0 E4
    * q6 V/ i( A# L$ `/ `5% b" d7 u5 w# Y" Q6 p  P! \
    60 Z. F2 a8 G. r$ m+ q/ O
    7# f, i! D9 r, L0 Y. I8 i* Z
    8
    % R8 p/ w  @, U9 b9
    5 C4 y' B' G( O3 ~* x2 ^10
    4 N/ m* d6 ]/ P7 l0 e11! |5 x  n9 {# _$ {8 W' U6 j3 s! s" H
    12
      x4 \4 l# b3 T* F5 C, ~1 e" |13
    * P& q: B0 i" }14
    $ j. W; y' F' p15
    1 k4 a6 V, m- B+ N6 C% {165 o5 G& a" W# R- m, f

    5 X$ K7 \/ B) A$ }" z8 ]* L; x( e& ?' p5 H9 H
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数/ t' U9 I" h7 l' h7 [  I
    model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']! b! I' q& n/ K
    # 主要的图像识别用resnet来做: A4 X/ o# [4 j  @8 @
    # 是否用人家训练好的特征% T9 W( V5 F7 n! }. B
    feature_extract = True3 \  n( T( B1 E3 m3 [( E
    1
    - _  S7 Z6 [; \. Y: \: V- ]. I; W2
    9 _! h, B5 e" p4 p6 \0 f5 m3
    ( A8 F* H  C/ |# w% ~  w4# a8 B" t% g9 [$ r5 A0 G3 X
    # 是否用GPU进行训练2 s9 r* r4 _/ H3 [; E' K- t9 z6 u
    train_on_gpu = torch.cuda.is_available()! ?) ?' D+ }$ r2 i! R6 Q( e- N9 e4 i

    : t/ T0 X+ j/ |& |if not train_on_gpu:
    $ B: r- {* Z- L. [! `9 A- z    print('CUDA is not available.   Training on CPU ...')* N* _1 P; F$ M' U  x$ K
    else:7 g. ?7 V/ p% k
        print('CUDA is available! Training on GPU ...')
    / w$ }0 c& K% B# ?! l3 j  G* C7 h3 C5 k* b0 n; n
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')9 Z( L# \' Z$ h7 j+ V& i! K
    11 Z  u# l! u( U7 t  d9 E& K5 u
    2
    , ^' v7 w: @# c: t: [0 `5 d3
    # a/ q* J# E- N5 F8 S4% X1 L, A4 [& W* Z, }
    5: m3 _# k+ D( b: j
    6
    $ q% n# n' ]1 W77 q# ^3 k, k/ K9 z4 I$ D
    8' v" c7 r7 N8 L0 y& q8 D% s
    9  u  H6 n8 K6 f
    CUDA is not available.   Training on CPU ...
    7 m$ r& o0 i  L3 d2 G1
    7 s4 \; Y5 i8 ^5 b8 L# 将一些层定义为false,使其不自动更新/ {- V% d* ]* c8 b( q8 ]" B3 f
    def set_parameter_requires_grad(model, feature_extracting):
    ; M3 Y* P: E. s    if feature_extracting:1 [+ y% O. \' U
            for param in model.parameters():: _, C. k2 _( i
                param.requires_grad = False1 `! {" h3 N8 K8 }; Q
    1  A  n7 r  n/ `
    2
    8 C/ A: C2 E  t/ Y( t3. S' ^+ h; |- U
    4
    3 F% I4 p5 J$ r) j5* E5 Q9 w. D& m  w
    # 打印模型架构告知是怎么一步一步去完成的
    ; M9 d* v1 L" ^; I, @* w# 主要是为我们提取特征的, w& L' r" F  C  Z' G7 U! G

    5 J6 E0 [- P& ]5 T0 V& Qmodel_ft = models.resnet152()
    0 A3 B) t  I9 e+ k* Lmodel_ft) a3 q+ p$ k7 Y6 d# p# K6 b) n% `2 w
    16 u. v* \- z  u( ~
    2' T+ [2 E/ C* i
    37 N% X$ B2 f4 G: i  @4 m
    48 l. R5 \/ t$ I3 P! J/ B
    5$ [! g' T8 ]( A" @
    ResNet(
    0 m# F: u8 t3 N/ G  i  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)# T5 v4 W# [- j2 [( I3 j# X
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True); d! B. ?4 e' t4 S2 e
      (relu): ReLU(inplace=True)' _; _# P6 L& J) y' Z0 }
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)3 N; i/ z5 |, y3 G7 [
      (layer1): Sequential(! p3 q$ c. X3 M0 z! `4 G
        (0): Bottleneck(- \. U7 q9 }( q, [+ h
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    + q( H+ Y  T) n4 \      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    8 r# d: r/ K% u3 g& Z' s% b& H8 f0 R      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    ; h+ p2 s; N9 [: Z$ W3 T      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    9 p3 x$ H( \5 ?8 t* g      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ' k0 o$ K+ D  K. C3 A# f6 W      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    2 j2 F1 j' L5 i% E1 K      (relu): ReLU(inplace=True)% M( U9 C) S$ q# c
          (downsample): Sequential(
    8 l: K- e2 W8 g. l  t0 O9 I! t        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)0 U" I3 j0 y* D- Z
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)* h+ |0 D! j4 K
          )
    & W! b" {4 P; `9 Y% f# p    )' v- l! m7 P: g* |. c# N
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    - G0 D* L0 c5 {6 d3 p    (2): Bottleneck(
    1 G, H- n; n6 Z" @7 r) B' {      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    5 G, F& M  W$ d$ W      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    0 t( d; j3 F# c      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)9 N  I& f+ o' k
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)6 m5 i5 Y2 ?' s3 T+ Q- N: G7 m7 e
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ! a/ {% t6 |) s( y# m! x! }      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ; e' x% k& \, S8 }: f# p' \2 L2 n% h      (relu): ReLU(inplace=True)
    0 @3 A( C) w  ~5 T, V1 f/ B    )
    % H( k- B; t/ ]  )
    1 J) t3 N- X( K9 \( s# t/ B  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))9 y. K, L0 U$ k5 P7 i0 `
      (fc): Linear(in_features=2048, out_features=1000, bias=True)! @. p0 y4 p! L- e4 s- e7 X
    )
    ! \: h3 @) A" n% Z3 X. q
    $ T" Z% \) l0 q1% K+ U6 Y2 ?# E% e( ^! N
    2# i$ ?1 x8 ~' z
    3
    , |: M! e1 K2 f: M, Q4, ~' S$ a' A, r
    5
    ' @* ?; ?# L$ I4 Y% W  F2 d* ^6
    - z( A* j& ^4 E: ^: ^; ?$ ?# ]# a7
    2 m+ j( O* p* `$ [' \/ I9 {8
      G( z/ N# _) X# a9 g9
    : b1 n, t, u, R10
    . i, J: o" M" m3 J' x  Z11
    ) p$ N% F5 K0 ~) p% }3 G12' z: N$ s) a6 H# j0 L$ C
    135 {1 J1 q) w# }( T  V
    141 E& e3 E& M' S( ?/ M+ e
    15$ {0 f* S- d/ G9 U& ]9 I
    16
    " I  b. ], J$ R5 N) U% W17
    ) N7 U5 d  C& Y: H* s  V+ h2 ]& N18
    # M4 {/ C4 S( C1 A2 @19
    + T( D3 F6 ~1 p' h208 C+ b! C+ v4 ~0 n" u
    21, r( ?( ~5 u, Q2 Y9 b; v
    22( }; q9 X4 |0 m: N
    23
    ! ~( O$ @) u/ }  T" I24
    7 s' `  s/ \* c2 z5 ^; C# c25
    # }  O% I) K! p- @1 I26: g  y5 P+ A" p/ g# R' E
    275 @1 v# y1 E# f. Z( g+ J) [' ?* m! Y
    28
    % {7 }3 }0 T* N+ p# y4 D. [29% p( l1 e! j' E
    302 l7 [# x# X7 o
    31
    - U$ L. @' m8 ?: x32  }1 ?3 V. v3 e  ]9 K1 E  X5 l9 S
    33- w2 F, B! W$ C7 M4 q
    最后是1000分类,2048输入,分为1000个分类1 P3 |; C# b. f/ k$ w6 `
    而我们需要将我们的任务进行调整,将1000分类改为102输出
    . r9 l( R+ d' Q( W" B  t- w/ w7 F$ U0 S5 H: c3 u  p. z# }4 i
    6.初始化模型架构
      b4 K" i, U: k' A/ x& ^2 A) `步骤如下:, K; D2 q: @4 Q4 T% Y

    4 w$ b2 _8 h4 m  S; C将训练好的模型拿过来,并pre_train = True 得到他人的权重参数+ W' ]$ J; ]* s+ E( P
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)2 K6 E* B+ r/ I# F$ T: u
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数+ P1 b+ }# @4 E/ N$ k- |
    官方文档链接
    5 j  r# |/ W. o( S/ U; m, @https://pytorch.org/vision/stable/models.html
    . ~6 y% H, J- B2 W/ x: S1 j! q, l( _
    # 将他人的模型加载进来, D" z5 D6 b/ ]1 k
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    % x7 j& h3 S" ~7 ]9 w    # 选择适合的模型,不同的模型初始化参数不同
    ) |: R$ l+ l3 y$ J: L    model_ft = None
    " m$ _" a: L' {* i1 m0 o    input_size = 0% q4 _0 f8 P* n

    7 O  [& C: W% {" O9 z0 v    if model_name == "resnet":
    - l* d3 U$ y# {- G% c' _        """
    ( G& o2 K. ~, ?        Resnet152
    * c0 s; @6 g( D) N" u  W; n3 c2 F        """. s* e5 [7 ^+ d$ w: }, e# a

    % r4 j4 C; D! N7 U# k2 b( w& D1 c8 E        # 1. 加载与训练网络
    + M% y* f' A9 ]9 N$ N& G        model_ft = models.resnet152(pretrained = use_pretrained)
    ; n, C3 a# a) b        # 2. 是否将提取特征的模块冻住,只训练FC层
    6 K; S" Z% o; R& C0 K* Q        set_parameter_requires_grad(model_ft, feature_extract)
    ; |( u8 s! N( |7 n) ?* r        # 3. 获得全连接层输入特征
    ; ^  N5 o4 d2 l. u7 `  `  q        num_frts = model_ft.fc.in_features
    . b! i9 @; X  P2 z. z8 x  R9 F) J        # 4. 重新加载全连接层,设置输出102! [" M9 N# u6 |: |; y
            model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
    7 x8 \! e& `0 U$ X8 q                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
      E5 d% I% x3 Z- Z3 C" V4 [6 @) ~  P        input_size = 224
    & c6 o$ B. g) j) N5 R( ]  p  ?" Z* }. ]
        elif model_name == "alexnet":
    " G1 F4 c& p8 n; u+ |( R0 p5 ^1 D        """
    # R8 q/ |" p0 `% h% e5 I        Alexnet- |/ ~2 N' q! E. l! R. B( S& G
            """
    $ h1 N( Z. D. W. D$ Z4 H6 I9 w" F        model_ft = models.alexnet(pretrained = use_pretrained)
    ) l  }9 _/ h* H6 t6 p6 K+ S+ U' s- \. x        set_parameter_requires_grad(model_ft, feature_extract)
    9 W% d% I: q+ I3 v) ?
    / V  \4 w8 f+ O        # 将最后一个特征输出替换 序号为【6】的分类器
    5 b& \' }4 r3 m' z3 w' R+ |% ]5 I) N        num_frts = model_ft.classifier[6].in_features # 获得FC层输入
    , D9 Q6 w3 i& e: y% [        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)2 U# T0 _8 i2 Z4 V+ h
            input_size = 224
    # M2 Q4 [- \; K% H( o! R+ P7 y9 O2 \" B3 N
        elif model_name == "vgg":
    8 w, Q2 @2 `+ M6 X& ~        """& ?; n" Z' U3 q, v7 G
            VGG11_bn
    8 A$ ?( z# p8 U        """
    ' R# T  e! q! v2 N2 \# E; w        model_ft = models.vgg16(pretrained = use_pretrained): A+ F4 Y8 i2 d0 a5 C$ P0 }
            set_parameter_requires_grad(model_ft, feature_extract)3 ?, `( X$ W3 M: V. F8 \; W4 D8 H3 ~
            num_frts = model_ft.classifier[6].in_features
    & ?* F, V* J& ?5 p+ f! X; T        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    ! C( R8 m+ k) h        input_size = 224
    ! ^9 p1 G! i: Q6 Y; T* Z- a8 f' T
        elif model_name == "squeezenet":
    $ f; Z* p- Q) A/ Z! p4 g# O0 C        """7 r, I% E8 F* m, a
            Squeezenet9 t% }# M# r. _7 f4 I5 c$ C
            """
    # `( i4 v, X2 H6 q' p4 x        model_ft = models.squeezenet1_0(pretrained = use_pretrained)7 X9 k8 M4 r. O4 N/ U
            set_parameter_requires_grad(model_ft, feature_extract)
    " t6 k4 A; o5 P5 Z        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))+ P5 @4 R8 }# ?1 D
            model_ft.num_classes = num_classes
    : u7 U. T, u8 L7 ?, g        input_size = 224
    ' l7 w! F5 E( z; |$ L. J9 T/ _3 d) }/ s9 K
        elif model_name == "densenet":: {! K8 v9 @0 C" Z' M2 f) f
            """5 X" `3 W" Y2 a7 t
            Densenet
    % _1 Y( H5 K. N3 }! W, T        """* g3 p/ }. A. p: s- l3 d+ d
            model_ft = models.desenet121(pretrained = use_pretrained)
    / h3 s$ W: M* G5 g" _1 b        set_parameter_requires_grad(model_ft, feature_extract)
    8 l2 d7 O( Y/ W5 |        num_frts = model_ft.classifier.in_features9 O0 Z  s5 s& _5 g5 X0 Q
            model_ft.classifier = nn.Linear(num_frts, num_classes)3 M. Q- G8 t! Y) O" v* {8 B" K3 s
            input_size = 224
    ! |6 _# G( C$ Q/ d) t: t
    : [3 q6 l+ |) F: |    elif model_name == "inception":
    4 }$ R2 o7 b4 H! z' H- T        """
    % H. M( t# n& _. _' u        Inception V3
    / q5 ~* X5 p- L% d% W: I        """. s# R* |5 p+ P" g# r, o
            model_ft = models.inception_V(pretrained = use_pretrained). r* b9 |/ W" \5 M& P
            set_parameter_requires_grad(model_ft, feature_extract)
    4 J5 X6 E% y% k- h4 K9 B7 A; v
            num_frts = model_ft.AuxLogits.fc.in_features( h0 \) x; A2 H. U7 }4 s, p: V
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)6 A% o; F: c0 Q" e, d3 m8 X
    . M* c# G, o2 A" |2 w) p8 V+ |
            num_frts = model_ft.fc.in_features' t- e9 r0 t6 M1 x' m
            model_ft.fc = nn.Linear(num_frts, num_classes)
    % ~8 N' X  {, O        input_size = 2998 a& w( }/ ~" c& I  @# }/ p" O
    1 ~: ~7 Q# J. k+ b
        else:1 `  v9 @/ w! A8 T9 I9 Z
            print("Invalid model name, exiting...")
    $ |& N0 P4 [2 ^7 }5 R1 M8 `/ c% l$ l7 a        exit()
    , t0 P# C2 n3 f& d/ A$ |* p8 p- B& O1 w5 N( e7 U1 Q
        return model_ft, input_size8 X0 Q* O- R. B4 B, |

    ! T, S* |' l) Q8 G5 ^# p% a4 k1
    ) u' F1 `2 _9 j- z) s$ A2
    : K8 m9 ]2 S* h2 P- T7 ]  h% v3
    + S# {: K* U+ p7 E# O; F: m4
    % U% w- y9 X+ F5
    : ~; p1 R9 d1 Q$ z6
    & Q4 P' w) A6 w2 K# Z+ t7/ v* w# \: B+ J2 ?
    86 G5 w5 ~' Z1 |1 R5 a
    9
    ( o) t8 U) d) j$ o10
    ' j1 ^# \  d' ]$ ^; [112 Y  X0 X7 d2 J( O/ H! x) y5 T) J
    12
    $ ?9 `& R9 x! k& e4 L. e13
    . I" }# G2 x0 Z- N) S- a14
    ' ^7 x2 [3 Z7 c& e5 ^, w" r15
    . `  v& B. C+ L- P( E" i3 n( O5 Q16
    $ ^* A! ~  T. f/ n" {: ?  g17
    5 n, P4 f6 Q% q( I7 I* f18& `4 q7 I7 J& r, H
    19
    5 K' s6 w$ k7 y8 h20
    ( i0 R, R2 S, N- {$ ~4 i: [21  L7 B& p+ B3 T) [& T
    227 R6 Z8 A: `# ]. n8 x0 r6 w8 M
    23
    : \3 d9 \* H8 |24- U! V8 ?: d1 R0 ]7 P& e5 N/ z) m
    25" r* K" Z' ~! g. ?: I: U6 Q: x' \
    261 j2 X& K7 R  F* G! J% ~, |0 B
    27
    0 V. v* ~6 L6 [- N28
    / u* G' x5 a% O$ @  g29
    - m" ?6 N3 ~% r30" y8 e( }. J" C: N' h& g
    31
    8 ~" b& y" r. R) X32; M  Z( H' A8 ~% d
    33
    ; U* ]0 l6 w5 x' U  N' T% W34
    % @7 ?4 }: N; a* `5 f, ]9 C35
    & s0 K2 G6 O/ S9 L* d363 i' R+ D' I" c5 o0 s$ J
    37
    5 E. @* y# b9 R38
    # H( T3 Z0 p& ]: M+ a/ }: T396 s: Q& z5 m3 k6 A+ X/ |
    40, f' ?: X1 j9 L! g; {& H
    410 l* s9 F6 r: w+ q, {, y
    42& _0 J4 N* g; h# j. V
    43* R! B+ H) H" B' E* i
    44. [3 x4 A0 Z1 @+ `* L1 c
    45
    - N* |. s. Y* I* A7 K4 P46
    . X" _$ O- f% T8 f: _2 P47
      e9 G. N4 z) _0 c( O. O6 H0 ?# o" S481 a+ e% K5 N- Y5 @# X; z( Q2 e0 C
    49) U5 B0 C, s8 y" D1 @# S
    50
    & ^8 D. [5 ?0 e4 q  z3 N0 j513 p4 P! q* `+ I* @# a1 ]8 T) f
    52
    2 E9 H* H" b" M( D( P: {  n53
    ' o! n7 n, G, M% O, q0 S3 j54
    7 a) Y6 X* `! k3 d  @% K55. c/ [% T$ S& q; @
    56
    2 R0 J/ r6 ]2 Y570 i8 u# w, o! j( f
    58
      ?$ X& g: s7 o" j! `3 T( A3 R59  `0 |* U; T* u( h9 w
    60
    * \3 w. N4 J3 X- k  i% f% \2 i61
    & |  s! A( ~/ L9 w0 }' j' z627 ~* G( v7 b/ e8 A
    63
    1 b; E- b) h1 M3 b+ K' d* D64
    4 Q# e' Y9 c# g" y( R4 |65, z) n9 J. X# e1 O; T2 A
    66; Z! q& `0 \4 w8 ]4 ?0 q4 L
    67- U# Z/ B* S* R
    68" m0 W4 m0 i6 ~: Y( z+ ~
    69
    # @% D! o  }4 P! c0 j4 Y5 e70
    " `  r1 [7 @' b% |% q711 ^0 x- W3 a) }
    72% w; P9 I6 h4 Y' F$ ]
    73( M3 c' |& g" G9 A
    74
    0 U: J' P1 k$ O+ r/ H75& X) d1 P( M& `3 n& B# E( o3 s! G
    76& G- f& `+ w4 J5 `  \
    77
    3 I0 ^# u0 n( G. d# |5 U0 X78
    3 ?, c9 l6 {6 y$ d79& a) [. j5 S3 H9 T3 L+ z4 C5 O
    80
    $ y7 a6 j' c: j# q; m+ ?81& q+ g: N, _; ~  b! `/ P* M
    82' @/ s, H6 W' s8 H
    83  e4 K( f9 x+ q! p$ F; P
    7. 设置需要训练的参数. }5 W6 m' N( C; P7 X5 B$ K  J7 I
    # 设置模型名字、输出分类数2 o# ^. U$ U2 I% [: O# P
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    ! b& V9 K* j3 j" r( S1 u/ T7 C( W
    0 s& U5 f1 c7 g1 p. H# GPU 计算! m# C5 O1 ~5 c* O$ @5 p  ~3 d' O
    model_ft = model_ft.to(device)3 j& ]( j9 ?3 g
    4 I4 Y# o  f: f* m2 I+ U
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取. W2 x2 Z) a+ ~# B$ i
    filename = 'checkpoint.pth'+ ?* h3 g+ h* Z& n

    5 i' q9 ]7 j, H+ @# 是否训练所有层
    ! Q. i$ J+ Z( t' m+ nparams_to_update = model_ft.parameters()& z. P: @6 f- C: q5 P
    # 打印出需要训练的层7 q4 b4 ?5 D+ h' d8 R" s3 Z
    print("Params to learn:")
    # L$ B# V" d8 |3 ]0 Cif feature_extract:: D1 n. s4 U0 o* _: }. o
        params_to_update = []
    - L" J/ z- j  `9 O9 b1 L    for name, param in model_ft.named_parameters():
    ; ~8 o" l' a- i* G0 B3 b$ @8 f        if param.requires_grad == True:
    0 I; O! I+ y- `3 i, ^5 s+ O            params_to_update.append(param)
    7 M( O( G+ G: D6 r3 r' t            print("\t", name)7 Z0 b- r! C% G% F' n" d! `
    else:: n6 `  C. T- W' e& _3 O& F
        for name, param in model_ft.named_parameters():, V  x9 w' s" F8 ]+ r) S
            if param.requires_grad ==True:2 w) F) w3 D! L! N
                print("\t", name)
    7 m3 [' M. T( }, ^5 i
    ' P- Q) @4 H+ q; z1
    9 d8 u) s- m: ]% U7 K2; X5 i, T7 x' T9 e4 g$ d$ x& ]$ p
    3
    # g9 n  Z1 ]/ i3 g, e3 B4
    . r( {( n; \; R6 W, ~5
    ! O& q9 S$ u( H6 O$ o1 R6
    2 G$ d# k( d0 Q! w5 ?7
    ! Y5 d( b4 v" M, s" q7 b8
    ) p$ b% n9 ?' I! L9
    9 q. ~: }' P3 Y/ _# {' E10
      E" J5 H* k, z" }) w( s! }11/ k6 J$ s# X1 x4 s, p4 o& b7 [
    12. X2 q5 u4 Y6 [+ Z$ Z7 T
    13" ?# E: y1 [4 [; R' h; m
    14
    + c$ F  P, t5 I& \* }; x2 o1 a: d15& N/ k) P- m2 A. n
    16/ ^$ s* t, n+ o6 o# m
    17" h( d: x" A. M
    188 O- L/ J( B0 [9 H& J  g3 k* C
    19. h% c+ Z3 b1 P" N/ a$ c) f
    20
    : [' Y0 b' ~% F* z4 c21" e0 X' ~* T/ }8 N; ?+ [
    22
    : d: B& }8 w3 V: [9 T0 T9 W. s23
    , X/ R7 d, l9 z  Q7 d, k  Y: e& kParams to learn:/ M5 W( `! M8 t1 H' B
             fc.0.weight* s' r: ?- T0 d( z* c! `- ?& _
             fc.0.bias
    6 n3 H& C3 D( q4 y5 c17 ?5 h2 ]' d: e' X% i3 O2 F: d
    2" y- w* G' k8 E% V( r5 }3 U; ~
    39 L2 v" x; N. o1 o- w
    7. 训练与预测
    ) Q; c% q+ N. s6 e% @$ ?$ V7.1 优化器设置9 T, D- s7 u8 w& c4 F
    # 优化器设置
    ; `6 j2 h4 q' u- b3 G; goptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2): J8 V- _% u; o2 x3 M9 y
    # 学习率衰减策略
    " R4 t3 b8 }& Y( zscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1). w! c4 A: J& k
    # 学习率每7个epoch衰减为原来的1/10- r9 B. ~; l* Q4 D  b2 z
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    ; t; r# P6 T! w  S+ [/ X' ^
    " I0 J% {' J2 I2 z! [criterion = nn.NLLLoss()9 @1 i( C) R9 u4 G7 k  ]  a
    1" ^# {; }! _2 P5 e7 I7 W
    24 ]: D' \+ K- ^
    36 v! G4 {; s9 o
    4
    ( e, @5 T! w: m( [5
    0 t, T" M- Y8 f6
    * o1 F* q, B7 @2 w) i" ]# e% E7
      p  j( B. g1 U2 [, q5 J87 c/ I8 V% [4 v) T8 `
    # 定义训练函数
    4 u0 M& H* a3 o" E! e1 C/ Q#is_inception:要不要用其他的网络+ W% z, O7 m( Q' d' r) e5 o3 Z
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):% C2 K+ K$ a0 S! f; B( O' F" w
        since = time.time()
    * D) J! Y, O- J7 s* t    #保存最好的准确率9 b& o/ S3 y6 @0 }) U2 a9 V6 O( q
        best_acc = 0) E( g/ E8 T7 z/ _
        """
    , R3 o4 D% F9 l% F+ n, c    checkpoint = torch.load(filename)5 [* r6 r5 R: X, V/ B7 y
        best_acc = checkpoint['best_acc'], h9 T# z; A" ~; @. |: t7 R
        model.load_state_dict(checkpoint['state_dict'])4 v' b: V* t1 z* z  _: _
        optimizer.load_state_dict(checkpoint['optimizer'])
    , k  F, o4 S7 Q! t2 m2 @, G    model.class_to_idx = checkpoint['mapping']
    0 J7 C, a8 _8 `5 `, q- x! q5 g    """' ^' u% R& b; @; D" M
        #指定用GPU还是CPU
    ) g0 }& ]& p9 x- G! w5 J3 N    model.to(device)
      f8 }' u- D$ @% @* l" R    #下面是为展示做的" X! L- \. b  N! X$ P3 G
        val_acc_history = []4 Z) M* S7 _& @5 E% t0 o! j7 U/ F
        train_acc_history = []! M) Z# k& e$ O% c
        train_losses = []+ W  F, f6 v1 _( ^1 x" Y1 J* I" `+ H
        valid_losses = []: _: t" r# I+ |$ y7 M& r
        LRs = [optimizer.param_groups[0]['lr']]
    ) N% l$ i# h# T) i4 c    #最好的一次存下来/ g4 G$ c6 V( a/ |( A/ j
        best_model_wts = copy.deepcopy(model.state_dict())# D! B& K$ A& z. ?6 N5 g! P% t
    5 b' G9 `: j) G9 ~& D# t/ E
        for epoch in range(num_epochs):
    ; i9 _8 ]7 O, a1 {9 \3 z' s        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    8 V* T4 J+ p) f$ `+ @/ u        print('-' * 10)
    * g( R8 c, i. i2 r+ A2 t9 t
    + l: h! A* `. Z, ]        # 训练和验证
    $ j, w# h3 R. k5 ^0 h# A        for phase in ['train', 'valid']:
    4 T$ v+ _4 {& e/ \; r            if phase == 'train':
    : N+ U* A/ p" J# E5 m                model.train()  # 训练5 y: q: r( m( f+ Q* Y9 q, p
                else:5 Q( X* _, v: v# w* P3 y1 A
                    model.eval()   # 验证
    5 W( o3 u& @" U# ^. _
    9 Q0 Y4 P# {& s/ |9 S0 s            running_loss = 0.0
    4 i" ~' ~7 S8 v            running_corrects = 0
    - f1 c/ M( G$ r! }8 _3 b. i0 s* v5 y+ h( X
                # 把数据都取个遍
    0 q# l8 D5 g6 m* K* u3 O# I  z2 t            for inputs, labels in dataloaders[phase]:
    , ?9 e* z  U3 k' h                #下面是将inputs,labels传到GPU4 x2 P2 _0 g! r! f) u6 I- J
                    inputs = inputs.to(device)
    4 ~3 D" S" ~4 {. v/ z                labels = labels.to(device)
    4 k- d2 V: D# ^) |0 P' M( F3 y1 X! C# z0 `
                    # 清零% s% f& W) }: j+ ^% z
                    optimizer.zero_grad()
    ; j$ x0 H( T9 z; Y                # 只有训练的时候计算和更新梯度
    2 \& Q7 A* v# m- ?' }8 \4 I4 b                with torch.set_grad_enabled(phase == 'train'):! Y* U4 T( h2 q/ n4 s3 h
                        #if这面不需要计算,可忽略" }) y3 {0 ~/ Q
                        if is_inception and phase == 'train':
    / X, ?# o0 |9 d8 q" a$ k8 M                        outputs, aux_outputs = model(inputs)
    - F' v, l1 O# ]+ x5 [; U, [                        loss1 = criterion(outputs, labels)
    : R1 B- {# Q1 V! k+ `* s                        loss2 = criterion(aux_outputs, labels)
    , h0 X' n4 @1 Y5 J# p                        loss = loss1 + 0.4*loss2  \/ H5 m! J. N1 k
                        else:#resnet执行的是这里
    / e1 h+ R8 k3 g5 A3 s( ^/ u% o$ f                        outputs = model(inputs). a& W  e1 ?7 R
                            loss = criterion(outputs, labels)
    3 q4 A' _6 ~' ~$ t6 I: [% K  E8 k6 P# F. i0 W% ^% K0 K
                            #概率最大的返回preds
    . y  W& G0 ~) ~( W; ?7 h                    _, preds = torch.max(outputs, 1)
    " p* s# z' C! N2 Q$ N6 f8 U! h/ F2 ?. u
                        # 训练阶段更新权重
    ; m4 k) F- y4 t7 Z( `                    if phase == 'train':. i( Z) c* T3 G% B
                            loss.backward()9 r$ x4 U  R4 A& A
                            optimizer.step()
    5 M. T% l# @4 U  A. U5 o
    & b, r" t2 @( P3 p5 q! ]5 l                # 计算损失
      r' u( \& J; |% W3 }                running_loss += loss.item() * inputs.size(0)# @3 R/ Z8 c# M9 G  n1 ^
                    running_corrects += torch.sum(preds == labels.data)
    4 }+ t  Q0 u) R
    9 g' W7 a0 E! l- w4 b7 r4 \            #打印操作
    . f4 [7 x9 G+ t; U            epoch_loss = running_loss / len(dataloaders[phase].dataset)
    . `" ~6 r( N: N            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    : ^# x! ?2 m# O! j
    3 w9 O: g& r2 E: w. s$ {" g; r- U/ N4 w9 y& r
                time_elapsed = time.time() - since- `9 j& V9 O2 b
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))- Y" @( o+ G) A) ^8 V$ e: r
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))8 E5 {% U% X8 L7 z5 b2 b8 Z
    5 z; E* ]& w0 g# P7 R
    + U  L- {, a. C( ]7 |) J
                # 得到最好那次的模型/ i; R0 ], F) U) v( s
                if phase == 'valid' and epoch_acc > best_acc:* n5 C+ r6 E& [; q  U4 `6 |- m1 |8 }$ t
                    best_acc = epoch_acc
    ; {0 V6 Y7 @& _' F                #模型保存
    4 d9 f! D( r' x$ z' w' H# H                best_model_wts = copy.deepcopy(model.state_dict())
    & `: r9 s4 c; @/ j/ ^( P                state = {8 S3 M1 W6 z' w& t% g
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    & }) |* g* `( P- c                  'state_dict': model.state_dict(),5 Q( z3 }7 T; [, x
                      'best_acc': best_acc,( D" x8 x/ T0 K/ k' n
                      'optimizer' : optimizer.state_dict(),
    / d, p" d$ J! t1 M" j3 c1 h% t8 Q6 J                }- U+ x' O# l- ~1 J* h) ~7 q
                    torch.save(state, filename)+ q1 P) r) z/ @# R
                if phase == 'valid':* w- v5 ^. f% o+ T0 e
                    val_acc_history.append(epoch_acc)
    8 u1 ?' i/ e+ v! U* D# }1 d) K$ `                valid_losses.append(epoch_loss): ?5 d+ Y# `& c
                    scheduler.step(epoch_loss)
    8 e4 d$ A4 T4 z            if phase == 'train':# p! w2 _2 s( y- F; _7 t
                    train_acc_history.append(epoch_acc)! }# M. n/ y" Y$ `5 C+ q
                    train_losses.append(epoch_loss)
    - l5 J0 j  V, h  o. u% S: h* f" l4 a4 a& J* y; p) x9 ?! s
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    & ]) ?3 `, a; l' l7 U        LRs.append(optimizer.param_groups[0]['lr'])
    / x1 u) x4 j/ j4 v1 b9 B        print()
    , A; V' k; Z3 Z7 J5 F) b
    % M2 K2 W4 x0 C$ T    time_elapsed = time.time() - since2 p3 T) E* P# @
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)): f1 K8 m1 i+ a
        print('Best val Acc: {:4f}'.format(best_acc)), ]& b" i% [- e: a

    0 v% d0 d: A" Y9 |0 j6 G    # 保存训练完后用最好的一次当做模型最终的结果
    3 j9 f( h) D8 C    model.load_state_dict(best_model_wts)! e+ |, l0 ^  V8 o: V
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    ; ]" F+ u- O# \7 l4 Z2 v  P- g5 ?* E$ z

    ; [! j6 [  u9 R: \4 S  k1; u( ?3 y+ D7 R% ]* _
    2. `0 V0 u! g) C! Z) D
    3" m2 l; j5 [* \; o. B9 o
    4
    + F3 O  Q, ?& X; e: F* L5' q6 h+ R2 I1 `9 m" l/ `/ ]; w
    6
    : k. U; j  A0 I( x7: J$ l3 Z6 v3 i! u! h3 a% W
    8
      F8 H5 ?2 m  K7 v9
    + k9 H1 l5 q9 T# M! l10  Y/ _; N; J% C2 s2 z: z
    11
    0 B3 i2 |: ^. X: e# G2 V: ?12
    . _1 I/ Y# m* W5 i; H138 W, i0 p& m4 d7 _& [6 o
    14
      z5 E# r0 o( P15
    . x2 `( v4 s7 Z& R4 Z) T6 b  H16
    + D- `" _% p( e0 i17
    ; Y0 Q* Q) k1 l4 I6 f" s4 U180 \) U& @, b- `& T( H
    19' e' s" H0 h% A& z, L
    20: u- T& S7 z& `. g
    21
    ( f% B" [% }9 j+ E. e22
    8 P' G% K3 m7 t% A& s% _23
    ) @3 E7 ?5 [5 t! Q8 ]24
    0 t9 V+ V+ r4 ?. \25
    % Y+ e! e: Y% j5 d26
    6 w& s1 v+ `. O* |2 W; Q; g27
    - Y9 J) H  W, `7 J( q9 o28
    9 O- f9 i; ]5 }) M+ M) d4 ], Q29
    % G0 g( j& o1 r: t! u30
    : Z! Q$ a( Q) e; j31/ [& a0 [' K5 t: o5 h# i
    32
    6 w) c+ T8 g2 P) b" \' g! r4 F33
    - g. e! i( t6 ]& H- p; H0 s344 Q( C% V3 P! h
    355 y5 [" {/ j2 |6 P4 ~3 `
    368 @7 b) l9 q; l+ c# y
    37' {; n/ B( y$ N* U9 @$ a! U
    38( C& E; s. D3 `6 D  a# Y: _
    392 D# G! W; }; y3 J3 W
    40; w3 L  a% \0 \7 |" Q
    41* |4 [( e, D* `6 S
    42
    9 p4 v8 f4 e- v/ D2 Y43+ s7 G* f) X/ v/ o* X0 w6 {
    44
    - D! R9 y: B8 o/ _! r+ E' J. {# \45( ~* w: l6 }5 [" }
    46
    7 h5 j- R5 T) {, Q47
    / y2 [0 N* A( u9 ?* A* d48
    2 q0 O% o9 Z) z49' x: f) q1 F1 q8 V2 i' |- J
    50
    : |! @& F5 U4 |7 i51
    ( ]; l: c; R! K52
    . h. y) T! Z/ q539 @/ K4 R3 r7 ~0 h8 K2 y
    54) ^4 g3 j! g5 e+ @2 X- m+ ^; M& n# ^
    55  t: z4 k  M. t* Z6 q
    56
    % O$ R% Q7 K0 R$ t: T# f57& i% M$ f2 H' k0 {9 N' K
    588 J7 p  Q$ L' T+ r; P+ m
    59
    ' }1 d7 t! p1 ?$ d/ j60
    " z3 e1 p+ [# k& v  V; ^) ]6 t61+ u2 |1 `7 E" T4 O
    62
    % ^2 X; e( p. t& Z4 o, h( o63
    % c7 U$ N$ \' m( X* m; F4 H64% w( a1 E' a. x! T4 p3 `8 j8 M
    65* c  g. J, I: ~6 B' U7 G$ w
    66' P# h! `5 g2 \2 z: s  g  O' s
    67  Y* w" o+ T7 D' f; H( q0 H' Q
    68/ S& w! j9 |5 X4 C
    69
    % r/ X$ `3 R! j- m! C  P70
    / g: Y* T: q" w9 r9 x: L71* U$ e& \& q7 L3 p2 ^
    72
    ! f4 w7 U6 u/ [+ ]6 T5 q+ t( H73
    8 I3 r1 N; I) h7 u8 D0 r74
    ! P" o$ @. Y$ F. ^- v" i75& l! P) s- Y9 D+ ^
    760 d6 C( \3 g5 S9 g! u
    77
    + ~" E1 _& z8 @0 a+ |  H78
    7 g. l4 l% m, w2 D% {79
    ' p; d# O8 q5 L/ I1 Q% R9 e80
    1 y. Q. T+ y0 i* {; x) b" \1 \81/ {( O" i; a+ }4 Z
    825 \( `+ [* c3 W% h
    83+ K3 d; @5 x6 Y* g
    84. v) p  m, w* B# d9 {% ^& _
    85) |5 R2 h- L  V8 X
    86) n  s, R# m  z+ J
    87' b$ h9 A; ~8 {% j
    88
    3 Q: {& K& W; y89! u4 c, A7 G2 C. {0 X) P/ I
    90
    9 b2 r4 B5 M2 ], Q! c& V; _914 O) V8 ~) ]7 P' l8 R; Y  g
    92
    9 X+ l+ |! g, l4 I# U- k93
    ( ?" U% w/ [: {' ]/ B5 J94$ d0 H' ~# r; c$ @
    95/ Y' d: @1 V2 S: P4 \% n3 V
    96
    + y, V' I% e$ A8 l: z* A# d* b1 y973 H9 t. V  Z9 c& x
    98) ~4 \. @: l# W
    99
    ; P) W" e& L( L0 l  x6 A- E100, G; \1 ]3 I9 x# K8 H- z! W
    1018 L+ @4 D0 H$ W0 o, {( A2 I
    102
    3 K' n+ R6 S' R! o/ F. I103* H' C& y; y8 _9 F, p/ E9 w$ r7 Q
    104
    5 x9 N9 N& I- T0 C! j105- c. c4 h/ o% I6 V. p& A
    1064 e8 x; \4 ^. g. D& T5 @
    107
    0 @! L0 D! v6 H8 F# R  f1087 B. m9 n0 v9 A* @5 @" e7 M
    109
    & T# ]4 V0 L" U1 C1 q, p110: [2 s  I; Z0 A4 r
    1118 E5 Y+ \1 S0 Y. N
    112
    , u; J$ u- y6 D, \9 `7.2 开始训练模型4 x& Q# D/ M' E% P5 E0 G5 i& a
    我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    9 S& ?/ i5 k! T1 S' @' A/ i. m! V5 @! K7 ]- v& M% {) [
    #若太慢,把epoch调低,迭代50次可能好些
    ! J! ?7 _/ m: G" p% \, X% ?#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合  _) Y4 z# _" N) [
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))1 z1 }3 K: C9 |4 L$ v% s

    6 O' ]5 ~+ v+ {: l. p, w' o1, S- H+ ~% F9 u4 j# Q
    2; x. g; H2 r' V* ~7 [& @, k
    3
    - b% _5 o# W. M6 G1 Y47 Z! i% R4 l2 m4 B) q$ @/ I
    Epoch 0/49 q4 p7 C5 I2 S1 z. A  Y. W3 t
    ----------7 x2 L( \8 H6 S) f! ~% y7 D
    Time elapsed 29m 41s! ?9 o, B: C0 Y/ M# O
    train Loss: 10.4774 Acc: 0.3147: K9 D  _( A* w7 A9 G2 m
    Time elapsed 32m 54s4 ^$ C2 V( O+ r3 X
    valid Loss: 8.2902 Acc: 0.4719  V9 M, N" ^1 W; _8 s7 X( T. k
    Optimizer learning rate : 0.0010000
    , s+ h; w+ y' n' Y2 h+ @- v* p" Y$ @' [% U! k/ q
    Epoch 1/4: S* u4 }5 j# S" Z
    ----------  w, Z7 D8 v( _: I6 m2 e& }' K
    Time elapsed 60m 11s
    ; X8 c7 r4 e. c# b% m5 Z# ~train Loss: 2.3126 Acc: 0.7053
    : z9 T$ S1 A# j3 dTime elapsed 63m 16s
    " Z4 A+ @! Y- D3 m/ kvalid Loss: 3.2325 Acc: 0.66269 c, F5 v- F5 c  T) W
    Optimizer learning rate : 0.0100000+ B! w) H/ K/ s9 @. ~0 N

    : `: K4 V9 J5 m% }. o& F, b" {9 U) gEpoch 2/46 `9 e) t& R+ X2 J: G. Q3 |% G
    ----------
    , {6 c2 E( r9 ~. D8 h2 i) {Time elapsed 90m 58s
    + F, F3 C& ?+ z& A6 vtrain Loss: 9.9720 Acc: 0.47340 B9 D% X) I  ~) o8 n
    Time elapsed 94m 4s6 M6 Z* [) b) \2 O" [3 ^
    valid Loss: 14.0426 Acc: 0.44138 O) n; y" L; C1 \3 C# E
    Optimizer learning rate : 0.0001000* i2 p( A3 I( f0 S" T8 D

    9 A6 s3 d: F5 I' e( B$ e" L4 d  xEpoch 3/4
    $ e# r6 f( j8 g1 y+ K% e0 S2 Z----------
    / m' n2 H/ C7 _, z% H3 B/ ], iTime elapsed 132m 49s
    , N3 L7 U! V' U: e& ]1 ~6 k  ytrain Loss: 5.4290 Acc: 0.6548
    / Q) z+ ^) {+ u( [  Y  ]4 Y/ R! YTime elapsed 138m 49s2 U! K. \* @) B; s9 U! K) J; v9 O
    valid Loss: 6.4208 Acc: 0.6027$ P1 O$ i3 `; F; @! u- K
    Optimizer learning rate : 0.0100000
    ) x7 M' e) ?- D- p  r$ _0 Q+ }) S+ k# `) d1 c
    Epoch 4/4
    3 j& ?/ E( m. y/ a----------0 A+ P1 s7 a, |7 v
    Time elapsed 195m 56s
    7 G/ J9 H8 L  N3 ^( T$ m/ ~2 ^0 W! ]train Loss: 8.8911 Acc: 0.5519
    - X+ O: N2 d* Z- _Time elapsed 199m 16s
    6 q( e2 ~! |. f" N* p% A  zvalid Loss: 13.2221 Acc: 0.4914
    " J' {4 _' `3 L3 Q  [8 N. D) wOptimizer learning rate : 0.0010000
    7 W) @) q0 c- a0 Q' k" H
    ! }  Q  U0 \8 ?( ETraining complete in 199m 16s
    ' ]7 h( U3 ]" {5 Y9 ~! X* f4 }Best val Acc: 0.662592$ T! e! V/ i) L+ c4 ~) i% X
    + y! W- F) }( Q6 b9 h: M+ I3 b
    10 w$ U" j) h+ W
    2
    0 f. D. b4 G, M3 \0 e; W33 Y" d. E0 X; Y# s. M' H
    4
    ( u& p/ i, Q$ c1 H! U6 w, U5
    ! h7 @) O) b5 n6 t1 Q6$ O; ~) x8 N1 p' _1 U# T
    7
      q4 b, Z# V/ j0 B- R85 u" b# }* I0 h8 y% ?
    9# Y( @1 N# g: |6 w/ q3 r& a
    104 d! A; t0 I# a5 R1 W% J7 a
    11
    ; F. C: Y0 w2 e$ n12
    - `. P- i7 E- H$ m13- g9 `8 C& b0 N# P) Y
    14* u$ ?5 {: M; i3 n5 j  W
    15
    ! B8 s) `- i  s/ }2 }16: `9 s, \0 Q, _8 b% t
    17( O8 S- u( p# u. F9 `' s" m& j  K  a
    183 {/ t1 }' m+ r1 `
    190 l4 H. o/ a8 c. i
    20
    # ~- p# [5 D0 ?  d+ ?2 F3 ?+ R21/ M, u4 a/ P" |* r; i- _+ Z, `
    225 U6 U! X3 X' ?. E5 C) ?1 Y
    23
    ; J; V! n- f  Y9 Y247 {4 B+ g# z; {( F. q
    25
    / b, F! c1 b' S$ }  o1 |261 L% b5 n* k: ]
    27: z6 e" ]7 h) y  S' A
    289 W- o) S0 z1 Y5 ?1 E; T& i
    29- D1 y( L3 U& _8 r
    30& K8 a- n; _% u- e/ q9 ^) @
    31
    ( P- N% [8 M% m! A32
    1 ~+ l. Q. @# Y6 k2 k! |33
    . S# l) d$ Q' w. w* |' V7 E34
    ' _* b0 {& D4 w4 |; ~+ R( M35
    ( E' x: j7 o' W9 ~4 e( J  ^36# @3 X' A+ A- _5 F
    37' J# P: v8 P! T3 M2 N4 `& h3 p
    38
    3 I' Q. w  A$ c4 S1 `5 i393 m' F1 ?. C- c1 x2 J& i
    40
    1 n7 J7 q+ i7 A41/ N2 \, @# Q4 C5 I2 y
    42/ h7 w3 `' q" f' s2 b2 \+ K
    7.3 训练所有层. ^, G7 T$ T, F; ]( L- o- Y
    # 将全部网络解锁进行训练
    8 X) A& u) Q' Q; H9 H8 N' Pfor param in model_ft.parameters():1 U6 Y( S- v# R
        param.requires_grad = True
    1 Q( M7 b/ J4 `& _9 [* j. Q4 m* l& j) d  Q5 A: U
    # 再继续训练所有的参数,学习率调小一点\3 H4 K0 G6 C  H
    optimizer = optim.Adam(params_to_update, lr = 1e-4)
    6 l- _2 E$ j) T1 u+ Oscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)5 R# |! D9 A% W6 S* `: @! {
    + L3 x  R6 n9 G
    # 损失函数
      u- i: V  ?" r+ d( o4 o$ D: Acriterion = nn.NLLLoss()
    6 j* I) e' f& G( G% I  z1
    8 m6 @7 A+ f4 \+ X& g+ ]* T5 O2' |* N% H- S# n* ]0 K: W& j
    38 P8 e+ _; _' r/ \5 W$ r) k0 _2 K$ ^
    4
    6 n7 Z4 }0 ~$ F' w) ?3 ]52 |, V  k( b$ \/ V4 u7 u7 j
    63 v0 ?, u1 ?" _$ S* Z
    7
    ! i% @% `1 S5 F5 T/ c85 y: M  H0 O3 _  ~+ v6 o/ _
    9( v, a2 b: I- K- M2 f5 i+ P
    10
    4 d; R) t8 L* X8 l) {# 加载保存的参数' `* Z; {% t: ?( u
    # 并在原有的模型基础上继续训练; z* y' t4 K' Q! N% d( O
    # 下面保存的是刚刚训练效果较好的路径1 c& y+ `0 p+ G/ x5 ^) D
    checkpoint = torch.load(filename)  I, t9 b, l: T# `. }$ X7 K
    best_acc = checkpoint['best_acc']
    " ~* _3 `% v+ k6 Y3 Smodel_ft.load_state_dict(checkpoint['state_dict'])$ c, D+ \$ `6 u7 i" y- s. d7 v1 @! J
    optimizer.load_state_dict(checkpoint['optimizer'])
    + j& n. h! o% R- r  z1" B, D- N' i. v+ B$ P
    2( P) d9 J6 J  ]2 V: _( q
    3
    . @7 q6 U- ~% m- f4
    ' ^  `+ m6 e+ i) Q. F52 I, a; C7 C' f, F
    6
    9 L& P8 B+ {$ i* U7
    0 W1 ?& d- E2 E; ^8 u8 h% K开始训练
    & v+ t( y' r. c7 Q注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考  K5 a5 B3 e* s0 c' t% R4 w
    . P# P* f. e3 l3 L* Z9 ~( h
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    , r: Z5 b$ k# c4 S& r4 Y6 ~1
    ( j2 ~3 f/ h$ p2 tEpoch 0/1
    ) N# A" x- e8 _2 E----------
    $ [7 A0 ?9 G" {% fTime elapsed 35m 22s
    ( r9 I& T" s5 ^9 P0 `/ p, k# q* ctrain Loss: 1.7636 Acc: 0.7346
    * j  ]; R6 H, I# V" sTime elapsed 38m 42s
    ) ?6 k; D  k: v7 x; gvalid Loss: 3.6377 Acc: 0.6455
    ( q! L* G, J& J0 b  qOptimizer learning rate : 0.0010000
    * i7 a( ?, ~, b9 A
    6 \! w, e1 J# OEpoch 1/1
    ( e  M! k! Y7 y5 Y----------
    9 C# U( V- Q' d' x$ M3 G# JTime elapsed 82m 59s$ z4 B+ T( ?9 P. d
    train Loss: 1.7543 Acc: 0.7340
    9 ?% d- e/ E- v3 V1 l9 yTime elapsed 86m 11s; f% G4 t* G: D( N
    valid Loss: 3.8275 Acc: 0.6137
    3 p/ ]4 J& Y% f4 tOptimizer learning rate : 0.0010000
    2 s9 l* Y3 b: E
    2 m2 q- `) c. G* S# u$ iTraining complete in 86m 11s) p- B1 p# c+ L* q4 I9 \7 U
    Best val Acc: 0.645477& Z7 L  n$ Z3 x. l

    ; c* E- G/ M! V) Z1( V) C, O' G# W% T' b# v
    26 i5 U) J6 K9 u: N* C" n3 M3 ^
    35 ~' u$ _) O# Q5 B. D
    4
    4 X# b1 c- E+ u0 j' r1 m: U5
    , g! G! U/ D! R9 s6$ S: z, X+ M( V! H, [
    7% R0 A; j9 k+ |
    8
    * }9 ]6 ]/ p# m6 L( o4 t0 \8 n5 M9
    ) I( r% n  ?: W4 p10: ^  d6 e4 C' I9 U
    11( g7 ]/ J& e9 S
    12, O6 {6 ]0 g& a/ S( x4 I
    138 F5 ~+ r; x( \1 @4 [" G
    14
    $ r1 f0 d  [% @7 O( g. G2 `+ D+ g1 p6 o15
    ! `- F5 A( k0 t161 Q2 M: S( z; ]8 L, r
    17: w4 l. k- A# l& h8 k5 b. D
    18
    : P. e! Y* D& V/ k! |8. 加载已经训练的模型
    ' [6 V# J! `. |( e. L3 f+ S相当于做一次简单的前向传播(逻辑推理),不用更新参数
    ' p" _( o1 j" G3 Q% U% E$ a. H
    $ T: _" i0 v: z( }/ v9 e& R6 Wmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)3 c) E6 c4 f* N9 r, ]6 q) [3 {+ [

    ! K' z# K1 F8 W, h; j) i# GPU 模式: `' {& l/ l/ j1 W
    model_ft = model_ft.to(device) # 扔到GPU中3 Q8 u- U7 R; Z' L9 |8 S/ N  y

    + J" F' ~( K! w  Q* {# 保存文件的名字
    - K' u) A4 Z! v6 Q9 m, ?, xfilename='checkpoint.pth'
    ' w3 I# n0 [9 q; r
    ' K( H5 t. J  u3 r! J  e8 ?# 加载模型
    $ O' o7 s3 r9 U  y. N2 v1 [& Dcheckpoint = torch.load(filename)
    % E7 U2 s& @' t2 o' L; _( o* ?best_acc = checkpoint['best_acc']" H* Q& I1 u; M, B
    model_ft.load_state_dict(checkpoint['state_dict'])& U/ `% ?# c5 c: {7 \$ q$ I/ Z
    1% Y2 y1 U8 J& Y7 H# W" X
    2
    ; m6 U" w% U3 q4 j+ I3
    " m: z( a; r9 \, G  z7 ?4$ ?3 T' A* `, R* J" _) r4 u, H
    5# J4 P# W$ Y: f. |
    6  @3 S/ n; ^7 ^/ T" X* f( t
    7
    " s+ r# E. C# m  v  u: v8
    + I& P2 a% F0 v/ B* P# ^+ a8 |9  Y* j: c# T  k' ?" O
    10
    ! k( {5 c+ @6 S- y11
    % B/ x& O: b5 z5 X12
    1 Z* ?4 J' G4 t4 A7 y) V8 o4 B& w<All keys matched successfully>- d( O' F' D$ b; _$ n0 B
    1& Z2 O* G5 U6 Q
    def process_image(image_path):9 y4 i0 ?( M8 r) s1 I* b
        # 读取测试集数据3 U9 }% L5 F6 P9 m2 o: t
        img = Image.open(image_path)( ~1 N, _$ ~2 l! `
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断. D' y4 G/ k, |8 \
        # 与Resize不同
      ?# y7 \& L+ D* N/ ]    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小+ L( v1 J8 \& S, i( V+ h2 s* r+ }2 z
        # 而且对象调用方法会直接改变其大小,返回None- D3 L& X  {& j+ Z0 M- ~; G4 q0 z
        if img.size[0] > img.size[1]:
    $ U1 L- Y3 C# [! l2 b, S: R3 `        img.thumbnail((10000, 256))
    + ?2 M2 }5 D) o% b1 [& `+ _2 Y    else:5 W% E, s: K3 Z+ `7 N
            img.thumbnail((256, 10000))
    8 e7 s8 i# s" F- n, a! V$ x/ \) x& m3 [
        # crop操作, 将图像再次裁剪为 224 * 224
    5 d, y+ f2 a% n    left_margin = (img.width - 224) / 2 # 取中间的部分% b3 x: |& C4 }5 L" B1 i
        bottom_margin = (img.height - 224) / 2
    ; H2 ^8 ]3 r, i5 O# ]    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度# `8 |3 A  m4 z" `, h
        top_margin = bottom_margin + 224
    ' j1 ?0 w& l* e! `& m! b) w1 A7 Z+ L% v9 r
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))7 x' ]8 e2 ~/ D) k' ^1 Z6 T0 p
    ' x# M7 s- @+ @& D! M0 s( e* _: k" @
        # 相同预处理的方法
    + q) i7 k7 s. L0 G    # 归一化# i. M. u7 r3 O( M
        img = np.array(img) / 255
    ' R( X: \0 G( x' Q6 j: [8 k6 y    mean = np.array([0.485, 0.456, 0.406])
    % ^, W4 I3 C" s* k) C7 A- M    std = np.array([0.229, 0.224, 0.225])
    ) ^9 r4 I! u! ]  c% b' z* M    img = (img - mean) / std
    * o3 a. ^8 K. L  q, V8 \' }8 O4 m0 a
        # 注意颜色通道和位置% K( h8 A3 b4 z4 s- Y
        img = img.transpose((2, 0, 1))4 f: H0 ?" ^5 _- W0 i8 h
    8 c! i- C: ^+ x+ u5 }/ b# x
        return img0 m& Q  v% l2 i3 w0 c/ `* Z

    8 i5 g" k2 l9 X: k) Wdef imshow(image, ax = None, title = None):* P0 j, P: |8 J
        """展示数据"""( Y" |4 k3 g; h% q: z3 g
        if ax is None:/ I6 ^' y( C/ J" v9 Q
            fig, ax = plt.subplots()- d1 `5 Z7 p, W4 ?+ c& C
    * f  U3 Z" ]+ m0 `/ x. L
        # 颜色通道进行还原
    $ R* W  B' A4 r    image = np.array(image).transpose((1, 2, 0))* K6 l( x  D$ R' C3 k% c0 U* S5 @

      T; ?4 Z. A+ Y  Z. h& Y    # 预处理还原
    ; S9 Z5 A5 A' K3 g/ `! Y    mean = np.array([0.485, 0.456, 0.406]). c% ~9 M! n$ h4 g. f8 |( f1 I
        std = np.array([0.229, 0.224, 0.225])* Q5 c2 x& l2 v; l0 h& \! P1 A
        image = std * image + mean5 ^+ S4 |( T& V7 g" R5 B( J5 V) A
        image = np.clip(image, 0, 1)3 J9 d) F4 F. r6 U: U7 [0 ]) B7 n
    9 V5 b# x  ]% j" r( s/ d" ]
        ax.imshow(image)) x2 V4 b% M: R; ], G4 u
        ax.set_title(title)! J- ~2 x- Q5 A2 G- R
    2 L& k4 B0 U; `: c3 Q
        return ax  K1 m/ d) b# B$ @  [1 }! P2 R

    6 k, Q0 a3 A, w! }image_path = r'./flower_data/valid/3/image_06621.jpg'
    ; B6 [5 A  t, T' m4 v4 I" Rimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    , {6 ?1 A, n! T% J7 a3 T. _imshow(img)
    8 X, S' E- \! y, }$ w1 Z8 q* b/ o2 c( q* L
    19 r( z0 V: h7 O% H, m9 Q3 T
    26 G" K/ P  W- \: Q, U
    3
    ) z5 ]: d- D5 |" V  `7 R4
    & b+ ^6 h0 p3 n3 V2 `5
    ' ]3 Y% r0 F: T6 L/ W& q" l* f6: i9 f& H4 g- M9 W' E8 V! `
    7
    8 W8 h4 ^4 [, W7 P7 q4 s$ P* V& m8
    % g  n2 m8 M2 Z. N. v( O93 `  |; l3 d5 ^" i' q+ q' _
    10
    ) S, |5 R3 W8 B, U* i118 p  ]$ f5 ^# [4 y) H
    12
    , a  {' u; o1 V# v) x13
    * j9 U7 c2 Q, T/ Q$ b148 h+ K/ O8 n; o, _7 R
    15% k/ x0 h8 p  J
    16; L2 s& }: _$ ~3 Z: y
    17+ m: k4 r; o7 g) ~: g' W
    184 y" |  J  u, M9 E) |! |) s9 p
    19
    6 N( ^; u! e3 H9 f2 q( i20
    " y' Q- {; x8 T+ j& C. W6 h21
    ) W% A# a: a- c- y0 B) P22
    & @7 Q/ ^" I' }: [231 O' i/ m: r5 k, S( @) {
    24
    # y; K; ?0 M( A  ]$ N% O25
    ; b' n3 t% }' F26
    4 i/ |9 G9 t9 N' B. M+ ~1 h. y% U27
    ; T8 K: Y! h. ~' P28" r& k. r+ E! y% z+ w/ w
    29
    7 _8 f* o+ ?2 Y8 G, P- ?30
    % i& [8 ]/ p( w5 }31
    $ Q) P' p6 X! t# C32! V6 n' [/ j# I' U: [3 r
    33
    ( L# ]% m$ w. ?- \1 {- ?34% J( o4 Z: e1 L6 V
    35/ t- u3 a. b: F  M
    36' k0 l; n( i& R4 d
    37( f. s' W7 I$ h4 Z- a
    380 p1 d* J* u5 Y. i
    39
    9 N5 T$ ?% |3 Z- S40
    9 X5 b- b8 i& V5 P* s, D: w41
    6 a# w( c- V/ H422 B. @  v% ?6 I  W2 \
    43
    1 i8 i% K1 R$ s5 `( U4 Q44: i7 E# U8 J/ c& B
    458 k5 e+ r2 W9 U9 Z8 H
    46
    ; \  a# `. N7 D, I% O8 ]6 {; b478 }' Y$ U. }# O, Z) G* j+ `$ @
    483 A9 P* w' P7 Q
    49
    # }8 M2 A# K2 ~, }8 A8 H50: h/ L7 b9 y7 c1 z3 O) y
    51
    3 `2 Y' U8 {! y  W52! s) n5 R3 U/ k
    53* N$ b- r* s2 t0 ~
    54
    " ?, H; ]5 E7 ]; c- A<AxesSubplot:>
    6 P6 y+ S% k* J6 ~, c: L$ l13 c9 z; z  ]" U) ^0 j; k; e
    7 i9 b6 g6 n1 V7 ]2 d" D. I
    上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    # w( g9 j  u' n- p  G5 X; _/ L
    $ F) {7 _  q8 i( Dimg.shape! u" M) f9 j3 U/ |
    1
    3 k  y1 ^# b( e  D$ }% p  A(3, 224, 224)8 y* ~; W2 h/ i: o
    1% y" b2 x  \! r, |
    证明了通道提前了,而且大小没改变
    / X+ q- E/ B; Q( y1 ^) }- s0 ?4 G5 Y# n7 K% Z: H
    9. 推理
    ; \+ v3 L6 J  ~% R$ Q; b& V* kimg.shape. F+ X& `! C7 s6 v' n1 L
    7 P' {" i, d5 p* l# Q$ v: L, Q
    # 得到一个batch的测试数据$ a( {$ v; K; A7 [+ I
    dataiter = iter(dataloaders['valid'])
    3 N! W5 W6 z! T% c& }/ J; Timages, labels = dataiter.next()
    4 M8 ^+ c# E, l% a; U' L$ u  D5 d! ?, ]8 Q% O; ^6 N
    model_ft.eval()
    6 u; T1 M" ?" F" B2 k8 f: }& C8 L: P, e" {* U& ~5 v
    if train_on_gpu:
    8 N( m; {4 M1 M4 w, J7 n    # 前向传播跑一次会得到output
    ' `! Y3 z9 U8 `    output = model_ft(images.cuda())
    - q( K5 R6 t" ~1 N) ?" Welse:
    9 F; s4 k2 T, a+ c) @8 c    output = model_ft(images)
    6 w0 ~( v; x1 }  a
    4 e. _7 z" t! C# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
    7 y; Q& A* ~7 ?) voutput.shape, _, l: f; {2 _$ ?$ ~% w9 n
    + w0 [4 n# h) n' t! i; p/ \6 f0 a5 q8 L
    1) k( c  Z- S" q' b% g
    2
    5 j- |+ S- T$ D$ _3. ]1 [& ~: e$ x% j9 v+ V/ S
    4+ m4 g" H, d1 Y( d
    5
    ; J* L( W+ Y$ f9 \. a6
    / ^: W# e; i; v, m# q: Q! `7& {9 s% ]' \  x1 A. H; @1 _0 l
    8
    3 h$ j- ^) b3 X) u* h98 v0 J0 y" O, A& d
    10! d4 p& k1 t3 {
    11: G' c. ^  U2 E& q4 y, t3 P  D6 q4 ~
    12- _0 ?8 L7 e, n4 w% D; q
    13
    ; {/ m! O  z5 f6 B8 ]14
    0 A$ N, G# E) o9 F8 A" L1 e6 v15
    9 e) x) v! `0 e- g9 y+ X% e167 v$ Y& [* z# U; @6 i" q2 `
    torch.Size([8, 102])
    : O( J, o. }) M. ^" ~. \! o1- P% n  l+ o$ u1 }0 k( a7 d
    9.1 计算得到最大概率
      C3 ^: r( O* i5 v_, preds_tensor = torch.max(output, 1)
    ) F6 b  u% x3 X6 v1 o/ m5 _) {
    + V0 I7 B! q% J( j% Xpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    & G' q7 t( y0 _3 _1 O1, Z# g6 ?. ]2 m7 }6 M- y
    2- m4 p5 J. t& c9 }: O8 N: h
    39 J$ O* x4 [  l% U, |5 \  C; I
    9.2 展示预测结果# b4 H" h% i0 W7 `+ D: N5 F
    fig = plt.figure(figsize = (20, 20))
    $ [' a# w: Z+ k4 F8 bcolumns = 4' R: q7 |: Z0 |& v' t0 S
    rows = 2
    . \* a" [$ z% r, }4 l5 ]+ W3 u: Z
    for idx in range(columns * rows):( T3 \4 c: Q+ d/ j* o, r
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    * M0 M$ a) f/ R5 V& z; M$ U- p    plt.imshow(im_convert(images[idx]))3 t5 j) ^) B# c4 B) [9 U3 y9 q
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), / C  _3 t3 |1 c/ @5 `% u2 L/ V, ]
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))" f9 p- H% o5 x, ~
    plt.show()9 s4 X6 o; B- ^
    # 绿色的表示预测是对的,红色表示预测错了
    , M+ @9 v! u7 d8 L18 q9 U1 N! u. @4 s
    2) I( i6 t. {0 F% f+ D+ h
    3
    ) z) o7 z, i( K1 r# q/ P4
    % `  s0 [& R0 W* C5+ {6 |( E7 P- W6 @" o
    6$ E, v* o9 L7 q  c
    75 ?+ x6 |# Z, b
    8
    0 C1 v8 p7 C! ^6 }$ p96 }* O& n" ?/ w' z/ ^/ V
    10
    0 g4 y& e# c+ O( G& p! a11
    6 m: S* a( F  a% h5 W6 d, Y6 t, ~2 X
    ' B$ _& X+ s$ X; O% f
    ; y7 u# j$ d$ U  z3 H4 p" H! @
    ————————————————
    3 r( E. i4 ~; A" `- C版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    2 ]0 j7 z: d. E原文链接:https://blog.csdn.net/LeungSr/article/details/1267479409 @6 L6 j* u$ E9 X5 ^

    + e% i8 _+ ?: d+ E8 E1 e
    4 h. f9 W% ^6 M: \* W1 g
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-5-26 05:07 , Processed in 0.537627 second(s), 51 queries .

    回顶部