QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2715|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    + e- q/ `2 ]$ `6 m
    6 D" [* [9 J: i文章目录  d/ b# {1 W5 v# w& I
    卷积网络实战 对花进行分类/ J% s+ N$ E2 y5 F/ G
    数据预处理部分9 w2 D- V+ E; _$ t6 ?+ C& @  J
    网络模块设置' ]9 z3 d1 w4 V0 o5 d
    网络模型的保存与测试# W* P3 @+ r1 ]: N8 e9 S
    数据下载:7 e9 o/ s3 S2 R! k
    1. 导入工具包
    # N& X  b  c& [7 h2. 数据预处理与操作
    & }: T/ [* D, [6 |4 r' j% c8 {: u9 q3. 制作好数据源
    4 d  p) _4 s6 b; I5 P6 U1 ?/ C$ z读取标签对应的实际名字+ T1 H. ^, t/ R; `6 i0 v
    4.展示一下数据
    $ M6 y! R/ b0 g# Y2 b2 h5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    ' x. }- A3 x+ M( a6.初始化模型架构+ k5 l! }  R/ ^) _. [; n
    7. 设置需要训练的参数
    4 o$ t4 `8 p: o) s2 T, W0 _7. 训练与预测  n4 s$ n* ]. z# Q
    7.1 优化器设置
    ' N2 p4 }+ h( R0 N7.2 开始训练模型
    1 k" }5 V' P0 x+ X1 }2 e7.3 训练所有层
    , D/ r2 y$ R+ U开始训练1 V8 k2 g% k7 {
    8. 加载已经训练的模型5 z* Q% E9 L9 g) N. K2 M
    9. 推理
    $ }5 j% T) Y! ~1 l  W9.1 计算得到最大概率
    ; u$ J! B+ S, a9 i! Z" F$ H- \2 v9.2 展示预测结果
    9 y# z/ Y% P, r( i, |/ U/ ^写在最后3 n+ b- `) V! ^& |5 R# U, e
    卷积网络实战 对花进行分类
    8 b. n3 X* h; Z# ~" G2 D本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能( ]2 X& M& E  E* Q+ n) L" b
    ; x# [6 |& K) x! g  {  o
    在文件夹中有102种花,我们主要要对这些花进行分类任务- t8 g( R# L* S
    文件夹结构
    - F# Z  P" O6 i* D( j. R2 Q
    ' Q) u$ @+ U9 wflower_data
    3 W" M; D/ z! D/ C8 k  e  I
    2 h$ M( Q9 I  ktrain9 j6 q2 ?) V* s# e& ^1 ^1 b

    0 N7 T5 @. P+ `( l, y1(类别)& K" L& ?7 b  J2 N% a! X: \. u
    2
    ; ~" g& W; \! c  ^$ W. uxxx.png / xxx.jpg
    - a) P  y6 R" j4 O3 Ivalid
    0 V" ^+ ~* G  e  f; Z6 E  O, O% j- [" ^0 ~
    主要分为以下几个大模块8 u5 R8 L" m# {: r( u0 Y
    ) t$ z/ G' ?; s/ P0 m
    数据预处理部分0 [) W  b% p+ D- r+ F. }0 p
    数据增强- a( F  @6 P' o# O- }
    数据预处理4 F3 i" U9 j1 u) K! L7 Y5 l
    网络模块设置
    3 t5 A: u& q4 X7 f! V# |& x加载预训练模型,直接调用torchVision的经典网络架构# K' \1 m( M( G- J) h% E8 M
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务0 |) D) g! t4 C: Y# x
    网络模型的保存与测试
    + a  S# \, R6 H- ?$ o模型保存可以带有选择性
    % J" q& G% `1 J+ B: p数据下载:
    : L: e, K6 ^0 {' m3 ehttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    2 E! V* j- g+ V' i8 W8 T' P- H% m% m$ D! ^- P
    改一下文件名,然后将它放到同一根目录就可以了
    $ ?+ E$ n7 T, O! S( v; e* H/ h: M5 V9 T5 d3 d4 Q
    下面是我的数据根目录
    1 p. B+ h2 J: }" n# g. \1 b7 E: A5 _6 j: x$ @
    - G) ~" `$ }! q* v
    1. 导入工具包" E' |, o3 C" r6 y
    import os0 ^* Z# v0 A6 Y2 w, F' G
    import matplotlib.pyplot as plt9 [" p8 \' ~. }0 r+ m
    # 内嵌入绘图简去show的句柄! l. O0 J4 m5 K0 H0 S/ u& v5 S
    %matplotlib inline
    ' G" Q. P1 `4 ~import numpy as np1 x( y+ d4 Q9 M3 U7 [% p; {
    import torch! U# t2 f3 x, b# a0 I+ ~
    from torch import nn3 z; D! r. D3 @9 O& ~  @

    5 q6 x6 e+ v% [  w$ Iimport torch.optim as optim2 o/ u* e: m2 k4 \' H+ X3 j
    import torchvision: C0 r: \: H" y
    from torchvision import transforms, models, datasets
    " U" B# w' U1 H$ |6 k( j; v$ W( f/ B* P3 B0 I
    import imageio5 z6 K1 |" y+ P3 `/ \8 t) y
    import time
    ! C7 |' I. p5 ^5 u9 g7 G& pimport warnings/ h# y2 p# @2 Y! \, r$ U- H- M4 U. J
    import random8 y# \0 e7 R0 g' Z
    import sys
    0 B& ]1 c5 c% a. A/ \8 V" rimport copy
    & w1 d  R9 i, O1 t* R3 G0 |import json
    4 G9 x+ \# f* h# c1 i: I% T& pfrom PIL import Image- m0 ]1 `1 B, H8 ]5 u6 m
    3 Q6 m, f* y# \1 `
    , h3 [, ]6 k& V6 }8 z! z
    1
    % [) y& {: E- [0 s$ W: {0 C2! w. l  W! i: V, ~
    3
    : X3 z" I- a; P8 N+ Q4
    5 k! {/ U7 z- k+ n4 T. }2 m, J! `5
    : K! A: z: C8 ^4 b& u: k. X/ N, ?6
    / z# c) R% J2 j5 R7& a4 ~1 N6 |" C! l/ y
    8
    3 A: Z  Y% ^0 S' F% s# d9 D9" _4 `; T2 u+ C% ]
    10
    " j4 v. I' o/ S8 K11$ y) y% U' j% w0 Y4 H9 w
    12
    / Q7 X) }- Z7 Y, w130 n9 P: r, X+ z" M% t5 r
    14
    + K2 D; H/ I- d. C15% |0 M* @; `$ Y, ]3 o5 }
    16
    $ S" {! U! d4 x, Y- c: l# L7 j17
    7 V, x" v6 k6 |, A5 b3 G18
    1 U% J/ m5 B$ k19
    * r$ U1 `% n% s7 e20
    7 {4 W1 O+ Z+ h: y: }, e# |21
    5 ^) J1 f' [+ d/ Y% b2 x2. 数据预处理与操作8 G/ }: X6 T: n/ G
    #路径设置
    ; [3 \: ~& B/ Z8 R5 O2 y! Z1 ddata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    $ M5 b6 i( @! v6 f  \train_dir = data_dir + '/train'0 j/ @% ?* M6 Z+ y( c) D, X
    valid_dir = data_dir + '/valid'/ G7 G9 |9 q+ Y3 U' t
    1. s# s. A6 |* W8 }; P* S
    2& N* H/ l' C- E; q7 p" a
    3
    + Q- h5 M' S& v41 b3 O% b9 ?, B/ W  d: D
    python目录点杠的组合与区别
    # @3 n6 }% |0 n6 e注: 里面注明了点杠和斜杠的操作
    ) Z" N% p( M) P5 B4 a, k0 x. y* [' A! O
    3. 制作好数据源* m3 K& t8 h+ `
    data_transforms中制定了所有图像预处理的操作
    , A# |  Q: G& S1 u$ oImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    ' p% d$ Z5 V8 m/ K1 F0 ddata_transforms = {
    ' C! X; q' m& c' d9 s2 \# ]    # 分成两部分,一部分是训练
    : W, q0 A# K- B, w& ?    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    9 r; P& G8 r/ W+ Z; @8 a% T                                 transforms.CenterCrop(224), # 从中心处开始裁剪
    * c% T, Q$ R- D/ I( A& v4 I; q1 ~7 k3 {2 f                                 # 以某个随机的概率决定是否翻转 55开4 x' C9 r; V9 ]# x% ^  t( u5 i0 w
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    4 o! E' `% |! m& e                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    ! z4 x/ D8 G9 h* f, }                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    ) r& q2 T& m8 p. _7 O, {                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),, F$ f9 n: x8 I8 R9 r  a
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    8 _& c# h1 }3 D9 {                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的" P" P) C+ U8 W5 C) D1 x, u5 _
                                     transforms.ToTensor(),
    7 a: M% G& X6 {                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差: }9 a: q! z/ E( W7 E" Y. g* Z" A
                                    ]),: r  E0 x6 h; Y
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化: X! p3 Q( {, s2 m+ f  b& [* X& |
        'valid': transforms.Compose([transforms.Resize(256),
    ( n" d" w* }" x- d3 i2 {2 r                                 transforms.CenterCrop(224),
    - }8 M( g( A0 N) r0 z4 J4 c3 I                                 transforms.ToTensor(),
    ; K5 G$ k, q6 g- V4 ^4 G  ~5 W# B                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同' U3 w% o! B* e
                                    ]),
    5 E1 T6 X( y6 T6 `: y}
    : n$ `; M0 o% t
    " k/ q; {- g& {+ A/ O. z( x7 r1! i! c7 @' X- W, |* _
    2
    1 q2 k: j% X; p, O39 ?& e  ^" V# Z& k
    4
    1 y0 Z8 R1 H% o. ^; w- E5
    ) T" v8 T+ O0 k! h9 `# F$ P5 O6
    ) Y( u2 o) y) ^3 z7
    # m) l" }& q" k. h+ Q2 ]8" X( ?, A- {3 p2 Y; t% I
    93 p+ y. _; H0 v2 H/ I, f1 p. ?& i
    10
    ) T+ G8 a5 g( R% k0 Y11
    & N; h% x1 n1 J+ S12
    ! T9 o8 e) p- R% k3 q9 r13
    . l$ y) R+ d% h3 O( D% w14
    1 s7 g$ \' A. Z3 c0 ^8 U& r7 X" J15" u1 W1 B7 C0 s6 W7 P) f
    16% c) x3 T1 J! d- k. Z% l
    17! }( x+ y1 P6 i! }6 H# v1 E% w
    18
    % u9 S2 V( @: [1 `3 n* N! f19
    % f% i/ ^% J/ I  y( x/ x9 ]20
    ' C& q* c5 N: B% K& n" u% k" l21
    9 V( r! J0 ^2 k5 k5 \batch_size = 8
    1 i7 Q6 p$ c* u' C7 Limage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    - H5 s% d( s) y& h+ wdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    4 j4 t* W- B+ O7 D3 M: X7 Qdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} # o% k  V( Q: A' ]+ X4 S
    class_names = image_datasets['train'].classes$ O, o2 J4 i! d4 J
    : w3 `- a) n/ m# k- ^
    #查看数据集合
    . r! p' f5 e8 m8 ~" D8 K% qimage_datasets" q- n5 h1 ~5 I5 x

    4 t& r9 G6 a( B$ A8 |/ F1# s( `5 B2 X8 Q" T5 C
    2
    4 W( m, s4 i, H9 G; j4 d' H7 D3
    $ B; c5 Q) C5 W8 h* L1 f) u% [2 s4
    * e8 F8 L7 [( `4 Y5
    ! A6 t, E; L; O7 a4 R7 |  j6
    : q3 P* n8 C, o) o0 d( d7
    + A' u$ d  N! \$ z, A- _$ N2 j) G  M8  z, ~/ o; I) Y
    9
    ( \5 h9 T" |% I& M- M! J; i0 P, ?{'train': Dataset ImageFolder
    " x8 p( K' I0 w     Number of datapoints: 6552. g9 h2 f$ U* c5 ~' k3 g( w, A
         Root location: ./flower_data/train
    " w5 H$ N" T1 R8 T! c     StandardTransform
    ' t9 |+ }1 ^0 {2 d Transform: Compose(+ ]% p2 H7 o) H4 V$ c* h3 d
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)0 D" P4 t7 D7 `" y, m
                    CenterCrop(size=(224, 224))
    6 d% X. }* |. ?                RandomHorizontalFlip(p=0.5)
    9 ~2 _! G( l. f, x- M6 D4 H- W                RandomVerticalFlip(p=0.5)0 t1 n4 G# [! L0 P6 E
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    5 X% z  t. {5 p! y5 I* t                RandomGrayscale(p=0.025)  k. S5 P5 N: B$ m: w5 H  d, t% u
                    ToTensor()- L! ^7 w0 g/ l7 s, a7 C  U
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    - U3 ~9 J# w. r( }8 b- n# G6 _            ),
    0 y: \" _& {7 ^1 E* \' U- _  S& Z! m% l 'valid': Dataset ImageFolder
    . r. y- p/ A6 i& Z5 r     Number of datapoints: 818
    : c; g! y( V# N     Root location: ./flower_data/valid( J% ~( t; L+ g3 Z# d  N: V- w
         StandardTransform, N* Z  ^. S5 E- v, r
    Transform: Compose(
    3 D) i" i/ O! X6 E- M+ s" v" L                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    * ^/ |4 ~  A( s2 I* [: V0 d3 f: U                CenterCrop(size=(224, 224))" T+ b7 ?% @% x8 p4 j' Q3 Z1 f
                    ToTensor()" \7 h) u! F! k/ |, y' A$ T
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ; I, k$ b) {& C* S5 p! q. u' Z            )}
    - ~# ?6 N! R! a% K0 T2 Z) r& }2 {: ^7 M# Y% J) ^6 f7 g
    19 O6 D/ o6 @2 E: Z$ w
    2
    2 u( \3 r" f! ]: r) w7 J. F0 v3( D2 v8 I6 m7 ^$ @, d# }
    4
    - I  h4 i) W0 B+ H5) P! `  }. Z& r
    6
    . J& ?( q. m* W, p2 ^7
    % r- T& \9 ~; o5 @% q8
    1 s- W; i7 r+ ?1 s. B, M7 R: N9
    1 T8 B  e) P5 f' h+ S6 z& E10
    ! X& R2 g% v1 A3 E+ m7 g; u11
    # b2 `# _7 W; c120 L2 c! C6 D/ D% S% O
    138 }: I' Z) [& e: c+ J6 X" I
    14) B3 k1 ?" V& n; w7 A2 v0 ~: l
    15
      }# S# j5 n% i1 S  A* b16: G4 L. X7 A' z3 a: L" e. M
    17! F- g# R1 B/ p" R+ T- y: T
    18
    1 ?6 U2 F0 }  u& |' _0 Z19
    ; Q  ^& r) i  D9 a. C200 x" g9 I" s! S1 o9 q1 U/ E% K
    21" X' ?! r8 x$ |
    22- ^4 u$ B! K7 o- s% V! \$ B
    23# c: A( e$ g; Z& R3 g# k
    24
    0 q$ }8 N8 V4 O- w- b! U6 `# 验证一下数据是否已经被处理完毕
    - A2 \% D4 [. n) s( `dataloaders
    0 R$ R  N# @. Y: P) F5 n' `" R& D1& U' b; q  J0 ^* a- v
    2
    : s2 T$ I) D: `% q; F* L* M- i{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,* K7 h) ?$ n. p( }, d: t
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}+ |8 R7 P$ n9 z0 i' m% l
    1- }8 d1 l- H9 X3 N( M
    2
    3 [4 N: c$ m; Qdataset_sizes
    / k, B0 a( A1 |$ s2 I1( Q2 h# F3 [8 Y1 R4 V
    {'train': 6552, 'valid': 818}/ o- x8 t! s, R  H6 d
    1) k4 _1 p7 k8 ]+ P
    读取标签对应的实际名字: \: r7 j0 g& N7 Z" ?
    使用同一目录下的json文件,反向映射出花对应的名字0 b0 V8 W: i  y) k/ i8 |& `& ~9 R  j
    # c9 ?4 g( E" o- Y- B% P' S$ h
    with open('./flower_data/cat_to_name.json', 'r') as f:
    3 L/ H# Q9 Z+ ?; W7 t/ G4 E    cat_to_name = json.load(f)
    ( H/ Q" W/ J, B' k/ g2 f1; Y9 U2 A$ p% p% X3 K+ T3 C
    2
    $ I8 a: n" }3 icat_to_name9 k2 ?8 v8 ]7 \1 N
    15 J/ \! q! b9 \. D
    {'21': 'fire lily',
    ; E9 K# ^. ^# z3 e '3': 'canterbury bells',
    0 t7 Y4 v" Z- W8 S5 c '45': 'bolero deep blue',$ }" y, Z+ ~; ?
    '1': 'pink primrose',
    ! K/ c' k* F- W+ a4 x '34': 'mexican aster',5 y% m1 ?/ E; D1 V  H! j
    '27': 'prince of wales feathers',5 b9 A0 l' S$ M
    '7': 'moon orchid',' Y3 f+ l0 g3 y. X# Q
    '16': 'globe-flower',! n# I; v0 d& U2 S
    '25': 'grape hyacinth',) X* _+ r  ^# d* h2 G
    '26': 'corn poppy',3 t1 p) t9 W( |1 K! o
    '79': 'toad lily',+ i5 O8 W& I4 M/ c! I# g
    '39': 'siam tulip',
    . s" S4 f- J* O# b0 v4 |4 J '24': 'red ginger',) v; q/ G! i& V8 a3 o& Y& ^& w
    '67': 'spring crocus',( \# j. R7 V8 O8 ?! }( \! H3 B) }: M
    '35': 'alpine sea holly',! I, y9 X2 ?7 P- Y
    '32': 'garden phlox',
    8 @  N  ^: d- }; ?  b) u! r '10': 'globe thistle',: b2 W1 c1 w4 t; }1 t
    '6': 'tiger lily',9 x8 P8 y' h8 J/ G) l; K1 B
    '93': 'ball moss',
    & N6 k2 d7 R, ^/ Y, b0 u4 g% v '33': 'love in the mist',! b; \  r, e4 h# }# G; B
    '9': 'monkshood',
    9 P1 H$ P) \" {* y3 e '102': 'blackberry lily',6 a# [5 F& s6 t6 s
    '14': 'spear thistle',4 }6 O+ f1 A. C* R2 L5 l: |) Y
    '19': 'balloon flower',) j  P4 t3 l& m) L" i& u$ A$ K
    '100': 'blanket flower',
    ) l3 V- S7 d  S3 C '13': 'king protea',
    ; ?8 v+ A, Y  X '49': 'oxeye daisy',: K' r- X9 I" }" R  T% R
    '15': 'yellow iris',$ F8 D. O+ W1 s$ |
    '61': 'cautleya spicata',4 Y  d4 x6 ], k, l% Z: |* n1 ^# T1 @
    '31': 'carnation',; A  f8 l% i+ R& l5 c, W: |$ N1 M( i
    '64': 'silverbush',- y7 v; \" ~) O( D
    '68': 'bearded iris',# ~" b& o  {1 Z0 ~
    '63': 'black-eyed susan',
    8 ]- V$ a7 _8 }- M '69': 'windflower'," E' @- u8 R7 {) S$ m
    '62': 'japanese anemone',
    ' ~% {; v, l- k, K% M. k '20': 'giant white arum lily',
    , r! k* a% E% N! _  x. u '38': 'great masterwort',
    - g: h# }# `$ o, [; _* u; u8 ] '4': 'sweet pea',
    $ c: ]# s7 D8 b5 r '86': 'tree mallow',. Q3 U7 @( J  i" D- s$ ]
    '101': 'trumpet creeper',
    5 d' R8 h8 C$ o! \' \) T$ ^. ^) X' h '42': 'daffodil',; e1 |* F4 Q5 }+ a0 ^2 \1 A
    '22': 'pincushion flower',
    5 q$ P1 I1 S$ P: m9 B2 v '2': 'hard-leaved pocket orchid',
    3 t5 k0 L7 b/ f '54': 'sunflower',
    % A& m8 }/ _/ B5 w' z4 _ '66': 'osteospermum',
    . [7 ^; S# G! p. t+ H- R '70': 'tree poppy',
    6 F$ }) e1 A$ t2 g, a; J, z '85': 'desert-rose',
    2 z# ~0 C: ?) O( N' G '99': 'bromelia',6 e' }3 r9 Z0 W( N9 _& C* o
    '87': 'magnolia',
    7 k7 ~) u, U5 ]7 R) r( U '5': 'english marigold',
      l  ]/ J& }/ [8 p) q. ?/ l '92': 'bee balm',. m, m" d3 C+ S' j8 ^/ }# g4 k
    '28': 'stemless gentian',* H# C% Z3 Z% I. E, }
    '97': 'mallow'," |) Z$ E* q' m
    '57': 'gaura',9 [8 E" ?# s6 N1 _0 Y  N9 N7 \
    '40': 'lenten rose',5 L/ H% d/ X: [( R5 K( @, \* s8 d8 Y
    '47': 'marigold',* Y* h/ l! F/ n5 m# H3 }
    '59': 'orange dahlia',
    * h- s2 T5 }5 r '48': 'buttercup',# L$ Q6 \, P! u& ?! w* H, n1 @' U
    '55': 'pelargonium',$ a; p6 ?- V8 m& [4 C
    '36': 'ruby-lipped cattleya',
    4 B3 m9 K6 X( O$ Z '91': 'hippeastrum',  h6 s! u3 {* X6 N+ u
    '29': 'artichoke',
    & _4 w/ l# G, q( [5 r) r '71': 'gazania',, w) h# o+ _6 Q% l* ~, a
    '90': 'canna lily',
    0 W1 w, e+ r- J1 c '18': 'peruvian lily',
    # z  n- J) \  A2 w( {( o '98': 'mexican petunia'," p1 \/ F  s, r/ M8 y% b
    '8': 'bird of paradise',( a! S+ P1 [. T: X5 ?9 A
    '30': 'sweet william',9 K7 r) @# e6 g+ V, T7 ?
    '17': 'purple coneflower',* ^2 W* D9 M, p; k( Y$ W
    '52': 'wild pansy',
    / E) T! P6 d5 F. h& A! r '84': 'columbine',
    ) o& Z9 ?6 \- p5 R( C! B- [ '12': "colt's foot",
    0 S' P* U5 r  p '11': 'snapdragon',4 d. Y0 Y5 P( b: W
    '96': 'camellia',( F: Q) K4 b% ]2 \
    '23': 'fritillary',7 S! r: l! @' @
    '50': 'common dandelion',
    ( t2 U: |: r8 v0 _ '44': 'poinsettia',& o  H6 q* {$ s& ^! A
    '53': 'primula',' D! f7 D3 ^1 [$ g) n
    '72': 'azalea',1 N3 c0 F# V! t" \" k: Z. G$ ]8 ~
    '65': 'californian poppy',
    7 x1 R3 Y2 f# ~ '80': 'anthurium',: d5 B1 H( l5 ?8 c: c) M; P
    '76': 'morning glory',, M5 V6 U4 j" |% {& }2 V2 Y
    '37': 'cape flower',  r: X6 }+ q2 r% C! t9 E6 Y% ^
    '56': 'bishop of llandaff',
    $ Y$ H, H6 m8 L$ ~  a0 o '60': 'pink-yellow dahlia',
    9 [3 y. u/ Q9 C '82': 'clematis',
    ' W5 a% Z) e9 H; M4 h '58': 'geranium',
    7 y3 d: x: @. A, C% B, V '75': 'thorn apple',' \; \3 x* L* M2 l9 l; B
    '41': 'barbeton daisy',6 U$ V" _. k- m) e2 o! a
    '95': 'bougainvillea',# C. k0 a0 [- a5 N% l
    '43': 'sword lily',% V$ K2 ]7 z6 x+ q
    '83': 'hibiscus',. e6 s+ |7 r0 z2 k0 f( ?0 z
    '78': 'lotus lotus',
    6 N/ g2 C- `" ]! C8 d4 o& d9 E '88': 'cyclamen',; S3 ?# ?% h" P3 r: B$ |; U
    '94': 'foxglove',
    3 D9 Q" ~' j6 m8 q) b  D '81': 'frangipani',- J9 e% w0 S! t8 x% I. q
    '74': 'rose',6 ~  [0 M1 O. \+ @+ J- G  i- r
    '89': 'watercress',
    / e" |; y$ H+ G2 n1 ?' P '73': 'water lily',
    ; w0 z# @3 B$ m8 k$ e* y '46': 'wallflower',
    , V! n  C6 ]  \4 f0 Y9 k& b6 v '77': 'passion flower',9 y; j& `: e; B; r( T8 F) x
    '51': 'petunia'}1 X  }2 X3 _2 ]% ]7 X5 J7 l
    ! M. I; o; Z) D; G0 f  ~
    13 x- P7 l* g2 Z+ h( x. M  e" B2 p
    2! [  G8 K- W4 Y' d# @# X/ U
    3) o# f$ f5 N5 X1 w# K1 Y
    4
    " q4 g6 d0 C$ Y( m) v5
    7 H4 R4 t/ ], t: k( o! M2 N( Q& u6& @  Y$ h6 |: G# d7 I& U4 {
    7  F! y+ |& q9 J. r" a! G
    86 O% g2 d# g( G+ C& |
    9  {, r. C; @* c: u1 j' R0 [
    10  u, ~3 c' k( h2 o0 M1 F/ d+ ]
    11
    % x1 |2 Q+ L3 d9 \, A- u4 S2 ~/ e. h4 `12
    8 B+ ~$ L+ s4 e9 |" A0 g4 e13" r, P' e+ w) t2 N  H* t9 i
    14
      z/ A2 e( S8 p$ f9 E- D4 O15
    " x! t' Y# t% b" t: ~16% Z  E: l/ F9 T) M) ~' B
    17
    5 V4 y/ ]1 ]& Q+ e) Y) s) Y8 G& B18  h; G$ E3 m7 X* v& Y
    19+ W; |' D' K# B) O6 m2 ^
    20! ^5 H9 t9 a* w% M& _, Q
    21
    * r* Q, y3 S- J7 `4 i& }22& _) D1 w) d, a6 g4 O6 F: n# g2 c
    23
    # W" Z0 @9 f: z0 a* q# S24
    2 ?. P# F+ r3 |25
    9 s% p" ?" _, f# [  o; p) w262 Q; V) d7 D' ], I$ u+ \
    27
    ) Q9 S1 x2 H$ k( T' O$ F7 A28
    % K7 d9 U* j4 h/ H/ ~9 W* C29
    6 M+ S7 M8 N' Q+ |30& u$ L; ?  ?& `- _; e9 j2 h
    31
    5 O6 a, g, S5 y7 X3 N5 A32% E( J( ~3 m/ E& [3 n0 T$ Z
    33
    0 f$ Q) S5 s$ Y. B! J: K* D, b3 Q- x34
    7 N8 Y8 R& m* a. U' Q0 S0 w- [35
    & {% j* \3 s" x2 Q/ \# m3 l36! ]' L8 k$ c! u8 H1 v% [! q
    37
    + g' A+ X1 N4 T+ x5 K+ q382 ]/ g+ l" z5 h
    39
    + T8 B7 D# Z3 U" T/ n( n! s8 D40+ s& s# X& I( H! r6 Q6 w
    41
    * Z3 T5 j" L* y( Q. ]0 N8 H42
    / U7 v4 S3 `7 f43# t+ J, f' V5 K1 V9 e3 b" C3 f
    44
    ) `4 w7 M1 R5 v1 p' J0 W45
    6 v; g" m2 p# l9 m  P  l: l46
    ( r! c7 l1 w( Q" p9 B2 R& u475 a7 v$ ]) t! w# Y7 {: H' q
    48
    3 ?* Y" [9 |4 Y8 \3 D; b9 W49" W5 q! r* S  j' r, g
    508 t  ^) `% {/ c. V4 h/ b5 o$ L
    51
    9 K! s$ W4 f$ c; q7 k+ A- K- o: i2 ~52; Z! [" E3 h4 L/ B* c( }1 o* ?3 Q
    53
    ( u, X2 B9 j, Z5 s544 t/ O4 E6 y/ u$ A6 u
    550 ^1 K  \0 H) K, C6 `/ m
    56
    / ?3 p+ G. T6 q+ Y4 A3 v57
    ) K, S7 m! U% |' X4 [/ |& ~58
    % _) @: _; q* b# {4 ~* p597 y; F) z1 ~" }  U  }# a6 C
    60
    6 ^1 s) Z) ^0 ^4 H$ j61
    , v5 k: W5 h  V/ q62
    , z1 s$ H! {. h. ?" v0 ^* F) k63
    5 F, }" z% W$ ^6 u$ q* A2 f# `645 y& ]4 ^( \3 A3 e1 k2 ^/ b
    65
    % O$ M; t! Z+ o" Y( c, J9 E66
    % I. Z3 `7 Z% a4 m67: w- K' x7 \, y6 K" A( l" l
    68
    0 m& Z: y4 D' ^7 _0 `' G5 @69, T+ X0 o3 S3 w; ]5 l- `
    70
      `6 Q' q+ C/ f1 P, B7 Z$ S" r71
    ) \0 b1 o# x7 f! r5 o72
    1 A3 Q6 F# A" v: k% z- ~73! t' `/ I+ ^6 l' H0 G
    74
    7 o) l4 }( ?4 b( T" `& q75( p& {7 P, \( x; z: B( `
    76
    % L! d6 c" A! V3 L: h; V0 \3 i77
    ! a) b; k! n5 K" R/ k78
    8 x2 {' J" o" D  [  x79
    % c4 n! d# V5 p( y80
    5 m' R/ X* Q' M* X* p81, M' C/ w7 c% g7 y5 a; }
    82( R3 _! k% k* ?& R5 I: D+ t. J
    83# F; y& H' t8 z& }* V" s# A
    849 I# L0 z8 G# l, }3 `. }7 j; h( ]8 w
    85
    $ _" p7 f9 j# C" j! l6 b# F* o$ ?& l86+ }" d* H! v* O- N/ I# x* }: O
    87* J% z) f  [/ N+ I5 M" e
    88* \' C% W! L* i7 H6 T2 l
    892 s. m8 W+ @+ l* o# p% ~1 f
    90% w: \. {  U' _5 B/ [8 O
    91* U" ?3 u0 K0 T2 |' I
    92$ C: E1 D2 k  @  {0 x+ g  U% B) e$ A
    93
    & X* j" _! n/ A7 U: R# M94. U+ @, w$ C1 {( J1 ~1 S
    95" ^) ^" i. A2 X/ y7 s1 C- j
    962 }, O& k9 H5 v2 s
    973 F* C  e0 m( v- |
    98# c9 `) ^1 f4 m
    99
    & ^( a$ P9 y& h/ R5 _100
    # m; |* m& F+ |2 y" q  _1 E101
    5 K% ]$ V# H) n1 h4 Z6 i102
    % @9 Q- T% \5 |, s/ d% _& X! Y4.展示一下数据
    5 d& r, ]1 P6 C1 E" T5 `, sdef im_convert(tensor):
    1 |! @$ W5 Y+ u2 U, X    """数据展示"""( C1 w6 p5 _, ^9 r( `# x7 k! [
        image = tensor.to("cpu").clone().detach()
    $ U9 D/ P9 [. W5 ?3 w7 X+ q    image = image.numpy().squeeze()
    9 A/ X& I5 _1 a; W- T7 U    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    - F: Z% Z/ G& P$ L3 P$ j    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    - e. m7 {2 C) f! h; E! x  S    image = image.transpose(1, 2, 0)
    ' K7 `+ ~2 L5 F) @6 I* w' O& \    # 反正则化(反标准化)
    7 H5 U1 e  z+ S9 v( U2 K    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    ( d1 n! }, D( y3 L$ L* R6 b8 Y$ \+ \" U
        # 将图像中小于0 的都换成0,大于的都变成14 [* Q0 A7 G- F( E7 E: c
        image = image.clip(0, 1)
    3 U. G% D! W$ _& U6 ]0 K! ?% T- {
        return image
    % G5 P9 B$ v! e% [, m( W1 ]1
    : i6 N. w9 V- Z  r; c2
    ( J- _; x8 v  K# E0 z34 j* I% }1 T/ O$ a3 P6 W
    4* s% F. j! Z) Q) G; P8 \- k$ U" K* p
    5
    " D. z+ I$ h. z9 ?6, X/ d' P+ m# a+ O8 x
    73 y9 G1 J( {' R0 _& C4 j
    8
    3 M9 l& p# {' T! p+ e: u% ?5 h& D90 _5 r- }4 L. L' `9 f+ f+ b+ e
    10' R1 c4 ^- s2 n5 e$ Q- P6 W
    11
    % W3 m8 K3 k) h, J128 i  W6 G# y% _( E
    13
    # p0 c  P5 t. G8 N* _* _$ L1 w14
    5 p8 P. @" ^4 i; Z! Q3 V; U# ?# J# 使用上面定义好的类进行画图" A, `2 N1 F7 d  o
    fig = plt.figure(figsize = (20, 12))
    8 e  O* _2 U' \3 J5 h6 Z) K# Rcolumns = 4
    5 ?- I4 \' n  @/ lrows = 2
    , d9 `8 n1 d0 h* p! `6 {$ r! T  f+ i, i/ q
    # iter迭代器6 Q% h7 `6 }" n# R" B
    # 随便找一个Batch数据进行展示
    + K9 L, o* r  {# }% Odataiter = iter(dataloaders['valid'])2 e& W) `9 Z& G
    inputs, classes = dataiter.next()% q* `7 Q- {& u6 t+ g! j# E
    ; Q+ r1 X1 q$ Y6 a* t9 T
    for idx in range(columns * rows):
    - r2 V7 c# J- J& O( l$ Q    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    9 w3 Z, d7 g7 k5 j    # 利用json文件将其对应花的类型打印在图片中
    # l. e6 A  |, G. [    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    0 W# |5 i% p5 H5 i7 q- v: I    plt.imshow(im_convert(inputs[idx]))
    * X3 M; P2 L4 V7 H1 ~# zplt.show()& \5 _  o+ f9 S. I0 j, c: r6 v

    + p( `- Y- C. I1
    0 @4 [: E8 K4 |6 q+ M3 l8 J2
    9 {- D: _. b! j1 c5 g3" P& T( \6 T- n  G
    4
    1 w8 \2 Y9 O2 ~5
      ~& Q% t: _6 b0 |) z6 \1 x6
    0 D$ |/ d6 _* z# U0 i; f7
    ( Q* Q& E7 j. X0 K, J* X! q8
    ! }0 X1 G4 }# V" [93 c, N0 j/ L- p/ L6 t
    10, l* _# l" W8 H: N$ u. d
    11
    6 ?4 T( J+ f5 Z: L. X- f12
    7 r6 z5 O# q# F6 }13
    7 B0 ?4 \1 N; T. T. B0 S6 W. E14
    + u+ @% @  b1 j! X$ {2 V3 q2 u7 l15# [9 Y6 Q+ m8 e% F$ S
    166 w* a  o/ S7 P6 i1 o8 e; R% B
    1 K8 y& |6 K9 t6 B  U5 P1 M5 x
    6 x) h  ]# g& N. d  d# [
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数* H6 V6 P( }) t6 C" x. u1 H7 r
    model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    ! G# `7 p% J/ j% _/ j& u6 o& m# 主要的图像识别用resnet来做. z2 J; f/ E2 ]: M" G* E6 ~
    # 是否用人家训练好的特征% q8 p8 x2 C8 S3 j6 w
    feature_extract = True. ]+ L/ f) x" y6 b' Y/ J
    1
    ) b$ V4 `3 h# Y6 F4 F: s26 A! c) S- T6 A5 A/ A1 |2 ^
    3
    / U' c) _/ D. b. W  `3 O( T4
    ; I1 s; O3 G! k  a( z1 L# 是否用GPU进行训练
    , }6 L: L) ^4 h, _& Etrain_on_gpu = torch.cuda.is_available()
    / B3 @: @2 C: u! O# I1 t" `4 V
    4 B3 Y1 o) e1 _$ qif not train_on_gpu:
    - g7 ?, I' l) Q- ~' t    print('CUDA is not available.   Training on CPU ...')* C* W8 Y8 n0 C
    else:( `4 w- y1 G5 O* Y' N: ~  K
        print('CUDA is available! Training on GPU ...')
    , R, r0 R, Q8 @' T5 {' W0 m; @' b( U4 D( U
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    - W! e, |7 t' n. N& j& \1  y* Y7 i' `' i4 g3 c4 X" S
    2
    ' l6 U+ y3 U8 Y) l( d: A3
    6 \2 c; M$ d8 N- m* q4
    7 E1 y" z! ?# N7 ^2 c5
    & b' p1 N) m" |( j$ H# c6
    * B& E) _" m& L- `& r; P9 j7
    8 a: y& g+ \4 ]8
    ! ~. C9 Q1 |$ N% ~/ D6 B2 n9
      M% A( \. {$ b# }) P6 GCUDA is not available.   Training on CPU ...$ X0 r2 y' ~( x" u
    1+ u) X- |# W0 k1 X
    # 将一些层定义为false,使其不自动更新& k8 k6 L7 t  W7 ^" C/ X: a, p0 f
    def set_parameter_requires_grad(model, feature_extracting):
    0 E; x$ h9 @0 }    if feature_extracting:
    ! U% y% }1 p0 z/ A        for param in model.parameters():8 l( ?9 X$ N  Z1 p# n( w
                param.requires_grad = False: I( P, Z, B$ ]! D$ l
    1
    " M* m# A! N* d3 A# ^2
    . w" x6 m7 O' d& Z5 r1 n0 U( P3
    , O0 S% J; S: u9 O, F% c4
    * C) d3 [' W) h7 b- ^6 k* p53 I- |# a8 z# B0 |3 |0 A
    # 打印模型架构告知是怎么一步一步去完成的' u- K% x' D2 `
    # 主要是为我们提取特征的( Z: \- p4 m( B8 [: c6 q

    . d8 u; w# c7 E! R; \/ n4 D. X; mmodel_ft = models.resnet152(); P' o; U3 `' P, u
    model_ft  M3 q2 f. k5 s
    1
    % c  r& x' H+ q$ A1 P: H2
    $ N2 F& W, T- v" k3$ U0 l$ w3 o# V8 E* N+ f
    4
    / ]" X% s9 n+ \/ X7 ~5+ [- X1 w3 S- o, t
    ResNet(9 @' ]8 b) F. U% [) D# c6 C. m
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)8 O# n: o; ]# k" y( T" H% u
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)% n( [  D4 c( |3 q4 Q
      (relu): ReLU(inplace=True)
    4 W9 _% u( v) c. n7 }# w  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)$ K+ G) ~# p* Q, K' k# k" z
      (layer1): Sequential(
    ( @; c- [5 E+ [  Y4 `    (0): Bottleneck(
    ) u6 j: P2 F# K, D5 T: C: B      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    # R& i6 N& ?  v/ {0 ~( I# U) O      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    $ I4 [2 ]. t/ I7 b5 J0 f$ X      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)! u) W6 r8 x  Z0 n2 E
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ x3 Y3 o5 D4 r8 @9 p0 x
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)$ |7 Z/ E8 M& I* a
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ( K+ S! P1 q: K/ n, F      (relu): ReLU(inplace=True)
    " T( f7 o, c4 J- Y      (downsample): Sequential(
    + v8 k& F1 Z' _        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ( N3 H7 J! W8 x' g# D/ d        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    6 L: r5 L2 j: J. k      )
    7 d) _; [7 ~/ ?, Z+ D" E  H; S# y    )
    6 a! b' ~- K9 ^2 {& T' ^中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。; M: V$ }! }- z* P& `  _& b( n9 J
        (2): Bottleneck(, w2 x- X# \+ k
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    . h4 J$ c5 t; V& |- `9 J) B      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    " j* @0 s! b# K0 Q1 j0 Y& q      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    + b; x7 n* F$ a( i* B% O3 n      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      u% @2 f# x) Q4 s' m3 E5 w      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    % ~3 _' S! P  q& u) V" S  f      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ' U7 R5 g# v$ N# a$ h1 M  g- [      (relu): ReLU(inplace=True)
    " n% i3 h  q4 p- _% j2 @  g0 g3 m2 h    )% z. }/ r+ `/ j0 _) H
      )
    ( F* \9 Y5 P: d" }  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))  _. y5 M$ \0 h9 C
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    1 @; P% R4 f  o)' R+ P- w7 w9 O. q) z6 }& V

    - ~3 C8 f3 M* n1* O; _& o+ Z' U
    2: _8 ^) `3 M3 \) D3 A% F
    3
    + e7 w6 q; w1 H: Z42 G, w1 V9 ]4 f& i: f; W
    5
    ' A5 F! ~2 _: k2 G6# ?1 S  t; S6 I" z" Q/ @5 B
    7
    . R, J% N" ]7 l5 T3 s) i6 l8  u6 l; C! O+ o2 @+ w: D
    9
    4 L8 A& j; [- Z+ H% f0 v10
    & m7 C* B5 f' U11
    1 C1 j9 _( g8 E5 C" [: R12
    4 i  ?+ f  d* f% N: T: G" c13, l  Q4 h4 E4 r
    14# p- _. \8 \( ~$ }" i3 }$ r
    15
    " n1 I4 d4 `% H: e* k$ N16
    5 l  M$ W; t; }+ W175 t3 V$ H& c1 {: s0 P4 b
    182 E- A) ~* ~* C* w; z4 J3 U2 v, ^
    194 X1 K* C8 x9 R& H& c' W
    20
    . H' Q( m: ~3 G. _211 h7 @; s. _( g$ M5 Q4 ~2 S1 F
    22
    & ?/ G4 @6 L) l  u/ i" h; h23
    ( i, y8 c: K& a7 g  k$ o5 D  ?* G6 b24  _$ s8 N8 M6 M& Z$ _
    25* u6 r/ s, `! c$ ?2 p1 z+ u/ E- H
    26  H1 N) e* q; }' z
    275 A; d7 n* G. R3 g& O0 i' J
    281 r. G' O; A: S8 O  h
    298 ?9 }' u# ?8 p" S0 S6 i$ `' A
    30# b) }( N; B  i# z2 {3 c$ s. o7 \
    31
    3 H* a; `! T8 C$ A+ a32
    $ W3 r* V" f$ Y6 K/ K334 n) x3 F$ K& \1 L, s
    最后是1000分类,2048输入,分为1000个分类; G! T4 x7 N: d
    而我们需要将我们的任务进行调整,将1000分类改为102输出0 n5 ]* u5 p; A$ u* _

    " |, T4 w& ?0 ~6 T  p8 q% b( X6.初始化模型架构; K+ s1 L% ~) J
    步骤如下:/ P, W( L6 {, s: g
    5 q8 S1 L9 ?% k" ]; k
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数8 \7 F: T* `2 Y9 o
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False). u2 B& j* a% N; y$ _
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数$ X# o/ t1 K% d" o8 m0 e
    官方文档链接" {# ?4 e  ?2 ^- P
    https://pytorch.org/vision/stable/models.html
    5 F& b& O  R7 {3 m8 G& P- m( f6 X3 J! |
    # 将他人的模型加载进来' X1 i( O" ~) r! h2 S# H! e
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):3 O! k! F2 V. J6 u+ f
        # 选择适合的模型,不同的模型初始化参数不同( h4 {5 u8 H1 w0 v
        model_ft = None" b  @3 q& S5 e$ @7 v
        input_size = 0
    " M; R8 t. J2 e
    . b; \! B1 S& @& d    if model_name == "resnet":
    ; F# }9 k9 g7 Q+ Q* R6 [# Z; W        """
    * W  z9 }" f# f1 t. {( o        Resnet1528 I4 I8 D4 }6 u$ y$ l3 Y' c" v
            """
    6 z9 v7 K5 i2 q6 [* P( t
    % O; R: `  f( }" A- ~  f; b4 f& G- x1 N        # 1. 加载与训练网络* p9 N  Z2 @2 ]+ d6 P) y0 t
            model_ft = models.resnet152(pretrained = use_pretrained)2 f# n# D9 n  x/ S
            # 2. 是否将提取特征的模块冻住,只训练FC层9 t* d6 p: z. k  j
            set_parameter_requires_grad(model_ft, feature_extract)
    - F3 p5 L7 S. B8 A$ O2 R( T; {        # 3. 获得全连接层输入特征1 n/ }3 {: s5 I3 b9 O
            num_frts = model_ft.fc.in_features' Z( [) }' P  m) q- w
            # 4. 重新加载全连接层,设置输出102) ]. b5 u4 y% y4 y
            model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
    $ o3 j& F- S- q# |/ Z                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    ) }' \% O' ^* O: ^  R6 J        input_size = 224  G+ C% Z; I! }  n" s

    " \' [& e) K5 ?' v! g    elif model_name == "alexnet":
    " C, b! o2 l% }. F$ [2 O3 G$ v        """( w5 ~% v8 n7 k. B6 b3 b2 E
            Alexnet
    # d% U  {6 q* s- K) y        """
    0 y: w9 O. G+ A% X% w# {# V        model_ft = models.alexnet(pretrained = use_pretrained)0 X, e  m: @# [$ n
            set_parameter_requires_grad(model_ft, feature_extract)" ~3 x0 ~4 i# a) R" @9 W- Y

    + ^; b/ Z) Y# K) y5 p, `/ [1 ~        # 将最后一个特征输出替换 序号为【6】的分类器
    , e  b/ s8 ?* c0 v( [/ M        num_frts = model_ft.classifier[6].in_features # 获得FC层输入+ u: j" M0 {9 ^+ \/ M' e& X5 l
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)" T: S* ^  W1 V( ?  L  [+ _
            input_size = 224
    $ G# Z- x  `6 c' @4 k; v( {  B5 p9 B1 ~
        elif model_name == "vgg":
      C9 H" |9 h2 q/ M        """
    " Q# y5 K& q. P( {        VGG11_bn
    7 w; Q% O- F: Y" T        """
      A6 G( e3 @" V8 i7 }( A& o        model_ft = models.vgg16(pretrained = use_pretrained)
    - F& [2 e) P0 p: \, V        set_parameter_requires_grad(model_ft, feature_extract)
    / \8 N- X+ P9 \        num_frts = model_ft.classifier[6].in_features
    1 V3 k7 ^4 o. b/ O4 n+ x        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)* Y) |5 N8 P8 C
            input_size = 2247 p! M0 M+ w. k; O! q
    % i% ]& A; B4 c0 T( z
        elif model_name == "squeezenet":* j, T% J0 n; O0 L5 N6 _7 e
            """7 k$ V$ B/ U6 |
            Squeezenet
    0 m1 M/ v* x, c8 ]        """
    $ j0 n; f; x7 a% Q$ T8 L        model_ft = models.squeezenet1_0(pretrained = use_pretrained)1 G' y+ ?/ \0 u/ {6 m
            set_parameter_requires_grad(model_ft, feature_extract)2 w% e: C1 Z7 i- @  o
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
    9 B& X( y6 n) y+ F6 O, N" L6 @. F3 ]$ a        model_ft.num_classes = num_classes
    5 T  t/ O) w, @" P) s+ c& O5 z3 U        input_size = 224
    7 g: M, ^* ^& x, U7 T5 n# H. Q/ O5 U5 K8 Z- k+ y
        elif model_name == "densenet":
    4 a$ x% @$ [- @; P5 w: Q4 r        """/ a! |1 D5 N) @
            Densenet
      G; V# @" m5 L2 Z8 s) I        """2 ^  Q7 m0 N3 P$ b2 [
            model_ft = models.desenet121(pretrained = use_pretrained)# N& d# h1 q; A- [  }
            set_parameter_requires_grad(model_ft, feature_extract): S2 l6 `5 k2 {) a
            num_frts = model_ft.classifier.in_features
    ! U8 [0 n, f+ H! t1 N6 k5 h; a# h/ E9 c        model_ft.classifier = nn.Linear(num_frts, num_classes)8 B, k1 U0 M7 f! n" S9 Y) Q
            input_size = 224
    6 g1 A4 R" B) P% B" Z  q
    6 S; @+ L" H4 ~% I8 L. H    elif model_name == "inception":& [; f; {# j7 H7 ~) }" n: Q: L
            """
    6 Q% k$ v" A# H% X/ v/ S, v/ S# L        Inception V3
    2 x# r0 Q3 C  Y0 V: l% Q+ q8 G        """
    9 s9 V* K% h4 ~! r( ^: `        model_ft = models.inception_V(pretrained = use_pretrained): ]) d0 l  V" m9 M7 i
            set_parameter_requires_grad(model_ft, feature_extract)$ i2 \  m; w1 M8 e: r+ |8 @

    ) t; {7 {5 x1 c6 `* h        num_frts = model_ft.AuxLogits.fc.in_features
    $ p$ V+ \$ X, z- P* A        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)& {, p$ u; A) E, A* \9 @

    ) C4 Z# N2 e! w( H7 F        num_frts = model_ft.fc.in_features' q! d# _8 Q- f6 G9 j0 I! D$ H
            model_ft.fc = nn.Linear(num_frts, num_classes)# @; @6 q% f1 t' H
            input_size = 299  f6 C/ a& u; Y- C3 Y$ N

    8 O8 h3 J' {' p$ ]  L! j  a    else:
    . Z; i* o0 o' v3 _        print("Invalid model name, exiting...")
    ) m" [+ T' i8 l' r        exit()
    ; k1 Q3 r; t, a4 f' X
    , _4 N0 j; |+ Q8 D7 Y+ s4 ~    return model_ft, input_size0 a8 w# s4 N7 k6 ~9 M* c0 l
    ( Y6 p; ~3 K8 Z
    1* M' C& W3 U, e
    28 b) W) b0 W+ `3 E% ~" D$ U4 t; [- {
    3
    4 l( M& d. X3 W8 B4
    + O0 ^- J9 Z) G$ t1 z5- ]7 R+ a$ H8 U: a6 J
    6
    $ Q9 v: U1 o0 @; b3 z0 `7
    / u! b! p% l) u) f, ]8
    8 E# }0 y" Y- n1 w5 z96 L  ~' d* p( ^8 h
    109 y. g' W# X5 D2 e8 C% M
    11/ M* K' s" q% g- U3 X' t
    12
    / f5 p# q3 ~& f% I8 D8 q5 o13; J4 j3 R$ t$ \* f% w/ H+ T
    14
    . j+ x5 a2 ], J. m153 V, A' g( g# y7 W" o
    16
    # y& R. s& }% O+ G17
    - |6 H& X6 ~% K& U; M18
    + R! O3 [) o- A* L! J! Q' G19
    ! x( N1 C) |" A0 v' _# u( h20
    ( V7 z; R) Y, c' ?8 Z6 O21
    5 S+ ^+ j% ?8 y1 V' k222 ~. |7 _9 n8 A" G
    232 k" f$ f/ D6 b1 j! g+ m
    24
    - g. p0 c7 V( m$ z. E+ U25
      C# Y% `; D" t' C26+ v9 b0 i9 `/ o: k( o8 F
    27% q* Z: P) ^. r  ]; O( V
    28
    * D9 C) }) E6 l9 z+ Y( W% V1 ~29% `4 Q; e8 @% Y+ f
    30
    , i, ^% f0 ]8 e& c" l6 U  W- y4 i3 p3 ~317 a2 |4 i; X! X5 E+ ~! `5 T
    32$ N# x) t5 z' G5 f$ I, ?, [
    33
    7 N3 N8 {! w( ]9 g% v+ C3 E341 j) }7 G& ], e9 g
    35
    - ^! H( x4 w2 C! \0 m4 ~36; D' G" E, m  N  j
    37. x2 f/ R) C0 ]- E
    38' L# [, }) c4 g  B5 U6 y
    39; G3 t1 _7 q6 v! n: z! f
    40
    * F/ ?" y( B3 s& n/ Z41
    / o- ^3 ^6 q( U; \+ A4 v2 C42
    " R8 l5 {0 m3 l. \43' U( q/ b. X8 I7 u8 ~4 S
    44  h; J, ]) d/ x+ S9 `) X5 n0 `
    450 g; C- k6 r. }: \& s
    46
    ) N1 E) W$ j% e3 j$ C# M' b; E47
    2 J! f& P% K" [5 p( s48
    ( L; y" I4 ^8 @# q. {, r49! D0 a5 w* }9 X! \. W2 ~
    50# @& G3 J# a) E6 v4 @& a4 U
    51" }1 I; I. K) B+ x$ _9 ]' j
    52
    6 L% q7 G  q+ `2 ^/ M53
    8 n" b# l' ^* ~) n54
    " y& d, X, p1 @$ h* W$ W* o; S! G( M553 Q% y9 D" h6 g
    56
    - y  _! {  Y- K  `- f# M57
    0 @, v/ X: T; T6 u) H7 h7 N( o6 q58
    3 I/ w& w6 b/ @1 e592 A' j9 W  Q2 U& a
    60$ o9 j- g2 J/ y% E3 n2 ^
    61- I0 q* N8 w- `0 g( A: r
    62
    0 K9 R: S4 D+ w" H5 v! V63  S! e) F: `5 v+ z
    64. w" D; L7 S  M/ O  e6 P* f
    65
    8 c7 {: w5 o& S0 _664 e0 z9 w. g$ @
    67
    % i9 ?$ q( X; U; U" h. y2 o& s% q68! t  E$ `3 d0 o& B6 Y" F
    69
    8 e: x, q. I1 F2 x1 w70% ]1 J+ [" x: G( N7 K& O' S
    71, x4 t; r/ m3 n
    72; i" g6 U) y/ W; \2 A2 |* S
    738 T  Q: ~. n" b- Q, w
    74
    , ^5 t  y8 a; c) O  A- P75
    & j: T+ e* \: f+ N/ z76
    1 R& f: V; w" N3 d2 |77
    , T7 ~9 }( s) y78; `; H, g; C- C2 X* W
    798 Z/ B/ f: C% _% k# s% V' o
    80
    ) Q$ V( S. W- s9 [& Z81
    , X& v2 j- p+ {2 l3 U. F5 o824 D5 c$ ?0 x# E9 x/ V/ ~
    83( ^+ `9 C' D0 |, o! a, z2 Z
    7. 设置需要训练的参数
    # I3 h! W( w2 x1 n, q1 L# 设置模型名字、输出分类数  g" I3 {# P5 l7 |" M* `
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)$ U# Z+ k! ]' F
    7 B5 I, m3 `( y% I
    # GPU 计算
    7 u6 p) b1 a8 H* Z/ }model_ft = model_ft.to(device): l' T; V8 B+ S
    & f$ y' a7 ^5 f, q- S
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    7 A/ o' S3 A3 \1 ]0 h" \, jfilename = 'checkpoint.pth'( e+ s3 d9 K- `

    5 a) A' W8 A; R; d% c( |( D0 R# 是否训练所有层& b) G4 ]% {# s& N5 b
    params_to_update = model_ft.parameters()  h4 B9 G" a8 w+ l
    # 打印出需要训练的层% R" _) H, N+ S& v) P6 ?6 i" J+ |6 R; C
    print("Params to learn:")/ i# Y# O% g3 h, E: u+ G9 u+ F. O
    if feature_extract:
    ) ^4 G( F) ^* p2 D    params_to_update = []: I9 Q* ^0 K; C
        for name, param in model_ft.named_parameters():
    % ^( F; w/ ?* U; ~3 q* m* S        if param.requires_grad == True:* k2 D7 \6 U2 U$ S
                params_to_update.append(param)6 ?6 i# E+ D- I4 L
                print("\t", name)
    0 ~# {9 G/ r# i. ~( [; x6 |  k9 Y: m  u  Xelse:$ g0 d' S5 @# J4 l6 U/ x
        for name, param in model_ft.named_parameters():$ D% d9 v  y# c/ C1 C( N
            if param.requires_grad ==True:5 V5 S9 `5 h; F3 S* f# q
                print("\t", name)+ f4 s7 L- g! y- A

    * K7 A0 b% L: g" O* H) V$ k- {. t1$ ]2 F& I4 C7 n  x" K* P/ m6 }
    2, n' s) R: e4 M3 G( P# D% Z  `
    3
    , C# d1 w$ _- y! y7 O! H4
    / ^2 I8 D% I( P( W! N" j5
      ~6 l9 v# f  O6
      S  o/ i8 b/ V& r0 w4 ]8 I7
    1 ~. W& [/ @9 J7 [1 ?; e83 o  W' j8 U, r9 Z
    91 a; {( A* g: o4 J* U+ W/ {
    10" G8 T0 m' S( G- v7 ^; w3 U
    11' N/ O9 }, [, F+ {" W( m
    12
    " _& h6 X, P$ T& `13
    % u$ H" c. A8 y, i' V/ E, V/ V141 o2 p, k) y% o7 r$ Z
    15
    $ Q0 \! @3 ]$ c  H16( H# A; K1 D1 [3 b& D+ o9 t* |6 N
    17
    " n5 T  F, b' v18
    $ x5 P! e5 t# f- y4 d' A19- r8 {* |7 Y7 K- n/ u6 \8 p
    20  [6 |: B6 d- p4 [
    21% k6 a- {! W) T
    22- k  ?& G8 m9 q# v' ^9 B
    23- Q: `0 W$ S$ v1 J. i4 C
    Params to learn:
    8 P. i( x) h+ l) e         fc.0.weight: A3 [1 c# @% \( Q5 Z1 X
             fc.0.bias
    ; K% b& V2 q4 C9 U  o; O1
    + R+ M3 t- Z! @2 p+ F! r28 f9 x7 g) ^% ~
    3
    ! o& d3 o  F% ?7. 训练与预测  ?, }9 y- I% @$ o
    7.1 优化器设置
    8 d. h: O, s9 Z2 x0 Q& w  P# 优化器设置
    # Y4 Q6 j) [5 o0 Xoptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    0 d8 l" w5 t$ v2 K3 W2 D0 o# 学习率衰减策略
    , d6 d1 `3 V5 L. ?: n- U) ^; [scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)9 V/ z. }& [5 o* O0 J
    # 学习率每7个epoch衰减为原来的1/10
    * n8 r3 H! n; G# V, _# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    - O% \/ _0 `$ E% ?* x2 n  g2 n: m" h( u% O& c& }/ _
    criterion = nn.NLLLoss()* i/ O; S3 n; O. H2 A4 j6 N
    1
      \1 c( r+ A& Y4 A23 z  J1 T- i5 a
    3
      X' P  S# H: A/ l4
    4 P' T) t0 X/ t( O! Y; P/ l5
    / ~" A( G; {. y1 x  N6
    ' ?' x4 c4 u$ o5 `2 ~- h7$ W* w- x2 r6 W7 A1 d
    8
    * a+ T# w8 e3 H/ n7 M( _# 定义训练函数4 _8 W7 K" V: D5 E% T# g; @
    #is_inception:要不要用其他的网络! s. R7 g9 |: t+ G0 `$ z
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    ( \. v; j' f0 C0 V0 z# P, j7 C( G    since = time.time()
    9 u( T- ?  b8 H( b; M0 ~    #保存最好的准确率
    1 g" T3 f  v, Q6 W9 `. {    best_acc = 07 d2 @: B4 y5 ^3 N% M6 U* e  w
        """, e: Z) G+ b3 z! \& U" F
        checkpoint = torch.load(filename)
    ; T5 K/ `, N5 I7 J: ^% q    best_acc = checkpoint['best_acc']( m0 K6 r% u$ {- F! `
        model.load_state_dict(checkpoint['state_dict']); W6 }% |1 C9 n1 n7 h- R
        optimizer.load_state_dict(checkpoint['optimizer'])6 ?: ~4 }! W7 K  z% D8 J
        model.class_to_idx = checkpoint['mapping']. P6 g4 t8 c# N7 X3 e$ g! g
        """
    , S$ N: B, F( i) ~    #指定用GPU还是CPU
    % s% U6 K1 X0 a" ?1 s4 D' A* P# c' ^/ x    model.to(device)
    1 H/ Z/ h8 z( U3 e4 y! E. o4 F    #下面是为展示做的: b+ S% o1 |6 o- B) H( O1 l/ ^* v
        val_acc_history = []3 n. O! R/ p. v
        train_acc_history = []
    5 K7 }0 T; j' ?    train_losses = []" p  _# k: B" `
        valid_losses = []! s- C2 J! x9 N/ n3 h+ ~: n$ {
        LRs = [optimizer.param_groups[0]['lr']]3 o, C7 ?3 O2 H. w# \. f
        #最好的一次存下来
    2 _( ]' h% x# S4 W0 L; L    best_model_wts = copy.deepcopy(model.state_dict())
    $ l0 P8 u! s# i) w* l
    1 A. P6 J9 J" ~    for epoch in range(num_epochs):
    1 R: T2 Z- M5 V0 e/ G        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    % f1 n4 N& v$ Z/ v) q3 j: p        print('-' * 10)
    9 Z4 L1 f/ w8 G8 W) |7 s
    , J4 n! E+ S7 ?0 A) s  ^        # 训练和验证% B+ Q  D3 N% H5 a+ p# D- @" I
            for phase in ['train', 'valid']:) s  G* s; ~8 v5 f- Z+ L- w
                if phase == 'train':
    : a7 m2 z8 ~1 W" r9 e                model.train()  # 训练, w6 O( ?. k) s) B
                else:
      k  U/ V+ D5 _( f                model.eval()   # 验证
    3 e+ L3 J6 {$ X2 p9 }: ~
    ! j1 V' G# k3 V: `$ C1 s* l            running_loss = 0.02 r/ V- q$ m- ~9 V+ A5 h$ e
                running_corrects = 0
    ! r! I* {& v" g9 n% }. t! f* q2 L7 P3 A# S3 {: G
                # 把数据都取个遍5 z9 t( x# \5 e, e0 n% J
                for inputs, labels in dataloaders[phase]:* C+ `3 v; s! X* F: d
                    #下面是将inputs,labels传到GPU7 \! u0 y  C6 k  `9 U( O+ j
                    inputs = inputs.to(device)9 X  U8 d/ ^1 c; }3 L: E$ u
                    labels = labels.to(device)
    , y% d$ \: a3 ~# F* a- m  @& F) r% Z" K4 j, a5 J2 d
                    # 清零* R( J, |  E, `! }, Z1 e6 ~( S
                    optimizer.zero_grad()
    - W; N, D. i3 i8 h, M                # 只有训练的时候计算和更新梯度
    ! P7 I+ E- _! Y) w3 A                with torch.set_grad_enabled(phase == 'train'):
    # F% D% |6 `1 H# Y+ r2 z0 B                    #if这面不需要计算,可忽略2 ]) m; f5 t* I) G- X! V6 k  e
                        if is_inception and phase == 'train':
    / k" M/ w2 m4 `3 W2 U! h* D                        outputs, aux_outputs = model(inputs)
    ) W  _, r- k$ y8 x                        loss1 = criterion(outputs, labels)) \9 _% M0 x5 f* f7 L
                            loss2 = criterion(aux_outputs, labels)
    ( R$ _7 g3 A- P$ q7 Y0 T                        loss = loss1 + 0.4*loss2
    8 t0 J0 e" S; F' S                    else:#resnet执行的是这里
    ! B3 D) D/ \; e                        outputs = model(inputs)* O& h9 G6 z) @
                            loss = criterion(outputs, labels)$ A9 F/ B6 f1 [  J
    + a, }7 w. Y/ m. N, K
                            #概率最大的返回preds6 T. F% @. y8 k4 P( ]) f: K' w
                        _, preds = torch.max(outputs, 1)
    + k& Y' m5 F  T% z8 h  v" z$ l! T: k
                        # 训练阶段更新权重& F  Y# V' l7 r6 w/ o
                        if phase == 'train':  D% B6 @9 w8 O- r( c
                            loss.backward()) }9 M' x9 {0 |* ^
                            optimizer.step(): O; d; E4 J1 e7 K/ R! u
    2 e7 n7 |6 z( f) K& A! i# h2 c
                    # 计算损失
    . {' T" f5 a7 {# Y                running_loss += loss.item() * inputs.size(0), {; l( X* d+ n4 K8 C7 P3 c: R
                    running_corrects += torch.sum(preds == labels.data)
    4 G7 ~" q' b9 V$ |1 i. ]0 u
    3 b- T& l5 A8 h- u' O9 m( l            #打印操作- b4 K; E, ^) L) q
                epoch_loss = running_loss / len(dataloaders[phase].dataset)- J4 V. ~# Y! D# y
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    # w' S- F% \5 W4 ]8 y# x
    ; A0 F7 ], D1 F! R' c4 v: c% @* e' w3 P- ?0 T/ a2 I! v3 y
                time_elapsed = time.time() - since8 `& `9 T# U; @4 M2 h
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))( ?! j; I8 z7 o. c
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    8 C; A9 O/ {( K* Y/ u4 D9 ?; B: f$ b- o+ W  X9 A4 c. m' P+ J1 @4 C
    1 X' M9 r; {: d7 V0 e" F
                # 得到最好那次的模型
    8 K, u9 c  ]0 l/ u$ G4 ~            if phase == 'valid' and epoch_acc > best_acc:
    * a' F6 j# a" F, g                best_acc = epoch_acc/ O( Q- P& O+ ^
                    #模型保存
      F2 ]. V! J5 B3 u; h; u- b                best_model_wts = copy.deepcopy(model.state_dict())% F, ?$ Q2 i& d/ y/ j2 B& ~( u. m
                    state = {
    & f* k/ q0 h9 s+ [1 L8 e                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数" r' ^8 U" {6 r& E
                      'state_dict': model.state_dict(),
    3 I! q! _3 A& I+ Q" h1 a0 O                  'best_acc': best_acc,- K  W; k2 e9 N1 x
                      'optimizer' : optimizer.state_dict(),4 ~6 p+ D3 L; W5 b
                    }" v) d: k% q2 x, _7 q" P
                    torch.save(state, filename)
    # `. `; o* ~4 S8 j& A8 x* J            if phase == 'valid':
    7 d. N$ O. h0 R" A. ], R                val_acc_history.append(epoch_acc)
    4 I' d! J  d* d; v8 A                valid_losses.append(epoch_loss)
    1 D, l( \& W9 l5 n& o0 ?3 r                scheduler.step(epoch_loss)
    $ E1 \8 s5 O: P2 h3 a            if phase == 'train':
      D+ \3 ~$ l" e6 I) T/ H                train_acc_history.append(epoch_acc)
    * }$ V  t6 M- p1 f1 ^( z7 w                train_losses.append(epoch_loss)
    2 ?4 f- C- @; }
    ' H9 P: C* q& f. l        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))" J/ V, X, m: g$ n7 l5 g3 `
            LRs.append(optimizer.param_groups[0]['lr'])
    + q3 V- X% t* @/ d* `% }: X        print()
    + r3 ~0 o/ f3 z5 T) d' ^6 B, E* f3 L& I
        time_elapsed = time.time() - since
    2 Y7 D& S1 }0 D0 ]3 ]    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    6 w/ t/ R( t6 ?  d9 q0 b+ O& U0 B    print('Best val Acc: {:4f}'.format(best_acc))+ _! j$ e, m4 i, n$ F0 I

    1 x4 y$ L6 q& I9 Z+ J/ G- U    # 保存训练完后用最好的一次当做模型最终的结果4 S9 p2 m% _: [, A; L$ O4 L
        model.load_state_dict(best_model_wts)
    9 C, y5 A* g% N; t4 Y    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    3 T2 c8 l- }" h9 o0 p1 H
    - M$ N, _& @& v
    , {/ P! G' E7 _1 ?1; c7 b* v$ \8 ^3 N" S9 f# z
    2
    0 t+ y% L. f. \& d3
    " f( l" u5 ^5 X1 Z$ ~' S& z( ]9 R1 h5 k+ U4! n  X& M+ U5 f1 u. C
    5
    / N" G' M+ c" ~( _( W- {7 Y- [6
    ! w7 Z% f+ q! n% c7: c# t) Z& |1 ?
    8
    ( ^% s% F$ f, |( e$ s) g- ~9
    * R$ _2 Y' x6 e# z4 B10
    # D: [2 T* ], y11
    8 |: B9 y$ T! E1 J2 r+ Z/ V$ k12
    % B" H% S* a8 |. v135 O" \$ F, n  G& r
    14
    7 j6 a: \9 T. G: B4 s15
    $ H; i) G& C% m9 z% g16
    . l" o0 k5 X$ L) R17
    6 m3 C1 R. S& j) q! M6 ]18
    $ t. n: x1 J, m+ \1 ~$ Y19
    " q% Y. U: B4 B2 c* c4 |20$ F) ^# r6 U* K% C  Q: k
    21
    % i# t+ d2 r- [5 K, v$ i22  W- }4 M0 H5 ^2 y( Z  d5 s7 [
    23
    3 z* k/ ^  t7 p& ^3 O  I9 U5 _9 U! J24
    " P+ k# E7 g7 B$ u( p) h+ c  M255 v: y0 B8 |$ O: O* N. e
    26
    / D! U. R; w9 M* J/ i' J27; l, [" e6 U+ y" x2 u7 D5 d* R& w9 E! ]
    28
    0 T* A' m' D8 ^/ @% m/ q* ]295 i) |9 N& ^: G9 _
    30- E$ M* \1 g- p0 _
    31
    ' ?) o& v' p) Z% c329 t9 q; i, \/ p; }
    334 U0 n3 B/ o- q2 L
    348 h( B/ s+ h, t1 i# e+ P6 ~
    35
    & r5 p1 y8 l3 Q7 ~  P7 C36
    - e7 C# U" R4 N$ {# x37
    5 l0 [3 U4 J+ q( x' w38
    + Q* K2 [  k* ^- ~39
    ! p7 o: ]) j( s40: c' G5 r* a) x: J7 B$ t
    41+ }2 o0 G! Q1 w2 s1 M, e
    42
    8 k4 v3 |& a& j  i4 s, R! L: U43
    : ]3 N: u" D) D6 b; T44
    * f% R6 }, L0 e$ t+ O% z% G45! b2 L. p- G2 Q8 i4 L; }
    46
    : y8 l8 x0 s7 J. l" d8 j/ `& h47) h7 ]7 X+ o- _0 }0 F
    48
    % q' P8 c7 N! H# c, l) N0 C4 L49
    8 o, {% L8 h$ |. r50
    . L& y+ a7 g& U51: L, T5 R9 x5 g9 ]
    52/ I* Z: x7 O# Y) \+ A: {- j
    53  q- ^% y( a0 p% ~2 {% Q
    54
    ) V# j* }# c6 v55! a, x' i/ j! w$ R7 v& Z: C' a9 H7 h
    56" U& c/ O& |/ J' M' S
    57
    : ?0 H/ t* F6 P$ N. @% t3 l( ?* w58+ \2 D% P. B0 v% y' z
    59
    * n- G3 ^! p! r: U" `1 P0 s4 f60
    + T8 {  S* ~$ W" r! }+ E61. w$ j5 n' a9 M5 t$ l4 s
    628 z/ s+ z% R6 g; f# Y, o; y8 F
    63/ o! G8 t! D" H) x& M
    641 M) @1 X, a5 v: ]  f; H
    65
    & G# l7 {" O5 V+ s" M2 L66
    # P3 c% R6 y: j* \* X2 r- R% P2 P674 Z) B+ y7 |% O) I1 D
    68
    - Z0 W8 P% c4 p) `# ^8 R7 Q69# S8 m! ]) K& \' w- ]8 q) I
    70; ]- L5 n" k4 R8 J
    71
    5 {4 E2 @9 `! g: \" @72
    $ x0 i8 H4 }/ s. Z2 N7 v1 X5 Y73. z$ t7 Z- j5 c
    74# |3 u5 b$ N; j' j, s$ U1 K
    75
    % j* d- o! ]4 N7 M$ q, ?9 \76
    2 e9 \# @: D0 X* X77' \+ f1 f) A% S
    78! n  O' E( x8 x$ `9 C% M5 ]; p
    79
    & B; h+ L$ Y" W' S2 ?80
    3 a* {: _- v$ `: }; C; W( [815 L4 |( y: `7 [: Z( f1 k' s( D$ }! J2 t
    82
      s. r  U" u0 E; J* M. [$ r' P83% J# G( ?5 y+ d( R5 B4 a
    84- A6 a' b& C1 L6 d. P
    85
    / ?, c$ Q3 N8 n" j' L' ?$ D: E86; w* j$ G8 i8 ^8 [' ?
    87/ g  B) X8 j: J9 q& M8 k6 J
    88
    ' L  ?3 |# Y4 H7 r& R' ^* f89
    5 i; u) S0 o+ Y- Q7 X3 Q/ l90
      ]4 @$ y/ w$ }4 ?! B1 _: n91
    - O* \+ x& K  Z& }5 \; X8 q92
    ' T4 _8 h. T+ L( o93
    2 D' a9 v; n/ b7 C94
    4 _% p. h8 b4 l! g. k8 W95
    7 e7 v! g7 d8 @96% f2 I# @5 u- k, y. _' M: k
    97
    ( O- i8 {& }# O8 j. }6 b989 t" s; @, G3 V1 M) B  l! y
    99
    : w8 _+ C" |) k% W/ Q100
    ; ?  h% D6 u/ K; c9 [101
    4 A) X3 `# w" w2 x0 a102
    & B) u5 n5 L$ @2 N  B6 R103
    ; l6 B8 h  [# H7 z104
    . Z7 c. v% `- N0 e% ]105. i0 H- x  E: T" r
    106
    # W2 X5 W3 W( i6 N7 \3 p107
    . m7 T( P# n/ h1 c. k3 }! k8 r' L1082 p. m! Q  z% w; `1 G5 G, V
    109
    ! j( C& `  a: N$ a6 K+ o110/ ]  f; C, [7 H: J: |
    111) r% e* |) ^, ?9 v
    112
    ( Y, e- k4 w& }7.2 开始训练模型
    / |$ W8 f  @4 j- p# G我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次2 A9 M6 M8 _) R& w# Y

    : [' F5 `- m# ^#若太慢,把epoch调低,迭代50次可能好些
    8 [. w6 [) D  v( m6 }# d" M( X- a0 [#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合4 S/ ?; o5 h+ a0 @" Y0 G
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
      Q0 X4 g5 o  Q# Z+ ~
    & y( g$ B# U0 j1 a1 ^2 ~) i1: E. ^" Y+ f$ @/ O" @$ x
    2
    4 P7 V4 d) H4 x4 q2 X2 _# w3
    $ H& p! O. ], C* t* d2 h/ D4% X( h' Y% q% A: M% N: i
    Epoch 0/4
      o0 ?  b3 S" ]----------- `7 J+ x" @9 X7 k
    Time elapsed 29m 41s" _3 ^" K$ Z1 c* g9 y3 ^9 q
    train Loss: 10.4774 Acc: 0.31479 `" n* B! R, t
    Time elapsed 32m 54s  d) L7 R9 @3 e+ \; c+ f* Y
    valid Loss: 8.2902 Acc: 0.47199 u5 c0 C' a  `1 n' C: R, M
    Optimizer learning rate : 0.0010000
    8 y  w" w% J; u$ A
    + \8 C5 C5 G  O) s; Z) v* pEpoch 1/4. y- I! Q, V+ R/ W6 s5 j4 Y
    ----------
    % m% `. C- t" k6 vTime elapsed 60m 11s
    ; B, Q! v  m. x. q3 s( y) U# P0 v+ Ltrain Loss: 2.3126 Acc: 0.7053
    : e1 u6 q  d" l% A  x2 cTime elapsed 63m 16s
    2 ?3 s& ^" v9 O6 @valid Loss: 3.2325 Acc: 0.6626, ?) L8 }3 _0 h/ M& A( j5 H
    Optimizer learning rate : 0.0100000
    # m% x1 u3 N& D" h( x7 P
    / s) M% u1 h! \2 j5 `! rEpoch 2/4
    $ f+ K/ [- V% \0 H" ^  k----------
    * k  d0 v# _3 }) [2 jTime elapsed 90m 58s
    9 k3 n; i" }) r! s/ m! W* Q# ktrain Loss: 9.9720 Acc: 0.47340 m' G5 }: w3 I1 p7 @* X
    Time elapsed 94m 4s3 W$ p* E" W5 `- G' _# u$ O
    valid Loss: 14.0426 Acc: 0.4413! V) i1 \! y" ]. n3 }* d7 o
    Optimizer learning rate : 0.0001000; v& s. K9 H/ O/ k
    ( L, O: o) F' ^- d# D. k
    Epoch 3/4. w; d0 ]! X' C( a& h
    ----------% A, ?4 G% M0 u8 B5 A8 P) P
    Time elapsed 132m 49s
    % g( M. _# i; E2 |0 itrain Loss: 5.4290 Acc: 0.6548
    . {* `9 s9 W# A* Y1 a; M; L; hTime elapsed 138m 49s
    8 }( ]. O1 x* B/ c% M$ Pvalid Loss: 6.4208 Acc: 0.6027
    " y4 P* y1 R5 N7 o. E0 @# t/ ]Optimizer learning rate : 0.01000003 d) G0 m3 |' w$ S  I% c& ?7 E
    : ], z( w. s6 H, M7 }: L$ P
    Epoch 4/4/ v  F7 D+ N" B: a- {  Q
    ----------
    + k+ m" z0 k3 X3 J9 T0 mTime elapsed 195m 56s4 W; m' s+ ?* V$ i  p0 I0 Q& c
    train Loss: 8.8911 Acc: 0.5519' i. u1 n! i# f
    Time elapsed 199m 16s( H# ]/ J: F- P6 E' W3 W  v7 `
    valid Loss: 13.2221 Acc: 0.4914; C1 `8 i# n# P5 l: J7 Q
    Optimizer learning rate : 0.0010000# F( _- T" M& p- U- l
    0 A' ^. j4 h8 R  Z$ h' Y2 F9 f
    Training complete in 199m 16s3 q9 B) z( ~) ~& }0 y9 e% m
    Best val Acc: 0.662592
    8 v2 S$ y% V* S& P1 c
    $ R! V+ J: @& j% ~18 f  d8 b( O6 F* K
    2
    3 P! z3 n, w: _  Q3
      h, q; [" Q, o' p: _9 o4  H2 L8 h3 t& S  Q. H  E4 U. k! s& W
    57 D; j; S- i# w) K
    6
    + B- w: K+ E6 X. r) z0 F7! Z% }: z3 \9 g+ }
    81 B" v9 h) o# Q" h
    9
    8 Y+ b' }. P) N  K10- `; I+ z2 m& L4 z
    112 Z& D# x' Y# X- j
    126 n" S/ R6 W+ @% |( S
    133 V. T6 o7 a" E; w
    148 B! U2 u! u" f. l* y
    15
    8 L2 ^/ i7 S$ g1 d! h9 Q16
    3 b! |  K2 k& @# z17  O! n+ O- g) p8 \' L
    18
    - g5 ^5 u7 l+ u/ S19! s' {" g% B! H7 \6 O" @- J# r6 A( i
    20
    7 |- T+ x5 `% m) Y/ ]. u! T8 D21
    9 B( p0 \6 @- q5 I3 X7 A9 _22. g' E5 `7 [3 Q2 j+ W3 i7 I
    23
    * a3 h3 h# }. p( L7 h( Y6 J24
    2 N; W  Q+ |- k7 h7 {258 {5 [, `5 {( b  e8 W- ?' D
    26/ J: s7 [% g2 c5 @$ R! s
    27
    ) c" h4 b5 |6 g* K% j" Z; a28
    3 A5 V: w% f# K# o1 Z0 h% u29+ R0 U8 o$ U9 ?" S% |4 Z1 ~
    30( c' P1 o& l9 ?  H' _( m! r+ n
    31% h: d" A1 W) I9 n# i# Q
    321 H1 o% h; J1 E) C7 W. N+ a+ \0 p
    33
    0 {- G, V0 F& w+ Q3 Q34
    % B: l* w; Z- W3 w; y' l5 v35. C! {" i5 b  }+ v' C" u5 u/ o
    36
    ! q0 j% |1 b( \7 e37
    " D; q% f2 J! D; m38
    + n" |4 a8 n& `$ m% K4 A39
    ; S0 G' L+ G1 X4 N) n! l  l' }, z7 F401 g: O; u; O% _8 c: `$ g  Y% z7 K
    41
    8 L* {: n1 Z0 M9 D42& c$ N" ]* f( Y1 T1 N$ Z6 F
    7.3 训练所有层
    ' k- k" r( W2 t1 e7 ]# 将全部网络解锁进行训练
    9 Y4 Y9 y7 M  E; j- ofor param in model_ft.parameters():" m9 @5 |5 ^( F& N5 ~& z
        param.requires_grad = True
    9 R; ~! O. h3 q- Y  d! q6 F" j- a& n3 S; |/ E
    # 再继续训练所有的参数,学习率调小一点\
    ' m' @7 R: q: r3 X/ d  xoptimizer = optim.Adam(params_to_update, lr = 1e-4)
    4 b0 h) {: q: i1 t7 ischeduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)9 P! {- V; d$ H: Y

    9 |! X4 ~8 C1 V# 损失函数
    ; }7 N0 G" Z+ P0 h, _5 ?+ h. dcriterion = nn.NLLLoss()
    * ~4 W1 C1 n9 W. y; N) R0 ]1+ F% N* Y- D7 v. _
    2
    ( a8 I" P5 n' ^$ M- U! X- N3
    : D& s% d* h( h/ Q4/ v' q. V% Q  R3 w7 k8 I
    51 g7 X' x. G! `5 w8 z9 J9 \- `3 w# s
    6$ z! E- h5 h. ~+ I$ |( E
    71 {: q( D& N3 l
    8( _- f) A& W& o0 L  @
    9
    9 |# a: C( Y6 n4 j6 g' P10& e" {5 @& E( F; V
    # 加载保存的参数, r- X( ^# V2 p
    # 并在原有的模型基础上继续训练
    % x# t7 j2 r  E9 |/ k$ `* t# 下面保存的是刚刚训练效果较好的路径; F- g  S; e. v5 x/ t" P) i
    checkpoint = torch.load(filename)0 ^  t5 v9 X$ C( k; G& b
    best_acc = checkpoint['best_acc']& v, R7 H+ s3 W/ I; x; d9 h
    model_ft.load_state_dict(checkpoint['state_dict'])
    9 j  z# u5 E, u( V4 ~4 x1 d" |8 }1 l) s7 l7 boptimizer.load_state_dict(checkpoint['optimizer'])4 t- ]7 x  Q' B3 B- P- w
    1. A/ t9 T7 _: A$ P( C" q1 Z4 [
    2$ V/ d; d* o8 ~- @
    3
    4 v7 v! |$ K; [- ]* J3 x$ j. \4
    3 c6 K8 X& k: O6 F& W5( Y/ s' p* D% w" d7 j
    6
    : S0 u# N2 r6 n# v0 a7
    + I: ]! P; O& {5 g" T$ @开始训练7 [5 x: M/ e, z4 H' H7 k& X
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考: F3 j/ u! v- U6 T

    1 z6 T3 ]% t1 p& c& amodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    7 I" N( L' S  a1 w4 c1 t1( Q! j* m# D& o
    Epoch 0/1
    4 \' i* v2 y+ t! b; {) c( w. _----------
      d9 J5 n( H  `% x' uTime elapsed 35m 22s: Q7 J9 D6 k6 k
    train Loss: 1.7636 Acc: 0.7346
    ' W0 [* N/ T5 H( p8 ~Time elapsed 38m 42s5 K+ E5 z( O0 Q- C& m
    valid Loss: 3.6377 Acc: 0.6455; E: x/ i( d  r2 D, s2 S
    Optimizer learning rate : 0.0010000
    ) v" f& d3 A# n" D) O0 m
    ' G, u9 J# d9 y3 y5 A5 V' hEpoch 1/1& D, X: O" O  {- j, m! J8 J
    ----------, X' ]4 I: T" [& x; }
    Time elapsed 82m 59s( }3 @, a3 ?  B! o2 y9 B4 p# A
    train Loss: 1.7543 Acc: 0.7340) w1 B0 w/ S8 y5 y/ Z
    Time elapsed 86m 11s
    ( V7 `" D( r& B+ q6 w; O+ y, Dvalid Loss: 3.8275 Acc: 0.6137
    ' A. O# f4 R, Z+ @5 I3 F* |5 SOptimizer learning rate : 0.0010000
    # f' I, o' l- ]  Y; F8 _2 L- b+ j6 Q7 q* M8 Q
    Training complete in 86m 11s* `/ U, \9 O$ {1 O
    Best val Acc: 0.645477
    / C) j- U. ]$ V- b3 z9 M& g/ S2 b; A5 V9 V
    1& g% I" {8 t* W
    2; ^$ i, O$ c0 x8 Z
    3
    . x; }3 |6 _- u# `' R5 ?4
    ( x6 X/ b/ U& E$ F: g' K5
    ! i  ^; l4 f/ g0 Z) x6
    - [1 N& C( g- C& M7
    0 ]0 f' D, H( U- Y8. I2 d5 Y2 F5 n+ x' `
    9
    8 i" S# M$ c& ^0 C10% J) b% z# s9 q% X6 M4 D. |
    11: m3 \3 a5 H4 O* v8 I( K
    12+ v) U2 @+ @$ n/ M5 X7 K) l. j! [- l
    13: H4 w, y) R. N- A  x
    14
    6 x0 o( m! m+ {2 h- C* b/ S3 h15
    # q1 l7 G% w4 U% P$ \8 w16
    / z3 q% R! F7 L8 y17* D( n1 `: [8 b& L  g4 L6 K, ~
    180 F" T- z1 `3 _. l# L
    8. 加载已经训练的模型8 ~5 X/ g% d1 h( F( c0 J3 c
    相当于做一次简单的前向传播(逻辑推理),不用更新参数
    * ~% ?6 j3 ^4 t/ j8 {& d. y
    , M4 k% t% d; g6 A5 I' nmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
    # I/ c9 d5 N* U7 S" v
    ( ~5 D7 {2 O6 a# GPU 模式
    6 r# ]5 z/ v. H2 @8 ]2 ]+ I4 e: d& umodel_ft = model_ft.to(device) # 扔到GPU中/ @& L8 o+ N! X; H5 {
    5 L! ?& f) T7 J5 ?+ L# D
    # 保存文件的名字- }  Z) W: F: a1 M8 j9 ]: n1 o
    filename='checkpoint.pth'8 t: C' e; ~2 C/ r
    ; M" N! l5 W: s" r
    # 加载模型
    ) @6 W& }! \9 T9 g0 T. A1 lcheckpoint = torch.load(filename)
    6 @% m2 D) }  A/ ~- Gbest_acc = checkpoint['best_acc']; S! p3 [& p  }) B$ S" ]% ]- {
    model_ft.load_state_dict(checkpoint['state_dict'])
    3 k. S$ m) c4 K8 O4 U2 S' Z! g1
    * U6 ^8 N: P9 Q6 g2
    - t+ `2 X! j1 J# h. K4 r7 y/ h3' T; v: K: S, ~, l) I: V2 Q
    4
    $ F" R4 s7 g- b; G0 T0 z5
    ' J- s6 C6 D- u65 s  N* J; Y8 _3 y$ v7 @7 P- t
    7
    2 Z% V( Y$ G! u! g. U3 G89 _0 K7 k: s1 s6 ^4 H; |
    9
    2 V! H+ n# i' [8 }+ ?/ I  f5 C4 Y10, D' \! W" `' v- |# t6 W. c5 u
    11
    0 l8 R+ ]0 f1 Q7 E% w2 g12
    " J% D9 a  }) F<All keys matched successfully>: ~+ Y; K4 r; l  i9 N1 F. K2 Z7 c
    1- Q6 w% [4 u  T& g6 C3 A; U* k$ O
    def process_image(image_path):2 s) ]' a( Q3 W, _2 T* n3 J
        # 读取测试集数据
    , v4 t9 M; z6 b# `+ M    img = Image.open(image_path)% b, ~8 q- f! P! }3 D! M3 I
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断: x- z4 i# i+ d# e+ ?" ~1 g0 V
        # 与Resize不同
    & f9 R' z% _! N  d1 S    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    6 v" ^: Z  A0 l2 @( g  |, S    # 而且对象调用方法会直接改变其大小,返回None# D( Z, U* W) ]4 O
        if img.size[0] > img.size[1]:
    " l( e  A' d% a" m+ K4 d. e        img.thumbnail((10000, 256))) t3 q' @8 M! H, |- |+ {
        else:; C* i& l. h- f% @  T, K
            img.thumbnail((256, 10000))
    ( l! B; R  m4 @2 T6 X9 a
    1 _: Y# X+ q" v% \& ~2 R    # crop操作, 将图像再次裁剪为 224 * 224
    & o0 b& F" S5 u+ z' @- P- q    left_margin = (img.width - 224) / 2 # 取中间的部分
    ( M/ y5 F8 o3 `    bottom_margin = (img.height - 224) / 2 2 I3 X9 F: y4 t2 k
        right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度/ i5 _& {* U7 @3 _. i  \
        top_margin = bottom_margin + 224
    2 X7 [6 m- y# w' O$ I- t
    1 V, \8 Q2 B  `7 O. }6 ^    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    5 _# l# Q1 Z. P  u5 D4 B) r8 w& x# }# ]5 p7 e3 u% ]
        # 相同预处理的方法. y& {" i) s! N; y
        # 归一化. D, I1 N5 u/ W0 f
        img = np.array(img) / 255
    4 U2 v5 J2 y7 ~    mean = np.array([0.485, 0.456, 0.406])* e' c$ m7 C& r% ?. j6 S1 o
        std = np.array([0.229, 0.224, 0.225])
    3 h3 t# R6 W- \    img = (img - mean) / std9 H7 t$ W- T& S6 i' a; i' x9 u  k

    ' v9 b  [' e- N0 i; D- e: a    # 注意颜色通道和位置( G; ~: r$ r5 W. k
        img = img.transpose((2, 0, 1)): y8 y3 s; V0 E! E; B

    2 W0 K) m% Q# h1 o; X% T    return img0 P6 @8 P! o: M! H

    ' U1 l$ y, F" c7 x: {. h1 {def imshow(image, ax = None, title = None):
    ( Z4 @9 e9 k% n+ B& y    """展示数据"""3 F. K# [/ z# X" {
        if ax is None:# q: [8 x( m0 o0 R# |
            fig, ax = plt.subplots()
    ' O6 n. A9 T/ S/ Y" Y% o9 r  N
    6 L( J, c% Q$ R% }9 W    # 颜色通道进行还原/ d  ^" L: j0 h$ M- F8 s$ X
        image = np.array(image).transpose((1, 2, 0))' \$ Z0 M6 p2 }1 a* F5 s/ _
    5 C0 \5 M( X& r2 T) A" s5 M6 ~7 ?
        # 预处理还原- f5 j2 [& E7 z& L' O+ }2 Y5 e& V
        mean = np.array([0.485, 0.456, 0.406])6 c$ h7 c- v4 Y; T0 N
        std = np.array([0.229, 0.224, 0.225])8 K' V) U1 j8 D$ q! D* i6 u
        image = std * image + mean
    1 X' P7 j) }1 [1 e3 D9 u8 a6 X3 Z    image = np.clip(image, 0, 1)  q) F& x) |+ k- v3 R7 G

    $ n/ O& P# Q5 t: [+ b4 o/ G( R/ L    ax.imshow(image)
    1 P2 `  f* Q! \, h* j' c3 v    ax.set_title(title)
    4 _$ g: d: O9 j2 y- P. @4 A% F, U5 z: m2 N4 U
        return ax7 A6 G/ z5 J. |0 b
    - s0 @  f$ v+ q+ R$ N
    image_path = r'./flower_data/valid/3/image_06621.jpg'
    # x! v% E$ f0 Uimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    / s6 O9 }7 o7 n8 e! f/ timshow(img)- m2 F( @* j( G3 d) Q( n' F

    : W' q: I* ^  f4 g* V# f; Y1) N& E+ r) m5 C+ b$ K
    2) }* S: i; W( z1 a: V4 U3 z
    3
    5 O7 Y( g- h, h3 D: m. @4
    . z5 W8 T/ z, \1 d+ l4 M5
    ) f4 ^4 d0 A6 P3 Y7 [0 p3 j' @6
    $ ?0 }( X- c% m8 e+ U- K7
    0 e' h; u6 r4 |! d( j8& `( V2 l0 s1 H( j, d/ X- u
    95 @4 U' n8 m% ^# p! ^1 {
    10
    5 `( j+ y" ~! w! w( U119 @2 t' H0 N* o& \0 K0 t- ^4 w3 q* K" C$ ^
    123 J- s7 s; c  \
    133 B# O) v7 l& g; |1 [  \6 H
    14' `: K. O* a# |3 h' n# h
    15: P) w, x/ u2 X4 e: M
    16) I6 g- R  P, b. T5 _
    17
    , ]( b+ m: W, j8 u" c3 [  d* T8 S$ @18' G/ k( A  n2 h" B& z: p+ l9 d
    19
    : |" O: f. z0 ^4 m5 @/ U20
    ) A9 I( k; O" p& m& v1 t21
    * P% ]' I; ~9 S% l4 q( w22
    % {  F8 A( {. u# ]- P* N- _23
    . l5 p3 T+ V# s" t# b* G8 o24
    1 c2 k; t$ [) s& {, C% v) O2 e! G+ S, G25
    / \$ D+ T) L$ {/ d' [  O! Z265 x. I2 R! g; e
    27
    ! M! J! e; v8 S1 K8 q- n. o28
    3 ~2 R% J: \4 ]/ r( Q. p29
    0 |3 j2 N  D/ d' _30
    * X( E; c& k- r9 S31
    8 L2 S# Q0 ^) h. q& r0 X' s" J32
    0 w9 k4 l8 U% R9 x33
    8 n9 P! C: c5 d9 I& F0 X34
    . P/ A% ]9 ~% a% H/ @35
    . l5 V3 ?2 z0 T4 L+ d" o5 e! _36
    4 d+ c! ~8 S" c/ a1 C5 {37
    0 R" O5 j# [- X2 j3 Z3 N38
    7 ]# ^2 |. Y  e39# R0 D; G: K# i8 Y
    40
    ( K& S) O/ x6 }1 }7 D41
    7 _/ ?: S: s3 a42. {: c# u5 Q1 B' e; V3 b5 {
    43. v: I- L- K0 a1 [$ E( w3 \" f
    44' H8 B& v- M, ^( l! K1 l) _
    45" ?5 P, Q8 i2 m' m6 k
    468 G! W5 }( c, U  L
    47
    & [8 K1 y& w# X7 K/ @8 Q486 ?5 M# T" X( m1 _1 j" y1 g
    496 I/ Q1 x$ e, v
    50
    0 Y6 ^" e" U3 h$ d5 m51
    $ g4 c4 S# ?9 f( g% \3 d/ T' e" U3 c52
    * e3 o$ S( _* s53* w- u; b2 u+ d' H$ {
    54
    - e0 u5 G3 P( `$ k0 k( ~<AxesSubplot:>0 R# ?- b* l+ L, w- W3 Z
    1' X; }' h. D: H: ^# S

    $ c% x& A' z& n$ {- f% D1 @, k7 z上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    $ |( H- @* `7 H7 B$ Z( z+ m. o
    img.shape
    $ w! q# c' K- R* i7 A1 k) e( s/ ]1# v1 ?% M* K0 e3 }# C1 l% D
    (3, 224, 224)
      a- C  T4 m& p6 C1/ O6 E; e" e0 p7 s- u$ N/ L5 e
    证明了通道提前了,而且大小没改变4 `+ k# E! u, B/ h; N0 i

    ' Y# k: g; m! ?8 X9. 推理
    $ g# H* y2 u- eimg.shape
    5 Y- C; }. G! Y% `  U' }4 u0 c. I& J/ D/ k3 m* Y5 z
    # 得到一个batch的测试数据
    9 |' a9 D0 Z4 x, }dataiter = iter(dataloaders['valid']); v7 Y0 o( A+ P, @
    images, labels = dataiter.next()
    0 ~  R. L& h2 b4 g8 [8 e1 C5 K1 n. [5 W1 A+ a. \
    model_ft.eval()
    $ D3 r0 P- j7 M. {1 Z( o. c$ _) I6 g; Q, v2 \/ x
    if train_on_gpu:
    , b' z& [  o$ z  T    # 前向传播跑一次会得到output
    - r# n$ u  H& R7 C0 r    output = model_ft(images.cuda())
    0 M5 j, H7 ^$ K- m' l' Aelse:8 p4 O" u) M6 ?: O! M, K) O/ c% Y% ]
        output = model_ft(images)3 ~: `/ Y5 c- T1 q0 N: `4 Q! W
    ( U1 o& g" n8 L/ B' K
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值( ?/ I  b/ p' B
    output.shape) _" @6 q9 S* X' m- U3 V. a% Z

    9 L7 S/ P' T* G2 X7 {7 D3 {( p1* N9 P8 V" E0 p, t7 _$ d* g0 t
    2% g8 {9 M' X* a8 ~5 {  j5 ^' u. S
    3: R& C) l" f9 z
    4
    & n' @% ~& R% D1 Y5: p" v- k1 u) L, E
    6. p) d" h, B/ F' U, a
    7" N3 r  z3 Y5 S# s
    8! k, |( w5 b5 H" }+ y2 D
    9
    % [  |; P- K7 V- e# H- Y! y10  m+ t' d4 S! L
    11# R$ K" ]5 v: O- B+ Y
    12  A3 m7 \! ?+ J4 n$ ]* b, m
    13! z3 g+ p9 V! C. G# J+ x4 S7 Z+ b
    14
    8 h3 Z8 {& Y: L# m6 n. q7 `1 `15
    0 P  O6 e$ d" y  N( ]# O: {16
    0 p1 {+ n( D/ Q$ v" w  Ltorch.Size([8, 102])
    , R7 ~# P+ P+ @2 F; l7 [1% J3 j$ K0 D  M; [2 V$ A9 l' K4 w
    9.1 计算得到最大概率5 g$ ^9 z2 f- n  p, E7 q
    _, preds_tensor = torch.max(output, 1)
    % r% t, e; k) |1 E' }" F' K- J! r3 Q/ {  n* n) {
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    6 z# m6 I; ^3 C6 _1
    5 ?5 H- {% i; f9 W29 @+ a' a. H7 a( I1 _; }# w
    3
    6 v, O+ p" k- O( g8 L9.2 展示预测结果+ L% d0 Z* n1 C. b+ F4 l& ~
    fig = plt.figure(figsize = (20, 20))
    0 }  ~; \- x7 |/ ^columns = 4
    . Y! K2 @7 g. O+ `9 ^, `5 d, hrows = 2
    2 M! T9 L1 C5 B0 B
    + {& w7 q% T+ n# h% {for idx in range(columns * rows):
    3 w  T7 B4 c/ W    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    & q% @' q* N4 w/ t; O4 g5 b- F    plt.imshow(im_convert(images[idx]))4 B+ k, I( y/ F: d$ s
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
    6 p( X/ ]: G, f  J, p6 b                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))" [4 d. ]  @. V& A
    plt.show()
    + M' ?; e$ G, a- h* ]4 k2 `- K( [# 绿色的表示预测是对的,红色表示预测错了
    " ?& t9 w2 w# ~8 e+ N1. O; i3 v) q3 c# _: H/ K' U
    2+ `! w8 ^: T6 s& g$ Z# s& q+ u
    3
    ' o3 x+ t) e2 f, j; u4# _& V6 e7 B  R1 b& M
    5
    7 b' w6 X5 }3 j! l4 _9 ]( l' b6& z( W# j! t# }* {1 d) K; g+ N' @. |
    7
    0 D1 C5 Q* ?" }' ?+ {8. C( F8 X$ o, V. X# X) l# x
    9; K/ U6 y( Y! x& d! n4 k1 _
    10* V4 ]/ i8 h/ l/ x' y: a
    11
    & V6 D: }) q( V1 C
    9 s8 W( c: m0 B' g) B! d4 y0 T; {9 l/ _; P  B4 E1 c
    2 `" t8 q# I  A! _5 Q
    ————————————————
    8 f4 _1 I/ F% W# t版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。5 Z+ x3 J2 q+ E% ^
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940& ]( H1 N$ w1 Q7 _% X2 F

    ! u* j% b0 P% X# j
    ) |* q( F$ `6 n* M. ?/ `( x' ?/ u
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-12 14:07 , Processed in 0.548550 second(s), 51 queries .

    回顶部