QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2719|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例' S4 S0 ?2 x  h) M8 {
    / a: T6 y+ F' l) z! T1 l& E' T( |( L' b
    文章目录
    : ]3 W* J# P0 y卷积网络实战 对花进行分类
    , X/ ~- a, O- W% D数据预处理部分
    % d7 r% E" h& T: H* c- |网络模块设置# i9 a( U+ o9 J/ F6 c
    网络模型的保存与测试
    # X$ c; a3 S% Q  |% T9 D/ U4 K数据下载:, I5 v0 F( ?% a$ M
    1. 导入工具包, E- C; k- ~2 l* Z' G( I, Q1 J
    2. 数据预处理与操作
    4 B# {; A+ E( E! Z0 Z3. 制作好数据源( @/ f! C: P: R: \8 v& N2 C1 @7 B
    读取标签对应的实际名字
    . l( N% B7 e" v2 N! P4.展示一下数据
    ( `$ ^% u$ M# k6 d5. 加载models提供的模型,并直接用训练好的权重做初始化参数5 {3 ~9 l, R: _0 g% B; D
    6.初始化模型架构
    ' Z8 V. h9 J" j& s* S, ^7. 设置需要训练的参数% C. [+ Y9 }7 R( ^/ ~1 w4 K  q
    7. 训练与预测
    # C+ P% a! ?: M) B8 n" \7.1 优化器设置
    + U+ Z9 z0 T: H* G7 p7.2 开始训练模型( p# m; N* @: s
    7.3 训练所有层
    ! i& g5 l6 ]/ ]( k5 |) b8 A" z开始训练
    # B9 g& a5 _0 h8. 加载已经训练的模型
    : t- c% o9 `7 C* [! w7 G+ g0 q6 x1 V9. 推理
    3 q) c9 |+ ?, E: N9 p, x9.1 计算得到最大概率
    . D9 ]: u: T1 a' R% H. ]9.2 展示预测结果0 B) F+ C6 C- P0 A6 a
    写在最后0 `! D" R6 h( U& U% B$ z
    卷积网络实战 对花进行分类
    ! R5 s# [! q: g" D' d本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    * K  t/ b* W. |
    9 z6 ]9 T$ G7 `; c& X% ]# I在文件夹中有102种花,我们主要要对这些花进行分类任务
    5 r# _/ ~  |9 D$ c/ r文件夹结构: L6 f$ E. q& B# }8 Y
    ) N5 ~5 K, v4 V! M/ q6 f
    flower_data
    : h5 x. O4 Q8 d* e8 [
    " {" K) [6 \% j/ k9 o' b7 I, Gtrain' M1 Z8 R7 n3 J# A) C

    / H% Y# o- d" ~4 c! C1(类别)8 O, i  ]: \4 |) S# t+ Y
    2/ A: x% c/ g) G2 [
    xxx.png / xxx.jpg
    ( S3 _* \; X2 q. i) |valid9 m3 m+ e. u- `" y

    ! a; p  k9 b5 P( p主要分为以下几个大模块
    % m! {% P3 R! X! o& j
    3 t: j2 y/ m7 S数据预处理部分
    ' c! F$ |2 s) P( n0 }数据增强
    " _; y, [- K- t4 U- \) j数据预处理
    ( ]  }$ O# S6 [4 u0 r. e  G( ^网络模块设置( P9 b4 R4 f! D- Z. a( U
    加载预训练模型,直接调用torchVision的经典网络架构* B1 v6 k0 D( U6 \
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
    + X4 s: B; _+ k; ^, d9 M网络模型的保存与测试+ j/ Y; r+ P  P" e0 N! r
    模型保存可以带有选择性
    $ l! G7 H0 H, E  `数据下载:
    $ l! ?& T) H. rhttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    6 M8 ~. r+ W5 N. @9 v8 [8 J3 V& g3 `, n* |7 d8 V' \
    改一下文件名,然后将它放到同一根目录就可以了
    & }' S$ A( _  ]5 `/ B+ ^) y4 |+ G( w4 H! ~% M& J( `  H
    下面是我的数据根目录1 v1 V6 K# `. Z6 L: `4 \

    ( u; B! {/ D, E8 K+ S
    6 ~4 M* @+ d% y2 s: l% h1. 导入工具包
    ! O2 \" P5 c! Z$ A7 U7 \import os  ^* q# {" v  i) g3 Q) L
    import matplotlib.pyplot as plt, b3 J9 @) V$ O
    # 内嵌入绘图简去show的句柄& J/ J0 `2 r) D- p8 U) }1 ?% c7 I
    %matplotlib inline
      g) v! L1 a* W- j. O1 Qimport numpy as np
    - N8 Y9 S1 r& R+ n' X; `8 g6 s  Q2 nimport torch0 T, Q& A) n5 B* m/ L: W4 P* R. n9 j& R5 @
    from torch import nn8 E1 ~0 J: g: R  b) ]9 W

    5 y" m) W) P# L( O4 Vimport torch.optim as optim) A% B" E: l$ _' D
    import torchvision! x( F' l& I4 s/ E2 A2 V
    from torchvision import transforms, models, datasets9 D2 ^+ K3 U5 s! q
    / A1 q6 X& d' |0 M. R; C$ h$ S( A) G
    import imageio
    5 Z2 \. a! `" w- l' ?3 l# H% d5 y8 cimport time
    ( }. I1 \' t% O, {* }import warnings
    5 r$ D0 K( B. `3 oimport random
    ! V# O9 E) v2 z. cimport sys: L" Z$ g/ s9 E2 b9 z' h6 c
    import copy
    ) O' \  Z+ F9 A+ Jimport json( ^! v3 B- h1 E3 r9 a
    from PIL import Image
    5 [  g' o, f4 _# _1 n. v9 z3 r# c  Q: X

    3 X3 V% B  M9 A/ W% D8 E6 [1, ^. c" K6 n% a% ~0 [
    25 r+ M7 l' x, X* C% L9 d- p
    36 a" N' h1 e$ K1 Y. J4 X
    4' K9 a! t* S4 k: h5 U* a4 [1 X) A
    5( }) v$ @/ E5 H/ u+ I% b+ s) o* u
    67 [! ?, h) m$ z& A% f" K+ i& c
    7% f' U5 b! p3 d9 c. h* X: U
    8
    # v! H1 S4 z! Z2 J9& h; ^7 p- H& P- ^' j0 g
    10
    0 a* _. i0 W) j- f11/ X! _# K1 u7 I7 y& n
    12
    5 |# Z0 @5 }7 ]  t13
    , }4 [& J! Y5 ]. P4 G14
    / W2 g, c* X) U. e9 e2 q$ N150 K/ j! a. Y/ \1 E" C* k' r" [
    16' H8 L" l; c4 t2 t
    17
    # V+ W! a2 W6 Y; Z8 Y0 p8 e18
    8 ]4 m; y% {4 D6 C8 a& G19. v) ?& J7 _6 J
    20  V- V5 a/ E* e9 N6 w0 J
    21
    9 ^- C9 U8 r7 J* y4 j7 N  b2. 数据预处理与操作
    / H* Y! u& y' }. W+ r#路径设置4 I9 I  R# v! j! Q
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录* @  P* h/ e8 M9 z# ^
    train_dir = data_dir + '/train'
    1 N. u9 N+ a- E1 x- Pvalid_dir = data_dir + '/valid'
    ) `. B' U6 Y% {. U' Y1
    ) J2 q" a. y2 g2
    : j  r- Q1 x: l' l: _0 Z3
    , g+ P+ Q& t) P4& n) _/ [7 I+ u) a# A
    python目录点杠的组合与区别0 S& ^1 `6 K. |* C" b" g
    注: 里面注明了点杠和斜杠的操作5 ^- C' Q* N. @4 Q
    ! B0 r1 o9 `# O! Z5 ~
    3. 制作好数据源
    & f/ \1 X# L, t3 `! [/ h$ adata_transforms中制定了所有图像预处理的操作- b7 ~5 c5 I, E1 X0 q! ?
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
      O3 y" n0 y6 P: U6 z0 cdata_transforms = {
    4 T  D9 k) p* z; {( Q' c    # 分成两部分,一部分是训练
    9 ]: l* I3 w, B* _: A( z    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间3 D* v- z! j6 h5 r2 r
                                     transforms.CenterCrop(224), # 从中心处开始裁剪
    4 X3 C& h4 s' E- f; f6 ~* w. j                                 # 以某个随机的概率决定是否翻转 55开1 v! y- J' Y" Y  W% y
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转+ B* c3 n9 Q+ u8 }9 G! {- l
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    7 w+ P# p/ n" H0 L0 A8 B; g                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相/ L( w9 p: {4 V- m$ z
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
    6 X$ T' m# t, e8 ?                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    1 F1 a$ |* d  _! h; a                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    7 m; j, V( g; E" D                                 transforms.ToTensor(),% Y$ n0 E- p6 f3 J& n
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差; g1 C9 G( O# Z  [  K
                                    ]),
    : f& A6 @' o+ `/ V8 o# S9 A! U& e, |    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化3 |5 x' c( `9 D2 z" d0 C: {
        'valid': transforms.Compose([transforms.Resize(256),
    ) _. S) d8 T8 G! m( ?6 A                                 transforms.CenterCrop(224),2 p: V# Q( B# G, G* a0 Z
                                     transforms.ToTensor(),1 y# Y0 m, n1 }
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    8 U7 k) q8 c  _" ^0 ?                                ]),
    3 G% d3 [5 Z2 D# g}
    # {7 r) w9 c7 X* L3 e
      t4 V' I. q2 S1% q7 C% k5 p4 x% N
    2
    ( W1 [" r6 K) ~1 r3
    9 A& z' r( r. f- t4
    ( ]; b9 T# q, N5 F% k7 Q$ ^4 Z' ]5
    ' Q, j* t4 i' Z, c/ V6% X% A& U, R! t; e
    7
    ) K) f, s2 d0 y" N. K8
    , f. d8 f! r: q8 Z90 H1 h# u3 \/ J- u/ \9 A; O# e
    10% q9 W( K& K) A8 T4 Y$ ?
    113 y* r; }8 N. n' y0 u5 g) I7 O
    129 @0 c2 U3 n+ Z  V2 X& c4 l4 t1 z
    130 }$ y4 }3 I! j0 Y' z8 C# A$ P
    14# E; s" v+ L, Y" k# N+ e
    15' c1 K0 }+ I" T4 N# R% k
    16# o0 k4 [5 F* A" l: j6 b
    175 R* w( x# h# ]4 ~/ L9 Z
    18
    ( Y! a& ]* j4 b. w, G19
    0 {) k# [; e+ h6 u" |+ ?) P3 `( D: d20( g  ~* q# r( T7 r* Z) ^) ]
    21
    7 I" d. k6 Y: O& b0 g! Dbatch_size = 86 d3 _3 s9 S: {, V' w+ `+ [
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}3 i. q. P. ^/ e1 `3 Z! D
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    5 S7 s& g* ^1 P5 m0 ?) [. Ydataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    + T( y. `/ P6 {7 F! j# Nclass_names = image_datasets['train'].classes) V0 ?# _% h! P
    ( _, e1 f% z3 `( w0 z3 F/ c$ b
    #查看数据集合- O2 R7 o1 s. i( o; v& M9 j
    image_datasets
    # c8 l' v* s# ?
    & L( P0 o3 ]- z7 s6 @" J1( W, J' Q9 {" w0 u& N/ z8 o9 H  g
    2
    5 c# b, j/ K" Q* k+ h3
    8 H& K& P; Q' ^1 T1 [( {% o4 U, j4
    7 y3 `- F* N6 B2 S, g, m# U53 w4 q  ~' w; z3 }8 A( {8 a0 j
    6
    4 k6 f/ n" t, A$ f! u9 O78 I2 M# n( e/ n5 J+ A& w2 }, _! v
    8
    + ]& D+ Z  a1 w! f9
    : z" d# }/ ^. k) T) c{'train': Dataset ImageFolder: K) f4 J$ A1 z
         Number of datapoints: 65526 v# A/ e- E1 q5 g
         Root location: ./flower_data/train
    ; w, e2 `+ ^; v, i9 F     StandardTransform
    ! d! W  ?( ^8 s; U+ m: i Transform: Compose(! I7 F2 n+ \' x
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)4 w6 Q6 c4 y' x3 m7 c$ K% T' Z
                    CenterCrop(size=(224, 224))
    8 h" D% c( V! ~* E( Y2 B                RandomHorizontalFlip(p=0.5)
      Q6 A; O: ]" V" m+ p6 R                RandomVerticalFlip(p=0.5)  ?+ s7 ?+ a; \" S+ U" }
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])# d* ]% H7 H/ k) Z" ^$ d! _. n
                    RandomGrayscale(p=0.025)
    2 k# m4 f# o* i7 [                ToTensor()
    3 h0 v9 ?% B! w  L                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])/ b- d' V$ S* D3 Q) ?" K+ j* d2 g
                ),& {, R( ?" ^5 ?$ P- u$ {* b
    'valid': Dataset ImageFolder% s5 L& G: r3 T* P: r4 r* Q
         Number of datapoints: 818
    / }' w. M6 A* i- n3 h/ o     Root location: ./flower_data/valid
    3 \2 ?6 o5 w5 K2 S" {4 h4 _3 v3 U     StandardTransform
    & n" s, Q2 [4 H* b& P- d1 x5 D Transform: Compose(
    ' K% ~7 L9 F) i, Y) |  W) C                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    " [2 T% n, K: U2 v# c5 D% J$ C+ ]                CenterCrop(size=(224, 224))  V- l5 F" u$ ?! y7 u: K0 [! D; e5 ?- H
                    ToTensor()' d9 J. y& O7 w# k9 }. \
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])8 p8 M& y, ^/ o- F
                )}
    % @- c7 x  I" V5 U) M* q# E+ |) t& e0 e
    1
    " l# ^4 p( t( L7 P" K# f( G9 T8 q, m0 u2
    ' e+ O2 C$ t  D3 A. r, V36 W1 B" i' B" r- z5 p' K$ Z% ^
    47 N9 Q- x) S: [% K; ]
    5: @: _3 `0 K' n) a" L7 a. W
    6* |# ^6 J- a3 p& s' l! C1 j6 K7 {
    73 }- V6 [' g3 T5 X( I
    8
    + G' \: s' F  f6 `9
    / x7 Z% G6 O8 u2 M- m9 |& T1 F/ D10
    0 A4 x3 Y, U/ r7 }" }% V11
    5 z9 Q7 \9 y  M$ _122 s& S# O5 m3 L5 C6 g4 `, y
    13
    9 R- C: y+ X# R14
    7 I8 o: G# w: F; N6 V  r& Q. ?15/ D6 v/ x7 t: h4 i1 o# r( S4 k
    16
    / j8 W) j1 z4 a$ o# y$ r1 F17
    / ]/ I1 l" w& ^. ]+ l8 [18$ p3 B7 ?# p4 g
    197 P5 w1 y% g+ j7 z$ Q) X! c8 Z
    20- m7 P6 F, p. B) |% V7 a# Q7 Y
    21
    . u7 e7 O1 j' G' c( X* o22. s4 e# O- ?1 E# F7 T
    23- h  D& u) |1 H. t
    24
    & v8 W+ q  V4 p# 验证一下数据是否已经被处理完毕
    5 a1 _5 k% }' [8 J% ]dataloaders3 P  P1 t, u+ _, |1 d
    1
    " |5 ~9 n% D( i. H, e5 j2
    , M/ K8 |; Y' e; y: N. z6 }& C( t{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,& D. i5 a+ b* B1 L
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    - B/ x  B* S) A- ]# u; `/ p( X2 k1; J/ V5 ^5 E/ N* v3 J
    2
    ; g) d& J4 _# p  Bdataset_sizes
    - A6 ~" ]7 b& v' }6 `1
    ; ^) \- n8 Q5 a{'train': 6552, 'valid': 818}3 J: c! r1 \* ^& P
    1
    / A! z' x  L% K读取标签对应的实际名字
    : X, P2 t  S  ?. K使用同一目录下的json文件,反向映射出花对应的名字1 a6 j2 ]. r; U% ~; F2 h
    ( A4 t& B8 N5 {/ q: a
    with open('./flower_data/cat_to_name.json', 'r') as f:" j6 Y5 i6 k# @: L( g5 O' m
        cat_to_name = json.load(f)
    9 _9 _3 v- J- s3 ?: `" f, V1
    4 |* u1 E/ l7 ]/ h- f2
    # ~4 z. G2 }: w  _* Scat_to_name
    ) J2 T5 ~/ D( {1& Q& w4 @2 o, P( r
    {'21': 'fire lily',
    . Z, O8 L0 H- ^# S7 ?4 r" S9 e6 f '3': 'canterbury bells',
    $ i2 s& e  H) D" {' w8 k8 |+ { '45': 'bolero deep blue',
    ( C4 E( C, x& E- E9 ]# E  y '1': 'pink primrose',. z* I% @9 |. C2 F& R" O
    '34': 'mexican aster',' B7 p# V' [7 i0 m* k- d/ D
    '27': 'prince of wales feathers',
    . Q% c* D) r# {$ T '7': 'moon orchid',8 S9 v4 [+ t) R" l/ A
    '16': 'globe-flower',
    2 V, E" P* Q) t7 I '25': 'grape hyacinth',
    % v+ _' B7 I8 A7 \% N4 M: O '26': 'corn poppy',
    8 Z5 E2 Z+ Y9 L$ Y0 K( ^( v '79': 'toad lily',4 E* `2 G  T. O1 U% _
    '39': 'siam tulip',5 U# K+ I& N6 n4 I1 U% U8 o
    '24': 'red ginger',
    3 d& _4 A4 I9 N '67': 'spring crocus',. [2 O4 S9 \# S3 s; Q0 S9 M( m
    '35': 'alpine sea holly',1 K# Y% z2 _0 h) x2 s& T$ ^
    '32': 'garden phlox',
    # a  r. d& N; v- K" C2 l '10': 'globe thistle',
    ( v  u( s$ J2 e( n '6': 'tiger lily',1 ^3 U6 f' A' F4 @
    '93': 'ball moss',9 i9 p8 X9 s. s1 o4 C4 x3 s, {
    '33': 'love in the mist',
    ! l9 r9 `# ]0 j6 s& s" W' b '9': 'monkshood',1 u; c0 V' C# n) g" N8 ]- v
    '102': 'blackberry lily',4 H( F& k2 E% G& ^( X+ p
    '14': 'spear thistle',! B+ b& O7 P) Z$ B- E/ \7 H
    '19': 'balloon flower',
    1 b3 U+ Z' Y; J7 R" b4 Q '100': 'blanket flower',
    - f9 A7 \: C( Y+ I6 r+ r  S1 S! R '13': 'king protea',6 S0 ?# w, N7 I' R( P5 `
    '49': 'oxeye daisy',0 T* r% ~5 K! w; h
    '15': 'yellow iris',
    0 T4 B! M0 ?$ Z) K1 m: e, a '61': 'cautleya spicata',$ D* x7 s/ z! U- F. s
    '31': 'carnation',4 f. @( \. J8 {6 [' @, a* c
    '64': 'silverbush',
    : W2 l6 d# a( i+ x8 S '68': 'bearded iris',7 G! y& n! k5 Y6 l: Q/ y, _
    '63': 'black-eyed susan',
    7 Q" U3 Q+ u. |- G8 K. F '69': 'windflower',
    7 Q- i) R! [- {+ ^" U0 @ '62': 'japanese anemone',& G0 z3 w8 m* u" A4 u6 k
    '20': 'giant white arum lily',  \/ A6 h, R8 b/ l% f& v* w
    '38': 'great masterwort',1 @& d) D: R" [4 O; h- ]9 ?
    '4': 'sweet pea',
    ; k; x  Y" M( N2 a- l# Z '86': 'tree mallow',% e$ u* [( {9 [0 Q7 _9 b, B
    '101': 'trumpet creeper',
    / u) n0 z. B) s: k '42': 'daffodil',9 A9 [' a. f# }' x/ L$ W0 T( O
    '22': 'pincushion flower',: o/ X; D* W7 z- g4 o
    '2': 'hard-leaved pocket orchid',, G* l6 \: l( h; x! L
    '54': 'sunflower',7 o7 N) j: s" Y7 i8 C
    '66': 'osteospermum',6 K  H1 n, ?& T8 D: z1 v; {  y
    '70': 'tree poppy',( y& c2 E4 h% w# @- ]! u5 @' h
    '85': 'desert-rose',
    / e2 A3 N3 a& j5 [ '99': 'bromelia',
    - ]5 _: r9 b5 e# U- X2 V! r '87': 'magnolia',: |- D! B6 O( A" w' z. Z4 H. Z* t
    '5': 'english marigold',
    / S* O% u' I% O0 F2 }& G '92': 'bee balm',
    ( y3 O& g! B. S6 s/ f '28': 'stemless gentian',8 e7 x. C  \% w! j& O' y* y7 i
    '97': 'mallow',, g7 O! o* n/ W  s8 X2 m
    '57': 'gaura',
    ( y) M& k  d5 o, N& x# `+ P- D" h '40': 'lenten rose',% H" F% U) {+ e* M
    '47': 'marigold',; Z# j3 m1 A: z( W& _3 H( |
    '59': 'orange dahlia',9 n6 X. z, m( |  e' |1 h
    '48': 'buttercup',4 O* v  U4 D' ?" @+ P4 P) N3 L
    '55': 'pelargonium',
    7 q3 m6 ?1 q6 g. {  k1 b! n% w '36': 'ruby-lipped cattleya',) y+ ~5 Z) }- n, a# @
    '91': 'hippeastrum',; e) P, D' Q; A$ Z* k) N8 t: E
    '29': 'artichoke',6 s8 g0 a7 y/ A1 i  y5 s% t' ?2 ?
    '71': 'gazania',
    ' _. v) j) W; e4 L7 S& L" l4 l '90': 'canna lily',
    * o9 G" G1 N3 f  z% F) u- v8 | '18': 'peruvian lily',$ n, Y: g  m* j3 k9 [6 I7 J
    '98': 'mexican petunia',
    6 V  {7 O( ^; _" A& @ '8': 'bird of paradise',
    $ K- @' |2 j# k% B* D1 c% U( q3 m '30': 'sweet william',3 p$ w2 B% n0 x9 I* z5 y
    '17': 'purple coneflower',
    0 u5 c# `1 b6 i '52': 'wild pansy',
    8 i7 j. t3 o5 @1 n$ y3 \ '84': 'columbine',
    # x, B0 [) g' J '12': "colt's foot",
    - D" p' E) n5 @4 Q '11': 'snapdragon',
    / ~3 p2 x! r6 J7 k. M2 p '96': 'camellia',
    , U# m/ J" H3 E7 h2 L: j) g '23': 'fritillary',% |; h; Q( _0 n- N+ q1 f0 z) |
    '50': 'common dandelion',
    ( P/ k7 k& S7 I  [ '44': 'poinsettia',7 x- |9 K' q2 z0 u3 c
    '53': 'primula',- U3 \( K% q0 t( F" g, [
    '72': 'azalea',
    7 T1 y& s* F; s8 ~1 o, r '65': 'californian poppy',
    / r" b  z) o8 U' J9 C$ F '80': 'anthurium',$ [, T/ I9 S3 {6 e- p, y+ B, N+ c2 t
    '76': 'morning glory',
    8 I& [9 I" H' ]% p2 S- q9 C' ~ '37': 'cape flower',
      q5 t( ]# N1 M, @' k; z; {3 U '56': 'bishop of llandaff',
    " X) u- S. @9 X5 S6 c' s9 \+ t7 y" b '60': 'pink-yellow dahlia',
    . q+ g4 i0 _7 K' y' ]- ] '82': 'clematis',& P4 P, c  @3 p& N/ x: K
    '58': 'geranium',+ h% K1 P5 S& u5 n% C* P" B- K
    '75': 'thorn apple',
    0 H; U+ E3 H- w6 W7 y5 y) ]& d '41': 'barbeton daisy',
    4 s* I$ g  |! [6 z+ f& Y$ ^/ @ '95': 'bougainvillea',
    ' A1 W5 b' X- {4 Q* ~ '43': 'sword lily',( F& w* A7 v/ j# S6 T
    '83': 'hibiscus',
    ; @& ~) D+ l# s/ `5 T7 v; I4 B '78': 'lotus lotus',$ i* I" E% V6 a  x, Y4 F
    '88': 'cyclamen',
    ) R0 i6 k) O7 k7 A+ \- N '94': 'foxglove',
    + v' y1 T% L7 j3 S8 ~8 p3 }, j '81': 'frangipani',2 x+ i" C- q, N( y3 H# G
    '74': 'rose',4 L0 b+ }3 c+ J! @
    '89': 'watercress',. T7 R- v; M- |0 [0 G# d/ @
    '73': 'water lily',
    ' Y: C5 a2 r0 @, m9 c; }, o '46': 'wallflower',4 Q$ J% h+ N0 @( P# U. Z4 Q
    '77': 'passion flower'," C8 [! k% P* p$ |0 S8 V2 y" B3 |
    '51': 'petunia'}% _: G: f8 M0 ~6 l- M

    ! @1 r; }" s7 S) n" q7 [1$ O  r1 V7 @: n
    2
    $ @% H9 r$ p/ D3 p) m3
    & u) ?3 e! Z2 R' |4
    9 ?3 ?/ a2 R; S1 s! ]* j% ?7 c5: O1 `5 V5 Q# h. F. q
    6
    7 c- O( u- }' ?/ l1 o1 i5 t7
    , J6 `1 u0 N/ O8 M9 t8
    # Y& a9 |8 Y( Y$ `: V  i8 H9
    5 [( B6 I% G: T10
    4 O7 E* V6 B& Y8 U4 t110 e2 V: \6 p4 t1 K8 Z
    12
    % U* C2 E0 V2 k3 ~  Y5 m+ w13
    4 l( \2 s5 P3 i( q% P14
      ?1 N) V8 C: c, w15
    " U! _0 b) k8 z* q16; A& z& C1 u: Z
    17+ [5 |8 T$ Z2 v0 [0 g, x+ G
    18; B3 O/ n6 _4 q, q4 a/ y3 D
    19
    # g: G6 h* M* F/ r7 ~0 e20
    + G: r! O6 o) v3 B* k21
    ; U7 m4 E- X( ?22% m( t: d* s* |9 `# T, X
    23
    0 W) i  M/ h/ Y24
    " X8 |' G9 M9 c) |- Y4 }1 D25* Z5 P0 l! u, z# x# {! e4 v2 q: O
    262 l" Q  ~  @0 ~
    27' H& L6 b& o. z
    28% w: E* n/ C. _8 T
    29& i/ {8 T* g& g+ G  [
    30. N) W6 |- O9 S, v, R, n
    31
    4 G  _5 Q" M- w$ r: s7 W( S32, Z6 q- T: T. B; l  N& B" }, }* I6 \
    33. k( K0 v- U8 r9 v$ Y
    34
    2 B, W8 W( e4 v. p6 ^35
    : q& B: M! ?8 n) y8 [+ c& u36
    : {* B) \3 M0 C/ U37
    0 k3 D3 Z2 |5 i2 P382 i/ y5 `- j3 R( S
    39
    6 w3 t, }% |& h( H40( n! D/ j2 o1 i; @* B+ e( t* W
    41, B) ]- L' ~* y
    42
    3 r5 \! Q; D7 |) ]+ E, L430 c" p( T/ f6 t4 X! x7 Y% h- Z
    44
    : l8 h/ W* F- x: B7 f0 }45' q2 v' p7 ]4 U1 {
    46/ w! t8 l6 G) C
    47  J# ]9 |0 D, T* N/ K
    48* k% X& d. J' h
    49& J* t( k! L- T+ }! y0 ?
    50$ b4 |6 l: Z- w. @5 l, O! O, W
    517 L4 Q) G8 i; d9 o  s
    52* C! b% h) N; n: K& {( Q; R
    53
    2 E' O  N" W& Z: l* T54
    ( }8 J. v' {/ r55
      V9 |" m/ E$ y- Q! f" `56
    2 p* u3 ^0 B& V- i573 X- t* i0 b( [' r! j
    58
    1 w! S7 I2 {8 i% O2 h59
    : `4 P2 f! b! E  x* l& m9 F- o60
    : |( {: K7 F3 I! {+ z61' e8 b+ a, X0 a) s6 w
    62
    . Q. r$ k1 l3 R  L635 k" G9 R4 D  w- i0 h# Q$ C) e
    64# m2 Y0 ~. u7 i' p! H/ p
    65
    0 [' A$ ]& M4 x66
    " D3 Q; v& l" f; U$ {& Z67
    7 u# T4 W6 B  x1 a5 y1 L68
      t  j% b, `% I( o+ g# n# {1 s69
    ! A: D# P: T# J4 y( X70
    5 m, X$ X  C2 X( ^* [0 L9 W* E71
    ; {1 F" P) A4 H2 a72" i& j- F7 O1 h. C! e. O
    736 u# I; v$ T0 q. [
    74/ i& M9 \9 }: {, d- p
    75
    6 Z- k7 j! g6 `76
      ?3 X7 G5 @0 K% k" T; M77( O6 K9 ?" A- m: e6 ~
    78
    , @# |$ V. u2 K& V( N  j79& s& ^3 ]) A7 q, \. P( j
    80
    5 ~9 O8 M' M3 Y8 B) B  E81% {" a7 w& ?3 |/ Y3 }, \6 @
    824 Z' T) }+ b$ b8 S0 p) f
    83
    . U* A/ ]4 R4 W2 _2 X+ y84' _9 k6 \% E7 _. [% z
    85
    ) U" p+ F4 d- R6 G* ]0 e3 v# C! `86  z! M, K7 W5 s9 V5 O( a
    87. t; L9 ]7 G( W2 m! s' Y) W. e6 M6 C
    88
    ; r6 @; T, t: s. m" D* k89
    , M: Y3 {3 t& ?4 {- @) H900 E9 |" o! u0 J! g. W
    916 R5 x) c2 A$ I1 d! y4 C
    924 h# F8 M/ S7 W+ S( g
    93. t9 }" {1 _1 H9 U" T2 N, f
    940 t* |3 x5 R% Q/ l
    95
    : C+ V3 o* W6 |- |% E3 h. K0 C964 a: i& [! j  t9 M3 C
    97
    , }3 B1 h) n' Y# U987 d6 ?3 q5 d, B: O* W. b
    99
    5 U- o1 {# X# k0 s  k. V100
    3 ?" s1 h, q, W) [2 M  k  X" R) ?101
    0 [, s/ n% H7 o# W* ^& j102
    # w9 W1 N# N& H; |9 R' p4.展示一下数据4 ~9 N8 G) t  z3 u  R
    def im_convert(tensor):; Z; t/ t6 m! `  G+ j! Q) R! N/ `7 S
        """数据展示"""
    0 Z, A3 \1 l! Q- X' z    image = tensor.to("cpu").clone().detach()
    ; \+ N+ P5 R8 J1 I    image = image.numpy().squeeze()/ r, ?: d) l3 ?  y
        # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    - y' w3 U7 ~7 X    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    * [; [( V  V+ ]# u' C    image = image.transpose(1, 2, 0); I" K: @0 _+ g# s
        # 反正则化(反标准化)
    & ?" ]$ v5 u( ]: `+ @' Z    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))& D' y: f, U; \/ I& `" w1 I
    5 \* l3 K7 l  G/ g2 n' G; n
        # 将图像中小于0 的都换成0,大于的都变成1
    / D8 ]. n5 h4 u% W. @    image = image.clip(0, 1)/ ~8 p3 L: A3 v/ O' ]6 O4 w' _% F
    4 C6 Y( F* m8 O" G4 _6 ^! g
        return image
    ' Q! h& L+ f" x7 Q: a5 f) h1
    ' d0 n& W# q' Q! T2 x+ d* m4 b8 k- I8 V2
    - _9 X! ?  e7 l) h) v3
      H4 r. j% S+ }& \8 y$ c7 Q40 ^; ^! O4 q, G6 i* J" r& [
    5
    7 p1 a- J. w; W1 U& T* m6
    % T; S4 t) Z3 B% c3 Z, i7' R- @- V2 T  @; c& z  g! T# Z! S
    8/ x4 E! P: F% _
    9$ O* V/ ]  \" y, P* Y  e
    104 u1 k9 f9 b& `
    11" `  _2 p. m) L# w8 T& d  C
    12
    * P1 I) d" a* k1 G13
    - _( W' w% @$ P7 O8 ?$ s14
    " Q0 `: [+ P$ m# 使用上面定义好的类进行画图
    ; Y1 U$ s& D* p! }) Efig = plt.figure(figsize = (20, 12))# }0 q+ K9 s! N5 r
    columns = 49 n$ C9 o* W! l, e% W4 g9 b
    rows = 2
    6 u' k/ `4 ?* W/ ~; F, ^6 R, p! d0 @& T4 b+ {# l! @' Q/ b% X
    # iter迭代器5 C. `' e; M3 {: E; R
    # 随便找一个Batch数据进行展示: Q% o6 u# e$ L7 r
    dataiter = iter(dataloaders['valid'])% P* b0 W/ j6 u8 n( L9 ~, w
    inputs, classes = dataiter.next()
    # h9 ?/ X+ T: H1 w! H0 K* c8 z  U" f2 _7 g7 B$ e9 x; H9 I4 ~
    for idx in range(columns * rows):
    5 U# y  _, p. w) v% K7 j  ]" H    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    5 l: d$ _3 u/ y4 {8 L: k    # 利用json文件将其对应花的类型打印在图片中
    8 k% f, A6 ?" t. m    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    6 Y# e- o$ O' g; h0 M    plt.imshow(im_convert(inputs[idx]))
    * ^' B- C2 ?: M$ r8 Dplt.show()
    & p$ C5 C" K. J& o0 A- V0 c) j. j5 @
    % V* A1 c3 C4 c* e4 I8 c* B; S# D/ @' Q10 }1 r  K3 h$ j* e( O
    2
    7 ~5 ?+ D: o0 J# u+ E) l7 I) V3 s3. J5 a1 |* L0 L" o& _
    4
    " t- k2 S8 ~7 z" x52 [6 H4 c- T  v* i& b
    6
    . L/ d/ G0 U1 j3 @" y! Q6 Z7% v! _* W# e6 b: N* `5 ?; P
    8+ W. V! I" ~% {; x) i' h- `  o
    9
    ( B% c" U" b& C( J* K9 p( d109 {! w2 a5 N% ^8 ^, x% K& ~1 Y
    110 r' C7 G) @$ C6 A: t* r6 j( M/ N% U
    12
    + j- i2 Y, z! \: V- ~13# |) {' i' {5 w* Z
    14
    : H: i  k- \) B  [) b" D15
    4 o5 O4 _# J" f/ E3 k4 x16
    ' P! k7 H4 o' N3 o  ~) a7 g% d% f' f2 P/ B

    * L" l  h' {% R5. 加载models提供的模型,并直接用训练好的权重做初始化参数, B6 g0 e% G) q/ q7 j7 Z( |
    model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    % m" p/ R# v5 A# n7 u# E* b# 主要的图像识别用resnet来做
    ! c4 d6 `$ ?/ j2 `# 是否用人家训练好的特征1 f6 N7 P, ~$ I, V. U5 _
    feature_extract = True& A* \) T  e  l& y
    1
    - e4 r) q( @# Y6 G2. H2 Y) h2 _7 A9 a
    3
    $ p! K  V% h2 Q& c6 }0 R9 \4
    3 c* q! ^  Q5 h- I: A3 r# 是否用GPU进行训练
    & I- m! w+ B+ L; T8 Ftrain_on_gpu = torch.cuda.is_available()# |: K! G3 z: O8 n

    ( T" L# L. H9 _6 rif not train_on_gpu:* q. H, w# ?% v* r4 H
        print('CUDA is not available.   Training on CPU ...')) `% {5 V$ E1 f0 U, x; e
    else:
    $ h: i$ @! b5 W+ F  V- t/ R/ D) w2 }    print('CUDA is available! Training on GPU ...')0 ?) |6 A' F! Z- N/ O
      m. H; u3 `- h: M( @  r
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')8 r' P* w3 [1 @
    1" C/ V" c& q" P% D
    21 |, M$ _( I+ I0 _: d( X
    3" d: J7 C: u2 W" \& u7 W- _$ i
    4& N2 X/ \. W3 b9 {
    5
    ( J' H% E7 O5 B' z; ]) q67 n  ^" f: z" g' T9 e: C- f
    75 W$ V: ^& E/ c9 {+ H
    84 f3 O/ ]; p& o
    9
    " n: {1 x8 x6 i8 |& }# R+ e2 ?CUDA is not available.   Training on CPU ...% g" \6 V/ Y8 P" R, X
    1
    ) j+ s; @9 Y2 B* X# 将一些层定义为false,使其不自动更新7 ~5 |' \  T! }1 l8 g9 u4 @1 I
    def set_parameter_requires_grad(model, feature_extracting):" f$ t2 b  m# o  U0 g( D
        if feature_extracting:" m" F% j- _, l0 T
            for param in model.parameters():# b/ M. C" h- Q: _
                param.requires_grad = False0 B* g9 [" o% f$ N. y
    1; L3 r/ q+ M" w2 k9 t
    2
    , l. |3 u; G0 T6 b$ V3- }. b, @* Y- M
    4" K5 P' f8 S, u9 z" L9 g5 J+ O
    5" F0 `, d3 q% x& T+ {) S) d$ C  }
    # 打印模型架构告知是怎么一步一步去完成的
    ' E, F! V5 \! B( V6 w: f( }+ I' }# 主要是为我们提取特征的
    + S8 E* F7 c( F" t7 F, c% U3 i; i+ ?0 G) S9 f) u8 n; E% G
    model_ft = models.resnet152()
    , E1 c7 D+ Z$ j- C. z, u" Q- lmodel_ft
    * b0 f" A+ K' e& {1' `$ W( X- e7 I' P
    2
    9 F% V! W' X9 p2 m/ F3
    ' O1 N8 `- H7 v% ]; M4
    , G3 g4 [; c3 K) X7 {" b; E# X$ w% q5
    - w* `4 a, _# w; B3 p( M2 |ResNet(7 W9 O7 I" q& D, x0 V: s$ L
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    ) t& `! ~4 A+ @  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ }: ^; y% G* F7 I, Z+ b& V9 f
      (relu): ReLU(inplace=True)
    " O4 k' U$ z5 d" U  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)" d5 l- j3 K; q
      (layer1): Sequential(, @# U* F3 m, h4 T5 I
        (0): Bottleneck(' {* B; I2 x; U6 w
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)" u- R# T; ~8 E4 e. H0 j
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    0 ^9 w. D3 S4 S& K. F7 E      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)1 c# q% l8 O; H# Q
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    2 V) @9 q$ S( M- W; p( n6 @5 d0 o0 J8 x      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)5 k, ?7 ]+ `+ x, p5 v. {/ C, x- |* a
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)& u  T  N# d& d+ F
          (relu): ReLU(inplace=True)$ q# ]1 j: N# n
          (downsample): Sequential(
    8 U! ?2 w9 f4 L# r# t        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)1 h6 \4 ?1 i2 U9 x1 a/ c
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    5 y" F" G1 V( ^$ s7 N      )! d0 e4 b7 x; q% B( e
        ). ^/ w* F3 w# ^/ m5 Q  B
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。% ~- c) [. |0 m5 X' s) v6 J5 M2 A
        (2): Bottleneck(
    + u- w" e% K( d7 i( j6 J1 i1 m! Y1 K      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False): o/ Q4 z/ \/ }# ]: h# M/ u  F
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 s& U9 [* N* O' q) s+ ~1 m% N" G
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    - Q) n" f/ U9 C- Z' P1 D8 R8 }8 e      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    4 Z1 |* T5 f; z% E% d+ i      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)1 Q) b9 @: ]2 v. Z! B: [
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    - _' y) f3 M& ^$ r% N6 |      (relu): ReLU(inplace=True)4 D% ]# g  m! e+ x/ F' @
        )9 `3 _' ^# F5 z' B, e/ b+ F
      )% C: {4 e5 O. C3 {6 m& e
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    ( G6 w9 L1 S9 g  X9 P, I! i  (fc): Linear(in_features=2048, out_features=1000, bias=True)
    1 @/ a+ ~% x: q' Y4 v+ l4 I6 x)& Q+ K3 _: m9 H2 Z
    9 V, ]# j4 T6 ]$ _
    1
    ) Z' ?, F- X1 P, ^* {/ @2
    0 B; R" K+ J7 N" j% S3/ q6 N2 c# `8 Q5 ^) t# B1 A3 @) k
    4
    ' v; B5 a3 \: E' S7 N58 _# S/ Z) {7 `: C3 |! D- M
    6* [# x; j5 i  i( x/ V0 r
    7
    , d1 c1 k9 G- w7 |0 O8
    - ~( j% A3 ~3 |1 Y94 B9 X" H4 T8 N7 O$ |" Y2 G7 S$ P# X
    10& c; C/ U; q# ^( X( e) z5 m- \
    116 [& e! d7 h% f: x# T! g( ^
    129 |; ~8 E! V$ a7 D  A3 M( ~
    13
    $ U' J/ q9 w/ X$ B14$ K& I1 @/ a: `! \" Y1 c1 w& C
    150 U3 n1 T; U; }. N2 k; n
    16% u1 s$ e. I% ]8 q2 z5 `- ^+ \; j
    17
    / b  q0 Z8 s! E; n( m4 l18
    . G1 g; Q! h! b( e  U. `19$ S  W3 @& u+ ]7 y
    208 v" L0 `7 ^4 _" I  J/ e
    21
    9 V! ]) ?2 i2 y7 z& {22
    ' _* j# K4 s6 P, D. J0 F23
    ! h9 F$ J, v! Q& J24
    & Y6 T: k. {5 f; M1 ]) K) B25& \1 W9 a. M3 n- y" e
    26
    : [( T1 `* P% O' I3 W# ^- i8 J27
    & G1 {5 Y& k+ b; _0 N28, Z! d* D" m" d5 J8 t6 K  o7 N$ O
    29
    7 ^  y# W6 Q  T, R; g$ Q30( p. l! W6 V3 b4 b
    315 t5 {/ [! {- C, B- ~
    32
      P9 M: {: m0 y" E! R33
    6 i3 s# H7 p9 T9 \6 z最后是1000分类,2048输入,分为1000个分类2 }4 j7 }" I  E
    而我们需要将我们的任务进行调整,将1000分类改为102输出
    - }+ E) n$ S; r- }. [
    ; ]* Q$ E  i$ i6 g0 B% \  m6.初始化模型架构
    3 U- w- N" R, m) O) S7 k: ~6 c步骤如下:8 D3 ~) B5 m& j& \6 V" i
    " ^3 _* Y' p9 l0 H
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数" s% m; ~( Z# q7 T  l, u; y& _. o
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
    : y& S1 V! u( V' H无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数4 ?! K3 ~6 R( u# }2 X
    官方文档链接
    . Z: b- F' k+ U- {" thttps://pytorch.org/vision/stable/models.html
    * L- _. {" N7 S6 ?; Z2 j0 G: L# P; C0 g5 O
    # 将他人的模型加载进来7 G  J3 k; w. i! _8 ^
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    # u8 Z- i1 Q3 {9 q# X2 ]8 E8 h/ }    # 选择适合的模型,不同的模型初始化参数不同7 i7 Y! Z1 ~' ~6 ^$ A  g2 ]
        model_ft = None7 L: T. s. B6 F7 Z( G: W1 [! B
        input_size = 06 K- k5 g) `% m0 h; @5 Y* r

    7 J9 K& s$ V6 }4 l, o' z    if model_name == "resnet":9 H; u* l+ A1 r+ l- F
            """
    % [, r4 z) v( m6 I$ N5 X        Resnet152
    $ u7 S0 A1 u" m2 b5 a        """
    , i+ }& d8 b' X  I3 g( a
    % h% p$ ^9 `* t0 P0 G" a        # 1. 加载与训练网络3 G1 P/ b1 H4 I3 [6 s
            model_ft = models.resnet152(pretrained = use_pretrained)) ?3 ~/ |7 }) x8 L* p
            # 2. 是否将提取特征的模块冻住,只训练FC层
    % _: h' w. U) c" I        set_parameter_requires_grad(model_ft, feature_extract)& Q1 U9 L$ ]: U% r
            # 3. 获得全连接层输入特征
    $ l* z/ a, w$ J5 ?  A$ x        num_frts = model_ft.fc.in_features
    ( o4 k- M) C1 H0 x! _        # 4. 重新加载全连接层,设置输出102
    - A& q2 l, G7 b6 Q& |, A        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),- R/ Z, L7 W: d6 M
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    9 a5 }4 B$ T# {& m        input_size = 2245 E; Q6 T, @2 N  X( i

    9 K, P0 o; e7 H/ [2 v0 ^' L* f& C+ F    elif model_name == "alexnet":
    3 S- F' L6 S( O$ [  H7 v2 ~        """
    7 m# S9 d2 _" l# D        Alexnet* u- r% M' V2 f4 c+ e4 |4 {
            """
    5 w  X& X# r8 F7 H: v; \% Z" f        model_ft = models.alexnet(pretrained = use_pretrained)6 f% k: E5 M+ H+ g4 ~4 K
            set_parameter_requires_grad(model_ft, feature_extract)
    4 X* b6 P& ^$ e2 E
    1 z7 E& W! O6 z+ u        # 将最后一个特征输出替换 序号为【6】的分类器. N. I5 H0 Y! k
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入0 D) @+ G: z' I1 O- D5 l
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    ! h/ x4 K' @8 L  [4 w7 D        input_size = 224
    ! @4 @" u4 |" ?1 d( J: x; a5 D) C2 ?
        elif model_name == "vgg":" T. Q' g2 Q' U
            """$ q5 g( `, p2 h+ Q2 U3 J
            VGG11_bn% \6 x  _7 X6 q( k7 M: R
            """8 f$ ?) ^3 q- m6 ~2 `/ X
            model_ft = models.vgg16(pretrained = use_pretrained): A# V' e9 g9 A  ^3 O
            set_parameter_requires_grad(model_ft, feature_extract)
    7 n0 T6 J, |* S7 |! y# y$ p$ T: q        num_frts = model_ft.classifier[6].in_features
    5 C3 w& l  R9 g# j# e: ^; K        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)& x5 `9 k1 Q, N' h" f
            input_size = 224- G. w$ I0 J0 R

    0 P4 q! `. ~- v* |4 Q    elif model_name == "squeezenet":8 F) D: H7 h& T$ W
            """
    5 g; H8 X& C4 z) z        Squeezenet. m  L0 w1 `# V$ w
            """
    , k/ w2 v: Q$ C" i/ S1 m        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
    # w; b, j, k" e8 v- W9 C  g; X        set_parameter_requires_grad(model_ft, feature_extract)( Q4 S& A4 y" i$ D
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))* z) d" E3 K) S% u/ u0 _2 Z
            model_ft.num_classes = num_classes3 }+ G+ x% l8 ~8 C! d, t
            input_size = 2244 h* B% S; T* q, W( F- t" O# z
      n0 }7 M  p2 u0 X5 {* P
        elif model_name == "densenet":+ `7 `# p6 a$ E  a2 Y7 j: N. C
            """
    9 R" Q% K& B" a8 q6 \        Densenet
    0 X5 W) P9 Z# R' Z' E9 d$ `        """
    - o: q& x7 c6 q( n% l/ n        model_ft = models.desenet121(pretrained = use_pretrained)+ Y9 \* _5 f1 G3 n3 f  U" U% n
            set_parameter_requires_grad(model_ft, feature_extract): _+ d) H8 B2 Q7 i; }0 Y+ k7 p
            num_frts = model_ft.classifier.in_features6 W& P  N$ S+ o1 Z! n' v% h
            model_ft.classifier = nn.Linear(num_frts, num_classes)7 O# C6 X) U# k! w. J+ M
            input_size = 224
    3 X, O  i$ t0 n5 G. I
    + R( `7 E- ?, n3 G    elif model_name == "inception":
    0 L' C" `; K! x4 B5 ?$ g        """
      a# \: P; [0 d, h% y        Inception V3, R/ O' v8 @2 q, V+ O
            """
    - D) Y) z2 J" b2 e+ I' ^5 J7 e& I        model_ft = models.inception_V(pretrained = use_pretrained)
    2 l  g! u8 x7 F4 r# K        set_parameter_requires_grad(model_ft, feature_extract)
    1 e# j/ a! [, c) `+ R: u& j) }. U+ w. _2 E" n7 C
            num_frts = model_ft.AuxLogits.fc.in_features% ?: v' F! Q4 u7 y( N- N
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)2 s7 D3 k& U: R) U' S5 q# q1 l/ `3 C
    . M- n9 _2 M. M, y
            num_frts = model_ft.fc.in_features: @  v, f6 z: Z8 J. G
            model_ft.fc = nn.Linear(num_frts, num_classes)
    8 d" X# E2 q8 m% D; a        input_size = 299
    0 O" `0 v1 Z5 t' ^4 o" L. ~( x
    7 N. e" j  q" V0 \    else:! ]' b' @9 L7 i
            print("Invalid model name, exiting...")& Z, P" i/ y* E# F  y- l6 C
            exit()- o5 A0 c+ V) F" T$ R

    + s) Z& I; I0 |. w) L    return model_ft, input_size2 m4 k8 b4 P% k% G% u+ R

    4 l/ A3 S; [9 w& ~4 s# Q) `1
    ! D. s; s' ?! ]/ {$ P21 M7 N# F) i. b5 y" `$ s& C
    34 J- d7 u( [: L
    4
    0 Y; r, B  c2 e6 x6 h; G5
    7 [, X! n! i8 y7 j- a% I) `7 ^6- N+ W+ I( u. p. |+ V3 N& M; V! {
    7
    , \) ?" |( [8 E! K8- U+ H7 B, G+ U
    94 ^! e# o5 @+ }' ]3 Q
    10: t# ?& _& U4 _5 y4 i% v8 x
    11
    * n1 t4 L: q: G/ h; c12
    & L7 Z" c1 @* \: \  V. p/ o13* e' j: y. q# j0 w% N" Q$ J
    14( @; f# [% }4 a9 r3 u
    15
    6 a" N4 I$ w, y* ^1 F4 n16* f0 Q* j! v: l% h& L; g
    17
    + x' v7 k: _( h; L; J5 F18; I/ t  U. d0 H2 m; c- b7 S8 _
    19; d2 J, _5 {* P( b. |" G$ e
    20, ]( a' |& e7 ]
    21
    - n, d# U. a% v22, M& v& Y! B% x# T) p) M3 Q
    23
    ( y1 B4 `! Z. A, w3 L246 H" n9 v* t. O- W4 G) w5 k  B& M
    25. U( Z! j1 _% i5 @* i
    26
    . f" k7 S/ ]: n  N277 L8 A5 t% y+ X& }: T
    28
    $ p- j5 P, H1 W9 I* N29
    $ d. w- {. b  K4 |4 E/ A2 Z30
    + C. s1 `4 @$ D3 \- t, D31
    + g6 z) x) ]- d1 m& s+ z' V; \. z. _) O& ?32
    ' l7 m( |, ?* v2 E33% V5 |/ l" A6 ^) f- Z! @
    34" t9 l: j5 [& A1 P$ D
    35
    6 E1 o% _; r1 i364 E! F# E' f3 u' J. z0 W" \7 R* ~
    379 v! ^7 w" t6 b1 D7 K8 X9 \
    388 ^" g3 b  J. L8 j8 ^
    39
    ; S: m0 H% g4 t, N40
    # [5 E' M( R) t9 h2 }* @) f: m6 m41- s; w0 c5 n& {
    42
    - b6 r" K, ]9 V5 s4 _$ ~8 R43
    - |7 K* y/ t9 L- e: d8 O! W# R44
    ' y" c7 ?* C* O. s/ W45
    & z4 i# D+ r4 Q9 Y. [3 j46
    ; I3 p8 x- h/ U* D0 k( D47
    , U! j9 Q& I. m7 }) u48
    ' B, B/ P0 B% R- v) c, {: P( O7 r% z1 U49; k* N$ ^; j6 O: ^( v: V' b( _
    504 R9 I! E+ m# H: X, @8 T6 V' |) G
    51
    ! u. j7 z, }  c. v( t52  a* G9 u$ M1 H' T* {  @- u5 L
    536 E' ^$ c2 ]5 z$ g; D: n
    547 [  I6 Z& [* R/ [
    55
    * r* w+ k3 o! `0 _/ f  ~- o4 C1 g56
    2 D" {' L& J9 k" J/ N2 w57
    7 N7 \3 O$ Z: [  R) e1 P& C0 H( A588 W* p3 E5 ?+ {# t4 A9 s$ \4 [! a- H
    59
    # `- K1 A; h: l! I3 [, h, m! {60
    , V' m" [: s" q' {: r8 |9 V* j61
    ' y5 t# C2 f2 t& x- I: P7 }62
    6 g% a$ }, @* \, W# J3 u# o& M7 P63
      H# P& y8 Z. I* s1 H; B  g64
    7 C2 c( T8 }6 {( ^/ N3 ^4 K65
      C% m# n/ ~- N665 V( \8 R' Y& T9 I# r6 F9 I% e
    67
    ! }" s6 ~' Y5 P$ x68
    % E- v% T" \8 x+ F4 ?5 ?/ F6 {3 ?1 G69
    3 ~7 }" M6 X( j6 r( z4 p- k! |5 l70
    5 b5 H2 t* s" O$ O" \5 L. D71
    7 _# g0 t8 `3 l# E* m, a( T72
    5 S+ G4 R5 U0 h/ v& I+ J" e9 _' X1 Y* p73
    ' G# k( U& J" k) g( ?) G74
    0 Q) @. m1 _" i( h& D3 }6 B" [75$ f5 N6 J/ W* _( D! U
    76
    9 s+ w) s% M( y: O5 A# j2 o77
    . H9 v. g  ]$ k" b781 D- Q& S! A0 d" t# g& i
    79
    / a! V' {: X" X8 S& \/ o7 R* f802 T8 F& U- C& c/ T1 d
    81
    5 P' l6 n6 g' Q, @5 z82
    0 M* g4 Q  U9 f; q7 C; E83
    ( [* U6 o  ~* s, X" f( @" \/ C7. 设置需要训练的参数( i/ f6 e" x6 {6 ~; e: u% A
    # 设置模型名字、输出分类数/ d. v! x$ i$ O+ G
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    + }* R+ E3 {. J/ X1 E. Y2 i$ ^, b1 p9 m
    # GPU 计算
    . E  B6 s9 u0 ]( d% m6 Kmodel_ft = model_ft.to(device); z7 C8 z5 y1 x0 n' s

    " ]/ ]: ^0 k1 |5 q: d# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    " j" m% |$ v! m. R% J' Wfilename = 'checkpoint.pth'
    9 ?' z& T/ t  z! S, p! s
    - L7 L+ ^) V) B+ E# D% X# 是否训练所有层
    - U" K* D% U0 R9 v0 |- P, v' ^- qparams_to_update = model_ft.parameters()
    & y! C+ e- D5 V# 打印出需要训练的层
    1 M& d9 |$ }1 m: r) Iprint("Params to learn:")
    * g. d! \& a  }' C: b4 E: eif feature_extract:6 J  p7 G! k; e4 C) a" G: z
        params_to_update = []( |; {0 D' G% F7 h# {
        for name, param in model_ft.named_parameters():
    " i1 v7 z  K! _& C; m        if param.requires_grad == True:1 D0 L3 C0 D2 e# P6 T6 ?4 G
                params_to_update.append(param)
    3 M' u/ ]; y  ~            print("\t", name)
    2 w9 j" D9 p1 jelse:
    ! \# N9 e4 G* M' b* q8 O4 |! D    for name, param in model_ft.named_parameters():6 I/ r  h; I1 o6 I
            if param.requires_grad ==True:9 I, D( _+ N% H. E$ e5 m
                print("\t", name)  ^6 K; X( Y6 s+ l& U

    % E/ S$ q: B; o6 x$ K17 o$ v1 N7 `2 e
    20 n7 E. [/ D5 i
    3
    2 p, q; A7 x  c3 n0 U7 n0 b3 J9 n48 }0 W# R1 q, E/ `( B. v) a. V! X
    53 J/ k. B; ~' R/ [9 u
    6
    ) `7 i) j+ z- d: h9 U3 F6 `- o3 Z7
    2 c4 c2 ^) r4 x& j7 U' k  p, {8. v) e- w- l4 u+ B  g
    9
    5 a+ f6 S0 c2 H10
    2 M, \; i. X2 D110 o( u0 c; a- e
    12
    9 V3 j+ O9 q3 Q( E; E13' i6 d, \/ X3 B% Z3 C3 P
    14" x" d) A, b5 ]8 g7 t
    15
    / g- f6 X* l# l1 U8 L16
    9 L( E) n# A1 q7 {6 B2 p% I17+ r  o6 ?* Z4 b, I
    18
    8 M% m2 Y  t) s" @! q! G! A" z, L8 d19* y8 s8 k7 C6 J7 P: X0 w4 D# q& i
    20$ ]9 N0 i3 w4 f& y  e; f2 R
    21% [( v8 C4 g1 \! g
    22
    * c' |/ |) g2 @- l! T' x& R9 Y23
    8 G$ e2 n  w# N, k! T6 n, yParams to learn:
    + w( m3 C- b  v! d/ R% F- v         fc.0.weight
    * [; o/ h5 H; C! ?0 K         fc.0.bias
    & {' j7 \! X- M; J1% @. V4 `6 Q# X  a' [- _6 w
    2
    ; d& C# v& f) D/ J$ j" Y# K3
    " H# U) u! s, Q7 s" ~$ c7. 训练与预测) i8 _: I4 w- o$ n! i+ A. ~; f
    7.1 优化器设置
    3 }5 E; D  O# e: {( \# 优化器设置9 ?5 g! I$ ^: l# r$ N4 W* M. ]& Z0 O
    optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    & {9 N0 Y5 ^+ ^9 y% x& E" v# 学习率衰减策略/ k8 u6 b# G* m' J/ r" M
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    % g% R; N1 _- P9 f6 d( ]# 学习率每7个epoch衰减为原来的1/10+ w( ?/ `6 A: C' D
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    ! J0 |. D9 B3 u2 W! y6 C; D4 t0 d% h+ I- i' J! M0 ?1 p- J
    criterion = nn.NLLLoss()( z: f0 o5 L: A, [: M& O
    1% j8 |1 r+ S' B, e, Z: g) C
    2% h$ S; E9 t0 ^
    35 j# E' D# X) K6 u. ^0 s
    4
    6 Y8 j+ k( O* j) b5
    8 Y* D/ ~: R( C; u4 q' y6
    : t1 O( b! C) r! `+ Q3 h* l1 n1 o72 G9 J+ l& C) e. U- c
    8
    6 ^  }: ^3 P( v3 F# 定义训练函数4 A7 ^) Y, d6 l$ u
    #is_inception:要不要用其他的网络* G* k$ u' o4 X0 N4 x- t
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    1 A& \1 v8 A  k. m4 A" U! }; e    since = time.time()) N! ^& M+ K$ Y# I1 Y* |3 t
        #保存最好的准确率! I' B9 p* s! S$ _4 j
        best_acc = 04 `  l% G0 a/ p- W+ ]; o
        """
    0 F) P+ T/ [8 K  i6 b- F0 ^7 I    checkpoint = torch.load(filename)
    0 B+ z( w" \# K: N6 u: S  `5 w    best_acc = checkpoint['best_acc']: s5 L  y6 G2 m0 z+ `; }' d
        model.load_state_dict(checkpoint['state_dict'])
    . m* ~: t2 @$ D& S8 U2 P. q8 Z# k3 S    optimizer.load_state_dict(checkpoint['optimizer'])
    * i. |( g6 ^+ t# Z3 I    model.class_to_idx = checkpoint['mapping']
    * h2 B! M! _5 W& e    """
    2 G  u) f* r5 g    #指定用GPU还是CPU4 L* T! o- j' g1 d$ v* G
        model.to(device)
    * C& L- l- R* K1 X    #下面是为展示做的
    ; Z5 B  v; M+ \# X. ]0 `    val_acc_history = []
    & t; o) U- d: P4 ]' [    train_acc_history = []9 R7 D& J7 V6 |6 c3 b% X. n2 z+ V
        train_losses = []
    0 i5 K  s4 ^- i2 n  }9 s& S    valid_losses = []
    * m3 W; z9 c; b! k8 s# f1 X9 W    LRs = [optimizer.param_groups[0]['lr']]- K4 {5 \5 p; ]# R
        #最好的一次存下来
    * y" z5 C8 R7 Z+ P, V    best_model_wts = copy.deepcopy(model.state_dict())" \( {/ ^! R/ ?' N9 s- g$ j

    . ]9 P; Z7 S+ I/ R: A6 t5 g* @    for epoch in range(num_epochs):: x0 T  m' ^; u1 |# L
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    0 S; T5 a/ {; v        print('-' * 10)1 y+ S. u2 R: l- J1 Y: K+ s; P' O6 ]

    4 g: ~1 P* t8 m        # 训练和验证4 y+ {1 ?& ]7 o0 i" ]$ g
            for phase in ['train', 'valid']:" {; V1 Z, s# |0 ^% V, J$ B
                if phase == 'train':
    4 @! c* \/ I3 g) W( K" j* R# ]                model.train()  # 训练7 }, {# x5 Y+ Y; i3 b0 r4 J: B
                else:
    - `2 O0 S$ a& B                model.eval()   # 验证( g/ C* P! i- e

    % Y  ~( i1 v2 y) y/ f* w5 s            running_loss = 0.0: _, K+ R9 O2 w9 q* f, ?
                running_corrects = 00 e0 ^9 t: T& s' _. y7 t4 Z. r' O
    6 m2 K/ h+ @9 A3 a* G' Y
                # 把数据都取个遍' G: }* l0 a3 E9 g4 Y  V2 G
                for inputs, labels in dataloaders[phase]:
    ( ?; e0 o$ O+ `                #下面是将inputs,labels传到GPU
    9 I# h# E8 ~! E                inputs = inputs.to(device)
    + R( O2 Q3 m$ O, F- h                labels = labels.to(device)
      ?2 i4 Q' ^5 Y0 S
    , Y! |6 m4 Y, |: i& {1 w                # 清零+ d$ `+ d* h: E0 X2 ]
                    optimizer.zero_grad(): r# L- r0 Q$ r! _, t* I
                    # 只有训练的时候计算和更新梯度
    9 D* W/ ^; o5 X  D3 ?$ E' B                with torch.set_grad_enabled(phase == 'train'):
    ) Z& }$ n& h. p" H4 w, _                    #if这面不需要计算,可忽略5 [$ q* U% v/ B7 g$ B4 k8 b
                        if is_inception and phase == 'train':
    $ H( z/ [/ P7 ^- u. l                        outputs, aux_outputs = model(inputs)# |7 i) B* x3 _" L" ~, t8 N  r
                            loss1 = criterion(outputs, labels)
    , K& P' i% j; K. S                        loss2 = criterion(aux_outputs, labels)
    : p8 S2 p# J6 o                        loss = loss1 + 0.4*loss2% w! z2 E/ {4 v% g9 N
                        else:#resnet执行的是这里
    + q. c% g5 P; f$ A$ z' ^- H/ k- t# r                        outputs = model(inputs)
    & O1 I" w  `! M& k2 v$ m1 d                        loss = criterion(outputs, labels)+ r, w& G# c" B/ b8 U& N- h( }: z
    : s/ s$ b6 F7 x% P, q- o7 P& i9 j- i) H
                            #概率最大的返回preds
    * M3 u$ |( Y" A* [$ r                    _, preds = torch.max(outputs, 1)) F- u! g( M" L: e! P3 L0 q

    " v1 ]4 A! X) T/ I                    # 训练阶段更新权重! o& G( O; m' O7 ^) @* r6 c
                        if phase == 'train':
    1 Y  ]( V* A9 c1 v0 X                        loss.backward()
    ( y3 z" v, T6 T$ M8 c  S9 H                        optimizer.step()
    2 B2 C7 U7 V! r9 i1 z4 n+ c1 q, c# q2 t8 I
                    # 计算损失- J0 v3 o8 a/ Q
                    running_loss += loss.item() * inputs.size(0)
    ) o6 T/ Y# ~& k- d. X6 X                running_corrects += torch.sum(preds == labels.data), E  V+ ^' n5 c; Z5 ^5 J( p

    . p! g" D; o# {# _" b% @            #打印操作
    . R8 O  b+ C3 g% L# Y: V            epoch_loss = running_loss / len(dataloaders[phase].dataset)
    ; Z8 m$ e. y) @1 X2 W' B: b2 w/ ^( s            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    $ X; b, ]; c/ {& c# ?% b6 o1 z# z5 F/ `0 ]4 K1 D$ l: \

    ( }4 C7 c: v  H8 m; D' s            time_elapsed = time.time() - since
    9 r3 E) p$ V8 w            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))$ X/ m3 x7 H0 ]$ d( b7 |& `5 ]9 j& H  T
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    6 y4 F8 o! ?! r) [: }/ B: ]  y/ V4 u8 S& U3 k6 o

    / d# F  c' m) m4 _            # 得到最好那次的模型% W! ?/ Z3 u8 ]9 c# F
                if phase == 'valid' and epoch_acc > best_acc:
    % x- {+ p6 H" P) M) H9 d& R                best_acc = epoch_acc
    ' Q4 ?( D" s  U5 P: e' \' m' o                #模型保存
    6 \/ A1 N+ h* T1 Y5 t7 H  ~                best_model_wts = copy.deepcopy(model.state_dict())
    " f# M& a, {- u" X: t+ N" t5 p# z                state = {
    4 Z$ X. o0 l* j1 m( V4 @+ M8 e                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数; @7 C1 [3 R7 [& e0 _" N
                      'state_dict': model.state_dict(),
    $ x4 Z; _  v- I/ ~. ?# _; u                  'best_acc': best_acc,
    . U+ R. N% E" Q7 i. _                  'optimizer' : optimizer.state_dict(),1 e- i" l7 t. e5 A
                    }
    : H; l8 _. p" c3 [                torch.save(state, filename), ?% g8 \+ G  B; C
                if phase == 'valid':
    % K* K' ]! E, ?* e, X1 ]+ e                val_acc_history.append(epoch_acc)
    ! l1 A* x7 Y7 d, r# c                valid_losses.append(epoch_loss)
    - p$ V1 [2 j6 }                scheduler.step(epoch_loss)
    7 w, Q' ]: S; L# V8 S            if phase == 'train':
    5 Y* v1 |- F, X! ?                train_acc_history.append(epoch_acc)9 L( r; C% O% g- }
                    train_losses.append(epoch_loss)$ z; z9 N3 c/ z: C; G* l
    * |4 a4 E6 o  i- w& S
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    , x8 _6 j. w. c6 z$ ?        LRs.append(optimizer.param_groups[0]['lr'])
    7 H0 v6 W' j* r' Z        print()2 i6 _8 h& c: W  q4 M
    * r+ `* h4 X9 m, o4 N
        time_elapsed = time.time() - since$ S% w* a$ B( W+ Q
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    - Z" ]9 e) Q1 M/ |' t' y1 A    print('Best val Acc: {:4f}'.format(best_acc))
    1 W0 |) n# \* X* ?/ ?
    6 t2 m" [  Z! X; y( }: v    # 保存训练完后用最好的一次当做模型最终的结果" m4 ?/ |, |& a! X& D+ T
        model.load_state_dict(best_model_wts)
    9 T% x: m7 ?7 e; ^    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 4 n$ c- f5 S" A/ b
    0 y2 T- O. M8 O, }' O
    ; \  r* x9 y: T8 x7 |) [& p' [
    1, \% {/ `+ k1 v" J3 C/ _( e* y
    22 M. N/ d# w; F8 h
    3( Z/ G% f' k# k2 {
    4" v; X9 ^. M+ D/ ~8 C. ~
    5
    2 W% Q! N$ k  H8 |* W6) u: _) i7 ?% `- P& R
    73 h# ~* ?" |% ^7 b! w  E4 A- M. W
    8
    1 S0 s/ N6 V4 x6 X" F9
    ( C/ U& x& e3 ], r, S0 I# b, s10. N8 a4 E- q" D& R+ h3 Q4 X# z" l
    117 ^* H' X) t. F
    12
    - ]5 r) {; @9 D+ f0 k13. X( O3 l- m2 k; }, ~* N6 c
    14( u) f2 m9 M" Q: N
    15% D% a5 ^. @: H, T1 Y2 |% G
    16$ ^: q/ R, m4 b
    17
    0 {% O6 ?9 a/ U2 G18
    + W3 p6 Q) @6 c3 ]6 Q19
    * I4 N7 i7 k8 c$ o+ I# T& c20% R5 `. c5 o& ]1 r8 o) K( X3 E  G
    216 ]; d9 ^- R, p5 o, N
    22
    * `  Q  {3 g5 G3 _/ M5 ?234 O3 W% y  w, f4 p  G
    24
    ' L7 c3 G1 ]$ B1 g& c25
    & d6 \! O! @  G+ V( }  [26
    9 |2 t2 @+ B! u/ m+ R2 b/ C27
    7 k3 M* R, n9 w3 \281 O; f) I# n6 f5 O8 \* k
    29
    , P3 x& |  j7 J/ e6 S$ ]30" Z% q+ e5 i2 ~
    31' ^8 I+ m/ o' h1 {4 Y9 c0 [5 O5 V" ]
    32" A- X0 Y( s4 e: {! A0 z; S
    33
      b8 |4 U7 X7 h4 z34. D) f% n- Q6 j! e
    35: {- H$ h7 ]' M6 M  G( A( a" |, u0 y
    36
    ( ^9 C+ t# L) n0 t37
    # B1 n. |" _& ^3 t' O38
    , `2 v0 Q3 P5 S0 e5 d- P5 ^. r39
    5 f9 X  h/ w0 D8 V1 I5 e* v  z$ m1 u40
    & c5 H3 _4 x5 w41
    + }8 {7 }: Z  c0 z42  a+ z; |8 _! S+ t% p" E/ P
    43+ s- s! Z8 K" R7 k( M3 i# h" y( N# \
    44. ~& j9 g" U5 _% U
    45
    + I0 @0 @8 `  H3 l' g46
    * n* _/ |$ K; }" }1 v. x4 l* I47
    " n) N& M$ G1 P  F+ k482 K! g8 y$ B+ Y( r
    495 [8 M4 z. p( L/ o  Z! m1 u
    500 |6 J0 P/ m5 M& [3 P" R% O6 S
    51
      [3 X! A  @6 D  w- ^$ d, ~52
    8 P( {/ Z( T2 p* V53
    6 y! j+ Q4 |* \. y8 ^54' `9 H1 p7 u- F  R1 T
    555 }2 H# B) e& Y' N/ ~
    56; t( y$ q/ h. K. U1 B8 J
    574 K- {' Q( s, f3 l
    58& o4 ^! T, j' n2 q' T
    59
    ! Y9 v& O% y) c4 I- B60
    $ X! B+ J' p7 s5 Q  Q. U610 g& }1 q( ^/ n
    624 O) ^0 }! W7 c2 |3 N3 ?
    63
    6 l4 B1 ~; x& m# T64: z3 Z( }4 D: w; b1 j
    654 w- J; j2 u6 ~* j4 g* X
    66
    " P2 x# Q& v: E* I  _/ J  E678 q: c, O7 ?- g
    682 u% G; w% z- U5 j
    694 @+ q$ Q) `0 G6 r- H6 t7 R  s
    70: T- ~- x* ^" [3 M$ ^
    71
    2 m4 v* Y. f  }/ G3 R0 n72
    # ~. R2 N4 S$ }4 c# n# Y% \730 X% D9 U* J2 H
    74: U9 h* m7 B0 D) Z% H1 E: ?- w/ }
    75& w7 Z% S  i' [; g) I" W- H
    76) a9 z+ I1 _! ?: y
    77
    7 M$ `- X* a- V) {4 @78
    3 a0 f7 T0 ?/ m& V+ |# q9 |79
    ; {5 P7 J; j" [80* l2 N; A! h# R+ v. K
    81' x& k. u7 Y; z. V3 v* Z' o1 s
    82  E# s7 B1 X. Q) m$ y" q8 ]
    83
    + }) t/ z) C9 C5 T+ [84. v5 R, {' h6 w# G5 e
    85
    * {+ O* ?9 h# N' a6 |1 p  R86$ a' s) w2 R, D, g9 c
    87* w2 m' d* `/ D' V2 [3 w
    883 I/ F! f! Y% I5 C  B
    89! r: X  y) x; H7 `7 L
    90
    * u  P5 T8 T% M: p$ Q91
    # u3 \  C+ ?" l6 W6 Q1 N92
    . F: D9 m+ L  {93
    % k+ v; @. ^2 }: P0 Y; `94% ^0 i9 t* W2 X, |- q# G
    95- P' P3 W: c* s3 X) ?# Z. [! n& T5 H4 y
    96
    . e$ M6 k6 s/ W) L8 h/ Y5 M97# o: F! D; ^; ]& q/ r& b
    98
    8 I1 u0 w2 x3 G/ |$ T! s- |0 E99
    : ?2 A( ]0 e4 W9 j. ?# ]5 }; s" V1009 k; C( t& z% F+ R
    1010 M# Y7 i+ a/ Q7 Q
    102
    # ^* J* G* S- R; w% m103$ N4 z* g/ P. P$ c7 u! z
    104
    4 Q- H% W7 k* Q! A105
    8 |7 T4 x- }  l) {106- _$ U# B# u: j8 s
    1075 b& Z; Z# z' V- o& I3 x( L
    108) F, A- C& s' F( u5 S$ O+ {* Y# d
    109
    . g! R& A/ _' j2 i1 h* c. l% m6 v1103 c  J! V: N% {% I, ]
    1118 G" }2 p  r0 E, O
    112
    + n3 `0 ?) K( d4 V9 K0 M; [7.2 开始训练模型
    7 C- \8 E' G2 Y) @7 I+ _( X' [" p我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    ' a, m. d6 f4 o
    7 H- s0 O" v( D" E6 y; k#若太慢,把epoch调低,迭代50次可能好些: q0 C: m/ r3 K2 d7 V0 [; h2 G
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合0 _! p2 T% C/ d9 q
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception")); w' x* M1 p" Z1 N

    ! P& O( w5 E; O# B: V1
    ) \" g! {) Y9 C0 q4 j24 D5 q( {/ `1 m* z6 m
    3
    ' C1 e6 Z! l, s4
    + Y- U! N$ M! OEpoch 0/4
    $ c! E' B8 B4 L----------
    ' ]( J/ e: x0 G8 O$ U5 iTime elapsed 29m 41s" t% h& L/ ^, P3 j
    train Loss: 10.4774 Acc: 0.31473 l1 Q( Y4 O9 c( f
    Time elapsed 32m 54s8 @, S: N% E. D$ [* h
    valid Loss: 8.2902 Acc: 0.4719' h" o8 k0 A" U
    Optimizer learning rate : 0.0010000
    7 o6 @4 [( f" [8 p4 D
    1 b8 B, Z" s' fEpoch 1/4$ N! I, m& @9 t4 F( Q) T7 ~) O4 b
    ----------& b% i' l: T: Y
    Time elapsed 60m 11s; H4 ^, U0 i* \$ P3 m
    train Loss: 2.3126 Acc: 0.7053. Q7 U' H8 y! j; G, \
    Time elapsed 63m 16s
    " x; p& O6 O3 c2 \valid Loss: 3.2325 Acc: 0.66262 i# u8 s  I) X9 y6 t- G- t) s
    Optimizer learning rate : 0.01000007 s: g+ \( T2 Q( {7 Q0 w9 Y

    + B- D( M0 j7 g  x* Q8 _Epoch 2/4* m- _5 {# X$ d2 z8 S1 J
    ----------
    ) X" Q2 D6 Z- S* X  L: bTime elapsed 90m 58s
    : L8 m* \( W; n5 J  r( dtrain Loss: 9.9720 Acc: 0.4734; @: m0 F3 C5 J3 v9 T( G( `$ K7 t8 N
    Time elapsed 94m 4s+ f3 I( H3 F7 ^2 @1 ]
    valid Loss: 14.0426 Acc: 0.4413
    , G) o( e% X: ?5 v( q+ h4 H* XOptimizer learning rate : 0.00010009 J2 Z  ?- v8 K$ T" h' ^

    % x" N4 z( ]$ r- n) oEpoch 3/4
    , _6 c% z9 E1 Y4 [----------% t: t7 N' S! Y! o. ^
    Time elapsed 132m 49s8 O) A+ a( X; d9 w: ~9 @
    train Loss: 5.4290 Acc: 0.6548
    3 O& C* H% I; F# ~Time elapsed 138m 49s- [5 p7 }( O1 H/ w
    valid Loss: 6.4208 Acc: 0.6027, f. }# w' l. L9 `
    Optimizer learning rate : 0.0100000
    9 U% _* z: @7 J( W; O, \* ~* E+ U
    6 }2 T1 x( F+ P9 d, U& c8 fEpoch 4/4
    1 `! u( t! H3 b: X) ]* b----------% z. c6 e3 ~" V9 V" c" U* u6 t; V
    Time elapsed 195m 56s  l7 ?- O5 T9 f" l$ `: x
    train Loss: 8.8911 Acc: 0.5519
      w" i, i& y; gTime elapsed 199m 16s# Y0 w4 s7 j4 ~- g% `
    valid Loss: 13.2221 Acc: 0.4914
    : p0 f8 n6 H6 S4 [Optimizer learning rate : 0.0010000
    5 q  `' v5 O& U& w6 A7 `$ g% g3 s0 {0 L( |# M0 ~
    Training complete in 199m 16s
    8 q: p+ h* V9 a5 P; ^/ sBest val Acc: 0.662592
    ( l" y$ c, N. O% e1 x4 W
    * K( M; R$ Z: L* z( Q' O1
    4 v7 K' v, ?! r, r! K0 T; P2
    ; H" V; j+ m5 [  Z- M, ]: b) O) m3! ^# T7 c" @' k4 i
    4* l/ h$ k( l: ^/ l. N
    5
    " \& v# \% F, e1 E3 _6! T& {% `3 [2 j3 a2 k8 V
    74 W9 P% Q- h8 |' @% }' ~
    8( X( @/ x9 ~3 q3 R$ ^
    9  _, y- Q! P; G
    10) p3 K3 a- F7 l! L/ f" S- p
    11
    - w3 Q! e; y- A12
    # l$ ^* \" h  x! p) H, E. e) P, j131 }1 u) u% o( U& H$ I5 {; u# E
    14* x+ |+ N5 v- y+ j* X3 f" c6 B& J9 A9 X
    15( [- U! Z5 n! |0 {" x- W  C
    16  g* f5 `3 i5 N" E/ y
    17
    ) E! z# S$ n+ \4 g7 j. j' o: [$ c7 ~18' ]- [2 h" O" I! X8 w6 @8 C
    19
    - f9 h& [' h$ B5 h5 k( f* H20
    , h* y0 n- s. r& y215 b1 q! V) z; O( h% S
    22& z/ b3 T- I% b$ w/ N/ G% K: s2 T
    236 a( N3 y: U7 P
    24
    # n) ]0 n2 N- Z: H25
    & B6 v3 O+ p. L9 x! z/ r! i26
    0 b! o1 u. V4 S. a; o4 J27
    ; o5 s) i+ R  F2 v) a6 r7 a28
    4 q4 d3 R5 M2 P4 f8 v7 k" a, [3 [29
    1 N0 [* x" d# C; I( o% c( F30
    ! r; U* ?; D' k  K4 O8 S& ?( U31
    % E4 D/ V1 P  H2 ^5 e; D: y32, u+ N: f( z7 A; @  H) A
    33
    3 b$ J" \& q1 ]' F' V+ a3 `34$ }! y7 Y5 S& w' R+ w" g. i6 Q
    35% @) l* U2 g8 R, N( t
    365 b' S8 X$ I' H( T" g
    376 V5 j% Z' e" w
    38- W- o1 @, |; U" l8 f2 ]3 l
    39
    8 A* a) v; D: @, t" r' _40/ u% B* C3 Z/ {3 c, @6 g/ g
    41
    & G3 K6 z: D5 Y+ I42" T1 [0 y) t! i2 e% T: \4 v
    7.3 训练所有层
    + ]/ v6 a9 W- q8 r, j! N# 将全部网络解锁进行训练
    : V- `% y, F+ X# e$ i6 v- ?+ w5 U7 V/ ofor param in model_ft.parameters():
    3 A( K* m; p: m    param.requires_grad = True
    9 Z; m# N( A# P% I, w; z* p" h; w: G; w7 ]9 x! j  _
    # 再继续训练所有的参数,学习率调小一点\" H" a$ d/ O) I% ]# K# S- c# w
    optimizer = optim.Adam(params_to_update, lr = 1e-4)0 e# b9 _7 m- @6 r: E7 x
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)  {4 l5 ?1 K$ I. N* w' ?
    $ _- Y0 D3 m8 }" D% {
    # 损失函数
    : a7 k- d( V1 Dcriterion = nn.NLLLoss()' e. P: Z7 G7 S5 C% w) ^
    1
    2 Y6 m9 K) I8 b21 k# r, I5 X2 r1 O  p
    3
    % r4 m3 p1 |+ v  z) Y4$ n# A% W) k0 J2 @
    54 S- }# [  r/ }# `& T+ {0 u3 O
    6
    2 e) e# ^9 C# ]3 }) ~/ Y" b! R72 ?9 H% W/ a$ G% i. L4 ?# v
    8  R2 O" K5 d2 S% g! j* U7 t: V( y9 s
    9
    6 ~+ A6 U! {; D! G; \10
    , A! D# J, T+ z- k2 I4 w& }/ H# 加载保存的参数4 j+ s0 i# t1 r: h0 g& b
    # 并在原有的模型基础上继续训练4 q- Y  N! P+ e" _% f$ ]
    # 下面保存的是刚刚训练效果较好的路径; h, y: T! a: h% t% p; M
    checkpoint = torch.load(filename)+ V& W1 M9 V: f
    best_acc = checkpoint['best_acc']  F. |( N) {& l5 X8 s; g
    model_ft.load_state_dict(checkpoint['state_dict'])
    6 f: G  u2 l) I& q! w' ^optimizer.load_state_dict(checkpoint['optimizer'])+ e& W* `8 L6 q9 \7 V
    1
    ! q9 j' F9 R  Q& a& ~2  {% O+ ]# Z) a1 z- V  v# }8 _
    31 v/ I% c" z4 t/ A4 L/ M
    45 h& t- s6 c* m7 }
    55 V7 {5 J1 {1 D4 p( F
    6$ a1 ]7 o) y6 L3 Q' p; ^2 T' d
    7
    7 c, V5 n8 w. F1 v& g1 s5 w2 j开始训练
    6 A1 @0 n. x. K5 q% L# h注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考9 {' h: E$ [# n2 h' R

    2 s- f. ?4 }' v) V+ l5 k/ v/ Xmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    7 n1 G% `4 J3 ?; ?" g1; G& Q' b7 I0 N9 @
    Epoch 0/1
    9 W  I) r# E! I! Q/ L& u----------
    + K1 C5 E5 x2 Q0 D* g. TTime elapsed 35m 22s0 N+ S4 q( [- I  ^# V2 ]5 L
    train Loss: 1.7636 Acc: 0.7346+ ~9 M& Y, L' ?4 u
    Time elapsed 38m 42s
    ( E0 \1 f6 C( K; c! k* _! Vvalid Loss: 3.6377 Acc: 0.6455) ^" ~- W9 W) c/ R* O/ W0 c' j2 ~
    Optimizer learning rate : 0.0010000
    : W$ E5 o. v4 T9 F' V2 T
    $ @4 C9 ?8 S: U  u& i; gEpoch 1/1
    4 u4 C) k# l/ }1 p+ P) w----------
    4 B. U6 h9 @9 T  X0 N+ NTime elapsed 82m 59s
    " Z$ C* ?3 d+ \6 y" strain Loss: 1.7543 Acc: 0.7340) M0 a/ d5 Q3 O% F
    Time elapsed 86m 11s
    $ A1 f0 k- v2 a6 a  Z0 Kvalid Loss: 3.8275 Acc: 0.61377 h$ W! s. E# r, k: T' o
    Optimizer learning rate : 0.0010000
      b. G. C( h6 V3 q7 j& C2 e. {( q7 Q0 }3 w6 t
    Training complete in 86m 11s7 F; s& N. {4 \+ k
    Best val Acc: 0.645477
    & v4 E; `% `; @' Y$ p8 q
    , f# R( ?, ~  }% P3 d1* p7 }8 g0 O& z
    2
    5 R, z7 d# t( h3
    9 z; P" t. O  y  w2 s' m4
    8 ~  i* J) K- J+ |: _, L+ Q5  W' q4 R1 {- n) r2 c+ J6 F
    6( e9 {+ a! U1 k0 C& |; |* W
    7
    9 ]8 ~9 r9 \& m5 n8
    " J% K) E% O% C9
    6 Y, v  {7 R2 @2 l9 l# R, c10
    , C& H. t$ e/ h: b& w11$ I# B6 T# n& v! G1 \( k2 ]
    12
    $ @% F. b- f5 [8 U3 ^/ K4 Y. }: p13% Z, R' A# ~) B! _# l/ a
    14
    # N1 h% J. o  t- w0 E" {15
      }& @* ^# `% _165 p9 I( J! W9 q6 F7 j$ k9 W& w: ~
    175 x" @3 i- o/ P* z. x2 L
    185 y1 m9 g3 ]6 Q% t* \5 m
    8. 加载已经训练的模型
    + [( k3 C, w- @6 _相当于做一次简单的前向传播(逻辑推理),不用更新参数' _3 u% H  |" s  |

    7 {, |0 n6 ?' K2 v2 h8 o7 [! ymodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
    , S! N* ~, l1 i  I
    ; F# U4 Z7 e; _8 ]0 O# GPU 模式3 V; E$ Q: A  \
    model_ft = model_ft.to(device) # 扔到GPU中
    7 c2 ]* R7 h( r% `' y! T5 H9 l; U) ]. l0 j% P% t1 s# `; n' P
    # 保存文件的名字( }3 B1 i9 X6 T* b- _
    filename='checkpoint.pth'
    , @: ?% H& S3 V
    - z; D- U+ w  }8 ~9 a# 加载模型
    . f8 v2 d0 u0 S! A4 n2 {checkpoint = torch.load(filename)
    # h  M. H5 p9 x8 @& k1 b) ybest_acc = checkpoint['best_acc']
    5 M8 s  _9 J; B* e2 {9 n' {# Dmodel_ft.load_state_dict(checkpoint['state_dict'])
    ( t# u# c9 \% e! @1
    . z8 B0 X* K! a) p' p  |; f7 h29 Z- P8 h# u" @% C1 Z3 s
    3! a% }1 U/ L/ @9 Z% }' `9 T
    4
    3 Z' U) e( O/ D3 S2 T5
      A+ E8 U6 T/ s8 n2 F69 h6 o1 F( H% L6 g  i2 _" z
    7" h" z$ Y  J' T' n. u+ w
    8/ P8 e  R! x; n/ b3 G- y$ m
    9# J; w$ S+ F, q2 e  I4 h5 b
    10, h' Q3 }" @- g9 M
    11# k- x+ H$ [: K! k' T
    12
    4 l6 y, Z8 _; ^5 F4 d% f1 |5 Z<All keys matched successfully>- R5 r' p- q9 H7 X0 m
    1. Y; j5 u: f' y2 I; Y
    def process_image(image_path):
    - h# L" ~  _, ~7 B. F; |    # 读取测试集数据" T( q+ S6 m; f  b: @3 p: \/ _2 a( g
        img = Image.open(image_path)2 `& T7 u4 _# C. j9 c  ]& S
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断5 d+ S; r2 o3 |1 @1 I
        # 与Resize不同# ]+ I6 z2 B5 D
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    . h& \, W. |. w/ o$ @    # 而且对象调用方法会直接改变其大小,返回None
    ; k+ i8 {5 O; l  L. Z    if img.size[0] > img.size[1]:
      U5 g  A7 W/ X( G, c        img.thumbnail((10000, 256))+ v5 [& z$ F7 a) `* V: j! b! N
        else:
    # ~0 @3 N% g5 B# o+ g        img.thumbnail((256, 10000))/ |. c* Z5 Z: O0 z
    . d) d% y  E2 X- Y1 i
        # crop操作, 将图像再次裁剪为 224 * 224$ a0 L4 ^8 b; |8 S9 s3 m/ l, h4 `
        left_margin = (img.width - 224) / 2 # 取中间的部分( ~( s: r9 j- E5 O0 v
        bottom_margin = (img.height - 224) / 2
    + m6 U: C1 T0 X+ {' {( J$ O& ~    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    * r: n' i; U' j8 K" {    top_margin = bottom_margin + 224& v$ H* D3 X. T( q0 n
    ! V6 @0 g# R; I- S' |4 m  Y
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    ; ?; G% ^! F. R  \9 X2 g7 _- ?8 o5 w! ~+ x5 Q
        # 相同预处理的方法6 v7 o* ?+ y& \5 ]
        # 归一化
    6 p, |$ s# k3 J: M    img = np.array(img) / 255
    9 ^( l$ H4 v5 u# Z# s6 E; \+ @    mean = np.array([0.485, 0.456, 0.406])3 a( Y7 X, j3 G
        std = np.array([0.229, 0.224, 0.225])
    $ F$ `$ \0 F0 C* z1 y1 l8 e, j    img = (img - mean) / std1 y, N1 r$ p' }+ I. `
      T( M  a2 |0 ~
        # 注意颜色通道和位置/ A% f$ Y* S' Z& d) m6 H
        img = img.transpose((2, 0, 1))
    7 O+ n6 s5 D/ V* @6 K1 j/ c& {3 ^% W7 p% r$ h/ B) l6 `
        return img
    / K1 Z. n. w9 l, s; E" q: n5 @* }, p/ S4 f0 R9 W
    def imshow(image, ax = None, title = None):# L+ Z) T4 J1 e; J' G2 v0 a1 g
        """展示数据"""6 e8 ~9 M: k: r* J2 J1 B) |
        if ax is None:
    8 J5 @$ l; v& K$ t# f        fig, ax = plt.subplots()  O1 G. z3 s% r
    % K+ x2 l! K5 o% e  O/ ?) @1 z
        # 颜色通道进行还原
    . P) L# e* O- p: `    image = np.array(image).transpose((1, 2, 0))
    8 L6 }4 _- |3 a6 N+ g, E' o& F) b2 p9 V* p  l3 r/ m
        # 预处理还原
    8 S7 a3 W/ J2 W0 n# n& B! ]    mean = np.array([0.485, 0.456, 0.406])- B) {# |8 U; P. k
        std = np.array([0.229, 0.224, 0.225])% ]/ ~& P* r) i
        image = std * image + mean
    ) S: a) F, D% `    image = np.clip(image, 0, 1)
    & F) ^) h2 i# `: K7 M
    0 E% u1 I( |2 f  C, h' i4 M# _    ax.imshow(image)
    ' `6 i# @1 j; t1 X  ]    ax.set_title(title)& e( `8 _+ R4 G/ P/ P& ]% ~
    6 c* `* [7 e) c1 H5 d/ v: B
        return ax# m; |- n% f! c+ g$ P) s4 S
    - X& L  F* t- r* y& F
    image_path = r'./flower_data/valid/3/image_06621.jpg'5 w1 ]9 d2 p- g. P, D3 k& |
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理, [) l+ P4 X+ A; E; |$ V
    imshow(img)
    4 O- v1 ]- H% i- ]$ G5 ~9 X1 ?- {8 B; N, D* V/ x+ f' {# B4 Z
    1
    & U6 k( W  `9 a) b# P2 N% B3 [: }2
    9 D# E/ \# ]9 F  ?1 d8 \( @3$ B  N: E. Y) e: q7 E
    4
    9 S0 v% z( Y& E* E+ D51 e' _; R+ c1 P" y8 ~! ]8 `/ G
    6; B  M; k: Z0 l4 M8 \1 t4 E! X4 P
    7, [2 C7 m" g. M
    86 v, ]& V0 b# D( x
    9
    9 \) R* Y4 r6 Y- M$ V  x" [' y10
    9 f# N) H6 }0 v) [" ^" @4 v+ L11
      r9 Y) E# L7 o0 f, |1 E8 j12% l. b3 ^& O4 t' x; r9 J, U+ Y
    134 F/ D& J3 R# E. a5 O
    14- M3 O0 N; d( `3 D+ l
    15
    - b% L8 l: `9 [16. Y) N4 b7 D5 T
    179 J  ~0 `9 Q/ [) e. ~7 [% N+ H( P
    18
    & f+ i/ g" E$ A) z0 C' m19
    ) ~7 ^; f. T& i% V1 d+ B20
    % b( f# `9 y, M  Y# w  W, |6 z21
    & B/ H/ K0 R1 _0 l5 f/ Z: Q$ d222 X9 K& ~4 H8 d7 a$ q
    23
    ; n4 U: D2 k) }* q7 ?7 P+ W3 e24! y! E6 a4 H) J- j5 I; E* P
    25/ d) r/ _1 O% u! c1 n% ?3 `& O6 j
    264 W! V& ^" W5 a' L6 w8 |+ c" u
    27
    8 ]' r5 a5 I% I$ J" j" a% [28
    + u' _: X" F/ h, T  j1 B3 Q: S0 V& {29
    8 q! h" J7 N0 A30
    , P! s. x4 A& N5 d+ n: j31
    . f9 W9 u, n6 h: \  i32
    . W3 @% ^% e7 g3 o1 b) O- b33+ f+ J; x1 z" L( G. [. m1 S6 b
    345 c. W8 p. Z5 H% ^+ F& B
    35
    # }* _1 c: Q, L' D1 H  u7 W# s36
    + E- y$ X5 f* d6 D# I& v+ F  t37' ?) d: p. G9 a: ]1 Z
    38
    0 _. W7 e1 Y$ K! a' A  w39
    ( G. _* ~2 M9 d" y/ C40
    ' n, P  I/ D$ b/ ]. x4 ~, [0 A) j411 N4 P$ o0 h7 W6 T# i
    42
    # U9 z: `! {& T43, Y' q9 N  t1 l* ]/ |
    44
    ) b5 f8 I2 ]% B2 T4 r45
    8 e8 U& t! P& b) p1 C* H46
    ; r# ?9 g* y0 t! f6 S/ ]3 G! j. c/ [47
    7 j/ y) C  x, b7 W0 V' i) e48
    $ W/ r* r# ^& }9 j% q1 z) Z, R49
    * S% I1 ^5 e# [; a50( K7 n/ G- {1 I7 F; R- o* T0 f" u
    51: r& R3 M$ @( `& G( a3 g; E
    52
    * v8 M: D( U3 @4 e& D5 f- C. K7 y9 E53
    ' X1 `$ C4 ?$ [- i+ d54
    ; Q2 X6 @& W: _6 t9 Y  n7 I<AxesSubplot:>
    - i4 O2 t' S  ^) H5 u# a$ q; {1. Y; a1 a) L4 Z7 X% M( ]. Q
    5 D  M) {8 u/ {
    上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确1 g" G! U7 s: f2 C% t( q  C7 {  {

    1 O5 H8 A' d' L! z( u; limg.shape' n5 M+ ?3 r$ }6 F. ~* S; H
    12 t9 W; L2 `+ y/ ~+ F5 ]
    (3, 224, 224)
    " T9 n( Y" X0 L3 _1
    $ ~+ k# r2 x; ~! b/ N证明了通道提前了,而且大小没改变
    . r) B8 _* c& b7 d  Q2 [  D8 g: s5 M3 R3 ?+ \# O! _
    9. 推理
    9 {: S* o; p, S8 `img.shape
    ) m9 I3 Z0 o1 q5 g( M8 e+ {6 A* f  T1 V
    # 得到一个batch的测试数据
    . l/ h, d3 B; ldataiter = iter(dataloaders['valid']); T' O: D5 r, v, V* Q
    images, labels = dataiter.next(): c6 G6 I0 r$ u; d
    & o' q! Z; M$ o  O# X1 l
    model_ft.eval()) f& h' }: n4 f; N9 W! A1 C

    7 A! N6 }- U6 [% T2 eif train_on_gpu:* N, g0 p% P! O
        # 前向传播跑一次会得到output
    # J7 ], f3 K9 v    output = model_ft(images.cuda())5 D7 N0 `7 A# ?1 l5 C8 Y
    else:
    1 h* O# k; T( F5 \: Z% k7 z    output = model_ft(images)
    6 L4 C/ Q5 O' o; k! c
    6 F  ], |& W/ v; @* r; G8 }# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值2 M* m) i* a- y: \5 w( j& H
    output.shape
    : ^* V: D2 a( I* `; G: k! Z+ F
    " n  k, ]5 k* F11 w) F0 A& n- ]0 a
    20 J1 b! ?1 Z0 u" M9 ^
    3% W/ p6 E" r, @
    4$ D9 ?* n1 H. v
    5
    ( U" y: w% Z& q6
    % L9 \7 x9 n) @+ i1 J  S! B0 k72 E( W, T1 v3 z" B; P3 f$ p. E
    88 F# |3 Z0 y4 L  D
    9" K$ x3 ~8 F& C. K1 n- ]8 G
    10; _* D; p* d/ u
    11
    # d6 c! W( F- P# o) M0 K/ y12
    ; x5 b2 l0 \- \2 R/ Z8 Q* L& g137 X. E$ O+ I/ `" R* R) A
    14( d. K1 m& l' `' Z
    15
    ( J) N; H- k* i; y, g% z# \16
    $ h7 q4 b9 X6 V9 d# Xtorch.Size([8, 102])
    5 ~- t& p2 V' s0 b1
    0 r& R- G( U" a+ }( ~9.1 计算得到最大概率
    ! a  @0 T: D! E% o/ S. F4 S_, preds_tensor = torch.max(output, 1)" I+ q7 I6 @" c
    ) F/ G. i0 F# C0 W
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量& I/ \8 l, h; r2 g6 V
    1
    & h3 P$ _* d/ ?0 g  X: i0 C+ l: Y& U2
    5 x! J- v- a8 o' l$ H  m& }* P) x3
    + r" A% B0 V: S* x% G! L9.2 展示预测结果
    , V3 ]/ ]0 p( [% r% b4 g# ?fig = plt.figure(figsize = (20, 20))
    - o* x- F0 _( u+ P$ C) ?- ]columns = 4. _! ~+ g4 u: i1 |
    rows = 2
    9 p3 Q; \& X. n9 G/ H: b! A: z# d+ b/ d# n; {- G# v: w) {% D
    for idx in range(columns * rows):/ Z& m* B% O; p5 P- p! z. y0 g% f
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])9 {2 `  _3 S* d* [+ Z3 s7 }4 c
        plt.imshow(im_convert(images[idx]))" O1 i0 ]4 L$ ~: @. B( B+ p, O
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), & v# c# C0 |" y
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    $ ~7 g6 ]. t5 `* lplt.show()
      n/ J1 t+ H1 n3 G  W9 M# 绿色的表示预测是对的,红色表示预测错了
      a' q6 W& O- o" W1
    9 f1 r& C  ?" \: {& _' v. `2
    0 I% N4 ?9 U0 H+ Z3$ @. A! ?+ Y. L) ]/ m$ ^# x5 D! i( o6 x
    48 I, }+ L: I1 k; i4 F
    55 w6 J. m8 f/ p- ?* G( g
    6+ J6 k/ @; Q& [: y9 `% n
    79 J5 m6 C! Z/ \
    80 o0 r* P% ~- ^  w$ ~
    9
    7 ^- M! d, Y& a, U$ d# K7 S9 q10
    1 I8 Y6 s# m: S112 q& d  }  K; A% n  T

    1 A7 s; H  K6 E1 D9 ~
    4 X6 h0 w* _2 G; Z: T  e6 g  m) l: u, F0 c; \2 |/ g* m  ?5 }5 Z; F
    ————————————————2 D  z$ y' h7 a0 B
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。2 m5 W: [; v' w" o) n' V+ q) T
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    ) k% ~- b- ?+ u1 h, H$ x* H# w; Q9 f6 f, @( T+ W9 b- {, _
    7 q4 W# L; ?6 y& x7 |
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-15 03:30 , Processed in 0.508441 second(s), 51 queries .

    回顶部