QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2716|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |正序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    3 v: B6 D) R( Y9 O/ x: w8 h9 b, q- ?8 E. T
    文章目录# S6 ^0 A$ b7 S9 g2 K& s* e; w: ^! G
    卷积网络实战 对花进行分类0 Y8 e, a  F3 l( d. I) i
    数据预处理部分
    % Z5 ]# C3 B/ ]( t8 `! Y4 T网络模块设置) U- \4 D) r9 ~, b5 X' M* U5 I# h
    网络模型的保存与测试
    2 R7 T7 M* V! Z$ M数据下载:  {% E$ X0 \# v
    1. 导入工具包& U" {: Q: l: p, b5 o* R( G
    2. 数据预处理与操作* x9 h* ^  {3 _0 J$ J  D
    3. 制作好数据源
    - P3 e. M% U& {, ?读取标签对应的实际名字
    5 B# G( y' C) \$ a4.展示一下数据, F8 h. ~# A- F* `2 I
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    . O3 v0 Q) }  d  _. y6.初始化模型架构6 r( e) V; a$ @+ v7 I: q
    7. 设置需要训练的参数
    . m6 G2 c7 Q) s5 [3 }7. 训练与预测
    " ?8 |2 N6 O; {5 p: ]6 a7.1 优化器设置* P1 h: \2 W  Y
    7.2 开始训练模型1 r. J% m- B/ p$ v* p/ P, a
    7.3 训练所有层
    5 O8 {2 A  P# `) n! Y2 G' Z) p/ K9 l开始训练
    : c2 s  |& H5 a9 l" h8. 加载已经训练的模型
    ) i4 E5 n( S) W% J7 V: B9. 推理* G1 R% h& u5 m& \' j. }
    9.1 计算得到最大概率
    8 P9 G/ ^. t( S  C. T9.2 展示预测结果  Q  Q) h% ?% ^3 E
    写在最后9 A0 K% [" F: Z) T: G
    卷积网络实战 对花进行分类1 c* Z# s$ T' i- M* U
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能' e+ H6 P7 @; X5 X

    5 s, y, @; p9 _+ A2 w在文件夹中有102种花,我们主要要对这些花进行分类任务
    4 A( N6 X# B% I* D0 o5 \文件夹结构
    1 l' z5 V2 {  R6 R. X+ {
    $ {6 N: y% K# a2 Xflower_data; ]: b2 q7 d  J. }8 n. y
    8 o0 W  E, }' @( O. j4 t/ s/ J
    train
    8 B2 ]  @2 V- O9 F
    * F$ A( r* Z5 |# w% l1(类别)
    0 ^9 K$ `1 Y2 f2" w3 }! }# @* r7 R( c
    xxx.png / xxx.jpg" E/ n8 {8 h6 V6 _
    valid
    + H0 Q* U6 D$ L# ~0 g6 z
    / x" C1 u7 p+ v) `. D7 V主要分为以下几个大模块4 G8 |' k3 S1 G# g# N: Z

    / H' j( y5 y' p( F' C/ W数据预处理部分
    9 U0 U2 z7 [$ O% M1 B" U& p8 Z数据增强
    $ E) L  r9 ^  o) b- i数据预处理
    , J  v* c: ~7 Y网络模块设置# h* t$ B1 ]0 Y& c
    加载预训练模型,直接调用torchVision的经典网络架构
    : K5 S5 L, e6 }$ L因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
    0 G; d0 m/ P3 y4 V2 m$ \网络模型的保存与测试
    ! p. W! D! R+ k  y) K模型保存可以带有选择性# k, A( x9 Q/ e* i  o! `8 r& M
    数据下载:
    7 G% Q$ N. Z9 n" Ehttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    ; X8 _8 k* q: S6 [7 V% r6 {
    + {; p; m. v/ C6 L8 ~" m0 m- }- C改一下文件名,然后将它放到同一根目录就可以了
    / n; ?2 u5 T1 q
    $ r9 u- \1 V4 d/ E$ l, ^下面是我的数据根目录. b3 x7 j3 L3 h

    ; M$ O' X1 O& N8 Z3 }$ I+ i! n4 `' X
    1. 导入工具包
    4 t# U6 H5 d% o7 W$ Rimport os
    2 Z8 I( s% z( m4 ^% fimport matplotlib.pyplot as plt* {9 p+ M7 w( Z" D- T. ]% r4 o
    # 内嵌入绘图简去show的句柄
    - h+ W  g6 Q9 {# |. r; H0 o& _+ h" N%matplotlib inline ! e5 Y# C- @/ t" d% Y
    import numpy as np0 `; Y3 B% W' O0 H4 k& O: Z
    import torch6 Y- i+ M2 W4 [( u; ^1 y
    from torch import nn+ M# R' o; l2 ~2 {
    + c8 {$ z5 V1 D1 s, K0 a0 I
    import torch.optim as optim( o( b# K% W5 Q  x5 ?" f
    import torchvision
    4 t7 M" g- R6 x4 o6 W+ gfrom torchvision import transforms, models, datasets, Z' b* R9 ~" O

    # K) ], E  ^, ?& y9 {import imageio
    # Q; o7 z- [" b0 h& l# V: Oimport time
    ) u; l0 V" q, `% L! z* Fimport warnings3 M1 v9 {0 ]9 M: f/ v" {
    import random: c2 G: y' x1 ]+ C' }& ?4 O8 q& ]9 Y
    import sys
    * R9 C" F% r; }2 x. dimport copy
    ) R0 r" c3 s0 A) |7 w' B  eimport json$ N; z( Q  l2 D1 f* l
    from PIL import Image
    $ v0 ~$ D& \: q! ?( r5 f
    ) p$ B3 E( i$ q8 J+ f* {  l) ]0 t* |3 e$ s  I- B
    1$ d3 o4 y) o( e- ?  A7 r: q
    2/ G, O$ ?0 _; I% l
    3+ [' e% c1 n4 H- y# z0 V
    42 I! M1 i, s8 j; u; ^
    5! A) O; @+ R0 r; r
    6  i. K' N8 \& }& ]9 t9 O) x$ Z
    79 R) t0 m, @; M7 C! d+ W. \& ^
    8. g: @6 `9 V; q8 N
    9
    . R: z* ^" ~# t4 v( ]$ d& k& `5 j10
    + W- ]4 M# L; X) {, L* E11
    # C! X$ T8 @: q/ Z, B122 i! N1 @$ O' [7 E( g
    13
    : [( O' o8 X. {5 j9 t# \; c14
    4 h) ~' q. y4 s6 s) H& y156 {9 O1 S2 h4 P- L  n
    168 }5 M3 A$ U' ~1 y# i* k+ F" p
    17+ l( ]9 S7 C* f& a
    18
    9 V1 S, q4 _. c$ ?  e# z  a# ]8 @190 A- L1 L9 U8 k, z; P$ Z
    20' Q5 s4 a5 A: ^9 t2 k
    211 E5 _2 Y9 z, }' c
    2. 数据预处理与操作- _. {! C, [. j! C, c& F
    #路径设置) Y! f! z+ G$ I2 Z
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    ! S9 k1 T- m2 R! ntrain_dir = data_dir + '/train'
    " X% l% }7 y" Z- F! ovalid_dir = data_dir + '/valid'2 O# t8 L0 I# G! L) J
    1$ H% P: U" \* E& C4 ~
    25 U1 _! d. B  Q3 ]$ a
    3& n5 V+ G% h) h6 {
    4$ k: ?5 U5 [$ o, ^, v5 t
    python目录点杠的组合与区别/ z% f4 A+ H. A! y
    注: 里面注明了点杠和斜杠的操作
    - F8 Z! X. E' D) \' p# B( s2 e4 y) s) j  D4 K8 ~0 H* L# l; K3 m$ W2 \
    3. 制作好数据源1 {- R) A+ m$ @2 _# M0 {1 Z3 [+ T
    data_transforms中制定了所有图像预处理的操作
    4 h) I. a# O& G7 V; g9 VImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    7 |4 t! r: K0 A& o1 idata_transforms = {
    % U0 Y3 j* ?$ N0 \+ q) C* G    # 分成两部分,一部分是训练$ s+ q& Q. o" f- P5 [
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间7 ]  X! T: {/ u3 k$ M, p
                                     transforms.CenterCrop(224), # 从中心处开始裁剪
    # F0 U! d: @3 P6 Z                                 # 以某个随机的概率决定是否翻转 55开
    7 ^3 L' h! A2 q  V' G! s                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    1 Q6 y- U2 i/ m# h, K5 D                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    # A) L8 n6 f) d& S' f                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    / Q: V7 P9 v2 L                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),* T" p% x' _3 d. t) g5 m
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB6 v& @5 i" S' J( c& H# y* l/ {
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    # x. {% I/ x/ w# _& J( \                                 transforms.ToTensor(),4 G( L1 O- B* [! C* Q" U8 x, b/ Q% `3 @
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    # t' Z  s$ i) d: t                                ]),
    8 a. o# G" b" _6 Z( Y( A    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化. o  s( a4 w* ]$ A
        'valid': transforms.Compose([transforms.Resize(256),
    9 j  A) \  d$ K3 @* N. }                                 transforms.CenterCrop(224),
    7 L9 y) q! k( u& U$ x  _* E                                 transforms.ToTensor(),! m$ A2 J( o0 W9 |2 [
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    ; J5 P$ x- `2 t% x& Z                                ]),% n3 j: v' @/ \# x( O# j3 {( k# ]! ]
    }
    3 F4 l+ c( l8 U. G  w4 N. i6 b1 t  `! r5 u4 R& R) L3 W$ u
    1
    : y. s; a% T1 c. t2
    ) ~( E' b" q) p+ H" v3
    4 Z; z: ]2 T, m! C4
    - U6 d# \% ]3 j: v5# Y- J( M: Z% {" }
    6
    / ?6 j; n7 x+ w5 m$ I7
    9 o" }9 P8 t: `3 a$ ^8
    & p, I' B( o5 }8 J5 s4 R4 U0 k9/ K3 A; L" a5 ]3 d
    10
    " ?7 C* D9 U! l+ J/ h& o2 E" q11
    6 C* V  g  x3 m4 O2 q; p- G0 {% c12
    # A: E/ J) C- C3 Y$ m7 e+ L13
    % Q5 Q: S6 s" l+ X+ K3 T. s" x148 ]3 Q# I4 E3 s$ a( H" f% Q3 f
    151 _  i: e. d0 H9 d$ Q5 @
    16( S1 c5 x2 F6 Q5 _" `$ h* c4 g
    17  I' n; P9 J) ?% {
    18: I; m+ M$ k. w7 d% E
    19
    6 m8 ^8 r/ `; ^0 K20) P5 |: ~! m: o& ]; ]
    21
      n1 \  I6 `3 O  t1 |batch_size = 8
    9 `+ N& G, Q4 L% S2 g% Yimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}- Z7 ^/ t- `6 u( o; r
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}7 u( A7 P) I; w7 W( @3 g
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    $ |4 W% d+ X, s4 C6 T5 nclass_names = image_datasets['train'].classes" k- B, Z5 k9 h2 h
    4 c% _6 |8 ]8 R! S5 L/ ^
    #查看数据集合1 n7 E1 c9 [6 {
    image_datasets
    # Q3 A/ g6 m0 O8 Z8 E) i; A
    6 z1 [; b0 J$ `$ x5 G1
    ) F1 G) H3 X3 ?& {3 v, ~6 A2
    " E6 K; r2 U1 d8 e0 g2 t. Q5 `3, g4 q4 `) l" L" b3 j* @
    43 i3 _/ Y: j, J& C
    5# E, y. c. X! ]. B  P/ W8 b- R
    6, ^- O. E, t, H( |
    7
    + c) [% G8 C. M1 [/ \( M8
    ! B. O8 R3 \( L" [5 @8 ~9
    ! G% @, R! z/ t6 m5 @{'train': Dataset ImageFolder
    " g4 ?% ]7 ^7 k2 P4 E     Number of datapoints: 6552
    $ e5 M3 K& p. ^' P2 c     Root location: ./flower_data/train
    ! x; g0 o7 ?5 E     StandardTransform
    , o' V9 f9 _% \" Q. S Transform: Compose(
    6 S" y& n! b9 \                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
    $ d  |/ E- O! M# C8 j  n( D& p. X                CenterCrop(size=(224, 224))
    6 w7 B9 S3 W  H: ^& `/ {8 t                RandomHorizontalFlip(p=0.5)* B. f* |* p$ s& ~5 \) ]+ V
                    RandomVerticalFlip(p=0.5)7 Z" v* x" @& ]/ ~+ x8 `- q
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    ) b0 j) B+ k0 l( G                RandomGrayscale(p=0.025)5 ?  d. j. T1 Y, p4 n
                    ToTensor()/ q5 ~/ [+ W8 R: w
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    . Y% ^  K) C0 h/ E            ),: z- Q  |3 J! ]: e1 X3 r
    'valid': Dataset ImageFolder
      L" @8 _6 S0 Q1 {* d! D7 Y     Number of datapoints: 818% m& t; L+ M! j
         Root location: ./flower_data/valid0 M: Q7 o+ s4 R
         StandardTransform
      L9 {& O- G3 c1 M Transform: Compose(4 B: q  z6 |4 V- B4 o; u
                    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    ) E9 z/ V) Q& P8 g                CenterCrop(size=(224, 224))1 J, Z/ M) n) \+ `
                    ToTensor()! x4 S) S  ]% d! V. k- R
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])2 x. F7 k5 l* L3 s  I; y1 r
                )}
    ( G; s, u; ?( \3 C. j- Y% ^/ k  c1 U, z, W% }* `
    1$ {8 i' w( D/ D0 o; I* D) d0 t
    2* B  H; c0 @5 P$ B% H8 F* ?7 x2 D
    3
    9 d. Z8 K0 z7 w0 r" W4
    ( G/ b. `/ K& e5 U5# {: C) c/ Z! ?' c) ]3 U: R+ F
    6
    ! F0 \! S) V- T2 r, H  w! {! U7
    : R5 L' f: T( c8
    $ ^: c7 i( n4 G# s. y9
    ' @) }, S) o" t5 q5 q109 Z9 w% f6 X* H% G
    112 ?9 ]2 |% _! J8 w; Z
    12
    ; P& f" @% r0 K13
    1 d* n$ M0 c4 h9 g  X6 d144 f- X8 @7 u4 S! y, p& v2 `4 r
    159 B& o( S8 H* [2 p
    163 s2 C; C8 T8 A- ?4 }% E7 n
    17! u' p: r- Q, N5 V8 q
    18- ?+ ^. O0 ~/ m, m% z; a- v8 R1 j% J% ?
    19
    , a  K* O( g5 k1 E+ Q- {& B& N20( H" g) Q) y# C! F6 P
    21( B/ O+ a" J: \
    22- t+ V2 e2 A; Z: ~- M, ?* z
    23  }- {0 E. s) m& z; t9 M
    24
    $ l  K$ T2 t! Y) F# 验证一下数据是否已经被处理完毕
    . A$ m6 d8 b- Jdataloaders- i, C9 p& x6 _& n$ I$ E
    1; p7 |( X% Z( E
    2
    9 s% r$ O9 @& F7 m, v% G0 H{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,1 \4 i% U- Y; w) p8 F3 K1 s
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    0 o# L8 G+ k  b. ^/ m1$ T3 |) Q; C* Q& k- K
    2- I6 S8 I' ^) q! g3 Y& p7 b) a
    dataset_sizes
    5 {# d6 M6 i' G( Z1
    7 P$ b( |0 e$ U! I{'train': 6552, 'valid': 818}
    ; K$ K1 }/ X' A/ f/ k; L# X1  d/ u# U/ E- V
    读取标签对应的实际名字$ k' E, k) U) F
    使用同一目录下的json文件,反向映射出花对应的名字6 w! o4 J! A" U1 P4 ?: i8 l

    / W3 ?+ M/ [3 m# k" Twith open('./flower_data/cat_to_name.json', 'r') as f:9 d6 L% @% Z) @. I/ e' i+ i
        cat_to_name = json.load(f)
    & h! a% x, X3 K3 ?+ X0 J2 Z13 @& M' W( D- b" m% y( o' N
    27 E2 u7 o, e) G, D) \8 g
    cat_to_name3 W# A, a: T6 p* H% _, _+ y
    1
    " w9 v. }0 S* R, h{'21': 'fire lily'," U9 i9 n' n: J! o& F
    '3': 'canterbury bells',
    ' {/ T6 N, ^" m3 g" ? '45': 'bolero deep blue',
    2 P- R- ]$ l% H- F- n$ l '1': 'pink primrose',1 W1 v9 A' v% J7 i
    '34': 'mexican aster',
    ) @- |- K  p6 m& d: T5 Z '27': 'prince of wales feathers',/ t- N( m) i0 L* {9 g9 r0 w% w
    '7': 'moon orchid',
    3 E) k9 M/ V3 m6 x3 A' |8 q: N0 `8 p5 Y '16': 'globe-flower',$ o; v, }. ?7 u( U& C3 h3 U
    '25': 'grape hyacinth',
    % T# v# Z% ]7 r# E& w( D9 B '26': 'corn poppy',
    ' O4 C' L6 i; z" n* y '79': 'toad lily',  Q9 z' _& j$ X- W2 E
    '39': 'siam tulip',
    8 r' C( r; ~9 n# H  [0 K '24': 'red ginger',3 n' c" s( E3 y  C
    '67': 'spring crocus',* d. }; m  h" L9 a3 z* q* u
    '35': 'alpine sea holly',0 ?5 D% I$ m7 _% S8 x
    '32': 'garden phlox',$ G' C9 a( E' ~( r; N2 `# l* r
    '10': 'globe thistle',
    , I- J& {8 s) m '6': 'tiger lily',1 n9 b8 A0 L+ a) T: w' g' X
    '93': 'ball moss',. W' D# C5 S: w2 f8 T0 ?* ~+ Y$ Y
    '33': 'love in the mist'," F# l4 z3 r- x
    '9': 'monkshood',
    - ~. Q$ ]# @  i4 G' x1 I '102': 'blackberry lily',2 l. D& x2 h, C, L
    '14': 'spear thistle',& _+ O* j! d+ U( c
    '19': 'balloon flower',+ }7 v+ Z1 L  i1 H" n( ^1 G2 }
    '100': 'blanket flower',
    " K: ~3 B1 d; m% h+ J '13': 'king protea',2 n9 d4 i/ n6 H
    '49': 'oxeye daisy',6 |$ c- }5 q1 J/ P: I
    '15': 'yellow iris',/ J% P$ y7 k# Z9 B1 \/ w
    '61': 'cautleya spicata',
    " y! {: `9 A* t9 {& b% s '31': 'carnation',% v: G9 I" j% z
    '64': 'silverbush',
      `: t2 x- v9 X- Q8 H$ D5 W '68': 'bearded iris',7 \8 {5 W( g) O9 Y
    '63': 'black-eyed susan',1 y& P2 z, i# i2 i
    '69': 'windflower',7 K: l6 ~. W" T: a. i. s3 `
    '62': 'japanese anemone',0 l- F0 }/ M7 Z3 j" N
    '20': 'giant white arum lily',6 O. p) o4 `9 ~. e
    '38': 'great masterwort',
    " g$ z4 e/ s6 u% Q# \ '4': 'sweet pea',6 ]8 d( i2 a" g# S9 ~0 m7 n, x
    '86': 'tree mallow',
    , `3 H+ I! n( v" B) M2 x! C$ l: S: L '101': 'trumpet creeper',
    . X+ l( O2 a: z0 a' H/ i" Z/ m '42': 'daffodil',
    , g5 @: g5 t1 v2 z( } '22': 'pincushion flower',
    ; Z5 L: U7 i; O9 y* {' J '2': 'hard-leaved pocket orchid',+ g( \6 D1 z5 y5 B9 }
    '54': 'sunflower',
    1 I8 B0 }/ T% [& ` '66': 'osteospermum',2 r+ ]3 @8 e, p. \" b% ^
    '70': 'tree poppy',* C  @$ X( C+ L5 s. X
    '85': 'desert-rose',$ A5 g6 `6 j9 U5 R  D8 |
    '99': 'bromelia',
    # E1 i7 }' B# t3 L '87': 'magnolia',
    / l3 l: |) e+ j+ z '5': 'english marigold',
    7 G4 n/ v* v& g6 I2 L5 U '92': 'bee balm',6 g% i. T% ?5 j1 `" k
    '28': 'stemless gentian',
    , t) r  J; n/ o/ _ '97': 'mallow',
    ) I: O+ A% w& S& P '57': 'gaura',
    * W% ]& g3 K2 V9 W '40': 'lenten rose',
    ! I5 Y9 m2 [' l* U) X$ d '47': 'marigold',
    ) Z/ u% X9 p' k) d  I '59': 'orange dahlia',
    7 j( j' j2 g) p- G1 _ '48': 'buttercup',
    2 K" d" k* w" o/ o+ G5 t  P6 g/ Q '55': 'pelargonium',
    0 F! C4 [$ g8 Q9 @ '36': 'ruby-lipped cattleya',  c- q. D9 t8 `5 d+ P' e9 f) Q6 t
    '91': 'hippeastrum',1 P1 z7 E* N' p
    '29': 'artichoke',
    ' K+ v& Y0 R, `! M  C$ P6 ~ '71': 'gazania',
    3 B9 C( F/ E  K; O; k5 V '90': 'canna lily',
    ( [. h* f' v- t0 B; ] '18': 'peruvian lily',
    2 t/ R: o1 B5 g! Z '98': 'mexican petunia',
    : T# X! y( q! G: }5 [5 S9 L0 { '8': 'bird of paradise',- ?# C: w# ~7 q5 n) c' ^& H
    '30': 'sweet william',
    , _" J  Y) Y4 e4 u- T* p '17': 'purple coneflower',
    ( a. z6 t" J: f+ i+ S '52': 'wild pansy',
    3 g3 ?) e8 ^4 x '84': 'columbine',
    7 F: F" F4 q! u: s '12': "colt's foot",
    ( P6 Z6 ]' @6 m$ \ '11': 'snapdragon',/ n9 k- G& a# k
    '96': 'camellia',1 U) P. m4 _* Y1 u
    '23': 'fritillary',
    1 A% ?* J* i; L  E) ` '50': 'common dandelion',; D! Z% t* E4 w. g9 P& T( r2 Y
    '44': 'poinsettia',
    ! g$ d% U3 j0 [- t+ u" l) M4 E '53': 'primula',
    . m( i6 E/ z0 K7 ]+ { '72': 'azalea',
    $ A" O/ l$ c' l: O, M '65': 'californian poppy',; d1 }. ]7 y& \
    '80': 'anthurium',3 q- Y' x0 y9 d( ?7 r( e* t2 h7 `
    '76': 'morning glory',) ]7 T! H! B) I# Y* z7 c" A
    '37': 'cape flower',
    $ h$ X; L  w  U) C '56': 'bishop of llandaff',( ~' B9 {. Z: K* z" I% L
    '60': 'pink-yellow dahlia',
    7 k0 L1 {( ]; E# U '82': 'clematis',* O! c: _$ y! |. f
    '58': 'geranium',
    ( g1 J7 f2 _! g0 c5 ]( c" N '75': 'thorn apple',
    ) u6 k. o! Y1 N/ P6 o' Y2 l* G4 J '41': 'barbeton daisy',
    7 g5 j5 J: l# C9 U' A* e '95': 'bougainvillea',) L6 x# G, i' R, G7 r5 ~* w- `
    '43': 'sword lily',: Y, r  o- I/ A- c
    '83': 'hibiscus',* B* L: Q3 W( [) ?% N
    '78': 'lotus lotus',: P+ q0 W0 z1 J
    '88': 'cyclamen',8 |6 O/ u) ^) B% n1 i+ q
    '94': 'foxglove',
    * v% G0 {9 `& ] '81': 'frangipani',
    " f; c2 {! M- r. S8 R5 Y, }2 z) p '74': 'rose',
    " \& {+ C9 u3 L+ K8 i6 `: d '89': 'watercress',
    - c( v2 Q8 Y- X* p% k '73': 'water lily',8 B( L6 G2 X: X, F9 V
    '46': 'wallflower',' ~1 M8 Y3 B, _1 k# C. q5 o8 ?
    '77': 'passion flower',8 O: w  G) w+ a3 L8 ~- Z
    '51': 'petunia'}; [8 d; S3 a; _
    6 s$ `/ ^( F5 A# c- i# f
    1
    ! M( H/ w( |$ B+ b, X( F2
    ; o+ y/ o# x: R1 `! T' U7 a% y3
    9 |+ {* X+ y; f- B4
    8 X; ~- z& `! d0 I) V+ s5  |+ I. ?$ y# p
    6
    4 A& @$ ?; B$ ?# M7 d3 e$ W2 E7: c# o+ `; ~7 I- |% b- O" a
    8
    * Z3 P- @! z8 J% H98 u- i# V* ~5 V" |
    10
    ' K0 A1 s4 F$ ]* k$ B. o. N! \/ I% L  R11) _( u( c% l/ ~  p9 M
    12
    5 o- \, V1 K% L& V, d7 I5 ~' R* w6 w5 Y+ i133 ?" R$ r3 h6 O' |0 Z. K& h
    14
    + g$ Y! V) `) I% o( B+ @$ R9 ~15
    ( w9 s5 t4 ^1 k; L4 _5 C! _16
    1 V( x. Y  T+ X2 Z. k: S% |17$ B9 |$ L  c4 w
    18
    8 O: P2 i# E8 z7 b2 C( X19- u# b6 p4 N) }! y
    20. R* c2 n8 n2 u$ r: _$ u+ [, y% R
    21$ ?( y- o+ f& n( i
    22
    3 W/ Y. u3 r7 }238 z4 ^2 I" }& f& P5 h6 S
    248 D0 Z4 ]) y5 a8 }, k3 V& t
    25
    / r1 \' t0 ]5 C# t# B$ d26# [9 e3 I7 e6 S6 i6 J! _! m
    27
    - @! r% N) `& ^28
    ' K1 x' U$ f- S4 P- w29
    " R- r& X! [+ I5 a305 ?. G( S* A8 W$ B5 s
    31" p; M! H$ v1 m% V9 [5 _( k
    32+ @0 ?7 L! G0 `! E
    33
    0 ~+ y4 w4 Q! T( U& U' H4 h* P: A34) N  O" ~$ p) g/ \
    35! R4 o3 g/ x! [# H; P. T) {
    36$ r) f$ ]" }0 o8 T
    37! p/ c/ \* n& N: }) ]5 `4 y
    38; w- |% a, ^+ K( T  f0 ?- F2 z/ _
    39! k1 T: }) H6 n6 D0 Z
    400 U& T& n7 D$ F" e$ g
    41' G3 [. a' E3 N; i3 }# p! J
    42
    & U! H0 H  Y7 u7 \3 p- ^43
    ) g3 i7 D: `/ s/ w44
    + o6 g- P  ^% Y0 }. `45  Y1 T! k# d2 j9 w5 M) `: V/ K* a  r
    46
    " f* L5 B8 }1 z; c- j5 D9 n47
    * n7 b5 A, M, W: G; l) C& U+ d; Q4 J4 A/ e48
    7 d7 _8 C% L* M; g- C' k, F49
    , d' x- P- F& I. p2 q% b509 o9 g6 \/ J, J7 V" I
    51( e% g! G: ^, C# l1 B' M
    52
    : p) }4 w' k- K% o( \9 y53
    : ?, B) O! b. ]; t54. `4 B0 K2 @% ]
    55
    3 e0 w7 F5 g1 w8 v5 o5 }& h/ e0 y$ r56
    1 B5 _! z; t% m57/ R( V: t: j. R, x" S1 A5 U; [
    58
    2 r9 ?4 Y& Y7 E+ V! E59
    5 l. w5 ~$ U1 Y! }60
    3 @7 y$ ]& ?# E( W4 G61
    0 N2 g* m; o7 G7 N62* t* v$ J( K7 M6 e, Z; F3 I- O
    63
    6 w3 B8 ~1 i1 I9 M2 ~" F2 \+ Y  ^64
    6 e; Z% k! t. L- X! ?! j/ P65
    ! w4 I0 b4 O: n; s! V' z66/ @/ G# D9 y! G# o; H# i
    672 u2 Y" x# a  [6 O* [/ }9 J1 a
    68' p; _2 s1 c' c! O! D5 w  U
    69. [7 ?( k! g5 P
    70
    9 p: E% B9 K+ v2 y, @71
    / \3 b" g( V2 H# ?5 P+ Z5 X72
    9 t* n. _1 V5 S- l* `- }/ y( p73
    9 d9 G2 {/ g5 {743 e  i& h. M+ _  b7 H
    75; b. u1 E  g+ Z2 g- c3 B0 F) b
    76
    0 i7 c" v7 X! }77
    9 E9 s5 k% P$ ?. g3 X+ D8 O78
    * ]4 ?( [! q( b* x! \7 s% y: J6 z79: X/ d0 j7 T9 I$ y9 {
    80
    : B7 e0 P, h* q& q81
    4 c) ~8 {  S7 ^  Y% c82  w# d( A- h, i0 o
    833 |) f2 T5 B. n; b
    84  g) ]' G( p! G' n
    854 f1 ~& I& g. L+ S1 k% N
    86/ ]$ J, ^5 P* ^1 j
    87# R% e3 o2 M& c8 Z9 _2 f
    886 q+ @' w# w, }0 `, l
    89) M3 M+ a) {% }! @- H6 {0 m7 q: f
    907 c7 x; c* n1 J  p
    91
    ' M3 H9 [2 n/ H" P6 m( ?$ [92
    ( h  l- E& V/ `93# d, M7 }* u; E! h
    94
    , X( O" \1 o" m0 s+ D+ L6 g95$ w  w3 x& y) W3 m
    96
    " r  C$ {2 z# i( `7 X. r! H- k975 ?! h" N8 W3 Q; d) e8 b6 R( }
    98
    - ~5 j! p& {1 g' m, ]99
    $ ^! S, ]5 l$ n) t' F100
    , `8 I* J# [, }. ~101
    $ v; q) J* K2 E: e; [9 F6 U102
    ( p( f5 {5 U' i6 \4 j! {4.展示一下数据
    / n) B2 W6 @2 l# Ldef im_convert(tensor):) {3 h- _+ I! t6 ^
        """数据展示"""0 H( ^# `( j% [9 N- m
        image = tensor.to("cpu").clone().detach()
    ) D+ Q+ n9 D5 p' v6 M6 n: d% }    image = image.numpy().squeeze()
    8 |1 b; N' k8 w( K8 W    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图" G/ Z, ^; e! l! r
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)' u6 p5 D) y3 Y4 X3 L' R" m& t& r$ X
        image = image.transpose(1, 2, 0)( w* V+ Q% O* ~5 Q
        # 反正则化(反标准化), F9 S8 p% F! z2 R: `
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    ) h* J" I3 x9 B/ @. |
    ' B* Z# \7 w9 y4 `* g% ^; h5 A/ z    # 将图像中小于0 的都换成0,大于的都变成1
    + j, F: w& r4 \1 I. P- H5 H    image = image.clip(0, 1)
    2 u9 [1 I; d0 ?! Y5 E4 E" x2 ^8 ~* d+ z3 s/ y1 c! j
        return image7 t; C( [/ i" M
    1- |, h7 J. H. ~# S( S2 @# C
    2
    ( t' ^( d# Y7 Y2 i3
    . y$ o$ [4 c1 |+ ?/ O4( L  @  n/ G9 X/ X) J, ^! E
    5/ b* N$ I; |" \8 y. M
    6. }" t: {$ |5 F9 N
    7
    4 ^2 R; ]  I, C% @" `" I- J6 N& `81 A( L: D. q9 m9 u5 W
    9. m( U" n1 g$ D7 Y, t: r
    10, Q0 [' }; L3 Z3 p' c* [
    112 e1 x% c+ Y! O# _6 A) ~" H3 O* W
    129 v- u& ?/ Q6 b; Z) m/ n
    13  J8 n& Q+ H8 [
    143 a/ K  Z: f3 X2 u8 u5 ?
    # 使用上面定义好的类进行画图
    ) H2 M! G2 p; `6 Xfig = plt.figure(figsize = (20, 12))5 P' J% Q+ v* G4 b
    columns = 4$ p1 R9 u5 Q. i" P
    rows = 23 b$ h4 H. y* X% D& t

    9 r5 h# r9 S1 w! o, R8 U1 @, Q0 T- @; i# iter迭代器
    : ]1 L6 P7 b) a5 V( H# 随便找一个Batch数据进行展示! V- n; i! l6 J4 L: R4 a5 F
    dataiter = iter(dataloaders['valid'])
    * ^  y3 @; _) g7 q/ }3 Finputs, classes = dataiter.next()
    / \4 g* ^( V" d2 @9 a, n" v
    3 `( S, e* ^6 @8 N" G. Vfor idx in range(columns * rows):
    0 z8 J' g2 y  Z5 V" \9 u- j    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    + n. R% K3 d" u8 I    # 利用json文件将其对应花的类型打印在图片中
    * [! h6 K+ [9 E5 H+ q    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    * o, _- z& ]0 V" K; Q, z) u    plt.imshow(im_convert(inputs[idx]))
    6 D9 s7 H6 I3 ?! x" tplt.show()/ m- L9 n# d( ^& }. u, ~

    6 \8 z# J3 ?$ g' @6 p8 J8 j1
    / V6 O3 X# _* o/ @  o& x29 ?/ C3 s. P6 _+ w* j) S" r) {7 R( F
    3
    % ^8 v# e/ `) z4 d7 A4
    ) Y' J' l* ^5 U52 Z1 z; u3 T2 M1 J. a: M. x. S
    6
    : A  ]: v2 B9 A& R7# _0 J9 ]( W9 D  v) N
    8
    7 }% y; ]! w  V: a$ d+ V* u9) B: k: e6 x8 t$ A, Q
    10  C5 \3 F% j, c3 ^0 P# ~
    11. F, l" e% j* ]. b7 _- k( I
    12; A  M; u( N5 k8 o& u* c; O% ~
    13
    ! ^/ _+ i" i/ F7 Y# W% w14
    0 K5 \0 Z! G1 w* @' E153 b  ~# f1 E: K( @5 J! k/ h7 a1 P: k
    161 `* `( P# X5 B1 v

    # a: x% ]# d5 u/ l7 B) B% G0 _8 T' K5 e. \/ b
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数6 h+ l/ D2 y* N% r( @6 g- X
    model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    ( W' I: L1 p9 U$ H# 主要的图像识别用resnet来做
    , F% @; C& G, f' I8 d0 ^1 D; m# 是否用人家训练好的特征
    3 l) Y- \$ w( _& Y/ Q. `feature_extract = True6 T! N6 r# U+ H9 B6 J
    1
    8 K7 S' x$ S: @3 G; D29 Z2 L4 b, i6 @
    3
    . T) B8 c; k  v: t+ U4
    - R, z% Q2 `7 R$ c* b# 是否用GPU进行训练
    3 A+ x+ l0 _2 f6 J, utrain_on_gpu = torch.cuda.is_available()
    7 q' ]" r, |5 ]; y; g4 _# ?8 l( ^6 r$ _- c5 W
    if not train_on_gpu:- N$ m: B$ S: g) Q) Z3 b, D) x% T
        print('CUDA is not available.   Training on CPU ...')% v: b* R$ e2 Y6 u( `; |
    else:( T+ _6 A0 }; l  e/ F6 ~* j
        print('CUDA is available! Training on GPU ...')
    ! T4 C9 Z4 t' T
    # H- g/ T# Q: _! rdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    0 b& ~; I$ {$ C! x& M/ `1
    " h2 v" `& N0 u9 _2 I2
    2 K8 \' q; o1 j; ^4 ~7 {3+ s; [$ q$ o* `2 [) @8 L2 I8 P1 s
    4
    ! D6 E8 }6 ^0 D0 x5
    / d3 h+ C, N% v; Z% N- L6
    - S! {& n; t1 G9 Z3 ^, _7
    # W0 g2 S2 A: p% w& F8' q: S! N% X3 |1 I: B, \- Z
    9* i) m- ~( y' M) W, ]
    CUDA is not available.   Training on CPU ...
    8 {9 t& C4 U% o2 r+ g2 X8 p1
    ' _- Z0 O. y* h  q" F  v+ ^# 将一些层定义为false,使其不自动更新
    - I4 ^& _9 R& `2 d3 V5 D) `7 Ndef set_parameter_requires_grad(model, feature_extracting):
    " p/ c2 A# V+ {( B* e    if feature_extracting:
    " V0 U! \6 h0 o% t: S, d, j        for param in model.parameters():/ n; ]/ y7 s5 F6 G% |" l( o
                param.requires_grad = False* w3 x1 P( d4 `& N6 z% q. P
    1
    1 k4 n) }& Z& F$ d) ]( V" X& m# v2
    * l5 q5 u: }) _# Q% d! G& Q3
    ( o$ O9 N0 C* ?$ U' }  L& T. }) R5 i4
    ( [5 R, o. o0 c2 r5 a5
    6 q1 c7 h& I9 R# X4 w" d8 A$ `% e# 打印模型架构告知是怎么一步一步去完成的
    5 S( _* w' W, Q; e- D# u# 主要是为我们提取特征的* ^5 R$ M' |& `1 J4 [. {0 z
    5 W# |; ^" R  P6 S: H& |  [
    model_ft = models.resnet152()/ O4 t( x& I. m
    model_ft
    , r" ^/ b' P1 P2 x2 ?& R& k19 @! l) ^9 \- Z) M6 p$ j. {
    2. \' X% r: d5 i( n, Y! d
    3
    * r  j4 c: P* {  c, W* m1 w4) ~( {+ j# f# n( N) O1 i
    5/ x$ t8 y+ r4 K/ o6 n
    ResNet(
    ; ?8 i# e# G0 B* W3 X+ W% }: v( |  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    % ]1 F& K$ q) H  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    : a2 _, \; t3 J* |+ Z% P9 K$ o  (relu): ReLU(inplace=True)& y5 {: ^. p) Y
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    2 D( G# e6 g) p* t6 z/ a( W  (layer1): Sequential(7 T9 Q1 W, Y& N: `
        (0): Bottleneck(. [) \- j6 W8 ^; {, }' g
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)( c: C2 W! `2 M- ^2 {
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    " i3 c  i) y- Q      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), z% F8 \2 M1 E/ \3 E4 _
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)3 s/ a# `$ G) m+ Z+ C- z
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)% K0 B$ L& d5 C3 n  y
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)* Q. P( E  E7 M
          (relu): ReLU(inplace=True)
    9 W) O; N% e; M4 }      (downsample): Sequential(4 G5 Y  i# v( R  n) u7 {
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)* |% r: Q" Z# t, C- ]+ G3 e1 i) U
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    5 d; o1 B4 M& T7 g3 C3 K8 s      )- P$ f! C% `. Y/ I. ^
        )
    0 \' L, `" f$ `- ?3 S1 c中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    & m- R# P: z2 v6 w1 t; Q9 G    (2): Bottleneck(
    . i  |* N+ u/ X% k6 i1 o      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False). W$ x) x/ g* X4 D1 w
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    4 O: S" S; d: T* w/ ]6 ~      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    $ R' `% }& y2 C/ k      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Y: Q9 C" Y* w2 I/ x  B
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)- s1 f" ?* R6 d5 a1 a9 g
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)( |8 g9 \( O' g( [1 F$ `* {# B
          (relu): ReLU(inplace=True)
    6 A9 g( P7 Z. H6 z% H' D    )
    - B, R1 @  D6 k& _  )& Y1 M* b+ ?5 d: x
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))* _4 r1 \* H1 {& a
      (fc): Linear(in_features=2048, out_features=1000, bias=True)& T: Z& X. D: j
    )
    % g, x  S+ h+ {( |' V: g1 ?. h: p+ B" V+ |9 E$ m
    1; q  s; F4 ?/ Y7 P
    2+ z" d+ ^8 S/ A$ v' n' m
    3
    2 }" d6 a2 p9 X; ]' O45 N! F% N9 s4 J! t2 I: I" B- n; e" o
    5" |+ X/ E! G/ a% m" i8 B% W
    6' B& e# z" R% Q  Z
    75 Z- T, r* p4 p6 a9 t2 ]3 P
    8: W7 l# c8 O2 q+ F' }9 B) B
    9- D7 l, L2 d& H( b
    108 A! l4 g, N& G
    11
    6 L, i" H3 [& D5 R12
    5 j: P, P) @3 m5 z( r- O13! u0 a; O* {/ c: J" J
    14
    * @! {7 U: z# l  }6 ?0 p156 p& Z7 |2 v; _6 M! M
    16  K* ~; ]1 J0 Y
    17
    / q- B' {. x& R7 D$ d8 \0 S18
    1 M) C% o( e/ {  X' n192 x. U3 i& G% j3 I4 q
    201 j  [: F1 Y) `# @
    21
    % l1 `9 C, A7 e, I, d4 t/ d22
    : y3 \# t2 \1 j% b235 T1 x6 D. d) w, M. \4 }" C
    24( D! g) n: i! h5 X% {
    25. }' J, G- f, i' i
    26$ ]( i% {$ R+ u- j+ ^# o
    27
    3 W# X+ }; R( u( i8 U" g1 m. @28
    5 w4 ~3 O3 a- n29, b& y. J4 r1 t8 e$ |' N
    30
    $ u8 z) d5 l( u% D2 Q5 D+ b7 y31/ _* O( p1 e% ^' M) I: X- t
    32
      m% z. O0 u5 f7 N* w) }, F338 }, W* C$ V7 V" D
    最后是1000分类,2048输入,分为1000个分类
    ( j, X* ~3 \, Y/ J9 ~而我们需要将我们的任务进行调整,将1000分类改为102输出* v* O% p1 ~6 ?5 O

    , R# ?0 j% ?# l; A0 d" h6.初始化模型架构( l& R( ^% z8 G5 J- a; m# u
    步骤如下:0 I1 q4 ^6 A+ H% f9 D; h

    % I' A- C+ A. V* L& l9 R1 ~" c4 V' x将训练好的模型拿过来,并pre_train = True 得到他人的权重参数7 E9 z7 D- `0 F) b. y6 d& g
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
    3 E) ]( j7 Q% g5 v3 v' X无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数1 a8 u1 a; D7 U) S% e
    官方文档链接# A; W3 A8 D7 o, W. `
    https://pytorch.org/vision/stable/models.html4 Y) G5 a: u* z( ?5 @

    - g! B1 J; G$ q( Y5 d. o' E# 将他人的模型加载进来
    & S8 S' k9 a( x9 m3 v3 wdef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):; i- r1 a- K9 W3 y7 |  G: m
        # 选择适合的模型,不同的模型初始化参数不同7 J# i" M9 p' a4 C
        model_ft = None  E1 a+ u: u$ S
        input_size = 0
    " ^2 @1 [! Q+ W) f" \$ B6 O3 D/ R' X: S% q/ j2 u1 m
        if model_name == "resnet":
    3 K! @- y6 ~3 ?9 K% C* t        """
    # f' L% q5 A3 y3 c* O        Resnet152  ]% _" e7 g# N" n
            """
    . O/ I  O4 U. D. a! ]' |7 x1 Y8 F- ~0 _
            # 1. 加载与训练网络# \  y, K  E3 ], x6 W$ u
            model_ft = models.resnet152(pretrained = use_pretrained), `6 m! p+ Y2 }. l" k+ V
            # 2. 是否将提取特征的模块冻住,只训练FC层
    1 u5 U+ Z+ @6 }; C4 G        set_parameter_requires_grad(model_ft, feature_extract)
    , P2 ~! G2 B4 ]/ |( ]  j' l) H        # 3. 获得全连接层输入特征
    . U- Y3 X. y$ R% G$ n8 {/ h; e  F+ q        num_frts = model_ft.fc.in_features" i* |4 r2 `6 `  C
            # 4. 重新加载全连接层,设置输出102
    * K: Q: g  O& _$ j% z! {        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),' Y2 E) Y! |. |2 j4 f, K
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    + b, u7 B( {- s4 `/ ?; U, M        input_size = 224
    ! ~; U) r2 a7 ^& w+ u6 z" z4 D0 F4 p! v6 N) M: q
        elif model_name == "alexnet":8 o& K8 U# u* r7 r
            """; ]/ G/ u1 p% x% z  O* B
            Alexnet
    + l8 {+ y, f4 @! _- F+ c        """3 W7 d2 [% b4 K
            model_ft = models.alexnet(pretrained = use_pretrained); o. _0 a* Y( Y0 W: n
            set_parameter_requires_grad(model_ft, feature_extract)3 i) u% d7 B( w9 n- `
    / ]9 h: h4 x& o4 l: K1 y
            # 将最后一个特征输出替换 序号为【6】的分类器. f5 ]1 `( V/ Z: u7 |+ f0 d0 t
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入
    / G0 S4 ^3 y* _- C  H  H( b        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    2 u; d* M+ L* p, |) b3 Y1 ]        input_size = 224
    0 G9 ?7 E$ A2 A; A- E
    8 ^) D. E: {# E2 v+ b. r" V$ z, Q    elif model_name == "vgg":
    ; t" S) M7 v* o( g; n0 i9 {        """; E, f4 L6 }' e# h
            VGG11_bn* b- B4 S/ D5 h' R- [8 I4 ]" I* p$ q
            """
    4 r& n' y4 Z/ U: n9 e. D8 r7 v& K        model_ft = models.vgg16(pretrained = use_pretrained)2 Q& }( Q9 n. _5 O
            set_parameter_requires_grad(model_ft, feature_extract)+ L6 q# E/ \" z
            num_frts = model_ft.classifier[6].in_features( n. F+ x) E( S( T
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)" A* \8 g: I5 T  f  \6 e/ \" n
            input_size = 224
      ]: O( n+ A$ z1 D" ^: a$ b# @$ S3 {! P9 |
        elif model_name == "squeezenet":+ w* z/ `8 C8 m) f
            """
    + k! q$ ~! Q8 P8 p  ?( z# W        Squeezenet: t; x/ N/ O; |) e; y* ]* M8 j! @
            """( n! `( p, W7 p* z7 g' f9 k) K6 m7 J
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)2 `% C+ x' M- B  p4 P: c
            set_parameter_requires_grad(model_ft, feature_extract)5 W' ]5 R$ w$ H: `7 v/ p
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
    ' {4 r/ q& {" }" R' H3 [) z        model_ft.num_classes = num_classes
    9 a) X+ J; W% Q        input_size = 224( @3 `5 e2 k+ v, S" h7 ~( c3 q
    0 ^5 E$ @2 o  W" K1 e* ]" h
        elif model_name == "densenet":- z- ?2 G+ O1 C' H  M9 U
            """9 ~9 r! t1 Q! P
            Densenet7 m& v* |5 P, k, e  t+ H6 ]
            """1 q* P) [& D& X  n8 L& Y5 F
            model_ft = models.desenet121(pretrained = use_pretrained)
    0 k8 S* x- I& M2 q) I# j! F        set_parameter_requires_grad(model_ft, feature_extract)0 h. J- @* I& {$ b: G# V
            num_frts = model_ft.classifier.in_features6 R8 ~! g* c1 R7 W
            model_ft.classifier = nn.Linear(num_frts, num_classes)
    + G" g  q# W" ^. i7 ^        input_size = 224# \: }# J0 \& I1 P% t# |; r  e

    2 ?+ G2 M% b, _- F    elif model_name == "inception":& `, a& ^/ l2 }) \
            """! k1 X# `" t  }' n+ W
            Inception V3
    + a  Q3 U  B- y        """# h4 K. z0 e0 L8 F" p! O
            model_ft = models.inception_V(pretrained = use_pretrained)
    " G3 U, N, i; w) S* Y9 s; V; q        set_parameter_requires_grad(model_ft, feature_extract)
    - F$ e0 i4 r8 b) |7 Z' g. m( W5 c/ H7 U
            num_frts = model_ft.AuxLogits.fc.in_features9 A# i; o; p; w5 B6 |* r
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
    3 s* g: U+ G- e/ {4 }9 e$ I
    , w1 n; ^$ L& x2 p        num_frts = model_ft.fc.in_features/ g4 t7 [# j$ q  ?; ?! ^
            model_ft.fc = nn.Linear(num_frts, num_classes)
    5 H) w( T% v4 C        input_size = 299* ]! s0 }  t6 p" w" d
    , v- I' ~/ T! W2 T# F1 n" s5 V
        else:8 A& O- x' J0 D2 W1 @: |2 H2 v
            print("Invalid model name, exiting...")
    0 q/ K. k3 E* u0 i; f- n        exit()0 E9 T" V) e7 I2 N3 V

    . _8 C- j) o. L    return model_ft, input_size" o& u( r6 B. U; T' ]$ \  Z4 \, q' E

    4 x; t8 f0 [! j" l1
    , g- k" d% G) t2$ I, {  v- I) g: g" @- [8 N9 K( X
    3
    1 b8 D+ f; Z, }' W4
    3 r3 H  Y4 C, ^8 u, o& f5
    + Y, H. ?; o' }. c1 E64 N/ D/ f, @- m. [$ Y
    7
      S! l) o: h8 ^4 \& i% f87 \+ x: X; c1 O+ a
    93 ^/ ]* T& M' p8 j  @- L9 k9 v
    10  C+ z- q8 ]) J% V4 L, S
    117 b0 c/ W% t  ~% p7 W, K$ d1 r3 n
    124 a9 l* T& i6 e0 Y; _4 H
    13
    / \$ f  m; |; E# j9 D7 C14
    ! p# E4 y# A/ n% O/ Y. S151 S' M% |0 a2 F' c; Z* ?; N% s) j
    16
    + _) D5 Y* y* }! }17
    + b- H# d" ~/ b: ?3 N18
    ( L: g+ ?3 D9 ~/ r0 L$ `19
    . {5 M) |( \! \4 [! R# ^1 N20) j) X/ g4 z( O/ ]3 e
    214 v1 A& z/ B' t
    22
    - @: P0 p6 p; _, W, ~3 s239 j: ]: [& A, S9 v- F: x
    24" ?! }" `- b, q1 [/ v4 d2 U
    25: j" B8 ^0 h4 J( H0 Z1 X
    26- j" S9 t. t. W
    27! O, B2 v; \# i2 e4 `4 O' [
    28
    6 \7 f7 f9 _3 r* B29. I6 j# F4 M: z# u
    30
    - q# l, [% V: C, A; T2 [317 T0 z# R" G0 S# `( C! T5 s
    325 P' J' [7 p9 o+ `/ A
    33
    - o# s! h$ P8 W# n4 v346 I7 u1 M# j+ e( M
    35
    3 j& y, z9 w, U/ o! k$ w  _7 q* X36
    ' k/ K8 [, i) A; i! w- Y" _375 B# }- s5 l: p8 d7 w) h
    38
    2 W$ ]( R# b9 O  G2 J& Z8 A" B394 O+ h& ?5 N' r8 x
    40
    . N" l8 Y6 n# O3 l5 Z415 Z0 v% x; i+ o7 _; L# [; u+ d3 y; H/ f2 ^
    42; t: ?- x7 S1 o& p. }  t
    43: x9 ~7 r- M* E7 V* `
    44
    " _- y! b# X4 V45
    1 b" H+ K. i1 f! r46
    " w  x6 v) D& S: f, F; @/ h& h* e47
    . M# j, Q  [% w" ~485 D* ]. l( C+ i; k) i6 X
    49
    ; Q# X3 T6 D6 Z50
    * H; \5 C5 `$ [5 g51- q  M9 y! V5 `/ _' {. d
    52
    + B, z) d; @7 X4 T  W5 M% X53! [( F/ j1 O, R
    54. s0 G7 D2 X) P
    55
    ; L8 \+ Q0 L8 @1 A$ X" q0 W/ y! j56
    + V6 w. d+ a. x0 c: K57
    5 B  y- ^+ `% W1 Z- g58) w8 v2 I' f9 E. C
    59
    ! s6 p+ |. l1 U& I; n" t602 V( Z; R& n5 _( A9 Z
    61. T! B. t" O/ ?( L, S( [  J+ _
    62
    ) c' Z7 H. A% Z. o7 i63
    , _; f5 r' F4 \' l64
    # M+ i, Z; u: J; t65
    ( C" e9 J! r% o' P. h, ~66% t, q7 ~1 N4 C* M
    672 ]" h7 q! W7 O  j; m" @" S
    689 J* u) x3 w# ~8 Y& Y
    69
    / }2 U  H# ?% K! m701 c3 z3 `  z; ~3 S
    71
    7 |' D, W3 Y8 s4 I) W6 v728 z  i& D8 I0 I- T+ V5 I: k4 Z
    73$ W) z) b8 ^+ F) @! s: x* y& `. \
    74
    : c5 \5 P+ \9 N% \0 e9 T8 D: x756 o( @" A+ H5 {1 U& P
    765 y: |2 v% |/ x6 R8 G% D( I6 A0 ?- C
    777 b0 N' a7 {$ z9 F6 r: d- C5 T
    78
    5 D1 F0 f2 Z% X# o& ^79
    + C. o7 M' i% z1 r" m3 e; v80+ d5 [! g4 P! Q8 X- L
    816 [% P0 |" U6 q5 W  @( s  g
    82
    ' T/ X* O- y# }# I2 H5 N  E83& F# W' M/ V+ N8 @5 l8 [; `: G
    7. 设置需要训练的参数
    6 |# p: d$ z8 ~# 设置模型名字、输出分类数+ @$ \) ~# k6 Q1 g8 O* R! q
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)) h/ V5 G8 d, _/ J. w
    % Y* N% b! s# y
    # GPU 计算
    : V# @5 H2 M, k, K# wmodel_ft = model_ft.to(device)
    & R: q9 V0 O" f; m/ H
    & i% h* r8 |2 a# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取' t/ r  ^2 O5 D" p2 }
    filename = 'checkpoint.pth'
    8 q6 h3 v9 K5 R" F' V/ D
    1 f$ ]/ i6 m2 X. ^* }# 是否训练所有层) x0 U+ B7 Q0 Z
    params_to_update = model_ft.parameters()( S8 r8 r( a2 |, h) V! ?
    # 打印出需要训练的层
    * b  N. D+ k; C0 @! w6 M# P: {4 n2 T7 cprint("Params to learn:"); j+ i4 n  f+ K  M& V* _: i+ V
    if feature_extract:2 H# E5 W8 i* ^8 Z7 @
        params_to_update = []
    ( `) t8 B5 O% ~" N' q8 q1 g    for name, param in model_ft.named_parameters():
    9 D1 j5 L  Y! B+ |        if param.requires_grad == True:
    * y1 L# l. y8 A1 X. q3 H            params_to_update.append(param)
    $ V) X9 i, b: {$ H0 `" m8 {) a            print("\t", name)
    - [0 h, T$ R9 C3 w  Ielse:
    , z1 V; }9 m) ?( I# \    for name, param in model_ft.named_parameters():
    1 d# m" ?4 v8 e' s        if param.requires_grad ==True:) W2 S( q8 ^$ A0 s/ v8 V, V
                print("\t", name)
      K9 z5 l$ n# b$ b: D
    / X, k) M) m' c; C1( {$ R( ?$ ^$ z( k* u7 p
    2
    6 e. \+ y$ G# Q  \3) K# d# X! r: W6 ?: p0 P
    4& O) d6 [* K' }; D! s1 s; X" g
    5" P' Q; @7 M& G  V7 [
    6: A5 M3 |8 L+ {" {3 m8 q, p, u
    7
    1 W0 h1 |3 s7 K+ ^: P8
    : {# J: _$ U: [/ H" m9
    ( q% s; p- ?0 r* L. S* A102 `+ ^" }, |  q% i! x. x
    11
    : o" S7 M4 \2 w( F1 E: f12
    5 r: L% m  `. I8 k/ b+ l; ~' J13: b% `/ [' f  [* ~( Z
    14
    ! c6 G( `9 W* T$ l' r$ q6 I! [) m15( H$ Q( p$ b  M' J; y
    160 E, J7 |3 k' `0 M" I
    17
    7 l; j- T/ Z- Z18
    + r3 c/ o! l: p+ b3 G: u" ^19
    4 t; S+ V# c0 y+ W" D3 L- e& {: y  ^2 w20
      w$ f" F+ f1 ?7 u' K1 Z21
      d$ H" k% U  ?. C' k' \! J" I22* h0 ]# i) k, ^2 i& c5 o" W$ ]% |
    238 Z! E, U1 Z" a# N5 J  }
    Params to learn:
    8 e0 _: T3 y& x+ v, v% b0 H. P" a( ~! H         fc.0.weight
    # n* ?+ o8 _6 p  h" y$ _  `$ E         fc.0.bias
    0 O$ j) Z4 I, Z2 f# Z3 u; l3 j19 @/ e& V: h3 ]$ Y% f' |8 v
    2
    8 u0 x4 ^$ M- q8 m9 p3, ?1 T7 P. F$ ^) Y3 R% H7 T
    7. 训练与预测
    . m2 n3 c. p1 \! H7.1 优化器设置4 G1 G# D# m; }4 ]: j% T* I3 B
    # 优化器设置1 U! N' m. [/ W3 d; B2 Z0 ?
    optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    % H# f  B) |$ H. }4 A" F. N9 V# 学习率衰减策略4 V, D9 h; P1 k6 j, p% @' [
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    " E4 {6 S9 o3 `) z4 t/ v1 H  R5 \& [# 学习率每7个epoch衰减为原来的1/10
    % J4 W1 {9 X  S; @/ T! d5 u# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    % j' G# Q  p4 R- B: u* Z; D; l0 e# G
    criterion = nn.NLLLoss()
    " g, P  G$ K7 G' K1
    - ~& v7 U3 h$ V: `1 x, l0 {& G28 K! Q' V4 ?8 ^5 }+ v
    30 ^- h7 @8 q$ B- O
    4; ]" P( d' ^* F- M$ k
    5
    # S, u6 w% P/ H# t, x2 H5 u$ ]6
    8 `3 b0 J6 c. V9 h7- l6 M/ @& |# f+ e3 \# l9 r+ H
    8
    / Z: }/ T5 J5 Q. s1 P" D# 定义训练函数8 o8 i9 i& C6 F. p0 a" U
    #is_inception:要不要用其他的网络/ b& }6 z" y8 ]& W- T
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    * D/ n, O# J8 k( `1 a- W* C    since = time.time()7 `  x  Y. z- z/ g
        #保存最好的准确率$ d3 \5 R: q/ U* t
        best_acc = 0! |. D* F; P; K) T
        """
    " i5 w% k1 |; H/ b* C/ w7 r    checkpoint = torch.load(filename)
      E( ]% V% l7 l    best_acc = checkpoint['best_acc']
    2 c, A2 q' ^% T% d3 ~    model.load_state_dict(checkpoint['state_dict'])# ?% e1 B0 y: \6 }1 g, k- k
        optimizer.load_state_dict(checkpoint['optimizer'])
    / w% r5 d# o2 {0 t: o    model.class_to_idx = checkpoint['mapping']
    ! ^1 i5 `7 e  k  E* n    """2 |. b  W5 H9 T" k0 n4 g7 g
        #指定用GPU还是CPU
    9 W1 Q1 O9 [4 _4 l# X    model.to(device)" q& X. S. W! G
        #下面是为展示做的
    + R+ O$ {$ ~+ q    val_acc_history = []
    1 p7 |, s2 m+ B3 Q    train_acc_history = []& a4 M2 i1 `  P/ x) k
        train_losses = []# _5 S9 N- N% h; ?8 m
        valid_losses = []) B* x' k% e# `, b8 O* n& D+ u
        LRs = [optimizer.param_groups[0]['lr']]
    3 K' s# {5 a! x* C2 D% _    #最好的一次存下来4 Q3 V1 h& \; G# K1 x4 n9 y, E
        best_model_wts = copy.deepcopy(model.state_dict())
      t$ v+ r9 `, O- z) J
    9 g( P6 E1 t% n# f. D4 L1 v' S! R    for epoch in range(num_epochs):
    8 }& t! Q8 f* s        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    5 ^8 p2 h0 j+ w4 ^# O9 ?! \        print('-' * 10). E' s8 V4 W/ I

    7 O" E) X. b5 x3 L* n        # 训练和验证6 l( v" V$ M- [% F$ P4 v1 f* {
            for phase in ['train', 'valid']:
    0 u5 C( J1 M; m2 c5 j            if phase == 'train':
    % a3 F) n1 f5 C/ t6 I  G( d, ]) _                model.train()  # 训练
    ! B. p: B$ m0 c! {# f            else:
    $ M2 s9 X* ~  x! t, V) g1 s' K                model.eval()   # 验证
    # D. A( M0 x- N& r. V) H3 v+ W+ ?2 ^* P! f
                running_loss = 0.0
    # F3 U1 N% r: H3 S; ^            running_corrects = 04 [' L( b( y" L+ n

    0 k) O- }+ m+ b, l  c% t: @' c0 L            # 把数据都取个遍9 y6 D8 S* i0 \& |; e. q: e
                for inputs, labels in dataloaders[phase]:! L* Z$ q$ F' N6 R
                    #下面是将inputs,labels传到GPU
    * H3 s5 T1 Z. @/ T9 u% V                inputs = inputs.to(device)
    ' I, C: u$ M' J4 A                labels = labels.to(device)/ b2 b! i% e: h  z$ s

    5 B6 r2 i0 u% V( A" o4 P                # 清零: D9 C1 r" R5 q- C' _
                    optimizer.zero_grad(), d# C, v" V  @3 b9 A
                    # 只有训练的时候计算和更新梯度  m  ~  i9 Z. f1 D
                    with torch.set_grad_enabled(phase == 'train'):/ y! [: g. e% v
                        #if这面不需要计算,可忽略
    + O1 h/ E( L( o2 i7 K2 K                    if is_inception and phase == 'train':3 B+ D) B8 Y! N0 ?' ?2 W" {
                            outputs, aux_outputs = model(inputs)
    ' O& z; o: b) T1 G- I" w( {& O                        loss1 = criterion(outputs, labels)
    & H1 d8 d, k* ~0 V* d3 i                        loss2 = criterion(aux_outputs, labels)
    0 N8 T5 e& h1 h; k                        loss = loss1 + 0.4*loss2
    & a# K4 G, I9 M! `. h                    else:#resnet执行的是这里" h; ]  d; q% I  u
                            outputs = model(inputs)
    + W* [% Z) m3 Y4 T0 R                        loss = criterion(outputs, labels)
    # W  ~, \0 S/ R( w* _
    , X5 j* L8 _: B* Q/ C                        #概率最大的返回preds
    / \4 }- U* p! X8 x- t+ G5 D                    _, preds = torch.max(outputs, 1). {- _6 L% y( k8 \
    3 s9 J4 s! O: d4 G" o! U
                        # 训练阶段更新权重
    6 X- @, M' b& ]                    if phase == 'train':+ X  b# e- r5 c. m. h: k6 m. d4 E! A
                            loss.backward()8 e! [$ g2 @1 h5 d, l5 F
                            optimizer.step()
    2 q. f' S. X/ `/ ^7 P+ u$ l5 A0 ?3 {* P* t4 s% x! M  @
                    # 计算损失5 y. t3 m- E2 u' P
                    running_loss += loss.item() * inputs.size(0)# i1 x" h8 a! y; b2 o) n3 D) Y
                    running_corrects += torch.sum(preds == labels.data)9 E" n* J# C6 H5 t) j" M! w0 N

    1 C' Q& P3 }: R" n9 D1 ~. H            #打印操作
    - G0 @1 g* H* W9 |1 S" n7 Q. G            epoch_loss = running_loss / len(dataloaders[phase].dataset)
    7 y" G# o2 S  P% e7 H            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    0 d; S3 O# m' R5 m. P
    * V) y, Y9 T: ]5 p7 C4 P
    ( w# S2 K* Z6 e8 P& J- U* e2 o+ r. k- e1 }            time_elapsed = time.time() - since  Z: I  k- b$ E% q1 \  [% ^# [
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))  P3 Y! g. @8 |. E! G2 P2 q
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    7 Y+ b/ F& ^3 G* C& x8 {
      r3 q" }' `9 j% t9 j$ b( |$ g; Q: @6 Q8 q5 |$ f/ h8 l1 w
                # 得到最好那次的模型
    ) C. G& O0 l0 B0 O3 E! h            if phase == 'valid' and epoch_acc > best_acc:
    * E) N: t4 X, {" p4 x                best_acc = epoch_acc1 ^: C0 K; \: n/ X/ r3 t2 q
                    #模型保存
    ( A0 B8 y1 v  O& H8 G                best_model_wts = copy.deepcopy(model.state_dict()): K, Y4 z0 }: ]+ X
                    state = {' n" l! D4 z- ?3 w- H
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数
      B& |  Z* m( q+ H5 ?+ F) n9 Z8 j( y+ q                  'state_dict': model.state_dict(),
    " B0 b6 Y: T7 z8 [( W* x                  'best_acc': best_acc,
    & {. P$ k4 c$ k# f# h                  'optimizer' : optimizer.state_dict(),3 ?- f( x" i/ P3 c' u1 |
                    }9 k$ H  r. i% [& t( A% y5 e9 R% ?7 e
                    torch.save(state, filename): U% K; M( }% j; [2 v0 @
                if phase == 'valid':
      H/ C$ W) m& D1 n0 m6 [                val_acc_history.append(epoch_acc)# w. Q( x% V0 e0 J6 Y) L' N
                    valid_losses.append(epoch_loss)& Z1 K6 I" {* l0 j5 R) w$ M
                    scheduler.step(epoch_loss), i) U' r. S' D- V
                if phase == 'train':
    0 x3 V- o' Q5 Z) ?& g) J- N; n                train_acc_history.append(epoch_acc)7 C* b5 F8 |6 E7 S9 ~9 u2 R
                    train_losses.append(epoch_loss)( |4 h' h* Q% R7 R$ E
    & a8 R6 _" [8 @; f* J
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))9 X/ q* M8 y6 @$ R; t$ E' t+ p
            LRs.append(optimizer.param_groups[0]['lr'])2 V* [! s9 a0 G3 ^4 @
            print()
    ! h3 z1 J& R6 ~  j8 L5 E, n9 ?+ K. w* h
        time_elapsed = time.time() - since8 ]: E5 x- U) V# s
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    ; S+ n) N) i4 U7 b    print('Best val Acc: {:4f}'.format(best_acc))
    3 ^- H2 m1 E6 B+ G# l# W
    ' ?/ e- [. x7 H- F- d    # 保存训练完后用最好的一次当做模型最终的结果
    2 L2 a- I3 t, }. U) `2 L1 x& |+ N: P    model.load_state_dict(best_model_wts)( q( q" j. X( W6 M  h5 U
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ' U# \) T" @4 o( n3 q% Z
    ) _" r9 b! j6 F& |, B

    * G( x$ b, X5 w: p1
    . ^$ t/ y: B5 H( l  B. r2
    ! r& ^# d6 a: c31 x; z2 x6 _( R0 U. O7 r
    4
    ) v" z1 B+ ~) }' ^& A5
    % b/ o1 t3 d& F& V8 F! T- [6# l! n' U! y' [- @; @
    7* b: O% X' P5 o1 d4 i  M( j
    8
    $ ~/ N# T# B+ |7 A" |9 W9
    7 e% ]/ H- ~9 ^' c* I10
    : S7 i+ u, s" M- O( s# b11
    $ Q$ H. t1 `4 w$ S" q- G& ^  E12
    8 T$ `3 n% Z- F6 L1 ?* D138 Y( S5 D$ X4 h. d! B
    14
    0 ]( R- o2 A3 p# K& B" w15
    " t' a& \- l. O164 {& W+ d1 H, h4 O/ N6 m
    17
      B9 `- u. z' X$ Y2 j- O186 c& e$ p$ ^/ ^. R! w; X4 C
    19+ Y/ }; q* U) {# `* @
    20
    ) v$ [( u2 {) g& `( _1 T21  z6 ?2 D- C0 B: a
    22* o1 Q1 D0 b/ V4 m; b" A
    23- [! G. M* ]0 T, \
    24
    % B) d  }: y" k) Y5 r  _6 o25) O  [& [8 e5 L% S- X0 _/ _) k
    26$ C+ e$ ^, ~- g' Y, P* K- m
    27( U" }8 V. T, l( F) o8 Y/ f, M1 U
    28! U3 h' M# R  N# T' a3 T0 U9 k
    29' n$ g- I) z% X3 P" v
    302 E' T* s6 b" x. F/ r( `( N3 z
    31
    & j3 @7 a4 A) @* @6 ^" }3 O32+ S8 p5 `( F+ g
    33
    1 T9 I; q) v, K* s34% w8 ^, L/ s' Q& m
    35
    : Z% L" P8 f# J% M: ?; b3 w/ o36
    9 x9 ~/ l1 K& Q. [# W37
    ( g3 F) T6 R2 ~, ~( g8 T) v/ M38; |' Y7 n/ x2 s, |! }1 X7 s
    39
    5 x0 ?. o' w4 X% l! ^; h7 z40
    & v8 `5 G( M" e1 B, P41
    , c+ H: i1 L$ u: @42
    9 s- h6 U# d& h+ i2 m/ ]& E/ T+ U43$ ]" o: v$ Q+ H1 L( ]8 v
    44
    5 j, r; \+ v- q$ n# O# q' v& q45
    ) I: n1 R( g8 D; v46, N* T7 J5 {- t% q
    475 J2 U$ q2 U9 d- @$ C; H8 z
    48
    % A5 d- p6 S7 \5 l9 U497 `. M0 h$ n# p
    503 `- s7 e) X0 v5 d7 ]; [' c) ]3 e
    51' |6 @" m& E2 T/ e0 |$ c$ [& R" D
    52
    . g( o7 D/ c+ d. y4 o$ Y53
    0 Y  C" h9 [" [7 N' P541 z" W, }% k( ^6 }
    551 c3 z' Q& F( @6 v
    56
    ; {7 f& ]7 q) J/ U576 s/ L# c& Y4 [" T- _/ s, H4 v
    58  b: b9 |' N6 M" G2 b) L" Y
    59
    2 I& o/ }& O4 N/ W5 k% S60
    # C" v* g; e. F/ L61
    7 v$ y# A3 I: A: D$ U+ v8 k62
    2 \: N' H  v$ S4 v- k63
    * [9 K+ u2 _4 @64* k- b/ D+ f; Y  s$ l0 w/ G
    65
    / ]# Y5 I: x+ ?; [! ~2 Y( P66
    3 B- }- V( G8 {: w6 Y3 e0 a8 P67
    5 p& ?% F" C9 a& v' z! m) _# @68& Q; I' R" j/ A- l7 Y; d
    69
    2 J) @0 `% H( B2 _3 |4 D- [70
    0 |$ _3 g" D, t0 y+ t71) y- S* O3 V% J$ R* ?
    72' ]" p/ f5 N# R; s9 q6 x% l
    73
    ) H9 v! d$ G2 S; G6 n9 _74
    ! d3 P1 N8 L& G* }6 d* x75/ ~& k6 n# ?4 {2 U
    76' |% N& h, ^$ A2 K4 c& |. R
    77
    # H3 L% w# P4 y$ V( D; E' A78
    " B2 f+ H1 W* t+ V# w& L, p' B79' k$ o. M3 d( @' k1 v
    80$ A  u5 W5 a5 S3 a  z# v
    81
    . u( k# u4 _2 M% X82; O3 r, ^0 [4 n7 g3 S
    837 H; d; H7 `2 S2 e- X
    84& a9 z9 b/ P7 c
    85
    $ Z& N3 @6 Q, P6 i86
    6 y$ w  R, S! m) J, r6 P+ A87! ^6 Q7 U7 S1 k0 N$ J% D
    88
    / k6 D: g$ k8 i7 r  u  r5 [89  J1 l: K& x* F
    90
    9 k" `  c4 Z5 `9 p$ }91! J: m* ~: e2 D- o+ G
    92
    ) m% k2 C% ], G+ ?$ d& C' L9 d93
    ' M* W& m! v* s+ }94
    ' t/ d9 w, y+ G$ X9 T" v# t95
    * j+ b9 n$ D8 w7 t/ N( V) j96
    # ~- S/ h% z, O, O& ~' f) n! T97
    - x. n8 m4 U6 _98
    0 f7 Q% }6 u6 c' _# W7 a99
    ( h3 P5 m! ^+ c$ D$ G100
    # B' p9 ]# b2 g% \101
    * c7 O' V8 f9 s102! ?/ V1 A8 [- [$ z7 j& T
    1030 w8 s/ z/ x4 G( d8 O; W$ I5 l% }+ a
    104
    ( X" ]; v5 n" y, k. W2 i105& `. N: s6 L+ @9 l
    106
    # w4 x& C3 m; S( r7 T5 `& k1076 |) J. s: i% }# S5 K. M* I
    108- J7 S/ T; W" v
    109& E5 ]2 O2 l9 b
    110
    3 g5 B6 s% S  _111
    8 b( A* O$ I3 o( q' P$ T/ E112- e$ _3 L8 V) e4 @/ ^  j' F1 i
    7.2 开始训练模型
    * s; k! A, \) D' l* v0 p$ Y我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次% k3 `. T9 \+ `

    9 l" o, X" I6 ]; s3 w6 _#若太慢,把epoch调低,迭代50次可能好些1 l/ {* Q( ~- F# ^0 J
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    # s1 T# ]' n, M8 P1 m/ `model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    * U  F# F  n" _  ^, L2 n0 R) n0 y# {
    18 F  x# \! y( s& i: x
    2
    4 f' X6 f* f+ B, g# c33 r3 z# r# b( b! G2 P
    42 G8 V( K8 G# _" H* f
    Epoch 0/4
    ) x0 S8 _9 y. L, m( V9 w& D& |4 b----------
    ! E9 h4 Y: i% u5 ~Time elapsed 29m 41s
      y- x3 q" r  G) E) xtrain Loss: 10.4774 Acc: 0.3147
    8 J: E6 {! T' I" J. t, |Time elapsed 32m 54s
    , Q" t1 Y/ M2 Y' k, ivalid Loss: 8.2902 Acc: 0.47191 G* ~& u1 {3 e3 V: [0 x
    Optimizer learning rate : 0.0010000% D4 |/ l6 b9 `5 x* x3 ?% s
    8 G+ U, X1 D9 S) O
    Epoch 1/4
    $ x% _/ `2 K! R" q1 x- ?----------
    : ?- U9 w( y' G  NTime elapsed 60m 11s4 c5 I2 S0 {8 [0 P8 Y4 j5 u
    train Loss: 2.3126 Acc: 0.7053/ \5 y7 j$ ^7 ^) b' Q3 V
    Time elapsed 63m 16s& G% }" |( p( ~, n/ u( F/ o% q, ]
    valid Loss: 3.2325 Acc: 0.6626
    5 ~( S7 E7 @$ g4 r8 _; mOptimizer learning rate : 0.0100000# s- h9 t3 E( {/ M1 t% x2 d

    ; v& R7 b% p2 _  F( }0 TEpoch 2/4( r8 h' @! s* _8 r
    ----------
    % v2 ]( a$ M0 g3 A1 [Time elapsed 90m 58s
    2 A2 }4 D2 P  u4 k) K5 btrain Loss: 9.9720 Acc: 0.4734
    6 }# w# {( h% v! Y6 ITime elapsed 94m 4s: V- E1 w! V, _- h# t) g, ?
    valid Loss: 14.0426 Acc: 0.4413
    - m* W, {, t' g! a+ LOptimizer learning rate : 0.0001000
    ! O  L5 w1 q! \- B" d$ T4 v) H4 ^* q0 R% H# D
    Epoch 3/4
    " n9 C! g+ X  Q& b----------
    : V8 ?% i1 }/ c1 QTime elapsed 132m 49s
    3 Z5 s6 x7 s6 }; m8 }4 l+ h! Xtrain Loss: 5.4290 Acc: 0.6548
    * Z9 V% ?" _: M2 Y! U0 \Time elapsed 138m 49s  d! Y* n' G. {9 I/ X5 J! `
    valid Loss: 6.4208 Acc: 0.6027
    , [# d, j8 P2 b! |Optimizer learning rate : 0.01000001 p7 z9 d3 ]6 ~  b3 U

    0 U. o/ T+ s# m& o1 I, B. Z9 DEpoch 4/4
    . J) B( D: e% B1 N! H! v----------
    6 G4 K. ?& S$ TTime elapsed 195m 56s
    * K0 L  a9 M: C; ttrain Loss: 8.8911 Acc: 0.5519
    , |8 P. y5 z& I% g  ~& DTime elapsed 199m 16s
    ( V5 b- {% g, n9 F( dvalid Loss: 13.2221 Acc: 0.4914
    ) L/ v2 G: U. R0 dOptimizer learning rate : 0.0010000
    6 v' Q; S5 q3 C! k8 f4 d$ M' k5 i) S6 q
    Training complete in 199m 16s) A5 s6 v2 o1 N5 y6 v
    Best val Acc: 0.662592
    # ?5 U+ M" L+ l
    1 R- g6 @  [% d5 W. B/ [/ e1
    , Y5 s7 n: W* t7 B2 C2
    ) B  R0 u6 o/ B$ _' O  M' I3
    ( w# m; B  }/ e. W44 i6 h6 |5 y8 K% H7 t1 T: q# c" }' l
    5
    2 z9 Z1 u: k; V! G4 }2 t0 `; f60 ^6 q; D- ~1 {! o7 X
    7  D6 U; X& z8 ]) O0 R* n
    8& G9 j; F7 S$ ?& J
    92 o8 j/ A! r* x/ g
    10' }5 b0 N4 ?( c4 J* y) S
    11
    4 k" [6 b0 ?* g1 P3 @7 K6 k% s12( C, h; z& k# ]+ a+ t+ J( `* \
    13* Q  Q0 I7 N* |) S, w' R
    14
    $ B7 ^4 ?  a9 I/ Q. c# P" x4 `15! X6 j% V: @7 ^
    16: @+ y! s+ }: ?2 q6 b  p
    17; ~. W% ~0 G  c; |. n: p+ P
    188 T. V9 C' E' m$ A3 Y
    196 N! u9 z1 O9 R# Q! s
    20  H9 x- p2 U: m
    21
    4 h6 f& y1 Y6 M5 y0 E, q221 T7 r4 G+ z7 c- F! x& G
    23& T/ U4 P1 q, i; l) t
    24
    ! i" ~) U- R7 d5 v) y# o25
    ! V1 n2 [; Y* {! A- G26
    * H1 {" h  g, a1 b. o27& V2 q# T/ q9 b2 Y6 n; a; |; S/ A9 ~
    28, R( ]7 \0 h1 C3 r4 I2 T  `. J: Q
    29
    3 q+ r  _" [: B4 n$ I' m$ M6 w30# P/ l5 w- X8 U( C6 U
    31+ o# Q! A0 T8 v$ F3 M
    32
    ( `1 ~6 e' G8 s. e' ^' i33
      C  h+ Z+ R2 W7 C5 _+ [- Q34
    8 Q% l) E( u) a# H0 n/ y/ E, D35
    , p- Z' u; |% R  w$ ?365 F/ J$ G) F' r; T) T, Y1 d
    375 a, n+ f- H' C1 |( l: y; e
    386 n) R, u! u+ o( m
    39
    ! e2 \7 e) M; Z4 _+ ^40; ~/ d6 R+ X% ^) m4 L  Q; F8 P* o' r6 F
    41! [: ]' Y+ F6 ^3 q+ |2 l8 M
    42' ?/ S/ ?% D: j9 c1 E" I
    7.3 训练所有层
    ! b: S* a1 a6 q# 将全部网络解锁进行训练$ i2 r6 G/ ~; t3 i7 k  M4 T! J
    for param in model_ft.parameters():
    . D7 B; a4 e8 w    param.requires_grad = True7 d  \1 `! F9 I
    ' m5 z* S; S4 M+ I
    # 再继续训练所有的参数,学习率调小一点\0 R9 g  Z( l$ A0 h- M) U
    optimizer = optim.Adam(params_to_update, lr = 1e-4)9 i% D. F& I% b1 I3 Q) y4 K
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)7 G& l1 F. N' t0 M* x8 K
    6 M2 ?' U7 o: T5 v( @+ N
    # 损失函数/ \7 Q3 @4 z. N$ k# e( k
    criterion = nn.NLLLoss()
    . O+ o. @& H1 P; x) y2 K12 ~5 Z# M6 t. ]  Z8 U# g
    2$ f$ V3 O* `; ^& _, }
    3
    1 [" h: v; Y3 N3 r6 ?7 v1 l. Y0 K44 C  J& q$ R: H; D
    5
    & T+ E/ `, Y5 [- p6
    1 [. h+ p6 y6 S# ?7 K( d7
    0 L  J6 E0 T0 P4 E' H: f) I- H9 s8
    / }0 W+ s# }7 \% x5 f4 P1 C9
    3 g6 y0 ^0 \* J6 }; _9 C1 K, ]) J10
    4 M7 u) y+ Q$ Y# 加载保存的参数  R4 Z3 T- T( J
    # 并在原有的模型基础上继续训练' {7 J, a$ A5 L& ?1 q% s$ w) j
    # 下面保存的是刚刚训练效果较好的路径+ _- }; ~' i# B7 X7 n3 f$ V6 O/ l
    checkpoint = torch.load(filename)
    : Y2 @6 A" S. e' R; Z. a/ W- Bbest_acc = checkpoint['best_acc']- ]5 B2 j  \) K1 {
    model_ft.load_state_dict(checkpoint['state_dict'])
    / p6 f% Q, ~1 k8 Goptimizer.load_state_dict(checkpoint['optimizer'])
    ! E0 f$ s4 m7 Y' k# D% S! i17 m- e* K1 Q& w. O* N0 S3 U
    2
    0 {# `& r! x: S" c; b3( T, `& j4 N0 }; W8 ~' n/ q
    4  g# V; t; S2 R. J- j2 F* D7 U
    5% E' k  M6 ?! E- b# f! Y( e- |
    6+ [: G3 M" D2 I& C, a
    7
    $ q" J; R6 H4 \1 P% _3 ~开始训练7 L! s- ?  K/ q/ _" @2 W# P/ g
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    8 a, a! u/ m- f. d6 E1 H1 L
    6 z$ U  t) L% d4 i9 r7 x& jmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))4 L+ T6 B. t5 o3 |; j: `8 E
    1
    . Y: p6 ^( s" L' Q3 H/ N8 OEpoch 0/12 K0 \, ^; T' T1 P' G  N" k  d1 M
    ----------( h$ X0 U' N/ @# _0 G
    Time elapsed 35m 22s; P; b/ v0 f. e
    train Loss: 1.7636 Acc: 0.73461 T# q9 a2 l, \0 ^: ?5 }
    Time elapsed 38m 42s* v/ J9 m( P" M, j( j# A' `7 ~1 O1 D
    valid Loss: 3.6377 Acc: 0.6455" z2 D# Q- b# R2 {/ A
    Optimizer learning rate : 0.0010000
    5 p& T! r# K* t2 J* Y) x8 }" Z4 S' ?0 g: F' `3 }' g
    Epoch 1/1
    3 u7 Z# P! {6 L  U& J" B----------' V5 }- w1 p& G3 c
    Time elapsed 82m 59s% q! o6 e# u6 u4 @
    train Loss: 1.7543 Acc: 0.7340% o) ]" j5 [9 J
    Time elapsed 86m 11s
    / }" S7 A1 l( E, r5 Q% H2 v6 @valid Loss: 3.8275 Acc: 0.6137
    / Q% t4 `/ Z' l# u: lOptimizer learning rate : 0.0010000& \4 i# x/ W1 s6 Z8 M5 x
    & L- u0 f4 d8 }: [# |" ?
    Training complete in 86m 11s
    " C: y: Q6 K; G/ {7 qBest val Acc: 0.645477
      Z- {" a- y& w! j; E" I8 U& o  M; f- [: Q- i! h
    1! [7 x1 W# B4 F; h
    2
    ; J1 N7 H+ v) s/ F5 A+ H% x% l31 r# I6 U0 V9 w, y
    4( N5 k4 o- E6 T% M7 @
    55 R0 |! [/ b" `. p1 D
    6
    0 t$ N  z1 c/ z5 s6 b, Y7
    , f4 d8 B3 c/ N1 S" ~8
    4 ]; a4 L% l( n' W  t3 `/ ]5 c4 y9; j& L9 v  f, m: ?
    103 ]( M$ Y! W( F# U! p
    11( J" G2 M! u& l& l+ X# Y
    12' I) L' E) {5 X* D8 u
    13
    4 D" z9 g3 S' T/ r0 I. m, D14. ]0 U& K' s7 d( v0 m) i
    15
    5 \! B' s( C, W' [: G16( ]1 g/ ~, x/ J+ [$ V; V
    17
    & ]; y' w1 A& N1 Z) n! J% d184 }0 p  r' I+ a2 |% a
    8. 加载已经训练的模型9 c; g" c8 q( u" z8 p) j6 [1 C
    相当于做一次简单的前向传播(逻辑推理),不用更新参数
      P3 Z8 B; s$ A6 U0 `1 e$ G. L2 i  k& g7 W" D1 k, B/ [/ \) Q/ j
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)7 B- A" X$ ?6 E- Y- f
    7 A/ N; E9 }' X( i; f+ N& s$ M. _. v
    # GPU 模式: s  i+ a) k" k/ g4 y; A  u0 G6 d
    model_ft = model_ft.to(device) # 扔到GPU中) ~' U- s. [! a% o8 F
    # M  W5 w4 U. M; U; N8 _
    # 保存文件的名字8 R! W8 T1 B3 q$ L6 R( B6 i) H0 s
    filename='checkpoint.pth'
    9 c# H  A! K0 l- E% y, R* w% n! D) U- f5 u$ k
    # 加载模型
      p: D8 a, c+ g! M: Y+ B9 A# ^checkpoint = torch.load(filename)
    . H+ y+ o6 V) Fbest_acc = checkpoint['best_acc']- l! l! E& P6 F& G
    model_ft.load_state_dict(checkpoint['state_dict'])' _' v8 Q' X& U8 x4 Q$ ^# j
    1
    0 v" i* T( j$ k1 S2; d1 }2 ]0 M) v2 K; h! M
    32 Y0 i+ l- s) c5 m6 ~5 R
    41 H4 ~- J0 n" D$ ?2 F2 y
    53 C/ l0 i2 t9 A% r: A
    6
    ( Y* d4 }, f* t% \9 j% _7. s- E  T+ R9 y0 f5 [  T
    8
    ' y- d! r7 H5 V. N) L/ _; n- V9
    $ k2 P  U! P$ c# W: B9 R7 O10
    ) ?# b2 A. m$ O$ O: K6 ^* y11' c2 i, x7 K* i5 r: ~9 ]1 j
    12
    ' L4 n; g: W- B. j<All keys matched successfully>
    8 E3 g, d- h5 W" f1
    % w1 a8 h6 @% n3 M" idef process_image(image_path):
    0 X! |/ ]! Q  N" I: S. l    # 读取测试集数据3 f; g5 S: `4 v% g  b* ?
        img = Image.open(image_path)# r' W7 p4 J3 k- b5 o$ G5 z! E
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    * O+ |& G0 U3 ]2 z6 ~    # 与Resize不同3 \5 H1 T' T/ S( V2 [4 G/ R
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    , Z# U  K7 s* K8 ^2 S- I8 \    # 而且对象调用方法会直接改变其大小,返回None
    8 h7 I* o+ |, a    if img.size[0] > img.size[1]:# m& L( o; g4 K
            img.thumbnail((10000, 256))* x6 t" k7 _9 \
        else:& K. l+ {, e$ R3 n% v" A+ V* D
            img.thumbnail((256, 10000))
    7 J5 Z1 Y" k- c. C) ?  G  S
    1 b. V: _( z9 d2 G8 m    # crop操作, 将图像再次裁剪为 224 * 224
    2 ~7 x, a; W# m5 J& u' j: m) F  b& C    left_margin = (img.width - 224) / 2 # 取中间的部分' Q) ^, }& e  [  J* `4 P
        bottom_margin = (img.height - 224) / 2 : ^0 J4 p; U: r
        right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    / I9 b/ C: C2 C/ W* M    top_margin = bottom_margin + 224
    3 `# F7 Y& R9 V9 q% M) ]1 C: a/ o- W- ]1 q: L- E& b2 }
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))3 f' k2 c+ f# Q8 \( N2 e; I, U
    . ^4 H$ ^1 X7 q
        # 相同预处理的方法
    . ~$ F; l7 c8 A5 L    # 归一化4 c- ~; m4 N# a
        img = np.array(img) / 2554 ^4 W* P, E3 P3 s$ ]+ R* u6 n
        mean = np.array([0.485, 0.456, 0.406])
    / B6 {; N8 v$ e) c    std = np.array([0.229, 0.224, 0.225])1 e, B5 I, l: d2 t$ v: p0 g
        img = (img - mean) / std
    ; A0 Q7 l2 H. b, {% c+ n7 A/ I9 y- Q/ a7 v
        # 注意颜色通道和位置
      M8 T* A3 d% F7 k    img = img.transpose((2, 0, 1))" v6 e$ o9 v4 v! m1 t$ {; w3 f
    ( k6 c) ?$ r) H2 q  [0 [$ E
        return img6 C6 t' H( T, D( Q
    3 `6 l3 {& U& \9 `7 p) d
    def imshow(image, ax = None, title = None):
    ( q$ {0 A) V1 g  a- [    """展示数据"""; \  M, i8 ~; v* L6 M) T6 F- o
        if ax is None:( e4 d) g2 t4 F2 L# ?, n! M$ a8 U
            fig, ax = plt.subplots(). |1 @! ?' R+ I4 \! W! N- Q5 s3 J
    . ^* w: \$ e4 k( P1 ~$ |6 g5 ]
        # 颜色通道进行还原2 ^- O. h% q3 ]4 O7 H( Y- Z
        image = np.array(image).transpose((1, 2, 0))
    # h4 T$ O1 V5 K3 b( k
      s3 S" N% G! b/ p. }# x8 e" P+ K    # 预处理还原
    . I0 Q/ p5 A  i: |3 B6 m2 n1 u, F    mean = np.array([0.485, 0.456, 0.406])) p: w. V+ d+ I! _; Q& s
        std = np.array([0.229, 0.224, 0.225])
    1 K; `5 x. n5 G) N5 w8 F  q    image = std * image + mean) I, ^3 q3 B$ A8 q  l& M& p4 d+ u
        image = np.clip(image, 0, 1)& B. m2 L' o" ]" [' d% ~
    2 y/ ]4 K! c7 ^0 r
        ax.imshow(image)
    4 ~( Q4 j/ v8 R! [! R, N/ O    ax.set_title(title)
    4 V4 w, _- t! X- A- F1 [* X' l: M/ F
    ' [- S' F3 l) ~6 {6 f# h    return ax& `0 s' ]6 j, l4 d/ `

    5 E. I% Z% O. B* |" `& {) A2 gimage_path = r'./flower_data/valid/3/image_06621.jpg'* n8 A8 x8 J2 M6 t
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    8 s7 T8 U8 s5 d6 ~& _8 aimshow(img)8 D4 L& c" y; r. Z. l- D0 u5 I! x. H. M
    ) a; T6 P' C4 ^/ Z6 t$ \7 k
    1' M6 N! @4 I4 F4 \
    2
    : H, E4 Q" n2 o3 q% `2 n3! a4 r" A1 \$ I. U$ J6 t
    4& e  Z1 c+ E+ N) D4 t
    5
    7 Q% H/ S/ I8 n  @$ Z6
    ) P6 M+ d$ e( S8 q- l7% ^# V  t  P, K$ X0 ~* L
    82 V( ?) y, t! E
    9
    5 @% w5 s9 F* A8 {4 m10
    . q0 P* C/ s- @  A3 p" D2 |' o) D11
    ! s* j; \5 }  q( {, h8 S12
    # C; D8 J- @3 p132 D6 U/ z3 A; U% V" `
    146 `4 {0 U* N# g2 {
    15
    ' |/ D9 c  Y2 A: O; {3 A7 N  w8 R16( i$ s+ d" R8 i5 i$ F
    174 K  [% T5 d' c
    18; B& M3 T( [6 ]+ y" c5 _5 }. \5 {
    19
    $ X& F* F0 }9 d& y20
    8 w$ @  B: d3 ?: `$ G! |21
    7 b8 d9 \8 `8 P) x4 V+ v' {22' r% W4 v5 m9 b
    236 S9 g1 t* X" J9 P) J
    24
    : y7 p5 b# k/ U" H25
    " ?5 L8 T& D( K26
    . H- [# i8 E7 G/ s" |" A27
    " S6 I+ T- q7 D) j28% j& d) i6 h* Y% i7 z
    29# i  z5 F% Q; Y  `3 f* b; @$ L: u
    30) |- L- e! f+ q% X: ~+ v
    31) P! n% v' |/ G( N. G+ W) j
    32
    + h) g7 c( d# y8 h4 y# }33
    1 P9 ]9 n- d( W1 i" L" e" P1 u34' t$ X7 J# b) s* v
    358 {$ P; m9 U8 [
    36# E; p! ~, U2 x  U3 u6 v5 Z
    37
    9 k- d: u% |2 k! J38
    2 D, j1 w$ ^; {# g+ l: U39
    ( d/ `0 `2 s5 D' N/ R40
    ; l% j* d1 b/ V6 g7 J% c41$ Q" Y1 h$ {8 U, h
    42, e  F: d) g) K! k& Q# S
    436 M( ?5 z! P2 v5 f9 V
    44
    8 K, K0 R2 |1 S1 B45; a. A, d9 i0 j8 O
    460 A+ n  u  K& H: [
    47
    8 J  B4 F% @+ A. ]4 i48- M8 @, L2 X9 _7 n
    49, a6 T" ?! |! g6 Q
    508 \4 A  h% L" Z! I& _9 y* C3 C
    51
    2 {; `  d6 ~  ]) h' @! D0 w524 d+ g4 S: T' g. }
    53
    / g$ [! A) ?! [5 ^* d* D* G. p8 p$ \54& ?1 i! n: y7 h8 j1 }( J
    <AxesSubplot:>- B! H$ l/ t0 h- m7 a0 t
    1
    + M6 g/ D; r8 |0 v% S; N4 a; j' `* v3 r: F
    上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确' [/ r5 x9 a3 ~& y. Z+ j' O3 z, Y

    ' r; K+ s! x% O/ \# p# F( R* q7 _img.shape
    , |; p! T$ T( s5 [1
    3 E) F5 H8 ~. o, B(3, 224, 224)
    6 i/ k8 P( a7 G/ k9 ^1
    3 w" q' n5 a: k9 A$ y2 g6 k证明了通道提前了,而且大小没改变
    2 ^8 \0 [$ u/ F  ]
    : ^7 z6 S8 f4 X. \, e# T9. 推理
    5 N/ K8 T9 a% c' z" Timg.shape! @* F, i0 T% Q% z
    & r% |7 V: i5 v/ Y+ s7 U: F
    # 得到一个batch的测试数据6 p* F/ q1 e2 R  _0 E$ n' E
    dataiter = iter(dataloaders['valid'])2 N/ I, R% k9 f* T" N, B) \
    images, labels = dataiter.next()9 g: o5 \5 C" {  Q( i
    : m7 _- Y. R9 h7 D
    model_ft.eval()
    * q! [2 b: i; L* T( h& H5 x  ^0 T# T* D2 R: K, x, p
    if train_on_gpu:3 y/ t/ ?/ ^' e8 e$ X
        # 前向传播跑一次会得到output  n% N; I7 `* f3 @) b" D7 W
        output = model_ft(images.cuda())- C' l  \1 Y/ c! ^& S1 A5 @
    else:
    1 ~$ \) c6 q0 }$ }; J    output = model_ft(images)
    : ^0 b; S0 L2 {3 q$ H8 z- w7 u3 q
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值9 C! J0 L' `) J, p# v& _* M4 P( u
    output.shape
    * B8 c, ~  q- X8 C
    6 c0 O3 J4 k; D& r! s$ I5 d1( R  o- u/ W  f4 m8 }$ O2 I
    2
    3 c5 a, O& ~, D& z7 R1 b3
    5 U7 Q+ b4 `% l; l2 C" U4/ Z9 N; o- X5 h* G1 s% g; a( [
    5# \: X8 G. q! a& A+ B
    6
    + r& I4 u/ ^- X% ~& f4 M8 ^7 Z8 o7
    8 O; S1 {+ j2 w% K9 H/ N. @8 ]8
    . a. D; O5 n0 j7 k93 u" u/ |! G+ i' X3 o9 N
    10) |7 H0 i# _5 [2 A0 t1 B) t+ z
    116 d1 X! F( @) |
    12
    2 T- U. v# N- |2 e- E/ }3 o13
    & A' P0 I% @9 d14
    " h1 }$ t4 p* n. o* g158 }2 i5 ]  \& I) `8 \
    16
    ) b) n) p2 ?- k) m& r6 J# c9 ltorch.Size([8, 102])
    ' H" F" |/ ~+ S! B" I# u1
    3 E: J) g, b/ n, L0 }4 a9.1 计算得到最大概率+ R- _3 Y, Z( {2 X( J8 ^, P, `
    _, preds_tensor = torch.max(output, 1): a# F( w) J4 v

    ' M) W  e7 l, n5 u, F- i0 A- u2 b: m+ mpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量  m$ L1 a4 G' v7 k; g' ]" k+ g8 Q: t
    16 @  W6 w7 D9 }
    27 e0 E; h. U' E$ Q+ D4 T- @
    3
    3 C* D* [$ j+ r( h/ F8 \' X/ B9.2 展示预测结果
    ; V: H2 ~, S% x/ z' @3 c3 Pfig = plt.figure(figsize = (20, 20))# R2 g8 {" e$ w$ M
    columns = 4
    - S! f4 p% A  V$ A$ Y2 qrows = 2
    8 m/ j* e6 s5 W0 j8 V9 u! \! L; K8 L1 Z0 J& }
    for idx in range(columns * rows):% \* g2 f  w" o$ E" l2 k' m
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    - G, I7 t. Z8 e1 q" ?. m- }    plt.imshow(im_convert(images[idx]))1 [0 T, Q, o! G$ l; \% |
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 8 S) D, ]& ^3 x$ X2 J
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    2 D1 q/ n9 Z) H0 G) ~/ Qplt.show()
    : \6 S' ^! |" H! n: D# 绿色的表示预测是对的,红色表示预测错了
    : N* N8 x# B2 g" d. ]14 U8 K& G' f' s% w+ m" Y9 i% s
    2
    ! F  Q2 K7 {8 x0 H+ `3
    # W, a$ j; @# K; p4
    * x9 f. q6 A! O) L$ l' a% g5
    9 W; c+ m( w) N/ ?- Z# }/ l6( R: C5 M. X; R9 ~/ I
    7
    4 ?4 S5 |& U7 c8 E# H/ y8
    # z* I$ `9 ?& b; q9+ \% ~  O6 I" O4 h( a
    10
    8 p* M- T' w; P0 K2 p3 o/ c, w! ?11
    , D: r  T: W5 E
    / H/ X  ?5 e' _; ~: r( i) U# t( D2 c0 s

    6 w1 S4 j  t4 A. S, V- n% H————————————————" [: X: g2 N1 v
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。6 N, O: C0 H* P9 |6 j1 E" a" T
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940! c) i& p7 `6 E! @$ E- R& W4 a8 _

    ' P8 N+ [3 K1 q; D9 J, p* l  f& e% M* I
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-13 06:33 , Processed in 0.960276 second(s), 51 queries .

    回顶部