QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2743|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    ! y/ A0 C) g4 p. C1 y* {) N0 @. R6 U5 s6 W
    文章目录
    / V- H2 r, ~: Q  E, [+ E; c+ r卷积网络实战 对花进行分类
    ( O) \7 \4 j7 }* a+ U! \数据预处理部分
    ' N0 d0 l4 i) u5 G网络模块设置
    6 n! @3 B6 Q) R4 _# s/ ~' k: L网络模型的保存与测试
    3 G/ v, @' v% Q/ U$ o数据下载:
    ; u1 k: J9 r# K0 e1. 导入工具包4 \) G6 k) X" x, |: q
    2. 数据预处理与操作
    8 Q( h6 L3 h9 l: O5 e3. 制作好数据源
    , ?6 P* h7 W, ?读取标签对应的实际名字. S+ T" {! C+ W( n
    4.展示一下数据
    5 g" W3 q1 L) y) c5. 加载models提供的模型,并直接用训练好的权重做初始化参数& v5 \6 y4 J; r" y# v3 T* |
    6.初始化模型架构- E, M- R# J: w( V+ |# {( D
    7. 设置需要训练的参数7 Z8 a) A* B0 O" U" {
    7. 训练与预测; y2 ~4 E( _( K! [, h
    7.1 优化器设置
    . |" x# o! q' s7.2 开始训练模型2 K* o; {" X+ R' L3 B
    7.3 训练所有层
    $ x' D- \2 D1 I8 U2 n开始训练/ F/ S0 J1 ~% y  }9 k8 t
    8. 加载已经训练的模型& X: y# @6 B: W# j
    9. 推理
    5 k3 B2 o* b9 K1 ^. |9.1 计算得到最大概率
    3 D5 C# F" r' K0 u$ n9.2 展示预测结果' h7 _9 N) \3 F9 P3 [2 K. a
    写在最后% c, ~5 h9 n( n" f! n! b
    卷积网络实战 对花进行分类* L$ w8 s2 p) d& E% n# [# P+ b' ?
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    9 G" T/ z6 Y: h, O) l& \" u9 V, b& w
    : f! l; u2 q0 G- c- V在文件夹中有102种花,我们主要要对这些花进行分类任务. E: |/ n+ h# `
    文件夹结构
    0 U5 u: p& \& @" f. F, {$ T2 m+ H3 m9 C+ F* d# ]8 u
    flower_data
    " l1 n' |6 e; n" r/ K7 f3 S7 R
    ' P" f/ u1 }! {train- }$ ?9 r. y0 {- c; K

    1 s. E3 c* w2 ?1 d( M1(类别)
    ; P* u' g2 ^, n; c0 {0 A2& C2 A5 d8 Y- Y) D
    xxx.png / xxx.jpg
    5 \6 d$ [+ h' o0 o& ?- dvalid) t; {9 V( k) x0 r* M6 g+ \

    0 I) x5 d2 x* e- a" f( c) ~5 [6 W主要分为以下几个大模块1 \1 W1 S/ [+ C, r6 A

    + ~0 d+ T# \: {  R数据预处理部分
    - ^$ V% k8 r# P数据增强; m0 @" ^$ y& h$ m  ^7 k7 j
    数据预处理
    . X8 |8 x5 k) a4 T$ B' _: A/ j网络模块设置
    % y, w5 S9 n0 s8 x( A, Y! F( W加载预训练模型,直接调用torchVision的经典网络架构
      G, [2 E" c6 A2 l, Q因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务) m+ K5 H% n9 j- _6 O8 p4 a% E# E
    网络模型的保存与测试
    9 u! c! E! T/ K( F; A/ K9 k模型保存可以带有选择性, x" O# Y% ^) }( Y
    数据下载:
    ; R/ E! T+ ], K( p" G4 }9 zhttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    ) W3 V: R  J) s& G% Z- z
    & B# V* O0 U% ~+ Y% `2 w改一下文件名,然后将它放到同一根目录就可以了
    / |+ u0 j. A4 g( I
    $ ]" W8 B8 r  W5 B( o$ q, @下面是我的数据根目录0 v* c6 |) S: x$ s  i7 `! H

    # `* y; R+ F( v2 M9 @2 z3 s; q( @" s8 ~
    1. 导入工具包7 S* U; ]) p$ z* o
    import os! a7 M$ p! k! ?2 a
    import matplotlib.pyplot as plt
    7 D7 e# n7 x+ D% w+ X2 u3 _# 内嵌入绘图简去show的句柄& k4 B9 `* d; C
    %matplotlib inline
    - E/ Y# S0 r9 J# w2 L! Dimport numpy as np
    / P2 j) P/ b' L7 m! A+ Wimport torch
    ! M: S8 W8 I$ R1 y0 {from torch import nn
    0 O% Z1 {# s- [. R5 s) O. D( |& \5 u4 `. ~0 [! K
    import torch.optim as optim, q1 r0 R0 B' c  B" z
    import torchvision
    ' f2 D  g; d- A# t" s0 ?, ufrom torchvision import transforms, models, datasets
    ; P$ a! R4 E+ u* F( _( {; d4 z) f- S2 S2 O
    import imageio% k6 T0 _; U+ q; c  j3 G6 F
    import time  i! }3 S1 S7 c* L1 z. S
    import warnings- z9 w# o& x( Q. u
    import random
    6 _( I# U1 T4 w1 U" w' r* L8 eimport sys
    " b$ S. s7 ~" f- aimport copy: F1 j2 h) m6 g* G: J2 f+ h
    import json
    ' C0 y6 u1 k! I% m+ kfrom PIL import Image
    ) Q0 P4 i8 s3 O  J
    " |/ ^( g2 d* x+ x9 n
    3 f1 I4 A% }; B$ p* B1 a15 R. j- n( {: G+ x+ C( x- t: W
    2( B! w) l+ i+ Z3 n) R
    3$ k' {2 E  N' _) R4 X3 `
    4
    9 w: N% _" n& s) }) `% U4 I: b5
    8 G8 j: M. N" M/ _61 h/ _! q% H5 a* W; b* _: a" C
    7, V: f; C) l* b# p
    8
    & g. m, {( l8 h. K! h( G6 E95 x  Z' I- `4 S/ h" L" d$ Z
    10: U0 b6 a2 U/ ]7 C) f
    11
    + R7 `# z- b7 k, I6 ]+ j" S# v12
    - }% C6 ^3 _) {1 }9 e13
    # k/ n$ E2 q% n, ?; a5 W1 g145 ?8 `7 T$ r1 W" ~
    156 [& i8 M7 ?8 \/ h( ~- K
    16& j4 Y. _2 Y, x! r
    17: q0 _8 T% F: G& ^$ t# y
    18
      J: D, E2 S; t8 `; }, c! B19
    4 Q: s9 r$ N1 s3 U4 a20& v. |6 C' R: M4 T. o/ T
    21+ ~5 s, y8 z5 k6 c. [
    2. 数据预处理与操作1 [$ B  c4 K: o( _# d9 _. A' l7 r
    #路径设置
    , L7 O8 O! |$ O4 i4 n$ }8 E" |data_dir = './flower_data/' # 当前文件夹下的flowerdata目录5 E  U* n4 c& _2 L; U# X
    train_dir = data_dir + '/train'
    5 h7 C3 \6 X. j# uvalid_dir = data_dir + '/valid'
    7 P' f  t7 Q! Q3 e, q1 X; Y+ v1
    6 L3 h3 G# O* U5 U7 G2' B6 g5 K' a/ s) A# X
    3& A, ]' W7 h' k4 U6 `; d
    4
    ' X$ L9 R: L7 e3 h; w! Npython目录点杠的组合与区别5 A* a5 j% g; J1 b5 a8 u& |
    注: 里面注明了点杠和斜杠的操作
    * t+ Y: q# E( x  |  C' p# s9 B6 ]2 D! i! m  A* Z- _, Z
    3. 制作好数据源
    ) ^8 U2 D3 p0 Kdata_transforms中制定了所有图像预处理的操作
    5 m0 R2 v3 q+ a- h7 K) ZImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片0 d; J1 d6 S8 G% ^1 S
    data_transforms = {! x5 i3 f7 `9 }* \# h' a: U
        # 分成两部分,一部分是训练
    - R0 R2 @+ v" g& s    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间' J! R0 V: ?8 I
                                     transforms.CenterCrop(224), # 从中心处开始裁剪; L0 z7 Q) t  R! v, Z; j
                                     # 以某个随机的概率决定是否翻转 55开( q0 E" @% u6 j6 ], K. {! ]
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    , L& C+ l0 x6 f0 r                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转* J6 w# v- Y1 X6 @! d( l+ V, d
                                     # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相9 R  d( y& E% r
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),* K2 E2 Q! r5 Q' G6 u
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    & J$ Y: x/ A9 w% P% A                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    5 r7 r" a: r3 O5 e  t$ x1 X2 A                                 transforms.ToTensor(),5 M/ [$ R# z& c  T* _
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    4 x) w9 M8 X9 ~2 }                                ]),
    $ ~: P/ ^* \( r    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化6 K" w2 L. f. W0 J" w
        'valid': transforms.Compose([transforms.Resize(256),
    ) n" q. `. w+ f! C                                 transforms.CenterCrop(224),) \( n; z! ~( s8 w$ p2 a. S2 X! ^
                                     transforms.ToTensor(),
    ' f6 N9 w! E4 x" x1 H+ @8 H' W                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同& C4 V  C1 h7 N3 ~
                                    ]),
    1 T+ ]1 b& p6 ?8 i- n% |}
      R# Q# W' G" I4 l; h, c1 P! W, d# v0 d
    19 {+ \4 H5 o  X% y
    2+ Z* m. ~* X5 h+ q$ a1 j
    37 Y7 v2 O+ p' A& d' N0 P# g
    4
    / p( U' S! C1 K+ |5
    % c! h- T5 C8 S5 o6 i; P6' O) H# G" o. z3 x2 ^
    7
    / y3 ^. C5 H+ F5 Q" M3 i8
    2 C+ t- g0 K* j2 M9
    ! }& S) p1 C% ^/ O8 r6 }10
    4 q9 X! |8 T7 x* h11
    8 S: m+ ]$ I# I2 M8 L, O3 Z12" w# k/ T* r! ]% g: d- x3 P  C  d
    13
    3 J. V& `3 n# t. H" b3 E7 k3 t2 P% I5 E14. D8 h/ b/ |# s2 z( f
    15& g3 |7 b1 g! t5 I' h9 L
    16
    . P1 J- H; y0 I  k' d7 S0 H, V; c173 L' N0 o4 v4 c# Q' T- v
    18
    1 j9 T  v8 G' f. ^2 G4 v, K19* x$ ?+ S9 V5 Y: S
    20
    : O! n  C) b' G2 U21/ W5 y( \& U* T7 }; L7 N, \+ P
    batch_size = 8
    9 j. Q- g7 ]8 W8 F7 U- T# aimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    8 `9 i4 P7 e2 q) `3 I1 e/ T8 b5 sdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    + c2 o! J$ Y, `0 S2 j) i. `5 ?2 gdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    6 t) k! N" j" i, M) t, D6 \* vclass_names = image_datasets['train'].classes, f& O3 E; `( ]! \4 }1 O

    # p2 b/ K, g; \) t/ L#查看数据集合4 w; v& p; j1 ?  `( T7 i) n& }
    image_datasets
    # l* Y% \7 Y2 T" C& g  J% Q* V
    1
    3 n. ?% ~) v8 y' ]2
    % ?+ i( Z5 G9 J3
    " F$ k& J% i6 [6 N0 ]4
    2 u# R0 c2 Z! V$ j. d. ]56 K% \' o+ M. f) U( ?
    6
    5 G$ w0 b2 _) {7 }9 h$ t; T7
      f7 b' p! |3 _8 R4 L+ f- S& j8+ F0 q$ n" W' p/ v% X; e
    9
      ~8 z" f- n1 M5 Y1 d* ]{'train': Dataset ImageFolder
    / `4 ]3 O5 j( [' j# l; J     Number of datapoints: 6552# N! Q8 b2 u: S; C) t0 u! |& [* Z
         Root location: ./flower_data/train& L& y) O- h, I: q- u
         StandardTransform
    4 G, z* Z8 A& O$ A Transform: Compose(" \" ?( e) g; \+ E8 d
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
    8 W6 b; G$ k5 s1 v                CenterCrop(size=(224, 224))* W( r$ G; y/ m7 ?6 m
                    RandomHorizontalFlip(p=0.5)
    * M$ a3 l# z  k# o( V( ]: J                RandomVerticalFlip(p=0.5)
    1 I% b2 G8 g$ x, _& f; N6 y                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])  I1 l2 i! V! U, f
                    RandomGrayscale(p=0.025)( |. d% ^0 c  \3 W
                    ToTensor()8 b6 ?& q; A7 X& t2 r1 p1 D2 M, c
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ) q& X8 H6 R$ H            ),
    ' n( |& I0 ^0 J' ^# q 'valid': Dataset ImageFolder1 {& N; a( P$ X0 ^2 T% C
         Number of datapoints: 818
    9 y, o& P$ y/ w' S& e# j; W" n     Root location: ./flower_data/valid
    ' W. B1 K2 ?, z, j& g7 o* \/ @/ z     StandardTransform5 l% y# x4 E0 J  Q6 \2 h
    Transform: Compose(
    ' U& _/ E: S9 L$ s1 z                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)' o  R+ C* m+ E: p. U& u
                    CenterCrop(size=(224, 224))
    9 i2 K( T7 q2 F# B4 g( ]: m7 j                ToTensor()
    ' |# r, e) f$ Z                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])8 V; t% O# ]% b2 ?
                )}
    7 L- a2 K) l  N# M7 O9 c9 M
    5 }. p+ A' o' y8 j& k" S7 K10 W3 j" f3 L4 |9 P5 {$ P& b8 Z& ^  B# K
    2  l2 g. Z5 Y% ]
    3" e7 i4 i2 @' S1 M9 V1 X( R$ @
    41 P1 r  S7 B6 M+ b& q3 U
    5
    7 [/ t! ?1 Q* B; [6
      C% c' a1 a& }2 ?8 ^8 t0 U7 {- F7
    + ~  i& n* O8 W# \; G80 P% s6 J6 k- b7 n" t& b, O
    9
    ( W+ |% v- _+ S, T10
    $ [4 I5 N* u  p) |11
      V% Z! r' v# J  W+ Y' x2 B) J3 n12
    ) m/ q, h& D6 M. C; X13
    , ?2 ~: W) U- |6 r7 x1 T14: K  P2 L7 X, B5 w8 e0 [
    15
    : ]* r7 \1 m* ~, K% F16
      r% O4 C2 y4 X& y) e172 q3 e% t: F2 O! E' B& D5 z
    18# j8 P/ k+ Y8 Q1 K) H. v
    19
    / M/ Z+ g. {( e0 a20
    5 Y9 Z% X3 l' {0 W- I21
    ( S8 k1 ?- Z9 G: \+ v+ \/ N7 H22
    4 q+ e: R( c3 x" Y235 }% h3 D3 }, V. C
    24
    6 I' ~1 N# y+ Q2 r7 R: f# 验证一下数据是否已经被处理完毕6 c( M" [3 H/ j
    dataloaders
    ; `- H+ k! h) `$ p( U1+ R, E' S) z: m
    2& u; T- V$ X& ]% Q) w" `  p3 O* F
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,7 S$ v# _0 R6 u" V' x3 |
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}9 {- K; o8 G3 U( o2 Z! g) g
    13 S) Z; X# e: Y# p
    2
    7 e. ~: F* {6 V- |* \6 |dataset_sizes
    . k6 K" o2 o4 {9 z% @5 x1
    1 o+ J9 q) c1 l5 j) D+ _; g{'train': 6552, 'valid': 818}
    ; e& H% {) j/ A% ]; E6 v9 I1
    $ S8 X$ G) c* i$ K读取标签对应的实际名字4 z' N- O2 [( x! }1 B
    使用同一目录下的json文件,反向映射出花对应的名字8 o: v. P& ?, {
    ; `- V+ `. ?4 P/ Y0 T9 }5 W
    with open('./flower_data/cat_to_name.json', 'r') as f:
    ( {! @4 M: p/ {& y    cat_to_name = json.load(f)
    9 b; S1 [/ {- ~, ^8 P1/ x! \3 p7 G3 O% Y4 f
    2: \' `: m' z1 W% N% y
    cat_to_name
    + S1 p/ \( u7 a1
    : g, o8 a: \$ K! h, J( N/ n{'21': 'fire lily',% b5 {3 T* G/ s
    '3': 'canterbury bells',
      q' f1 p& i% |$ x  j" ^ '45': 'bolero deep blue',1 G. T" \) A) E% \0 v
    '1': 'pink primrose',
    & N) ^2 n% E' V! L/ N. Z. h '34': 'mexican aster',
    , g' Q# a3 h# Q0 B  q/ S. p '27': 'prince of wales feathers',& n$ \( {; C4 }% P& H  o$ ?
    '7': 'moon orchid',/ p" x) u7 S3 a8 k! m* f" k
    '16': 'globe-flower',
    $ }3 c9 G  y6 h/ ?. Q7 f5 T '25': 'grape hyacinth',& @3 J7 W! Q8 F2 r- P4 y& d
    '26': 'corn poppy',
    ! J: q# ~5 y6 \$ B6 _* j) K( ^" O8 \3 Y '79': 'toad lily',$ l5 N. Q- R. D" _; g9 I, l% }
    '39': 'siam tulip',% M) d$ E4 f2 a( c9 P3 S- p; [
    '24': 'red ginger',! w6 `. ?+ N. X! k) e
    '67': 'spring crocus',
    " ?: Q7 i& J. D+ }1 h '35': 'alpine sea holly',
    8 |% D# Q' |' \5 _+ X7 l' j '32': 'garden phlox',
    6 P- k: A) T- h4 f- G) e" I4 n '10': 'globe thistle',
    0 h$ L. w3 s$ ~" f '6': 'tiger lily',5 I+ H* |4 C/ N; o
    '93': 'ball moss',# V: g# q0 o1 C
    '33': 'love in the mist',4 t9 P: m1 _" F- U) }9 ?
    '9': 'monkshood',0 r# }2 m6 K) X/ O
    '102': 'blackberry lily',' S; j! K; b5 I+ \+ `9 s8 q
    '14': 'spear thistle',5 v, }: q# {1 @# B
    '19': 'balloon flower',
    8 K. _$ h8 g  f0 i, M! M: B* P& f/ M '100': 'blanket flower',$ F- {  z$ d7 n
    '13': 'king protea',
    ) J& D0 ]8 U: q6 y% O, b! V '49': 'oxeye daisy',
    " n* U6 o9 T( d '15': 'yellow iris',5 y& S0 ?: V  @: ]0 D1 }& [
    '61': 'cautleya spicata',
    5 h0 E0 c2 V* P8 @$ Q8 S '31': 'carnation',
    / {+ U+ R$ I( q '64': 'silverbush',
    5 j* `- @) t& w  d1 ^. Y '68': 'bearded iris',9 w- N" ~4 z( p( w; R6 L' \
    '63': 'black-eyed susan',' Y2 s9 y6 S, N# T7 Z; z2 A
    '69': 'windflower',5 f9 |' @- q& M7 p9 k/ f
    '62': 'japanese anemone',
    ( J& g) ~( }+ M '20': 'giant white arum lily',
    ! C& \0 \+ G: V4 N '38': 'great masterwort',
    ) I( t- @% H5 Z" ?+ m. K4 Z '4': 'sweet pea',
    # D5 [& L  o9 D; X, F '86': 'tree mallow',6 z: y" p, Z  X, F: l! R
    '101': 'trumpet creeper',
    ! Y. |6 k% k' k '42': 'daffodil',
    . n' T: M, x# l1 c$ c, H8 z6 } '22': 'pincushion flower',6 P' t% [1 G* c
    '2': 'hard-leaved pocket orchid',
    1 }6 _+ {8 ~# y: d '54': 'sunflower',  W3 |3 l- s' u: D: r& S: a
    '66': 'osteospermum',5 l4 G2 A' N- j% f/ @
    '70': 'tree poppy',
    % @8 W0 o6 t& M* u' N6 ` '85': 'desert-rose',
    % c. |- U: v3 K& e: A" q. e1 _ '99': 'bromelia',
    - Y2 Y' K* a/ t7 H '87': 'magnolia',
    : G3 ]6 b9 ]- I" S) R '5': 'english marigold',
    + C+ L& Z8 G2 D3 u '92': 'bee balm',
    ( G% J" j& q& w9 g '28': 'stemless gentian',
      H! y" A5 F1 B8 Z7 U( j '97': 'mallow',
    $ ~" k5 [1 f9 z: T. I0 b  H9 d1 `/ v3 k '57': 'gaura',2 d) H) p3 x0 Y4 K- X) \
    '40': 'lenten rose'," L1 \' _0 W( j) S* Q" g
    '47': 'marigold',8 i& _& q5 m! |
    '59': 'orange dahlia',
    / Q% P: e+ w$ I, {' `6 w* C '48': 'buttercup',
    . g+ I' e3 i6 J7 F: [6 D '55': 'pelargonium',7 h6 h, z+ D# A4 @+ ?9 p
    '36': 'ruby-lipped cattleya',0 W/ c4 r0 |0 A! a/ g3 P
    '91': 'hippeastrum',
    0 I' l# v. y) h6 e; [ '29': 'artichoke',& |5 B0 r( U+ O* w" r- T0 U! |
    '71': 'gazania',& |+ ?+ s" s% V& l- R& O: k8 a. L
    '90': 'canna lily',
    ( T  |( M7 D+ t* X3 ~: C '18': 'peruvian lily',
    ( V. k1 m* b5 u '98': 'mexican petunia',
    : P2 {4 O- a' p! M, \1 v2 o; \ '8': 'bird of paradise',8 g' V. Z; N' c- v( A3 y8 K- Z% l
    '30': 'sweet william',
    9 t/ e. U0 n& O '17': 'purple coneflower'," I% d- w( z. H" p
    '52': 'wild pansy',
    / W* o- M9 h. i+ B5 Y '84': 'columbine',
    ( {% g" _: B, B1 A. b '12': "colt's foot",/ [  P# l4 h; S" x4 S
    '11': 'snapdragon',
    1 v) o  [# }% v% ?$ ^- a '96': 'camellia',2 w) ]& p! p  ]$ q" s" R, w
    '23': 'fritillary',
    9 D  S% S5 f; A; s  o* z '50': 'common dandelion',7 L. y$ @- I. }
    '44': 'poinsettia',3 _# _3 P0 y1 T7 D& w
    '53': 'primula',$ v4 B+ l0 u! j6 m
    '72': 'azalea',# T! [& f, ?; v1 d6 D- K
    '65': 'californian poppy',
    7 [* c5 z8 {6 O* D '80': 'anthurium',
    5 Z$ i9 U9 m/ J+ Z' v '76': 'morning glory',4 a1 r9 T/ [: e( l' r
    '37': 'cape flower',8 M1 U8 s1 V9 z6 A/ O7 u* x& v' r
    '56': 'bishop of llandaff',  u$ \, C3 E8 X  I# @
    '60': 'pink-yellow dahlia',! S# U: M' Z, x: I
    '82': 'clematis',
    # \1 m8 \! d5 w '58': 'geranium',4 S1 }+ [; g0 T: l7 h3 w/ `
    '75': 'thorn apple',5 o! }( g  D( E. b* C' [" g
    '41': 'barbeton daisy',
      c" D" }  Q& x& t9 Q' E '95': 'bougainvillea',
    8 P8 L  Q; ^2 T4 c4 q! X* L '43': 'sword lily',  o/ @* o& B4 O' }9 a8 u
    '83': 'hibiscus',% t) S. d9 E  x) P+ G& }
    '78': 'lotus lotus',
    5 m. y3 s# D# h' q$ z9 ?- f '88': 'cyclamen',
    % g" v# F5 }# b '94': 'foxglove',; j" A7 P2 Q& c- v2 l8 }
    '81': 'frangipani',
    : ^: M( S& v6 x* G+ s '74': 'rose',
    3 t0 d% x; K3 C1 R6 ^ '89': 'watercress',
    . {  e6 N1 E/ ?- k: p3 a '73': 'water lily',% E2 K' c" v5 ?3 I+ j, O  h
    '46': 'wallflower',
    ' p, V+ |1 A; U2 d/ b+ B '77': 'passion flower',& F3 ?/ r7 ^, E' r
    '51': 'petunia'}: I% _. R  P, l6 M, V

    ; O2 k9 a0 J; G3 W' N( b+ m1
    0 P1 R7 u5 k6 n" L& r2
    & p. y" V% P& [) ]* E3
    0 s! _3 [7 Y! V- v$ @4: ?3 U& K  n! W; t
    54 V' _4 ]( l# U, [* o4 i
    6" U4 u, r% E0 x8 a# T
    7* G4 @) @1 @1 K9 O# E7 @
    8
    ) S1 i5 B* F$ P  D9
    ; `  k+ ~) M3 E) l' \0 ^  x$ x2 o10
    1 q" [) @7 M8 H+ x5 d( {11% I4 S* b' u  `) A
    12) H' H5 S" P3 T  K6 m' Y: J# s
    13: K$ Q: h# }2 U) q( d0 O7 Q
    14
    8 `" j- K, t7 M- h7 O# P% ~$ o. q153 `, w" m3 N4 ^$ E9 g; @
    165 V6 M7 Q+ M/ s+ _; p7 s
    17# M9 E' |- |4 [$ i# m2 U
    18/ W. p0 P2 g; G
    19
      l$ H  b  o6 x+ X200 r" I9 J" r, Q( E1 R* s2 e
    21
    6 e( k* i" X  v  P' q22# h2 i* E: j% c
    237 S' A# t6 A& A) i# T# a1 l# c
    24
    % R0 M2 X; J2 C25
    " g: @- m4 d/ t2 S1 W9 y5 q+ ^26) B; |' i+ s- T) w: b  e2 _/ I' q" o! |
    27; b( H9 s6 y6 c( r$ K! r  d
    28
    / g! h* F+ {/ d3 Q" _1 V29$ a5 t- l; \& u8 c" T3 _
    307 q3 ?+ \# ?8 B9 u$ [' M* x* ]
    31
    " z8 G, R! ~, J1 O0 P0 t32) x# l5 C+ H1 g* n2 g- a
    33
    5 E2 y1 o+ r: _4 x34
    * r/ l, d! G. j% t353 I2 |* |: h& Z4 x$ c- @+ b
    36) D& _3 |8 u2 U5 C( P/ c
    37
    ! U( i% T5 Y( R38
      i, t: s4 b0 }6 n391 b7 E; N: C" I9 B
    40
    % i. N9 j( B: g" u3 Q. ~8 D8 c4 |419 P. P; E8 ^# P% K8 u- y
    42% |; M0 m( N9 F3 m3 r1 z( t
    43
    7 `% ~: Q' V* }# Z441 @9 s( `& _$ y2 E3 Q! e
    45
    , ^& @1 E, I0 K) }7 z46
    * q5 y9 e4 Y% b$ }477 ]: U3 Z- C; b5 {# D( t
    48+ m/ v% S4 R8 U3 M( O
    49- D: T- S: n1 W' i
    504 f# X, X  B3 e9 E6 g% T* v
    51
    2 k9 }/ _- E- C52
    - M1 y; s  m) a9 i  {) z8 @- L53. p/ h' Z1 ?/ ?) E1 A
    54
    % g$ P) j( c. ?55
    ; ]0 ~0 \6 L+ E" Z' y) f6 n56
    ( k' l- M! p2 O+ x5 V& j578 X. ]& s6 T3 u0 b) e  r2 ~
    58
    ; r% o1 F$ ~3 H  N59
    * ?. M2 z5 {2 Q  q: N6 k60
    0 v7 y2 q7 g" [( Q& {1 w61+ J7 p3 u& A! B9 ?- M3 V# z
    628 U: h( p6 f5 x' s: ~; d5 n
    631 f& ~5 p. i: D3 N, `
    64" ]' M( o" ]5 p
    65
    & Z* ^7 w; n$ v7 p) T- e+ u66. a' S9 C6 @  O) c6 v
    67
    ; h) n+ q0 ?% p4 i68
    - F7 L! `6 c, N! b0 D69, r- ^6 `) U7 _6 _* ]5 D
    70
    / E5 L* z* n( ]; _' G& R71. q& Y( A7 T% F; E1 X5 r
    72
    + z$ d0 z9 ^  ?! w735 I$ w2 l: A- c" f! B
    74
    ! C6 i) a& o9 R/ }' s75
    $ I: R* V) [$ U. m; t- E76
    - i: D0 I( a1 Z77
    ! w7 W  W) j- r9 H# P! O780 R4 i# H/ F2 A
    79$ m/ _! m& p( ~; r/ J( y5 _7 c
    80
    9 l: g! L) ]/ V) u& D# m819 r+ e: E9 S2 S5 P* F
    82
    " V; g# ^) C. o83/ C* Z! J+ x. o; h% {) V
    84
    9 Z! D9 z2 ^5 H3 X# u) S85+ a+ b; @1 P0 v$ y( W: I
    86+ q& [7 i2 v6 D: n; |
    877 `- E$ m" h9 p8 S, T+ v8 X
    88
      D5 p5 }# }2 Y/ c' O+ X) o89
    / t0 `& Q+ P; ]90
    + ^; |4 i* t+ A( G91  Y- u6 ~6 `4 F. G4 {2 @3 b; T
    92: Z# E' E% c5 N  U* X6 R# I
    933 f) y3 N4 H6 ?8 O
    94: U5 r, \: V7 v
    95
    4 o; ^% u0 j- a+ b5 U' S$ ]4 q7 ]965 U: ?4 u. n: d" ~# e
    97
    6 S& f( i7 V$ [! [+ J" N2 `98
    ! L' u3 W, c$ }" ^; x8 q9 y99; S2 R; \$ i6 F2 P3 [  q" ~
    100
    5 ^+ T: B. c* W. l$ n1 E101& S: }# f( a& W0 G; g
    102* {% ]" L% ^2 M; f# j& H
    4.展示一下数据( t; m/ Z9 q9 [
    def im_convert(tensor):
    ' I7 v6 j7 g+ Q    """数据展示"""
    . Y+ [+ Z- V2 D5 f    image = tensor.to("cpu").clone().detach()
    ) E: H# ?% K3 X2 m( W    image = image.numpy().squeeze(), U; g; z% _/ z/ Q  {; \: t
        # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图! M' M, M. i' w5 a- S
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)7 j2 ^+ {! P1 ^  E  s7 m1 C% T5 f
        image = image.transpose(1, 2, 0)0 n$ m( J& c$ R4 h; {$ x, `9 u
        # 反正则化(反标准化)
    + G5 q1 `1 a; A0 m1 ~  |    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))9 X/ ~+ H' a; ]4 `

    8 Q1 E+ j: m, V2 Y3 T) s* g- P7 d5 c    # 将图像中小于0 的都换成0,大于的都变成1
    2 `2 s/ C: {& p$ A    image = image.clip(0, 1)
    : k1 e! j, f+ G1 [$ p1 G5 ]1 C  K9 R  I; K
        return image
    $ `% i( l7 L) S1' c, G4 J0 V  W' @' O: f5 F( \
    2
    2 m; Z0 P, g( x* D0 }% u3  u: k- R& d. T4 N
    4
    3 ^0 M# L& f: d4 h5* I  J' t* j1 j9 ~. D! k3 F+ t
    6
    1 |, w  B& F# G. m" R4 G5 s! K+ H! O( V7
    # R  H2 m" A4 b) j1 Q6 i8- d9 \% W3 W$ R* e# d
    93 d7 r/ l# T! N# f7 u
    10- X/ h' I3 @* T- i% T5 {
    11- |  T% c9 ?6 G' a7 F4 A
    12, P  ~. _1 F7 c+ m
    13
    % r5 t6 H( f  |14
    3 _. d( ^" P% x. P: @( L# 使用上面定义好的类进行画图2 p2 U3 L1 H" M
    fig = plt.figure(figsize = (20, 12))7 C2 a8 Z5 z% q; W
    columns = 4
    1 J* G/ _( e$ x' D! K+ T* Mrows = 21 }# W- D  E  t' w" @3 t; l1 \! @

    % ~4 {! j: h" @( O# iter迭代器
    # z; o; c$ d- a1 ~7 D* T, S# 随便找一个Batch数据进行展示
      j( J2 f& X# cdataiter = iter(dataloaders['valid'])
    1 F" i8 G* B  o4 dinputs, classes = dataiter.next()* E+ U0 L6 K2 ?

    % T7 t; o! I6 N; m9 A4 n2 yfor idx in range(columns * rows):4 U! ]8 Y6 }4 e+ J0 e% q+ s+ M9 s) K
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    0 O% A& W6 H( z    # 利用json文件将其对应花的类型打印在图片中0 q9 J4 X* R2 D8 r
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])) J% b/ `0 u) Z8 a7 K7 D3 a0 ?6 [8 K
        plt.imshow(im_convert(inputs[idx]))
    - {+ V0 y' F0 o7 T- O) U9 Aplt.show()
    9 ~2 n+ l# B4 {
    1 h) G, r; Y' P+ Z1
    7 ~6 T+ J$ S) Y4 m" b2* u& O* I; s% T# Q$ n
    3
    & G% u, T0 X4 c' K( I4
    " |# ~/ P7 O# d7 b$ B3 v9 b: D$ z& X5
    6 S( p; X: D$ i- u# I: ~/ H6. V( p! r  Q+ C! W6 D3 P6 U! ~
    7! e$ I/ w4 n7 i# b- `" X
    8
    ! o& d4 o6 |1 h3 @9
    7 Q) d. g5 d4 X+ d) v* E10
    3 Q2 G, a) N; E" I11
    % t$ P( K# @! C1 T1 ~# M- G# t12# t4 o& [% k5 }  [  ~' [+ y
    13$ X" d. B0 J/ n9 h
    14. K5 Y) g9 k6 A% q! p3 j7 |( c) e& T
    158 A* u; z" ^. Z1 o
    16
    " W9 h- U1 A) j) _: J
    ; ^/ J9 I- n' r2 L0 @3 Y; a
    0 g! D- Q4 }; ^. S" h% W, B: r* u5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    - z  ]) h- d% P- @/ C: `model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    7 c# E! D( D% C) W2 s# 主要的图像识别用resnet来做# I' i0 p+ j( r0 H7 N
    # 是否用人家训练好的特征
    + {( ~0 {  A/ j" _feature_extract = True' M+ R2 b& d# n3 R6 f- P
    1
    5 _" j8 N' F5 Z! K1 W2
    : l" V; @; o8 w8 ]6 V+ L32 _# J+ o) ~- @: r
    4' N" w. f1 B% Y) U3 t
    # 是否用GPU进行训练
    ) F* S( ?8 t3 L, v7 i; r8 I5 h: S) gtrain_on_gpu = torch.cuda.is_available()6 c& y; c' s4 w) P& Y) L
    4 ^1 c6 |0 k; ~+ g0 o. j
    if not train_on_gpu:8 t: H& @/ f/ s# `- E6 n( M, {
        print('CUDA is not available.   Training on CPU ...')
    3 h6 d$ h, \6 c# t- Z- A( Belse:
    0 w( w- e4 q  f% T: {, I  }  |    print('CUDA is available! Training on GPU ...'). [9 k  U7 _$ |

    6 F" `! I& t1 i" _device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    ; T7 I8 B/ F( ?+ J9 _( n1 d: R1/ ]( S( T$ ^3 J& P3 Q# N: O
    26 `/ H  i. o6 e$ U% ]% U
    3
    3 X# ~5 b- x, Q* ]& f. \, e4
    + w" P1 F; M5 K( C. q  N5
    4 D6 U5 e2 [1 F! B  @, _5 m6: }8 h  `8 u$ M, r
    7
    0 M/ }3 l1 _* P1 a: z4 z81 S4 U7 W/ n% a2 c) {9 e
    9# Y8 ~. l! D! Z/ b+ j1 ^8 D
    CUDA is not available.   Training on CPU ...
    ; k  [, {; s# s) I9 c2 U8 y1. }: Y0 K# m/ x$ J$ C, L" w
    # 将一些层定义为false,使其不自动更新5 V% W( j. M* N: j) R) o
    def set_parameter_requires_grad(model, feature_extracting):/ G2 r- s5 \0 E7 Y4 C- F
        if feature_extracting:  |! ?$ A6 r/ ^7 y3 f
            for param in model.parameters():% r% a0 v! B6 y3 o: M
                param.requires_grad = False- G: x9 e& e* r
    1) N0 W& G/ f' h9 @
    2* V/ n; j4 V, @$ C
    3) _9 d, K% w0 Y4 S* O8 r3 e' D1 x: k
    44 G/ C) L- N  s9 i8 P1 z# {, d- j' Z: }
    5
    & O9 e( a3 [/ O! F# 打印模型架构告知是怎么一步一步去完成的
    " L  @5 N  U. o% i1 e# 主要是为我们提取特征的- G% f, {2 G  @0 m$ U
    6 x- O# [4 N& L& D/ H8 @
    model_ft = models.resnet152()
    ; \9 P5 V1 A$ o& ~1 k% H$ Hmodel_ft
    9 I# Z( R/ L/ X1
    - ^  o7 [, r$ d3 V4 n2) _# K) ]) E9 n; F. E: h# Y0 ?
    3
    ; \0 f1 B- J& O+ [4
    . K7 X- `) o$ |, q3 u$ A5
    - [% Q4 Z$ d# f6 _4 L! z+ cResNet(
    1 `  F4 d! G( i( u% s5 e0 {; X5 z2 z  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)7 ?+ @9 l- f3 `& R
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    # R; I8 b8 a' r7 N  (relu): ReLU(inplace=True)1 x2 H- w) L. l# [' J
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    - `( s5 H" o: }- S( I+ p4 q  (layer1): Sequential(
    * z1 k' {: u; D9 f8 P% R9 k5 n" ?3 g    (0): Bottleneck(
    # S1 \2 j" e1 u' a: h: T8 m      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    6 B5 {( s7 J( t' A: I0 X/ |      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)2 G8 A9 z6 A7 ?8 f3 ^2 \
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)* o, J5 B" q. i7 K8 l: Z$ f
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    % ?' n& k) V: ?) H* p  E5 i5 h      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)9 F( q- j) R; J4 }1 }1 v7 M
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ w  Q$ ~1 I4 v  M9 T0 i
          (relu): ReLU(inplace=True)( p7 W7 k, W  i, I2 f( V4 X
          (downsample): Sequential(9 ^6 b1 O4 C; `1 I' ]$ ?1 ]$ k
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)- c/ Z5 b2 H- l( ?9 N$ m
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    0 u$ U4 n% I1 _( J      ); I0 s9 @5 N" N+ r7 e
        )
    - i* r3 y  M0 a+ h: q2 I中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    # W# R) Y" u% z) p( z- B    (2): Bottleneck(
    $ P# k+ d2 V! K9 V! L3 x( m      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    2 P% i' g, Z6 P7 _6 {$ E9 y" |2 S3 Y      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True). _2 e5 Y4 d- @0 g+ T- C
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    " r+ C' k5 ^1 n# W% W7 V      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)0 E$ `! ?  K- V$ h
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)) J) f( m' |' u, h7 E6 `
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    6 Y+ Y, b2 W, O/ V4 ?$ r      (relu): ReLU(inplace=True)- ?- p0 G) Q# @, {" N
        )
    ' z+ j0 K1 E2 L8 o  )
    " y3 h& |1 T7 O# |  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    . Z0 c- P8 R$ @5 {/ P  (fc): Linear(in_features=2048, out_features=1000, bias=True)
    6 i1 E6 e) I& `4 G)4 h# y9 \3 d# X4 t5 w- }* y
    5 P8 z3 \9 E0 z9 x6 A
    14 g/ I& c3 V* d- A) Z( F- t0 a1 C, W
    2
    / z- M- S1 p3 s% N' }3
    3 [7 ~5 q# \; Z; l44 S: D. q, V+ y& b; }7 l) ^( }
    5
    ' w/ f7 J. v% b& K+ x1 T6  _4 a) x; p' J
    7$ ]/ t/ l9 b/ N% g+ r" v$ `: `
    8
    - H& P, f% }# a/ @0 s7 U* m1 w9
    . T4 G+ S5 ], C6 X+ E2 L9 h% _10, c4 N; W7 M" a! v
    11
    1 P$ B9 H& @- ^12: U. e4 D9 |/ r. I) ~/ b6 e3 q8 O
    13
    ; g9 x) ~; v/ i# l5 l9 G. h& P: t14
    5 `( B) M7 y$ U6 a  S) h0 ^" g15
    0 n+ x2 g" |, O7 |- ]7 R: `9 B16
    1 g  Y& l7 Q# w3 Y% b# }5 f' b171 p* K8 I( ]8 O+ p
    189 D# d; A- v0 T& O* i" x
    19
    6 E  \& l. ~7 h# r9 s4 e/ F8 Y20
    " ]  U1 N: f0 i5 h0 J3 S213 q; X0 p9 o/ ~. v5 e+ O1 t3 [- d
    225 [5 g3 j/ m/ e
    23
    ; o& |" X& J7 [- B24
    6 Z: j' X, Y8 Y$ u2 I2 y25
    / t/ {: {7 G% s  f  x4 i. i* T26* v4 l$ E* }- \1 s$ l8 F0 N
    27
    - e' F/ _# U/ h28
    . q! w6 e& T5 t, O) E1 G( x29
    9 A7 V( X2 A% }" A* b30
    ! ?% f- q* k6 {; j31
    7 T* i- {( @+ l32
    & g0 z4 Z/ b3 @* d33
    ; p- g4 O8 V, K* w( y+ O最后是1000分类,2048输入,分为1000个分类
    - D% k4 Q/ x( v9 @& s而我们需要将我们的任务进行调整,将1000分类改为102输出/ c; M' g1 {% V; O" ?& O3 e* c1 w
    5 a+ I7 W6 h/ N5 m  @2 O
    6.初始化模型架构+ n) y6 u8 C7 l/ \) J3 t$ r
    步骤如下:
    , T2 D6 \! U+ C1 k
    ; A# E# g6 `# d+ ?* m# {) _将训练好的模型拿过来,并pre_train = True 得到他人的权重参数( u2 F2 o8 m8 @: F' j9 u6 {
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
    & J/ ^! S6 I  Q+ U无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数" f+ p% i! p4 i# N3 J/ m( Y- O# X
    官方文档链接( l' I, B$ e/ }
    https://pytorch.org/vision/stable/models.html
    ' A/ V0 E% T7 |" N( a" s- i3 F, F& B$ _/ X& s  s) L
    # 将他人的模型加载进来
    $ x/ W! b' C4 G; j7 [7 T8 Y3 Vdef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    - z& e2 }4 y# Q    # 选择适合的模型,不同的模型初始化参数不同: l1 m: y9 y' G. T6 Y- m6 s
        model_ft = None
    ; N0 a. _* T5 c* T    input_size = 0
    6 w6 d/ u- Y$ o4 |9 u% b6 @
    ; {2 `( S  ], H+ L' G; p    if model_name == "resnet":
    0 H- T' Y0 j- _        """6 W' T6 l/ H3 w% Z1 V
            Resnet152
    # _- q* I1 N) F+ j# c- W        """$ T$ F1 U8 m4 v5 b1 E6 `) i
    9 M1 k  X. H! y. B
            # 1. 加载与训练网络! e5 u$ s2 v$ Q+ u! Z3 M5 \) [
            model_ft = models.resnet152(pretrained = use_pretrained)
    # ?* V0 w% P: `$ m; e        # 2. 是否将提取特征的模块冻住,只训练FC层/ w# q9 {2 A& @0 f' _4 `
            set_parameter_requires_grad(model_ft, feature_extract)
      [. x+ Z) D  X  Q% v% S        # 3. 获得全连接层输入特征% H2 l9 c7 Q+ l" O# S, [+ M( j
            num_frts = model_ft.fc.in_features1 T" l8 Y) ^3 O. ~; R9 e" s
            # 4. 重新加载全连接层,设置输出102
    8 g! A. ]- U6 |( x/ ?; @        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),8 N' f& ?; B+ o  [0 Q" P, O0 s" m
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    0 l! v" Q( w/ j4 J1 H, q& f" Y/ o9 ~" G        input_size = 224
    % x/ s6 e" f) I2 G
    , U" B7 h1 Z$ c" |6 ~; D- c    elif model_name == "alexnet":7 S5 \) q( o+ C; X
            """
    ' g! P3 }% R9 a; t9 [7 d        Alexnet
    + g. [$ X. V+ O+ G5 E        """# [) v4 [1 U, R' S6 ]
            model_ft = models.alexnet(pretrained = use_pretrained)9 z2 L7 v% q5 P$ ~3 |% Q% ~
            set_parameter_requires_grad(model_ft, feature_extract)
    , R) N, L0 ?. J1 i+ s9 m$ h) {5 C6 o2 y" b
            # 将最后一个特征输出替换 序号为【6】的分类器" A8 `) \7 `/ Y
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入+ W. y! O3 E) S+ \  L/ \- R
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)& \; f4 u3 y( {
            input_size = 2247 _+ }- s$ D4 h; x6 K; q, \

    ( a. c7 c4 f6 k  l    elif model_name == "vgg":5 `, U/ D& P! i+ d! ?' o% W
            """. q2 H% d: O! H8 m3 W5 p4 g' z+ \
            VGG11_bn
    . ~+ Q' F1 y  @( u        """9 W; X+ C( v: Z+ ^: x
            model_ft = models.vgg16(pretrained = use_pretrained)% F# k/ L5 ^* C: ?  g1 A4 u8 f: h
            set_parameter_requires_grad(model_ft, feature_extract)
    ( I% V6 i, d& g0 a2 ]+ o        num_frts = model_ft.classifier[6].in_features
    ) v' O6 e( `" f: K5 E- ~+ ?& U* l' A        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)% ^0 B0 L% C  w" q5 `
            input_size = 224
      k4 X1 _$ P8 P- x5 ?" X9 Z( n7 M% T8 a7 `6 T: p
        elif model_name == "squeezenet":
      X% i  t! b7 T        """+ w* E# r1 H/ [: s. z1 }/ D7 v6 u
            Squeezenet
    0 P; V. U2 c. p  A        """8 p# Z0 U# O% m4 Q
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)
    ' y( l' d) l( h5 E4 |4 v        set_parameter_requires_grad(model_ft, feature_extract)* l: A1 k9 _5 f8 Q* W, K
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))$ O* p% W0 [2 l8 m& d
            model_ft.num_classes = num_classes
    9 B) r' g" u- a# ]# D        input_size = 224# o" e5 U4 x. d( O  K7 H9 E
    # [! u) o' m& {+ z8 Q/ B
        elif model_name == "densenet":2 G, z) x. [* Y9 b/ T0 o( V1 x( {
            """
    # r3 T5 ?5 @3 H" P        Densenet
    * w5 M7 p5 ]0 G        """
    0 g5 v. v1 h( {2 Y        model_ft = models.desenet121(pretrained = use_pretrained)% Z( Q& y2 i* a; }. D; I8 J+ R
            set_parameter_requires_grad(model_ft, feature_extract)
    : c  w$ p! v7 k8 u9 w        num_frts = model_ft.classifier.in_features5 U/ T- K$ A/ W- `$ I! d4 j: _* `
            model_ft.classifier = nn.Linear(num_frts, num_classes)+ v- @2 A" {7 f$ k9 G
            input_size = 2249 m% s* n7 t8 r0 z8 {0 C8 M
    - m6 p, I+ _0 V4 A
        elif model_name == "inception":
    ; U( i! q1 S7 Z1 ^* _8 ?        """
    ( ]1 n. H% t) Q9 l' R: c+ P9 t( W        Inception V3
      r  F- n: x: k- U3 f2 i        """- ^1 [! N) U6 ~  n. u
            model_ft = models.inception_V(pretrained = use_pretrained)' q& m1 W2 s' u( c. |5 B( ?4 x
            set_parameter_requires_grad(model_ft, feature_extract)
    - u) y- F4 h( [1 Y4 L0 T0 U& q6 w
    3 G; c4 B' A# l6 }2 v4 e6 l        num_frts = model_ft.AuxLogits.fc.in_features: ]. W* S/ W: d! T
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)" z) ?* L5 L  C! m7 R! |
    6 E3 {! D; R; E! p
            num_frts = model_ft.fc.in_features
    3 v( h8 I+ H/ J$ y; U        model_ft.fc = nn.Linear(num_frts, num_classes)" ?% R- t* Q& p( y$ e
            input_size = 299
    1 L. _2 g  y7 N* k) s( Z: T3 V% e  j/ q6 z3 D
        else:
      ~2 x4 _; e; g3 N9 ?9 f        print("Invalid model name, exiting...")
    8 @7 e9 ]2 _; `) s        exit()) O1 |8 [8 E/ ?& s

    : r/ I9 u/ m( i    return model_ft, input_size9 C9 k3 k6 t: f. H" N

    6 p, U. W$ u5 {. f% j' B! h1
    . Y1 @/ @8 C, n- h1 O, X$ x29 q* H& @6 |3 i2 _8 n
    3. Z- U7 ?+ `5 l" ~) F0 Y2 \$ o
    4' u  w# D2 C2 y* `: M' B$ z
    56 W8 k8 o- q: a) |
    6+ t; R" f; y0 Y" s0 }8 i
    79 J7 M( J9 v1 \' {( c
    8$ g4 f2 I# `, H2 d4 Z" I* A( P7 N1 X7 A
    9
    + g! @5 x5 v9 i/ F3 i/ ]10/ ?* W; o# m% q. Y
    11
    : h: m5 j2 b: e129 Q  \  `! _4 C6 A# t9 G
    133 ~' z3 X- o; _$ G  K$ M# [; x
    14
    0 c0 x; D% [1 J  a2 h' c15
    # b* \+ i& I1 I1 d161 ^/ M- ]+ h8 d
    173 a- `6 w0 w6 k8 i
    18
    - k6 F8 q2 W8 u6 Z8 e) I; B19/ R+ m  Q3 S: l6 g7 `0 W
    204 F  K3 h1 l4 r5 v
    21% Q, {1 q! I7 `
    22
    " c( n- V+ w$ l1 m1 M23; B0 N4 A& j3 q6 ?8 U7 _8 I
    248 Y# n2 C9 F4 r2 a
    25
    + _, _' Q0 J7 c1 N. U$ r: h26
    / B: g& q0 B+ N2 U- p0 H5 S2 F27
    ) m: E9 i% P& r" K28
    , K) T- y/ M& Z+ p  {0 e! Y29
    6 D7 c% _. b" H. [, l0 |- D" p307 d: v. b% R+ i8 W: A; Q+ h+ g
    31
    . t  z; P& A" x" D* w32
    ) U/ h/ P0 Y+ G. o339 \; T- M" E$ B6 o) I
    34
    ( R* U/ h8 v' @5 a9 @35
    7 Z/ m  @$ _( x; Q36
    ; }- Y0 I3 P( U7 c' n0 S37
    * x6 _+ I& M" X: |4 i3 [; `38; @) H0 Y. w7 N+ K; d6 s5 w
    39' H3 m: a4 m) Z6 Y+ N7 j0 Y
    40
      L3 G7 P/ ?; k* l+ r41; u3 d! ~2 D; t# P; D3 B
    424 q' Y6 V9 v7 @1 Y& N3 z
    438 E6 Z" ^5 v. z8 W( j
    44/ K' |. J$ A5 w" ~+ A+ B( `2 ~
    45
    - L  K( \; g8 v6 Y4 ?7 M9 e46% H! e  \  h; e
    47
    - a* d/ _* f0 _) f+ `! Z48
    * V% O. ~; _" M: t" Z* N1 y2 l49
    " T- d0 v6 m: i! l8 g( Y0 s* B* S/ O500 @# ^9 Q. s; n3 r
    510 `+ w4 r5 ]4 A9 R* B
    52$ ^' v9 t& o' y: P* G# S
    53
    & Z7 u7 S# E8 u2 `7 R- w5 w54" ~7 }. L8 n- y' ^" o) G
    55
    ) [& n  ^9 u5 X, @) `% C56
    & i  D3 S0 O/ j( N2 S0 W57* r8 F9 x. P. s1 y5 L
    58
    ) y4 ?( u! M5 j$ s1 W6 V59
    & \, K6 \- w. A: f" u60
    , h7 m- `* s: _$ \2 `61' h# o( o  G9 i% T, I7 j3 C
    625 P% B0 o" z3 p1 Z# N
    63
    ) Q! a% L7 x  N  m' Z+ {& I7 t$ U64
    ' v% K, S5 R/ X$ ^# [7 B6 r: ]65: o# \4 p0 T$ b
    66  [3 r2 D$ a+ I+ b* l9 y' `
    67$ h5 D# E$ t' V
    68
    3 ?& r2 v8 e* w# ~. d' e& W9 m69# ~2 V# u! m  f, X
    70) I2 z, D# R* W$ B9 R7 y3 W
    71+ j, F6 T9 E1 t! X9 d: y
    72
    4 i0 O$ x; j( @0 A% {7 i( ^73
    4 `/ G) U% _; C0 C* H7 R74
    2 m$ W8 E! Q# ?7 |7 @5 G7 M75. u/ q# A7 D1 m- f
    76) Z4 f+ o0 s* {' a4 W) q- p3 N8 D
    77* a9 J, V% s) L5 l
    78! @2 I9 k  S4 [' k
    79
      x9 {8 R; S& V80- W* W/ _9 S* B: K/ J4 S
    81
    - G* l. r; n) d9 i7 n1 o* p82! ?! A9 I5 W  a- ^
    83, |+ F9 `5 f; O
    7. 设置需要训练的参数/ @; G+ N9 i. p( s. ~
    # 设置模型名字、输出分类数, O8 h1 ]& B) I0 ~7 \2 K- p
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)& M0 b4 V- \* _# t- N

    . m. P# o8 P8 G, p9 g8 E8 d# GPU 计算- K5 R/ `6 `9 ?* w
    model_ft = model_ft.to(device)
    3 i/ W: d% E+ H
    $ O  E, n# U3 P* k+ X& k! [0 k# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    3 M( z; v# G2 D$ K' {8 xfilename = 'checkpoint.pth'( u7 @- }& {2 K# F9 ], F
    $ l7 U4 {/ f5 r5 L6 v5 d
    # 是否训练所有层
    ! n$ W6 a% @* {( P; `0 Yparams_to_update = model_ft.parameters()
    ' n5 q+ t9 ]7 T- |# 打印出需要训练的层% ^* ?: j6 Y0 ]6 g6 e/ s/ U3 S& [
    print("Params to learn:")) r" M$ g  j2 w  y5 T: H* I: O5 O
    if feature_extract:) d, J1 y; ~1 Z  H( L" v; m
        params_to_update = []  m( d! ^+ z+ B+ ~: e, s1 \
        for name, param in model_ft.named_parameters():! `. T0 C! `# L! z! T, l
            if param.requires_grad == True:
    ' B9 ~' z- |: Q# d9 V            params_to_update.append(param), e: {. ^( |, {
                print("\t", name)
    3 b/ D: [/ K! X9 H# R; Z% _3 Oelse:
    7 Y4 g4 w+ q! K" p0 \/ ~2 [) U6 f' c    for name, param in model_ft.named_parameters():# f2 \( c$ a/ \. A  b6 V, H
            if param.requires_grad ==True:3 U' v. F& l" S  M6 W! F3 e
                print("\t", name)' C+ y' T4 }+ F( W$ |+ S0 \! y/ r

    + n/ N3 `2 z/ q! N$ F1 u- K- c1# m# M7 `7 z9 ~4 C
    2% T/ R7 `! ?9 h  ?7 E: }9 }
    3! T, e) M. j, g5 p, s* I
    4$ j9 D- O* V( R) a. o* E2 Q
    5
    ) M' P- b* b  z  [' H6  g1 N' Z1 a, t& z( J
    7
    & N% |2 ]  |3 S# x: L) A" V8
    ; ]. Z; U8 I7 H! P5 {3 w9
    / b8 u/ c* H+ K& n& J) B10
    7 t" o6 Z1 a8 O4 F) H3 M' n* m11
    ( i$ K9 ^' u) I; g1 ?12
    3 q1 l$ O" B( M1 o) T% u# t# k13
    % o6 Y0 J; H# y8 F14' M4 w3 }% i; o
    15: t# _8 R- q1 ]% p$ T
    16: y& B' Z6 ?- d# |) T% [. t$ D
    172 T& l9 O5 Y: x( q4 t& E8 S4 d0 y/ v
    183 W$ W/ @# f( E* d) U( F
    190 [$ y3 y/ K( B* C- q3 v: w8 Z
    20& u+ L  L& @, L, @. D8 w- I
    21
    7 j6 f4 o# z1 }2 _( b22! {5 o6 x* l+ \" V' g# P
    23# q# z4 t/ K. X0 Y
    Params to learn:
    4 @" S& F5 o8 ]7 ~- A' o1 K         fc.0.weight
    : p) X. T; A2 d4 E6 M7 g         fc.0.bias
    ; l: f' O4 I( }15 R5 B: D1 @6 {, `; v
    2
    ( o3 H8 y0 `% f* z7 N0 J3
    $ r0 f' l/ P: T, D7. 训练与预测: V. r6 Z1 {6 z6 H" B
    7.1 优化器设置
    - q4 H" R' M% A$ H- p' S% C' n# 优化器设置
    5 i0 d4 G# F* v; M7 Q9 |optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)- j6 N& h; n: M$ \5 d, C+ _+ S" M8 P
    # 学习率衰减策略
    " k5 |9 y# ?9 H: g" a+ A/ d1 Xscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)' i9 ]6 d% z( }" t  c( l
    # 学习率每7个epoch衰减为原来的1/10
    ! x# v5 Y8 e# u5 D, I) |8 L8 z7 ^# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    8 ^( D0 [/ A) `9 ]
    * E1 L1 C" H* R4 v% ]criterion = nn.NLLLoss()
    4 _. n8 T3 z5 j$ p2 k& N1) d! \/ d. U; i7 B7 Q6 {  X9 Y
    2
    . Y4 q  z& @0 F7 r: {- \" }) i3' Z( ^! e' |. X
    4' h7 B4 o$ m# V- a( K1 O
    5: l9 N' U. \/ ?* v) K  @+ p
    6
    6 k  m8 F  X6 Q2 q9 k7
    ; n8 J( x  r) J' a: n8
    " x3 x" j: @5 k  `7 P7 C# 定义训练函数4 v6 B3 m6 r9 j
    #is_inception:要不要用其他的网络
    8 x# V0 |4 V! @- U, I# ?: [0 Ddef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    ( w8 g9 Y/ Y$ ?' ?    since = time.time()
    6 \5 I! r% t8 F( D* s    #保存最好的准确率" _- y% |! T( k- V8 u3 U
        best_acc = 06 p9 A# b# S+ F' }, j
        """
    4 V! X* a; z8 X  v: b- w; Q; C+ ]    checkpoint = torch.load(filename)
    3 j. ^! c( r5 ~, o2 j1 q    best_acc = checkpoint['best_acc']' {" m$ x1 Z2 h! W, S' D- u
        model.load_state_dict(checkpoint['state_dict'])
    , c3 O% M5 V. X9 E    optimizer.load_state_dict(checkpoint['optimizer'])
    % D& Z8 ^7 o! R% ]! @    model.class_to_idx = checkpoint['mapping']
    1 R  l  T! H' V/ |    """" h5 k, f! o3 v5 v' s
        #指定用GPU还是CPU% j* w+ e0 ^1 b) i/ e3 }# r$ w" e
        model.to(device); B, G5 Z' L' x5 e
        #下面是为展示做的
    9 }3 W7 T* A' H3 I    val_acc_history = []
    5 J) L; ^/ b1 d    train_acc_history = []
    ; ~" X' K) }1 p    train_losses = []6 u$ V7 v& |$ b% t$ f
        valid_losses = []9 X  [4 c5 V+ U) l( e# Z! U+ a
        LRs = [optimizer.param_groups[0]['lr']]
    # H$ T- o8 A3 |; t* F. b    #最好的一次存下来
    * K9 ?7 S# C8 q7 W: _    best_model_wts = copy.deepcopy(model.state_dict())0 l9 p4 G, d" E8 f

    ) {5 f% G$ q: \# ]. b2 x    for epoch in range(num_epochs):4 t' G( Y. \+ s1 q' g3 q
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))8 f8 O  V. ?0 k2 d% r9 G& ~) U
            print('-' * 10), K) c: q! W3 P( W
    / z7 n4 I8 w9 g: \+ D% h) g1 F
            # 训练和验证- |! ]4 M# S  c  g9 H
            for phase in ['train', 'valid']:
    . k( a' @" G1 S            if phase == 'train':. |9 k' j' y( ~8 Q$ K) \
                    model.train()  # 训练+ Q' s  D' Z, s4 k- Q) A) B8 ?9 g
                else:
      O! i3 k; d* H! g, F4 q  ^                model.eval()   # 验证
      z& ]! s3 C# }( R- f/ |# G8 A+ |8 p9 G7 K
                running_loss = 0.0, V9 T2 U0 c  x8 P8 P& F% |
                running_corrects = 0
    1 X/ U. S) y. q) O6 X8 n5 ^
    3 q& T7 u7 `& y2 y( o            # 把数据都取个遍
    ! {( |; S: E. x' O0 A0 j  t8 x            for inputs, labels in dataloaders[phase]:8 K- ?0 p. u7 e. C2 ~! B
                    #下面是将inputs,labels传到GPU
    0 [& f  X8 Y/ z3 L4 m                inputs = inputs.to(device)
    # u: O& L) ^& s' C                labels = labels.to(device)& `$ I* ]6 T; P" D

    : O2 D0 q, W8 p                # 清零' [0 i# e' ?0 o. v
                    optimizer.zero_grad()6 k8 U; }6 ]- ]& ^+ r' c( e6 V) k
                    # 只有训练的时候计算和更新梯度
    ) n7 p( V' ?) }                with torch.set_grad_enabled(phase == 'train'):
    6 O+ a( D5 h3 h( {5 m$ M5 v+ a                    #if这面不需要计算,可忽略
    1 _; P/ u  f" {; H' l* N1 s5 ~) i                    if is_inception and phase == 'train':7 `& [! [) v* k& c
                            outputs, aux_outputs = model(inputs)4 y0 r7 z* P0 Z; `7 b
                            loss1 = criterion(outputs, labels)
    6 |  F+ I5 @. a( h7 ~! e) U. P                        loss2 = criterion(aux_outputs, labels)+ `/ a5 M  \! }) G/ x* S: ?  |6 G
                            loss = loss1 + 0.4*loss25 M: ~6 H. O  s- A5 I
                        else:#resnet执行的是这里
    8 u& ~1 s3 a, `% w                        outputs = model(inputs)
    * W6 F- Q7 ~4 P6 c& P& C# f                        loss = criterion(outputs, labels)
      O  S' B5 r$ S6 K& J9 X% u: }2 l1 w' M' V+ }! E
                            #概率最大的返回preds
    3 Y' Y4 R7 ~) A  }" {+ J                    _, preds = torch.max(outputs, 1)
    9 y9 E  v' L3 [! n& A. W+ P* R4 [, i8 ^* q
                        # 训练阶段更新权重. Q% z0 E  z+ b/ {8 |" y3 _
                        if phase == 'train':
    # D/ z7 G7 v( p/ D* X7 f& y2 r                        loss.backward()% t8 \8 n# \" ^7 D* F
                            optimizer.step()
    : C! r1 G' g% S. P+ s8 v* ^# s3 ~+ f+ E7 Q( X
                    # 计算损失
    1 X! P& P( X- F1 H) N                running_loss += loss.item() * inputs.size(0)6 P" c1 _6 W0 b. `2 U1 R7 c
                    running_corrects += torch.sum(preds == labels.data)% q; `' S2 R. H9 D7 |

    # [9 Q8 _* {5 ^) \            #打印操作9 Z( k* O3 I, W' u' h+ G% j
                epoch_loss = running_loss / len(dataloaders[phase].dataset)
    8 e, W- t6 y" s% q: C            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    # T8 w6 W( o1 |5 f1 o5 G
    5 v$ }, r+ Z' \5 g; M! t! z: E; h* {2 F3 G+ }( P, p7 U
                time_elapsed = time.time() - since2 D' R9 x% {' Y$ |" o
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))" h& \  r" d$ x  R/ D$ I* T
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
      P/ B! U# z3 s3 C! z& g" D, f1 n" }. ]
    3 x& @7 V6 d+ m: {/ |) n: U- r' R# r  r2 g/ m" Z. I
                # 得到最好那次的模型) O  ~) _, l: M4 m* Y) Z+ L
                if phase == 'valid' and epoch_acc > best_acc:
    7 v) E& p3 O7 ]( U2 @, _                best_acc = epoch_acc2 H% u- ]8 _2 @  j4 w0 s6 D
                    #模型保存! ]$ j3 r: |0 r% \
                    best_model_wts = copy.deepcopy(model.state_dict())# h1 K" a! U6 Z
                    state = {
    $ a1 D' ~% M( u, b5 l4 t4 P                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数+ V+ n$ K( \- M8 M
                      'state_dict': model.state_dict(),
    9 N# D, g5 [$ f" i( D                  'best_acc': best_acc,
    % O8 y$ W0 P; q7 c                  'optimizer' : optimizer.state_dict(),
    8 a8 ~) O7 ]( s. s, y& {                }
    ' m, b; c! h& F                torch.save(state, filename)8 {! B+ ]% \+ N9 I
                if phase == 'valid':
    ' W. `9 C! R7 E4 U, |) e                val_acc_history.append(epoch_acc)$ _% e$ j- i  t1 L" S
                    valid_losses.append(epoch_loss)  F/ l: g. O( U: ~" r6 h) N
                    scheduler.step(epoch_loss)
    1 p6 x! R% X* h( }! G3 d  X, ^  y            if phase == 'train':8 \0 H$ |3 m8 K, ]$ i5 s& |
                    train_acc_history.append(epoch_acc), y1 m7 V; V" N% t; X; }
                    train_losses.append(epoch_loss)5 ]4 i# J9 H' e0 T9 j4 m
    ! ~" k- S0 O" k8 D" M3 B3 R* Z
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))8 L: K/ W  Z" ^
            LRs.append(optimizer.param_groups[0]['lr'])
    % k- x) g6 s" e+ l/ y* L0 `        print()0 G) o8 h: k3 l- ?  z) t
    6 _# l/ B  W; p9 s/ m
        time_elapsed = time.time() - since
    0 P, G1 t$ }! H: C! p- ^, q4 s    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))& S4 Y' m6 z- L/ j6 g
        print('Best val Acc: {:4f}'.format(best_acc))& _. _. y$ W1 r$ i
    $ x* ?0 u+ a- |1 W) O# S# _% U
        # 保存训练完后用最好的一次当做模型最终的结果' o/ l3 A' f$ G5 `0 d1 C) A
        model.load_state_dict(best_model_wts)
    0 J; d% c; }6 Q! x    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    % `4 |$ _3 d1 z- i3 H# x
    ' `! ~* {3 T, q7 a/ h# k  c7 k7 w  O$ \
    1
    . J& Z  ]3 s+ F1 d3 R2" T; y- Q8 O( C9 t- T  ~
    3" B' f. G4 }% ?; l( B0 @7 ?
    4
    ; H+ W4 U- a) {+ M& O: [! b5. y. i$ p3 h* {# T; b/ {8 t% y9 u
    60 j  W- Q! z; E+ X: W0 [. W
    7
    : I- r3 e7 H2 e5 y4 ]: s( K' Y6 S# Q8
    $ \# c4 p5 J/ `! F9 w90 T- G, U( v. I, }  C
    10
    6 W6 U* g/ H! E11
      B  X/ D2 T' a; L( g, ]0 k) A12
    ( h4 q- [  ]7 P1 B5 b0 ~% Q5 {2 q13- O  X' o3 e6 p6 u+ D0 f
    141 H( j0 Z3 o/ r7 a& x& T6 H- e
    15
    : P) p' H7 }. f- [$ z16
    ! |4 m* ]* \+ E& }; T$ G4 T( D; x17
    % z$ c3 c3 ^8 n# ]18
    - v. l( z. |( [19
    1 U" w- W6 E' N202 c0 D- a; [& r+ Q/ E. J  R
    21
    # e; ^7 n. ^. U; U$ N1 _* _# t22
    " {% t4 V/ S/ W+ ~1 _6 ?- p, p' b2 ~23
    * e( ~0 n- C; e; w6 b24
    ( R9 L5 Z2 ]' V$ _2 p25
    9 V/ Y! b# K5 Q7 g! [260 ^" Y) X5 @* o8 L# ^
    27
    & S  |" p- b+ c7 j& o9 w! ~" C: e5 O288 C4 [# n! ]! s7 n
    290 P7 l/ B2 U1 k1 V+ F0 k: f
    30
    2 o' Q6 D& G: V5 H( \: V31% m+ v& i2 `1 Z8 d6 l6 ~
    32
    ( A- Z, p9 J4 h* Q0 f3 w. j1 H' k33) }: H" A  c3 y( F( i2 {
    34
    1 l+ d% \9 n8 s4 d8 N# w35
    4 ]8 m/ k6 F' X7 V. H36
    5 h7 F3 m* o1 V% i/ H. @$ F' P37
    : B, h! B8 L* ]. S% I$ y! G! [38
    / n6 \/ [) S+ d39
    6 {' @. l" X* U/ r: i6 Y; D. I1 @. A402 S6 ?& W4 u4 ^
    415 V& d, t7 a& x; G9 k; e
    42/ o; Y) c) [  V. t5 @+ w
    430 m7 h/ X4 r  O# ?: p( l
    44) r5 j& t, O4 A: C3 G5 v6 @
    45
    . {; ]1 M8 _- s& j: S# x46. K* V7 {% Z( f) h) g, D* i4 |
    47
      T+ [3 w) ?! T- W48$ ?6 n3 J  \! E0 Z2 K* B9 B
    493 `  t3 C- \2 f- w
    50
    1 w& q  o3 V% X) I: z1 X$ S515 L2 f* S8 l& a$ }' D2 q. D
    52
    3 S9 \# }2 ]! S, F% W4 S8 o53
    7 v# v& M- q" N* v; ?# k) g8 r/ c54- i& J% y* h( c5 p8 R) z
    55
    ) ~* Z3 L, v( T% l% b* ^56
    + i+ U  _5 A: \9 d9 N: n- W57
    ) P4 X; A5 F2 X+ |7 ?& \6 }' P58
    . L" J2 V" k9 ?59
    . L- x) x& }- u; S# M* T0 A) b. u  z60- z9 p5 p- d6 A2 {0 U2 E( T" i
    61
    ' c1 l: R/ {( O9 X# K+ T625 ^: U3 w( S+ N7 j
    63+ z+ ?+ \+ I: x$ [8 R! ?3 `
    64
    0 U; `6 |! T" {$ C7 B$ D65
    - y8 ?6 V4 S" @/ |. H668 M% q( U$ y/ o+ ]4 C4 z: f
    67
    1 s. N+ ?/ @. k. P/ R2 H68
    4 A2 P5 @- O  c4 j69
    & r, x/ l- O% `$ O" [* ]70
      @1 y5 a4 h3 M& p, z1 z71
    & s% z& U! Z: S7 N5 P72
      y7 ~, N6 h, N/ Q73: P7 t! |" D) a+ q
    743 {- d7 Q! j$ C7 r' d
    75& y' c. J, z1 L2 O
    76% J  x+ q+ ]6 L3 W' S' [! t( \
    77/ \4 y( q& W9 \( V; j4 q4 Q$ v
    78
    ; C0 K+ O3 l* K1 z# `79
    . R, u& y5 {* S; k! z3 K1 y$ Q2 {80
    * E( t9 G8 D& {1 R* o( s81" A# v; _7 T+ ]. c" c/ ?$ ^; r
    82
    # `( h/ t5 C5 S* ?2 L8 @83
    ( E  u  A( ~# m; u  p% h84
    9 _6 d; [; Q# f9 s85# M1 D8 u) b, Y7 ~
    869 \5 _. P) ]+ U' _
    87
    8 L6 m: ]) G5 b- j4 ^0 o6 q88# J+ \0 F. z" x) g, q
    89
    6 N2 G/ }4 Q5 `90; J8 {) g# Y$ ?, X: k3 `4 X, R
    91
    0 B" D/ P$ b5 S( F# g92, e  [) L: L8 [) w
    93
    4 H& C5 \! v$ e/ e94
    ; F4 K& h. i5 \: _2 x( g4 P95
    0 _3 X( N/ s7 O+ t2 B: Y9 y96
    2 i4 B7 {0 ^: Z& N3 ]0 A1 L97
    8 [$ Z# n4 S/ B5 \8 g7 T980 X6 }: x. Y: n) i: i* {
    99, _* G7 O5 n0 S- Z+ y* L2 h
    100
    : {( S5 j. p# p7 t101
    0 y8 A1 r2 R. C" L/ c; l102
    5 S' u. @) Y0 ?  n: q- V) z1031 @4 B: q9 t' Y1 j9 d8 G9 ~6 p& B6 l. Y
    104" F, C* p; ~& X) D+ [
    105! D+ c5 O2 q* K  a8 D$ o  F" [6 N
    106  k& w7 ^" c8 X' S+ h$ K
    107
    * H; r+ k" B" e  M108
    ( _( I+ z3 l1 \0 R109' N4 I4 Q) d, a2 v: z
    110" D- U  K3 d3 I  w- @9 q% F! E
    111) L: d' ~3 b9 w( m. K6 {
    1121 g% J) h8 q4 m/ |5 K
    7.2 开始训练模型
    & f4 d" Y1 V/ B' \我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次' l) M4 A& j, F, J0 z6 D4 N  Y8 z' I
    9 F3 }& x- F. R6 z$ Y. X3 c
    #若太慢,把epoch调低,迭代50次可能好些
    . `8 O) D+ _2 x# N  n& I9 y0 y( ~5 O# |5 J#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
      q! e& i! v, K4 n" jmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    ' I# `3 @6 p* W! t3 ]- y4 J& [0 C5 b6 u" b9 K) m
    1
    " Q& V, S# E& |, p8 s2
    - W/ U! a  C8 q, J3* A5 a# W/ B- f7 C0 Q
    4
    & _! y9 ?$ G+ uEpoch 0/46 i1 q0 w+ S) N2 G# v
    ----------3 C( t0 [: h5 r  u+ L
    Time elapsed 29m 41s
    1 x) H% Q% x- C6 d( o8 X( ^( @train Loss: 10.4774 Acc: 0.3147
    $ r, Z3 j" S( }Time elapsed 32m 54s
    ; e6 F" K8 `+ p) I; dvalid Loss: 8.2902 Acc: 0.47198 q; _1 x2 n7 d2 F7 ?6 q8 Y# c
    Optimizer learning rate : 0.0010000
    ! E: |4 s. O. [/ h7 D& E, U9 ^, a& {1 L& R
    Epoch 1/4
    2 v  W& i. C3 j+ }; @----------
    4 o/ e, @4 G# x' Q# ]" F9 J- `Time elapsed 60m 11s* ^4 p3 k! h: w& l4 ]- h  [
    train Loss: 2.3126 Acc: 0.7053
    * s6 L; x3 ~  ]0 m  f3 }' y) [  oTime elapsed 63m 16s
    # b& c7 {# N2 o5 Wvalid Loss: 3.2325 Acc: 0.6626
    2 N1 O; r5 Y* d5 vOptimizer learning rate : 0.0100000+ y% b$ p6 T0 x6 r2 b6 F  @% s
    $ Y9 o; ]( ?7 X- G( q' l
    Epoch 2/4
    / o. {. a* N" l- _5 H3 E----------! ~% Y' C8 [( c& A4 M
    Time elapsed 90m 58s
    ' J/ A) o* w+ g/ T; k4 Ktrain Loss: 9.9720 Acc: 0.4734
    ' f8 l$ p$ a: I" u3 u6 O) xTime elapsed 94m 4s3 ~/ `! ~4 ]  A) W( x
    valid Loss: 14.0426 Acc: 0.4413
    ) @, g! X8 u1 i6 {- m7 q6 QOptimizer learning rate : 0.0001000
      y0 A  |( P9 h9 q, \3 t  {2 x* p% Y) [: C
    Epoch 3/4
    ! N& |; G0 o% @2 B1 J----------
    / h! j. K& r$ I+ [* d) bTime elapsed 132m 49s
    ; J+ G. B, q! s0 `, c9 A1 h0 {train Loss: 5.4290 Acc: 0.6548
    . ^2 X! G5 L/ N" QTime elapsed 138m 49s
    # s* f9 t+ Q) ~. ^: O; Mvalid Loss: 6.4208 Acc: 0.6027
    , S  ?5 Y1 r* j: gOptimizer learning rate : 0.0100000: G" g0 ~$ N/ G
    0 g' R+ Y; `0 G, o9 V8 |
    Epoch 4/4
    0 K0 ^: b7 L9 K6 F2 E+ n! ]----------9 u. p& S1 u7 }$ I3 D8 o: D
    Time elapsed 195m 56s
    * f$ i- r  y' b& [$ g: htrain Loss: 8.8911 Acc: 0.5519( Q' D! s# Z  V7 M: C  U4 C
    Time elapsed 199m 16s% H* b" e- g) o" K& B& H
    valid Loss: 13.2221 Acc: 0.4914- G4 c- U7 T$ b  m1 p( A
    Optimizer learning rate : 0.0010000, D  G! V. Y0 O# D' E/ l
    " l+ {" I+ ~3 V1 t8 d; U
    Training complete in 199m 16s
    0 `: |: m7 b4 a" ?3 EBest val Acc: 0.662592
    0 s8 S" X7 J, R) P" `1 \  p. {* v2 W' b- ~& f1 l
    1
    " c1 a0 T7 b6 c$ w) q* F: g2( M( A# K2 W' u$ U/ {
    3; w9 _+ D$ Q# _. O& M
    4
    $ p2 y/ X/ A; l9 ~5
    * w2 J# ^7 L" Q! e6
    9 O+ I; Y3 l1 w6 S* Z. p7
    2 x2 q6 K8 ]) z# t0 M8 q* B$ ^8- E7 @" e  r' F% j
    9
    4 g8 d* f8 u8 \104 d8 v' m' e4 L
    114 D1 |* ^. @/ t% v
    12
    4 b2 H5 n" w) P# s! J9 T5 l13
    - S( H0 Q1 o" e- V143 |; P. d0 `4 a; r! X& _8 |
    15
    5 j3 R, P3 u  N% H5 n5 I2 c* H  U7 A; d16; F2 Z0 f: i& d' y9 \+ X3 M6 \
    174 a4 R7 O+ }# Y! ?& P  \! Z% [
    18
    ! ^  X0 \# r& V1 f19
    6 E7 V& p# ]2 I( `( d& q0 h20
    : X( O3 K% p) Z3 s2 l216 ]$ g! `  U$ t8 B8 ~" e
    22; U: e3 O/ J( x' W  g5 E$ N$ A/ v
    233 m  I6 ^- M( t* ~2 W
    24; R/ g; `8 [) t+ X# r
    25
    2 [2 z2 n1 L1 {9 g( s" f9 ], M9 n( k26
      l: v2 ?9 t% r" o1 D9 S27
    + x6 f8 L4 v9 Q+ P. z) {0 ^283 G' O+ P+ S' n+ f. X7 }* N
    29
    - J1 V/ k1 [/ G, Q0 ?  l( d30
    % i8 v5 S5 M* @31, y* [- b9 ]9 K7 m: Y
    32
    / F) H# m! o3 `, j33
    0 Q, T3 W3 Z/ R: L; G4 c! Z34
    9 Y3 o' c4 x' g$ R35- f; F/ `0 d6 Y: e6 u9 J/ I  k
    36
    ( B( x$ ^2 W; x9 h376 y2 E/ p( m: L! Q) E- U0 b
    38
    ( u5 G* o' C$ C$ ~# l0 w  W2 c39+ ^# m- u* A$ d8 I2 e' b4 @; v" X
    40
    2 J. B7 e9 U8 Z* N0 b41: x: N# u7 q; d
    42
    6 J& y0 G  i$ @; x/ p. H/ Q% f7.3 训练所有层
    ( _& Q1 G* v7 ]* P# 将全部网络解锁进行训练& g. S( g) f+ \
    for param in model_ft.parameters():
    8 u  O9 n* g, z, g7 _* o    param.requires_grad = True
    8 Y1 q4 Y8 q( X9 Q  f6 u: k  O
    # 再继续训练所有的参数,学习率调小一点\' l5 H' J" S6 T. y, V8 g/ e. P& ?
    optimizer = optim.Adam(params_to_update, lr = 1e-4). p( p# [* M7 R/ @% y
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)3 q# X. x  x/ Q; E; Y

    + L; P2 y2 y) P$ ^! F- w# 损失函数
    & n7 m+ f9 R. `8 Ncriterion = nn.NLLLoss()
    . f4 R1 }/ N9 o5 u( t- ^1
    # G2 ?; m* M% c$ J. }2% T8 w1 [! |; x5 ?- O
    30 H3 A! V# r% K# X
    4/ q! _# z: ?7 h; |( Y9 ~) z
    5
    & \0 m/ ^* }1 W# N- w5 O1 a; h6/ v$ Z, I; U- O: }
    74 R5 T& @& T1 i) d$ }
    8
    $ t. r; G- A5 J" ^, D9
    4 V) z* P+ v7 }+ @# m3 g' Q10
    ' n1 t) M; s2 U! W6 h# 加载保存的参数" _1 h0 V2 |, v# g, k2 V# Z
    # 并在原有的模型基础上继续训练
    8 w# q& [, C& _9 |' @# 下面保存的是刚刚训练效果较好的路径
    3 _" x% ~2 ~$ z8 A& N0 m8 D! echeckpoint = torch.load(filename)+ [' C' @* c3 Y4 g6 K
    best_acc = checkpoint['best_acc']
    ; c5 e# x( l" W! d4 X8 u! q2 Hmodel_ft.load_state_dict(checkpoint['state_dict'])
    % {- y7 M" E/ U) Xoptimizer.load_state_dict(checkpoint['optimizer'])
    ) k; }) z* {& k, S) x  ]7 m17 _9 U/ Q$ G5 z4 o+ ?% n0 n0 \
    2
    ( e" c4 M9 G8 X$ c1 r' W3
    . v# l  k; v6 O* W46 l4 J3 S- o* P: F8 ?; `( d+ o
    5; ^' a7 l9 p' i% A
    6* G7 ], Q) k2 T
    7
    + J! l7 A/ v* V3 M3 E$ G3 L. m开始训练
    : m- w* O: r6 t. N# r4 y注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    5 i7 {7 v5 K  u. h* o$ H* q$ c+ P% r
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    9 D/ J- m: t$ P' s! C1
    5 C- E; v0 ?: n8 U  qEpoch 0/1
    9 x1 y3 S3 h1 x% _3 l" {----------+ X: U- R3 F: g8 c+ Y, c# c
    Time elapsed 35m 22s
    5 J: }) z" W, Z. x/ Qtrain Loss: 1.7636 Acc: 0.7346; _# t! g% y$ Z" i% t" J, S$ \
    Time elapsed 38m 42s: f' u* O" M: a4 ^/ I% w
    valid Loss: 3.6377 Acc: 0.6455# |; y% Z0 d, m0 J& z/ p% g, }
    Optimizer learning rate : 0.0010000
    0 Y. e3 o2 ]. \5 R- D. n( u7 ?. O! l8 ^# V) F" K" P* }
    Epoch 1/1$ E$ b- S0 [7 o
    ----------7 }8 X- X$ m: N
    Time elapsed 82m 59s
    ' p- k0 y  y3 ~) u8 {) l, d3 U) L4 Jtrain Loss: 1.7543 Acc: 0.7340% c- f# ]: C6 T( c1 K
    Time elapsed 86m 11s
    : {: q) ~* a. u  N5 U, z$ ]valid Loss: 3.8275 Acc: 0.6137
    , E& A+ E9 U! c, kOptimizer learning rate : 0.0010000
      B& k& P$ m  ?$ t, C" I% f4 H& U' L4 P: w! y" X- _2 L7 r, f5 J6 x
    Training complete in 86m 11s
    ' I8 ?8 w1 J7 c* f  HBest val Acc: 0.645477
    & x. x6 h9 d1 d& b/ q" G7 B! A: ]& n  F3 j: n  p1 T) O
    1& l+ ~# m8 G* e) t3 K7 W
    23 ^% ]% b5 @9 ^$ Q- _! k0 N
    3
    ! C/ w( b* E6 p$ j4
    7 x  U/ t% r7 j( Q7 }7 N6 ~. N# p54 n: ], M( W5 O
    6
    3 y4 N, k6 C# [, q$ V" [+ L, i, A7$ _1 S9 U' p4 h
    8
    " y: E  L/ d+ u% m/ ~5 S  [% \9  o' m% W; B$ g1 ?0 t3 V* Y& c
    10
    / D6 }& a" _: s8 E11
    % w0 w% _" j( l) }' a) K  w12
    5 @/ J. @9 a& X$ r! z13
    $ ?# G0 X2 o; t. f) ]( ]: N2 W, ]14
    - O" i3 H% A; p& m( Q15
    9 f+ [  I: r) [0 M* y9 V% x* E168 Z- Q/ M4 s9 w: E( s2 {* J
    174 b# h. O  n& k  s5 b
    18# n5 |. z$ s) |8 O- W( k, X
    8. 加载已经训练的模型
    . B2 O5 m& z$ D, ~' w相当于做一次简单的前向传播(逻辑推理),不用更新参数, D7 a; C' Q& k! E& Z3 d" }

    ! p- @, i$ f0 emodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)- W7 _0 K- w" E4 `; Q! f
    3 ?3 i, N4 Q# i/ W/ Z, R7 g) F" I6 `
    # GPU 模式' F# A+ X, Q4 g
    model_ft = model_ft.to(device) # 扔到GPU中9 _+ ~& E0 M+ n, h

    * V5 X, u# R( I: J4 r# 保存文件的名字
    ! T! F1 Z/ |( Cfilename='checkpoint.pth'+ i' b, t) Q5 X% N: i8 V

    . k; t% T6 U: a$ p; }3 z# 加载模型0 l( B) W$ @/ ]! p# h) v1 G
    checkpoint = torch.load(filename)% F& A6 D7 ^. A, a; ?: n" Q
    best_acc = checkpoint['best_acc']
    2 `5 a- _# w" M$ f! ymodel_ft.load_state_dict(checkpoint['state_dict'])5 {- A( G" s( [9 @1 T
    1& F  S; c5 r  U! ?. J& s
    2
    7 q! m& ^& t+ q& y3
    # w# d4 `- e' o! Y4
    " f2 N  h& m7 W; F5, r, z( L1 V8 g9 x6 |* }
    6
    / V6 U) N% g6 S- K7( f2 g( G9 G* t$ [# L
    8
    0 h$ ~  k2 _0 v: ?9 C9, X# }3 ?& b0 B( K, ]  }
    109 t7 ]$ }8 b. b; [
    114 I3 B. M' h. Q' b
    12
    / q+ `+ c3 J, Q1 d, e% b6 `! _* z<All keys matched successfully>
    $ F. s6 Z, T: ]/ J1
    ' ^: o1 h0 F  i7 ^3 z' S3 }def process_image(image_path):
    + S1 y9 A( d' @; x4 Q    # 读取测试集数据
    $ m# Y7 J7 F+ x4 ^: L+ E+ ~    img = Image.open(image_path)
    + r! @/ v; V% P! S    # Resize, thumbnail方法只能进行比例缩小,所以进行判断4 c  t& o7 p4 s% D7 Q) R
        # 与Resize不同# j. v/ x+ k  ?! V8 ~# G
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    & m1 }5 U+ i9 G$ `( H    # 而且对象调用方法会直接改变其大小,返回None) s0 H# E# G4 ~! @
        if img.size[0] > img.size[1]:
    / {( j) R; J. x1 J6 u1 B        img.thumbnail((10000, 256))5 d2 h; o* z" T3 v
        else:
    5 I+ y# R  o6 ?7 p! O, B; L" R/ c        img.thumbnail((256, 10000))$ W1 z3 F; q1 z6 ~" z3 a9 K, e* N6 I9 a
    ! L0 j8 Y7 j) g, s
        # crop操作, 将图像再次裁剪为 224 * 224
    5 a$ H/ \' N; `4 q2 R    left_margin = (img.width - 224) / 2 # 取中间的部分
    4 h8 N1 @. h: Z    bottom_margin = (img.height - 224) / 2
    & K+ ^! S6 U( o6 S' @+ {5 d    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    ! x+ i1 I$ N# {' ]. K0 e    top_margin = bottom_margin + 224: t9 [9 K  L- P0 e. m
    # A, @( z; Q* G
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))+ z6 Y$ V) k, D' L- R% ^

    * D& j+ u# Q0 t2 I    # 相同预处理的方法
    ' z( x4 g6 n* c* L* T! T    # 归一化
    * g2 k) X3 M& w- q, n. i: E  H0 G    img = np.array(img) / 255: t9 P3 a) v3 Y' `
        mean = np.array([0.485, 0.456, 0.406])0 l8 n8 q! l- n, \: `
        std = np.array([0.229, 0.224, 0.225])
    1 Z! v# y; q9 `) G% U* T  G8 Q    img = (img - mean) / std0 o( t& r1 p# L6 f. H5 S  v  U

    , \  T( G' M* Q" m) k    # 注意颜色通道和位置$ \% e* F6 c. _; b. Q( ^: e" n
        img = img.transpose((2, 0, 1))
    2 l3 Y' C+ x  Z  n5 L. J' V' O. A8 k3 M, a5 h7 j
        return img
    0 ?0 k* t! \$ q0 m6 y
    * G3 @. v2 x0 A- Xdef imshow(image, ax = None, title = None):1 `+ J) s$ Q* f  B# `1 v& @; M
        """展示数据"""
    % l5 N% _2 ~4 h. a1 P2 A3 r    if ax is None:
      A5 K2 h2 y8 i) E, L2 W        fig, ax = plt.subplots()5 z2 R, Q! I2 w2 e; N+ [* @
    + ^7 U3 u6 M  E+ x# \9 k( y
        # 颜色通道进行还原
    8 B5 r5 I" u" t" Q: s2 I    image = np.array(image).transpose((1, 2, 0))2 [7 K; W( S# ^
    ; ~4 L* ?7 e; N
        # 预处理还原
    . @* X3 t* z7 d: ?. P    mean = np.array([0.485, 0.456, 0.406]), c  h  W4 X* `  X
        std = np.array([0.229, 0.224, 0.225])
    1 `- {3 L; g" }4 g    image = std * image + mean( W& z- i7 P' F; y% A* ~/ A
        image = np.clip(image, 0, 1)- V' o6 A/ M8 x3 X3 C3 A

    # Q  ~+ e* I7 {. y    ax.imshow(image)6 P, r( a6 ~  B  q, K) ^& A
        ax.set_title(title)# f9 K3 U: X1 O6 h
    ! `# _6 W8 ?9 R: Y9 t# b
        return ax
    * V' f6 Q+ S/ e* v
    ' b: `9 X  ~( S9 N/ `9 B1 T3 Vimage_path = r'./flower_data/valid/3/image_06621.jpg'
    ( R# a% @. s" y# ximg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理  m5 P" j$ R$ a, L0 S2 k1 c8 o
    imshow(img)3 X$ E/ o, d" b/ q  R
    ) I$ }! @% K9 C7 A, t7 x
    1( C: ?+ P& H  Y5 u& _5 {6 O' M
    2, W- x8 F: _  o( ~( _5 m1 `1 q
    33 H+ a0 o! z# w) b+ i0 R% t0 T
    4% p$ ]- @# @4 H2 S
    5
    ( M: h. e4 r+ C  [1 k6/ e9 k/ a/ h: L* ^7 R  O
    7; N8 m6 C0 u% d9 N# [8 _) P
    8
    4 ?  l" K: g. d1 G9
    1 r' T# a% u5 c2 _# Y) R7 D4 \10
    $ a" k/ I/ n/ V8 B( H11) ~" }- X3 m! A& ^+ D: u. [
    12: Z. C. j7 B( H
    13
      i) l( K7 {3 E14
    7 W  i! B9 t' U/ k8 K1 h: ~15
    1 _, \0 D# s" ?, Q6 @* O' B16
    ' O& ]7 [- x" w' Q, n7 \& l+ b+ x17
    & U4 I4 U0 y( |- A' d18
    7 |+ R" R3 `: N* v- u# F6 Z* N8 J19
    ! \  t& i6 _: P/ W20
    0 \* _. Y: c6 Y* ]217 l/ E2 }, b# H$ x$ T
    22
    ) T  }, N+ C0 S  B' E; C# {' \23: J2 ]5 L$ s, ?8 w$ ]/ C
    24
    % O/ c/ q  S  D8 P" u+ U$ h25
      t8 B  |# {5 \' k  A261 h! K; I7 I$ G
    27
    - o1 b3 t8 o* J28% K+ P$ \& O+ I
    29
    & P1 [. u8 [6 P! i; A30
      R7 H: d/ B' S# U31
    2 `1 U$ z3 k. z9 T9 C32
    + u; P" B. R! c' }$ X4 s338 T' [! i- H( Z- Z/ _$ T9 B
    34) d3 {/ q/ j, i0 C8 g( r+ p
    35) b8 j6 p  i8 t( _, L" r
    36
    - e8 [+ w2 `& i. J- z! U37
    9 J" E5 \$ W$ `! k0 h, @% b' E38. V2 x  x+ n0 a& z  f* e: h
    39
    ; e1 k3 b, y6 U# d40
    . @$ H. x9 s2 O2 i41
    6 p2 k- R6 t* m7 b2 A42
    ! J. ~( R9 X4 `2 K! l43
    % B/ d9 W* m+ O' l44
    3 Q, H. x2 P) ^( _. i1 I" {45
    9 i- ?$ t2 M6 P8 n6 Z7 _2 i46
    2 s/ l6 ]: A) `% C, w3 k( |47& d" k) a& k, e, d9 p
    487 T' j" E- _3 V: D* |2 Z
    49' \$ b) e; D; V7 i  Q  ~& b. W/ y
    50
    ; u0 g+ m: s1 ^9 F  F, j9 C) }51
    ; B0 p7 h3 ~# i/ M) u* C522 C0 A2 q, y: m/ l
    534 i: m+ Z( W) j9 E1 d  U8 N
    541 M9 ~: y- ~5 K7 q+ y- C) i
    <AxesSubplot:>
    9 q$ t) K- v: u) @7 G1
    / j( V4 h1 G2 N# k7 i
    % G) t- w+ ]; i% K上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确. @* R  U. ^5 t
    8 C9 n4 W' ^8 Z. e6 n: Q  F
    img.shape
    0 X5 `7 i9 M* z  r; [0 }8 v1
    " \% K; d: _& t5 G(3, 224, 224)3 U5 S2 J- [+ P7 C" t: I
    1/ Q/ D7 U4 [* ?
    证明了通道提前了,而且大小没改变
    4 |& m" x$ u- w( z* D( \2 e0 I  r; v& U0 F
    9. 推理
    9 P1 h. c+ |  W9 Y/ Nimg.shape/ {/ m9 d- u) b* M1 y3 f9 l4 c( C# u
    4 s: ?* a# z4 f5 b
    # 得到一个batch的测试数据/ }4 J1 ?! y# x) e6 [1 s1 k
    dataiter = iter(dataloaders['valid'])
    5 r2 h5 m! H+ W0 v. X, ^. E) Vimages, labels = dataiter.next()
    + Z& w, @1 N, @5 Y3 G* d. P
    1 w) ^. n0 X" |) D1 Gmodel_ft.eval()9 s" J8 r' ~+ X, b( g$ j
    7 t! f5 e& u# o& ~/ Z
    if train_on_gpu:' @6 D; F( f0 J6 I2 }# u' ~
        # 前向传播跑一次会得到output
    ' _) K6 L6 J$ E* v) k- g7 r/ c6 w    output = model_ft(images.cuda())
    $ L# O+ K0 C! belse:
    1 [4 K) H, S3 j4 D* Z8 B    output = model_ft(images)
    % ~3 L% ]/ {* W8 {' e: B) E! J7 I5 i' G) V; _
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
    . m6 S  H1 Z) @5 Q4 F( {8 Aoutput.shape( g( ^  Q9 L' o8 V
    0 M4 \/ h5 e+ T6 S. d4 H
    1: v/ M( P; C3 ~& @& v. e% d! A- n
    2
      V' W3 c& ]) l, o+ }8 x3% M, e& |- k% \) h5 G6 t
    4
    2 m' }& M3 z: S, w  t+ Q: K2 W5- k- w7 ]; p, A  e" _
    6
    $ ~: ^! _4 U: m1 L5 w: t7
    ) @& I) ?0 T7 E8 ?8 P8* ~" I; k' a4 F5 M7 B. T7 o' g7 A
    98 F$ y- k' h' Y* V/ |8 E
    10+ Z9 l* }# u0 F" g
    116 B0 a9 w/ f! C6 [, A% a& g
    12
    ! K" w; W. I/ Y9 W, P6 z- N13
      G9 V, F( l) j9 C) r14- I" a9 g$ ~5 ?4 e% N6 p' q
    15
    " b; N; D6 J# n0 g6 i16
    7 u  m$ f/ V1 r7 y; ]# }7 Qtorch.Size([8, 102])3 v5 j; X) p* _4 Q2 d
    11 Q8 r0 l. ]& ?+ W; k! m1 s$ K1 w
    9.1 计算得到最大概率1 Q% M# G0 ]. V( N1 r* t
    _, preds_tensor = torch.max(output, 1)- \7 Y* i' a1 W) i2 u0 M" Q1 h3 k: S

    % u% q4 m; K) `$ c0 N9 |' {! Cpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    - L+ n$ B7 H+ e9 x8 l/ Q& W1  Q' E( m4 Q. h( K0 p
    25 S" S' b( t) h6 X. h' Y
    3. I' d3 H7 o' p% g
    9.2 展示预测结果
    5 N! F8 c  c; hfig = plt.figure(figsize = (20, 20))% I7 n" ^" k0 Q. a: C, s2 g1 \
    columns = 4
    ! T: I, v7 k5 R1 Z- }/ @" m# xrows = 2# _& R% y# B* w" u( Q* N! i
      T0 m$ p6 u; o$ J
    for idx in range(columns * rows):
    . k! V, ~; V! c* c/ _    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    - K7 F! V; ^1 W2 M+ e    plt.imshow(im_convert(images[idx]))
    ) W3 U0 C* y% m; n  z9 E    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 3 H7 o. n0 Y9 L
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    1 i7 a9 z* ?; P5 j) jplt.show()! M# b! j3 ^  F9 ~/ e
    # 绿色的表示预测是对的,红色表示预测错了2 r, s5 @$ E. w4 B; V) I
    11 y8 b( n9 m* \3 [) X- G9 V) }
    2* r1 b: q9 B6 v. t+ |. `; U
    3
    4 ^" z8 S5 T9 w9 l46 M2 y: @% o* s
    5
    " q2 z4 r7 z+ D- g1 C% ^6
    ! @# V! Y- E3 \, Q# `7" z% n* ~1 S1 ], z
    8
    ; M5 Z9 O7 Z0 S# v! f9. A4 X. f% I, i! ~
    104 s: V) q! |( F
    11
    2 m' f2 f0 M' O4 s4 p5 C* g  `6 ], d4 C6 \4 G

    ! b7 g1 z" @) V. C9 F5 B1 p- _4 ]- b& R1 l/ |! k) o" l
    ————————————————
    & p, D* F$ F) x4 W: P& b- @" d2 ]版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。8 m1 S4 j. P9 D3 C: m6 a! p2 C9 U4 a
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940& [+ t: {& ~. _2 k6 K0 S
    ' _( \. B8 {2 `/ T3 G: j; M
    2 u% o6 Q9 M& d7 }8 V5 Y% I
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-5-26 06:34 , Processed in 0.524711 second(s), 51 queries .

    回顶部