QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2579|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    ! b3 t& H+ k" t& I0 M, S( N7 r
    , _" c) z; i- f: W" ~文章目录# n( D* H/ A. ^: o
    卷积网络实战 对花进行分类9 s8 [; E# H9 V( W* n% c6 }
    数据预处理部分
    - {  e3 C2 c+ }8 A% {4 w网络模块设置
    2 z( ~) v: W0 [$ R- v( O网络模型的保存与测试
    / D0 A3 z+ g2 w% {& H) d0 z( z数据下载:% B0 v7 a1 F. h5 ?" H2 c. w! }
    1. 导入工具包- J$ x. t& C9 ~" E; |& S
    2. 数据预处理与操作
    ! r4 B. j- Y. Y2 |1 o! I3. 制作好数据源
    / H6 K4 O3 Y- I. Q. W% e读取标签对应的实际名字
    3 Z% }4 A% @. C& [% u% t& x4.展示一下数据
    ; K3 m6 ^2 {% ^5. 加载models提供的模型,并直接用训练好的权重做初始化参数
      w" r* [5 Q9 Q# j* D; N6.初始化模型架构
    8 V7 r+ G/ H  o; P9 V$ D5 F7. 设置需要训练的参数
    & X6 ^4 J0 |; d3 S7. 训练与预测  ?- n& p7 y" U8 h
    7.1 优化器设置8 y. L6 N7 x+ I0 T" h" G& o
    7.2 开始训练模型( w( n" R3 n& s3 O$ Y. n( e
    7.3 训练所有层, b& @; q1 [" C% p
    开始训练
    : l8 x* ~% {' O8. 加载已经训练的模型' h5 Q, _9 G2 b& ]( q
    9. 推理* W# D8 r) b; B6 c! {" \5 Z$ X& _9 K
    9.1 计算得到最大概率1 u; g# }/ N' W7 L' y0 E+ K) \
    9.2 展示预测结果0 b( m: \2 I3 S
    写在最后
    % c5 Q% Z, F8 A. n9 u; b6 s卷积网络实战 对花进行分类2 j% U7 E8 x; O$ b$ A! I1 Q; z: G
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能$ G" d! x6 m( K7 |) X

    5 _- X+ t1 M/ y7 }. [# J+ M在文件夹中有102种花,我们主要要对这些花进行分类任务3 n6 X4 @+ r% K& q( I* m
    文件夹结构0 B2 i. l; E; @( p, P# P
    6 E4 I% D4 {$ C$ b( e" @0 g
    flower_data
    ' i( l' z7 |1 r( J8 |! Z
    4 w$ c& @8 A0 h* Itrain- j1 I5 T( u0 o* O

    5 D2 g5 E5 w. ^+ S5 B8 I$ V1(类别)
    ( L; V* Y5 z9 i; e# n: G2
    4 D: Y; N( L0 p. M% E8 Bxxx.png / xxx.jpg& N8 X$ u6 _6 R5 H0 C$ h
    valid- `2 j! P  U$ ?6 }% _5 U# l& r
    8 e$ o8 F8 z; a$ g
    主要分为以下几个大模块$ `4 A: s2 q  z9 s" f3 |* ~
    * ?/ J. Q& L5 t$ B6 Q2 K+ z& n
    数据预处理部分
    7 X6 [: X$ }: [7 ]* J" z5 b数据增强
    1 ]! I! P! S- v1 Z数据预处理! C( r, J$ z8 l6 y. N. z  ^
    网络模块设置0 I* Q1 q7 |' u, z+ |( q* ]
    加载预训练模型,直接调用torchVision的经典网络架构9 n$ g6 h; B( C0 B2 e
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
    # y3 [6 z" p/ N6 ]/ b; N网络模型的保存与测试/ z& W0 m' L! C* ]! e
    模型保存可以带有选择性
    2 e5 {+ v) k" c/ F& R* V7 U数据下载:
    / ~! e2 h. d2 D: {9 Ghttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    , L, K3 V5 b# j
    9 y- s5 y/ ^8 k3 c$ o( p# }, ?5 B( z改一下文件名,然后将它放到同一根目录就可以了+ o: b3 R* ~$ x& ]  B4 Y; ?5 I
    4 ?7 l) Q9 N( G9 h7 d  {
    下面是我的数据根目录
    ( e- S3 H0 W! c9 T1 H2 p! B) K$ J9 R  [9 F4 l/ c

      v) i$ p- {! t0 f. C9 n( J4 X1. 导入工具包) B1 T' r0 Q6 @1 Z# S* C
    import os3 |' {5 p" V( }/ d# i& V. k
    import matplotlib.pyplot as plt. C/ l  |2 L6 a1 }( }: s' ~
    # 内嵌入绘图简去show的句柄4 h: t; t+ G- f" O! d5 |
    %matplotlib inline + t8 n# e6 D( |) C7 [
    import numpy as np7 m/ Y- u+ p) U5 j
    import torch! `* _7 V* R0 v8 Q4 h
    from torch import nn
    0 @  G# `# `! b; D8 N6 B3 \3 g6 b1 {# y& ?9 y4 j1 x5 h% P
    import torch.optim as optim7 L+ J, p( [0 M8 m; i+ m6 U
    import torchvision
    5 F: H( F( G- K9 u5 T3 @8 @from torchvision import transforms, models, datasets; i3 Z  k& u" [  b

    ( ~1 o+ k+ V& D( vimport imageio3 F2 c2 \: }$ {
    import time
    " n+ l  ~* y  q# {8 Q  uimport warnings
    , f% n' f6 L5 C/ N/ I  t" rimport random
    1 _! c* k# t% W9 d# V% Kimport sys) j7 g: J# x0 j! G( n
    import copy; d9 K. v0 q) o( f) {* Y. Y
    import json
    * _1 }: p8 a# ^, g6 O8 Dfrom PIL import Image
    8 Y! D5 A7 |( z) ^& C& m
    $ E  M; c' J( B, U/ `* D
    * b, i5 f$ b  d( g; `: Z; n/ b6 U1. k, t9 K6 o7 ?/ @. C
    2! W2 t% j  j( F1 c. l! |$ T
    3! r' e2 l# j! A
    4
    % [2 c4 p4 k  |0 N& r9 L* k: G+ m59 n- W4 p/ [# Q5 u7 n
    6
    8 U+ [6 E4 b) F" N) I9 ~1 \7
    4 d" k+ o# D. F  F7 t% ]8
    # s0 j, r3 S* i. R9 Q9
    3 ?+ f6 D, ^. V% c6 k10( w& w% x# k! t$ U) {
    11* B( w' P0 ]% i7 l: B( q
    12
    7 O4 i3 t2 E# V9 B: j) r: E! Q13$ N1 e! e2 @$ r: J
    14
      M% o  N" I) E! N15
    . E4 z2 C6 s5 J  ?6 v: f- P2 o16! f7 h' i3 T' f6 w, I
    17* l  e5 k+ ^$ R5 O
    18' `0 W+ M; P1 E" N# a
    19" o! L, m8 P1 v( d% N
    20
    ' c7 {! {  N( B5 H) \; v211 n) J' ^/ V" G3 r, b  E6 i
    2. 数据预处理与操作
    : U0 P* \/ I$ N. T/ k6 `0 l#路径设置
    ( H5 Q9 ~  @+ h- a  d, A/ rdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录0 Y5 J! P+ W: w; s4 L
    train_dir = data_dir + '/train'1 V. s) h6 g/ Y0 z
    valid_dir = data_dir + '/valid'# k+ X0 i: G# \* U1 D
    18 J8 |2 K. }; t% w
    25 _: Q& d' }- I- o. ?
    3* O& H+ s" b2 M& M) ~, C  B, B
    4
    7 L1 G2 f2 i  z% T$ Opython目录点杠的组合与区别" Y8 ?, E. M& w7 c( w4 ^5 p& G
    注: 里面注明了点杠和斜杠的操作3 ?4 A* d, U2 v* U9 |3 _$ Q. H
    $ w7 |0 D1 K8 j1 _
    3. 制作好数据源
    : n1 f7 s' D7 `4 I2 Sdata_transforms中制定了所有图像预处理的操作0 i( G$ W5 w) }, q( F6 c
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    # z# y; _5 F1 I! B& [data_transforms = {5 N9 k6 @4 p) F4 H% s
        # 分成两部分,一部分是训练
    : Z  z- [" ~! h3 _; h    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间) e$ U! \( h; b5 i
                                     transforms.CenterCrop(224), # 从中心处开始裁剪
    0 `# k  k! y; X1 X3 m0 i' o& [8 p4 x                                 # 以某个随机的概率决定是否翻转 55开1 p: F7 _9 K$ e3 ^7 k# D4 W
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    1 n* U3 w( k$ n: n- y/ f                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转$ h9 K* }# J' V6 ]; v
                                     # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相4 l  ^- c1 J# U9 X$ s4 h/ _
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),2 b" d3 M' v6 `9 T& a$ @  |
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    0 y; \- q8 B+ P4 Q* d' |                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的" W/ p4 J# T9 K4 n9 [; _9 M
                                     transforms.ToTensor(),+ x7 m3 E! Q5 q; g6 v( n
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差, `% g! _: e$ g) v/ Q
                                    ]),( U) L, i4 V3 f$ Q% W% @$ S
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    ) g/ s& {9 R" M8 d8 X3 I    'valid': transforms.Compose([transforms.Resize(256),: |+ f8 Y5 @( F( Z* U
                                     transforms.CenterCrop(224),, [2 Q4 {4 T; \: g- {
                                     transforms.ToTensor(),9 G: _( l2 n8 h8 K: J
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同+ X0 y" ~, |/ l% e
                                    ]),4 ~2 e: F. [9 S
    }
    8 i. A1 R$ n3 [& B
    & }( r- J( V' R% Y( \0 i- C. s14 F: V& N% K: V% d9 g& |
    29 b8 d2 c+ y) R1 B! b) w
    3; }8 Z# V9 ^6 O4 k
    4
    - M/ u5 o6 @7 {. L58 f: u& M1 R+ }1 D3 t% {
    6
    $ ?3 r  y, K0 h' _, v7& L* }: R& A4 U  I+ Z/ l
    8
    9 V. O2 M! C4 n: f99 ~& p2 f  i2 `, u2 c( T9 x! I
    10% J  O4 A7 Z/ K1 `
    11
    - s! g. b  S/ |  X7 z* O12
    ' g. K2 y2 O; F* [; @; J4 m# `% Z( C13; K" r! n5 @, h* y9 e5 R  o
    14
    7 T+ J& G  `/ g. [% |, K15
    - ]* W; H" n2 Z& J1 J16
    ( b5 c3 O3 z/ P2 A. a17
    + O4 r: L, d9 f) Q! a1 D7 ~( ~, d0 h6 A18
    : O% u2 ], b7 ^& {6 n1 Y4 R2 ^& y19
    9 f( T3 b; g* N# U20# V7 @' @: [' B% q9 V1 w
    21+ P' {5 u" b( _& K
    batch_size = 8" `( D6 v: H- }6 x- t: E1 P
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    5 J, ~$ U- A4 m: k& j8 Xdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    ( k; \' u* y6 N# A* p9 e/ kdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    ' J2 j$ X! v$ Q* }# R. L; s6 D& oclass_names = image_datasets['train'].classes' t  c/ |7 G. p

    " T& f, p: d2 @# W! v  B#查看数据集合
    " F: p- N* _8 X+ eimage_datasets* @3 T, v" X; q& e5 s. Q& y% C4 z2 z
    4 s/ H* m5 S* J# a$ e& C9 u
    1+ Z  k' K. Q) A7 S5 n% \  O
    2! g) g2 X+ A" F8 [
    3& ^' H0 [5 I" p
    4, U0 g8 J# R7 h) |1 H1 n$ q( S
    5
    : t% [; m. ?+ |) l' ]) `3 [6: I" ^; y0 m3 ]+ e! O
    7* y! U* m$ |' M% h9 J
    8
    " S5 v# R0 U7 `9 {97 ?0 J2 s& Y. p& I1 B$ f3 }
    {'train': Dataset ImageFolder( ^7 A. m* t5 l) C! j* M
         Number of datapoints: 6552
    1 N, d( t0 C7 a( {     Root location: ./flower_data/train
      H- R' R: K+ w     StandardTransform
    5 E* V' S/ w- Y Transform: Compose(
    : N; g3 V" O- d$ R: m                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)' ]1 e' m4 s3 [# ~( p! F; P
                    CenterCrop(size=(224, 224)). t4 S  X& L8 x; D
                    RandomHorizontalFlip(p=0.5)# V! _. a, d4 R, m& c! I
                    RandomVerticalFlip(p=0.5)
    , u3 n9 H. Q2 s; K: }' h                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    7 N( Z; ?9 z2 ]& |( |                RandomGrayscale(p=0.025)5 U3 N- ~& ^1 j4 Q. {
                    ToTensor()
    1 F2 _- b5 _" C  W" [                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])5 z4 N4 H, @  f% R7 k1 W: c, [
                ),
    # [9 y7 D" R* |: O( E 'valid': Dataset ImageFolder$ d+ ~% E* C% P* X
         Number of datapoints: 818$ s- ~2 M' I. i2 w+ q1 z, z, J
         Root location: ./flower_data/valid7 q, ]2 W+ F8 \" O7 ^
         StandardTransform
    1 j2 C  |+ W% z7 {( ` Transform: Compose(
    1 p2 b% d7 {  U" b. M( s/ \/ w                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    6 P2 x; m) w6 Z# z0 b) @' g                CenterCrop(size=(224, 224))/ e* i! N; q4 I: W& I
                    ToTensor()0 `8 ^- l  |  }
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])% Z# J/ Q( \# x1 \  W
                )}
    : ~9 }) P4 G* e/ N. A7 Y* x! S' T% r" v& B5 Z9 S% B
    1( s. q: R9 K. c( L) K/ I
    2# c. c( k9 f4 d( {$ Q
    3, M. e" T9 h2 C4 \* I; O
    4
    7 Z5 U; J" k2 e5
    8 ~2 @2 g, c) o5 j( L* l$ M. D$ n6+ u" a1 R7 ?$ N! a& E/ Z
    7
    : q6 |! o* x  \$ G5 ~8
    ) t0 u2 F$ k3 o& a& S) f9$ [, L& S* C0 t9 H! w: n+ P
    10
    + K, [) K' o% |" F* V" X( O+ [5 |: Q119 X1 X0 m9 _. P
    12& @+ A) c3 X7 q& `2 G! L' J
    13
    ' q; x/ ]1 k4 D& G14
    5 a5 c8 w3 ~0 s8 j, `% L# ?& J152 K! l/ T, r* L* S; }% b8 u  l
    16
    0 ^5 _  w# c1 {17& m# D+ S. F' x
    18
    . g: w( t3 A, M19. R, n3 W- o7 z) P0 _
    20
    . |: v, o; s) R% K! T& k215 k! y8 F# [; i8 I
    22
    8 X- X2 ~: f) Q, z7 l% C& Y23
    6 d( K/ {0 u& o/ B248 t, q, b9 U' S) R2 J* s  F
    # 验证一下数据是否已经被处理完毕
      v6 m6 G% i. C/ O3 c/ Ndataloaders
    ; E0 E" v/ [/ v1) y; E) m* d9 A
    28 `5 Y5 m0 o8 t/ h4 \) [- p7 j' J' ^
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,( u+ a6 B8 E) V
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    8 |# u! @1 \( O, T1 H! O1
    0 G7 }/ _$ B) K0 [) O2
    2 t2 B$ Q2 i$ ~( q) a: Tdataset_sizes& M& o5 \* N( E7 O/ H, Z/ ~
    15 q, Z5 G1 ^, j2 y! z+ h, J! H' p4 W
    {'train': 6552, 'valid': 818}* E5 i3 ?; ?; I0 k: C. y
    18 \& D! V; y$ `' _1 U* d" P4 Q5 I
    读取标签对应的实际名字2 Q; }0 t& [9 k- ^. N
    使用同一目录下的json文件,反向映射出花对应的名字
    ) ~0 \* y* ^( _3 G: [8 C5 U! i3 {8 u0 O/ i+ }- U' L* ~
    with open('./flower_data/cat_to_name.json', 'r') as f:" e+ v+ M  g, Y/ J' \
        cat_to_name = json.load(f)+ O) }8 p3 |$ G2 _7 V
    1) ~! R+ n" E9 m' U8 G
    27 ?$ w. h1 O& |6 M- g7 O4 i
    cat_to_name& m$ W& b. ~2 `: f9 y
    1# _2 e; u. G) X! q# b
    {'21': 'fire lily',: c7 c- n6 K. g; O6 _
    '3': 'canterbury bells',! a# q/ q1 a4 Z% j- i+ e2 |
    '45': 'bolero deep blue',
    - w" d) i! _' P: s# H '1': 'pink primrose',
    5 U; G" i" S: K. P/ C  u& e7 l '34': 'mexican aster',9 S3 L/ e; K4 c* R8 H( f3 e1 g
    '27': 'prince of wales feathers',7 ~9 I/ a& Y1 o
    '7': 'moon orchid',4 P0 b/ L% O: D/ x9 l! W
    '16': 'globe-flower',. \& F6 f: F% ^" m& [+ }: s: w# Z
    '25': 'grape hyacinth',
    6 x. b5 L7 D5 N& A3 q( {- s '26': 'corn poppy',1 [$ L. U- U! x" i
    '79': 'toad lily',* r) r# r/ k2 k0 Y& h3 ^
    '39': 'siam tulip',
    5 p0 A" _3 x- V  x1 v '24': 'red ginger',+ q! j. D- S3 q% X, D
    '67': 'spring crocus',
    $ V/ k% p, d% g# V6 }8 _$ [% j '35': 'alpine sea holly',
    & n, x" d& q1 b: @. ` '32': 'garden phlox',8 e  e# p, X/ ]
    '10': 'globe thistle',2 H0 ^1 B# l0 o' G
    '6': 'tiger lily',
    5 j4 v' c- ~7 |1 n8 E5 f '93': 'ball moss',
    1 {3 _& p5 x; a7 F4 k- O9 H6 S/ V '33': 'love in the mist',
    + D) a$ R: ^8 j/ s* a '9': 'monkshood',
    ( F$ V- U0 R5 J# z0 w% ~ '102': 'blackberry lily',
    0 b9 R& i. K! h4 M '14': 'spear thistle',
    " X; N3 b6 b. O! p. Q* z '19': 'balloon flower',4 o1 M6 g* \2 Z3 {% m3 r% }' b6 D
    '100': 'blanket flower',% i# @) |% O( `6 T# ^( ~
    '13': 'king protea',7 `  n& `: W( i- J
    '49': 'oxeye daisy',
    6 p) }2 I- a  r5 y8 \ '15': 'yellow iris',8 h# k% Z- k+ S% D8 N2 f; v
    '61': 'cautleya spicata',
    " ]$ p" w+ l! ^& u  B! J- F: Y '31': 'carnation',
    * F4 K- u$ ^4 z* [/ ]0 [, e" y '64': 'silverbush',
    & Z, K" c/ ]% M '68': 'bearded iris',
    - I, k" n2 R0 |. H" H '63': 'black-eyed susan',0 f+ N. v* S7 z5 X+ V9 U* Y
    '69': 'windflower',. S, I; [/ K7 I- a: n( J) J
    '62': 'japanese anemone',# b- c4 y, V7 d* C
    '20': 'giant white arum lily',
    , n$ J  [2 j! E) _ '38': 'great masterwort',, G0 {) p# N  `, A$ C
    '4': 'sweet pea',
    8 d' h3 t$ l0 G& | '86': 'tree mallow',9 k2 y2 b& p7 q1 ?5 ^; e
    '101': 'trumpet creeper',2 N6 B$ t" \( b0 T) u$ K' ?6 q6 x
    '42': 'daffodil',4 E- ~& Z$ w, N  o7 L7 E
    '22': 'pincushion flower',$ I2 T$ d" z* Y# [' A/ N
    '2': 'hard-leaved pocket orchid',
    0 B* Y1 ]' w* }# a* F/ d; k '54': 'sunflower',
    " u! {- s& V$ z6 D. ^9 E '66': 'osteospermum',
    + A* k1 l! L" z4 o '70': 'tree poppy',
      k  q" v' A2 O- }* L2 j '85': 'desert-rose',; s2 c# n2 |1 t) x8 @6 u
    '99': 'bromelia',
    4 g: ?5 K& Z. L. P4 j, a3 x '87': 'magnolia',+ I2 b+ i% x+ }. f5 V+ h" ^2 e7 C5 ~
    '5': 'english marigold',
    ; J9 O( d* a7 F5 u '92': 'bee balm',; _: e$ p& ?) z; d5 M
    '28': 'stemless gentian',2 p# U4 W! ~& ^% Y% ?$ C" a1 z  j/ j
    '97': 'mallow',
    , F6 S2 `( l  w9 M7 Z, [" h7 J '57': 'gaura',4 b8 @; E( T6 h8 S  Q
    '40': 'lenten rose',8 _0 x/ g! U9 r1 u* p1 H
    '47': 'marigold',
    2 k/ @# }3 _! h8 e! o. P '59': 'orange dahlia',' N0 o+ {6 w6 X# g% e# W
    '48': 'buttercup',/ W$ n. `& @+ l. ^
    '55': 'pelargonium',
    - _4 y" ?, x( [) `# S7 c '36': 'ruby-lipped cattleya',0 a7 s+ [8 P9 U- r4 Z+ T
    '91': 'hippeastrum',( n/ O3 u  f1 o2 e" T" ]. ~+ F) d& c
    '29': 'artichoke',
    2 g  h4 T, A+ l) x- V& p5 ^ '71': 'gazania',$ L/ H- h8 ]  b! [$ P) k8 g
    '90': 'canna lily',8 T1 ]3 n7 T6 J$ {! U+ B
    '18': 'peruvian lily',! Q2 j7 z2 E- L0 y/ t
    '98': 'mexican petunia',: I. q- `% g7 w! f, `5 ^/ P
    '8': 'bird of paradise',
    + B1 z2 ]! s. U3 J) _9 s6 l '30': 'sweet william',
    ' g; T! H0 O. {. u; D '17': 'purple coneflower',
    3 p! _' s9 D1 p7 ?( I) b/ J '52': 'wild pansy'," w+ L& H2 w- E) Y$ X2 V, H0 }
    '84': 'columbine',
    2 m$ w. D4 D5 C. r7 r; H/ J9 r '12': "colt's foot",
    0 B: ]3 d6 S5 T: v '11': 'snapdragon',
    + G' O  W; k0 s2 v# |! E& R1 `  d '96': 'camellia',$ M" ]- ~1 X" E4 C& _6 h6 u' j
    '23': 'fritillary',8 q( v( r4 a1 @
    '50': 'common dandelion',
    " p; i! {* F, K% z7 l '44': 'poinsettia',- L7 q+ \) l* m
    '53': 'primula',; o# z8 @% V$ ?8 @7 O
    '72': 'azalea',
    * j2 F, \( h" S0 z. x" q '65': 'californian poppy',
    : L9 i1 L  y6 b! p% Z '80': 'anthurium',) _" @5 ^: R% x: d% ^; o) g9 L
    '76': 'morning glory',7 q: O# j. `& ?8 F! W
    '37': 'cape flower',6 @! X) g$ T! S. T9 Q
    '56': 'bishop of llandaff',
    7 y" J0 }/ N7 T- I8 n4 P" A '60': 'pink-yellow dahlia',
    2 N, t6 q% @5 P6 d '82': 'clematis',
    * q1 h) c; U$ J" g$ g '58': 'geranium',
    / L/ c' ]0 B! O8 e '75': 'thorn apple',
    ( J$ u+ G8 x* n/ H3 q) V" _- O8 _ '41': 'barbeton daisy',
    ) C2 ?9 `4 v) w# [# I '95': 'bougainvillea',2 C+ r/ Y- Z- G# h! c! A8 |% r7 O
    '43': 'sword lily',, W( m( N  P+ c, n
    '83': 'hibiscus',
    ( A+ C" w2 j; W '78': 'lotus lotus',
    ( p) w" {; M/ [& z9 O '88': 'cyclamen',
    1 y8 G" L/ b0 ~2 F0 K9 S% D '94': 'foxglove',; `- w1 u5 Z) Q( v' j
    '81': 'frangipani',
    $ Y: X  A. W+ X8 V '74': 'rose',4 u$ I, {- F2 @( [" c( T
    '89': 'watercress',1 o3 s3 V1 t% s/ l6 r/ Y+ {! s
    '73': 'water lily',3 D, S) p2 b) _9 X1 G, r' G
    '46': 'wallflower',( b2 l2 v; x: M& N
    '77': 'passion flower',
    % \  o9 |+ q" u* C8 @& E '51': 'petunia'}( P7 ~  R& p; S1 c3 {

      R3 e9 g' }6 v& s6 L1! X: o% ?& q4 Q
    2- ]) q3 u' f* b+ C8 A
    3- `$ Z' |/ ^4 ]( B
    4- A9 A) u  r# ]8 A9 v, |8 L
    5
    7 P8 {2 w9 b0 K1 U5 y- r6
      z( i$ h4 s3 M5 L: G& |7+ S& i) R0 ?; ?2 C( C
    8
    ' Z  |" B4 y0 u! d9
    : n) d! o6 R& U( o( P! Y, ?10
    # W; L/ U- h2 f; ?" M9 |11
    ; J" v0 S2 q) }12
    ( r: _, c5 [9 Y$ w4 ?8 V$ W+ h137 l" _" c, o/ U9 O, D
    140 J4 G3 O  ]9 d* j  B, \1 L$ O0 W
    15( [$ b. z7 I/ M+ h& C
    16* `7 ?+ V; K5 K( I* S- O
    173 M# B- E* z8 f/ j) D' @5 a
    181 L1 N: w/ P' n
    19* |. f) h, x: q- L  L
    20
    * O! h7 d3 o- n" s, {21
    % N& n+ u& T  s  u4 `1 e  P221 Q' B5 j1 ^, d& |9 c6 K: G
    23& A, C) T9 {1 K9 U+ Y! v" {
    248 V% S9 y1 W5 [" p/ K$ w, C1 H- c
    25
    7 i+ k" J4 ~* V! C6 R26
    . I( ]0 M- t8 x8 h275 u- v$ |" X  {- p6 V
    28" |0 t0 O  L5 M9 Z: w
    29
    / e' O8 g: z$ U! |# C4 S30
      p- D. X( h: G' v. [3 J, D! J31
    # r) i0 f9 a3 @2 B& N6 d( B32
    ; ^$ H0 d! N1 r) H5 L, C33/ |+ w* `0 v" e4 ]! V% S
    34
    ; a/ `3 X! E* G7 J5 @" o354 \1 K2 ]9 K( J! }
    36/ J* c' g$ {/ n$ Y
    372 r1 h6 s# U8 J4 |. V  @; V
    38: z( ~' s/ M2 T( `$ R8 `
    39
    4 B5 S0 ^  ]/ I2 y; Q% F40
    9 [0 M! M2 r0 P% T9 I! U/ R41
    4 g2 ~  L* M! X3 f; U427 @7 Z2 \7 C( ~+ a+ y: ]6 S. h/ g
    43
    ( x# m1 \- V* p: X44
    8 p# p  |4 S/ {6 d/ K45" E/ O3 E% |& D! t& j! `2 X
    46- P$ D$ i& U+ ~; ^4 M# j
    47
    5 P5 @+ y7 \4 W/ R4 @48% }$ x+ ]0 \2 X" I9 @, A6 e
    49
    . K5 @* ]/ v( z# w50$ K; ~# u  x! l
    51
    " r3 @  G7 B) n- m6 r0 i- {4 a3 b52
    ) R. B) W; e) J: [53
    6 M5 r3 v6 h5 r. x. X" b- ~  f. c54
    , H. s. A: _! j; e55
    / i: K. J2 \/ o56, H; n3 ~/ W$ j1 [) x
    576 F+ g) i4 }+ b, m. l; E3 S7 s
    58
    4 }. F2 k( K4 H- @) z# {6 x* ]8 X59
    ( {6 b$ C/ q5 N$ p- i60, D) m$ @" N% x4 o. Z  A7 q
    61* u& x/ u# o  `6 l  `/ e
    62
    ! r! {9 ?& U0 B9 q3 K" C9 i634 }- d& ~6 `2 O! f" H% h6 C* a& u
    64
      l$ X: X# ^. z65. l" ^  l. S: A% M" g7 }  t
    663 J6 ^$ B& {1 h9 y: x
    67
    ( C1 {( C- k7 A: W6 q682 _/ Q+ b4 C4 N2 Y
    697 I6 d2 X$ k; T$ F! O+ p0 U, U+ T6 o
    70
      \* ]7 {0 |1 X" U: P, b71
    9 @( h8 K, Z6 W+ ?/ z$ p72! o3 y5 \# ]3 W& }+ e% Z
    73
    ' B: s8 ?, g3 z2 ?74
    , v7 u' B% i* S4 q/ G8 a75
    2 [/ p% J9 {6 r8 u1 v76( }0 B" a& [0 a- J2 S" E+ O8 Y
    774 a. j8 P+ M# F( G- K' Y  X( U/ D4 F# ?
    78
    . P; W) X1 s/ ]; }79
    3 P+ d# S4 j" G80
    # N; z9 a$ t; M- F; ^; P$ u81
    9 h8 F# V' L+ J  j820 w) c# y* d* k9 j9 m. ~
    83
    1 z& l7 t8 q) Q1 }- D# }! I84" ?5 ]' R3 u8 w. e, D
    85
    - X0 X: \$ i  F* M8 ~, D" l86
    : f5 `( f5 ?) P87
      s4 N! {# B9 U- u; w9 n7 \88
    . B# H6 ^3 e8 V3 s( `5 k! R# V89/ t- M4 o- j. y
    90
    ; O  ]* N, ]; i  J91
    - S! c# L  p: R  j! v9 a2 W928 d7 N* K  \( H$ n. E
    93
    . R) M2 P4 I' L2 ~+ X94
    3 V& F. T1 k. a- F/ j$ N& z95
    % d; ~# e+ x5 {, {4 r2 O, {& W96
    6 k4 L8 _( G1 s% Q. |979 \, z: s8 H5 P; B
    98- V' T) _4 B# b+ L3 ?
    99
      r4 c* L5 f: K' N; ~, e4 A100
    " {5 X! m& |: [! J: N101
    : u: q- r* H6 F) M& ^1028 s$ I: G* e( L4 c! p
    4.展示一下数据9 x0 j5 I3 h, u8 k
    def im_convert(tensor):' }0 Z# s4 Z( V  }$ P* i' c* j
        """数据展示"""2 b0 r4 P9 O4 R/ U8 n" m- Q
        image = tensor.to("cpu").clone().detach(): @. T6 r! Z, j% t$ M
        image = image.numpy().squeeze()
    ; Y4 y8 a* Z: E! X, L    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图' t& G3 n9 h: Q+ U' N4 a  @
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    8 b4 Y7 \0 P7 R    image = image.transpose(1, 2, 0)* q) p# S" Z# M4 u6 H9 H& a. T
        # 反正则化(反标准化)
    # V3 n' _' @% w. c, V    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))% {" T# V4 `8 i( D' y$ i
    6 Z, {  U9 a6 I) g7 G  u* K
        # 将图像中小于0 的都换成0,大于的都变成10 A' {, C0 L5 v+ Z' l/ Q4 X  |
        image = image.clip(0, 1): M: Y# _6 \3 I1 C/ b
    1 ]# B. n' L/ C' _6 g) U# _& b4 t: N
        return image
    $ D; W5 }6 l7 m1, d6 F$ `! q3 X" e
    2$ _  c$ ?% ^7 H7 b
    3) [6 \3 g9 ?3 C( v
    48 e7 e$ ]; Z* T0 g! v/ e1 ^
    5. B$ k2 A; [7 c% O4 E  t: B( z. Q
    6; F4 c9 e; N& ]1 m0 V
    79 s* c0 R& h# B8 v; |, a: q$ E9 F" _
    8! S) G6 x/ t! e. P0 z) Z; B& b, Q' [
    90 s8 j  P$ ^; E2 W
    10
    ( F/ u& H# t7 d- F9 h11
    " H+ }0 o# B' q" ]  `3 S8 e124 I3 U# R2 R6 r0 K6 A3 s" `+ D
    13; X# O" k% l  i6 B0 m: f. R
    14
    6 s$ I) d3 ], `. v3 y1 r! l' Q* u# 使用上面定义好的类进行画图0 I% L& y! Y. L' A# n
    fig = plt.figure(figsize = (20, 12))1 N+ ~6 N. o* P9 {
    columns = 46 M- z" s5 ]- K' L
    rows = 2
    : U9 r! c& w3 l5 X/ P$ V" ?. M+ l% y, w. j* q4 e
    # iter迭代器
    ! |5 G  k, @" Y" l7 S* ~# 随便找一个Batch数据进行展示
    $ s8 v2 {7 }5 [9 pdataiter = iter(dataloaders['valid'])
    6 j. P6 B, R, @: K3 c6 }" F5 \% pinputs, classes = dataiter.next()
    # _0 }1 o7 S4 E. i4 ~: b* k! ~: I, U  t5 K1 z
    for idx in range(columns * rows):
    4 |7 r5 I7 [- C7 K/ c% X4 ^, ]    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    0 h- k0 F4 {, Q& I. M9 N! ^# p    # 利用json文件将其对应花的类型打印在图片中! F* _$ |( l# u8 O3 @( n
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    6 M7 j# e& e7 m; J" ^5 [# w    plt.imshow(im_convert(inputs[idx]))
    % d7 g" N  Q6 b! z$ z9 Z% I* Yplt.show()
      }( j  ~: v. ?5 F8 W, |0 H9 w) F$ h9 G
    14 i$ Z, d3 o# k
    29 j+ X5 B7 c+ f7 C) L
    3
    & J$ a% |( l' p" n- {4, |  H6 Z7 V! y1 t8 Y" I7 U
    5- F3 `/ y: _5 c
    6
    4 W: |* ^$ l: ]% h! Z4 q74 U% p* u+ p, Q( a
    8' r0 m2 W  c- S
    9- A  S6 u  h% ?' u
    10
    . {  u3 t# R* C' K, Z; T4 T11
    : s& G. Q5 D) P! M+ K/ E12, {6 \, @7 `, m5 w! J1 w
    13
    6 n' ^8 L: {7 M1 t* I14
    ! f& o6 }9 X7 Y- E; e3 T15
    ! W" ~! V" q1 }9 H7 ~8 y; u162 R  r% m2 G# z

    ( e- O9 z! q- E0 A5 L' I
    ; V; K1 [* H% t3 h1 D0 s5. 加载models提供的模型,并直接用训练好的权重做初始化参数
      i, t$ Z$ w- p9 Fmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']  i5 X) j! e, c# A$ |4 n
    # 主要的图像识别用resnet来做
    3 P% T. y. e' f! n% B" L/ R! F# 是否用人家训练好的特征/ L. @5 Q% D1 w3 d# F5 \7 A% \
    feature_extract = True: C2 Q- d; \# j3 {/ G' A1 |' S
    1! m9 a9 b) H3 e& E0 A4 n: E; l- c
    2
    " _) a0 H4 w+ @, @3  k' a; _; _5 I' v, |4 Z) B: @0 u5 @
    43 w; j: d4 A- p2 A0 X
    # 是否用GPU进行训练
    . h1 ]+ t7 D0 i& Ntrain_on_gpu = torch.cuda.is_available()! I3 b! y& h; m

    # Q3 w" h: r9 Z) sif not train_on_gpu:( Y# ]& z9 a( t) t$ x/ P
        print('CUDA is not available.   Training on CPU ...')
    1 E" a2 @% N. A- v3 O" U- w- {% Helse:: g! ~; N2 W) j7 N/ w9 n1 [
        print('CUDA is available! Training on GPU ...')
    7 \# T% h" P' y
    6 v+ h2 Y$ ]1 J$ }/ `device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    4 g& n9 l3 B; J2 @. v. |1" D- x5 b) T5 w+ R
    2
    # q( `2 R! T5 s( o" x3& N2 ^  T0 |0 {" h" i
    4& ]+ |4 r1 p1 z2 C0 T
    53 `0 D- |6 Q5 D
    6
    3 O- ?. O. S0 H# _& n74 E' m6 n6 C/ b- v: ?
    85 ?1 U1 h$ `, |+ U$ d1 j
    9( Q4 ?' o& K) @) S' p# Z0 a
    CUDA is not available.   Training on CPU ...% S# O- g% B+ `  J( D0 w' z
    1
    3 @4 |, H. ?1 E0 Q, X1 H$ \# 将一些层定义为false,使其不自动更新
    . S$ |) Q! y. L" x, g; k, s0 E$ q, y: Xdef set_parameter_requires_grad(model, feature_extracting):1 `  h# r9 c% N" f
        if feature_extracting:
    ! R  p5 d8 w6 t& L8 g        for param in model.parameters():, ], y# E7 d6 r  r1 C5 W
                param.requires_grad = False
    4 l4 A+ d5 e& w) d- ]# ?9 w& {1
    3 ?) n( F3 k% S2
    # E/ B: q# j: p: k3% ^2 X  B5 d3 v+ V
    49 z- r# ]' c8 B7 K1 V& E" G
    5
    ' D: u; q3 a# s" |5 a- W, N  p% D# ~# 打印模型架构告知是怎么一步一步去完成的
      S( }- V0 z+ B" J; F# 主要是为我们提取特征的
    , _, x5 S* K" ^9 U! g
    . O- q$ @0 T* f( P5 Fmodel_ft = models.resnet152()
    8 ~7 `) u! z. O8 S, v, Mmodel_ft1 w- L0 W: |* Y5 R; A# u6 N* k
    1
    4 _( _+ n* q0 G$ G( W& b/ b2
    4 N: k* x8 Q, u) O. K3
    1 b# H; J# ]. y" D/ [4$ l. W( t( x0 B: B1 Z; E. V
    5% j$ F2 N" L- z
    ResNet(
    2 X! y) ^- k4 l0 s& I  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)% P" e" q$ Z0 \# P
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)6 D* {" b) J  G' D! X
      (relu): ReLU(inplace=True)/ B' g- R) i  u
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)5 l3 u+ g2 \+ r" V
      (layer1): Sequential(
    * M0 g3 b, a# f  p# `    (0): Bottleneck(7 \8 @) f& Y& `1 \
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    # ~- L7 H& U" L" @* L      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    $ G3 H2 E. {) H3 o0 A  n      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    : T: Z4 J( a4 T/ c      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)" Q+ E" @+ u: L1 e" K  _# D9 K
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    " C: r7 t( J( H  D$ r      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)0 i8 J. i' ]1 x$ q( I
          (relu): ReLU(inplace=True)& ?* s+ g" E3 `. W
          (downsample): Sequential(
    ( Y! U2 g6 }, I+ i: C' G, y& H        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      A- S5 N4 b/ ~9 e% M        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ h$ u3 @6 y* I, N8 Z
          )
    , n2 h  @7 E2 g+ I  A6 m& L    )- }& E* r% k% u& l) V. x5 X7 O) U
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。; t7 v+ E1 m; O& }  L4 }; u- m
        (2): Bottleneck(
    / N4 i+ h  ]: i6 j( K1 Z) W      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)0 O5 {8 W9 M* C. x
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    " P" u2 F# P" r; ?5 M6 A      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)8 t5 i* j7 O3 a2 X" j# ]/ p
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    * n; Q% c; e( ?1 [. H      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False); F! w& e, u9 J" q
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    0 u, g  U" f5 D! ]9 x      (relu): ReLU(inplace=True)2 c% K  P; H+ b& h
        )
    ) [) g  T$ u1 a, `  )
    ) }9 ^8 o$ G5 L" z: n0 u( G% N  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))& }$ r' l" f4 x- G! u7 b
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    & H  O5 c4 |: M& _2 D8 R)) }: @4 h! @9 J% L. q

    6 ]- T+ ^9 q# P/ ^1! e+ Y, f2 P0 z0 ~$ @1 ]5 k
    2
    ! a* d  q  P. A2 A& z( W8 {& r: ~3! v- j  _, @. ^% s% R+ A
    4) M& n1 A; v0 S( e
    5; [  @# R4 L& V# Q1 e
    62 [& v( u, S" |3 A; r
    7
    0 ?4 P% {0 ]1 F' g- {7 ?8
    " f1 y4 }- _+ Z  g7 S9
    8 P5 ]1 i/ c% Z' F& s; `10+ ^( d8 O4 x; u" i1 a4 c. V: o
    11
    . b+ E" E6 q) H" O& D* v2 N; v5 D12) n. B/ _/ k+ ^7 ^4 K$ {& @
    13  Q; [+ k( o! @
    14! \# L- ^& h+ I
    15+ S, q) T6 R0 b2 m: Q
    16, \  [$ M" v+ F6 O- ~" g1 l& f
    17
    $ l! G, a4 w( K9 ?2 d; _18
    + [. O  Q4 B  y0 t190 \" F' e3 A" Z) G3 u
    20) c3 ?# U: Z! F
    21
    : j5 x, }% U) u( [: h; ~' M1 M# S22
    0 b$ h) j+ n1 ^/ j9 |23- T* V, o! J5 t: I( ?
    24
    + N) f$ t: c' [2 o/ R  s, }+ I2 y1 J+ t25
    & P2 H7 g6 a8 O! @% U, o- T% t/ y6 z26
    ; E5 _3 K0 u( }- {7 k; n2 z; z27; |! M+ {( }$ T; ?' Z( u2 Q
    283 P$ t& z4 v& c( b
    29
    4 }( i$ G( m. g+ K30
    0 V6 U9 R& z* A' s* v7 K, j( s+ ^31& E8 [/ K9 p0 [" Z
    32) U3 h& c& |; M+ N
    33  d3 u. b' B& S+ }* L8 U0 U+ U: n
    最后是1000分类,2048输入,分为1000个分类
    ( X8 h; p; c) b6 E8 Y而我们需要将我们的任务进行调整,将1000分类改为102输出1 c/ E( \/ S0 I

      ^- f' r$ }% X7 i" M. I9 g0 W6.初始化模型架构$ s/ J, p# @( B5 V" i
    步骤如下:
    ! W8 G- X; v0 b) V& H+ R+ T) P+ Q& Q
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
    ( u/ K0 N& K6 U# F" K9 ?可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False), _( n3 U0 z3 A( s" S
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数1 H. N1 P4 |9 z) @; a
    官方文档链接" s2 p( Z, k* c5 t3 W5 ^9 Q5 Q
    https://pytorch.org/vision/stable/models.html
    ) O0 a' p9 t( a! @- }6 p$ S
    , G* E9 x5 d1 F4 G# 将他人的模型加载进来$ B; z  E+ \1 S! i2 y
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    8 A3 d% D4 n1 }4 K9 ]$ I; `    # 选择适合的模型,不同的模型初始化参数不同) r3 W7 R& L  p  p
        model_ft = None
    ' x2 m: n$ p! N    input_size = 0" F" Q) ~$ k5 `- |
    . W- p4 \+ s4 {" s6 t
        if model_name == "resnet":
    9 e0 `1 V4 |9 J/ h9 ?( L0 m        """
    ) X; L  _4 t9 Q: m        Resnet1521 r. |# d' G, M1 r7 S. F
            """0 @9 B: c2 x  Y* R8 R4 K

    $ u8 a* Y. b+ {0 p. g& x) Y        # 1. 加载与训练网络
    4 {$ q! O2 D  s5 l  ?        model_ft = models.resnet152(pretrained = use_pretrained)# c' K1 A6 g& M/ I
            # 2. 是否将提取特征的模块冻住,只训练FC层
    # X. D7 e( n. Y/ {9 D. d7 v        set_parameter_requires_grad(model_ft, feature_extract)6 G& o* r1 c; a/ J" j3 @
            # 3. 获得全连接层输入特征9 N" e4 f5 F: ?' m4 ^. o8 n+ ?
            num_frts = model_ft.fc.in_features
    0 d: W7 V* x2 S( D        # 4. 重新加载全连接层,设置输出102
    1 K, s* f. D) P1 k% Y! e        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),; J" _0 R  Y8 r9 M; H- k& ~
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为18 u6 [* T! Z- I# L# B5 P% Y7 A
            input_size = 224( r( b, {) @1 W
    6 P0 Z, L& I, ?( H: {) r, \
        elif model_name == "alexnet":& Z: L0 {# K' r8 _4 v
            """
    2 `( C! |. x% w' v! Y        Alexnet
    & a! J* k, B; ~; f/ b3 f' \8 x        """7 }& D. s2 x5 }6 W4 Y9 q& F! l
            model_ft = models.alexnet(pretrained = use_pretrained)
    $ w) m' V# U3 K# [+ z3 x        set_parameter_requires_grad(model_ft, feature_extract)
    / |$ g& ~( G/ v* u4 G& m% X( ]: u+ e- L0 Q) o
            # 将最后一个特征输出替换 序号为【6】的分类器/ F! o& L8 H! ?" M5 B! y7 L( G+ f
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入: w% D: T- R# R# n$ V- w8 u
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)8 p- }% n6 Z* D: v
            input_size = 224$ v. \* H: q* q: ]( y. K( v4 u
    5 C# |8 y/ o5 _$ B! |( M/ X
        elif model_name == "vgg":$ V5 K7 o6 n, g, M6 j* j7 H: L
            """
    - b2 q: M$ i* s9 {1 `* ~        VGG11_bn4 v5 [2 R# [$ w
            """: L  \) A. p( J1 B$ T5 z) D* d
            model_ft = models.vgg16(pretrained = use_pretrained)
    5 ~' o4 ]* a- x( k3 D        set_parameter_requires_grad(model_ft, feature_extract)5 u; u& K6 J) J
            num_frts = model_ft.classifier[6].in_features( V& Z7 z. i" r
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)6 K+ u6 i3 C; L: Z' k
            input_size = 224
    ! _0 L" i# w" i# O- G* C& L) C4 v5 t* t6 |9 }$ I  o
        elif model_name == "squeezenet":
    ! m; q$ s0 ^2 K0 k2 M& M        """& Z: D/ E1 K7 U: P1 z3 a# o
            Squeezenet7 k. B  D3 J/ D3 v5 |( ?% o( C
            """
    / j( H. R) j1 ~1 O        model_ft = models.squeezenet1_0(pretrained = use_pretrained)( x7 E2 E4 ~) ]1 o! Q& `2 \# w
            set_parameter_requires_grad(model_ft, feature_extract)0 i  h' T" l* [* t% _$ y! \1 l" G
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))7 o( t3 }+ k1 q1 f! L
            model_ft.num_classes = num_classes/ h* L6 i& i- u4 x+ q- i5 A; d! p* b  S
            input_size = 224
    9 D$ Z5 ]- U9 ~0 J5 D; D* n
    ' J2 Y3 C% h# G7 ~! N    elif model_name == "densenet":7 n/ |1 b  r% }4 h# K& V
            """
    ! x; E$ K* G* a, P        Densenet
    1 X. E9 W2 M9 K+ Y, \3 Y        """
    ' ~2 T( I; l5 ?' P% Y        model_ft = models.desenet121(pretrained = use_pretrained)1 H/ m# H# G2 Z0 _; [
            set_parameter_requires_grad(model_ft, feature_extract)0 |# O' y" W3 e# a& v! ~- A7 x
            num_frts = model_ft.classifier.in_features% c+ z3 u, Z& W% q6 o- W) H, q
            model_ft.classifier = nn.Linear(num_frts, num_classes)9 g- h9 H, g: o0 U
            input_size = 224
    , ~5 _# l, X) f  r7 G) S5 |- i' X& x) l! S& k
        elif model_name == "inception":
    ' o: x* ^* H+ X$ z2 p) d, z( P* U        """  @4 @$ M+ J" F+ _$ L
            Inception V3
    * H. e5 v) p. I" n        """
    ) X: R6 E! [2 u" k3 _; p. D        model_ft = models.inception_V(pretrained = use_pretrained)
    - k$ T( y6 w( K, F        set_parameter_requires_grad(model_ft, feature_extract)
    6 o) ]- }. V  D$ q! }
    * L9 ~# H9 w0 H' H/ k1 e! P2 t* K        num_frts = model_ft.AuxLogits.fc.in_features
    2 h/ W. ~3 }  V, G, Z9 h2 A3 o        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)/ x; q5 K+ s* t  k0 J+ s7 Z
    / X8 i3 B2 T/ }7 c, ^* {: Y# o
            num_frts = model_ft.fc.in_features7 ^, _# x3 {1 I/ _5 P
            model_ft.fc = nn.Linear(num_frts, num_classes)$ |  D5 ^, }- k% n% ?
            input_size = 299
    # n" k! p0 J4 C" c  X% o4 y# _- Q
    8 T8 j7 o" l! d- J    else:
    & f+ _5 x8 x5 Q4 u- n        print("Invalid model name, exiting...")
    8 e$ n9 e, T) R# k0 @: K        exit()+ G6 R- x+ T5 f% d
    ! u; `! X+ w) T5 c- F
        return model_ft, input_size- I& W% w0 U2 x; c0 y% ]1 G' p

    - V0 _$ Z* ^- f# H' K8 [1
    6 l) l2 Y# s. d. ^. a3 }2 W2* P, U" z; d1 c3 y0 l  ^- E7 ?0 }
    3
    & t$ x" ?1 X, ~  ~! o$ @4, L, q. f0 _* x7 |; k
    5
    ( ?! E7 H% |9 [" t& ?6
    ; P3 E0 B: l0 ]  Q% M7
    ' c6 `8 x( z% f0 M/ C9 B' @" o. Q8
    : W4 w) E( d2 j9* S& R  l* U* f% A7 ?
    101 p5 Z5 ?; I; `% G  s& a
    11
    & k- i0 Q7 [$ L2 F/ a12( z& C; k2 M3 z& ^) Y
    134 s$ `8 ^) V" z3 X* i
    14
    1 x5 d) k8 V( e15
    1 o& p6 d7 z2 Z5 \165 ?4 U7 R! u* Z& t
    17, b% J) U. U/ B: J1 P" Q
    18- }9 o5 S2 a4 n  b9 C
    19. [; R3 Q+ Y/ F( C3 |2 W3 d7 J* z
    20/ f2 C8 v7 i, {, u- f
    21
    ) U: @& R7 a: j% [  L22" ~& H: |* ~3 I: Q% ]/ }) q4 G
    23  r( r# }8 T8 @, V, p
    241 [& H& G* B' i
    255 ]- q! w, i3 H1 O# }
    26
    ; e; b( p; |. @$ E273 G. @0 j  E8 o% \$ `6 K
    28
    : S) e8 W2 w0 d! m+ G+ o29
    1 }1 _* q# j, P  I. \$ |- S30& m/ [  Z; o9 e
    31
    ) I2 I( e. k  x: W0 a32" i% c1 e( I# p( ]4 {- }3 Z0 j& d( q
    33
    7 S6 ?) A* O, V9 M5 |# L346 U7 S: m3 _. t5 _& i: Q
    35. z& z0 T, k# W) G$ [, C5 S3 \; ~
    36' k0 y  D% a9 S& @$ I
    37
    # C$ x+ _6 y" C! I38
    / c/ w$ c  g3 r* h9 }" v( P* D; Y39
    " E3 I8 \4 M+ ?' l40
    : `  E6 Z' t% s: Z! f  D7 \41
    . u. I- [$ s' V) u* ^, X$ ?42
    2 {" W/ T% Z5 i& ]. C43
    & b" c( |: G5 t; I5 ?, l& J6 u$ q443 v8 Y% o! I: Z) @# V* K
    45
    + h0 W& x; b5 b1 |0 {& X( H5 @) k46+ y6 j5 F4 [* H, t9 W
    47
    : s6 Q9 S, j& o3 ]+ O* E& A0 @5 s48# E5 r7 S  u9 g2 o. @$ |% Z
    49- v7 z; _4 F( l' f
    50, P# M6 {" z, Q. S/ Q' |5 }
    51
    + V' _' J$ L4 z* e3 T( j4 Z/ S" J52
    7 a" S2 ~# V& T3 `! Q  A53/ c( w! z4 N+ w- e: U8 a2 N) Q$ V
    54
    . K1 b" j: G% C55+ l6 H1 ^& Q+ p
    56
    " n+ e9 g+ O2 ~; }, J2 c$ C# ~57
    # n# w) }6 a% [58% k( Z9 x9 I% m7 i1 F
    59( E' D; g& l* b7 d" |
    60$ s! Z1 n, ]) c) V2 Z! l! t
    61; W3 `5 E7 l% t1 q' `$ Y- y) S
    62
    5 Y9 v* ~& I1 Z& ?* R63
    5 e* A: d* }8 i* P64
    & O- a) a% X5 W. Q+ h3 c  x65& E( h3 y' E: r  t  [! ^" ^/ y; }
    66  P8 ^. k3 w( N3 N
    67
    - x4 I$ z/ \6 H! ?' z- {7 t68
    ' }# }" N0 `( N4 k: z69- X& _7 i( F3 y+ Z" r6 u8 \
    70. U: f$ M- v. ]/ G8 z
    713 _( {- v% L/ y) J$ Q
    72) {' a! X: R/ I' V
    73  c( d, e2 w( k. R& l5 `# V& s
    74; @! T  b* F9 h! A# d% L9 c
    75% @; C3 I1 T6 o
    766 N9 G" u0 `! F
    77: c$ b: L( G% A* d- w+ `# S
    78
    5 L6 v2 L$ Q0 b$ L* }# S% Y( q& V79
    # e. B& P: [$ y8 N) f" _80
    ) D/ y& e: m! S4 x81
    8 O0 z. O* R% C; C5 g9 k  m82
    2 h# _# ?; v2 G$ x6 O3 b83
    " s& i& w. i* Y5 l7. 设置需要训练的参数
    4 w; J5 }7 h; e# O% \# 设置模型名字、输出分类数* ^- P4 O" h2 r7 Y
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    ! T5 d  v: y3 J
    4 L# m6 O7 d$ x7 ?) c# GPU 计算
    4 ?6 ^3 G9 \0 w1 ~% X& P# ?* O" qmodel_ft = model_ft.to(device)
    , X7 m) x( D8 f5 B; n+ R, k0 {/ h" T# ]5 Y' ~" k  K
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取! V$ A- ], E3 J8 S6 V) g  L* H; X: K
    filename = 'checkpoint.pth'+ J  g+ l* r7 m, y9 p, N
    0 |9 O. ^8 N" Z. h" G+ s; n8 c
    # 是否训练所有层! Z0 o, K8 \, D
    params_to_update = model_ft.parameters()
    3 T' X; ]" L; R5 k' M! I; D" ?1 \3 U# 打印出需要训练的层
    5 Y9 x# E0 v- S8 F5 {4 Iprint("Params to learn:")1 b8 n  I2 B2 w' V
    if feature_extract:: [# K/ u- r, t" l2 u4 [
        params_to_update = []
    3 T. O6 A2 S! `1 ^! j$ T$ }- t    for name, param in model_ft.named_parameters():0 U2 n8 q& y( ^7 H5 p- j  h
            if param.requires_grad == True:  ^; N6 Q% S( c8 p  W7 Y
                params_to_update.append(param)0 b2 {" K  B* g. v- y
                print("\t", name)4 n/ A0 k+ ^* h* E6 ?; R
    else:9 }' S" L5 n9 X0 C' ]8 I  n; R
        for name, param in model_ft.named_parameters():1 Y2 d; S9 r# @# h
            if param.requires_grad ==True:
    8 m2 V: `9 u7 p* L7 e2 q8 X) x            print("\t", name)4 W" j, ^! H9 E. \: n& q

    / O! W% B# M8 S19 N* [3 D+ h" H, v0 q, E
    2
    8 \6 l) m# @! b' N/ X. l9 W32 t* R. t# `4 g  ]$ P
    4
    ' W2 A; Z+ \. S- h& U0 M9 t$ w5" \2 J5 @; C% p% Z! X, K' j
    6
      A8 ]1 D( `7 ^+ A9 [+ ^- P7
    4 @4 R/ X8 q5 x: ?( o! ?! j8
    0 W5 U: i* l2 _6 C2 }% R9
    : A1 V& j4 Q% i5 C$ v0 u  f10
    8 T. |+ {* }* S1 q11. |, {! ^  o1 z! P
    12
    , t; s  s0 ?, L( I" X- J/ O13" s+ {3 P7 n/ R0 ?
    142 X* B/ V/ E$ y1 G
    151 n! ?9 d* F" O$ `5 c' N1 n
    16
    8 X! m) H% O# t2 H2 E, {. q% w$ q176 g0 S" s2 `) }
    18
    " t) g3 L- i3 G0 f9 A3 V. ?19
    , V  y9 P; k1 O: X/ G20# {$ b  A# _6 |6 W
    21
    6 m4 K$ j) Z- E8 X2 m22
    ; ~6 _$ F  x* C+ y9 `3 d1 @" ]234 h, n; b6 s2 M2 r' {
    Params to learn:
    - P8 ~9 ?- O* o         fc.0.weight5 v& l1 b1 K2 R. Y2 u4 J
             fc.0.bias
    # g) r& k5 l! @$ o: v0 u& {1: f( h% H0 l! M! e
    2
    5 L% ?: M# ^# A$ M: Q1 C32 C+ y* x6 ]# e8 j/ [; T
    7. 训练与预测5 y6 R8 d4 Q4 R
    7.1 优化器设置7 m  o2 s8 i* d/ \
    # 优化器设置
    9 h7 S- O/ k0 E8 k# u% M7 Soptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    7 G( H" L' x2 U0 r8 |0 M! X2 v; I# 学习率衰减策略# ^8 d& ?* b4 O; F- \% \- `) H
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    9 N. ]+ r" F, k& Z7 f# 学习率每7个epoch衰减为原来的1/106 e6 Z8 |5 x/ G, J0 d
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    # S- I7 N, T; J& p7 k1 v. a9 m- D
    % j8 F2 }, [! x( z: T3 icriterion = nn.NLLLoss()8 E3 V' S$ e6 O2 H
    1
    ) a* ~4 I7 g$ D3 _8 Q2
    ' E. s1 g" h9 R4 y" \3
    1 e) }: E: I4 K7 b  e" y4
    " Y" [: l( G' O+ j* k" \8 ~5
    ' F$ o, _; Q& p% a' L1 S- Q+ M6
    9 }# m( P. B* N" m. l3 [77 q+ ~6 K0 H- C' }
    8& i, r  G! P9 c0 k) Z, }
    # 定义训练函数
    7 U" j0 A5 B+ F7 Q! X#is_inception:要不要用其他的网络  L0 K6 W$ e! z, w2 b
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):" J, ^8 Z+ {( E# I2 M
        since = time.time()
    2 ]1 ^- }) K) p9 T1 p    #保存最好的准确率2 c+ v3 ^+ O% j" I6 ]: \
        best_acc = 0
    / F1 W# [+ V/ B% c1 J    """
    # {/ q7 a: K- H8 w4 W$ `. O; A! z    checkpoint = torch.load(filename)- ?" s3 W  k, @3 Q8 [" Z7 B& Y: _
        best_acc = checkpoint['best_acc']
      v! p1 I% A5 ^: c! ~' I8 K- L    model.load_state_dict(checkpoint['state_dict'])% T# M3 j9 O" o! q
        optimizer.load_state_dict(checkpoint['optimizer'])
    8 b  G% o3 d, U0 _! V0 W. m+ y6 \4 b1 e    model.class_to_idx = checkpoint['mapping']$ E& l- o# t# N! z3 u" A& v
        """. q: r$ S. [5 t0 P) M
        #指定用GPU还是CPU
    ; ?- U( S$ {5 }) d8 H* x; {0 p    model.to(device)
    9 O2 n2 B8 L2 k9 h. B    #下面是为展示做的2 B3 m2 m6 N" L
        val_acc_history = []
    . d8 K% s  U5 d" ^% I$ A" A# a    train_acc_history = []
    ( D2 T" O' o% |, s4 _* K0 E    train_losses = []
    0 q" x: \# A! p+ _- \' d    valid_losses = []4 F# y4 u  C9 {& b
        LRs = [optimizer.param_groups[0]['lr']]
    ) \( P3 w) W$ o& W1 N+ Y7 C    #最好的一次存下来
    4 d+ Y8 \! H- g" B% r% X- G- k    best_model_wts = copy.deepcopy(model.state_dict())) f$ [& x, u2 V9 [

      {8 j$ k6 I9 N/ t  `- d6 q    for epoch in range(num_epochs):
    % }& s2 G0 `3 {/ e. p& M, P        print('Epoch {}/{}'.format(epoch, num_epochs - 1))/ O0 o7 o  `; `& P8 u; ]$ F) o
            print('-' * 10)! A4 U# M. |7 ^: n
    - Z) `* ~% H! E  H# ?  G
            # 训练和验证
    2 Q! f0 L% _$ {# w; }4 S# G: f        for phase in ['train', 'valid']:5 X+ }3 h$ x9 {" [/ G; ~  b
                if phase == 'train':3 f+ y: \* @( I8 B" {
                    model.train()  # 训练
    + r9 s" K3 \) O            else:
    0 x* r: i0 @- n+ _) H) S                model.eval()   # 验证
    # r: |& [$ `+ A9 r, G9 e% u$ P" l+ _. G, X  g/ M* R
                running_loss = 0.0; j: ]; p" m9 s- k0 L* y
                running_corrects = 0
    8 \2 h0 x: {! E' L- {# b% p' B* P% ^6 m
                # 把数据都取个遍
    & N/ H7 U( E% i            for inputs, labels in dataloaders[phase]:
    / B% N" q3 D2 L8 i6 N% H                #下面是将inputs,labels传到GPU0 X/ W- M- B. X# [( M8 A: b( p
                    inputs = inputs.to(device)
    " T$ p) G# q% `- ?! I7 D                labels = labels.to(device)
    7 s+ }# `9 E- B7 h" h) R8 n9 e. ^/ D! S& W7 t$ w, c
                    # 清零" ^' W2 K# C0 J* @+ g5 k
                    optimizer.zero_grad()
    3 o' F! x  J. e8 p) Y9 r+ _) b9 ~                # 只有训练的时候计算和更新梯度( E- \: B) z" n) |. `( F
                    with torch.set_grad_enabled(phase == 'train'):; ?5 u+ T8 {7 J3 g8 J
                        #if这面不需要计算,可忽略
    ' P1 R5 d5 C8 `7 m+ n+ S% N                    if is_inception and phase == 'train':3 m  z4 w) h0 i4 |
                            outputs, aux_outputs = model(inputs)
    + ]' F! {& Q& B( @" f                        loss1 = criterion(outputs, labels)6 w% t, p( `7 p9 i' ?
                            loss2 = criterion(aux_outputs, labels)
    9 {$ j& R! \) V( u0 j. V  k                        loss = loss1 + 0.4*loss26 \* t6 m$ f9 o* |6 j7 l
                        else:#resnet执行的是这里( G9 G& B1 s* V
                            outputs = model(inputs)
    " `& K, v9 V8 a$ w                        loss = criterion(outputs, labels)2 j; g3 E- O' |9 j

    # C! l% E& E% o/ e% s9 T                        #概率最大的返回preds& x- }3 v5 U9 G# n
                        _, preds = torch.max(outputs, 1)+ Y: u6 T: _* u( X' ?: c
    , p6 V0 L7 ?" E; x0 m0 y/ z
                        # 训练阶段更新权重1 W  A* b$ p( u1 b. ]- ^* y
                        if phase == 'train':8 G* l& R3 J0 |6 }- V
                            loss.backward()7 S' l8 @, B. W) C% c
                            optimizer.step()! i, r6 x/ S1 R

    + }# b* B- `% H5 R) B6 h3 I; e                # 计算损失
    # q. Z9 Z# y, K7 ?# ^                running_loss += loss.item() * inputs.size(0)
    6 d; U5 w" M7 h( O& i6 O4 g) [                running_corrects += torch.sum(preds == labels.data)) N* S. O1 o% v( Y, h

    & ]% f) V  v5 h$ E& a# L# r            #打印操作. I7 }8 ]& B. m9 {
                epoch_loss = running_loss / len(dataloaders[phase].dataset)
    7 T% |7 o: z5 y8 a3 A9 w            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    ; O7 h; E, Y2 t2 K
    : J% a( D( ]6 s8 @; w
    + o8 ^9 T8 {" v' G  \/ y: M            time_elapsed = time.time() - since5 ]5 Y% S4 ~( g! G
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))# u$ l" C2 X1 L/ u: Y9 }9 U
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))( [  R) s. C: K* r8 W

      N& i. `" V3 Q6 a* ^" b
    6 e- r3 B. L# |6 b2 W. [" Q            # 得到最好那次的模型; g% _/ [! v& L& G6 u, Q3 |4 ]
                if phase == 'valid' and epoch_acc > best_acc:3 ]5 b5 D. f. O% a: c& u6 W
                    best_acc = epoch_acc
      H6 ]0 D% w% ]                #模型保存
    $ v8 a- U1 J( _4 ^                best_model_wts = copy.deepcopy(model.state_dict())
    $ k( W; F0 K3 P* b0 V% K" f                state = {
    ' Q4 s* d+ t/ G: Y- \                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    4 f9 b3 a5 `  B/ l" _% o9 x                  'state_dict': model.state_dict(),
    : K8 s& B: t+ ~5 ?+ j                  'best_acc': best_acc,
    , ]% X1 A- J- F- p* [; D( h& N                  'optimizer' : optimizer.state_dict(),# E2 a- p4 k( d1 m0 O
                    }
      ?# W5 n! m  P$ H+ c8 h. d7 g3 J                torch.save(state, filename)
    / p% g  N3 ~  }/ P2 v            if phase == 'valid':
    ( J1 N, L! K. T/ ^3 s" d                val_acc_history.append(epoch_acc)+ s( b# ~; y8 `% ~  @" V; ?) D1 e
                    valid_losses.append(epoch_loss)4 m' ^; O( u' t8 y# m
                    scheduler.step(epoch_loss)
    6 G# i: q% p3 Y/ }) Z1 C- E            if phase == 'train':
    1 \0 X# y2 k/ ?8 ~3 m                train_acc_history.append(epoch_acc)3 ]& A7 k8 v8 [3 I+ I5 s
                    train_losses.append(epoch_loss)' t* U& ], X' b" f
    5 b4 a& n$ z$ P, X: q3 Q
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))9 A, d2 [3 ?! z# b0 X
            LRs.append(optimizer.param_groups[0]['lr'])
    4 T7 K, M$ n2 R! I( @. ~        print()% Z2 C" ^5 y1 K( ^8 x0 V1 [

    3 F1 Z  h( w& g* W5 c    time_elapsed = time.time() - since6 `% l& W8 \1 A+ b% D) S8 U! |
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))$ @( g5 e# l8 f
        print('Best val Acc: {:4f}'.format(best_acc))  w5 Z: l# d, Y; ^- ?& s' N$ _

    # E1 Q, O! o% V) d7 V) o# t    # 保存训练完后用最好的一次当做模型最终的结果( P% Y9 w  g, ^  X  t( K( t* |
        model.load_state_dict(best_model_wts)5 l( f4 ^' v- ]- ]
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 3 r. e% O( Y+ f) _

    2 D, M2 P% w. n! |1 D" [. ^# E6 F4 d# O0 q
    1
    4 I- J3 f3 G* t2* T1 E. U: T2 `6 ^' h: i6 W, }/ Q- V2 {
    3
    3 o3 U! N, i2 h( g) H0 W. f4
    7 G2 T* g9 N' B5 o! V' w% ]# ]0 a56 N7 W* ~2 `6 |5 b2 t9 ~/ {2 k; i
    69 W. f1 k7 b+ ~, |. c
    7
    9 ]0 c# I, b! G8 V. D! k8, j$ z+ U) A% H; q$ ~* O& [1 F8 `
    9' s6 E, l( q8 a' B  t+ t
    10
    6 x  d$ B" I( s6 d) G11- P" E! z& `2 v6 U- ^, g' |) o2 _
    124 d+ l8 ?' E* U2 F
    13
    9 ]( _: O2 y+ W+ \: J14
    * o3 C! p/ g0 A* I8 @2 U9 e  p15
    6 ]6 t0 k( p" J1 b16
    1 ^8 V% A, O. k9 \! k5 p# m' f17
    / J+ H" q; o0 j# p$ ^18
    : H( f5 j! N5 _! o197 q0 o6 F6 Y0 l) N: Z# o9 F
    20
    ! q; Y6 \* e, s215 a# D! _+ t5 b. N3 ?% G2 }  c) `
    22
      O9 I3 a3 l$ ]1 @7 i, ^6 T233 X. i6 F; O$ ?0 O
    241 I, ~: M2 w3 M' S+ q8 ^
    259 k: r5 g9 q' _4 o2 ~
    26
    3 t1 P' @/ \# a27
    & |3 g8 i9 B5 A& Z28
    4 W# d6 p# s% t1 I) y5 C8 v2 r295 ~. O6 k' o0 [% [
    302 C2 ^; Z) u3 f- @/ P& X* l9 }
    314 K6 n  f8 `& a
    32
    / V' R" ~( q. z  f33
    7 s+ P7 s2 h& @) t) O  q34
    ! l, s$ u/ W( i35  _0 e- e7 I5 Z) P
    36' {  b+ P! d, H+ C
    37. e" L/ V, |8 O6 n, [
    38
    - c# ^9 g$ L/ U- L; H  F39
    ( ^( b4 I/ [* ]1 H1 O' |4 t  k401 R1 L% Y, m; n( ?
    41* Q# ]4 J  F) d! b( Q
    428 L' \4 u9 O6 a. h
    43
    9 j9 u' ^8 x1 W2 ~44
    % q; \" k% t! K45
    6 Z2 c5 b6 N6 m/ L: r46
    - S; V8 V1 g! o: a8 ~, w9 V7 v7 i1 h47
    " r& }. @8 N+ G3 t, N48) @& A# r  H4 c, M' K" A
    49( x# n. e) G. {# ]
    50* v9 R6 j# x7 x% J. V
    51" @6 v; {5 q0 b1 v5 o
    52
    ; z) J3 Y6 f  D* k) ^, U53
    9 M5 o# O  o( |3 U/ c) C* `54$ A  a4 ]/ _  ?- g8 s- v
    55
    5 f% b6 M5 ~* X  \  M( {9 u, s560 ^+ \6 T; f; V1 W3 |; c
    570 m: y! z: M, G7 F+ _
    58
    6 [7 e1 b& _# E6 w9 \  T59
    ! r- o* w$ ~& {% W; Y1 e, Y/ E9 H$ D603 \* m' |9 u$ c& b
    61# x, V- L- ?; L0 z3 L$ d
    62- Z6 @  s# _2 ~/ }1 M$ q: X1 O' @
    63
    * c7 u& A4 B" H6 ~% U' `( O2 a' \- s4 |64* Z9 Z. J1 D# q8 y( ~0 Y2 B, b' Q
    65
      J/ e7 {# S0 N* k+ e2 k! B% Y: [7 R66( a7 ~& S  J, q  T( u# q
    67# I  P5 h- O$ r
    68
    - z: ?- |1 Z& I5 B! D. q69
    ' n* Y  y" Q8 ~6 P6 d70
    9 z7 S; ~1 ]& R2 _711 [+ j$ ]/ A; C: y7 y
    72
    + a8 [$ f# j& d; g+ O/ {9 `2 p/ E73; |$ S6 x4 X% ?" e$ o9 R  ?( s
    74
      x; j8 ], g/ w) ]$ Z+ D" f) G75
    ! s0 u- E8 a) G* w( q76
    , u+ ]! p% B8 C8 q+ |% b4 x5 N4 U; f77
    2 P: o6 Q2 x( h: \# C2 d, P/ z78
    2 g6 s3 S1 t3 M: A8 w79
    1 Z' k5 q  H& f5 I  N" x80
    ( K# Q) B8 N4 U8 h81: [8 l. @2 Y6 K  Z% H
    82
    , e  r; Q3 C% e2 j& ], `83" v* _% Z8 l5 j2 M" K
    84+ g- b$ J" U% u' \5 L. I
    85
    6 e$ z7 R: Q  P2 ^5 a7 R+ n+ E86
    2 r) s: j. Q& u# S87
    0 g& ~" R) B0 f  H4 `' }9 T0 O( @! n, G88" l) e. Y+ d+ B, H; |
    89" _1 m" A* R3 ^/ s' _1 m+ C5 S; ?1 T
    90
    ; @; K& `# L6 p916 O/ K* I) Z  I2 J# g7 f( Y) g0 z
    92, n+ x# @4 V3 @' v' L
    93/ Q- ]7 z2 F* ?) }! v* W( f& g& a
    94' V) q" u+ m* b* f  }* m
    95& D7 T, P* W8 j1 b( J; L  B/ R
    960 O( T+ a5 S( N6 k3 ?( }2 ~
    97  |8 t' |, M: g: B7 \, b
    98
    5 \8 s  t  O0 S2 L- @5 _99
    & ^, d8 X% k& q, R& p9 {100. t2 y" d9 R; B3 f
    1012 j4 p5 q% T. }
    102
    2 ^) s  F# t8 Y103
    ( z8 ]/ B8 U! K104( O/ I" ^; P% S' u
    105
    8 d# k' q5 u" E- D! |106
    , M2 w/ h- c1 P6 {2 }1077 y# y+ x& e1 f' Y; M' Q
    108, U7 {* G) f( d  \) S: `3 e" X7 j
    109
    ) {9 x0 u  [' M9 n110
    1 D- Q$ c7 G/ U' n4 |0 T  }" d111
    , j; Y# a/ @6 q. B$ P9 ?112* ^2 D3 K$ v7 M# p6 u
    7.2 开始训练模型
    0 N# S7 W) S# z我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    " @+ g4 p7 C% o7 N
      m2 t3 M0 F- i5 ~/ |#若太慢,把epoch调低,迭代50次可能好些
    8 {- o# {) K+ g, i6 J4 M# _# S- Y#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    2 m" S3 |4 b+ u! @# o: pmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    3 \+ O+ }1 }9 L& w
    # p8 |7 b' s" S2 h1
      z% E1 c1 r7 [7 ?2
    5 }% X! R" }: _) g7 {5 Q9 S1 k3
    3 ?1 V3 t- `, q  r) v0 v4
    6 S' v* t/ X; EEpoch 0/4
    2 P; l! G3 ]) ?7 K" j$ C3 G----------7 `7 t. L& T! k5 a' |2 z
    Time elapsed 29m 41s" v3 x. w; |6 S
    train Loss: 10.4774 Acc: 0.3147: p0 x+ [! Y6 V2 A+ U, b# ~$ R
    Time elapsed 32m 54s2 ?+ ~& `+ K. J5 q0 P3 L+ C
    valid Loss: 8.2902 Acc: 0.4719; z- R7 m6 _4 @, P( B
    Optimizer learning rate : 0.0010000
    . [3 L, D) }. O5 @+ G$ e4 {2 X0 Z
    ! j7 m7 u% C6 R* ]7 i5 P1 KEpoch 1/4: K# c6 d; A8 z4 [3 _% c
    ----------1 L# B1 v5 R8 l0 K" I7 `
    Time elapsed 60m 11s
    ( g5 Q, Y% L' }3 btrain Loss: 2.3126 Acc: 0.7053
    $ G6 |0 ^) d- V8 n0 XTime elapsed 63m 16s
    6 M6 C9 w' @+ J: E  x) ~! X/ w* fvalid Loss: 3.2325 Acc: 0.6626/ k) E0 B2 F1 v% l6 q5 W7 }
    Optimizer learning rate : 0.0100000
    ( `# T  r; s4 g8 i" w6 R" O  _/ y9 ^  b# E* N
    Epoch 2/4
    4 M4 M9 y  D: d5 g  v1 Q----------
    3 v  A4 b/ @  s# e* C; b" {Time elapsed 90m 58s
    6 D; g# T0 E. `6 ?train Loss: 9.9720 Acc: 0.47347 E6 u2 l9 |0 p7 ^1 D
    Time elapsed 94m 4s$ R. s8 Q" ~: A: u5 K, ^! ~
    valid Loss: 14.0426 Acc: 0.4413
    # M4 Y! |6 v7 [1 GOptimizer learning rate : 0.0001000
    1 D6 z3 Z5 z0 W+ A
    / ]2 W4 A1 C& A% b  tEpoch 3/4
    / I; `/ p/ a! @& H! _----------* H5 f! M' O) J% s
    Time elapsed 132m 49s
    & R/ E4 y6 b# m1 L# T' q2 htrain Loss: 5.4290 Acc: 0.6548+ t/ m, j$ r/ F+ d8 `; g
    Time elapsed 138m 49s: d8 x. ^( W0 y* B
    valid Loss: 6.4208 Acc: 0.6027
    8 T( Y! J4 f4 sOptimizer learning rate : 0.01000009 j# O! i/ k8 K7 E" E
    + n2 ?0 F- Y  ~2 X0 h
    Epoch 4/43 c3 ]% x& A2 V
    ----------
    + f1 C: R) C* h  q' I! l. eTime elapsed 195m 56s4 }  i  N4 S3 f% ?, M% t# q6 I
    train Loss: 8.8911 Acc: 0.5519
    9 W& f% s4 F* ?6 T4 L+ W, STime elapsed 199m 16s( ?# k2 q) N) ^, I( I' ?! C
    valid Loss: 13.2221 Acc: 0.4914
    6 {0 V: q, ]/ p. aOptimizer learning rate : 0.00100007 P1 W' A8 W3 A  V% v4 h' x

    ; P2 I" e) f7 V; S+ z2 C2 w" ^/ ATraining complete in 199m 16s
    ; k3 o1 j+ \7 {; c$ v! iBest val Acc: 0.662592' s4 X& `. j3 n9 {' w) F4 e6 j2 w
    / \" W2 A, i% i# R& ?2 Q5 }' G3 f
    1
    0 V0 S  z2 Q5 I0 ]2 h: _1 Y9 Y0 u2
    , w& R* X  j# L- y0 _+ \3
    , I* p" [) c2 g- I6 D. u4 O4  Z5 f: q+ b5 h7 H& }% j7 X' _* g
    5' j  g5 W: v% A' P5 e* V/ C
    6
    % T/ g, H( `3 D4 u77 J( Q# K6 Q, D- p/ a0 f
    8/ C" i/ l! \, y- o. F, n
    9
    5 o9 s% Y2 t% A4 A" S1 v3 B10
    3 [+ b- r8 Q+ [$ D) U1 K11: Q% O4 z! K) U; Y8 }/ x) Y3 m
    12! m# _; f* _1 n5 Z8 M
    131 V/ g* e2 P) [0 V$ y; s
    14
    : c1 V5 h) o* h* t( H( p0 J15
    % p- L8 Z5 N" M  Z# c4 ^. H16/ ?1 @. R4 G: m* d
    178 W! h) d3 H. e( ^! U
    181 H! z4 u# G5 b# ~
    19
    / O6 s: a; Y4 C; F/ F6 h$ B- t0 p20
    7 r$ f" b6 _7 @/ E+ s- G21) T) ], c9 w1 c7 v4 t  t, h4 b* V
    22: b3 L  q8 b' w, j/ h/ q0 Y
    23
    : A, A4 ?: f) x$ C; Y24& h) {; \* k* C
    25
    5 ]* b- @2 D6 e- r% U26
    6 v7 w) X2 f1 B1 x" c27; y1 ?6 L) r, K5 D9 g5 b2 q2 u
    28
    . n8 Z  V* t: [5 C. N4 q; J292 F  _3 G/ \$ z
    30+ r+ Y" W5 d8 H9 \! T7 u
    310 t" h- ~! A2 V, `' ?' y
    32/ z6 N- q7 @2 c# k3 N% `
    33
    1 c. a- j; X9 a" v  a- j+ k/ l, k34) d! Q+ a- B  X5 t! H2 @
    35
    , r# w  l8 f% ^0 v* W36
      s- H/ _2 s; ]6 v0 Y8 a1 a37. d9 p1 V: {5 D( x* @6 n# M
    38% S  i  P$ d; k+ E" j5 D7 t
    39
    . V' A. B, n6 b% ~0 `; Y8 _* i40! E" i/ Z% F; D6 ~1 e* E
    41
    2 M. t/ x8 ]# [( f: T! p/ h42
    0 r8 S& ^" Z: c" f7.3 训练所有层
    ! i# y# A# l9 t: i" ~# 将全部网络解锁进行训练
    - ^% h9 C9 r& H! J/ V, ffor param in model_ft.parameters():
    5 @0 ~$ L! N# Z  R2 y3 b: m    param.requires_grad = True
    ( J6 l" P1 v# G% W+ s$ Y
    8 D/ t' Q3 O) H$ h' J3 Z; \) H# 再继续训练所有的参数,学习率调小一点\
    ' B( ^9 m. K: Voptimizer = optim.Adam(params_to_update, lr = 1e-4); J; Y/ E3 P# C1 t# r5 H: f: z
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    2 r. s5 d- ]) }' D3 }. @. e: W" f: h& O. S
    # 损失函数
    % ~4 W+ c% m  F$ W, v/ ~/ D% Ccriterion = nn.NLLLoss()
    6 N3 G. B6 P: u& A: C  t: m1
    5 h9 C5 Q, M& I  h3 I2 v4 \2
    - r  B$ ?: ^0 y# W9 _) ~/ q3
    " P2 _" f& {( Z: D$ O. d" J4% c2 q9 a- h7 w' I; y) M5 ?
    5
    # |( d1 k# x9 v/ D7 Y9 S6
    4 v4 h6 ]! j3 w( x- O: _( `, n7
    . I$ f+ r2 g+ Q3 m81 U* @# M8 `$ r. k" v
    9: g5 B6 o$ `! s6 ?) W/ ~
    10) I6 |# Y0 P5 j/ Y8 ^
    # 加载保存的参数3 w( h; y* Q' f  _3 E: t4 n
    # 并在原有的模型基础上继续训练* a1 |5 _# \2 g5 m8 T+ p1 l% F
    # 下面保存的是刚刚训练效果较好的路径1 c+ J) P" D# h6 X: L
    checkpoint = torch.load(filename)
    - }, @% P5 k5 a$ Y) F; _: Z' \. Ebest_acc = checkpoint['best_acc']
    ' ^& Z' ~) H9 Y, S7 W& \model_ft.load_state_dict(checkpoint['state_dict'])
    & h+ p' Y1 Q3 [: T  U2 [/ e" ], ?optimizer.load_state_dict(checkpoint['optimizer'])5 f; h; D# o: C5 ]' K- O4 h' X
    1, H5 @7 I' W. L- U( V" `
    2: F' ?3 Q) b# Y+ b2 D
    3
    0 |. c2 E* r( u' [' C$ m$ p4
    5 I. B/ y% g, L; N, |" B1 l55 H6 L2 W8 v: {" p
    67 \$ T; |* g$ \2 j! b, }5 @: g
    7
    & p& P3 r. ^9 o7 l开始训练
    + r  H+ O, `9 T# }注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考7 O$ r' ]5 }8 g& ]. f) p

    & {7 p; Z( T$ E( M/ Ymodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))7 G* ]5 u% u/ a! P9 N9 m& {* H) U
    1
    9 h- X: ]0 ~0 s, X. D+ |1 l: e" {Epoch 0/18 a/ _1 t9 n$ O+ d
    ----------' z, K% @- N0 ^" `! `
    Time elapsed 35m 22s5 C9 K" G' I6 L  z0 Q, v
    train Loss: 1.7636 Acc: 0.7346
    + O; A5 T3 f8 h! NTime elapsed 38m 42s
    ) e5 W1 f+ U3 s; S4 y& d6 Nvalid Loss: 3.6377 Acc: 0.6455
    " Z8 @2 L+ T+ j. Y& z* {Optimizer learning rate : 0.0010000
    5 w. ^/ A  P) `* |/ ?8 U3 i3 A* U1 ~9 U8 q+ n' n% `8 f
    Epoch 1/11 Z8 n( ^. s3 s5 z: T
    ----------
    - u+ i8 d, }, dTime elapsed 82m 59s; }/ H. U. b2 H1 J$ m$ B
    train Loss: 1.7543 Acc: 0.7340% Q" ^3 T, q# u- e
    Time elapsed 86m 11s5 M% t+ Q9 P8 P- J" K
    valid Loss: 3.8275 Acc: 0.6137. |4 y9 d* r4 c! j
    Optimizer learning rate : 0.0010000
    2 S9 C5 S4 ~% V, \5 Q
    5 m4 u  v8 ]* d5 K( S3 _. r+ q$ uTraining complete in 86m 11s0 t, ?9 m$ Y( C8 j0 b# _5 ^
    Best val Acc: 0.645477
    + \& U- y2 }( f5 H  O. f5 b( e. W8 m- x
    1
    / o$ x( C) w9 d2
    4 s, l& O2 D; m* K9 j: F3
    6 I/ v& A+ ~- E0 F. P4
    ) p" E! X, H: S: \. N5. |, D9 k4 P8 M) `$ F0 Q) G. m5 v  o
    6& o% M% j8 g5 T4 t- t0 v# P
    7, n: Q* q/ t9 S4 V
    8+ j3 f& @: J- N0 K9 n% |
    9
    2 y% @; q. B3 r* |/ ]10
    0 }( U4 }% M  L6 U' }8 v# A11
    % X. p* L* t& Q! j12- s( \1 J/ L4 c% d/ `' \0 `  e9 @* T
    13* C7 A6 y4 T) ?: \- s) k! H
    14, f" Y$ h; b* y5 F; m. q) |- _
    15
    1 T/ m; Z. c' T  H& z8 R# ^+ d" ]( e167 I7 V& e" U5 t% c
    174 w, y+ O9 K- \* l- p/ ~6 I
    18
    : B. _8 r: Z* W8. 加载已经训练的模型
    : T* ?, B9 U1 U1 _- p相当于做一次简单的前向传播(逻辑推理),不用更新参数
    8 d: O2 P+ g: R2 t4 F
    / U+ a2 w) c- q& _  [$ Q, }model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
    ' \2 \* x; }9 m  m  V/ p0 K7 k8 |
    # GPU 模式& p& U/ s8 D3 w0 x1 B
    model_ft = model_ft.to(device) # 扔到GPU中
    0 s3 M3 g6 H( g1 a( L: |+ w" J6 H7 B9 Y7 P2 @
    # 保存文件的名字
    2 N5 e( q, r! a2 W+ a: c! Ifilename='checkpoint.pth'; \$ a7 ~, J/ V/ F- Z
    + E* @- q! G; t
    # 加载模型. @8 d% f+ a: J
    checkpoint = torch.load(filename)0 m7 _; c  p2 o+ C/ a% z! b$ `# |
    best_acc = checkpoint['best_acc']6 X! h' d. l* u" ]
    model_ft.load_state_dict(checkpoint['state_dict'])
    0 e) ^3 J3 I; h: {1 ]1 ^7 a: |1
    3 l4 J7 H5 B$ ~0 F6 h: D# L# L! W) o; L2( X8 l9 ]0 Y+ q( r, A' A( T. J- r
    3
    , ~3 ^. }, U1 O% n4
    8 p- ^6 I2 S. Y5
    # U& Y! p5 t7 p3 J/ F# u! C62 J$ v7 q, T  C, H# Q1 E, T
    7
    4 g+ x" [7 S/ L( n4 i; z- w6 ~89 E" f! \, t9 V$ m4 h8 P
    9  L3 b1 @. q$ N1 o6 m3 t. t4 G) u
    10( _* M1 H+ K* V$ b* c5 L3 f$ H
    11
      z3 f+ U4 ~; E: w& [2 _  m( d' h12/ k5 I6 h( o6 \( d
    <All keys matched successfully>
    6 o$ c' {# k  b1 `6 n, L1
    . F$ |6 a" y. x9 hdef process_image(image_path):5 R( m! {" `$ t2 L( L
        # 读取测试集数据# k6 [% Y: J6 @) o  B3 k' f" b
        img = Image.open(image_path)- b) V" q7 C: m7 H$ v# M' q
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断+ A  q9 C! P/ |4 O
        # 与Resize不同; s2 W$ V9 }- U9 v+ v5 W' R
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    - m" m4 @5 ~" d( }    # 而且对象调用方法会直接改变其大小,返回None
    - u9 h" d! k; C/ T    if img.size[0] > img.size[1]:- j+ S$ N9 ~$ d: g1 t) g! [) q
            img.thumbnail((10000, 256))
    ! L- U& [* X/ h5 O; V    else:
    ) [: N& l0 T- C; u/ V        img.thumbnail((256, 10000))2 V+ @2 O1 C) ~0 O4 t6 l+ t. _

    ; L5 ~0 ~5 \0 }1 z$ r- t6 N    # crop操作, 将图像再次裁剪为 224 * 224
    , }! F# g# \- p/ {' H' f    left_margin = (img.width - 224) / 2 # 取中间的部分, X. K5 M3 i5 ?/ E: w4 l. q7 y
        bottom_margin = (img.height - 224) / 2
    # a2 ^! `9 _4 G- h5 i5 c    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度7 a; u: S5 A0 O5 |' [! H/ f, \
        top_margin = bottom_margin + 224
    ) C4 W  Q  Y& T8 ?
    * H4 {( r# y/ ?; v/ o1 P    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
      m! W3 H' [! Z$ R! Y
    " I$ I' b/ N, L+ Y! ~    # 相同预处理的方法8 W* p+ S) ^; i
        # 归一化
    0 }6 c: p+ K  \9 E, T, o    img = np.array(img) / 255
    9 s' _- @/ A: C    mean = np.array([0.485, 0.456, 0.406])0 o4 ^* R/ w* Z( s! u/ ^1 n
        std = np.array([0.229, 0.224, 0.225])( v5 X9 G0 l; v* h9 j) t
        img = (img - mean) / std
    , b- ^* G: o7 D' e' E4 Z0 V
    4 u; s/ ^1 X. ^. h8 j" `    # 注意颜色通道和位置& h3 y7 N  I: @# \2 u. h' O
        img = img.transpose((2, 0, 1))
    ! C1 V+ b, y$ y
    & N  c8 o7 G  l2 F2 X    return img
    6 y6 A4 K, r+ L) X, `+ d
    : p9 N* r9 b" F$ cdef imshow(image, ax = None, title = None):: _: {! H' s, F
        """展示数据"""% @+ E" x5 m: Q% O
        if ax is None:) t/ E9 L+ n$ Z' a9 Z
            fig, ax = plt.subplots()$ a8 H) g& ]% b
    ( ^6 w+ a" ^* W# ?
        # 颜色通道进行还原9 B  h- o( e9 _8 y1 S  E5 _  j, S1 L
        image = np.array(image).transpose((1, 2, 0))$ ~* G# w- e8 ~; D  W

    7 H7 Q+ [! j6 v! C% H+ G  u; {    # 预处理还原
    1 c' ^6 @9 K* [% n  x) O    mean = np.array([0.485, 0.456, 0.406])
    + u5 J3 y) p% c" r3 P. a    std = np.array([0.229, 0.224, 0.225]): |3 E" x5 j/ I3 \
        image = std * image + mean
    ) F. s* J1 t* Q; o- ^. U: J    image = np.clip(image, 0, 1)
    5 J% S: |9 @' j* o! {
      T5 O: K$ C8 s* q6 o* b, s& q( R4 ?9 n    ax.imshow(image)
    : H2 @  u+ E( _; F: s    ax.set_title(title)
    " l, w- e1 l, S! ^- |! F* |- \# g. O
        return ax
    ( k5 y, c* D3 Q+ i/ N8 r# o* l( y. Y2 b6 L
    image_path = r'./flower_data/valid/3/image_06621.jpg', s- i$ c' e  V' D2 p/ U, T/ r, ]% V
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理7 l  |: F, \/ C
    imshow(img)' s7 B/ q# W2 a% s# O4 T
    ; U" @6 F1 D) f# X8 `# I; d
    1$ h+ H0 ~, y$ ~2 H$ n: z1 k
    2  {$ D" c' h7 x! Z
    3
    2 X, H$ P5 `& |# r' f+ d4
    : N4 I% h  z6 Z6 K) ?5
    9 P, r& u% C) Y9 H/ v1 c6
    6 D4 r, e: e1 a. z) m$ M7: z, \( C# c1 O, ~# C7 }1 N
    8
      i/ J  d6 t2 F: }" }2 |% z( h& @" ?9
    ' f! d: b. b( K# H$ Q10" R0 @" G2 B" N, ~  n$ y) v
    119 M+ O. e% r4 E5 R4 L4 t. \. d. e
    12
    3 ]  `1 s; ?: y* _$ ]13
    # R9 `$ |' Y" P* C4 I5 _14
    ( r  y# r  U* Q1 w157 w2 ]. M( S: v) k+ U
    16
    # ^8 m6 o, O2 j; X8 V8 }# `17" y9 _" n( g9 K2 q
    18
    - u0 [$ z, Y/ _% I. Y" Z, l19
      _0 R  b" _, E- W8 }3 B20
    * B/ i" d  ]( f! R21
    5 C9 R1 K; g. V2 e3 n+ d22
    - }5 f6 d: {0 ]1 h3 O1 _23
    8 E0 h9 f+ }, V1 B! i! y6 m24
    : \& V' I( ]5 d" X" B# z2 o25
    & h  Y* _! {3 w$ d  A# \26
      d# m' z' l+ P+ b) K" e" m8 H6 b27, w/ b) l2 l' h3 U5 I% v
    28
    9 K6 J* V* B: \' B% H: b( E29
    : D1 O5 r4 ?! r- z: i/ p9 W30
    5 K. l- i, p/ Y- s; `, U7 I31- h: M' F  M6 [% l
    32
    : z- Q: J% r! r. b33
    - t( Z& x* o/ _34( m) G6 U1 A) P% o2 o% v# p' k
    356 ]' E* [7 L+ u  q( D, |& f
    36
    0 a: U+ G: n. K4 a) ^0 }37- B! S& X- o0 m4 `' N' u6 _
    38* D' u  b# J& k
    39. @+ o- ~; O/ Z  P; W6 z. a2 @# t) F
    40
    5 M6 u+ P# V8 G  ^5 x6 Q$ c, J41* b# l- p) V5 u0 U
    42, m, F& B  N0 _3 o
    43! f% U! o; }# ]
    446 X, H# v9 C- V5 f
    45+ z6 I) y9 H7 u$ b' b
    461 a; R, H% w% m+ }4 m  f6 [% O
    47
    ! u/ Y# \  I$ c! K48+ }# }( B9 }9 f3 c. B; \
    49" K4 o7 E! I0 [  Z  M# g
    50; r  t5 @) T) c
    51
    - Z% P' f1 A+ Z9 j$ @; V5 I- d) m52
      M2 z/ U  D( Q6 B6 R2 E) e9 W53
    + l% [/ B" i% {7 a$ E& i8 ~54
    & g) v5 y* i) x, v7 }( x, D+ P# ^<AxesSubplot:>  p( F; {  X7 E6 v) O6 e  d0 A
    1
    , g, ]1 I9 l% J' E7 B& p8 B
    , Q: a7 `% h, ~  O+ A; o& |4 t$ [上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确6 J5 ~0 t/ q. a! I% b  |$ u: X( k
    6 [; Q7 u. Y$ Q& W
    img.shape% \: F- T6 B1 q1 `# X, L6 Z
    1: ~0 N1 H: e9 A3 N: m
    (3, 224, 224)9 Y8 D$ A) A9 @9 S8 o. l- ]+ b
    1. T7 P8 l; o* A) {! M  s/ R
    证明了通道提前了,而且大小没改变
    $ j6 {9 ?5 Z$ O+ u5 F2 c6 X' Z. r! h3 E  ^! p
    9. 推理' N  t+ M/ n5 F: G
    img.shape
    * i" Z6 A3 |" _7 t" n
    0 J( F4 c5 f0 x( t& L4 c) m# 得到一个batch的测试数据
    + x; }( X7 N) w  S& m& }  vdataiter = iter(dataloaders['valid'])8 f- A" W8 H4 Z( C* n& D8 P6 K
    images, labels = dataiter.next()" P0 M% b$ @$ |3 B0 K  C* Q

    & s5 \: l' C! ~" E. qmodel_ft.eval()* ^! g4 p7 d$ O9 t$ r

    9 m* `: U$ |0 F) @! q; T+ rif train_on_gpu:
    : T( @" n6 Z% k$ d  K& _( \" e    # 前向传播跑一次会得到output6 x2 p/ C, D2 Y: n
        output = model_ft(images.cuda())
    6 x! I1 Z/ M2 F, z8 \% Z$ C/ A4 L5 Melse:
    1 G6 J5 G7 S, n* `7 l5 ]2 n    output = model_ft(images)
    ' p6 l, G- Y: q: a) U5 k8 F0 x" w" ^" C6 U) [0 t5 d9 R
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值; ]! c: Q) ^5 I; ]8 j9 y" l
    output.shape2 N* v8 _- p7 s) ?5 ?" v6 c
    ! D: t8 F. n% K( t6 j/ p
    1# T, }1 c! }- j: a, g2 U" t+ ?
    2
    : V) k7 z0 o4 F2 a3
    5 Y+ W! b( {7 _- D2 z6 \4
    , `. r9 o" q6 z$ E3 ]2 |5: g% x# N9 V# {+ i: H% d
    6
    ; Z1 l) ]( E( j& A* a% S7
    7 J# t, {+ a- a3 c& v9 g8
    & @6 }. {* m3 U: A( X9
    ! ]; J0 w6 Z1 n* h10% e) i. ]5 R( o6 A
    11% x. f4 E. Z2 s6 Z/ g
    12  S* s) n4 E1 C0 {$ ^( }
    13# N; X$ a. i# ]3 k5 l& p$ }- r
    14
    ) W; D: }3 f; Y$ J3 t6 F1 |# d15% n; D7 c% R* K. z$ q
    16; J$ i& |. J& M# D' j9 j
    torch.Size([8, 102])
    ! j: V" Z% P0 o: E( o0 E) \6 m0 M( D  P1
    2 L4 A8 p: i  L9.1 计算得到最大概率
    * d5 J9 `4 J) P+ ]- {+ c_, preds_tensor = torch.max(output, 1)
    1 `! N* ?- ?6 W: H. V2 w
    9 d8 ]; e( u" i3 E6 Ypreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量) x" x0 @) @/ V8 Q# m) J; v7 D
    1* T5 G+ K0 Y7 m; y% P6 i
    25 |- y! D) c8 a# R6 N' M: }9 x3 e
    3. t) I8 F) a" b! d
    9.2 展示预测结果9 ~) `. z( f- G2 m
    fig = plt.figure(figsize = (20, 20))
    1 X4 l- V6 @3 ~6 |) ~6 q! ~1 V7 ccolumns = 4/ ]) P3 B$ t0 y; y
    rows = 28 L/ c/ q0 p2 @  i% L
    3 t( X  U  O5 {1 A$ w
    for idx in range(columns * rows):- j8 y/ B& \; ~5 V
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[]); y& S1 v# @, @& {9 W9 I4 j2 l
        plt.imshow(im_convert(images[idx]))
    + }: D! A5 H! l* h! e! A! w    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
    2 a7 W2 v6 d' @                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    7 m" [7 p4 K9 K* K) @( [* Fplt.show()& w  o; E4 ?$ H. H
    # 绿色的表示预测是对的,红色表示预测错了, r: |8 G2 F9 ~( d9 u$ |
    1
    & |* d4 v2 k; u3 j8 X/ Z. U4 @' k2" ]9 {5 K: R# e2 b0 ]5 i: E
    3% k, [- {" d- m) [/ B, Q3 f) }  J- n
    4) C2 k% J5 h! z/ R5 a
    5& D# [% q4 A& n1 V* Y
    6
    3 f- i' a* v2 Z7 z7 O" @7
    ! D6 z8 i1 s- p6 Z7 J8. O5 a9 E$ Y) {& J2 B
    9" d8 ?* d4 q! C# K! z( j, Z$ J
    10" a# }% Y7 \* U- a" z
    11) p- [8 T7 X% {; |7 B" j  e
    ! F9 Y+ _4 V& P9 @. v3 A1 }
    * h8 p2 B+ z3 L* H
    + k2 N9 ]  j% N3 M
    ————————————————5 D. M3 o4 \2 Q' p; c
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    4 `) M( j3 `6 w, |原文链接:https://blog.csdn.net/LeungSr/article/details/1267479404 \& R2 ^) n$ T. [6 L6 B

    ! |0 F$ _/ W6 a1 C  j. q- [$ `' k
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2025-12-9 00:51 , Processed in 0.500241 second(s), 50 queries .

    回顶部