QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2296|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例" V. e) _1 b. ?& Q& Y# X

    8 S8 h# I% T, C3 }1 K文章目录$ i% D0 Q2 w% O* G. E& a7 z
    卷积网络实战 对花进行分类
    2 u3 ~) [' l5 C  I; ~5 Y5 j! V数据预处理部分8 j6 |: C" D+ z( K$ u& L( i
    网络模块设置8 [( Z% X6 _  Q1 F" `+ g
    网络模型的保存与测试* d3 M2 E. [* f
    数据下载:- W2 X. w' g# n5 J
    1. 导入工具包
    3 R; f- x0 ]! Y+ i4 K* N2. 数据预处理与操作
    ) d+ D& @. Y/ ?6 N3. 制作好数据源4 g% z- J% B9 Z+ A
    读取标签对应的实际名字5 i6 p. h- @! K# p
    4.展示一下数据
    % K, y7 w0 J3 b8 m2 _5. 加载models提供的模型,并直接用训练好的权重做初始化参数$ f% ?8 X- @7 n" L& C3 |
    6.初始化模型架构
    5 d7 v3 G* k, P9 n/ h6 W0 y" r+ R7. 设置需要训练的参数1 Y- a! {3 u' b0 A
    7. 训练与预测
    ( |4 w4 X! Y4 m: k: Y0 R; E$ I# d7.1 优化器设置
    * s1 s6 B6 _' T( Y+ V0 j2 w* v7.2 开始训练模型
    ; B2 ^5 L/ l: k6 n7.3 训练所有层! g/ q; N& C3 s0 c  C
    开始训练
    9 H6 V- e7 V2 _# J  b( V! ]8. 加载已经训练的模型; j3 ~, @0 V3 f, {& u2 q
    9. 推理" w# d* @8 m8 {& s' V: u
    9.1 计算得到最大概率1 |7 `& [: u7 O; h: K
    9.2 展示预测结果
    ) W3 k# B) W; k+ }9 c写在最后
    / ^  [2 F7 }$ o0 R* }卷积网络实战 对花进行分类
    1 w+ }, @  B: Q2 u/ `0 W本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    ! b9 Q: X. Q& K2 N5 `* j# X' w% M% e& l
    在文件夹中有102种花,我们主要要对这些花进行分类任务3 m2 M( ~) m) }. F- P2 O6 D1 u3 u
    文件夹结构
    & E: [6 h% o! \% B( r
    % {+ O9 G/ z1 _1 r3 w; Wflower_data$ ]# k0 M# r4 o% {  `" h

    % H5 H. s" ~4 u) W% mtrain3 v4 w; I% A; M. W

    % d- e+ k/ h' z4 \% O1(类别)
    * p6 N& W% W7 K  }" N/ j" d2. I/ [1 y/ W+ e0 ^) D! a* X6 W
    xxx.png / xxx.jpg/ h5 M" _  j! X% ^: W; c
    valid
    " l/ D0 }% n7 x: k# U- p/ G% H. ^" ~/ M0 E. L' p
    主要分为以下几个大模块
    3 `! }( K! F& M) m9 u9 S7 \2 R: {- z& N0 J8 z3 Y
    数据预处理部分2 [9 _; I/ @8 m" a
    数据增强9 @" r: A( y: o3 _5 Z% f. U
    数据预处理# c4 t& q5 l5 @& G) N* y- m
    网络模块设置
    5 P2 O5 Q: M; v0 W$ O6 y# Q' ]+ T加载预训练模型,直接调用torchVision的经典网络架构
    - ^2 @$ H& ^( }% t因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务2 q0 [& j- w( v3 X
    网络模型的保存与测试
    ) ^& t2 i/ m) B3 t; R4 k模型保存可以带有选择性
    # [3 ]& y4 f* {( ?8 @, A数据下载:, X  ?1 a; W0 v
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    5 B" l% ^" M" X$ `7 c: t6 I$ x" H) p% D& D
    改一下文件名,然后将它放到同一根目录就可以了
    # L; V$ e, O! O; k6 W
      f3 b! k  \% H: j0 H+ w+ h下面是我的数据根目录8 e9 Y9 V' H8 k, y0 Q' b
    3 H' @! @9 e0 \& {. L  k

    7 |+ {/ S4 A1 q# A9 w1. 导入工具包
    6 X! Y$ {9 G6 F4 A0 t( Z- Zimport os
    ' w4 }* `: q9 O1 z# ^$ {import matplotlib.pyplot as plt
    9 I( x* h3 m4 a8 [# 内嵌入绘图简去show的句柄
    & w& L! C0 \6 I%matplotlib inline * Y+ [5 v, C% X/ ^
    import numpy as np
    0 g4 J, L$ {. R" _0 gimport torch" N) u* F+ t6 D1 Y" {% S
    from torch import nn# x- f" V2 |9 d6 c) X+ Z% p

    3 ?: a8 _% ?8 himport torch.optim as optim
    . J: a- I0 S* ~3 Qimport torchvision* Q' l8 A# p0 x) t
    from torchvision import transforms, models, datasets, V5 r$ C5 t+ X  v" x% W2 M" V

    , \- B& }& |6 N# iimport imageio5 |6 _9 _' S* L$ {, U
    import time
    % b7 x, S8 e* y' ?  g  \7 G, himport warnings* l' g5 i8 n5 {/ A9 L/ j
    import random
    * J% x$ i( P5 e' T# D8 W) W# kimport sys5 u9 K/ E# O$ j, ?
    import copy* J. t) l; z6 F# A5 w4 i
    import json1 P5 }8 M) R) @& X! ^
    from PIL import Image
    0 P  o+ Y0 p5 w8 M- a7 g% Y4 `) X! r( j; y' u) u

    1 B. F% Y- s: E2 F4 O- p0 |1 c3 X1/ l9 K6 |7 u6 O
    2" e; y3 ]: q& L: n4 y; b4 g0 d
    3
    ; U+ E) f! j1 S: m+ `3 P. m3 E4+ h. `# P0 {# x4 {
    5
    # F8 z; r' C7 ?0 Q7 V4 Y. V# D+ u  x5 A6
    * O$ W, x$ v0 j: N  J70 [9 w# |1 J  j' `' K
    8! f' C- t* p" A+ q! g
    9) ^% O' C% w3 F2 b* b1 b' y
    10) A& p+ }' k  j0 I
    11" D* {7 X' E) V  B5 ?
    12
    9 H6 E1 `; Q- D: \5 T6 G13& }* u/ {5 h/ |1 F* P
    14
    4 t5 b; G  l* J/ y; ?# Z) b' e15
      r" M- k3 {. ~3 s1 j16
    ' m/ y' F% d9 t8 E$ g* d8 h17
    8 C9 w/ v- S$ R2 ~/ m180 S% N7 Y1 y8 |
    19
    1 X2 _4 `$ b& o" J* [7 l20+ k+ s, y9 }8 X
    21; Y. [' L7 F4 D/ y+ |6 u. Z
    2. 数据预处理与操作
    , u* S4 U/ O' d$ \, e2 P; W#路径设置) x( F/ Q' J( P5 m5 w
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    ( O+ f$ Q# K/ `9 Z8 ttrain_dir = data_dir + '/train'
    " O' e( ^. x" V/ Xvalid_dir = data_dir + '/valid': D4 p! g, }# X$ @; m5 y( g3 ?
    18 {$ r9 w$ @. U! A
    2
    ( e' P2 O( M8 V, `3
    ! k. W$ D3 w, _/ g  n: \% @8 N4
    / T: i, q1 [& \4 t' R2 _8 w) apython目录点杠的组合与区别
    % m  Y- w6 {  k注: 里面注明了点杠和斜杠的操作
    + d# k* I+ Y5 o7 R* q8 {. w4 {6 m9 O- }$ r9 k, o
    3. 制作好数据源
    - c2 R! w) b' T! d# wdata_transforms中制定了所有图像预处理的操作/ K+ V. w: ^( w! m4 z+ k
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    # F8 l4 D5 U3 B% k( {$ i; W$ c  x8 \data_transforms = {; Q: Q8 o, c7 Y8 I
        # 分成两部分,一部分是训练# ?0 \! D- r- ]  L$ |# F  ^
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间1 F' {# c& X/ F9 F5 b1 T' [
                                     transforms.CenterCrop(224), # 从中心处开始裁剪
    8 t/ c' D; W+ H3 d% V: E                                 # 以某个随机的概率决定是否翻转 55开' C, O7 ?) ^( H
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转- k- R2 G* K  ?6 {: ]0 L6 r
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    ' D0 Q$ K9 u7 {5 M                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相( K( F- _. @( B" Y( d6 q4 l
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
    3 A$ F( {/ s- N( e2 A9 g1 a                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB: T/ K, P1 G# T. S8 q, l$ ]9 u/ n4 [
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的- \2 M- O  H8 c2 u1 i! c
                                     transforms.ToTensor(),
    ; K0 L. ~1 |7 B; |  N                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差* T1 B9 A- [# x1 f
                                    ]),/ C, u* e8 ~0 p! Y; w) g1 M
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    0 v4 e/ S9 @: }4 t& q* J* k    'valid': transforms.Compose([transforms.Resize(256),
    7 s3 }% Y8 e9 j1 F                                 transforms.CenterCrop(224),
    / u6 h. U4 ~+ O( B                                 transforms.ToTensor(),
    2 w# Y, _2 s7 ^6 `2 R" w                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    & `0 ]5 U- W" E! P0 A5 l& f                                ]),
    / A. G8 }; v1 Z$ |4 j1 H( b}5 Y! A; T" f: n6 S

    # F/ G$ `4 R1 w5 w. ]1) `2 I" S: ?% t8 h
    2% l  Z8 \  }; D$ c/ _4 Q" E
    3
    + l  r1 v5 u" f; v, o, k' k4" b) M# T1 u. l% x4 @2 y5 ?
    5% H6 k* D4 _/ q5 H# ?* Y, i: m
    6$ M' F& q) A) s' @
    7* u" t/ Y. H8 O8 _5 ^$ z& G
    8: g. S. G! [9 `; F; D/ b; w
    9% _+ }/ e, d/ ~" f: V; K# i
    10  e; g9 A/ V$ Q0 _& Y7 W: ]0 X
    11
    : c) z- j+ K( @6 @9 R12
    9 K1 }: F+ Y" K' H& i: H/ w- t13
    ) T4 _0 N$ H2 l+ B14& T0 y7 [& _7 L  R
    155 i0 c1 k" {8 S1 e' Z
    16
    " ^2 W( r) |# A. h. [6 R# z17
    + y: P5 J6 |( L2 }2 ?6 `- X1 ~18! B1 x; p: {+ ~
    19
    * q. F5 x5 e: m, R5 H20
    0 M9 O! N0 \$ d5 C" r  W* V21
    3 }/ Z1 w9 t' G; Q' n. j) N0 Fbatch_size = 86 T# I, ^/ k) ?7 I( S+ ^' |
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}- q4 H% S' M6 z
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    , }1 h5 y1 T( G$ _4 U6 O& Ldataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    ) I* X0 I9 |8 L+ e0 ]class_names = image_datasets['train'].classes  y: j" ^1 f! ^* g1 N$ A$ Z7 Y
    $ e7 M4 A# F; c7 b1 X/ }5 m* z, a) n
    #查看数据集合
    + q0 J2 X) G! K4 W& _image_datasets+ l5 y8 k0 J7 N/ G; V( z( y

      s+ J3 C# V9 d1
    4 A$ [8 t0 I% _/ @2 S- P# O20 \2 l3 c# X2 Q5 E' o& V
    3
    2 a9 k& I6 Q5 I: s4 X0 X, d42 ^3 D! y7 F; u/ t6 p
    5
    " m1 j# ]$ u8 K, m8 J/ f6
    % A1 L+ [( m# q4 N. T8 o7 D! E9 [7$ n: p& f$ C" y# a+ B* S7 S
    8+ h  K% z  k: j
    9
    ! f8 z% I$ g/ Q, B, n2 T{'train': Dataset ImageFolder) c& U( N# d% t* t5 O
         Number of datapoints: 6552. `, U8 s6 F+ N& s6 x4 Z
         Root location: ./flower_data/train
    % F; R) Q2 a: K& b; h! l     StandardTransform
    3 y% I$ V$ d/ d6 P9 c8 V- u1 ]/ K Transform: Compose(# K+ ~  N% I9 \" R1 O
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)5 Q. ~" h/ e% B+ e
                    CenterCrop(size=(224, 224))
    . X) \7 ]) g7 d. J+ k" |9 ~                RandomHorizontalFlip(p=0.5)3 A, x. q# O2 q: m
                    RandomVerticalFlip(p=0.5)
    7 |/ m( p3 T( k3 D# b5 ]                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
      w) k( v: A% v  m1 U  v9 o* c4 s1 z0 w                RandomGrayscale(p=0.025)
    : |. p# `' A) T) }# d6 I+ J                ToTensor()
    - f/ K- x  q$ T, L7 ~; _) C" t+ h) Y: r                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])7 Z0 W6 p' @" J* u2 a0 t
                ),
    # r1 A' i; a8 `. k2 Z 'valid': Dataset ImageFolder6 n7 e) z: X* Q5 W; L0 _0 k
         Number of datapoints: 818/ X6 n' W4 d2 `& F% Z- u4 [
         Root location: ./flower_data/valid
    5 K. p: m3 S- j* X" g     StandardTransform
    $ e+ E8 L/ I! S! }! e0 b% q# V0 ~! I* { Transform: Compose(
      g3 o0 N# u  i! i# |% G                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    3 `  D/ M  {4 }! F; ^                CenterCrop(size=(224, 224))1 }5 `+ |0 j* ~" `
                    ToTensor()
    8 K- g2 H6 a, H8 |7 Y1 }                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # h. @  p* ~8 L+ Z' v/ }2 a8 j2 P            )}
    ! n' j7 L: ~& Q9 q9 N' w- S
    & e2 c3 ^3 Q* r1
    2 \; A# E, K6 Q+ i& g27 c6 y. @. B8 }! d
    3
    8 U- G7 R6 R8 v( P7 g4% H: h0 A: b6 K, Q9 M/ L
    5
    , k& ^' o  u8 V2 {6) h8 u; a- X9 L4 l6 |6 G
    7
    1 Z4 e+ d5 C, y6 T9 Q3 t* U% l8
    ) S$ m# E6 {/ q3 r9
    & e6 r9 D) |2 }% b10
    . r/ O1 j5 o# i  W( ?! r11
    3 S9 m: ?: w. l1 `0 |% n' ^1 ?8 E12
    5 J  ]2 M# }, F5 N( R13" e1 L" h( M/ i9 F
    14" H$ m2 _  `9 x
    15* ^1 }5 K4 \4 @- n' ?& X% |
    16
    ; t) S7 `) z, w5 a% ?17
    $ ]9 M# i/ b( O4 x/ i: T2 Q& o8 B18
    / x+ K3 D9 g7 Y191 S) A. S+ Y. o8 v3 w+ B% }
    20
    $ H, O; c; c8 O% z4 e) o4 z21% P/ o: {( J3 a# n. w
    228 k; k5 W: y7 x
    23
    + e- e. u2 z% d- A# E: Z" t24
    * ~& y  P1 W% |6 z/ \, u: t0 _# 验证一下数据是否已经被处理完毕# o  V2 u& }* A* p% c
    dataloaders) l' k) z/ i" w7 c& H) y
    1: B( L0 g* L" o
    2( E( o, z" P' M0 H! E$ t  U
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
    5 A: J7 e3 n, _$ |% l 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}4 H; N3 a, a  @$ u7 E. ]( L' d" c
    1- ]( Q7 g! n( ^$ g- d
    2% ?7 T' _* B  I  o3 e
    dataset_sizes
    3 P1 f  h( ^; P5 H7 x1# @- Q0 X+ f1 w% m
    {'train': 6552, 'valid': 818}% f" u6 B/ G, N. w9 L8 t! w
    1& V; [3 f6 m0 F0 ^7 x& j( [
    读取标签对应的实际名字
    ! \; P: Y& S4 Q/ {0 n4 h' j& n) Z使用同一目录下的json文件,反向映射出花对应的名字
    8 Z: Q4 n) G+ D$ I9 j  b% u1 P' [6 j  Y" y8 U% e
    with open('./flower_data/cat_to_name.json', 'r') as f:0 A6 u1 I- v( }' N
        cat_to_name = json.load(f)
    * K  u- r8 c2 D' e6 _3 i1
    # ~  x, }' q) ?; j2
    9 Q* \" Z3 w/ b/ c" w" Q3 b. ncat_to_name% @9 P8 H6 n8 g( {( Z( D3 ]4 Q
    1, t% A# `$ y3 D2 @5 U7 w, e
    {'21': 'fire lily',3 ^% u- x- f: I& f
    '3': 'canterbury bells',$ H  G% B0 t7 n6 v
    '45': 'bolero deep blue',
    ) N) m  s, Z7 `& ^6 J2 y6 i '1': 'pink primrose',
    3 Q! ]" U! H, t4 m5 d4 F '34': 'mexican aster',* Z4 l6 @4 k" P; r: D( X7 v
    '27': 'prince of wales feathers',1 g9 y. [; M) Y! o+ ~) s: R
    '7': 'moon orchid',
    ) U) I, I% i: v! E- F '16': 'globe-flower',
    9 b4 c! t9 S4 _5 N" D6 g& D '25': 'grape hyacinth',
    , `7 Z9 G/ v! k3 k '26': 'corn poppy',
    # F, e7 n. u% k: d+ e '79': 'toad lily',
    ) m2 B0 V6 N% n6 _3 X( u# ^- p '39': 'siam tulip',9 Q; M3 S, o) Q5 B+ c( u% C
    '24': 'red ginger',. O9 {* I$ F& Z  H9 l! U# w3 w
    '67': 'spring crocus',
    2 ~$ l" h  V6 ? '35': 'alpine sea holly',/ j5 e3 E! J$ D$ ^' x6 C& y+ O
    '32': 'garden phlox',
    0 f6 G9 k8 V8 i* b '10': 'globe thistle',
    $ H  J( d2 D2 X7 u+ r" Y5 |* T' H '6': 'tiger lily',
    9 \) f  N+ e/ Z; Z7 q7 u '93': 'ball moss',
    ) f9 [- b# I  N. A) ]( r% r2 |9 o2 _( b" K7 h '33': 'love in the mist',6 ~% O: r2 u/ \' ~4 ~( `+ p& C9 S
    '9': 'monkshood',/ F* K4 U! G, q; v2 x2 S# u: Z( X
    '102': 'blackberry lily',
    0 a, `9 u$ G! L6 [: p) L# | '14': 'spear thistle',
    4 j! n$ p" S- h/ e* S '19': 'balloon flower',& P3 U5 @' I; Q9 F7 m
    '100': 'blanket flower',
    8 B% S0 ]- F* d# } '13': 'king protea',# W& \$ f6 ]$ z( ]+ `/ P6 z8 a2 E
    '49': 'oxeye daisy',9 w: b0 F2 `% ~3 O
    '15': 'yellow iris',
    2 E5 Q+ z' Q- X  q- _+ X" t '61': 'cautleya spicata',% q& N: ?1 Z; m& i
    '31': 'carnation',
    / H6 V: D* |, ~7 Y '64': 'silverbush',
    . p; p) y  d3 J" Y& M% s; w '68': 'bearded iris',
    , m# D' I5 A" R: I4 n( W '63': 'black-eyed susan',7 R. @+ Q# Z( Z5 i
    '69': 'windflower',
    7 T8 j" H! ?0 s: a1 _) k; h0 R '62': 'japanese anemone',
    ( m& t0 D) D0 u1 O4 H9 ~/ N1 G0 i! w '20': 'giant white arum lily',4 a! _4 n& k! k  Y0 ?
    '38': 'great masterwort',
    ' E4 V$ p$ c3 q* ~ '4': 'sweet pea',
    % b" b2 z, c& \  F '86': 'tree mallow',
    7 u' e" m: n7 k, y& C8 z '101': 'trumpet creeper',
    ! C7 S) x$ t3 Q4 I, V '42': 'daffodil',
    / Y3 \0 N+ F5 C3 }7 j '22': 'pincushion flower',
    , R! o( ^4 ~7 n- f9 E- N$ v '2': 'hard-leaved pocket orchid',
    / Z7 v$ T  f' c( Q3 C. } '54': 'sunflower',( r+ r+ {) g) l# E
    '66': 'osteospermum',
    1 n1 Q* ~, v" ]6 {" ]& j '70': 'tree poppy',
      d8 A* T, ?& T7 i% ~! m '85': 'desert-rose',/ `4 M! R% X6 A6 W0 w
    '99': 'bromelia',+ x$ w, e# P" b7 g3 x  y" \
    '87': 'magnolia',; M( a! {% E  z
    '5': 'english marigold',8 W% ~3 ?  I# [: V3 [
    '92': 'bee balm',6 r  x- h% C' h
    '28': 'stemless gentian',- C0 m0 _- G5 p2 [# J' r  y
    '97': 'mallow',
    * O- Q; @2 ?7 [. w '57': 'gaura',- `4 ^# L* ~8 x5 h! K6 F* A$ o
    '40': 'lenten rose',
    + V; u6 s4 I0 U- ^5 }1 R3 ^ '47': 'marigold',, X& _7 g. \) V  A$ T1 P6 {
    '59': 'orange dahlia',, ]8 E' }: L( n* V9 x! k7 g
    '48': 'buttercup',
    " g9 [9 J! v! O) S+ t '55': 'pelargonium',* W2 k1 U6 C. d/ U4 w
    '36': 'ruby-lipped cattleya',
    1 W' P" |1 A$ l: Q0 y0 l& _$ H '91': 'hippeastrum',7 B4 D" t: Y( Y( ~
    '29': 'artichoke',) b, ]7 T3 ?0 K- T! |
    '71': 'gazania',
    0 V. b2 S9 L% b6 _4 X '90': 'canna lily',/ v8 S# }. `8 l: z5 ^
    '18': 'peruvian lily',) c) L% A, A2 S! a, K" g: T" u
    '98': 'mexican petunia',! Q5 O8 D9 q9 X* ]; c3 D
    '8': 'bird of paradise',
    ! u4 P2 v2 o7 x* f6 J2 e '30': 'sweet william',
    / X- o3 C5 f6 H# u  I '17': 'purple coneflower',
    ) ?1 x& |2 J. @" c+ a '52': 'wild pansy',& k, t' W6 ~; w
    '84': 'columbine',  p' d0 Y( p8 M
    '12': "colt's foot",
    9 q% k% J4 |& U1 U3 P4 H '11': 'snapdragon',
      q: _0 `; d; M '96': 'camellia',3 ]7 [  ~2 K$ H+ B! c7 ?
    '23': 'fritillary',
    # F' g& T) L! L1 X/ U9 H '50': 'common dandelion',5 {" A3 r, r+ W% C  r1 U
    '44': 'poinsettia',
    + _. f% P* ~8 O* F) G3 } '53': 'primula',
    9 R# {9 A3 g1 d3 W; m: P/ e  V '72': 'azalea',
      n0 r( \# M6 ?4 ?5 c '65': 'californian poppy',0 m& y( j% j6 x6 P
    '80': 'anthurium',: X; `" h$ u3 b1 r# }
    '76': 'morning glory',
    3 k6 V9 C4 N# i2 a5 I '37': 'cape flower',/ e1 H' c8 K4 X) _
    '56': 'bishop of llandaff',2 ]2 C! i: N7 L9 I5 W) f
    '60': 'pink-yellow dahlia',
    & D4 ~# Z! ^  X) [; U '82': 'clematis',/ y, c" {( X$ t$ V) s8 C
    '58': 'geranium',
    7 V5 U$ p3 t+ z2 Y4 x+ N '75': 'thorn apple',
    9 N2 `1 [  q  V2 o5 C '41': 'barbeton daisy',
    & k& I! C0 a7 |( [" o '95': 'bougainvillea',
    . u5 R4 ?$ p- E; U& { '43': 'sword lily',
    0 Z* Q9 d3 R' L; N+ Y1 F '83': 'hibiscus',
    0 `3 ?: M+ f* ^2 S$ m '78': 'lotus lotus',
    ( R6 y$ X1 L( | '88': 'cyclamen',$ n( ^/ y7 S. Y  ~/ U/ D9 P- u
    '94': 'foxglove',
    ! X3 I) e+ S: {3 U3 N '81': 'frangipani'," Y- M- q1 ]4 J5 Q
    '74': 'rose',
    + ]1 @" @( x9 I '89': 'watercress',  F; T0 L$ E' o- ]- y" y
    '73': 'water lily',) Q5 H! N8 _; C+ f1 d" H
    '46': 'wallflower',
    0 @' M. b' ?) B '77': 'passion flower',
    ' m" ]% O) f( s# c3 ` '51': 'petunia'}7 \/ W9 a* k, X4 s$ W' C( o4 m" k2 G
    4 i( r1 H$ h2 a, [, `
    1
    ! F& ^1 B- K$ m! c  ~. p2
    ) S, Q% }9 S$ R+ S+ X" @3
    , A( r2 R) [( B5 i& P; X+ P" h- D. j% S4
    ! i0 @7 o6 u5 a- _1 c, g5' Y% f- }, q( Z, ^9 T" \* x4 S
    6  g6 y) Y# g' G) C3 J9 ~
    7
    $ I& q- w7 a2 R1 Q8- H" J, m# i, _7 |. @
    9
    & ]+ y6 w2 f8 h( t4 G10
    $ C5 P7 E  s8 Y( \! N: B114 q7 d9 _; S* l  M) S
    12
    6 d) \& ^9 h% p9 T, \& l4 c; k13
    " D& x5 `7 X, L; [, V147 F* w8 r+ `$ @2 m2 r
    15) S9 O  d( j: I
    16
    , R% e7 o5 ^, z$ r17! U6 V( E: {) h3 @% D
    18. m% L% P* r# l" @
    19
    ) k& ]4 ]: S$ U$ M. K20
    % `0 ^& e. {; F0 k+ ]: ~21
    8 k4 q: g# c) Q( c( _9 L: p1 [2 s22
    ) s# Q$ N/ B4 k# b23
    + K8 u. @- _" c0 G24' q& H3 o) u! r  y) n+ x
    25
    * c+ F3 S' f: H! f; x* h4 o26
      Q! ]5 N7 _$ k$ M27* O* g, P% d9 }6 g, o
    28' G7 ?2 d' i* D6 M7 f! D5 M, F
    29' }1 l& ?/ w0 N- Q: G% w5 E
    30
    $ U1 K1 [) ?5 c  q% R  u& ^0 T31% ^/ U8 D6 |$ A$ J( j& I. S
    32
      G5 S& R! ]% f% K, r2 ^33
    $ F6 G  {0 Z3 _9 ~+ v34
    - ]5 _' {# z" d' H6 D5 k- ~354 k: b  R1 a4 q! N* d
    36
    6 f  x/ U) B+ ]8 ~( B% n. [) ]379 }" E# w6 I) v2 G, u
    38  Y; I3 k+ P* k$ l5 H- F
    39
    ) G+ ~. E+ ]* N40: x9 P0 R( C# g0 w
    41
    / q% J4 Z4 `6 \420 n" _- K% J' G+ p) D% z4 o: I
    43
    3 c7 g; `1 T3 @3 ]) ~" E3 i44
    0 M! r5 a7 W: ~7 V( z6 K# z  ^45
    ! X, B$ b! ]; r5 |" w/ D/ D46  B& ]9 M" @+ t) s, ~
    47; D/ }- h) f1 Q" ^. w6 _* h
    48
    2 E* I- F9 t' r- H  r49
    * |' |! O) [! ~3 k509 V4 z2 U& U. z1 j; n4 y! V* k
    51, ~* p4 L* q. x- q
    521 u  u8 Z8 g% [3 T! S
    53
    2 v' b, f+ R+ j+ g  ?" z54/ k# z( \+ w1 e  {7 c3 p/ e5 @2 g$ Z
    55
    $ ~* V1 |2 `& D  a' ^5 F56# E, B% B5 j5 y( ^9 r0 G) f
    57+ ]5 z! x- B9 Q8 }0 |& M4 B
    58
    7 t/ B' V8 U7 ^% |% Z$ E5 P) ]59- C* R2 l# Y; H2 k. d4 j: r5 b! s
    601 G$ T, Z. D6 [9 o! H: C
    61
    7 Q. Y, o2 p2 O+ o1 h$ i, o620 R- X1 u$ e' |& G3 A2 m# @9 {' g0 O2 e
    63
    ) C  J  ^- Z8 W$ r! J64
    % ^- f* Y! d& w* \+ Z65
    ! y) @/ g% b8 \9 y+ t66% F/ H2 }$ t. `. s/ }+ S  }. ]
    67
    ' V, a' U5 }+ A( D% K687 o3 g9 i! u  A* S9 T& R2 l
    69
    8 i+ Y3 x1 \! K7 k6 d705 A% a+ W; I+ \" H
    71
    8 v0 q* R* a& l/ _72
    & Z8 p( ~& c: C# m73. Z2 i1 g; J7 u. u8 R
    74
    ! W) d9 ~) G# i! i6 _6 y75* e3 k; D: c* C; m2 E" f
    76
    . w3 ]% I4 ^$ {4 L2 l' U77
    . u: `3 q+ Y; n4 t, s8 B78  c; n; E- ?+ @
    796 ]5 t/ i* q) j9 u
    80
    ) k0 a0 m9 }1 \: ^81  g, ?! y* c0 N2 B5 t
    82! x, U" }2 E! Y) w6 J, r
    83
    5 z5 l2 Z6 `+ [$ e# m8 ~5 j84
    ( u& ~0 ]% q$ V- u" U& A85( M) t4 b6 }1 |, v# F
    86. k0 ~. C# ?. n8 `
    87' |% Y  P2 \: `- ?2 }* x( m& L
    88
    - b/ k- P! i! g6 t891 t- ~  P. S5 f, ]
    90
    : V" X. ?6 G# B0 ?# N6 {3 u91+ }  x# I/ z5 b1 ^: _
    923 _) q6 G# }) l) A- d
    932 ^; O# b6 ~3 J7 l' V) i  j
    94
    # E7 ]. I- y" `$ J7 g" q' p95( }6 t9 `8 o* {7 f* o
    96+ j9 i5 L4 v8 S5 o; h% B: i+ O4 }
    970 }% h* M2 t4 Z: q' ~
    98
    1 ~2 d- V# s# ]7 _# ^99! k" K+ `# M; g% O( Y
    100+ ^" Y4 g1 a8 w8 I$ x0 A
    101# E& f% o2 {0 [' a* B
    1024 w7 \% z0 i: S' U1 `# P
    4.展示一下数据+ X1 i! R/ ]1 P4 N" x0 |! v
    def im_convert(tensor):
    7 |" N5 Q3 K9 Y$ Y    """数据展示"""
    5 q9 _+ J3 }7 R! T# D3 H    image = tensor.to("cpu").clone().detach()! G1 B; a0 k" v8 P- E
        image = image.numpy().squeeze()
    7 i) ?9 b+ B+ S8 X4 l    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图8 B/ ^3 e% a* k# Y* y
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    * `5 x' P# B) s  Z    image = image.transpose(1, 2, 0)
    2 g# {) H( t1 B5 X$ c' J    # 反正则化(反标准化)
    ! G7 s2 {* P; w    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    1 v- {4 P+ L% N8 t* t
    - o6 z3 m$ h" R; C2 n7 ^    # 将图像中小于0 的都换成0,大于的都变成1+ D; m: W. h- F1 G6 b& [  l" r
        image = image.clip(0, 1)& V+ c, R4 h! z1 \4 U

    ) K; C( X! A6 V( _    return image3 j3 q4 f" G% F
    1
    ( c7 q8 o$ Z5 q( F26 S% Z9 r2 ^' H3 |
    39 x* t* m( B) [2 J9 {6 g& ?
    43 {/ i6 [  s2 k
    5* o/ N+ l- G8 l4 s  h
    6
    8 W+ Y" h. L9 h4 `- \/ q- E! t2 h73 |- V" }; i" k5 f, T) B1 n
    8
    ( S( S% e! X$ y. r. Z' T7 p) B9
    8 @5 w( e: Q; d! y10; V# o6 m( H) {
    11
    " ]5 |# L$ ^" |7 t- M: G5 C8 n12
    0 [8 h: T& R- ~13( R: @/ K5 \0 [& f& k* b2 l6 X2 p
    144 x: e' G5 P  t4 M
    # 使用上面定义好的类进行画图
    . B& p2 V$ }8 U# m4 D/ t, m/ ifig = plt.figure(figsize = (20, 12))% m, o- v8 J0 F2 W% d: N( M
    columns = 4
    9 F$ l4 P1 y- X0 O; I' U: Trows = 2
    ( F, h0 X3 [/ D9 A( f9 ^2 s; c: F. C
    # iter迭代器, t1 S# @- A+ ~0 f- t
    # 随便找一个Batch数据进行展示
    4 R* U. R" h* _% J. Cdataiter = iter(dataloaders['valid'])9 f9 S6 D# y* x% ^& P
    inputs, classes = dataiter.next()
    $ Y7 W( n2 C: G" v  F0 P7 ^
    ! S' N! \4 }: a7 O% Z6 T8 ?. ofor idx in range(columns * rows):; h% b2 t1 l: Y: y/ z
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])- U0 i6 o4 n; o( }/ Y
        # 利用json文件将其对应花的类型打印在图片中6 v0 |9 s# [3 W
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    9 g0 B  |' D4 g. i7 S5 x    plt.imshow(im_convert(inputs[idx]))7 s* H, A7 G9 @- f' g9 ~
    plt.show()7 J# q6 \2 G6 {9 ]! u4 T7 Z
    % x" A1 b6 ]5 v( ^" m- p
    1
    ; S: n( O; P) m% T' Y21 W5 z2 F8 B6 D; g/ d( o
    38 F6 H1 Y# |3 ^/ Z( R4 }* k
    4  q# i$ F& Q3 g2 t+ R
    5
    ) }' Q6 I! j5 q  M4 U9 M, s. J) w4 W6
    / j9 Y0 f/ x4 h* N7 m7 o$ t7
    + F. P7 ^* R! p& d8
    2 e" i) k  Y# q, ]) t3 C9
      J  ~; h: ^! s0 Q- S. u10
    * o1 u" f' Z. @- g: P116 v) t3 F. K; l$ a. `; H
    12) r" G# i6 d5 M6 h2 e
    13# ]8 p6 H2 n( ^1 x' r5 v
    14
    ) m1 o/ p' P) z/ `+ {15; o* ^- B, m8 Z, [
    16
    ; X5 `  [0 Y: n$ z* K
    $ r9 g/ k* p% ~4 Z/ W. }1 g" j4 F; b* @3 ^/ y
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数
      M" i# X6 }: Q" w8 N, N1 u# Bmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']% A4 m8 z- H$ [6 J3 m. o# j
    # 主要的图像识别用resnet来做
    0 B; t) r: r1 r+ u# 是否用人家训练好的特征
    $ q! L0 p! r% P) i( Hfeature_extract = True$ W2 G, D* I; C1 P- @0 [9 g5 n
    1+ [! M4 U" b; {% J- i4 o. ^
    29 h! _. E* G) R/ D0 {3 B7 J7 n, L
    3; T# O5 m+ D& ~1 \4 g- g
    4
    - {1 |; \% S! I# 是否用GPU进行训练
    1 T- k" O# Q8 a4 H  qtrain_on_gpu = torch.cuda.is_available()
    . ^3 Y2 f( }0 l2 Y( }. z& D2 S' d$ W" Z6 \+ U) N8 S
    if not train_on_gpu:8 S9 R# a3 o! D, Z
        print('CUDA is not available.   Training on CPU ...')
    7 |, U! v: p# ]else:( L) h" y6 r" f0 z+ s" ^
        print('CUDA is available! Training on GPU ...')% N! ~8 N2 v$ N6 Z9 J* u! s8 o
    + ?4 Y, k' e3 I. W
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    0 h7 |6 \. o: r* f& l2 {1: k$ m) Y4 n  \3 ^
    2
    5 a' C" V1 J! w36 h6 [$ D0 g6 l3 y
    4) `4 ?3 o8 C2 i6 I! q
    5- j: d- v- o+ U: S0 l& i7 u
    6
    ' z4 x" h, b* [  R* |- [, E/ }" U( m! I7
    0 r9 m0 a+ L0 I0 U8
    0 n& L, i% i# K/ w- k0 S8 Y# Y9; J' w; o2 B- k
    CUDA is not available.   Training on CPU ...
    1 I* U! r/ T8 b" r5 b1
    , J, ?0 j4 L7 H" o" w# 将一些层定义为false,使其不自动更新
    ( q' W- c$ J" U% @6 N" s& u0 S& U1 Sdef set_parameter_requires_grad(model, feature_extracting):/ O& _3 J- ~+ ^$ y' e$ w
        if feature_extracting:- G5 R. B( ]) J4 W+ j" i
            for param in model.parameters():! g$ B# }  [( I0 S" R
                param.requires_grad = False9 W" B0 j# R' |( V' v0 w( F
    1
    # d8 b' c- G) J3 I1 x4 ~2" R# y7 }" v2 Y  f( k! G* t
    3# v2 W" f. e# v& Z
    4
    + T# I: a$ `7 D3 S5
    % b/ K, Q8 @4 _  [! l. Q7 _( q- _) m# 打印模型架构告知是怎么一步一步去完成的
    6 P- q, u4 I  J4 ]/ i; [% a6 {# 主要是为我们提取特征的
    ; @( X3 e2 L8 }
    5 [+ _6 h' l" r3 ?model_ft = models.resnet152()# H2 e" u9 B) R" X" j! _
    model_ft( r2 M/ u$ C- P6 G8 h& ]' G7 u
    1& J( Y) i6 H) ]- R* ]5 I; L7 ?  a
    2
    2 w6 _) t; m* x( x$ t$ s+ `3
    6 o+ `0 B" E. ]4 u2 @! {* E+ j* i4
    & d% V2 X1 ?; f8 o5
    $ M  O1 \6 X( t, L0 M' \" jResNet(
    / j! b7 ?2 h$ V7 T  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)8 R) w6 ?- z$ d. M( y+ b, D+ p
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)$ L& {$ B' R' {$ g, T4 w
      (relu): ReLU(inplace=True)
    " }) d3 |5 i' r, q- @: }  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False); `& E  T  M" t9 I7 f) z
      (layer1): Sequential(
    ; d1 A+ ?8 g% l5 @    (0): Bottleneck(' g# d/ u4 `6 j9 q% r
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    5 V+ h1 v# p8 m7 T      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    7 d* @5 }& U- P) L3 p      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    . E! j( U  {7 x' A5 y5 x5 `      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ P; Z6 \) O0 c9 Y; D' [
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)& ~# F3 k. U" q/ v
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    1 Q5 R. V. {3 Q1 q      (relu): ReLU(inplace=True), W6 A- y+ v9 N# t
          (downsample): Sequential(8 }7 }5 [8 A% _7 k
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)- H( W! k2 N" d0 g1 ^2 e
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    - }; x8 z4 f4 r4 g7 p# T2 G5 k      )! N" {! ]0 [& ?) [1 Z: l
        )$ @: q2 S3 N5 E& N0 X9 l& ~" B
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。( G4 v" M: W( }+ Y. D
        (2): Bottleneck(
    ; m$ j9 }3 H3 V% U5 m. W      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    5 |: K" F* ?, _+ s( g) K- W' p- y      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    9 ^# O( U7 J- w/ I% T2 X" }      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)6 Z# k+ ~' l- i  ~$ _' v' Z7 \
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    % d( U2 J; `) h4 s+ w  c* D( _+ ~/ j      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)2 T% k" m& M1 c( ~6 x
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    4 K, m8 G1 R+ E  z      (relu): ReLU(inplace=True)
    % O  x( h" x: o0 g4 h) ^    )
    & F, w3 @* W! l8 _( L  )+ v  c6 j/ k5 }( ?7 L
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))5 K/ ]: a" H$ E. ~
      (fc): Linear(in_features=2048, out_features=1000, bias=True)
    , m( m# E  e% u0 a)
    ; l" e5 T6 x. V; n5 Y! Z) J" z0 \+ z1 B$ u
    1
    # U6 B4 E+ y4 n# A2
    ) j+ z6 ?- s: Y7 ?5 y! K3
    : Q" B) M# ^( v0 ?0 B5 ~' g; a4
    $ D. y' [4 j- g5
    + B3 [/ A* C+ m* n6- g. X; O/ g3 K( d) x
    7$ i$ H( g) U1 Q# y: W2 W
    85 b# w9 S9 s* v! e
    9. P: P6 W3 e$ s- |2 X
    10
    ' J, f% J6 P1 c+ W  w2 A0 y11
    : }+ Y  t2 i" D! k  D. g12# Z3 Q6 m/ \" Z- j* Q/ |' c* x7 N% u3 Y
    13
    4 B4 K# K& F! B# [# ?: \14& U1 ^2 U4 W- y  C# T8 |4 s2 h
    15
    " i" l2 W5 q8 H6 q" @& ?- V3 [16
    - r8 ^: X1 m" ^4 g7 e' A17% T" K' U: z+ f4 i
    18, e# q3 S* t# P5 N$ l
    19
    0 B! t  X  ~# F4 _3 ?/ n208 |* y' m) V& h
    21( h  z8 d: d, z) d6 m/ k
    22" m/ c% @* N# Y. O( q
    237 @" p3 ~2 D: o
    24
    8 D! y5 o1 q% u  Y9 k) f( q25
    ( s" L& X6 ?6 B- U; [  f26# ]4 x0 y2 D$ ?* R6 Z
    27
    , I9 K* S4 J) y6 B" w' J; r, ?$ d% O28
    ' ?) U/ V+ a9 d' P% V29) J, x& P% S4 d5 P# s! y
    30
    9 L0 Z2 y/ g( U$ S* g31) F- S- v$ R- k' C6 A3 @
    32! v. V3 {% K5 d6 }! \- S3 a
    33
    + n/ d0 f0 `0 Q4 }最后是1000分类,2048输入,分为1000个分类, p" V! m7 g/ f* }( {% y
    而我们需要将我们的任务进行调整,将1000分类改为102输出
    ' @& D/ _( [' x. K* C5 U2 G
    $ f' U. L$ H' v) U6.初始化模型架构
    " O1 W7 r1 c' T; A. U3 d: g1 Z: q步骤如下:4 p" n% o; y" [$ _# ?; }

    . {4 }5 g& u% A! N) \* ~* G5 p将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
    2 l/ z) w2 Z0 |. k可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
    % n- J9 Y8 n/ ~) C无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数6 U- t( ^# h8 E* p: N9 P# a0 C! s, T
    官方文档链接
    2 h- t/ A4 F" m+ Ghttps://pytorch.org/vision/stable/models.html
    $ |" S8 g: n& {8 ]' m6 q/ x& @
    1 z: X$ g! v$ e, k! p5 P  Z# 将他人的模型加载进来
    ( u% d& d  @( s: o0 t7 ldef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):0 X! p) I" j1 B# X' z( z
        # 选择适合的模型,不同的模型初始化参数不同
    " m5 C2 a9 E4 s! P    model_ft = None5 e! u# z4 z, f  P: m+ f) G( o0 _
        input_size = 0
    $ \1 `" c+ u. r$ E: s2 j; m3 D/ |, _' z  D% Q
        if model_name == "resnet":
    8 x8 h9 J8 C2 p0 C        """
    : S# F+ W) g' q) b; c6 r        Resnet152
    4 O/ |: ]8 ?! U6 {; |" v        """6 U* L3 ~; A! o$ j2 {
    ) F! A& v8 I- J# ?; H7 `
            # 1. 加载与训练网络
    1 `. W/ h# V; t2 D& D        model_ft = models.resnet152(pretrained = use_pretrained)3 f: x& o7 H9 t6 N2 b( R% L6 p
            # 2. 是否将提取特征的模块冻住,只训练FC层
    3 o: n4 V' T0 @4 q. v0 v        set_parameter_requires_grad(model_ft, feature_extract)
    & I' K/ S# w. o# _5 M        # 3. 获得全连接层输入特征9 Y( X$ g, ]. }- @/ h* }: a6 R
            num_frts = model_ft.fc.in_features
    & G( z' k/ i  i' c) B" J( W/ q# Y, S        # 4. 重新加载全连接层,设置输出102; P3 E  P- r. i
            model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
    * }8 ?# i0 I5 r1 G                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    ( F/ \  t5 |1 R* u! r0 d        input_size = 224# C, c& F. @4 q! U) _: n! I
    9 d' b9 L$ A) \5 v- c3 Z5 `3 o7 E
        elif model_name == "alexnet":
    % ?* O5 w$ H& O! I3 F, W  }& m        """
    " u4 k6 M2 U$ N/ C) W        Alexnet' R( U/ ]  ?* J& w
            """  h3 N, y" k5 {* g6 Z2 _
            model_ft = models.alexnet(pretrained = use_pretrained)- C0 @  b4 `$ a5 e
            set_parameter_requires_grad(model_ft, feature_extract)# W6 s: T; i; o& l' L
    ; K% D/ ]' y5 ~  s) G
            # 将最后一个特征输出替换 序号为【6】的分类器
    2 |( N% ]& h& r* N        num_frts = model_ft.classifier[6].in_features # 获得FC层输入
    , F" A9 }1 D' q% r& s4 C        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)4 ^" ]5 H8 w1 B: _8 R. Q4 D/ I
            input_size = 2246 W$ u; e; ~, n$ r7 I3 _
    9 S! I3 Z! g* f0 E9 n
        elif model_name == "vgg":
    3 l" n  C6 z+ k2 k! B        """
    ( A% S! ~# K2 ^/ O1 X% Y- m7 K        VGG11_bn
    % F- K, J4 X4 q- @* T: |        """) v8 c  x7 V9 d
            model_ft = models.vgg16(pretrained = use_pretrained)- ?: A) T9 m. c$ l1 d+ z  q3 i' k2 ^
            set_parameter_requires_grad(model_ft, feature_extract)
    1 d) Q$ i1 T1 u7 Z0 c        num_frts = model_ft.classifier[6].in_features
    . n4 ?; |0 Z; K! V) G  y        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    8 V4 }2 \2 h- P9 ?% H9 z3 _0 f( l        input_size = 224
    ' W/ Y2 ~$ v6 e. Q( v' ^
    - ?0 c% B5 p, @$ F' }1 a( s    elif model_name == "squeezenet":; ~3 k5 Q+ l4 m  k
            """% D2 K; d  _" t. E7 ^; w9 O4 a
            Squeezenet
    " {6 y5 J2 _7 u6 G5 _        """: x7 H: T' C) Y3 \7 l/ C% _8 a, _( ]
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)
    1 |) l7 A6 l: R& _9 s        set_parameter_requires_grad(model_ft, feature_extract)5 m/ O& @8 E; m. s( j# S
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
    - H8 j3 P9 b  A9 f        model_ft.num_classes = num_classes9 [5 x+ J, o, ]: E
            input_size = 224
    " W# @# N  E, I$ w: U0 d/ P; }3 z' U+ R% ^0 ^
        elif model_name == "densenet":: o6 e0 I' j) P' |  U9 Z
            """* \$ H- b0 D4 l
            Densenet
      ^, j0 K( ^9 A8 [        """0 T, r5 s2 Q. \
            model_ft = models.desenet121(pretrained = use_pretrained)( u- Z* ]# k' \" @# _& J
            set_parameter_requires_grad(model_ft, feature_extract)6 d5 D/ R0 z& M- H7 V  ^
            num_frts = model_ft.classifier.in_features
    # A8 @, O5 l8 O) u        model_ft.classifier = nn.Linear(num_frts, num_classes). }- A& E' U7 l  Q
            input_size = 224
    , }4 p4 `/ D( w7 L& p) w
    6 @2 a! D' O! v4 M/ g% H  d2 O    elif model_name == "inception":
    9 S$ a, J& I1 H# u1 l  A        """  o/ M2 ]! `, r. V+ g
            Inception V30 U' V( i; v: ~' u
            """1 q" E. T9 Y. }+ }. [$ u4 U
            model_ft = models.inception_V(pretrained = use_pretrained)
    ! T2 C5 p& V; o/ N0 `, `# L        set_parameter_requires_grad(model_ft, feature_extract)
    ) i5 v" h) b5 T* t  K  s3 Z8 Z- r" ~) X7 h
            num_frts = model_ft.AuxLogits.fc.in_features
    , F# W, {. `4 o' Y3 o        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
    & F5 x; b4 e' {% v8 \+ w# z1 R. t- C) A
            num_frts = model_ft.fc.in_features
    8 _6 F0 t8 Z3 y% [        model_ft.fc = nn.Linear(num_frts, num_classes)
    : z  m' q  C5 u4 @; a' [* c+ k, V        input_size = 299. P1 c- K7 d' K1 |7 L' }& c  P! T

    , y# _- p3 w: H4 c* d  `    else:
    0 Z4 T: m1 k/ K+ o! h3 X        print("Invalid model name, exiting...")6 f; j+ y0 N$ ^3 G% N  N7 @
            exit(). y: \- O8 u0 |! _' Q# i

    1 p0 G* W, ^& x+ E  q    return model_ft, input_size# ^; \# l, O) |5 F9 e% D
    % I  d* r; D8 ^+ H5 Y" o
    1( n0 N  n3 H- K6 c9 J
    2
    0 n' Z' H/ q; ]# A& I9 ?3, e/ D. J- p% n/ x0 i; w. _
    4
    4 W3 [( P/ \! Q! C3 A, f+ k4 s* W7 m: _51 Z7 l) y. o( w. Q
    6$ t7 W8 @) X! ]+ j8 `0 i
    7
    4 D  g3 o. [, w$ @. m: q. v8/ h* \  X3 {+ W* B/ Z1 b
    97 q. `2 V9 y* K; ~
    10
    % A4 J( U% Q. m* M: g# }( S110 _4 k" G! d9 q0 ^
    120 ~  k% j: z' t6 @7 [
    13  d0 M2 [- m- i- S
    14
    2 ^0 z, V; n3 |( g159 ~- x/ P, `( x- y# P" N
    166 {* x" I" B8 J* B
    17" M( Y7 ~) v& y% F& L; Y* _5 n
    18
    & F1 {: u# v% G8 g$ \# O& R$ ?1 ^- z196 y4 ^! _- I& P$ Q: y6 X
    20
    + e: t/ d6 A1 U/ E4 @0 R6 y% e21
    3 O' I* r' I  w5 r0 A22
    - q# Y1 Z% s9 |! ]23
    2 `( f6 ^$ @& M( E- |$ r24% f9 Y# j. l5 }( c2 t+ l. S5 q
    251 T" ?, {* ?6 L$ K
    26
    9 m* p% a: r" B7 Z+ p  s  z8 E27
    ; {, a4 {1 W; T5 y$ m2 n7 u28
    $ e7 x; t) U- n& X' L29. t! D1 O* @0 k( f
    30
    ! m% g  I4 t, K! C' U# W31
    7 `- o9 g1 h! W  [, S' ]. `32- s* ~: e! J% ?" c  ?
    33
    5 p$ t5 K" _( M0 _) g/ q34
    % ~' G6 p/ b/ m! G- a35+ P: A, r( E0 f3 [$ r1 n
    36! J) m4 E8 C3 ~/ B5 p
    378 k" ?3 f$ C" o0 ?. G) i8 M
    38
    " A: K, \5 G* ~# I2 W, Q2 V  s39
    , A; w" h0 ?% a+ Y/ W  L40: \* i8 r$ L* P. k- K
    41+ @2 w* [( H' G* {2 L
    42
    * A8 D8 L4 X3 `/ I43
    6 O6 ?& m* R- T, h- @$ ^7 G44$ r0 Y1 Y) E/ j/ i# ^
    45
    , u# {) X3 h3 f, W! e6 {46
    ) o5 {+ K. M* Z! G1 X! Y47
    ! }- g/ Y. G& P+ N3 u48
    % o. n  k* ]3 u1 z49( |2 L4 P: N, V$ ~& s
    50
      w4 Q) J, y, V51% ^. f/ f3 k! t5 W( y% \
    52, P2 c8 A' P% t3 R  l$ k: e" I
    535 E8 e: |/ e* [$ W. L$ n
    54
    $ G; h" k4 k0 T2 l2 E- u55
    : L3 }9 W* C( q3 M$ j56
    ) |0 e9 W& o. e3 t) t' @6 \% f57
    - t1 P$ J; \7 A5 N7 ^! S4 {58
    6 H2 p( Y& \& s2 F59$ b$ r& x; K; Q: i5 Q" T7 b
    60$ e) d- @! Z# p6 v1 G
    61) A5 l5 b4 U5 `* `$ W& K# Y
    62, ~+ ^( Y. g( `' j# r- [
    63% Y( X4 y& R8 Q9 \1 ~
    646 W( I+ _8 {% L' W- k4 q, w. V
    65
    % k7 i0 F* h/ |) m; j: F66
    : V! {! Q5 }' n. e5 G( k* s% ]67. h1 c# I0 o/ ~9 o% h' G; l
    68
    3 A4 [6 M1 z; R2 M) _9 b8 a69
    7 ~; G; i, _& E& q% e$ [! a70- L; |! `: g: z8 r! T
    716 G; Q% R* z' s5 Q( F2 X3 {: D0 ^
    72  O% u9 C4 D+ V( ?" W1 M& M/ _
    73! [; d  t) v2 Z0 s4 |. h3 \' X/ ^
    74
    * R7 ?1 Y8 y: B: }5 h" t75
    , z2 G; P" r9 Y3 _/ T# j; T764 u: }5 \" K# A0 t4 k& E8 ?
    77
    % {& I9 h" \; F7 O8 t788 x( A$ J. y% L
    79
    3 V; v& X2 D( u9 Q* {, B! m801 U- P+ p, Q/ E  @1 T9 @1 y
    81/ r0 _: ^% I/ x! b: m# K8 k
    82, u" L, t+ q5 N" u& J& c  U
    833 K7 m6 U; V- W/ n
    7. 设置需要训练的参数3 U# p4 [; h# e+ z6 J
    # 设置模型名字、输出分类数' R, N, m5 ^3 Q5 {% h
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    5 Z% T! b2 x' ?4 t! n* `
    1 s6 i6 h' ~6 t8 e3 W3 v# s. h# GPU 计算
    " [- n+ p0 v! Cmodel_ft = model_ft.to(device)
    5 Q% j2 `  e6 t' l" N
    & Q' ^) F- h' r& A# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取7 s6 N0 y1 p$ E3 J
    filename = 'checkpoint.pth'5 ?" J: h1 Y- k

    - i  B3 x; y* y5 X+ V) U! |+ x" m# 是否训练所有层
    1 x' a. H+ K' F% bparams_to_update = model_ft.parameters()
    , j3 M! E) x. o1 i1 s( {# 打印出需要训练的层$ `% I0 Z) \5 y  ^: y$ O
    print("Params to learn:")
    1 O3 d7 S; X6 h7 b. Eif feature_extract:
    . `9 k! q1 v1 A    params_to_update = []% ]8 o' A1 Y7 ^7 b# X* j
        for name, param in model_ft.named_parameters():1 a$ W0 H4 ~) Y, j% u0 N5 R
            if param.requires_grad == True:
    7 ^. s( A3 B' q3 a: l            params_to_update.append(param)# r6 z4 Q  I- j: H
                print("\t", name)2 M6 u5 j7 S8 Y4 h- o1 I6 a1 E
    else:
    8 a9 }4 x- n  L! M    for name, param in model_ft.named_parameters():, S4 b. ^- }1 n# n* f* O% W
            if param.requires_grad ==True:% t( r9 k' m5 |( `5 V& f5 D  G9 W
                print("\t", name)
    $ R* b# T  z+ P6 w0 ]9 v6 V1 X1 P
    5 `4 m; o0 S# K: t1, i: T: U- @; L% ?- G5 g! ?
    25 i# [+ f- w! T) o+ c  Z# \
    3: H( D  r+ n- B
    4' e- z  f% P$ e9 C
    5
    6 Y/ Z) w! b! P) r+ q% r0 B% H! q6- L. g) A; J# x! u. J. v9 X5 ?2 p
    7: |5 b* l: }) k
    8
    . a# L+ m6 l; j  t( W9 T9' l$ W; X2 W$ W' Q+ w
    10
    & N# Z2 g$ _, _4 G0 n3 m11* ~: @5 T6 \6 @  f; M6 Y; f
    12
    # m' D+ o5 ?- ~& n132 G" h- Z& O& k+ _% o' p
    14
      |* B) s# L% x4 e! u! S15" j, D7 e  }8 j' y# }2 @: [( k9 Y
    164 g* M' x! Y. x7 J) @9 p5 M
    171 B4 }# I$ H2 A5 p( \* M% p3 J
    18, F! r. L4 P: z# r, P% L
    19* L" t' P  z: F! B, f
    20
    ; Z. t. S' w! a/ Y* q21
    2 W9 j, Y: Q- j  N22% ]# v% b" R9 c
    23
    5 z' V# N- ^0 W' O3 Q  ], LParams to learn:
    . ?0 f2 ~3 A  x         fc.0.weight0 y1 [0 W3 q7 {" K
             fc.0.bias& x$ N! o/ i) q. p9 ^; B
    1
    1 x( I+ Z& U' Z8 H0 ^6 [7 v2  y/ Q/ N; T; |. W
    3
    ( ]# \9 Z$ a9 ^3 X9 {7. 训练与预测
    2 e$ g6 g/ T* L7.1 优化器设置/ C5 p& e: G* Z) C( u
    # 优化器设置
    9 d2 g/ o+ I$ W+ ~. P& `' H) l- Y; u+ foptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)) s2 [4 N4 m6 T/ w' h9 U. M
    # 学习率衰减策略* B; c  ?$ Y5 X  t$ H+ N. [
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)3 P; l( F+ h0 m( m/ q' h- |
    # 学习率每7个epoch衰减为原来的1/10
    ( `) P) f) C- o$ Z# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    8 k1 |1 H& L  \' e& {* M' }; p
    7 N- @4 r1 v4 c  l6 ecriterion = nn.NLLLoss()
    8 P# ?) Y( ]% B, n; c$ @1 P1
    # e! @  X( f+ t. k( C' c7 E2
    ) u# x8 |) r" }% d8 ^/ [, c  J3; l/ _* q. A% ?; Q: z: v. P
    4
    : B. `$ ?2 e# i7 g8 N57 k. a( p9 R+ X& ]
    6; I+ y+ U, Z! }' e
    7
    $ o) K0 T7 h2 T/ r84 V6 y0 A# }, w# k
    # 定义训练函数
    + w9 j+ {+ o: O  T5 a9 p9 J6 O  o1 A#is_inception:要不要用其他的网络
    ! g8 ]+ S9 Z1 z- ~% z% bdef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    ( B( x1 _  A$ r0 l# \3 Y    since = time.time()" H& @/ u  I% I9 L0 y
        #保存最好的准确率
    7 i/ s& _" k1 q( E    best_acc = 06 L* g* `0 d8 p) g4 d0 u) S
        """: N* K' q+ h% [4 G- z7 e% [0 d% v
        checkpoint = torch.load(filename)
    4 z- v( ^* ~$ a9 @    best_acc = checkpoint['best_acc']
    / A' m. y: K6 c: e8 s( m    model.load_state_dict(checkpoint['state_dict'])9 \1 ^0 t! u5 Y( k+ I  m
        optimizer.load_state_dict(checkpoint['optimizer'])$ z7 g7 P' g& G; M
        model.class_to_idx = checkpoint['mapping']7 A. L2 C0 i9 ]+ o1 J2 Y
        """7 X7 u3 p( I; S% y  A+ w
        #指定用GPU还是CPU
    1 M4 A+ k8 d2 [1 h/ S2 W    model.to(device)' Q* H# m" `1 M3 H8 m" q
        #下面是为展示做的$ P: M9 r9 P: Y8 t" M& [' T: U
        val_acc_history = []* q7 x- K, _7 [; V! g, t* Z" J
        train_acc_history = []; O# d# A: L1 v
        train_losses = []1 }, d0 d1 o: h  t7 v7 v& Q" P' l
        valid_losses = []
    8 C4 e; ]1 A6 K5 T    LRs = [optimizer.param_groups[0]['lr']]
    $ b3 c5 y8 m' @* |    #最好的一次存下来# k7 q9 m' C, ?+ F
        best_model_wts = copy.deepcopy(model.state_dict())2 }  m- r5 ^! U" G6 f1 i

    1 x# M5 Y( ^# N- ]    for epoch in range(num_epochs):- m' g1 L! A  B# P& F
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    5 w9 m+ L$ T, d. z        print('-' * 10)
    * i  P- }1 s; Q6 j* ~' P  g1 O
    3 r* k- }- \0 D2 y# o        # 训练和验证
    4 T7 p' ~, s+ v. x        for phase in ['train', 'valid']:) S1 n+ _: b: r' f6 `" ]( Y" \
                if phase == 'train':
    + d2 i* _' k3 U% R9 p! n( B3 M                model.train()  # 训练
    - b+ t8 v% m' K5 Q3 D  J2 z7 ~* u            else:
    - j4 Y% D# M0 d4 M' F9 Y                model.eval()   # 验证. Z* [  l! U  S/ `- Y, n2 ~; e0 J
    " o: P! F( M' b; _  O. j
                running_loss = 0.0
    . F  S; X+ Q( T            running_corrects = 0  @' f" e" \9 R+ y' ]) e; ~% J
    & k" I2 ^( k* h$ L) M& Q2 C2 S
                # 把数据都取个遍/ V8 X. K+ C+ c/ O( w) D
                for inputs, labels in dataloaders[phase]:
    5 L# t' j3 v, c1 {                #下面是将inputs,labels传到GPU
    : x! P; I$ t+ z1 t) A                inputs = inputs.to(device)
    5 ?2 y: q& [, d/ I" ]- ~& i. R' j6 `+ O                labels = labels.to(device)
    2 v, o  u) b( m; o8 v. e9 V" B& i) W9 n! E$ B/ N% j
                    # 清零
    1 j* X! l  w- ~9 l; c                optimizer.zero_grad()
    ) S7 S9 o; Z. _2 ~. M/ I' Q0 g                # 只有训练的时候计算和更新梯度) e! i2 h* p6 T) }$ D" a
                    with torch.set_grad_enabled(phase == 'train'):
    : {3 N% Z, }% _# b5 O                    #if这面不需要计算,可忽略& M/ S& W8 B. Q* P
                        if is_inception and phase == 'train':
    % W4 K6 g3 _4 `0 p5 d                        outputs, aux_outputs = model(inputs)
    & p% y+ [2 M$ @; M3 s& N1 M( @                        loss1 = criterion(outputs, labels)
    6 q0 n5 a8 R2 I* v, J" {; u1 p                        loss2 = criterion(aux_outputs, labels)3 L; E7 G# L# a3 S& `" G/ F+ r
                            loss = loss1 + 0.4*loss2
    / P" r) c! P) u% H0 F                    else:#resnet执行的是这里
    , [$ _7 r* ?  a7 ]6 X  S3 w                        outputs = model(inputs)  p8 p/ Y( e7 U" W9 w1 {2 B
                            loss = criterion(outputs, labels)8 m. B2 V  c- N' \8 V$ e

    ( \  w- {$ b9 Y$ ~: P! h                        #概率最大的返回preds
    5 Y. u6 W' G0 Z% f/ k) s" [                    _, preds = torch.max(outputs, 1)) o! X: i# V# S

    * o1 k+ p: X% C) C/ G                    # 训练阶段更新权重
    ! R) s1 y/ A7 e* r/ v3 _                    if phase == 'train':
    ' R1 u+ B5 g) h3 C! ~! V                        loss.backward()' ?  g) u" I% @; e3 y: Z- ]
                            optimizer.step()
    - D* ~; `2 @4 M9 a/ @
    8 g5 s. E' R7 C3 H# Z                # 计算损失8 s, I* M# h# o& F
                    running_loss += loss.item() * inputs.size(0)
    , c% x' N' ^1 _0 u$ Y                running_corrects += torch.sum(preds == labels.data)! C5 m  Z+ h5 e
    0 ~1 u4 u& _" A
                #打印操作
    ( @: }" R2 O. e            epoch_loss = running_loss / len(dataloaders[phase].dataset)
    % y. t" W! A! G/ |/ G- l  m            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)  L8 \% {/ c8 H, Z. M
    . q( e8 y( ~$ u4 H/ w$ y
    ! D2 ?. O- q! M3 B" z
                time_elapsed = time.time() - since
    7 l6 t- T9 a* Z+ ], R            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)); i+ z* [$ g" @( g
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    1 m. Q& c" {$ ]" f+ L) q! X! _7 M) Q5 M: }# k0 L1 Y

    2 K  b$ `) Q: E! C+ W2 O            # 得到最好那次的模型; v& ]9 E/ y0 j9 W, z+ N
                if phase == 'valid' and epoch_acc > best_acc:- ~; @+ O% X+ K* R6 c5 Q$ o
                    best_acc = epoch_acc
    ' [/ Q! ~- [* [6 k- v4 ^                #模型保存2 Q: F; T3 u* e+ ~* [. J
                    best_model_wts = copy.deepcopy(model.state_dict())
    - _3 l& @3 A! I5 ~. d, a                state = {
    2 F6 h# B, J# J. m0 P* `$ ~                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    $ P/ E2 a# Q: u2 K# I/ L                  'state_dict': model.state_dict(),, i8 s8 G9 ~$ D" O4 Q+ n/ e
                      'best_acc': best_acc,& D) M1 b2 T' T; D) F3 [: |
                      'optimizer' : optimizer.state_dict(),, p6 Z7 D* }: y! }* _1 A5 U
                    }( l. @; r! N0 t' |; |
                    torch.save(state, filename)& E" O: [# Z- g* L; y3 c
                if phase == 'valid':
    & a, f; `: |2 z0 G; g. h# P: T/ j                val_acc_history.append(epoch_acc)3 K" _  a; u' H5 }  E) Q
                    valid_losses.append(epoch_loss)3 ^% w( H( J2 o; E
                    scheduler.step(epoch_loss)
    ) d5 i* b) \* p            if phase == 'train':5 M- B* _9 U' _- J/ \$ o( ^  ^$ X
                    train_acc_history.append(epoch_acc)+ S/ j5 ~' H! {6 n) Y/ S/ {
                    train_losses.append(epoch_loss)2 J. }6 C4 m& l- X& \
    6 K& x. Q! l. A* f. n; N8 ]
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))+ ^4 J* W  U7 _: _* y( R$ i1 y
            LRs.append(optimizer.param_groups[0]['lr'])
    ; u% }6 h+ l  b0 I, G        print()
    9 d+ B/ Y% ]; ~/ _. h6 Y  z0 H" h, r% f) e1 Q2 h, d$ K
        time_elapsed = time.time() - since
    5 p- F" r3 I4 @; d6 x6 X: a7 G    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))0 T8 o5 Q7 K4 T0 b/ o9 l
        print('Best val Acc: {:4f}'.format(best_acc))
    3 f2 p5 K7 k6 {( {; v
    * A3 X5 ]0 V6 x5 }0 f5 K    # 保存训练完后用最好的一次当做模型最终的结果
    & i2 c$ D6 A/ l, ^; q/ Y% G' h. J    model.load_state_dict(best_model_wts)
    ; b: c; ^; w5 \- `    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    ; D( ?, \9 R( t1 g7 s) P) o: h0 N8 X: p6 P9 [. r: t% K6 ]

    * X5 N. V& I( [) {/ a$ c, p: G1. L# s/ R; v0 R' h0 {+ u
    23 L' Q# `; p. W
    3
    ( s5 t0 F) ?0 W5 \6 x3 P46 A! N- j3 b+ j* c
    5: Z4 ?% e( A  X5 g5 l6 w0 i
    6' {( z. {9 d: S) V* R0 k  N
    7+ N1 H0 ~/ y0 x+ W# t  q/ K  k9 t0 a
    83 r/ q7 D% Z1 P
    9: b. J, z! T, B' {
    10
    4 ]5 l9 _: U( ]- J( g; m. J% l" p5 H11  t5 g1 w# }) _. D% C
    12# Q4 _5 f$ h1 [; Y& h
    138 ?! [' n+ Y- x( v/ K
    146 V5 v4 w* T; L/ b+ W
    152 C3 E% l4 y) C
    166 m& v' n$ ?7 |3 ^& \' j+ @: t
    171 Y$ \4 l3 N5 e- o+ H! }
    18
    ( {- F& J5 v+ ?7 v( A19
    ; Q" B3 C. q7 }1 L7 v/ ~208 g1 V8 h8 @8 n" ]4 c
    21" O9 B6 w6 R" n9 \5 `
    22
    - n. [' g9 _6 j. Y) P. ~0 ]4 }1 h23
    . ~0 Q% F! i: j5 X& O8 o4 D. Y24
    - l% Q6 g3 |' \4 w9 P) V25; j1 G" Z& L5 D1 A9 o: `+ U: J
    26$ M+ ]1 B+ c  i7 ?& }% W& P
    27
    - d; f3 t1 _2 V; a286 P' o& S  ?) z" r
    29( W' J* ?- B; o1 P8 s( q0 N8 j& a
    30
    4 Z- N4 R3 m4 W( y3 @/ b" _31
    2 V( x5 q5 P( O, e3 O32- {7 C% P: Q8 ?# i3 r
    33
    % H- g5 G  b/ d7 V0 |34
    7 ]* Q8 a4 V1 ^7 x( @4 H+ g35
    ( \: k& n# T( t36
    . }- f0 y, c/ b  S4 U: g( B: g2 ?37' s$ _2 B2 J# z
    38
    ( r. P# ^" g! M$ O) p39
      f  y# I! k1 C; ?& M" C- H40
    * B. I/ M& a6 s% G41  w! Q) \' |! u4 G1 w3 H' {3 o) z
    42  m6 [* O$ [4 i5 d2 O6 J
    43* X$ F7 s0 H$ _9 s: j- `% ^" w1 }
    44- f) Z) f- }& T& W, z5 j7 u3 f
    45, ~8 I. w6 c; X0 w* A  r
    461 ]* ?$ k' q0 Y$ O
    47
    4 `4 B) \9 n: o5 E5 ~48
    4 G: s8 ^8 m' y4 d3 R3 C49
    0 K+ ^, [. H4 n  k, C1 P50
      H0 W* Q  k/ B  W% F51
    , q& v: ^$ l) V* U52
    + F  ^1 T4 v! J& A2 @538 j5 k* M0 D) c- p
    546 o9 c" `. x0 U1 u
    558 {7 Y4 g+ o! r- i+ V+ N3 h# M5 u1 L9 l
    56! y* r/ G9 X6 V! c' K' ?0 X' b4 r
    57  \1 H2 `* \0 R: h( a8 `# C% o3 E
    58
    + y: c8 c  w6 a' i& i- Y0 h, o59# f- O0 j; x/ j7 c
    60
    : o1 C$ L% q- v& O* N& E61) R  q6 D9 U% m) T
    62
    1 t' o* r# _  l63
    5 y1 B; }" {8 q64# ~0 b  \$ r  V0 _) V( ^& J
    652 c" K& T1 Q0 ^; ?" i) q/ i7 Z
    667 h# t! K+ Q  q" ^' R
    67: t7 T* F$ }# R* P) l2 F2 e
    687 ]( f5 t, q! y  K( A2 O
    69
    5 A. R6 {& x* z8 ]  i# g70
    7 V9 T* P- w! Z9 s7 ?3 l71
    6 K) r3 P% H- G3 Y( h72. K* |9 @! O# M; f2 n
    73( q' u! ?6 g! J8 g4 @/ C2 f0 W& @. }
    74
    ' b# Y" \8 f3 K; c2 j754 P9 q; i+ C+ z" E" ~# y
    76
    ( g; `. |$ e  `: v- K# r8 j* m7 {+ h77* }8 N% q: K$ ^
    78; n( N6 h0 ?# \( X& E7 z  p+ u
    79: j3 b/ X, y+ G" p& Y) A
    80: g3 x( D) P% o2 q# X
    81
    . v: ?% O8 E" w+ ~7 v% t$ ?82
    5 U( ^8 o# V2 G3 w  y8 ]. e# K83
    8 I' L% ~  B" d. V84
    * k4 L8 E1 O$ C/ E  r' E& \+ A; M) B9 M85
    : W6 q; ^: U, y# L. B8 W86+ Z9 a: n" J; n
    87
    * k% p) ^* T% A) ?2 m; V' o& w1 V881 d7 {4 ~- P4 N) X( s& e8 E
    89
    2 Q2 o* O" z& }3 o90
    8 i' d0 _) J; v" Q919 T  h8 t' K  u6 l/ [0 x
    92
    1 S$ g; G6 i; R, f% P93
    . `0 z  R$ d1 P% k94
    ; N0 U# @5 t9 G0 @% w95& j0 S$ o9 ~5 A9 U' I: N( U& v8 y
    96& Z# [! `3 W2 c
    97# v% J0 ]+ F1 v$ u
    989 L/ D, T$ `' T& y+ w
    99
    8 ?# r, ]- @( V# `100
    ! p# C% C! f( K% `8 ~0 D101% [* v% g9 t) z4 M5 ^
    1021 A$ E, }2 b/ R' S4 N* s; h" s: H
    103* r+ W# F7 R) \
    104: U7 d$ j3 Y- n0 m9 h+ Q( T0 H
    105
    . c0 t% m5 T; q0 F& [106  Z& L7 t! X4 W, o
    1073 d. Z: Q; N: m5 A# P6 J" m
    108/ u& w1 T$ E  r; ]
    109
    / ]4 [. Q2 Q5 F6 P; f110  s) Y: A. M# d
    111
    ; ?! S- b9 A5 |( d7 X) b" B" T112% Q/ n9 b! g+ C/ M9 S1 @( a
    7.2 开始训练模型
      ^, [+ P1 |% W, b' F4 E) o1 `我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    0 g/ x0 i0 G3 M) }. B/ f9 ^
    1 z4 H0 `' n' _) M7 k8 U7 y#若太慢,把epoch调低,迭代50次可能好些
    : _/ v. Z& q1 P$ b/ j# W3 G$ u5 l#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    5 L% A. K' I; ]! u  m0 }$ q& Imodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))5 _- ~+ j1 ^  R
    4 {0 B2 k3 z9 }  L/ \
    19 Y+ i6 J3 Z) d
    2
    . }) L! l3 g- j; G5 |9 N3
    - i2 z% p' i+ n. L* O43 p' D! V3 J( V7 D* k- t0 `
    Epoch 0/4
    ) A; l; C# q9 D( _- B# t  n2 z----------  I4 F8 e. ^( \& Q( r
    Time elapsed 29m 41s/ W/ Z4 S/ a3 C. T" b
    train Loss: 10.4774 Acc: 0.3147
    2 @0 O8 f2 H# u) w  h) j5 YTime elapsed 32m 54s
    ' o8 X- J3 t( Y; vvalid Loss: 8.2902 Acc: 0.47199 H9 ^9 |) D% r  d4 O5 }8 f
    Optimizer learning rate : 0.00100002 p" I, f! K2 C& a" m, d

    + A/ v* {( m! \1 |* r. aEpoch 1/40 R* m# A, @; E; t$ z
    ----------
    2 M# B& K) R, N" s! M% gTime elapsed 60m 11s
    & [+ J$ Q  |' ?5 b$ mtrain Loss: 2.3126 Acc: 0.70536 v- j/ m4 T: n
    Time elapsed 63m 16s
    1 o4 l8 M+ J& \- Svalid Loss: 3.2325 Acc: 0.6626. A8 ^* c4 Q+ g/ Y! E4 a& v
    Optimizer learning rate : 0.01000006 N& u: H5 l1 s+ ?# ?% |% b

    $ u# z% ^+ i" I9 f* LEpoch 2/44 A1 t8 V$ B4 x* m" b( ]
    ----------
    2 P4 H# x1 ~6 m8 F- bTime elapsed 90m 58s% T. q2 d* m: W2 c% V2 l8 |
    train Loss: 9.9720 Acc: 0.4734- Q* Z* L- {2 t) c# M
    Time elapsed 94m 4s
    . g; N: y  Y+ N6 Jvalid Loss: 14.0426 Acc: 0.4413
    ! ]3 I, b* W4 h0 B, H, t/ EOptimizer learning rate : 0.00010003 ^' Y& ]) F. w& d
    $ s+ t6 Z% {: B! X, g. a7 @9 k
    Epoch 3/4) `, `) q+ a6 a8 }$ i/ D
    ----------
    7 ~% t, `& P5 Z2 W# oTime elapsed 132m 49s
    3 I/ s" S  v" i8 Ytrain Loss: 5.4290 Acc: 0.6548$ F( k8 m0 F& h8 z# A! `  p7 I
    Time elapsed 138m 49s) y: }/ F" p& ]" ?! l' {) J. r8 G
    valid Loss: 6.4208 Acc: 0.6027- r8 [( K/ W; o! d2 _& d$ s5 S  \
    Optimizer learning rate : 0.0100000# x9 i" s9 K8 j+ ]- P

    - v& Y' w1 y2 F6 H- c, i0 \Epoch 4/4
    ' s5 A6 K' e$ {4 U0 r----------2 @$ p! P+ i; [
    Time elapsed 195m 56s2 G& J: s5 Y5 i+ r$ ?  }. y
    train Loss: 8.8911 Acc: 0.5519
    - I% M: [  B7 L9 ~$ X  |2 r9 y3 ETime elapsed 199m 16s
    / a: f  y6 O# Z& ^valid Loss: 13.2221 Acc: 0.4914
    - F9 [7 y' z7 j. q: B( }Optimizer learning rate : 0.00100002 s1 S) M: S3 n! V5 `- F
    ; e6 s3 y" Q" r: @7 b2 o9 g4 H
    Training complete in 199m 16s/ x* f2 t' i$ K" b
    Best val Acc: 0.662592
    ' ~* `  N6 C: Q8 M$ i+ ], m8 @
    ' F3 H3 M1 m3 V; O1 B1
    - Z: X% H! ^: n/ w( L2, R8 z, n: T& e
    3
    & [' W- b* e% i* x  ^8 ]/ @; V0 P4" k' T5 H' F1 R, M, i# f
    5
    $ S3 E0 `% t. ~: |9 P68 P' l6 F( o' `- P# \- \0 B
    7- f  B9 L' x. e% |: z8 z3 A( |! v
    8
    1 P) M: X% f& ^' \. J98 F1 N/ k  k' p* }, ~
    10
    / `: o) }; }  `& p% C11, r6 ^4 b5 J9 W1 \5 H& N
    123 Y- I" |2 [. |. M- q- s0 z' A* p2 U
    13; a& B  X: p9 A% Q9 n0 F, i3 u
    14" Q+ _( d+ V: T- i5 D- Y
    15
    . T4 K: o' S8 [! y" ?3 E( Z$ \16- }/ X* e7 i+ o0 G
    177 @% w- `+ V! {
    18* U& `  v: V) w* e
    19
    6 p5 ?# O+ D9 u6 }20
    9 m6 T# |9 o$ }, m; O21) b: \7 g1 i8 U* Q6 S9 Z
    22
      d% P' J/ {* C! r( @1 H23
    4 B3 x" L: H6 [* r5 `4 N24& u! w4 C  [1 _, }- ~
    253 C9 T* `6 A" X; a, B& |/ R5 \# z* V( X
    26
      w; f% i; w  T0 \) ]27
    6 A5 S, a/ x" `" w$ c: ^/ k9 W28( D" ?" d8 D- j! s0 E0 W& w
    292 y; d' ^6 G' s- {3 l
    30
    # s, A- X  E! m31
    ) w5 J# u: V* F% U+ a: B1 \# @32
    % r+ x0 |$ a1 T: V33- j5 {/ r- Y# ?8 z# s4 \
    34
    : c+ g" t( Z, B$ M35
    0 S# H& B1 f/ \- E9 W/ t36: Z- R. ]1 l, O/ j2 L/ o3 @
    37
    $ `3 u- s6 T9 y4 d  F38
    / b5 n/ {, k: D, H: T* t) g39
    8 `3 j; }1 V- T. I# n9 M& _40
    0 K% F' Z% `: E1 z  }41
    ( t6 Q) P; x5 l- g. ]- f" N42
    5 f9 F  x: {8 [5 P7.3 训练所有层
    0 F3 U0 w& n" ?# m9 `0 e! H. ]# 将全部网络解锁进行训练* ^& P5 ?- O) A! C
    for param in model_ft.parameters():
    + ?7 R! s# x$ u! Y9 C* d    param.requires_grad = True
    8 n( k% x: J$ ~* R" _
      l  t8 X6 H3 R# 再继续训练所有的参数,学习率调小一点\
    5 g, s3 q7 Q0 Z4 i& Qoptimizer = optim.Adam(params_to_update, lr = 1e-4)( a- H3 c8 n2 Y: p  H* I
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    & V' h; }9 }. b  d2 }' x8 M1 u( e% k8 n; }3 r6 K
    # 损失函数
    " L& J6 L. w9 e" ^. tcriterion = nn.NLLLoss()
    6 p, f, n  Y  Z! j8 K/ y1
    5 U8 T( \" q" e/ G3 u/ O2
    + Y6 H+ N3 Z  x; O7 F3
    3 V+ c3 K% `7 h48 N/ a6 |! ~- F* k
    5
    7 e+ |* N. Z: l0 J% u6
    ( W( K2 H+ x. B/ z7
    ! W/ t' Q! w$ W, ^( m86 [4 h7 c: a3 w" \
    9
    0 q( m! I8 z$ K) A( h10) _0 b7 r6 f, q0 y
    # 加载保存的参数
    6 q9 N0 u' |! _- }% l! F( }- i' L# 并在原有的模型基础上继续训练5 Z! A* n) N+ V* T
    # 下面保存的是刚刚训练效果较好的路径9 s  W$ s5 c% h  t- V' z
    checkpoint = torch.load(filename)# y1 G5 y4 B* s7 f# [0 z2 k$ C# y
    best_acc = checkpoint['best_acc']
    / I! ]" H# }" u# A7 B2 x. `4 Omodel_ft.load_state_dict(checkpoint['state_dict'])+ Q3 J  H$ d& H7 k, x
    optimizer.load_state_dict(checkpoint['optimizer'])
    $ w" Q. K4 p+ f: L1
    5 e: [# \$ w0 F" \1 K& }22 z; W0 v8 s8 Q, i" X
    3+ i0 `8 [2 e. H( s
    4
    8 ?  G4 N( H- X3 W) L- T! N5
    0 i" X1 S4 A: d# r, X1 _6, u3 g' \+ V* H
    7
    : k0 c& t# p) x# T& d7 }5 q开始训练
    3 |) w0 f, L4 S& |  G注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考8 ^" ]" Z' F3 P6 `! k$ g4 V+ X" e
    : W: F" R2 b5 K
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception")); e) T( V+ x( Z7 [" s0 @4 ]
    10 O8 v  q8 H5 J5 s; M
    Epoch 0/16 W* c* A( l- F3 {8 a
    ----------
    6 {5 W) k" E* E# VTime elapsed 35m 22s
    % w( Y) `8 A" f5 W( Vtrain Loss: 1.7636 Acc: 0.7346! G3 a+ x+ t0 x$ ^, P2 y8 R
    Time elapsed 38m 42s
    ! A5 J7 y$ r+ c% q7 @/ n% m% ]6 H3 y; Y8 _valid Loss: 3.6377 Acc: 0.64557 f0 f, k1 n* d3 Q( ~/ k
    Optimizer learning rate : 0.00100001 I% K2 u  L5 d* S( }1 K

    & T! S9 |' s( f# l! mEpoch 1/1
    + L6 a6 Z  F9 o4 \1 F----------
    ( w% r) S4 \: S9 H8 b' D2 L" STime elapsed 82m 59s
    , K5 Y9 k! W- L& @9 }: P7 dtrain Loss: 1.7543 Acc: 0.7340
    ' I8 _( X3 N; |7 t: PTime elapsed 86m 11s( ^) e$ k% ]1 c  y" u. @$ @
    valid Loss: 3.8275 Acc: 0.6137
    ' b: K) B$ I; V/ d9 ^1 ~* P0 d& G  ROptimizer learning rate : 0.0010000
    & y) P. ~) T5 I" b" @
    ! `' W+ F2 X# iTraining complete in 86m 11s
    $ m. I/ H5 c* VBest val Acc: 0.645477
    - e$ O+ `4 Y# r+ X" g; I6 q$ a. p0 }6 p8 L. f, r
    1
    : i8 r* ^+ Y) e8 V; m2- D: ]& w' X6 j8 F, D2 ~
    3
    : }9 D5 |- l' F6 x( w2 m4
    , @+ _3 U( Y$ r' z2 f5
    ! I9 G% ^% A4 s6
    8 V: _5 q9 p1 i) o3 O  N7
    + v, e1 a' G% w9 i2 G$ q2 i1 ]7 X8
    * [* b+ c* T! I0 U9
    + b+ w& F# ]7 Q' b' S10
    2 x3 L" x& N. p; r4 r( a/ f8 B  a+ k! r11
    ) P' n3 N- K, L$ \. j126 f) C6 q) s/ Q" x) p+ k" E
    13
    4 T4 d5 K9 M; L* q% r7 f% E14
    5 D, ?' K3 p/ u# @# }1 T15
    4 {  f* X' R" N  g16
    7 Q$ B% |+ A/ u; _8 |- m+ ]17
    4 z  n  @* }! Q- j18( o8 ?3 e# b: @* e
    8. 加载已经训练的模型
    ( P' w/ D3 t  Z) w0 s" z5 X/ Y相当于做一次简单的前向传播(逻辑推理),不用更新参数9 O- _& v  t" `) H( h4 O/ ~, e' K
    + R9 J. u$ c$ @* d, |/ ?7 U' Q
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
    8 u; w+ @& b( Q6 {# G. R& M4 D8 I3 L. R' S
    # GPU 模式
      `/ w. i. q1 z4 D1 Dmodel_ft = model_ft.to(device) # 扔到GPU中$ r3 y& }3 Y! j# r  x

    $ c9 u% P: W6 S" A5 ?: z# F# 保存文件的名字
    : l4 p) Y) N# k1 c* c3 T8 Efilename='checkpoint.pth'& ?2 ?2 D6 T0 d- d$ q% _

    & H( m  ~1 S4 g" o# 加载模型+ C& {7 X. m/ E5 n
    checkpoint = torch.load(filename)
    $ P- o* X2 i$ S$ L* ]best_acc = checkpoint['best_acc']
    2 D1 X. C' W: o8 i# Xmodel_ft.load_state_dict(checkpoint['state_dict'])( ^' K2 u, e& d' |' m! e# ^
    1
    1 _) S2 K  E3 J2 T: v' a- s0 u2
    : F, @1 G2 [, u! t( T- D3
    $ `. V1 B  Z! u  n* d% g1 m% W4
    1 R  E+ w; F( ~7 O3 N  K59 T3 l3 N  P3 _& d
    6- x4 R  ]+ w1 N* T( u
    79 M8 {6 ?) k2 ?! E
    83 @8 C9 z& \5 l' Z" i
    9
    , R( y/ A; Y* x( l10
    & E  p' i5 U! J3 b" o3 Z% Y111 w4 ], ?9 |5 _; O  U" l: l1 F
    12% t" P; P; D) r: b% d* [
    <All keys matched successfully>) b  o* @/ ?. T0 d+ j6 u
    18 R: [5 z: B' i% H5 a) }
    def process_image(image_path):
    ! m, J/ }4 Z' A    # 读取测试集数据+ D3 f2 q3 o) B
        img = Image.open(image_path)7 |! Y$ _+ R' t: E. x! b8 V
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    7 T5 o& E7 G$ b( s0 O& l5 t+ u    # 与Resize不同
    ; d1 `! I1 @$ u: y) [7 Y    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    . W" Q4 B& B" D* s; g4 I: C4 y+ P    # 而且对象调用方法会直接改变其大小,返回None3 V. j. i. {0 X
        if img.size[0] > img.size[1]:' b) K4 H4 M, w# t
            img.thumbnail((10000, 256))7 X( Y% s6 _2 i: @4 X3 V+ w) R1 f
        else:
    # r' `9 Q  `0 \$ M! @: b, P% H        img.thumbnail((256, 10000))
    % ^$ Q  F7 Z6 f. @' S$ Z$ P
    4 s; Z. W! N2 ]( ?1 I; y' ]  ^    # crop操作, 将图像再次裁剪为 224 * 224! d1 f7 @, S& I" @
        left_margin = (img.width - 224) / 2 # 取中间的部分6 Y* u- `6 M  J  t" {' n# D( |
        bottom_margin = (img.height - 224) / 2 6 a( |, S( q, N) S+ d
        right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    ' e; s) B+ b4 f" H7 L- |    top_margin = bottom_margin + 224
    2 }) S6 F$ A% P2 p0 L! k, u+ L' r# H3 ^" o2 K5 d! K8 M
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))5 C; z& s' I; o: Z" Z4 A

    & e& y9 q2 I/ F1 Y9 S. z    # 相同预处理的方法& z6 `* W6 p$ m3 `( h3 {, n  j
        # 归一化7 I7 C, B& G5 b1 x; h. M  \
        img = np.array(img) / 255
    + z+ r6 K$ t& {" N! U6 J) ~    mean = np.array([0.485, 0.456, 0.406])- {/ Z) w2 s( g( F% U+ z
        std = np.array([0.229, 0.224, 0.225])
    ; [3 [/ r" B7 C+ w; [  ~    img = (img - mean) / std0 c5 W, `# M( y$ e# z4 d1 Y5 F- Q$ I

    - B( b0 H% M/ r  ?& z0 n    # 注意颜色通道和位置
    . A5 t& _/ D& r    img = img.transpose((2, 0, 1))# ?) f/ _! g# i* r

    , A5 o0 R, Z5 ^: q$ ^    return img
    7 z7 \& L# P9 b/ g) {9 G  T: x) x2 p" t+ ?6 D  [5 f; e: ~- Q' J
    def imshow(image, ax = None, title = None):
    " W; C4 L8 ~  |    """展示数据"""0 {% w# {6 d2 Z! E* o! b
        if ax is None:
    8 x9 Z0 p; @9 x. d" c' u        fig, ax = plt.subplots(), i; X! Z. X  R( a

    " }' M$ p# {1 }( {5 `8 E! }! i    # 颜色通道进行还原
    7 r3 T9 B8 t8 y4 D' ?    image = np.array(image).transpose((1, 2, 0))
    & X9 V6 q2 q; M# u7 F) q
    0 T( i0 G- H' `# F    # 预处理还原5 i$ p3 h( p4 P. g" [7 C2 E3 Y
        mean = np.array([0.485, 0.456, 0.406])
    ( y/ f& e4 @, U0 l    std = np.array([0.229, 0.224, 0.225]). r4 t9 g" b# ^% l0 c0 P$ |
        image = std * image + mean
    , Z5 T! g- |; [    image = np.clip(image, 0, 1)
    4 ]/ R6 l8 A5 _) M( {. X$ I0 r0 T# D. ~( ^. |
        ax.imshow(image)  A4 g( Y) B) C. I% ?
        ax.set_title(title)6 [/ K& c0 ~5 I1 |5 x; `
    ) J$ z+ e. M1 G7 v4 Z5 m
        return ax5 T8 O' E+ f8 j* h1 f3 U
    # t/ N. l0 e# M8 u. K$ X
    image_path = r'./flower_data/valid/3/image_06621.jpg'! N  O$ D" u& _: x; o' P8 g* T
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理$ F8 D' {' p1 B, ?/ m* N" n
    imshow(img)5 [& V6 J5 `. |* b5 M- {

    2 j0 ?# F8 M5 B+ c7 \% {# U15 Y, u: Y" Z/ t; }8 w$ H
    27 C% F& i1 i0 f0 u  t" i
    3# ]! e+ b. k) J3 n2 L3 E2 @
    4
    8 F* B# D+ Z, J, M1 a" T3 o5" y* x, q" i: X  }9 g& d9 ~2 P) F
    6
    2 L& O: b1 P; j) F7 W# t  |! O7
    ; ^2 D4 K/ }( a5 a6 Y  ~2 E8
    4 L2 E, {+ i$ t' |3 D9
    % T( a0 [( H% _& F8 Y10
    6 n% S- n4 Q6 L7 D- P! v! `/ L( z( \$ D11; k* B1 y( ]0 m/ N- y
    12/ \7 U' Q& L( H( R
    13
    0 x6 ~3 S2 R: z140 N  [1 v; a% Y6 d& s: |
    15
    - ~& a, R$ v3 d. p5 _; y7 o163 D6 h6 g# x/ j/ g- _# I( P
    17
    3 ~0 K6 e; I9 j( j18
    6 N7 ]1 I  L' Y/ p' X" k19
    " ^6 n8 x* r( C: o" x7 A202 @  u+ D* A# X9 @1 p' k1 M
    21
    * ^! H. C2 [8 P) l22
    ( q) y1 @+ e" [$ M/ `234 B: F. U# E" M- h$ h8 X" R
    24
    8 \# V! b6 J4 m- [$ S9 U# h25
    , \, ?# m0 k* X& T' ^26
    2 h; c& U7 U: l- D7 S3 s27- D* _9 q, w, Y( l
    28
    ( c" K. f# E6 M' C+ w' B29  d. n! L  l' s# N
    30
    # i7 }, \8 \9 u6 ]1 r1 q) L3 F31
    8 k: C8 m2 Q9 ?0 Z' ^; Q32# p8 x* r3 H+ D7 u- y
    33: |  r+ K/ e5 t4 ^% ^% A
    34
      M  P8 L6 h% {: O- t35
    " q8 `! u& |: v) @% [  X% z36' C- h7 p& ]. |6 T
    37; \, C/ v9 _0 G* y" b0 u; G+ f0 K
    38- m, T7 O/ p% Z: I
    397 R# ?6 e* K8 [% Q0 s2 e
    40- \! ^" k1 o' r2 d2 ^, x+ \+ T
    416 Z: q3 o1 E9 M2 L4 C  X$ o" o" N
    42
    , {/ r' q9 `6 H9 X43
    : }7 ]2 H% `1 U. l44; t0 U; S- v3 _
    45
    & `4 p2 G" I/ E( ?: O46
    , A2 x2 H6 [3 ~478 f0 R0 z1 ]- a: R& K  [$ H  B# S7 M
    48
    3 i  b# j/ }' |9 E0 z49( C. f2 R$ r3 x5 n- ]
    50
    0 l2 W  Z, n7 r, `, |518 z; U9 B! {" U. K" U
    52  F. m3 x/ q' }
    539 o4 T; q  i+ }  }( |  n, s7 N
    54
    7 d% c+ O8 U: x8 I( P2 P<AxesSubplot:>
    9 v/ l' J3 _8 z) W5 Q$ R1$ j3 g; A8 I  w- w
    8 l  I" \: Q# X/ ^6 U6 V
    上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    ) Y# b- \* w7 I! S$ [1 K( d
    ( R- _! h1 a4 T" R! r  t$ oimg.shape
    ) F: r% ?3 [# }& J5 {# w1
    8 R8 B, S$ ?3 i(3, 224, 224)
    # R9 l7 q! W+ `- t% d1
    # {& J1 Z$ _: K( p8 T证明了通道提前了,而且大小没改变7 j% u0 ^* U, t. J( Q" K

    3 C% i8 O# }7 f$ F9. 推理: T$ g$ J) K( K
    img.shape/ ]! Y8 f4 w( l: e+ @

    6 b9 A+ C+ B# \& N0 D! c  `( i# 得到一个batch的测试数据  Q) j2 Q1 z5 F4 d) [1 m% m) f9 j
    dataiter = iter(dataloaders['valid'])
    0 q+ ]; k; C3 _2 G+ a- u* l# Mimages, labels = dataiter.next()
    ( H4 I3 r' h0 y. e) x6 f! a" s
    . e; A. a7 m) X; E4 N  |model_ft.eval()
    * B* u* _3 A4 N3 Z  a" r
    5 k/ q& O; F3 j4 x: p$ C5 s& pif train_on_gpu:
    0 B1 F6 H7 T% k  m* d    # 前向传播跑一次会得到output
    ' G) M" `( n; [( X+ ^- U: t7 q    output = model_ft(images.cuda())
    # u. v( b' @0 K( ?2 b8 Eelse:
    : n& e8 F: [: ~3 F' |. B    output = model_ft(images)
    - D' k! }- W- ?; V. ~- K  m8 I3 l
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
    ) z- ?2 a6 t# P4 q) Ioutput.shape
    : R% P9 p2 w+ x, o% l2 `
    ! a$ H/ A1 [" u. R2 [1
    * }& P9 |$ I9 Z/ D2/ y; h1 M  W: |: [( x2 t$ O
    3
    2 n' L& D/ O! x4
    - w8 u' ]1 C0 n! N+ U5
    ) x7 {" J& C+ L% a9 z5 r6
    7 x% ]  p9 V8 i( V7; j  ^# x, c; E) s4 m  M
    8* g5 C+ I  c+ F  R0 D+ M( q
    9
    ) g6 Y" p! h+ t10
    ; b3 K1 `# }3 j) k7 K( q8 |117 a/ E* f# X9 K! u5 n2 v6 L& o% f
    126 g# c4 l% g( i; }9 a
    13% Z$ U) h5 W4 ~0 B; V
    14
      c3 E0 i- z: Q6 j6 a( e* u( `15, \$ W$ G% T1 u* m9 K
    160 E2 C( @6 R7 U% ]
    torch.Size([8, 102])9 ^9 v$ }0 w+ L5 N
    1
    1 ]& S8 @# V9 \8 ^$ Q9.1 计算得到最大概率6 O6 _; ?' V. @9 ?6 @6 [
    _, preds_tensor = torch.max(output, 1), o7 C' s$ o0 Q: p

    4 H4 |* W" \7 `- D  d% T6 Zpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量, s+ |/ C" A$ O7 p6 F0 y
    1) R9 ?# O1 P) G1 N+ k+ {3 ?
    27 g: P/ H6 [4 A5 \5 `2 g
    3
    , R: c' V, Z4 U+ w; S- N9.2 展示预测结果
    $ k1 Y% o9 b: Ifig = plt.figure(figsize = (20, 20))- ]0 c3 C7 I4 P/ }# w# F. l
    columns = 45 ^% G8 F2 h- E. O* \
    rows = 25 p; c) I/ b. ?) U

    " [2 M: ]3 Y& M. i* pfor idx in range(columns * rows):# W& B) r- Z: q8 X4 U* o
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])( `* f7 s# k7 r- n' s
        plt.imshow(im_convert(images[idx]))
    & {+ ^" f, E# G- M    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
    & W, M& U# k3 d6 v0 K                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))* L8 |2 `0 a; a* z
    plt.show()4 e- M- |( I  L. e4 k: V
    # 绿色的表示预测是对的,红色表示预测错了
    ' P5 {: j+ S. y19 Q6 ]0 G. k' I2 ~- v% \
    2
    5 h7 Q; M8 q2 _2 ^3
    # _, W9 }. y1 @. p4
    ; U7 `1 L7 z! O  q5% ~6 d0 F" y$ m+ c! l3 s
    6' }8 T/ {4 ?3 q
    7/ `: z) n: z* D8 J. r6 w# M8 `
    88 j4 c& e" `$ P$ Z
    9& _5 k% Y, L3 M2 \) B# j
    10
    * X. v3 \& C3 J11* W9 W- U$ ]: u* m2 T
    ' [% ?$ C: R0 g$ s

    , V9 q/ @4 n- b- k% I
    / j9 a7 p" Q: m————————————————
    ( u, J; ?- V  m6 M( F* G8 S版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。7 I- }7 A* q  ^4 Q5 x
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940- X( W# F+ _/ C) d; b

    8 _7 L4 z9 ?& `! a
    # b3 `0 G, w" `) G( L
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2025-7-26 19:04 , Processed in 0.592074 second(s), 50 queries .

    回顶部