QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2335|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例9 @7 s8 u3 T" n0 l1 M

    , o* N7 C( z' @- I- F文章目录
    . v  `- [8 d/ ~' x9 n( p卷积网络实战 对花进行分类0 k2 z1 _9 W: b4 L1 z2 u
    数据预处理部分! s. a8 K9 D  V) O# n! u! Y9 {
    网络模块设置' t6 K; U5 I6 K) P
    网络模型的保存与测试
    " A& ~4 c- j, `0 h4 |数据下载:' J, ?/ [: b! H8 X
    1. 导入工具包
      V4 N5 u4 t0 Q4 p1 c: o/ q- r2. 数据预处理与操作
    4 V9 ~9 h: b; s3. 制作好数据源& A7 o2 f9 U" H
    读取标签对应的实际名字6 Y6 Y6 f/ @, U% T' G' N
    4.展示一下数据( e( K% ~0 s4 E: |
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    . R. l, X& m& o  y. V4 V6.初始化模型架构
    ! z5 ^* j; v9 Z/ Z/ [% {7. 设置需要训练的参数
    0 b2 H& n6 j1 l% a7. 训练与预测. @8 z+ Q) ]) p  R  @  A+ d
    7.1 优化器设置1 y, o: w. d  Y  o: c
    7.2 开始训练模型+ y  b/ ]1 d1 P7 f, z
    7.3 训练所有层
    ' F7 A+ d+ M/ C6 l& W开始训练
    / I# p8 a* G# M% C% W/ T8. 加载已经训练的模型
    ; E" u5 l1 o0 j9. 推理
    ) v. Q0 c. ^1 r# q9.1 计算得到最大概率% y- }$ M7 q/ c* k* e( z8 t* J1 F  B
    9.2 展示预测结果0 Y0 k9 {7 n9 K2 x; x5 [
    写在最后
    ! E, e$ H0 m2 M+ ~( e: \* x卷积网络实战 对花进行分类
    % ]; S! Z: g$ _; f; q本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    0 V# ^& a5 g; P/ h9 i3 o! u0 n5 V/ i9 `/ p, v* b' D
    在文件夹中有102种花,我们主要要对这些花进行分类任务
    , V$ w7 q# L3 _) d& z文件夹结构! C, t. a# j4 p  |' X
    9 }$ A! ^. {' M1 {9 X
    flower_data# X# A3 a  B. g
    3 [0 u' L- B1 Y9 p. H# i& h" e9 m8 m
    train
    ' f$ Y* F# W% J! P" }/ q; t( `; J: n6 s
    - f* |# U* s6 H. K1(类别)
    1 U1 B" z/ M" D+ f# T% [2
    ; o. c9 [. t6 s/ P6 _! vxxx.png / xxx.jpg
    1 H  C' f2 F( R# m0 q( U+ A3 evalid& g1 j* \9 E/ D* T

    ! h  y' e$ v/ J2 c4 V9 \: L  |( L  s主要分为以下几个大模块+ @. ?1 c1 w' K/ F0 y9 q: K

    * [5 D. R3 |0 E8 W数据预处理部分
    $ g/ Y9 N+ U$ c+ F0 x, J数据增强! }9 f; l+ ]3 q3 O% G( R; @
    数据预处理& O6 o: x. T8 O9 |" h5 p
    网络模块设置
    " x- g) b( @4 B" A: u, t0 B加载预训练模型,直接调用torchVision的经典网络架构
    9 @) W' [7 k! A' N! i% j% i因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务& S8 w4 }) ?& T( m. N9 t  K1 m
    网络模型的保存与测试
    / M/ R3 I( h' K' ?4 q, @模型保存可以带有选择性
    , y2 ]9 @- z# s数据下载:5 M6 P* q) ^$ X9 B- y$ Y) a
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset7 r6 h! d% _( X, x6 h
      U( J! z: O7 i3 S1 n
    改一下文件名,然后将它放到同一根目录就可以了
    $ [* s$ l, \8 d' D' E
    * e4 R7 n/ ]: \+ N$ J下面是我的数据根目录
    6 Y9 |( ^5 ]; z/ i1 x6 ], I" _! L1 H6 t& s0 b6 i+ t8 x/ H9 x$ ^% q# Y* R
    : ?+ A) t/ x2 z6 D- e' a
    1. 导入工具包% q+ Z  l3 v2 e  C. q+ U, ~' w+ r
    import os
    ' O! v; }# R' Ximport matplotlib.pyplot as plt
    ' E2 ?3 @$ k# G% s# 内嵌入绘图简去show的句柄' ?' G, x; q5 k! J9 a" o3 ]$ Q
    %matplotlib inline . [# b/ Q2 u. m2 g' }6 J& g
    import numpy as np/ [( k$ ^: N( _& ?
    import torch' a1 Y( }+ E0 l: R7 P# k
    from torch import nn2 q. z* e+ n9 @, `

    ; u+ s& t9 {- h" Z, D( r5 U$ U  z3 Oimport torch.optim as optim- N! K9 \6 q3 J% O: a
    import torchvision# c" n7 V: G9 A3 h1 I" o6 A* |" H9 w- f
    from torchvision import transforms, models, datasets
    5 m" Y4 F, t" k5 l; `6 H( ^# z- T1 `! I! l& o3 ~
    import imageio) O- }% g. ~4 p# o1 W; ?
    import time
    # ^/ |, p, f  Dimport warnings+ Q8 t, q! Q0 z; X, }3 N
    import random
    5 ~! H' x6 S1 K4 J' C! j* ximport sys
      z8 X! ~% k1 H6 ^! S" y1 Nimport copy0 X- o! ~& ~: U, Y  @
    import json
    ! [& x- l3 k- Cfrom PIL import Image) P! h: k" f' u  o0 h% M

    5 H; [" Z* k6 l7 l, G) p$ x7 H1 T# @2 V1 F$ ]; q1 [" r: a; u
    1
    3 S; X5 W* ?) R4 R, X- m2
    ' O! {0 G+ W8 {3* E; s: f6 h, X6 [0 l4 l
    4
    6 J1 D: @% @/ J- Z: h9 f! ?2 B5: e4 J* b8 F" P/ Y/ y4 G/ O
    6
      Q: o  h7 ~3 N# H7
    ' |+ l9 l$ o7 ~4 y) p" A87 X$ x( T0 o& ?( t& G5 [
    9
    ; I$ |9 s& o' H# [10
    3 e* A1 E. a5 O! T$ ~( d' k6 c, i  F11, f6 ]. F5 \' |! C
    121 n+ Q. Z! }. U7 e; y
    13- @/ h$ S$ x2 q# F5 @2 v" C+ H
    14
    5 b3 W( d' J( y: J! Q& U15
    . _) S  e- w6 Y- g- y  B16
    - @. h# l+ i) D2 R$ @+ D5 N17( \# [3 x+ H4 \; p% |$ s
    18
    4 f) q# _2 l- C3 R; b! a+ u( |; Y19
    % G/ H1 \* z& ~* q3 i20
    / Z2 T: Q' S) \8 }( Q+ A* z. o+ T21; s, S: A+ p! A# |" ]( {
    2. 数据预处理与操作
    / W* |6 X. D; ~2 V9 r#路径设置
    , a% ~$ N, B  ^& A% ~7 mdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    3 k  t' h  d' C% T2 Z8 Ctrain_dir = data_dir + '/train'% o2 ~- f+ s( }* C  S
    valid_dir = data_dir + '/valid'/ R/ G4 E8 P; b  y2 M1 f9 H/ c1 X
    1
    - I6 i. W  e) i, Y2
    , s; T% D/ ^/ S9 J; \$ ^8 E* P34 K0 s' |& Q* g8 z
    4! M9 y: t# K9 P2 n4 S4 E0 Q
    python目录点杠的组合与区别0 J, d$ L+ c# d" ~+ X. g0 r& Y
    注: 里面注明了点杠和斜杠的操作8 y" m- L* W& i) y( |
    % a3 C  W- V8 q- j
    3. 制作好数据源
    % c' M* b9 M, E! |, U: |( @) Zdata_transforms中制定了所有图像预处理的操作+ G! Z( G- e3 Z) [# x" P2 C
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片0 t5 A# D2 u: O, c) d+ G
    data_transforms = {
    * [' [; y0 c  B8 N7 a    # 分成两部分,一部分是训练
    3 y) v  p0 Z7 @! g" o) s  [: i+ y9 d    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    . |  q% C, M5 c5 M4 u                                 transforms.CenterCrop(224), # 从中心处开始裁剪
    % N4 l6 j- ]3 L- j- w                                 # 以某个随机的概率决定是否翻转 55开
    8 |* H# V) W2 ^- I' N                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    7 ~) n  B. t; D$ ]) U( M" ]( K                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转! \. B) H! d3 d2 Q1 ]/ @
                                     # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    * P# R, V' n& d0 m7 i4 a* D                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
    * l1 B  v( p, n1 \7 A& H5 I4 k                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB3 O2 b8 E7 R- @: [( u& k  y
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的& C' @: D( @4 K' i
                                     transforms.ToTensor(),
    + q8 W! y& H$ l8 }% O: a                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差' I2 u6 u: F/ N6 x# Q* y- |" Z
                                    ]),1 A) R& p5 h4 W: Y) @7 I; A
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    2 B1 Q" q$ X) u2 y; r6 x    'valid': transforms.Compose([transforms.Resize(256),
    ; [: H# z+ m* \$ V' K) ~6 w                                 transforms.CenterCrop(224),
    ! v: \% m5 i- P2 }; \% ?6 j; b8 U                                 transforms.ToTensor(),
    ! X9 J) L# t6 u" _  S* K                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    # n% G( q6 _! v5 A: }3 b                                ]),
    ) T! {) |4 K* S; N}
    1 x/ q+ Z1 G3 p* Y* A5 V' U& q$ E" }6 x+ ?) x+ Q
    11 p! k' ~3 A, h. \! z
    2
    & a# B/ W2 B5 ^, _6 i+ m3* T7 {  u4 s4 R' D% T% v
    45 C' L1 }  T% c% @
    5% I# s. c$ j, ]% f: U5 J
    6
    ; q7 I9 g: G% o- n, f2 I) h) l7+ @/ ?8 P0 L2 I* H% v
    8, ~7 k- S+ o. Z
    94 E0 D2 G8 x5 F/ z( Z: v2 s
    10
    - Y+ y' _! t+ a; f+ i. z. ?  @11! J, }* W: D4 k6 M& X$ T5 f
    12
    6 M& c# d9 _  W+ `13
    7 X+ }5 ~* h+ w2 [( N14
    : J. V- @5 N  ~- {5 c2 {, C) q* ~& O15
    & }/ @% m; h' b- Z5 P16
    + x. V1 }& W) K* \$ ~$ g* n' y17" y) Y' E. N2 L/ G4 F
    18
    0 P# m: {9 N8 p% e( ]; a& Y9 F19
    ( V* \: z. y2 D! W- Z$ O$ p20
    ( q1 A+ @& b& s) ~- P5 B21( G: W# J4 |" ?2 U
    batch_size = 80 b* N  w* {0 m. E" Y1 T7 ]
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    $ f4 d& z( ]- S4 Odataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    " T' w# S  K8 B; ?6 v3 o2 J6 Ldataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    $ I* Q/ Q! |! `& j2 Uclass_names = image_datasets['train'].classes! M2 J% Q/ ?% j
    * d+ {8 \& [: W) A6 ^, o
    #查看数据集合
    - F$ c, x% L3 \% k1 Timage_datasets
    0 D' F/ o1 V3 A( X2 z) |, x# z1 F5 o- B% m4 k" X) G
    18 U) c. o( Z# r! e) Y
    2! h3 L7 a" R: }3 ]# L  F4 x
    3
    4 M: b5 ]5 X. n- j) ^7 \4: A% p; q# Q5 G
    5! y; w6 `: f+ C* c
    6
    / }) B4 Q* E) d; B1 _7
    0 j( a4 r" {) O' g% ^8 f  `5 M) E82 T7 y2 d6 l5 M- }
    9
    , g4 |, U8 x( {. o- @( l: j0 A7 N{'train': Dataset ImageFolder
    9 ^/ d; {* ?$ D% G     Number of datapoints: 65529 ^9 b+ c! c6 b
         Root location: ./flower_data/train1 }2 U4 b; \* E$ t) [0 V  R$ q
         StandardTransform
    6 |; C7 N' J- ?. A8 M' b6 ^ Transform: Compose() U! s; q9 G% T4 k" S( p4 k0 @% e
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)) R5 [6 y; {* }) e1 ?
                    CenterCrop(size=(224, 224))
    # [+ e0 p& r! G                RandomHorizontalFlip(p=0.5)- k3 f. v& ?- u6 y( y$ ^3 i* U' M
                    RandomVerticalFlip(p=0.5)
    " Y' ]% e4 v0 l3 E% ?                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    5 ^( D! a) S2 W+ K. b                RandomGrayscale(p=0.025)# O: y2 I1 x% r1 D. f& s$ b
                    ToTensor()
    & G9 R% ~9 x. l, ?3 g$ {                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ' t) o+ f9 \1 F            ),
    " ?# K; x/ a7 @7 D/ O 'valid': Dataset ImageFolder+ j3 j6 z% M0 c3 Y% W
         Number of datapoints: 818
    * [/ i2 H# T1 y$ j     Root location: ./flower_data/valid
    7 H% t# c) |* O, K/ x; q     StandardTransform
    ; ^8 z8 R4 H2 e" j  C Transform: Compose(
    $ r, C( W3 i/ {. d3 X  z/ G                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None), v6 S! e- {" D% ]3 U* i
                    CenterCrop(size=(224, 224))6 e; j* ]3 h: L# x/ j" Y  Z0 o* m" H
                    ToTensor()' c$ m/ H, q8 O7 u
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    & j5 I3 o  e( T& X- }            )}+ v* V0 h4 K: O+ B* L6 n, ]

    # d2 q4 Y' w( Y. K. ^( P1
    ) r# ]) n, V! L) I+ v2
    2 R, I7 K! F$ G( j8 G3$ R( h2 T3 h/ z' Z& Z% A
    4
    5 t) f# U* q6 ~; I: M/ d$ _' U57 o  ^  s5 j% Z
    6
    : I6 Y. D# M# @7
    5 l* a- p, v; e3 H7 {% ]! }  \8& L9 _- l/ L$ U+ a! v6 {  E
    91 _2 I3 I4 n" m# W0 C
    10
    7 @/ S" Z9 I3 W" B11
    $ Z( m, _2 Z0 t. C12
    ( H4 N- U$ P) a. }* @13
    . ~) W/ s. h' X  L. \14
    + ]' ?# b% N3 Z6 P& `15$ h* s! B1 `7 {3 |
    167 s- l! S2 O2 ?
    17
    8 H! H6 L7 S) Q$ u182 u3 R( |9 ~$ ~2 k, s+ k; H+ u' d. L( z
    19
    5 b# W# Y2 M# \9 d20- G5 [3 s. t/ P6 z, E; j
    21
    $ g- G  U1 M+ Q: g$ J; X22# }& y$ {; d; T: a; P
    23
    ! ?% |. g1 f( R! }. U) w24
      f# k/ Z! l  d% G# S* u; w1 u$ ]# 验证一下数据是否已经被处理完毕, o* Q5 q6 ]9 J& o$ b
    dataloaders! m6 P) N% j3 K& M1 R' @5 f
    1/ j* w: {4 f; u+ W$ L
    2/ E* c5 ?0 \" w+ g
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,' ]% Q4 w- N3 A( P; z
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}5 h9 ]# u* g1 n2 ]! R5 S. M
    1
    7 p5 e, H! f2 `3 q9 T- ^2* t" D# i; {! Z. Z, d
    dataset_sizes
    4 Y8 i) T8 D! z$ V) ]17 m, Y2 X' I2 V
    {'train': 6552, 'valid': 818}, v8 e# P9 B* \; t8 o% O
    13 i7 J/ H# s! V$ G3 R
    读取标签对应的实际名字
    " T8 B5 B5 e1 n. c* n% ~5 K) N使用同一目录下的json文件,反向映射出花对应的名字/ J1 H9 t$ q5 _5 p4 z$ E' v

    + X* A" A5 \/ O+ F- Swith open('./flower_data/cat_to_name.json', 'r') as f:
    - P" q! x/ P5 v( O; m* G8 ^/ a    cat_to_name = json.load(f)7 ~* `+ g# E& G4 b
    1: P- O/ Y9 ^9 B4 \- m; {
    26 s( I5 u( Y0 v/ |
    cat_to_name9 y- y% ?% q% U
    11 m+ _& {# R1 d8 `
    {'21': 'fire lily',
    / J! G% x! ^0 f, G$ } '3': 'canterbury bells',
    ) i' a0 X' c) e" S '45': 'bolero deep blue',
    1 A- M. j$ \+ |, V '1': 'pink primrose',
      p) E! T, f& [7 [ '34': 'mexican aster',1 r8 d" M. e/ n$ i8 W; I$ ~
    '27': 'prince of wales feathers',
    ' m! u( n7 j" @$ M '7': 'moon orchid',
    % ~( r* J* o, h& D2 G6 N '16': 'globe-flower'," k) P# K- R' i
    '25': 'grape hyacinth',
    0 S& v/ K. ~- U% a+ }6 N; N '26': 'corn poppy',5 v% N' e: }- d, V0 |. J
    '79': 'toad lily',, |8 y4 K0 \) r6 ?9 `- @
    '39': 'siam tulip',+ D5 B. k' h1 u, o, B4 J. x# ~
    '24': 'red ginger',
    ; E) Q+ A* _! X6 a- Z '67': 'spring crocus',, t6 c3 P$ c# ~8 J( r
    '35': 'alpine sea holly',
    - D) F' h5 A8 ] '32': 'garden phlox',3 k6 P/ c1 q" L
    '10': 'globe thistle',
    9 a5 n, Y8 h# |+ ~1 t9 T6 B" M8 N+ ? '6': 'tiger lily',! E% Q( P9 S1 a0 [
    '93': 'ball moss',, U3 j* G/ p4 A0 z1 }
    '33': 'love in the mist',2 o6 m! L4 w9 g
    '9': 'monkshood',
    1 @: ?9 f( C0 {) s/ a' k9 W '102': 'blackberry lily',
    9 c4 F7 C+ v( L7 P$ r$ v '14': 'spear thistle',/ G1 d- J& @% A
    '19': 'balloon flower',
    0 q( z5 f  v; g1 ^5 D- b '100': 'blanket flower',5 A3 Y& g  p! w) B* o( q8 l! q
    '13': 'king protea',
    ) ~& a/ \( J, M" }3 r$ P2 a '49': 'oxeye daisy',; ~( s; C- }! i3 f8 B
    '15': 'yellow iris',8 b- z% \# {5 z' M  g
    '61': 'cautleya spicata',
    7 o: V8 d/ {# P  R3 j* T6 j8 Y4 d2 T' b '31': 'carnation',
    , s2 w; [, V7 H0 p5 T- e '64': 'silverbush',
    # Q/ L' _( M$ ~  } '68': 'bearded iris',
    # m' I5 Y' g) h$ |2 @ '63': 'black-eyed susan',, l( X$ c" V! w* c$ K* I. x1 ^
    '69': 'windflower',% i' m# q; C& A3 b7 p4 I) ^* z2 A  A
    '62': 'japanese anemone',$ O& Z( P) J9 H$ w; G
    '20': 'giant white arum lily',
    % Z2 P" R/ l  \6 [6 ]0 Y' ` '38': 'great masterwort',
    3 _6 j3 }9 ?4 e6 A/ j7 h '4': 'sweet pea',5 ^5 E2 h; l+ F9 a
    '86': 'tree mallow',
    + O$ x' q0 Q  u2 D '101': 'trumpet creeper',( Y4 R. }( C5 _, _
    '42': 'daffodil',
    2 W6 v) C0 v7 U6 Y '22': 'pincushion flower',0 p( J4 B3 R7 M6 \9 M
    '2': 'hard-leaved pocket orchid',: ?7 D' J$ o" ^: Q" T; S( b% y
    '54': 'sunflower',# C& {1 r7 Z5 V# N  n- K+ q
    '66': 'osteospermum',- V5 Q2 q( z/ W) Z5 A
    '70': 'tree poppy',' X& I' {. V6 M4 }3 Z' I
    '85': 'desert-rose',' b& o( d& T% B9 a3 W
    '99': 'bromelia',# v4 x: L1 U( \+ ~4 h
    '87': 'magnolia',
    0 L9 V" l! c/ s9 r" p '5': 'english marigold',
    . v2 r; H  ~  B( p7 M# N5 o- E '92': 'bee balm',7 L) i/ U2 V+ S$ S/ f
    '28': 'stemless gentian',: h/ h) W( D7 m
    '97': 'mallow',
    " O9 B/ a* Y. b, o, c6 n2 [ '57': 'gaura',5 P$ }' j6 K9 N! L
    '40': 'lenten rose',: G) n: t& w0 q, [2 W9 ^
    '47': 'marigold',
    & ?0 V- ?0 Z5 \: Z4 y: g2 g '59': 'orange dahlia',
    + I0 i$ K' u- @+ U '48': 'buttercup',
    * ]. h8 x' P8 L3 ? '55': 'pelargonium',# C' ]; O# e  g/ V0 T
    '36': 'ruby-lipped cattleya',3 }8 Z+ {# M& l/ S
    '91': 'hippeastrum',0 G. P( r7 t7 t& g, K  ~
    '29': 'artichoke',* `- b! K0 L5 x& R
    '71': 'gazania',6 c" Z. `: F8 w5 z
    '90': 'canna lily',2 e+ @0 u# v* T, r
    '18': 'peruvian lily',
    # K5 ]9 p; q$ Z: M# i '98': 'mexican petunia',
    + N3 B2 p, F; |2 r& R '8': 'bird of paradise',: Q# Z: d6 }& |0 B( F2 \
    '30': 'sweet william',
      S3 X  J. Z9 ?* A '17': 'purple coneflower',5 k8 V3 I! t  l
    '52': 'wild pansy',
    $ \9 i9 L5 [1 D# U& M5 m '84': 'columbine',5 Y- |8 B& f5 ^8 M: K4 B7 H
    '12': "colt's foot",3 _" U: S# S* k
    '11': 'snapdragon',6 f! R* r2 z: D4 ?9 H6 Q
    '96': 'camellia',
    5 Q5 Z3 E6 n6 @' U7 d '23': 'fritillary',2 O" T- C1 {8 H  Y" t: J$ f0 X
    '50': 'common dandelion',) z& o: E, F7 Y2 i$ ~# L6 c( ]
    '44': 'poinsettia',+ n& L+ P; E( b* s
    '53': 'primula',
    7 g$ r& O# K1 B '72': 'azalea',
    : f% Y+ v$ i1 C '65': 'californian poppy',4 ?1 L7 c. x$ t; o
    '80': 'anthurium',  ~; Y$ E0 K7 d( A  p
    '76': 'morning glory',
    * n/ s$ o2 G; e" N '37': 'cape flower',8 u# q; \; H) ~: h4 N
    '56': 'bishop of llandaff',
    8 n% L& |0 R8 C6 L '60': 'pink-yellow dahlia',- x/ ?1 H1 r6 _
    '82': 'clematis',
    # K( e$ l4 S- S. [4 ?. t '58': 'geranium',
    / @5 I& _& n( K( R6 w '75': 'thorn apple',9 h) c/ n& ?* q7 V$ e1 e
    '41': 'barbeton daisy',
    1 W4 O1 P7 h# W# D6 r '95': 'bougainvillea',
    9 r! ]/ |9 t6 z) @" ~ '43': 'sword lily',
    8 d4 r0 n4 v) x8 Q4 h3 W3 D, R '83': 'hibiscus',
      C! J, z5 k. ?% m7 T( g '78': 'lotus lotus'," Z# z5 ~5 a! J* w/ j# e: }5 Z% F5 h+ E
    '88': 'cyclamen',
    + J; C) f- C. t' w: n: l '94': 'foxglove',
    ) b* ~% [- s" S8 C6 V6 T* o- i6 ? '81': 'frangipani',% C  |5 G* C# @' n" Y) l3 v9 j- u& B& _
    '74': 'rose',
    ; |2 S* o8 T# }  o3 F; _1 Z3 c '89': 'watercress',
    . \" _8 {; d; B3 j, C '73': 'water lily',
    ( e( C" C5 g3 G3 d1 o: \ '46': 'wallflower',6 H5 f  Z% i) V
    '77': 'passion flower',
      w- I4 m" N+ B1 L) z& \ '51': 'petunia'}9 Q, r, s! V! M0 G$ ^
    % E7 N3 J+ A* u6 ~+ x) ^; B
    12 e+ Z8 j. B0 F4 e/ U9 _# F/ ]
    2! e# Y! E6 V* f& d& ^
    3& l* A2 T4 c: ^6 Y  u8 Z% r
    4
    8 ^2 i: l( J" f. z4 \5+ O1 {  a, U% @0 Y
    63 D+ d% N, V# N& \- L! C$ N% {
    74 k# @0 T1 Z; X) e8 u: E5 R
    8
    4 `  ~9 [2 h! N4 [2 Q, e9& z5 F2 V( v% S( Q4 w
    10, Z, s" p) X) q  b
    11
    2 ^5 Q: O: `4 B12
    ; J% R& p+ b- v/ B7 G" ^' X" n13
    9 Q7 H! z- v2 L14; v# D; ~4 ]9 f/ m
    15& h7 q" H& \% c9 |
    16# `. A" r0 k1 k6 X
    17
    ( X9 x2 V5 Q* o; T+ f, a% W! {18, f0 R3 s) y9 [! N6 @
    19
    - Q1 @. E% ^# Q) L8 ?" ^20
    5 Z7 g2 O/ }8 \( w/ j/ {212 v4 C4 H& l4 e4 o" F% [, b# F
    22* W. y) A8 T  {( F2 H& t
    23
    & K' |# j# d2 l7 w6 n, ]/ X! Y24" ^; [$ u; X0 B: n% D7 c6 I# C
    25( ^( s- F9 y8 e$ W; S) n' o
    26
    ) B* @2 M- H4 U  D! L( q% V27
    % }& x' ]% D, W/ l- X+ l/ c" R28
    ; M7 G8 s7 [  p5 j9 Y' T294 F# o* ~8 |' k. Q; l* l
    30) T$ h" M6 j; I3 S: n: j
    31. h% k* O: ?( ^
    32+ J0 c. Q6 {+ i. X2 p! h2 v
    33. T1 P# x# u+ w+ Z
    34, m4 g- B. t1 e- Z4 e9 g# y2 ]
    35
    # b; {; O" w- s( f36
    + S" K0 s$ a9 r; t4 c37# ]( e2 q0 O+ ?8 D! f, z+ m; E8 G
    38- _. o5 V& \9 D! V) f
    39
      Z& |% x7 P5 Q40
    & g' o3 W- H/ U5 r$ j0 E41
    : p: ^4 \' R% f! R" l42$ j( `8 z- w, K  u- x
    43
    / ?( w5 f, {& T, Z" s1 v44/ T* l3 B+ N+ A" c
    45
    ' i8 L' z6 D! q5 v$ e46
    5 E4 Q* i1 V, E; b  S47
    , H+ W! {. W/ a3 W; T" w/ B48
    + X) J- [0 S8 E' o  B# `/ h49
    ' P6 {, ~; H3 {4 [3 ^50' t1 f5 g' S- y4 C; B
    51  e7 W7 B3 g2 \2 Y. s4 _; z8 N
    527 }/ B$ F( `( @8 U! R5 N
    53
    / K* |5 }; G* q8 U54% y5 V) R0 c* w/ U. ]# O3 J
    557 @1 h# ?. B0 r  {8 Q  `
    56( Q. U2 l& v# i. j6 Q
    57
    - x/ f+ J4 G. a! o58* s+ O. h/ i% F. H1 |2 m. }% G3 t
    59
    / P! k2 b/ U) i: ~60
    9 K6 }) o2 p8 i' r" I" u  f61+ p( u$ ^! c" [  i3 ^1 j: L9 d1 W
    62& b9 W4 X4 ~4 S: x
    63! z, _( r9 X+ P4 e
    64  {+ Z; i, |. E6 c- j
    65! m8 u% V% w$ C( c' O2 p: a
    661 \: {+ i( A% j! T; P2 Q1 k
    67
      s4 a# v6 c, _4 z6 h( |4 Y3 }# ^68+ Q. S% K& q& k. K7 l+ a1 z, z
    69
    7 R0 }" l: {  {707 u$ [* Z& q5 i& \4 h/ i
    71
    : ~* A) p9 [2 H. j7 U: s9 _72) m9 ^2 s) i: w  f0 _9 f; D+ Y
    73+ x9 v( j# ^; b4 e! S0 j6 ~- ]
    74
    # t* s- K) A; P" K1 x) G' f0 t75
    ; N; n7 r* q$ P# U2 ?& h  P76
    # u5 u8 Y" C' N# h$ [77+ z/ Z0 t! ?& N# ]1 [" y
    781 \! O3 Q5 t' Y# o7 H5 p
    79
    ; k+ {& l" b4 E1 A& W' f80
    ' S- b( P; X& u2 l815 f6 [5 @9 v) h+ ?3 @) u
    82/ ~8 s+ v% h4 {4 |9 O1 z
    839 `) k+ }8 W3 S
    84  o  |( w$ H% D3 q8 l, b
    85
    9 _2 s* L6 O/ y86+ o3 e* {7 l# W1 s: y7 A! k7 f; Z/ G
    87" J- V! X; A) O/ w4 D* i" q7 f6 F3 Q
    88. r) [: L) @9 e
    891 U2 t. r; [+ V+ \# p2 m8 v
    90* }& L2 A8 [, M; @1 R+ I1 g# Q: A
    91
    2 s. J+ H1 j& z6 H+ n- k92
    / Y4 E- |" Z; G1 a93
    6 G; @) r& v! v: H94. f$ \( ~# Q+ c" A* I  d, [
    95
    ; _5 p: M* v7 @8 ~% V2 d/ }96
    9 `3 R5 ~; n4 }5 V97
    & ?+ Q: c* K& p98
    # D+ g% m( Y# d. s4 `, [99
    6 {7 }/ @% T2 C+ A) Y2 [1 ~% y) w100
    ! f6 f0 J  y* C: h$ t101
      `. o/ Y/ r1 K3 I102
    # ?3 D2 \1 Y* `( v/ S" g4.展示一下数据
    # w0 }" o7 q. Z# _def im_convert(tensor):* ^: p+ q6 t3 e) m! O
        """数据展示"""; a5 H  K- F, ^* M3 Y& n2 C3 B
        image = tensor.to("cpu").clone().detach(), r/ i0 `0 ~2 U- y
        image = image.numpy().squeeze()
    6 E$ q/ n4 I( L* x- [    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图9 c5 K7 h/ M" x9 n& ^
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    ; u% w; f0 n$ v) s# r    image = image.transpose(1, 2, 0)
    . f7 ~) S$ ~4 C5 s3 h$ W' W    # 反正则化(反标准化)5 B2 c& d( Z- p1 b
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))- ], v4 g, X0 G/ S; }% ~( M
    + S% P; Q6 U) y8 [; d
        # 将图像中小于0 的都换成0,大于的都变成1
    : T) T( W7 [5 S/ O    image = image.clip(0, 1)4 b  [: @! }4 S% h/ L

    & w& l' u. w- }% c    return image
    0 P2 o: |$ y4 |4 v1
    3 p/ A9 k4 t# y0 P0 |# R2+ b: ]+ ^* B. Z1 W) q
    3, F" J- F+ P* j2 C5 r( b5 e
    4& J2 j9 ?9 t+ B; [& h
    55 r  y9 L1 g8 z0 B  v: Q
    60 U7 ~, Z1 D% L9 V
    7
    / ]' z8 K: x1 U: a. @+ ?8
    . V) i/ \- \: P+ O) I9
    ; E/ ]( K9 n/ v7 j* {+ S0 ^10* d$ f: B3 w* P  N
    110 g' V' I9 H# r* ~/ ^
    12
    1 }5 L; e# d2 X139 O0 L5 i; _+ ]( ~- x. @! u) E
    14
    $ _' C6 ]" b/ O! }; `# 使用上面定义好的类进行画图
    7 t; T& W* I$ A# efig = plt.figure(figsize = (20, 12))
    " K4 `- G( f0 j- t% m; I5 ecolumns = 4
    % Q% D7 w! O9 K  E: t% _rows = 2
    4 ^0 P/ |0 N9 D$ V# M" y4 R' l% v/ s+ ?& f
    # iter迭代器
    ' e( e/ s" ]! h8 K* [; L# 随便找一个Batch数据进行展示2 U- H. K- z$ X' D0 J5 g; _- Q
    dataiter = iter(dataloaders['valid']), I: o5 R6 A( c. T3 A
    inputs, classes = dataiter.next()' C: k+ R7 ~2 f- K& o6 ^! y
    : c# }- M2 Z2 a4 |2 s* [
    for idx in range(columns * rows):
    , [8 K! k0 i( X+ h" |    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    % W7 f+ C' v" s) Z; [0 ?( P+ R    # 利用json文件将其对应花的类型打印在图片中* A2 d( B. ]; ~$ v5 c% u$ B
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    ! d) G6 I! a/ S/ V    plt.imshow(im_convert(inputs[idx])): z9 Q3 m7 c0 K
    plt.show()1 k8 D& i  P9 Z: q! i" t2 v
    ! d) u3 T$ X: K# K: c
    18 K) G% c6 g+ A0 ~1 L  }6 h; ~( j
    2
    8 T" q- A% ^) b( i+ k) [% d3% X# p4 M6 X; B; b" e" h/ t
    4
    1 j! A" p1 M) o+ T9 b2 d1 e5' ]+ @) g3 x( b' P/ e
    6$ P5 M2 d0 i& M4 h1 W$ b
    7
    - B$ p+ [8 w8 ?) M8: H# v  B. f0 l* U8 u
    9: C* T! J5 E  i, i
    10
    9 G, ^8 s. j' c- ^" c9 M11$ L* `5 O& I. F5 U* i& J
    12
    * o  w* p4 y( u2 Q- `7 L6 M13
    $ x+ q: H$ W, P0 A, O" K141 k; g7 l" ?3 r
    151 k7 z0 P2 f8 O& k: `  X- Z- H' ^0 b
    16' p4 b# [. ~& t& D0 u4 H

    4 ~" S" j6 h* ~; x$ g' V# \- q3 w
    2 H) @4 n, r: @! v5 V; d, E5. 加载models提供的模型,并直接用训练好的权重做初始化参数
      R3 f) e6 L- a3 c1 ]! [( xmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']% }+ U  F0 n9 E( S; V. W
    # 主要的图像识别用resnet来做
    + T6 d+ u  Z; r, Y7 ^# 是否用人家训练好的特征
    + {+ h# @- D( w8 r0 kfeature_extract = True# p, |( U) i  ~. M
    1
    1 w" j9 _( Z4 O  {% f21 k5 z0 G% Y7 j+ Z4 [% U* z1 A( T
    3! \' G# K( s; N% X
    4( Q+ j8 Q% o: k
    # 是否用GPU进行训练
    ; e- W$ T; n/ T; ^% q5 Xtrain_on_gpu = torch.cuda.is_available(); ~2 U9 a7 }% B+ Y! |0 i7 R' B- p
    9 l& E+ ?7 x" I  _" \/ g
    if not train_on_gpu:
    # F( `( ^- a4 O$ j& \$ o  y% ?    print('CUDA is not available.   Training on CPU ...')
    3 _2 K. \9 L0 }else:
    1 I+ c1 V/ i, O9 U( s1 G3 C    print('CUDA is available! Training on GPU ...')
    6 p/ `  ?6 ?# L0 P4 p# m
    5 n0 u4 S  }' u* E/ |. Kdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')+ i( l" D( `2 ]* E3 Q! |4 _
    1
    ) A3 ^% [4 g4 [/ ]& {1 z. V2
    0 D" G7 t  d- v! F9 W$ p6 v3
    ' S+ K' R+ a) z7 K" a4+ V5 a& y4 V- m
    5/ l' e8 J0 o1 Z# k3 r/ x9 o2 \% E
    6) N) A9 v/ s+ r' X, O: ?" z9 a# T8 w
    71 d1 J( m1 f' M+ f8 V( A: D
    8% e" \9 V/ B# l2 C& G9 k/ y
    9+ \# |& w0 h8 B5 p7 G$ m! y) x2 ^4 r
    CUDA is not available.   Training on CPU ...2 W0 g+ w6 p( U6 d; w3 b% ?
    1+ A2 b. X: e. b. k- ]
    # 将一些层定义为false,使其不自动更新7 w7 |+ i" h& U. V: U4 H/ I2 h
    def set_parameter_requires_grad(model, feature_extracting):& k, O6 V1 T6 }$ _
        if feature_extracting:
    ! y! c3 d+ e# Z0 Z# t  ~/ J5 T& U        for param in model.parameters():
    8 v3 A0 N6 e$ I! d) c            param.requires_grad = False+ Q2 i% z, k2 d5 D4 K
    15 J5 o  z8 d% F
    2% r4 Q! h( K% o) u9 f' X' c
    3
    2 N7 w" H, g5 Z. |1 _2 z' K4
    $ a' D+ ]# V" o' W% F" B$ I3 U5* R  e2 J6 ]! A/ I
    # 打印模型架构告知是怎么一步一步去完成的& G( n+ s) R2 ~0 N
    # 主要是为我们提取特征的
    / R2 x( Z( C, S5 C7 q+ _, P7 t3 w7 w/ u7 Z- `
    model_ft = models.resnet152()
    5 P! Q( l& o; e! X3 dmodel_ft0 c# t, E4 P4 q9 r; |) @* N" e4 l% p
    1  b: T* X+ U) a) W6 \
    2/ T( z; d5 F) ^
    31 k, j7 n9 H' [. S' I3 C
    4% N# Y' K) r* u. F& w
    52 H. Z( S4 J% J% m+ U  n
    ResNet(7 V' N' K9 \! V& A4 _6 F
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    # U$ {7 P" N7 O: T" O% F) `  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)8 z3 R/ s* B1 ~* K4 Q% h$ d
      (relu): ReLU(inplace=True)' V9 b* A/ K; y" J) v
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    2 n% o' B4 U- c7 S4 V- D6 y. M1 I' P  (layer1): Sequential(, ~5 G* X) y# E* w1 t
        (0): Bottleneck(
    4 n$ D# ]- u3 u1 Z# \/ ~, y9 v      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)0 X! V9 z+ x% H5 J( `/ h
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): @" C, e+ T) k' w3 T" t( n
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)# U( J( V4 ^6 Q! f) F
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    6 n/ F0 x6 P- x- W( h      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    # S: v  v; q5 Z( e9 v0 m/ z      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)' j+ ?6 _2 K6 A, H0 V! B; S9 ~
          (relu): ReLU(inplace=True)) Y. R3 e2 }1 ~8 D$ h
          (downsample): Sequential(
    # E( I; J5 i# \( C/ {# q# U        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    1 j) d9 D' z' G9 q) L+ V        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    4 P# e" e% K1 G6 B! [      )+ k" i: h5 @# d# ]' h2 p
        )/ }4 ~, O1 b- m  b, [
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    2 h/ M: q5 k+ t1 a3 Y* z    (2): Bottleneck(7 ?' S7 D4 F) d) C
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)7 `) P9 T8 I) s1 O* t
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    5 ?) t4 m# Y9 ]( d2 c+ }      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    0 U* F7 n" S  X% B' r- Z      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)8 Y5 f& D' f% _2 S) J# A9 n& B- g
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)  H5 m, V; c% \: y* O4 e! M
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    $ e! W' [4 x: ?      (relu): ReLU(inplace=True)& F' g9 Q! M5 D+ z! V: U
        )
    6 j9 X3 d, _$ Q/ u; C/ y, w* B  )
    # _: B3 S( b* U( b) z, [3 j  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)). l. u' Q( _6 b
      (fc): Linear(in_features=2048, out_features=1000, bias=True)" ?2 T& ?0 p; z! u
    )# U* G3 d. a& d2 _, n6 U" ?

    ; }% p8 t8 [: a$ j" b4 g" S/ i13 h& V* K0 N$ \$ t5 b0 M3 u3 l
    26 ^' o6 z  u6 F: i
    3  H  Z  X( C. M8 t4 y0 z  c5 Z
    4
    ! P- [9 }  ?. d& m5
    9 S$ j6 X8 H/ g& v+ S* N- p6! h( X8 V8 K: `6 ?+ F, J
    7
    3 H! }4 s; ]- K  a3 Z/ H8
    * _" j' Y( c6 M$ N' d9
    1 S: a5 K+ z" ^9 i10
    / ^0 `/ w5 w$ H, ~, ~) A8 Z11
    1 [. l1 q6 G. c8 O, ?2 M3 U12. f5 b; B1 z5 R' b0 _) C
    13. M4 ]$ R# E. v7 ^  P/ h# b* R
    14
    % H/ b& t: H! f4 {1 ~# L5 d8 j+ F) w15% _$ b# E6 o; p( R2 z2 ]
    16
    6 P7 h/ A$ d# ^17
    $ l) |: |8 S& G& ], X18' S* J0 x4 X1 E; _
    19, c, B) Y$ I. z4 G7 ]: S# j
    20: R2 a5 ^" G9 M
    21" B5 O/ m& T% y2 {& J" R# [
    22
    3 o8 ?' q( Z4 l: q- W( u# X23
    7 Y. d7 g  n# n! J244 u  q  E% L0 [9 M- ]/ q
    25. l) i5 J0 S* p# q* k. S* B# P
    26
    - ?! ~8 c( P" M1 B; s27/ ~0 I6 a3 G7 k# X# Q( \9 b
    28: l* a% F, o/ S0 B# E: B
    29: e0 }4 B2 O6 ~( k# h
    30/ x9 Y  G; W( x8 N# d6 A
    31
    8 I2 Z# T6 I: g! B32
    ) H7 @5 r; A! |2 i; M: H0 O33
    ; @* ^9 |8 x  ]+ Q0 O最后是1000分类,2048输入,分为1000个分类
    ; N3 R7 T/ i0 }4 q7 }( i! x1 n而我们需要将我们的任务进行调整,将1000分类改为102输出" J3 E- L6 J' X) V9 a

      b$ @) k/ S2 x4 c- y3 w5 \, t6.初始化模型架构5 G* w$ H8 y' V+ \4 H. o7 Q$ a9 V
    步骤如下:
    & s8 [1 \" [0 L1 }- ^
    - X8 a2 z6 g4 Y  i9 U$ t& F+ m将训练好的模型拿过来,并pre_train = True 得到他人的权重参数8 |& }- Q3 }6 ?. w
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)9 _! G) y* r% Z2 v0 n' S
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数) Q3 e" c% X/ \/ b3 X# P3 Q  ~
    官方文档链接8 C( a4 S# F) n# p1 ?5 A: q7 T
    https://pytorch.org/vision/stable/models.html" e  o1 n  g0 v) v

    " J# t* H1 w1 y& H# 将他人的模型加载进来# B% R3 `- Q8 K! P$ S
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):- y9 p: p7 J2 B/ c
        # 选择适合的模型,不同的模型初始化参数不同0 |7 ^/ s( C+ t2 w$ |* N! f
        model_ft = None
    ' @& e1 g& D" n! V    input_size = 0% _3 M; V/ g& X0 t( |' e

    & T. i" b" j+ j    if model_name == "resnet":& K5 j; U( C8 z0 F9 M  W( I" Q
            """
    $ K1 t6 Q7 g5 o$ w$ M        Resnet152
    ; v4 C" ^7 T& G* U+ R8 `        """7 U  J! f$ U% w7 z# w' P
    8 I' `* w. W% W$ J/ W. d
            # 1. 加载与训练网络
    , N& u1 {7 W* s        model_ft = models.resnet152(pretrained = use_pretrained)
    " A$ m0 g# r$ D  E2 F        # 2. 是否将提取特征的模块冻住,只训练FC层& e$ L% L: T' h* E
            set_parameter_requires_grad(model_ft, feature_extract)
    ( B2 z% j- C& e6 V        # 3. 获得全连接层输入特征' {$ S( F8 y, l- @  ]
            num_frts = model_ft.fc.in_features
    + n3 O% j9 R: h6 \        # 4. 重新加载全连接层,设置输出102
    / k0 W2 s! U; ^! w; W        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),0 r- g3 J0 W+ s1 g3 a, K
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为18 s1 U! P# J9 e* G* E) B
            input_size = 224
    . l8 i5 [# J) S% M, n7 m
    $ _5 G8 G4 Y4 v+ X' R7 L    elif model_name == "alexnet":
    1 Y! j6 l, j8 W4 X0 [, O& W3 _        """- D" M% U6 M" ~
            Alexnet
    # h1 b2 I! f6 l# `0 W' g        """! r" z& U$ d! |
            model_ft = models.alexnet(pretrained = use_pretrained)
    / ?: A3 u; `. c9 s  P9 S" [% o4 ~        set_parameter_requires_grad(model_ft, feature_extract)! x2 p1 M+ u' T+ v8 P! R

    ' K; F& w- ?7 `" s& _8 Z        # 将最后一个特征输出替换 序号为【6】的分类器' s, J" a7 ?0 `  I; K  |
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入: ^  |- C5 |5 ^+ j5 s* b
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    3 ~3 e4 ?, E% ^5 M$ N        input_size = 224/ Z0 e( k! U0 V9 [$ ^

    - @* m5 Z& \5 @/ M    elif model_name == "vgg":% k" c2 F8 Y. _# \# ~- c
            """
    + F9 M3 T$ @6 ?2 e( a! \$ f        VGG11_bn
    2 G2 l$ i# F, m1 E1 [, w        """
    , K- t8 U# M/ Z. ^6 x" h1 z        model_ft = models.vgg16(pretrained = use_pretrained)
    4 _2 {( ^0 t: Q. B8 g, t( G        set_parameter_requires_grad(model_ft, feature_extract)
    ' W. j$ b- D7 G% \        num_frts = model_ft.classifier[6].in_features: e8 q0 ?4 l& r/ V" @
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)" m5 e) m) A* M  _& Z( l+ H
            input_size = 224+ u1 b* _4 j2 T- Y, i9 y" ?
    4 {( y+ y* ~: E+ {" E: L/ n
        elif model_name == "squeezenet":- }+ h7 C7 N& R$ R- N, @
            """6 k' A: h5 s& Z" o* C* M$ u* X. @: w
            Squeezenet
    9 S4 O0 B( G( P        """
    , M( C& q: A( D4 r1 a" }        model_ft = models.squeezenet1_0(pretrained = use_pretrained)( m8 Q( S; n- }: v) O8 s* @
            set_parameter_requires_grad(model_ft, feature_extract)
    / F: s) l  y, u) B% q        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))7 z5 ~, k1 q2 B5 K
            model_ft.num_classes = num_classes4 N6 C! k( `/ x, E7 _, O( K% U3 b
            input_size = 2242 @" c& S$ ?5 p, T/ t! [
    % y) r" C* m6 g8 a  U
        elif model_name == "densenet":" d) g; x0 ]: l- C5 r
            """
    1 A; [- g' Q: V( m. A. ]+ S        Densenet) }3 m0 V6 ^2 S, M; q  T
            """
    : Z( B' L* ?2 O7 ^8 ~        model_ft = models.desenet121(pretrained = use_pretrained)
    5 N9 n6 O( b  v; a        set_parameter_requires_grad(model_ft, feature_extract)$ m- |$ Z$ [7 h6 C0 ~: m1 o0 C% }
            num_frts = model_ft.classifier.in_features- p, g, E4 M. Z+ a' g5 [9 |% V
            model_ft.classifier = nn.Linear(num_frts, num_classes)
    : W% a! U, d% }3 l7 e$ V        input_size = 2243 A; n% i: t0 a5 j' s
      N) S8 r% R" b& X& W/ A7 T" Z5 R
        elif model_name == "inception":# l' j& U: O8 W* C  R8 w3 E# Q5 O
            """  j& o; l! f5 o3 m5 f" E5 B
            Inception V3
    / j2 V2 A* }3 X8 Z( {& {4 c        """
    5 k; _, s& K- y& S        model_ft = models.inception_V(pretrained = use_pretrained); U; E: p% j* y4 F
            set_parameter_requires_grad(model_ft, feature_extract)2 l! J$ v: R# D. d. p" K

    ( J7 L( ~1 X  C7 |& D9 v        num_frts = model_ft.AuxLogits.fc.in_features! }# s8 N2 |9 v1 Q3 f% x4 k
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes). d0 y, |% ^% q' v7 [. m. F
    - X' s' z" s' b/ l4 [" U
            num_frts = model_ft.fc.in_features
    $ Y/ V; Y  a3 s4 z# U: @        model_ft.fc = nn.Linear(num_frts, num_classes)
    : E$ M7 ?" q0 ^+ K        input_size = 299' `! k9 P# A) d* U  l+ q
    6 J5 \+ p6 _  V. J: L% m
        else:
    # X' x( x- G- b+ \6 J4 ?7 l! O        print("Invalid model name, exiting...")
    % E4 Y4 l5 V! @  F        exit()
    1 E9 m/ [- c' y0 g
    " v& a0 s5 H. {, U1 y3 ^, g    return model_ft, input_size; s6 g# C; z8 T8 Q. i# [
    9 h+ z6 Y# T9 ~+ g# |4 ~
    1
    , `& z: i7 I; v/ E1 f2, D3 y1 b( T9 m$ L. a3 b/ m# @
    3% w) Y. n' t  U/ v' e
    4
    . I' ?: s: h5 P! Z( v% U: `5
    / o1 z$ S9 n7 Q6 q6* R$ C3 Q5 s- n$ H) G0 y$ C* i
    7
    7 ^$ k/ \2 {1 O, A) N' h! i" d8
    ! F, U: M3 |2 }, v9- g, l3 W/ A; x/ h$ |# v
    10
    ( ]# Z3 e( \7 E- M0 W2 D* r11
    , B4 c+ D* E  Z4 Z8 y+ _( s120 ~) H/ r2 t& }5 t! T" L& H
    13" s4 t4 ]4 V% e7 ^% C% L- S' B0 R
    14
    ) r0 g. [' x9 S: \- F; h6 C5 q15
    9 I5 S! v! R4 V$ _- I16" F4 _* a" x) W+ G
    17' z( w+ S3 X) J% l1 b- u9 |6 L
    18: L" R4 B0 C# _3 b" A" r# `3 W
    197 \5 O+ F# c( z2 N5 Y# k$ o
    20
    & e6 [4 E: h4 h7 n) G1 N' p21
    5 u! I, B& c2 M5 h0 E22  m: L: a5 k  Q% L5 H1 @
    23( ]/ B3 B1 @9 z) f9 a
    24
    5 T5 b$ M: Q  B252 A- E+ m9 ?+ t  ~* r# s
    26
    2 A! K, Z" t1 b% P! o# |" y4 E+ W271 j) O6 F4 q5 _9 h6 E! L5 J
    283 ]% }: D8 h; ~- ~0 d  s4 q; B
    29
    7 X* `) S/ \, k& L: ^' _& \30
    4 g9 E0 i* r. f6 n31
    , ^$ P" o5 P6 m4 C9 a( ~, M- i- c) a329 ?! C! u( D2 _
    333 I0 w1 {$ n0 r+ `: ?9 a( t
    344 H6 @$ i8 J+ m+ N0 i+ G5 a
    35' ^2 n- F6 L: Y& x3 W
    36
    6 {; q2 W$ T: u0 n: F& @37
    7 K* F' o! h& I1 c38
    0 D5 I( X; j' y1 K# }8 L39
    - d2 C, o! y4 h40
    6 x$ E" _7 _2 G9 P% s6 G( h41: L% B; a6 o8 S* P: C. |, t% U
    42
    ' }/ h1 G$ d. l* e43
    ) ^; h$ I' v0 {44. p  S2 s; C# Q7 R0 Z' D  h
    45. r& A8 K9 ^/ f3 u# o6 s+ C7 Z
    46
    ) h' a- B/ t/ F47$ S/ O+ b7 s! D% Q" k; b7 c
    48
    : F) z, T9 H2 S0 l6 ?& k49
    3 Y, S" i6 O8 b" ]2 V$ r50
    ' h: ?! t" H7 R) K% d' z% }518 r  W, P- I" D- t% J7 s. k
    52
      t  `" x1 @" n; h$ e4 |7 p53
    8 ]2 |7 `1 W: Y54
    / ]' v, O( Q- ^55
    4 v- \5 _' B5 r6 |% _$ D/ `* Q56+ N( M7 x7 K: W4 m: h- y# A
    57! S' H5 u, P+ ?* r5 ^4 c* O
    58
    + v2 G$ d7 X# p% D6 q59
    9 r! Y" S, N  F7 O/ z60
    ; h1 W& m7 t0 i/ u" I* a( u: i% J61
    % g2 b7 `7 M- [2 @62
    + P0 W! S6 y. x9 b% H; W5 \636 W- x- ^. B; r. I
    64
    ( ?& h8 L8 Z* }( o+ I) s0 @3 j/ [0 `65
    ' s" y) O3 I* U3 W2 k, S" {66
    ; ^- {+ }; j; U) B2 G678 Y- o0 X% Y0 ~/ ^, ~$ k% x
    684 N; L: T4 i% I
    69! Q/ M* y( q8 A0 D8 M
    703 G0 m. i7 q2 _6 p# v3 U
    713 A6 |3 `- v: S
    72) p( N7 Z! x+ ]1 A+ n, }
    73
    9 g" v+ x) N/ l8 |  O$ n" K2 `74
    ) C' ]" [/ |5 \/ c3 h5 Q  i75) u- \4 _& R1 j8 t% r
    76
      E! E) f4 M4 \% w4 y% F9 ^$ o77$ L2 z( T+ Y: H
    78! e3 `- T6 ?" h( P# t& M" U
    799 V# C# o, I2 `! V# U8 e, y( l2 K
    809 v" s) H& L/ U. z7 h
    81
    , f' N& I) Q, {5 C: ]: B) E82
    / h7 B& u, @4 g6 s83
    . C* K3 r! i) Y0 V. V* h7. 设置需要训练的参数* [2 y3 j( C( ~% c; d  r  h: A$ H
    # 设置模型名字、输出分类数
    8 j  i3 m0 f2 ~: Y3 [model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)/ W5 }. a* C/ P: R9 I1 T% I7 r& h
    & n, m2 z) f) D; j  t5 q' v
    # GPU 计算
    " s1 G" W/ x  @' ~model_ft = model_ft.to(device)" i. h4 |$ X( k" T, \* q% M

    ) G; |9 W3 Y1 j3 w" j4 Z. Z# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取) E! N% W- L& E8 t. ^8 l: G
    filename = 'checkpoint.pth'
    / i: \# M7 C+ s* V1 e5 q% |/ ?+ g
    / t) r0 a0 h. @$ y# 是否训练所有层0 M  n2 c3 i: Q2 s" [7 E
    params_to_update = model_ft.parameters()1 j% Z& E8 v# v0 m6 h
    # 打印出需要训练的层
    # e) l) i# D3 P' ]- _& f% bprint("Params to learn:")5 _  V( I+ j$ p+ r5 V
    if feature_extract:
    / n+ w) m8 e- u. J. i- U% q2 K    params_to_update = []8 q% i  U  X' v
        for name, param in model_ft.named_parameters():
    ' M5 U' s# C5 D& S        if param.requires_grad == True:  m2 l# ]. t+ R4 Y$ [, ]$ q4 w
                params_to_update.append(param)
    8 O' ~$ d" i$ ~4 w6 f  b" Q" E) d            print("\t", name)' R) }1 g+ n9 L9 e) c
    else:# `3 E9 f( ?6 Y, x5 J# a2 T' M+ v) r- r
        for name, param in model_ft.named_parameters():
    6 r8 u! f) m2 ^) Z0 E$ ^        if param.requires_grad ==True:. `8 n8 Z: y, C. l- [" F( j
                print("\t", name)+ E6 }0 O' _6 ^: x5 w1 M# r' y

    + Q7 S" e4 G1 s. f$ R7 p( J4 W1
    0 t$ l/ L* d# X  C, r  i2
    2 d+ ^& e% I& t5 V7 \' W3
    % @) y+ L1 _+ s( T$ l49 u: a" F" T2 q: r8 V& R
    5
    ) Q; K) `% \( ^! s. X7 l" \- Z60 D" A- J# U3 l" \* S' `5 u
    7
    - A9 E$ Q: A2 x# ^% z8$ ^$ b9 H6 _# t8 v( D3 r
    9% W3 B9 _. B, R) |0 Q1 I0 k6 y
    10
    # @( d! |* b+ H: d# t11
    : w: o) q$ Y  H12
    - `2 {+ F4 q9 x& ^13
    , l" ^( i2 P. F% K* F& i* t$ u4 P  Y148 A' [3 O7 A- p. l* y! h
    15
    * w! e* u. i" u! S165 d5 ^4 s: b3 b, W1 T
    17
    ( Y0 H# |( R/ j- y5 L6 L+ B18& m) K# j% d' r7 h8 d3 f. f+ h
    19
    0 o: N/ g5 f% G20
    ) O9 P& j% d4 U3 h5 t9 s, `1 j3 K21
    : b: Q. x4 U! e0 u. K( p5 p22
    " [) F% A* T$ z1 j$ c# {" i" Z23
    6 m# f2 S9 X+ p  U8 f+ mParams to learn:
    . ?( i; k. [. L" Z' A& Q+ g6 b         fc.0.weight
    2 w! N# q1 v; n$ ~0 m  K) X" N         fc.0.bias
    7 k% b. H; A! Y0 r7 k1( m/ y; Q4 a4 h7 Z7 i( m( K! ~
    29 s- X' M( _! g% P; a
    3. H& V' c6 g  t" n( b7 x, e& o1 }# y$ j4 }
    7. 训练与预测
    ; v  d, p1 z, }7.1 优化器设置
    ; `0 r, \. s3 [2 Y9 p- }3 R9 Y# 优化器设置% _: ]4 ^6 C0 U. A9 ^! ]3 @
    optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    ) A7 e+ O1 |  @# S9 [% b) `) y# 学习率衰减策略# y+ `5 }" C" m: l- z$ z$ b' U
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1). e" H# K% X9 o9 o+ l; y* k
    # 学习率每7个epoch衰减为原来的1/10, `- j# S7 A2 C3 k
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算% y& v4 m4 O9 c  @5 T8 {
    ) `6 a1 ^- i' J( I) G3 f$ l
    criterion = nn.NLLLoss()
      }1 R" x! K) c3 k9 T  k6 t1/ B! B7 b! ~8 Z& y2 g! z
    2; W' v% n9 c& ~- ~5 ?
    3
    6 q# s' Q) _- Q; B& ^: x% r0 x4
    + B' G4 X5 W" q8 E3 {' n" q5
    9 E7 R7 m8 Q; k- b9 G2 q4 v6, [* q3 i, d: }( U# u7 e6 b
    72 A7 H% E6 }: ^1 z# v' ~  `/ c- |! h  b
    8' Y3 J# ^& j! }& X% e
    # 定义训练函数
    - ]- U, x# y! @#is_inception:要不要用其他的网络
    ; ^! O0 g1 A1 {: g( ~* x4 idef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):" j# k0 H. f. H' C8 Z: _
        since = time.time()
    6 ~4 \' y2 r5 A. c$ {' R" U    #保存最好的准确率/ V7 Q1 v$ |( S" w- o
        best_acc = 0
    & V* N9 u8 B  U9 `: }8 d    """2 F2 ^7 D/ }( B) ^$ U6 D8 R
        checkpoint = torch.load(filename)$ c' t! c0 d, Q$ }6 E# R% z! o
        best_acc = checkpoint['best_acc']- T( t9 T7 S1 N4 u, h( [
        model.load_state_dict(checkpoint['state_dict'])8 V" C) U$ M# g, s) t
        optimizer.load_state_dict(checkpoint['optimizer'])8 `. o: q7 R: ^: C4 Q9 S
        model.class_to_idx = checkpoint['mapping']5 ~( Z  Y/ D3 n8 ~
        """! a* a  q3 A) ^. p: h- b' i
        #指定用GPU还是CPU- k  j( ~4 V/ l8 `) G  M+ C( O
        model.to(device)- l1 j( L' n7 J. v4 C3 L2 q
        #下面是为展示做的% V& k% \1 D$ w0 t( M
        val_acc_history = []6 V8 {2 H: q4 N
        train_acc_history = []9 B5 O- J  M9 p6 Y( ]
        train_losses = []5 c# V8 r0 E8 A9 f! Y% A
        valid_losses = []
    ! [. \2 m/ ^+ Y! h    LRs = [optimizer.param_groups[0]['lr']]
    . ]. z6 b6 l, N* R; n0 ^    #最好的一次存下来
    ! z6 M, ]  D0 P- j# U7 ?    best_model_wts = copy.deepcopy(model.state_dict())
    9 k0 \( c) u; N# K
    8 {6 P3 ?) M7 P1 _% X& r  `    for epoch in range(num_epochs):
      A4 `: O* X" S3 i# J        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    ; i0 a4 n0 C! {0 o8 \/ S8 _  T        print('-' * 10)1 u) Q9 [7 A8 u& d. e& O" u0 D

    / r/ i/ H9 }: I4 j. L* O        # 训练和验证$ Z4 ~, h  d9 M% K
            for phase in ['train', 'valid']:$ C* U, y) ]% d. g! w
                if phase == 'train':9 d  \  b4 n8 A) y# p9 A7 X  f
                    model.train()  # 训练
    5 o- D. }3 ?% g2 N8 A8 S            else:
    ; t- e% O, w( ^: [# C9 T' ?                model.eval()   # 验证
    8 `+ v) ~$ O% l+ J  t1 N7 k% s* e8 N
                running_loss = 0.0
    4 C- `+ `, P0 s! C" Q7 u) k* y2 G            running_corrects = 0
    & ^# W1 b7 Q) G0 e. o7 Y$ _
    ( n; }# t, M- w3 N0 d            # 把数据都取个遍
    $ S0 H# y3 _. z            for inputs, labels in dataloaders[phase]:# K  x( x  F$ g, U7 k4 ^
                    #下面是将inputs,labels传到GPU! u. M9 R! a9 i
                    inputs = inputs.to(device)8 p7 i3 B0 u* [$ }
                    labels = labels.to(device)
    % ]4 @  q3 m. x, |
    8 w3 c  ^5 `0 v6 B+ b                # 清零
    0 i: Y: L% C; x& @                optimizer.zero_grad()
    . g/ Y& n7 S. y2 `                # 只有训练的时候计算和更新梯度( a, M4 y# b8 }. J
                    with torch.set_grad_enabled(phase == 'train'):
    ' M6 e! K: j  A4 X                    #if这面不需要计算,可忽略
    4 P5 y9 ?& B& a" i* R: ~                    if is_inception and phase == 'train':
      [; H  U* U& e; `# p8 ^                        outputs, aux_outputs = model(inputs)
    8 L5 j$ c3 E+ x8 i                        loss1 = criterion(outputs, labels)
    0 c- L2 K; H; L                        loss2 = criterion(aux_outputs, labels)7 e$ a; o3 f0 y; D
                            loss = loss1 + 0.4*loss2
    * e% [/ U: _0 K( h: [                    else:#resnet执行的是这里) F7 v' X$ G  q; P
                            outputs = model(inputs)
    ! ]) E& Y0 I! p* d- O                        loss = criterion(outputs, labels)7 N5 z8 G  T2 K
    8 E3 Y5 h" }" w% ?% Q
                            #概率最大的返回preds( v  ]/ c( C! h( S# @3 p5 j
                        _, preds = torch.max(outputs, 1)
    9 w! f  I- l7 n- M8 y% j' }
    3 ~$ i: n1 m9 W2 t- V( C7 Y6 |                    # 训练阶段更新权重8 d8 ], B) C7 p( I
                        if phase == 'train':
    ; @% w( x, {! l$ k2 j2 d                        loss.backward()
    0 w5 S2 }! S( D! V) e                        optimizer.step()
    $ z7 Q; Z' l+ k& ?. i, b
    / K0 s" ]* f) D* D+ C$ k                # 计算损失- b& z4 e1 G2 j' A7 I) O* L. s1 ]/ ?
                    running_loss += loss.item() * inputs.size(0)3 o" C) h- @9 N* j9 L
                    running_corrects += torch.sum(preds == labels.data)
    9 V3 b2 K! x- c  i% z1 s+ C* Q3 i
    0 W( f9 [, e) h+ O2 L! V% T5 f            #打印操作
    ( ~* L, w+ `2 k7 H2 n8 s' H  w            epoch_loss = running_loss / len(dataloaders[phase].dataset)
    - M! N0 e; o8 m) X2 v            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    9 F( F$ W0 r9 T4 Z8 e: ?8 f
    4 C+ [7 Y3 i$ l' }  t
    / g) n/ f3 u, e. n            time_elapsed = time.time() - since
    8 m( K0 s2 G; q: d/ Z3 H% u$ V            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    / z$ l  i6 q$ I( B  \$ ?3 A) D            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))) b3 J9 H" U' M! d4 K
    , V* n7 S- g. C# `0 [' [# B0 p* q
      G$ o6 h9 ]- N3 i; y: C+ ]6 P$ t
                # 得到最好那次的模型
    " p8 I1 n  v6 d1 O" T            if phase == 'valid' and epoch_acc > best_acc:: l  [  j& ~/ t4 ]
                    best_acc = epoch_acc
    8 j7 C( T- J; P$ l4 z5 o                #模型保存
    4 r% Z1 d. E6 M                best_model_wts = copy.deepcopy(model.state_dict())
    / n& ?8 |& R: b& M                state = {
    # a; c& G& d2 Y, [0 s3 u                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数2 M5 R* l1 n5 v& o; b) s
                      'state_dict': model.state_dict(),
    0 p" a+ Z5 V3 q2 `                  'best_acc': best_acc,
    - k, n. ~9 ?' k1 b                  'optimizer' : optimizer.state_dict(),1 ^' x; B$ |- u
                    }
    . I- o& M6 B: c# u% H5 l                torch.save(state, filename)
    8 v5 `( b6 S# }! G$ Z* z            if phase == 'valid':" `7 f2 r" L* n* [
                    val_acc_history.append(epoch_acc)
    , x, Z: S2 h2 `+ s1 M1 o$ m                valid_losses.append(epoch_loss)
    % y  R9 z7 |* O8 f/ F                scheduler.step(epoch_loss)( X) c4 s. Y: }- D( V8 m# M
                if phase == 'train':
    6 g3 i; ~, E' Z. y6 X                train_acc_history.append(epoch_acc)- p/ U* x2 D6 U0 c
                    train_losses.append(epoch_loss)' D3 Y, S, l* n" W' i
    * G& d" [( s) h" `
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    $ p% x% ~$ M' h" ?) e9 s9 M. y        LRs.append(optimizer.param_groups[0]['lr'])  x- l$ L# |5 E( z, ?' k+ r: b
            print()
    5 B# ~7 g2 C$ C
    ! Z6 s$ j% N% d: o/ l1 b$ t    time_elapsed = time.time() - since: j5 Z# ?+ L+ a+ q6 K& R
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    $ s- K7 S' W* u! Y5 g    print('Best val Acc: {:4f}'.format(best_acc)); }4 G$ y7 }, \) x3 b) T4 U

    * n& M9 z. V' X/ m1 F* W    # 保存训练完后用最好的一次当做模型最终的结果
    ; l, W9 N9 j7 x. C5 H    model.load_state_dict(best_model_wts)- q9 Z5 S) L# G+ A5 m' p3 i
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    6 F+ b2 |" b- y5 ?8 g+ m" g' a1 b1 g2 n! N; x0 l7 J" H' f  y
    ! X6 l1 u4 F( H9 V7 `
    1$ s4 b0 x- w! U3 \% b1 B
    2  I! w8 Y' i% f
    3
    6 }- o4 `3 b0 X' W4) t, F5 k3 W- Z  _8 X- O' q
    5  i& J3 k0 ]( F2 y
    64 k1 M1 m' i" h( W$ M* k1 D, T: C
    7) Q* Z( j' U, s8 x9 M1 I  n
    8
    & d( T! ~, F, c0 B9& R, ]' a# z- q* N* G
    10( w$ R5 ~" ?9 \6 a* V
    11
    : x; j4 A* Q6 ?9 M8 e12' C) Y* q# y& E0 G# ^. z' \  [
    13' f7 c" h/ p$ ]3 u. ]6 I
    14
    * g% W0 @1 Q: p9 c15& e- O1 _( D4 d6 D7 m7 C; h/ `, R
    16: V& b; [* ?: y! [; ?1 \
    179 {0 v9 v% I* E; [9 g. ^
    18
    : a/ y; h0 G7 V/ j2 }; n1 h( O19
    $ n4 o- w6 ^8 J' Q2 u8 p20
    $ L4 K% A* K% b& v' w: X21: i6 R- k& R) A
    224 n" C; t" j" y/ v# e9 N
    23
    $ y2 M' M2 E6 }+ A/ r; v' p24: l/ g; ~: ?! Y( e: Y
    25  y8 t# R/ q4 R1 Y+ n( @4 S: w9 I: v$ z
    267 p1 y, v1 ?7 K' t' e" I8 N
    274 P! ]" J# L1 i3 f% \1 S
    282 y. Q) B* b' D, m; u) }& l
    29
    ( Y7 y! X. t2 }& D: e. ]30& q* G* a  {6 {
    319 r9 r- n! N0 c, v( u" ~8 h$ m$ Q
    32" h+ T+ k) W0 y9 N
    33
    - I- S) ?/ S. @! A34
    1 ~( X, E# w& l5 O) o$ G35
    $ J3 c  k. O1 e368 }/ X, c2 l- x* u% ?
    37
    1 U5 _/ B" t1 _5 w. }38! n) u1 Y( c1 M9 |) |% R
    39
    ( ~. z, Q2 v+ U- }# Y0 i406 Y. `( q( M4 K5 M, `% Z$ U
    41
    5 ?& [' g* z' M* t3 [422 @% T0 z5 j2 l) u: z
    43! J5 l  v4 J/ f
    44
    - {; d) J4 J5 p1 s0 r45
    2 l% v' Z$ {2 r* K46+ {; |3 Q6 L0 H% x1 X; d
    47
    6 u0 d2 x8 `& F5 s481 l0 d( y4 |. V% D
    49
    " w8 c& W- K- }50% ^* B4 M9 }% k* Z6 O* n1 ^4 p
    51
    0 N4 u; _' W* ]- T7 t521 ?0 R, H# Y8 p; {: Q- k0 V9 X
    53
    : A- I4 ?8 Y+ W9 V+ f! q$ y54
    ' n: ^' z7 R8 s8 G& S" x558 E4 J5 i$ b* L5 w( I' G
    56
    7 @1 C7 E  {3 }* i, [573 R5 @' [9 T2 _( O2 N$ v' U& T/ D
    58
    / ]% @- C# u& O; i59
    ( t6 g( T0 n: C4 _+ f6 `+ o60
    ' l% A; O6 V6 U. ]61
    0 P( u- K+ @2 B+ K4 V3 o6 p62
    ) i. r; c; I2 y- \6 q% |" @63) j) d7 Z) t3 y: o9 P
    64
    / F/ B2 ]5 v+ f% {; x2 B65
    3 D6 I8 m4 |8 ^3 g( u66# K# l/ q# s) C, k; E
    67
    1 u  j7 O" Z' s. J% c5 w68# X2 R! o7 o% h+ W8 n
    69( K" \+ }7 s$ |3 @7 m0 j  f/ n; C
    70% a. ?$ b+ x: |' k
    71& i1 p( F; [9 _8 i# x
    72
    ! v$ i1 ]) \" g  [7 P6 R7 y/ c73
    / ]$ A( P8 g1 {2 w3 V  A74( W9 f# }- Y2 q0 E2 a
    75
    - p( ], Y5 U- k0 y+ ]* S76
    ) @* t2 t# i# v3 u3 O% o  U( m77
    2 N0 [+ y2 m4 k6 w* x: R6 @' w  ~# f785 r: g( ~+ g8 ~% D
    79: Q4 @  y. [/ l
    80
    / ]3 U8 e/ B' L7 r3 ]/ ?81
    3 P0 d8 r' _+ ?' j5 {# Y82" b/ ?, p% B, H7 W4 X  D
    83. }, Y- Q5 Y  C5 W, x. i; }2 {
    84
    / I5 z! H( ]1 o! H  z; T" L85
    8 t8 F( p9 f$ _6 w& f- o% Q  _86
    5 S: I& P* G. i87" F; @8 k! X4 w: R) H7 l) j( S5 h
    88
    3 x4 d: @: d# {# k1 T; \89
    / n% k( G; O1 @90
    : U6 g: |# [, L* l& n0 S& T915 l6 P3 r2 `, c' a
    922 y* E) x9 o! g/ ]! v: v" g  \0 y& ?. g, N5 S
    935 w' \+ X1 H8 z0 k1 r5 P8 v1 C
    940 B$ \: l( m8 k% K
    95
    0 _" p4 U* L9 d% p96$ |/ h* e2 M; G7 c- _
    97; k4 K) e3 y1 F9 j: A
    98- a2 V( G, g* ?( c
    99
    / [/ }9 @# Z; ^# s; F5 Y- r100
    9 Y# `! d# I, B) Q101
    0 }3 s% I  b8 m" O3 k102
    ; K( m3 {# _8 c( q5 u7 x103; a) G+ v, e  O, d; t* n( z
    104
    , Q2 m% z6 M& O$ F1055 _4 N& z. z5 v2 M
    106
    + s3 g. U1 Q6 J- G1 n5 l2 y107
    5 p( L  Q$ K' d- e8 I1 A2 \108' t; {0 j$ ^7 Z. \9 K, X& W
    109
    * T( A' x0 L+ J& N& k7 [110# R. P2 O" F, J9 e- E; c: ]
    111
    & a" y! ?0 I6 \% W: V% I* u2 X& ~: r1126 M' g/ ~9 S# O9 F
    7.2 开始训练模型) e% Z$ r) P4 l. a' m/ E) s- U
    我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次8 h) J' Q" \7 n( U9 t( a; A- @8 ^
      `4 A+ d1 A& c1 W1 ?
    #若太慢,把epoch调低,迭代50次可能好些& Q  B, z' j) Q, X+ Q
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合. {4 d8 t9 P" k7 T; ?" \
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    : {2 ^  G, p+ e) C; I% B5 H+ W& Z  a# e2 ?5 R3 R4 C0 y
    1
    * F: N9 x* W" P1 t2  R  C# e1 j& f( o9 b  N% j
    3. z$ F# _* V: Y# [" C/ w1 D6 ]
    4' K/ L" ?1 Z- H* k$ Q% d
    Epoch 0/41 u  D* [, q- T) z6 t8 n' w+ |6 N
    ----------
    6 c- X% ^# q9 TTime elapsed 29m 41s
    ! x4 d7 A, A% [train Loss: 10.4774 Acc: 0.31475 t5 @! J5 N' z% O; i+ n* H# S9 h
    Time elapsed 32m 54s
    / _# m8 v7 l0 |" u- kvalid Loss: 8.2902 Acc: 0.4719
    ! S" O* W* G/ N% ~Optimizer learning rate : 0.0010000
    / ?- T! d! P% ?$ N/ X. T
    3 _; l2 y7 J! a: n4 N( SEpoch 1/4
    5 E+ W/ P; g+ u. I8 n% }----------
    4 Y5 B& O7 P' h; d0 [0 d6 r( xTime elapsed 60m 11s1 U2 [7 {9 T4 a# u( z4 R0 q
    train Loss: 2.3126 Acc: 0.70532 ~2 L: Z4 H" W
    Time elapsed 63m 16s. j0 b+ E7 V6 M
    valid Loss: 3.2325 Acc: 0.6626
    5 o/ p: Z4 k' M! d- FOptimizer learning rate : 0.0100000# w; a" W: X# \* U6 i  u5 w  C
    5 [" ~% l; X3 T* t, r
    Epoch 2/4
    6 c: H( _/ Z) U: J----------
    " c( F1 K0 a, O: e9 W% `3 bTime elapsed 90m 58s
    ) N( N2 g- l# e( \& r- v6 `train Loss: 9.9720 Acc: 0.47346 @: L+ t6 F. m$ g' B! X. |2 Q
    Time elapsed 94m 4s; B5 G( ]& w" d
    valid Loss: 14.0426 Acc: 0.4413, G! q1 o7 b, L4 E5 X! k1 U; U
    Optimizer learning rate : 0.0001000
    5 V3 ~# @" R# L* {3 \: h+ W& W5 j0 j' K' B
    Epoch 3/4
    , N% C+ {, b8 c4 g----------
      i$ T1 Q0 R. M$ G* k6 iTime elapsed 132m 49s" o1 H- x& t8 @7 O6 R0 q! C
    train Loss: 5.4290 Acc: 0.65489 h! o* x# f% C5 O  [
    Time elapsed 138m 49s! {& W+ M3 \0 v; S  X7 I
    valid Loss: 6.4208 Acc: 0.6027
    0 }. r# s) E$ o) ], FOptimizer learning rate : 0.01000009 o! }$ {8 s3 {3 g- l9 E- t

    1 g- f1 p5 M, Y9 I7 |. O" QEpoch 4/4
    . ~' E. t2 _5 M! p: y' c8 p, x& K" f----------4 I; E0 e8 w) W0 Q2 {0 S7 j
    Time elapsed 195m 56s* Y- ~) ?0 `* ~; u
    train Loss: 8.8911 Acc: 0.55190 q! N) G) }2 u; X
    Time elapsed 199m 16s. }! L% |4 s0 m0 z( y' g
    valid Loss: 13.2221 Acc: 0.4914
    " v' \" K4 O1 IOptimizer learning rate : 0.0010000, a; ]; m8 v. S

    , Q- Q  L! l' ?& m& s/ C" T3 cTraining complete in 199m 16s/ \6 m) }+ d" B) ^
    Best val Acc: 0.662592
    : B; u' z% l4 C: {( ]! P; e" K' e0 [
    1
    ! m, ^* ]: z; u1 E! d& X2
    ) B+ m6 M' F* w3
    : d! j3 T. t: t/ q8 O/ i! Z  B4% [7 d" P$ b1 N; x" J$ h" N, Y
    5
    6 \. w0 G) ]; }2 D) C' i2 Y6: m7 b1 Z0 U# P& x
    72 }! ~( _* g1 _/ k+ Y& a
    8
    " F# L, L4 k% r: [$ W+ A% K9
    3 w* W# n% e4 X7 p10! G7 t5 u) g6 d
    11
    3 A, r$ ]) o) Z123 T3 H, D- X! ]4 Q
    130 D7 `* c% D8 N& C8 f1 ]% @
    14
    : |' [/ {$ _3 c/ Z; f* B15# l/ X3 r7 M) S. Y2 Y
    16
    3 P: }' _- k; s/ Q# N7 \: T" [17
    / {8 T9 b; `1 g4 G2 \18
    + Q8 }3 ~& {# {6 W( S9 Z19
    0 N1 p8 Q4 J8 @2 m8 j20
    - L, |' k" ]5 v% }21
    ( w. ]  p5 ]( K  G! Y: o222 J( U2 d9 V( X& O8 ?: V" [% G
    23, K' C' d/ H6 \
    24
    ( s% Y% i. p- K7 r3 B7 P  v: b+ u25+ c) T' R4 {( s' b0 N0 s6 F9 w0 \
    26
    - L; P# ~! n! y6 W  n# v27
    , Y/ F" ]( i8 d* \. a28+ [4 a0 ?2 n& m, U* ~
    29
    & n1 k6 p1 g2 u30# u" _: k) H) L8 {4 O  O: M7 D1 z
    31: F3 U* i& d0 y. u
    32! x1 D- y- g$ b. e& i2 u, @
    33
    ( `( J: |( R( F+ w+ i! a347 l9 k1 u) F# k+ P; Z
    35
      W) X+ F% w  M$ g36
    % k; J  b! A+ `9 D+ L  U4 `5 ]$ t37
    7 d0 O8 _/ U7 s9 M  x( d38& \9 M. l0 r: J1 L& I, P/ K
    39
    : C! k/ |6 d! U% S; u4 h+ r. \0 d40: u8 b1 z. y# B8 ^1 k6 Z& k: b
    41+ ^) K; ~# e4 B; D4 t& x: L' E
    42
    7 H7 H( M  T0 H* c, V5 h: t  Q7.3 训练所有层
    3 o( h7 A/ h! _$ T# 将全部网络解锁进行训练
    $ y& F  ^2 t+ x3 k3 [for param in model_ft.parameters():
    # l) ]- G5 j/ u+ q5 c% v    param.requires_grad = True
    % C+ O* o% m: {6 ]: m: R5 A. a+ Y1 a3 D( ~; }: y. P: Q
    # 再继续训练所有的参数,学习率调小一点\: V' {% r# ?1 q7 p, c
    optimizer = optim.Adam(params_to_update, lr = 1e-4)
    / d# _4 n& j! z2 z# B# N' k) ~* q3 [( i- Zscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    # I: [" m' l9 T8 ?. e9 L2 Q& T/ _/ T: w1 F4 g7 i6 h8 k
    # 损失函数% ~5 y/ a7 c' `5 D- u2 @
    criterion = nn.NLLLoss()" Y6 r! O5 L6 D6 ?3 B' V
    1) }2 n5 `$ U6 X
    2( ^6 V  _  _+ j1 ~! O7 C
    3; [% l1 @5 R4 h3 }% ~6 r* U3 g- B
    4
      V* m# _1 h- T- E5 a6 `# X9 ~5; I5 l) t6 O) K/ `
    66 R2 S3 i# Q: @7 h- Z2 h  s
    7, w3 @( g9 \: w1 r
    8* |, t+ j' j- J
    9
    / O& F' f" m- O' Y" D: _, |10
    8 C# Z: t. @0 ^8 D3 C- ~) r3 Z# 加载保存的参数5 j3 B2 B" t8 H- b
    # 并在原有的模型基础上继续训练% f9 s* }- z  F6 g
    # 下面保存的是刚刚训练效果较好的路径
    9 {& l/ S5 D' w3 m% C3 n' lcheckpoint = torch.load(filename)
    0 M' G! u" i3 {0 Y2 ibest_acc = checkpoint['best_acc']
    ( ~# |7 {2 S& r4 V/ [; h4 Jmodel_ft.load_state_dict(checkpoint['state_dict'])
    7 d$ R1 M  ~' ?* {6 Xoptimizer.load_state_dict(checkpoint['optimizer'])) L; ]) H2 j5 }  R3 R! ^: D
    17 a3 M5 r1 l* x
    27 g  U+ w$ X" T
    3$ t: J- C+ H4 o9 Y/ p# G
    4! J( C4 ~  ^, c
    5
    & z; K& O8 }" E/ a60 Y: S' X8 R& w
    7  V: P, C  v% Q4 P' e, ~: x. {
    开始训练
    $ t, x0 U% c* H* D* p2 c8 ^+ B注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    ' \2 w& m# ~7 O7 ~
    ( n$ i; |* [+ o# a; Q& p/ D) ~, Lmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))6 k. C9 \/ m5 v
    16 }1 r& C$ e6 \- h* L9 `) u
    Epoch 0/10 M8 T5 N# b: |1 [8 l
    ----------! H6 w# ?: q' }6 \
    Time elapsed 35m 22s
    9 p% z; G" F- V4 Jtrain Loss: 1.7636 Acc: 0.7346* U: s! Q4 S+ h
    Time elapsed 38m 42s
    1 v' z7 A. M7 ]/ U7 ~1 o( e' J9 y- ivalid Loss: 3.6377 Acc: 0.64555 ]8 Y( T) h/ `( i1 ]
    Optimizer learning rate : 0.0010000
    : y6 v; O/ J% B! T  ?& |4 Z' V1 H( e$ L8 o( I$ A9 T
    Epoch 1/10 W* N+ g5 Y& y1 Z5 S( z
    ----------
    " ~/ \! v3 |, Y* ^' m" BTime elapsed 82m 59s8 N6 |7 ?  z/ m! o% ?5 O
    train Loss: 1.7543 Acc: 0.73405 c- ]7 z: T" u/ m! K! k, T& H
    Time elapsed 86m 11s
    # N" d% I# X; C4 N; }; Nvalid Loss: 3.8275 Acc: 0.6137
    / M1 E5 \6 a8 H6 i# K% NOptimizer learning rate : 0.0010000& s' f" X: H0 Y' m

    % j9 J0 D" H8 N; |7 J2 k8 WTraining complete in 86m 11s% M! @( v8 [$ l. S- \' r
    Best val Acc: 0.645477
    ; J' R8 v$ H, K: U* a. v7 O; r  l* s$ c1 h2 u9 b0 l
    1) z' S! P3 H. i$ O5 Q; R5 i
    2
    ! L7 D* D4 b; w% ~3* a& J8 }7 M2 \+ C5 r3 n" |
    4/ s8 Z8 t9 H1 A% g* ^
    50 |: g, }: r) L$ x! U- V4 f
    6, q0 T! [$ F' ~) e; a" l; ~
    7
    4 ]; n4 ~6 X9 S$ O. v' M: [8
    , j1 [/ D4 p  F- V1 E9
      M  i1 t( w& |* g10# \- V* [) d0 c8 a$ n0 x+ ]+ ?
    11$ O6 x/ R1 h- J) U1 O1 D
    12
    % Y! F2 c% M; ?; A$ p139 r# M9 r$ i! q: `1 u5 {
    141 @3 I: |+ R& L9 d
    15' @7 K; e0 j# [4 F$ ]; I% P$ j# d
    167 x  z/ m. d4 c  |/ r( V
    17
    % j. ]- [- d) e; N3 e; U& v8 A4 I* g18
    7 Q8 n2 O4 _8 v  K( ]8. 加载已经训练的模型5 ~# i# F# I& v- _6 q7 j  F
    相当于做一次简单的前向传播(逻辑推理),不用更新参数
    3 h9 C9 r$ l9 ?6 \6 p' i6 I7 A  [
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True), `, x/ M' ~8 y0 @: {9 C6 D

    ! Y5 ^) g- b* v1 @6 l6 T. V  `# GPU 模式
    6 S& ]/ g" D" e4 ^- E* cmodel_ft = model_ft.to(device) # 扔到GPU中
    " Y2 i8 G8 `# X
    & \% C8 C* D* L! B5 ]# h+ X" _# 保存文件的名字4 ], T) t4 T  Z8 f) ?
    filename='checkpoint.pth'( r8 r7 k( M4 {9 G6 E

    3 n3 Z* g2 N& D7 L' o0 \: A# 加载模型0 [3 I5 ~& _. I% ^! b( j# {& D1 ]
    checkpoint = torch.load(filename)
    ; y+ x. V) M0 J$ ~# x5 ]best_acc = checkpoint['best_acc']8 ~) M9 E7 D% Z6 T. J* F$ [# m
    model_ft.load_state_dict(checkpoint['state_dict'])
    / d: ]7 t+ D- C( v1
    . m& j, p9 E& m' Z0 x- ], F2
    % a0 b. p7 _4 `! N- Q- B3
    + ?; v" S9 {% B: a- m7 Z* K4
    5 |, s# P1 x; U3 a( I5  X  O8 F# e7 o. }  D" l! F
    6" l9 P+ G% c+ V
    7
    ( W  i7 E* n3 q/ f* U, I88 C9 {9 [) S0 v, S: G; D
    9& t6 c% q0 ^  V8 W  n7 c: X! C
    10
    . t. C) b- N9 W119 [7 ~6 m# X1 \7 s! Q- T( T
    12
    4 x3 s9 ?- Z6 h0 T  [% n* B<All keys matched successfully>- Y6 u7 Z$ D# p* X3 l/ W3 Y
    1' i- L, W# ^0 U5 |
    def process_image(image_path):9 f& d5 C. Y2 i# B" V
        # 读取测试集数据
    2 N3 v' A/ M' f# `    img = Image.open(image_path)! U7 U) p* U# i
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    8 {! M* V, `3 o) S2 E+ B    # 与Resize不同' d# D/ v6 Z& Z1 m4 Q
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    - V( C( r: g0 v  z    # 而且对象调用方法会直接改变其大小,返回None! x) a+ P9 T$ W& Y6 M
        if img.size[0] > img.size[1]:0 z+ b7 {! I: P  v3 T
            img.thumbnail((10000, 256))
    9 ]! I4 w4 M, u( F* F, [    else:
    . b5 a/ Q$ Y' S        img.thumbnail((256, 10000))) B* k: k1 t( `( h3 E  V2 [5 d  `% m
    * j( v& k- d, x; u5 M1 |
        # crop操作, 将图像再次裁剪为 224 * 224
    . y3 q4 Y5 q$ W& e  Z: j4 w    left_margin = (img.width - 224) / 2 # 取中间的部分# v+ @' j% j8 }) G$ R0 O
        bottom_margin = (img.height - 224) / 2 # X4 |' N% V) s+ m$ q) {
        right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度! I2 N6 m  E2 ?5 r
        top_margin = bottom_margin + 224
    / L- M9 M+ n& ^5 U  f9 ?+ ^0 ]  P% [& [. j6 t8 a9 N
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    + Y. b4 j9 g# Q
    6 a) w1 T5 M# m( V, q    # 相同预处理的方法( X, Q! h- _2 y3 y* b( N
        # 归一化
    : v; r. v! ~+ A! h2 ^' Y- @    img = np.array(img) / 2552 Z; j1 M. o/ J4 h+ B+ o) Q" C' |
        mean = np.array([0.485, 0.456, 0.406])
    $ _7 T: M4 g# d( q0 f    std = np.array([0.229, 0.224, 0.225])
    6 r; B' H+ O, k3 ^, V    img = (img - mean) / std1 h( G- c4 I% H7 P0 p/ @  f

    8 h( a- M  `# |( T    # 注意颜色通道和位置3 o. E, y/ u* u- ^" p+ U/ t' L
        img = img.transpose((2, 0, 1))
    9 J7 F9 Y7 X' J, m) B8 |$ q( F* C% K4 L7 A- f6 @
        return img
    3 D# i/ n$ M0 m; ~, U7 v# {5 k% q3 U+ A9 W) r% k& \6 A4 g
    def imshow(image, ax = None, title = None):+ y+ f2 }5 c% D
        """展示数据"""
    6 _) j. w# s$ s* J: _7 ]# F    if ax is None:
    * o7 y; h) ~! [        fig, ax = plt.subplots()" F% w" w7 E' h5 C/ W3 d+ e) d
    7 _6 T! s3 y- `. g1 Y8 S
        # 颜色通道进行还原
    ( w$ P- m, }2 n0 s' Z/ G8 [4 }    image = np.array(image).transpose((1, 2, 0))- J: ]/ l; G3 e$ C7 _0 G3 O

    , t9 p5 N8 n" P6 T2 P    # 预处理还原/ _3 Y; L8 D! K$ @
        mean = np.array([0.485, 0.456, 0.406])
    2 B) j' |; ~9 X( Z/ S1 L7 u    std = np.array([0.229, 0.224, 0.225])2 I' s, H! Z( K& Y8 e
        image = std * image + mean3 d; g" R, s( X/ _
        image = np.clip(image, 0, 1)
    : u4 G/ V3 N2 w) t  K  ]2 q; f& X" {6 L7 F1 ^) D4 R3 d: O7 b3 U
        ax.imshow(image)( w& M5 a0 u) ~  ^; a1 u1 H
        ax.set_title(title)
    7 U7 q: o; k4 J' E- i+ \2 ]3 J! Y% q( {) n0 V' e9 V2 k
        return ax! k3 x; d( B, V8 l" e

    , X0 ~4 O! B1 P( @image_path = r'./flower_data/valid/3/image_06621.jpg'
    " r! @: q5 X/ G- |. Simg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    + v2 B" ]4 |; q6 Bimshow(img)4 w& I# ?' f; O: [& q: e! }
    8 f: }- t5 y$ F' ~8 z
    1
    7 o! l4 T6 h  \21 Z% v0 R9 m9 R
    36 Q# |1 D: ?; q  {+ x7 E/ J& S5 ~/ {
    4
    , A* t6 O3 h0 U! f( Y$ ^5
    ) `/ u' e2 R$ R6 n/ g9 G6
    , ^7 l  ]7 K- F5 m6 B8 f* |! R( c- ^7- W( o3 w% F5 W* S5 F
    89 f- Z- U9 H' R( t
    9
    " j6 [* i( E% ~7 |& l. B8 w0 C10$ o$ J# v% T4 j/ D& w
    11
    6 u/ y8 z- R+ z12) b& q7 ?2 V) w; @
    13
    2 d% v2 g, U* A% V( K148 Q+ }+ K6 e" q8 N& J+ i/ W
    155 u3 m& v: ]4 m* ^7 ]" z. V
    16
    ! ?  _/ h( M2 e* L! v17$ w% M% l) I. q3 u+ y
    18
    - v- N9 T, m; d2 M  `+ d( j19& k+ n. M) X- K0 U2 @, ~
    20
    / X& c! w& L' B2 q/ m& b1 m21! U/ D# t  {1 I
    22! o+ R* \3 ]. [8 x5 ~7 p+ v4 q
    233 N' M2 p% i* |8 s) c+ Y
    24
    " ^4 `1 h; e) q+ @; X: i258 ]' s: ~. A' V) L2 R5 `2 }/ i
    26' x/ N; v4 _2 R0 l3 {8 t: h
    27) G, z2 {5 n/ a+ B/ G, e
    280 a2 I- C, Z. t( ?1 @$ X! P
    29; S: H" C/ L/ ]% v
    30
    2 H/ ~: M( |, B# }! x5 w316 U' z. E& T9 P  }9 K
    32
    ( x: o( _& u0 V" u. I335 [! F- F: n7 I9 `  |9 M4 T3 Z
    34: o' C0 i8 Z& a. j
    35
    8 a/ M& K1 D) ~. C4 h2 P  I3 r' J36
    - j/ e0 x  m8 d2 \& x; y$ _37- z% |9 _7 r! X4 w* W# a- j
    38; Y' `, }) \8 D  y2 T1 z
    392 m1 r6 I  g$ ]# T2 ^8 O- ^/ e
    409 @$ P8 w8 E( l) y5 [4 \
    41* R& p8 \% v' \
    42! r$ ~% Y; ~- f0 @8 F
    43
    5 v8 C% G. D& N+ u; c44
    1 B1 _/ b, |/ w2 @6 {. \45
    $ w6 \/ K& ^2 m4 ^, i; a46- ?- {: _, ]* N1 w
    47
    ' `9 {8 o2 @8 p. v) H487 M  N+ O; ~$ I! q0 z4 v2 d$ B
    49
    2 Y% e- E8 V4 m( i50
    1 o" g3 |8 z% ]# g51/ }( j- u) |& k8 H
    52! p) }+ c  }2 W/ t, {
    53! h* ^3 Z4 F0 _( ~) g
    54+ ^1 k6 t3 k2 i8 A1 H- i
    <AxesSubplot:>
    4 o1 `3 O: Y, j1 G' Y' Q4 N2 b1
    1 H5 P3 N1 o& @9 _1 k/ m8 V# L7 m
    上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    " u0 k7 `6 X5 C' ?4 {
    0 X% ]2 `, ]$ ^; K& x0 dimg.shape
    ! m5 m- f5 H' r/ x5 z& v1$ K3 `3 Q, B2 k# J- a
    (3, 224, 224). c$ \: O+ Z$ i: n( _1 `- n9 r
    1" C  \; g" X/ s8 M, A* L! s0 V
    证明了通道提前了,而且大小没改变
    5 N9 w8 S( D8 {( Y3 [, \( |' d( @- ]/ b7 v/ j: P
    9. 推理
    , G. x* w( p9 J+ h$ l2 n$ \7 ]img.shape
    9 b2 B7 t; p- t* [8 J2 X3 `  {1 R' J# O
    # 得到一个batch的测试数据
    $ o/ F+ ]5 Y* g( _$ A8 H4 `dataiter = iter(dataloaders['valid'])
    1 }% u' v. P  I: n1 p/ d* Rimages, labels = dataiter.next()# K' ?) L  c6 t5 N
    4 y* a$ ~0 w( I: B
    model_ft.eval()' F( I. M+ f  M) K7 p- n

    0 |9 o* _, x; Z4 V/ }- A% r0 F0 Cif train_on_gpu:' H3 q5 f# \& h+ D; {: @
        # 前向传播跑一次会得到output
    # @5 _, h, G( c! j& h( x. b  ^' h" d0 t    output = model_ft(images.cuda())
    / E+ b  V% w4 l* X3 ?else:( n" N1 X  F* r) ^; N6 M
        output = model_ft(images); H6 e! |8 [) g5 X! O( Q* s! e: L

      O/ e1 T' }+ `+ X# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值' s+ u  Z# D& D1 K. d9 K, `- Y9 U
    output.shape
    2 P' f, l/ i9 V9 ]# D
    9 s$ l2 n4 W# G! U9 q1. h+ L' P! ~, d1 o5 a
    2* ]: U: ]/ n- I
    3
    5 }: d* }" U# q' M: ^4
    ! R, g. a7 P" w) m+ e" @5
    3 D- v6 v+ X$ |' Y/ h, N: @6" T4 n5 H1 l" Z$ {
    7
    # J# `3 X5 A2 b7 R% V! [7 r5 D81 a, q& @% P* N) i. S
    9' p+ l' L+ p: L) G( t
    10
    6 L. l' @7 U! ?- y1 P5 i% ~: Z11
    / D1 n9 h: u# Q0 m- v" V# a: S12
    , P3 E6 {/ ]+ a135 }" w" W* H9 u/ r- q+ f+ E
    14
    ; Y: k( [9 S7 `# x% X2 l% Z8 \  |9 b15" @# k* Y/ \' {% w& {7 E
    16% c( ~& Y! K4 H2 n# g( v/ U
    torch.Size([8, 102])
    ( D5 f1 W" K$ O: e1 i( d* m1
    1 W# D: l; A! M' U( |1 `9.1 计算得到最大概率
    - }) a  w* _* E1 X6 I1 M" u2 R3 B! N  L_, preds_tensor = torch.max(output, 1)
    2 I. M; S1 s5 C- v+ k. k+ N2 E
    5 |& j& B" u* i! ]; Q: Apreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    , B2 g5 s, J8 q) N, G1+ e4 }$ N& n$ s4 U
    2
    ) v+ w9 _2 _7 i( i/ U35 a1 J2 g# T$ Q2 y
    9.2 展示预测结果
    & s# S: k/ w( Q2 A& afig = plt.figure(figsize = (20, 20))
    ; r5 _- C3 m0 K# @/ F" |3 lcolumns = 4  t$ ^4 [/ W2 w, [  {8 A+ l) h& D; y) L
    rows = 2- M% t' W% y) [( H" @5 |

    ; [5 W! E2 ~) \for idx in range(columns * rows):
    ' |3 `$ a( T4 x    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    6 n6 i% m) N6 F( B. Z    plt.imshow(im_convert(images[idx]))
    ( i. Q) [" q  s0 {4 G  U( A    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
    : [/ I/ n0 d' |/ \1 |6 ]# P; {( T! K                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))5 a4 B+ l) v7 j
    plt.show()# G7 ~7 z# k1 P  t7 n
    # 绿色的表示预测是对的,红色表示预测错了* d- v/ Q/ }- Q# ^
    1( V- H7 }% l0 M7 M( i
    2
    + ~( X! R; |; x+ C31 W4 A$ V% Q' }' P! x/ h7 u8 l
    4
    $ ]4 s6 C* _- J% `: V; o# ~6 Z5) d& K9 A6 {+ d; e- I7 P7 q. f& u- {
    6" u& x6 R+ L. Z! a* L
    7( J6 x* M; n1 o
    8
    " R  d6 e- ?; N$ q8 g. X99 L( f7 z# n6 X! A9 M  d- C' k
    109 m" j# F# q: P7 S, I6 z( o) k
    116 H. P( t+ @  Y, {* x* q' D- K

    ' ]8 k* W- B( M0 E; X
    * o: }+ L1 l9 w, |4 ^3 o0 }7 x" ?  l4 Q% z2 P- b$ W. ?; k
    ————————————————" O) l- @! V* T* {: W
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。2 E4 `/ m0 L6 t/ ~
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
      b" R8 q- S* j& O/ S2 y( a) o! E9 D( r$ Y

    3 K* M; v( r! f, ^5 Q
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2025-8-4 12:48 , Processed in 1.328087 second(s), 52 queries .

    回顶部