QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2753|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例2 K. m4 G+ p8 k
    0 J# B& B3 {9 ^+ F5 d2 Q
    文章目录: _% n1 `1 q3 m4 C0 u. V: a
    卷积网络实战 对花进行分类
    7 i& E' ?# h3 i* c  [: s数据预处理部分
    ! e/ E& J& G+ T& L. I) l网络模块设置# {, s/ O$ G& _9 j0 u# C7 Z- j/ L
    网络模型的保存与测试
    % \; k7 `. N% i+ t数据下载:9 a4 `8 }& d9 o, w
    1. 导入工具包9 P% f" G8 w0 M+ m& `: @& l
    2. 数据预处理与操作
    9 V+ y& G% D$ v* H& c- ^. s6 {1 L3. 制作好数据源6 q3 s" @" E3 T; N9 S( O
    读取标签对应的实际名字
    , P7 ]& B( W3 E4 r2 d% P4.展示一下数据
    7 |+ b+ ~, [5 n( U9 u5. 加载models提供的模型,并直接用训练好的权重做初始化参数5 M/ ?2 j  [# I
    6.初始化模型架构! ?& b2 o/ E2 q, Q
    7. 设置需要训练的参数
    9 C+ H1 _4 J" m/ m7 q+ `7. 训练与预测  [. S5 s+ ~" D5 z
    7.1 优化器设置
      f9 [1 }) Z" q# {, u5 ^7.2 开始训练模型
    ' ~! ^1 M: D, r+ ?# E7.3 训练所有层1 [  \# v7 c8 y* e; K8 d
    开始训练$ S" `; ^, L. ?. T8 j. D* s
    8. 加载已经训练的模型
    9 d1 c5 Q7 o; w$ r9. 推理7 X; U2 ?8 W, y( Q2 h/ |3 S4 x* o. n
    9.1 计算得到最大概率
    + [' L% |8 n, ~) h) J" B! R% @9.2 展示预测结果
    ' q1 W* ^; e1 ?$ U$ {7 B9 i写在最后
    : s" i0 X8 N! w卷积网络实战 对花进行分类
    5 c5 a) S- Y' F* l; ^本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能; s6 \2 Y1 d  Y6 N" P
    ; x0 o. A7 f% q- v
    在文件夹中有102种花,我们主要要对这些花进行分类任务
    # ]8 U) R' ^6 h2 I* e7 u; a5 R文件夹结构
    + h; L# P3 M# u  x; j
    5 l. g0 m" ^4 N0 ?: cflower_data
    / L9 [/ y% G5 ?8 }
    + |# D: f' |' |: {8 }train
    - I# v, _: b! A1 H& e% A5 F) ]" ]% J5 ~; _) a+ |
    1(类别)4 ^4 {- p* }% v! u6 ~9 P6 K
    2
    # M3 c9 R& Y8 r2 \xxx.png / xxx.jpg0 Q  ]( \( P/ T
    valid1 K  f. `7 x- J* |
    & w6 G' w  W$ @: c
    主要分为以下几个大模块
    ( p; K. _, E( t$ Y4 T0 X9 S8 r( r9 H
    % P7 h9 W) Z; e+ `* w% M- \: `数据预处理部分* f# u, \6 i0 n2 @  G
    数据增强( [, B' J- p$ _& m0 E
    数据预处理
    + N" o7 @7 b) R" h" g: l网络模块设置) v" ?$ p: ^3 w& L/ q
    加载预训练模型,直接调用torchVision的经典网络架构. r; [8 Y: [- o: g: d+ K7 B0 d
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务7 B5 ]2 Q7 P; v- w
    网络模型的保存与测试
    - F  I1 K" B4 C' d, R+ h模型保存可以带有选择性
    7 B+ U1 P9 _  A$ l9 X! A; h数据下载:
    $ P1 B. C0 S/ l. m' C* K) `  Thttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset% |* j# ~5 l3 O, f  {" S& r: ]

    + Q+ W+ {6 K0 u0 V  a7 s9 L+ J3 ~" R改一下文件名,然后将它放到同一根目录就可以了
    . L) Z, r/ W7 [+ a  |- |- j3 \, L7 y6 Z; \, G6 `7 ?
    下面是我的数据根目录: l2 u  e' ~; ]# K! f

    ! U3 o2 _0 M0 O( ~4 [3 i& k4 f' L0 b2 i$ N8 v
    1. 导入工具包
    ( s; @0 g  b5 H% k  @import os
    * m0 g! x7 o; s/ o, [& z/ Wimport matplotlib.pyplot as plt
    8 K- T$ A, \" R% |% I# 内嵌入绘图简去show的句柄
    $ E# V* P0 U3 X" Q  o%matplotlib inline
    / O9 \5 {' C3 A- G) w0 O4 jimport numpy as np4 a  J& m% I; d
    import torch
    & Q' o6 e# I6 o6 qfrom torch import nn
    8 `1 f4 l; P. I: Z4 |$ u
    " i0 T% c; U) o8 zimport torch.optim as optim
    6 i" u8 F2 c! G& Fimport torchvision
    ' x2 M! g! q/ i  O& z! Pfrom torchvision import transforms, models, datasets
    1 g! j& J* o- Z4 e: p2 S5 ^
    , ~* F% O  n# F- Rimport imageio
    5 q& m1 \" L2 ]1 f8 N7 B7 _import time
    5 r5 u! U, m' ~- |import warnings
    % y: _) o( |( Fimport random4 t4 X' c0 F% ^6 U4 R9 }
    import sys7 P2 w8 o3 R7 m' M* I
    import copy- g4 o! l3 {) a: u( A0 z0 E- s
    import json
    0 J+ a4 R/ Q( t# I  x0 ufrom PIL import Image1 R) l4 s$ [+ W" R
    7 D) }( Q7 w% a/ V

    . u. f# W, x/ z$ I: i5 a9 A! @1  G/ a& w, S! d5 b2 ?7 R
    2! F  v9 @4 \3 w/ T
    3
    & y! A. J" m! r4$ B/ l( y. Y0 I) L
    5
    8 a/ \3 g+ K4 @( M( Q+ W6" ^2 t9 G) q" a- `  b
    7
    - K: _4 j  n$ p/ i( j2 ?( \- r82 k$ M4 t6 H" |8 B: y
    9, h0 w3 h( g4 L- p) `
    10
    $ l7 B8 c/ B, V4 G0 x11
    / I7 M, t4 ^4 ]+ I  a# G  @12
    : R! }) E5 |: X$ a13# b! \1 g; P2 ^+ F
    14( C7 ~: S; B1 X  \5 B
    15
    . L! {# I  [) t  k- l- l( q. ?167 o' h) f+ T; v5 G# v) K
    17
    3 ^1 o) c, w8 I18
    # d% C9 h/ h( i) W3 W, Z& \190 }7 |6 c1 N+ w$ R- \
    201 x2 ~# @6 b6 L7 H+ @7 W. T
    21
    # h1 F8 T! ?. N7 e3 J2 _. K2. 数据预处理与操作
    # g, v, t% R, z$ M1 r+ R#路径设置
    : ]' R+ p: I8 e" \! Sdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录2 F* u' F$ W2 J7 r8 Q" ]5 Y9 U
    train_dir = data_dir + '/train'
    * r. t0 L/ T2 K% ]valid_dir = data_dir + '/valid'
    0 _6 T8 {5 B9 T; B1
    % l& Z3 r6 W' W8 u2, R8 G) x) h( l
    3
    1 @0 X2 L& l: w  m  F- P3 O& }, p44 Q6 V4 D7 u- O; I2 e2 s
    python目录点杠的组合与区别! F% Z5 d/ x6 w' M
    注: 里面注明了点杠和斜杠的操作* |, V* a4 N7 Q  ~3 w/ ^6 N% V

    % ~& r3 Y2 Z0 ?( L( E" u7 |3. 制作好数据源% ]; p, ?4 X: e2 L  O$ d
    data_transforms中制定了所有图像预处理的操作
    : m2 Z& x5 B  x# r  o  nImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    " e% `) P. v8 ~9 kdata_transforms = {, P6 b9 ?; H# @9 s7 M
        # 分成两部分,一部分是训练
    . X/ u+ s4 j6 x& D# B$ {6 E    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    6 n2 n/ ~2 Z3 y5 [" E& o! U                                 transforms.CenterCrop(224), # 从中心处开始裁剪' W+ j* r4 h5 p, d2 s
                                     # 以某个随机的概率决定是否翻转 55开1 v6 ]  x  }4 N/ ~' M/ p. e
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转/ T) {9 C6 Y- l& Y" p
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    , T/ v- L3 o- @3 P( a9 I                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    8 U- t9 X, v( ~+ ]4 @' V( v& n                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),3 B5 `! H% U7 y# Q" M4 z
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB; R" i& |/ T* H; `: N5 t) R
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    . X7 W/ n* y; z3 R) Q/ W' M. J; k; ~9 W                                 transforms.ToTensor(),
    : R" k! W  X$ S( |% Q  {3 {, P                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    * u+ O- I$ J- m2 B7 `& J. g                                ]),$ ^3 w0 t5 |9 j8 D" D3 e
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化2 P6 j- [" L/ B9 g
        'valid': transforms.Compose([transforms.Resize(256),3 [- |' {7 Y$ r# m. r" R
                                     transforms.CenterCrop(224),
    9 m6 Z  d/ s' D- E" C! n                                 transforms.ToTensor(),% m% W  J$ _+ Q4 [4 g4 r
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    9 u$ j4 ^0 S  f8 {                                ]),
    5 a7 a: ~8 O; e7 U8 a1 m}
    ( ]7 I3 ?/ g( h% w% H* h) \! O' v/ G) S# Q( l
    1
    , r% u" o& e4 \) |( b: U% r! @24 i, S$ s% h9 B4 `- q$ r
    3
    # k6 {/ z2 r6 n  o3 T2 h4
    , R3 A/ N( y+ J0 j3 U4 N/ ?5 I: z5, @" q6 b# O7 G; w- o$ k7 C' M9 K
    6& I- b' \5 S, N# Z3 ]3 k, I, m
    7
    , _; a! D: l3 U3 K0 T1 n! |82 E* l# q7 J! R: e& n1 r
    9
      c3 o' n1 n9 Z10& C" u) z1 V9 i. s  F
    11: r% N$ R( K+ E5 M* _
    12) V8 M7 s7 @& u6 ?5 L7 G7 y6 f
    132 I/ M( n8 |' V3 u, H0 X( L2 z
    14
    8 }+ R; J7 _  [& R3 W8 G  U! v15! E* s4 j/ u7 U0 q7 [' z
    16
    ( y) _4 Y- V% [" Y  }17$ W9 @* b! S  z$ }
    18" K, Z, p; D/ n6 W7 @4 w  Z) f% S
    19
    " h6 b! r. O" d6 y, J6 r; q0 _20' w& {7 i8 }! A& Z
    21/ Z; L4 W0 w9 O2 r
    batch_size = 8% |' t' @4 n$ M" c3 Z' W5 o3 @
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}; @$ f8 O' v6 F. x9 y
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}5 N3 g% r( V0 m. @
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    & z5 {/ n2 D4 i; J: ?. g: Y/ A& mclass_names = image_datasets['train'].classes" _% K9 Z) [; ^

    " R" d: ^# ~( e#查看数据集合, A' e$ U& _7 [' d/ h7 l5 p
    image_datasets
    ( c  v0 N9 D4 |0 f% a* e8 k' L# h! Q$ h6 x5 J0 S; I. J- _
    1
    & t- n6 f' }7 m6 ?) M8 |1 c# W2
    : m" _( d6 H1 X. x. a3* G9 p  Q0 |) ?# `
    4
    : X2 }; j+ i0 m  Q6 ~2 ^" i. d5: F0 Q  N7 [, L/ p6 {+ j4 ?
    6
    4 I) t; u7 W* k0 @7" O) o" M1 y8 i8 ]! R
    8
    % _- @5 q6 F  ]( s4 i9 B/ D9
    1 y2 I3 y- g% d$ X{'train': Dataset ImageFolder
    ( R4 h0 c% ^( f+ G; ~& u     Number of datapoints: 6552' w% B2 |6 Y9 t2 b; X
         Root location: ./flower_data/train. }* w3 o+ \/ l8 I# h# F! z
         StandardTransform
    ' J/ ~6 n0 S- L  E  ` Transform: Compose(
    ; U, A' d; i' V                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)! H& _  ]+ |' i% A) K6 w
                    CenterCrop(size=(224, 224))3 h. P5 I% b' f8 P4 t% B
                    RandomHorizontalFlip(p=0.5)
    7 [( L/ N% E! U/ Z                RandomVerticalFlip(p=0.5)
    5 i9 `  k5 s$ J8 I8 X! q" j* e* ^                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])( P, k% W% T+ {2 d4 u6 G
                    RandomGrayscale(p=0.025)$ b0 n2 j: x: [* z/ d" V
                    ToTensor()
    9 ^7 Y% P0 i1 E3 s) k* v1 ]                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): x2 D) B. I1 c* s/ p6 D' i0 t
                ),$ x* G1 b' W8 p) }. A
    'valid': Dataset ImageFolder
    - f( s+ P! Y# ]; y  F9 @     Number of datapoints: 818
    + r1 v6 h0 G  d* e7 T     Root location: ./flower_data/valid$ l& m$ H& Q2 F, i. B4 a8 D
         StandardTransform
    : \, j5 U# d- Y- a' X; W" w Transform: Compose(9 @) F+ @, h/ A
                    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)% H3 |  x/ k+ i, ?: F$ w$ d9 i+ H
                    CenterCrop(size=(224, 224))- K/ Y, G; X# H8 T- U
                    ToTensor()
    # S/ \) `" |" N3 ~% S% L                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    % g* z6 B+ T2 Z2 y- ]1 A. q" Q# g            )}) H, @' F6 U  c
    ' j  o- A" w7 j! k/ @
    1! h3 _3 p' D" Z
    2
    ' n4 [+ v" T& b& u3
    0 t, T7 X( f# H1 {- z4
    9 D6 H/ R; Q  ?' I7 {" u5" i& Y* @, }" J, G& E" f$ |
    6) E, k$ d' D( f: @
    7
      q: R" r7 V, u6 M+ s, a5 c8
    " v9 F5 t8 B4 t9
    6 d3 \5 a/ n$ y! Z/ Q/ X0 Q' U10
    & Y7 ?0 z+ i4 Y4 G$ [11- v( c- T2 \9 @2 P0 x
    12+ H2 o/ t& k2 ^' e
    130 z5 y4 z" m0 X4 g( z7 E
    146 @( |8 l. ~8 d# T. B2 ~
    15. Z5 k) ?  D5 c% B/ [2 U
    16  z, E, v1 i; v  V/ N1 I  e, a; \
    17; c, {# S& o' D8 [4 Z( z* |' t# M  C
    18
    " B8 y  f( d+ L! N9 k4 h19. f0 ?- S8 w( v4 C  F+ X
    20; ~6 n* B8 G$ g7 x/ W
    21
    * _1 p4 ~9 Z' K* g22
    ) `2 Q. O- [$ Z6 y4 @237 n- f+ E, H# i4 U& e+ n  M
    24
    % v% ?! P% `& t# 验证一下数据是否已经被处理完毕
    ' S; n- R$ [9 hdataloaders
    - x# u# b$ v* p1- j5 q* _8 o/ s
    2
    4 ]$ u2 s4 @$ s8 h, J  d! W{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
    ! u& W0 f' W2 X 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    7 \% Q" C" C! t3 ~4 }% B+ H) N0 G1* U& I) i' d4 i4 L: Q  U0 x9 A
    2
    & m8 }4 `% F5 s1 ?dataset_sizes. s! C/ i: `1 r$ }# D3 K% }( \5 B
    1) l8 P% a- x, {( b* ~  a/ S
    {'train': 6552, 'valid': 818}6 S+ Y7 R; }; Q' R
    1
    3 a$ S3 \. H# K3 v) c  o读取标签对应的实际名字: y* _8 T8 e9 |6 z1 A! s
    使用同一目录下的json文件,反向映射出花对应的名字7 t5 b6 ^" m2 o3 n6 H4 S  C

    + ?& A" G& N# C# g, xwith open('./flower_data/cat_to_name.json', 'r') as f:
    % K- d) T3 W  m, u1 e    cat_to_name = json.load(f)9 J% f) V( h% [* e3 F, q6 ]+ M
    1: I* ^- v8 z( X$ O- b( m# L0 p6 K
    28 N% L1 U) d- X- O
    cat_to_name
    , j6 r0 q7 ]0 l) c1
    7 z- I; Q6 Z9 K% P4 q{'21': 'fire lily',$ I$ a4 Y0 S4 M! s8 I
    '3': 'canterbury bells',: Y. k' x* _4 Z- h  T9 W1 ~
    '45': 'bolero deep blue',
    * j9 c9 m! t4 ?% w '1': 'pink primrose',
    ! m( }" v3 @% a5 y( j '34': 'mexican aster',
    - I1 g6 Y9 F/ }2 U: e '27': 'prince of wales feathers',
    % \) ^/ n% u& t9 z '7': 'moon orchid',) C+ x; r5 M0 ~9 }
    '16': 'globe-flower',
    . e' Y8 `6 V6 Y9 _/ h '25': 'grape hyacinth',1 K$ ~% U8 Z2 U! ~& c4 e$ d8 O! T, o
    '26': 'corn poppy',
    1 W' Y' d* M# \6 r: X6 U '79': 'toad lily',# r7 R) Y2 S# n1 E2 X+ Y
    '39': 'siam tulip',
    " G  y9 `* A% G '24': 'red ginger',
    ' O$ B2 ?* k0 n5 O4 w4 k# _  h% A0 Q '67': 'spring crocus',
    6 \. F4 ?3 ~0 i; ?3 z$ H& Z '35': 'alpine sea holly',; J6 t: k5 H( H! z& M
    '32': 'garden phlox',
    % G# G8 h% a2 f$ ^) p+ ?) A0 D  P '10': 'globe thistle',
    3 w* o3 C/ w! R9 b/ \ '6': 'tiger lily',  h) X! w6 I* j0 s8 p
    '93': 'ball moss',
    4 ?& d3 [+ a$ C4 Z) J '33': 'love in the mist',
    8 Y5 m3 f) q2 [, ^) n '9': 'monkshood',( A2 F% q6 e0 c" G
    '102': 'blackberry lily',. o8 i! W1 ^8 b7 ~) [
    '14': 'spear thistle',
      b2 |# ^# v4 O5 u '19': 'balloon flower',5 l8 ~# h& \1 @' L1 P; n6 Z& F
    '100': 'blanket flower',4 O" m) u  b% k
    '13': 'king protea',: `$ E* V6 |9 u. k
    '49': 'oxeye daisy',
    & l8 w; M' a& M1 M '15': 'yellow iris',
    8 |# J8 I) v' U- k '61': 'cautleya spicata',
    + Y4 _5 f0 P& ?/ w '31': 'carnation',
    " r* ^: o! y4 f; T! r  k '64': 'silverbush',' R5 c2 A, B: M/ q9 Z9 r; ~% R
    '68': 'bearded iris',
    8 m5 c* P5 V0 C4 O/ c$ u '63': 'black-eyed susan',
    % q3 m; p/ A4 K( i- d' J0 z '69': 'windflower',2 H( [2 X3 o4 l! l- t" L9 `
    '62': 'japanese anemone',
    6 d( z2 k" W# k& g3 m/ Z '20': 'giant white arum lily',
    * E, [& M& m5 o3 p2 H '38': 'great masterwort',
    : E% t: D8 a. C2 Q  Z, H '4': 'sweet pea',
    # b( z' e3 `: S& E7 F- [ '86': 'tree mallow',
    4 F8 I+ r8 d( ~+ s8 }9 p" D '101': 'trumpet creeper',9 s% v  C( X. R" b9 k; `1 ]! o
    '42': 'daffodil',
      o$ v' R  W7 N9 |7 ] '22': 'pincushion flower',; Q/ @# `. E+ ^- ~! A" @7 F! j6 t5 P
    '2': 'hard-leaved pocket orchid',
    0 c: y+ D5 A- m! U) X. `6 X' u '54': 'sunflower',
      k6 K7 M# j4 u' H& h1 N '66': 'osteospermum',
    ! Z* e& G8 s" ?' v! @  p '70': 'tree poppy',
    4 s! x; o: A3 V4 z. R# Q; P '85': 'desert-rose',
      I; j7 a; g  G7 U2 G3 V& C '99': 'bromelia',9 f/ {7 G. H4 V
    '87': 'magnolia'," D2 a) E/ N; {: ?4 W1 c0 E1 h
    '5': 'english marigold',# m$ x3 }" V( N* |  `# T/ Z
    '92': 'bee balm',
    * ~  h# t9 Z, e1 y7 g" _ '28': 'stemless gentian',
    ( }2 t; N- e+ `5 Y. D1 u9 i% {2 j '97': 'mallow',) \6 ]7 |+ J; g: M1 q: @
    '57': 'gaura',: }. b7 n" O; R9 x
    '40': 'lenten rose',0 ]$ a" @( A6 }# r+ e3 z
    '47': 'marigold',
    7 m* e5 l" i# b% c- p) @$ M '59': 'orange dahlia',
    , S  ]0 w! L, W8 Z1 F '48': 'buttercup',  K' Z, Q1 }0 y
    '55': 'pelargonium',# X" ]+ P+ ]' @# ?; T7 k
    '36': 'ruby-lipped cattleya',
    & J9 h0 Q+ o+ S '91': 'hippeastrum',
    9 z6 N; f5 X2 z. R '29': 'artichoke',+ [; _/ x: G* i% w
    '71': 'gazania',
    " x# J  h$ O. k1 ?2 Z! w '90': 'canna lily',8 u. l2 L2 z3 H. ^. T5 ^2 f
    '18': 'peruvian lily',
    ' j1 I1 N! ~3 g. s6 t '98': 'mexican petunia',
    " T0 D! M. G2 x+ @' P '8': 'bird of paradise',
    / }) @- L; N0 \) q( e: r '30': 'sweet william',, C: u# Y: M; t' J- O
    '17': 'purple coneflower'," s9 v5 G+ A" T4 m% s4 B" ?4 z4 S& J! F
    '52': 'wild pansy',
    " T1 r! m, p  L% N: F '84': 'columbine',6 V6 M8 E9 ]) W2 ?- J
    '12': "colt's foot",
    . H* H7 T0 y9 D% G* ? '11': 'snapdragon',
    3 p$ Y6 Q- G1 Z5 ?  b' a3 O '96': 'camellia',
    7 o; T- u$ N1 X. B" G3 K '23': 'fritillary',, ]$ o7 B" t  W& j% X
    '50': 'common dandelion',, V8 c" X+ Q# E. _) F& V3 l
    '44': 'poinsettia',$ O1 A; z) h& `+ d8 T  T. q6 U
    '53': 'primula',
    6 |) P% G9 x; d% e& B '72': 'azalea',* [8 W. q8 f+ v- e
    '65': 'californian poppy',
    6 O  c  h6 B6 G '80': 'anthurium',
    6 r) N( X/ X1 [& q! u '76': 'morning glory',  l3 j/ u) U: [7 a& K& S9 S$ b
    '37': 'cape flower',. v1 s. `- D/ l+ q+ Y
    '56': 'bishop of llandaff',6 t5 h2 m8 |9 K
    '60': 'pink-yellow dahlia',
    / T+ A! }0 f& T) L '82': 'clematis',7 B4 s) V) n3 Z1 t
    '58': 'geranium',# t7 Q) m) T  m) s" J; D$ n
    '75': 'thorn apple',
    5 Y3 p$ [9 I) Y3 l6 i: y! z3 I '41': 'barbeton daisy'," l" o% U5 W4 `: v8 t# I6 z
    '95': 'bougainvillea',
    7 S* M* y& l; i# x! Z- Z$ M '43': 'sword lily',
    9 ]  v: S+ \3 _1 r* ?2 V '83': 'hibiscus',
    ! B2 a) H3 W$ e- V" ^ '78': 'lotus lotus',( I4 p& Q8 w" ]! z, x. e
    '88': 'cyclamen',$ F! M  W8 |: y$ b
    '94': 'foxglove',
    " y" r1 X' _6 P' R, a5 ?+ b! A) X, x '81': 'frangipani',
    " W$ y6 K1 T1 _# B '74': 'rose',
    + [5 _: D) t9 y$ y '89': 'watercress',
    8 }4 V9 W  Y# W% K6 J '73': 'water lily',, V+ {- j# c% k; O6 H- L
    '46': 'wallflower',
    * {+ t* q9 y# s( a '77': 'passion flower',. n' V3 ^( A0 s2 u
    '51': 'petunia'}
    + D; h+ n1 W2 W0 V4 i  t) p  P1 e/ a6 T* T1 z$ Z& Y# S
    10 g) e2 l# a+ e4 H; q8 n/ l$ ^# C! C
    24 B9 q2 j' M# T) ~: I# [  i1 ?
    3
    1 R* d7 o. v- A4" ^: O1 v% n9 y6 b% p' x
    5
    & @& H' [9 \! f  h5 _' E8 ^6
    - n  s  I5 m% S$ q76 A) g% I2 Y. S! f$ y7 v5 C
    8
    5 u, L2 h# x3 l' N0 _( L9
    , h7 F; h! M% B& P% L1 @10
    8 g) e1 @2 _) N& a4 Z( W- K11
    0 {0 ~; j' ~, o# H12: y3 ?  X4 q' U/ s. s! {8 o
    13
    , M8 f9 O. b5 ?3 k2 M' Z6 h$ l14
    1 K: U5 I- l/ \3 r$ @152 r. m. _% f& G3 }# Q% y
    162 h) R* y7 o- b5 ?$ A0 H/ {
    17+ C! T! f' p! i( V
    18* Q5 f" w% L' h5 U/ b( \
    19' V% R, {& D% E/ e
    20( {  v# i& ~8 e' V* O. ]8 r- W3 \* S
    21
    4 s7 p* C. F* G6 ?6 e22, O% Z2 T2 {$ X% R3 ^. F
    23) ]" E( Y4 S, w
    24. ~( X; D- D5 v( f* Y, @
    25
    : T1 |7 z# G4 X7 [  p2 ~. m8 e26
    % Z- T% \) n. N9 B27+ g+ a* q% `8 u7 X
    28
    ! m. K# I8 w$ f8 K29
    & e, ]9 ?, Q. U9 E" Q30  M6 J& @8 M* Y1 m3 |
    31
    % h- G5 F) e8 c$ X0 ^4 H! S) r32
    . e' a# V9 r( f) U8 x# S- r' Y/ y336 i8 m# |% A- ?  A7 E1 ]
    34- L1 l7 H, V$ Y% F) M
    35# F# y6 D% s% w  v  w3 M2 {
    361 j3 P3 y5 C/ V0 E
    37
    : ^; z4 h( Z$ R$ a! H2 [, r: [- K/ h38  g/ k4 f; L9 [$ a: H1 `3 g
    39
    : ^  N. h( m* a: e400 o8 ^' W( A  ]
    41
    ! j/ {7 D+ A& f7 G, P. U42
    $ \4 C8 s: k. ]- u43
    / X+ ^* Q2 y! A$ l441 e% D; @( K1 y4 d: O
    45  Z: |5 r) k! t( o
    46( Z; w* d% ?8 u9 [
    47! X& K: ^: E2 F1 \
    48
    ( T- \( c  E, T- J$ l6 ~49
    ; o5 O. a( m- d4 _50
    - l3 a7 @9 p) d2 L+ q/ [% j51
    " U) y- L, d/ D& T  I% m52
    3 k" H5 t$ u  G( ~53
    3 ^4 G/ B/ ^% Z548 y5 \6 W, t# W
    55$ I6 V( b3 R* I0 w% \. V1 P
    56
      `2 B8 c9 q0 u1 v, t578 C- c% X" l3 M9 K# }  ^+ @# W' {
    58
    + d. G, C% b3 t6 n59  f1 u# U+ O$ C; G" B, X. X
    60
    0 X7 x8 X8 a2 i61
    ; O6 A1 v1 z/ N62
    1 d  H) A4 D) b+ ^8 X% Y63  R. v, \1 f- `) ]! a
    64
    9 }, F( F) \0 e& w7 Y. _5 c65- K: L+ k& ^% p" n8 s2 I$ r$ z
    66
    & w8 k; E! L1 m67/ p5 Z# W! |3 I5 k0 _; n
    68
      ^: v, v& C9 R+ g# d3 J# x8 d69
    / Z- @/ p8 w/ C1 W4 O- W- {- c700 a3 D% B( t& N4 B( C, _7 Y. b
    71
    3 q  |. H, \( |8 B# k72
    9 t  t8 I4 F! g73# D3 b- _3 T, }$ @5 t- w
    74
    2 z& u- t4 l  y75, @2 X/ t! h. F2 y5 z
    76
    6 y* Y- @* \* E9 G77
    ) \5 }. v0 p' h, Y2 A: ?78; ~* _" O8 P, m4 J1 V. u* D
    79
    ( g, b3 g! w2 V2 E6 s80
    / y, I7 f5 c: g- w2 ?% I! {7 {" |81
    " c& k, _7 h9 l# T2 u" U0 c825 q) O9 U( S8 V6 f% q
    83
    , G% `0 r  ~8 U1 f1 e& c84
    4 C( p% a5 M3 X6 I# e854 R: Q/ I4 P1 E( G
    86
    $ n, H/ P' A: Y' n871 ^. w4 E; T4 W
    88
    8 t, [- y  i. Y. }% E89
      m  d( ]  g3 _3 I90+ ]" O" ^! }  C2 J+ }( F% [
    912 R# Q1 B2 a8 ?0 @9 G
    92" V( j) l3 D, D! V( {
    93
    5 ?! h% H& W9 ?946 k( z8 b! q' S
    95+ M/ L1 _' c9 W" R- Q
    96
    5 E5 f( c! z7 C+ U1 N* Y977 M, C' u6 j7 C% W5 q9 n' `
    98+ y" Y" N& J: x2 k' F1 q1 I% k, O
    99) x9 N  r2 A( r
    100( Q) I; ]1 r) e6 g
    101
    ) K# i# m0 E* p- h5 ]9 W2 Y1029 \3 x0 [- N, r; }* a
    4.展示一下数据3 ]# f6 z# @7 p/ |1 {
    def im_convert(tensor):9 r& r7 [4 p8 o# h8 S: ~, R
        """数据展示"""2 R  D2 [9 T  U6 |) B0 C. w
        image = tensor.to("cpu").clone().detach()) x3 T* f7 \) Z+ F8 ]7 n8 W
        image = image.numpy().squeeze()
    ! e$ @, S+ \3 n4 \    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    - N4 Y9 T# s/ Z$ C    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)8 s! T4 ]0 r! v9 y
        image = image.transpose(1, 2, 0)
    " g" [- s+ J6 k" ]* f- V$ T' ^    # 反正则化(反标准化)0 x6 E( M- F4 @, h
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    # T3 _. I& i  B( g3 E
    ( [& _' _/ `" k2 b  R- L; b    # 将图像中小于0 的都换成0,大于的都变成1
    % N+ T9 E9 x2 w( Y. G1 j    image = image.clip(0, 1)* j8 E, q  H9 V) {

    # x; `5 |# E: U: m' X    return image. X$ l  ]/ X2 c  Z
    1
    5 _: V# |& \- Z2
      F2 b/ X  x" a3" K7 v: _  F) k: Q) }' R- i9 Z
    48 l. C7 n* X* J# s5 |+ `) G
    5
      g0 d: Z$ W  I6
    5 m- n& Y6 n1 n( M5 q* {77 F% b; I* o* m) l8 L; K: e+ U
    8; F& {! _5 k7 q0 N
    9
    8 S7 u5 I0 t# p3 f103 i0 T0 \1 w  w2 A! i+ Z# [  g' R
    11
    6 e# ~: ]: a' H6 a12
    5 |1 K' k6 H  |, R* b2 d13
    ! q0 v" v* |. }3 M5 o14
    6 |4 {  v: J0 k! w0 S% z# 使用上面定义好的类进行画图8 R+ w. M# c6 a4 [+ d
    fig = plt.figure(figsize = (20, 12))
    : j; L% R" O- C# rcolumns = 4
    ( [/ ~, ~5 z) o  @+ Qrows = 2
    6 G8 Y8 N2 ]  l5 Q( e0 i' `/ u6 @" Y- ~# ?
    # iter迭代器0 q) F: A! x0 S
    # 随便找一个Batch数据进行展示
    & }0 U' V5 b: P4 M7 M( d$ Odataiter = iter(dataloaders['valid'])% h' M4 B; q! U4 K- z
    inputs, classes = dataiter.next()
    0 E( c6 ?8 z# k; W. W% f
    ! b" @. Z0 |7 P9 Cfor idx in range(columns * rows):, |$ L2 P- q) ^* P
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    9 }) ~  S6 ?2 R4 w+ \    # 利用json文件将其对应花的类型打印在图片中
    3 \; d; F# y0 i" z    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))]), m) [9 j" x$ J; F
        plt.imshow(im_convert(inputs[idx]))# ~5 p" H# t( @! c" s/ B6 O
    plt.show()
    6 N: f" s4 @5 ]+ k" m
    5 Y/ {6 B7 z  \3 e* _1
    ) Y, {& K8 @; S5 D5 L$ }" W2
    ( _1 `- R9 a1 F6 g5 V1 p" f- d2 v% n! Q3- {7 r# R9 l/ S2 y
    4
    , Y1 K! p) w! S+ y5
    # L& I; }9 [& s' l6) M, i! @6 y; D' o
    7
    ' M8 g# `: e. s8. |8 M( T1 O7 O" n( T5 h7 R
    9
    ( K/ I' `% g8 e, X10
    + Y: P# H( y5 d# ^7 l# W) ~5 _11
    ; t+ w: f4 R- }$ k8 O. w9 P12
    " S  C+ ]8 s1 L* z. ~13
      l' ]+ r8 R7 J" K14
    1 ]2 U- m4 g0 q7 A158 P/ s) ?7 m7 h+ K8 [* Q
    16
    + Q( l! d& R* M& s, V9 K7 P* a, b* V# M4 ]  H; A& r$ R

    1 `6 k5 `, l& n0 h$ T& Y, O5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    9 }) `" L8 G7 B' Amodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']* h. n- E; t9 {5 V# d5 [8 h
    # 主要的图像识别用resnet来做
    . {' A2 \. `7 U4 H! M+ v# 是否用人家训练好的特征+ k2 {1 h& i- X7 i
    feature_extract = True/ C% |2 L& E; O. }& J2 t
    1
    1 o4 O  f5 ~  z* z1 [) r' V21 R- W0 M3 n/ s6 u! O
    3
    4 e3 a! r6 E' L/ D$ [% v9 c4  r) N$ w: V( D7 H" f5 {4 u( d
    # 是否用GPU进行训练
    % h# M5 W  T' y: _" Dtrain_on_gpu = torch.cuda.is_available()
    ; S, M7 c* H* N% [) P+ w, c
    . W- A/ E7 y6 D) Y: f; K- Kif not train_on_gpu:0 p: I4 x2 B. K  ]
        print('CUDA is not available.   Training on CPU ...')1 h6 Y4 _3 D& A: u; P- M
    else:
    2 d+ Y+ e+ [/ |/ e+ N; M% S    print('CUDA is available! Training on GPU ...')
    4 l5 |. E. I, V" [- O; F8 }  H- L+ m: [1 A" {& R+ a4 Z" F- f" _
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')6 P. G2 r4 r" s' k$ e
    13 Z$ Z% r7 g7 q4 ~
    20 U  F5 ]" C1 A+ n1 [5 f: E
    30 `3 \( N9 m* q$ n
    43 S: c0 z+ j! Q
    59 e5 u: L  d6 h2 n  ^
    6
    , q9 [& w# r7 _. G+ n$ A8 R& e! M7
    ; L4 _4 s, F5 F( {8 W( r7 g8
    ; m. p+ P8 ~6 j* `5 ]) G: B9
    % S6 w2 ?5 U4 K9 M* ~$ J" p$ jCUDA is not available.   Training on CPU ...
    9 v3 F5 g7 U+ [7 d5 l1
    0 {* R( M5 O6 v- ?6 b2 Q0 n# 将一些层定义为false,使其不自动更新
      @2 F9 A7 N0 ?9 a" z1 Adef set_parameter_requires_grad(model, feature_extracting):9 p0 ]) j) ~, K, u/ z' _# l
        if feature_extracting:
    6 p5 X( j9 h" B        for param in model.parameters():. g9 f+ O; y* Y$ J
                param.requires_grad = False
      w4 m) A% J( C1# ^7 C* p+ G1 u1 Y* P, V8 {
    2
    ! x* L4 `- o+ `3
    ) `. T5 H7 `) {: T. H* R( A4
    ; o; a1 H$ i. e& k2 b; O; T5
    5 x0 k; j  s+ f( i+ P) [+ m; {# 打印模型架构告知是怎么一步一步去完成的0 y: C) P3 a, k! U4 W/ w- J) u, f8 y
    # 主要是为我们提取特征的& I& }4 Q3 C( J

    . n1 k; N  C' x! Y9 A* T" gmodel_ft = models.resnet152()
    7 o- O" w9 q+ ]& X8 o- pmodel_ft4 X: r% N' m  F, s
    16 P0 U' S$ i: i" a' ]4 d8 v
    2
    $ `. S8 ]) R+ M/ i! I& L3' k7 X, s) s3 H: O/ C$ J
    4
    1 [, K+ }3 Z% Y& y4 n6 P5
    " k9 C$ n: Z$ q1 S: U) tResNet(
    + q( Z: `; F2 ~2 w# H6 W2 z# J  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)6 d& f$ {2 o3 x8 E
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    . y' B4 l% B0 `* h3 N9 G) E4 N$ J  (relu): ReLU(inplace=True)
      A- G% d- A- l5 d. @# i0 g  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    . v2 B; _( j" U0 M5 G  (layer1): Sequential(
    4 |, c2 z4 |; h! \' `3 f. m    (0): Bottleneck(0 ?# X9 S0 d& x- `
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ( f2 H- Z' w) S+ y9 g0 u      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True). c/ {1 G8 _6 i3 [
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)+ l+ v  _/ D5 N9 t; G1 v9 p- ]$ L- T
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      ]  N' U; I' ^  `9 ?% ]) y: Y$ i      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ) X! ?( Y1 @# p" P/ ?3 x/ I      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ d# n: Q5 B" p8 G  j; Z  C
          (relu): ReLU(inplace=True)& N. C* V* ~, G
          (downsample): Sequential(& U- H! D4 d& w# G; A. a( R
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)6 g. l5 M  z" V) n
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    , E# F) x8 D3 O      )% r* e7 M: V! c& S
        )  G: W! g& [2 v+ R
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    . S7 F! U" m- B    (2): Bottleneck(
    % n# @- K4 A+ ~8 j# d      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    * H% y) R  a# Q; h$ R+ h1 u: O  d3 ]0 o      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)2 a: h0 |, b0 |0 p
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)# C2 \" ^7 Q+ @+ w
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)! q4 s: y* u7 `5 k0 Q2 U
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    6 w/ m9 M& P: k' V/ i& t      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    % ]* f1 n7 x1 ]( l      (relu): ReLU(inplace=True)& q! L/ z: r& j1 ]9 o: U: ]- l# y, C
        )9 b8 x. h3 A7 C# F4 _( D( u
      )
    ( b7 W% S9 [* m5 _# i& [. K  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    4 r& r4 F2 I0 s9 x% V/ Y  (fc): Linear(in_features=2048, out_features=1000, bias=True)' u+ W2 w3 w& H+ T! }7 ~8 k, I
    )
    # |4 f" J) h) h: i0 W, F2 M& j# v9 `2 ]$ [& l6 _. z
    1
    " d; _! {$ F; N4 x2
    4 b% R1 w$ @: c% H5 R0 q3
    " m: i4 |2 C9 r, J3 a6 @, j) i4
    # n6 J- O+ Y- o7 _; I2 m: Q5
      p6 I4 T5 Y  l' r1 O/ v( r, G9 |6& n/ S$ x6 Y$ k& I3 o# T  o
    79 g6 k, z/ F7 ]# g0 K6 W2 c
    8
    # }1 Q$ }+ d0 z: s( U$ b94 f$ P9 J! [( f, O; G- @6 {4 @
    10" f( w3 U: _6 {" N3 c
    11( \8 O' K# O7 n$ `
    12
    " q; B) B% e9 m  d) X137 n8 c% _7 S' s. C0 B# Y
    140 o( t) K) a" m. m5 B) b/ W
    15
    . f) j/ p8 u$ e- s9 S16
    $ p* o7 H, X* S3 N& A8 q0 e170 T, C6 j: |$ t! i4 i' Q  G
    18
    2 ]# ^, s3 A7 S& u19, C- \! _2 Y% H+ g' u2 {7 `3 \
    20
    - x- b0 p! t7 b9 j  g  M21* h$ S& b9 w1 G+ Q; b
    22
    7 n/ X7 y( i, f+ O) W' c6 I23
    * S1 F) ]& ?' n# n" E244 X) ^  b7 S( m: h/ Q
    25
    # h% w4 y& ^2 Q8 l* Q26+ T: T6 M$ `0 L
    27( E6 E  l- J! P; n" m/ \6 ?$ i' I
    28
    1 {% y( P/ E# T, |/ e29% ]& y3 }. b" x
    30! j- \' t" \5 [% W8 T: g) d
    31+ F. p1 f, v+ v, Z& ?4 x9 {0 ]; K
    32
    7 |8 W$ C3 s2 c$ E3 c/ b33
    1 A) L7 h5 w8 F! G# {, x1 b最后是1000分类,2048输入,分为1000个分类
    ' j/ R9 N! I4 Z6 D' D  n4 w( _& O而我们需要将我们的任务进行调整,将1000分类改为102输出* S' j# q; ]; O6 M8 G8 c

    + ^+ ]# D2 u6 ~+ F; M1 Y1 }8 L6.初始化模型架构. y* F; ]1 H7 \/ w3 S4 U, X' r
    步骤如下:9 G, V; A4 `, E' `. M, |6 c

    & h# l5 F( x1 Z将训练好的模型拿过来,并pre_train = True 得到他人的权重参数- d! e" x# r5 d" [& ^2 W
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)' M7 d! M6 G! ?" I# @
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数' T* r  X, V# y6 x3 @
    官方文档链接
    9 h. r8 ~8 [8 k' F; ahttps://pytorch.org/vision/stable/models.html
    : y5 U- q3 w* n! z) P- g, J; P7 W: X6 _5 c2 S
    # 将他人的模型加载进来  `- O2 I' x5 f: |+ N
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):+ a5 [7 N) a9 l" ^4 Y7 v
        # 选择适合的模型,不同的模型初始化参数不同/ x( N# C" o1 l8 h" U0 a
        model_ft = None
    ; O* R! |7 ^7 b3 M0 f' L    input_size = 0& M$ D. E+ @1 L' D5 H

    5 D4 Q4 ^* |. |$ T1 H" ?, q    if model_name == "resnet":2 S4 m7 ~3 R  c8 j
            """
    ! L# h! R4 X: y+ B& P; h        Resnet152( t* k) S' p; T+ G7 r
            """& }) a8 v* \% M9 P3 J1 A3 R2 s

    ( F; K, L5 V: x5 {: ~        # 1. 加载与训练网络4 {1 N8 g0 G, r
            model_ft = models.resnet152(pretrained = use_pretrained)
    ) i: t/ B+ f; k  c+ t: W5 X" f! T        # 2. 是否将提取特征的模块冻住,只训练FC层' Z" m$ O% R$ D! o
            set_parameter_requires_grad(model_ft, feature_extract)
    * \* Y& B' N1 x4 h        # 3. 获得全连接层输入特征: A( A3 y8 N+ q( r9 K; Q* e& ^
            num_frts = model_ft.fc.in_features8 T3 ?3 O5 [: U' f1 g# w7 e( H3 D
            # 4. 重新加载全连接层,设置输出102
    8 R8 n# q  ?2 A# ~. z; _        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
    7 }6 t+ h9 r) @# e# ?                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1# ]- V! X3 e) t8 e4 f6 [
            input_size = 224
    * m6 P4 E! f! q7 Y8 r* F2 f8 h( h) \; l- Y
        elif model_name == "alexnet":
    # W- L! @2 }- u7 K( G        """
    + l+ p/ o/ a. n8 h        Alexnet
    7 m+ ^  W1 t( }) }" k: r        """
    - q% O% |! s5 t5 p5 |; o        model_ft = models.alexnet(pretrained = use_pretrained)
    8 {3 W: \  W& v+ P" l+ e% u  Q1 o' s        set_parameter_requires_grad(model_ft, feature_extract)! Y! m4 v3 d: ]8 h8 M& T# b: K

    : z- m& X+ T' h; U3 e: i$ B        # 将最后一个特征输出替换 序号为【6】的分类器+ o& y% S, `3 z& x, P- `* y
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入
    / ^" M' d% O; U8 f* p0 a% L        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)& V7 H; f' x8 H, b( `" x% m% H( X# {5 R
            input_size = 224
    / C" ~2 Q# G5 n1 I0 R4 f4 q$ J
    ; b# q$ u2 y- f  v( d    elif model_name == "vgg":
    - k; _; P: m, @+ n4 T1 w        """. }' r% Y( r  Y/ n
            VGG11_bn2 k9 Z2 A/ A4 O5 X, q
            """
    - ^1 m/ A( ]' T        model_ft = models.vgg16(pretrained = use_pretrained)
    1 w7 q: p+ `) T        set_parameter_requires_grad(model_ft, feature_extract)
    % r& T# C; R3 C7 M3 {# d8 M        num_frts = model_ft.classifier[6].in_features
    , T/ `& l6 I9 x0 u# l: [8 Q        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    , d( Q: G3 ^3 \        input_size = 224& E& p3 N) t3 Z6 n

    7 [7 I9 R1 V( Y" h# y    elif model_name == "squeezenet":
    & V* v' n% s4 r. Y( a        """
    : e$ C/ v$ Y$ v+ a5 T& y        Squeezenet
    8 v% }/ P! h' p% C% q  k3 z        """. ~; A1 g/ c" P& v$ x* [2 X
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)( O- I4 w" h& ?, M) i- k1 m9 t
            set_parameter_requires_grad(model_ft, feature_extract)4 u  H3 K2 f1 T: g
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1)), s" U1 ]* x- {: l- Q( ^/ X; t
            model_ft.num_classes = num_classes* D( F7 n2 d. }/ v5 l- _
            input_size = 224
    ! D4 d- j1 n+ q, S: I' H: z5 I$ r% K9 l9 ?# J  @
        elif model_name == "densenet":" U1 j" ~& K" D, u+ y% {
            """  ]+ ?8 c4 n7 F( @5 v
            Densenet
    6 X- b# \$ j2 j' A        """
      ~' g& h9 G( ^( V* Y        model_ft = models.desenet121(pretrained = use_pretrained)
    3 U0 `; s- }8 N! }" @/ ?8 d7 z" v        set_parameter_requires_grad(model_ft, feature_extract)/ T. h, ^0 p$ G, k; A0 k- p: h& o
            num_frts = model_ft.classifier.in_features- k! \4 ]2 W* k" N$ S: q
            model_ft.classifier = nn.Linear(num_frts, num_classes)
    - Y6 z2 Y3 g6 g  {$ ^, O        input_size = 224
    0 ]* x9 V; O6 R) T
    % G3 L0 G2 X( _  y    elif model_name == "inception":
    9 V3 }7 q# L+ e9 h9 y2 V        """
    6 w! m6 Z+ _# g3 @. o$ T! ~# Y        Inception V3
    . R/ K% ]/ e4 x) u& j7 Q        """
    ) F1 [6 L, [* r0 J; h( n        model_ft = models.inception_V(pretrained = use_pretrained)* n0 e; P. E/ @& W* p3 B
            set_parameter_requires_grad(model_ft, feature_extract)
    7 Y4 A+ u$ [9 q# W  ]2 d/ |6 b
    9 H: V( `  A) t/ Y7 }# F        num_frts = model_ft.AuxLogits.fc.in_features8 V( @, |- u# U6 G
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)# Q2 y3 V8 u8 D) u) V
    7 j  v. i; y" @* w. `% S. |
            num_frts = model_ft.fc.in_features% |- t  ?9 w1 K9 ~* {/ m1 b
            model_ft.fc = nn.Linear(num_frts, num_classes)
    0 ]# g6 i4 U5 `- q+ J% f4 p; O        input_size = 299
    : A) Q) h# a- }7 Y' C+ E1 x( l8 L4 W# ^) v" o( P8 P
        else:! z; _# j  }6 i) }& u6 z
            print("Invalid model name, exiting...")/ u/ f/ ~- `% o6 @) r
            exit()
    ' p( ^2 u8 y0 t$ i/ u% y9 c% E4 j  N  Y% D
        return model_ft, input_size, g0 Q. ^# H2 W0 C" J1 w
    % V6 V* U2 u7 p+ _3 f: j
    10 h( P1 j' z' l% _0 [+ t$ g
    2
    2 R) x) Y+ H! Z4 x: d. ?3
    , v8 U4 O4 Z3 n( U4
    * o; L! _7 P% Q" b. g8 D3 Z5/ {( `. y5 _% X# f( {4 G' p2 S
    6  ~9 B* E: n0 y! `, o  T
    7+ R8 \5 f9 w$ R2 ^
    8
    ( |7 v3 ^4 f1 a+ ^2 J9. q5 R$ n2 B( L. ?# `
    10
    " R, {9 u, z9 i3 H$ {11$ H) |' H) S) ~* I
    123 b. B! x# [; w- d) Z& n, Z
    13
    4 v6 a# B2 \, G! T7 k# W14
    5 N. z( c! s+ K, p0 U15
    - m' c6 Z( h4 G) k167 E, S1 f- \# h
    17
    7 t' E( f+ f7 W8 ]9 _# q18
    % |7 J6 j! h- d19
    ) }$ P! @  r5 G) E20
    " L0 y7 s+ N0 z# p21
    7 f& ~6 I2 _4 ~22  w' u# l' h0 W2 b
    23
    * g" o0 ?; J! J& K246 }8 n6 B- q6 z1 c1 G& q" K
    25
    5 }8 x2 J# s( z8 o7 I+ ~7 J0 U7 L26
    * E) f$ t' p9 V, V; J$ Y278 x3 M3 y2 K1 `4 [
    28* j. F- s! C$ @- P! u) }
    29
    7 b3 v) E, m1 i8 _% X$ p! p30
    3 a' T  q3 Q0 l* m7 P6 ^1 y" g313 a- `" w, ~6 p* G) [6 G
    32
    3 a' ^3 L7 E% ^3 H! D33
    $ ^/ i) L: ]2 R2 O347 b9 D( R4 O* b
    357 e3 J4 n, O' [9 w9 b; h
    36, ?$ w5 `& n! t
    37
    . V& Z# N0 K$ `5 h* \% K3 ~38% C6 m8 _" \+ f0 v+ a
    39
    ! {2 B# K6 V, B" I407 I& x" }$ W4 w/ n! C( l1 W
    41/ w3 g9 Q7 {9 X5 M& I. h4 u! n$ y4 \
    429 q' ^- o+ f- e5 F1 E2 L  e0 V
    43. |( D  e- f$ ^' A; w
    44' U! m' `) G- s" }: W8 i/ W. f
    45
    * \0 ^( o* f8 z46, z3 K: H4 k. W/ p! w% Q: |
    47" ]/ T$ ?2 h' L
    48
    6 N0 t; z/ y* A; n4 \49
    ! ^& g6 y) _5 N! M" T! l( z& d( m3 I50
    / T; U7 h' ]$ F* G$ s8 O$ E, w  H* e51+ T1 @, C1 y+ d' e
    525 Z, D: l5 [# o& E( g4 j7 |9 v
    53; L' h  r7 i9 e$ m8 m& h1 J
    54) {5 X. V3 E0 f, X" ?5 G
    556 C+ s  K# R6 [0 u! {! z+ Z- C' ]
    56
    % [0 J8 V# {: `0 N1 H/ u/ s575 x9 P( o/ O6 h0 l+ [
    58
    : T! n3 d, G* f7 n# l) T; o) f/ ~59. ^8 S9 ^; N9 b3 I) U) i1 C
    60
      X" X4 i' P/ j7 ^, a, b! g/ O9 C61# J* Z7 y5 |' u! A: S
    628 R1 _" j/ i& V
    631 D2 K& v2 M1 e2 @7 U4 k4 w0 z8 y
    640 ~4 t( L: T1 v
    65
    " `) B  T# o( {' w66: i! M; w* B# y4 f4 _
    67  g  X: Q. @3 A- F- K) L8 t
    68
    0 n- ?/ K4 s' V+ K. I0 B+ d: M- E69
    ( f" [4 N  J7 A70+ p7 I( V+ K6 s1 l
    71
      L$ t# I7 Q3 H72
      Z  @# u% Z1 W+ y2 Q" p73* \. O  Z- ^! T+ o
    74
    ( }/ x+ n2 G0 u& R3 M# \75
    4 ]" k8 ?# N/ S* k- `76
    ; o2 o$ D- D3 v+ e: _( z5 c77
    ' _# C" I# i2 n789 D( r1 U. `, o6 b: ~
    799 j0 n- i! ~; {; a
    80
    " D: A- B! Z$ m7 K5 O3 W- Z814 m1 w4 ~6 i' v( f/ `" r
    82
    ' V+ G& j8 t" _! z83; T) F; z$ W, K
    7. 设置需要训练的参数! k+ |4 P3 _' R3 p) w$ [
    # 设置模型名字、输出分类数
    * u. i! k- a) h2 z" _) Z' A# ?model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
    ! i5 \; f  S: Q, O/ j( U2 ~7 C( Y4 a6 y
    # GPU 计算
    ( G% {, L  W: p- a9 u( T) y3 D& Pmodel_ft = model_ft.to(device)
    / d6 {2 N. i& u5 L1 B/ P: A
    # d# {) A  i5 u# I$ i" o4 A# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    8 g- J, U/ `; ~2 {8 L2 Ufilename = 'checkpoint.pth'
    / T; V/ c- a8 k- V; O* a. F2 q' p  ~
    # 是否训练所有层2 Q* u- ?) W& Y6 L& t+ h2 i
    params_to_update = model_ft.parameters()
    " b0 y8 ]7 h( L) _* ?7 z. @# 打印出需要训练的层6 ?) @) c" S! l. l* X- y! N
    print("Params to learn:")
    ' S  w% ^5 l# b( Q6 Rif feature_extract:- i2 n3 z( }% ]5 r7 N: w4 ?
        params_to_update = []. e' W( B- K/ V
        for name, param in model_ft.named_parameters():3 T$ v. l; T( b' Z+ @6 |! b9 Y0 o
            if param.requires_grad == True:
    ! r1 D; _2 U- }* @& c            params_to_update.append(param)
    * c' j3 K0 Z  q  i  E! z& b7 p! P7 E            print("\t", name)
    ) G* W4 `$ E' ~2 O" Ielse:: w& P7 _3 Q) H& X: `
        for name, param in model_ft.named_parameters():7 a: l* g. R6 d& K
            if param.requires_grad ==True:: j% V' e7 n& n3 t2 V2 n
                print("\t", name)
    & w6 q. @# W9 B( z
    4 P' t4 o! V% H1
    + J% ?2 E- o/ Q2( h0 q! \/ F) E0 X
    3
    $ {% R* x9 r3 l4! L) F0 E% \( D; p+ ^
    5! g. W: U4 L3 H5 d$ F2 R
    6. o( T& _5 [6 `
    70 @6 H& ^1 h4 v$ |, S
    8! O1 H& }' Z* S- D) v
    9( h2 R" E, R! H6 _6 I
    10
    / j0 H/ D" ]# E( |$ @8 v; O. K11- r/ d/ n+ k* K) q0 d$ h% `& |
    12
    8 {2 _3 ^6 P" Z0 |2 j13
    6 |( I; @2 ?# K' F  r& C* O) S3 p% u14
    5 L5 q. B8 G* Q6 A+ X15
    % p" V, F( i: E; H6 G( ?0 i16( F5 B& G+ x0 K/ Y
    17- o" m# |4 e" o
    187 |1 X0 h  _1 T
    19
    $ N6 }( f* E3 {( Q9 T205 S" Z# T; j9 m% J
    21
      o1 M( D" n4 T) |22% c, E) T* C$ L: A. _1 e2 ~
    23& @. T& F6 _  `; Z+ ^
    Params to learn:
    - ]% ^& |0 c. i  P% C6 V, O% u         fc.0.weight
    + V( N2 o7 w& k5 Y: P3 T         fc.0.bias5 T6 H% T& |# e
    1
      Q) ~3 W8 o  G2
    8 }. `+ Z# L, z  ?0 i% \$ u# B3: r- r# p! K/ y
    7. 训练与预测
    & m! I  X' z- _! l' l0 `; u' A* p7.1 优化器设置
    2 G, l6 W2 i$ w# 优化器设置
    . b% d+ w, l8 T2 Y& V9 B; }$ Eoptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    9 R: \- r* X% k0 v/ @+ _* Y# 学习率衰减策略
    * G2 X4 X' T, L' Jscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)4 }9 l- r# ]; k: x) P3 s- r
    # 学习率每7个epoch衰减为原来的1/10
    $ Y+ h  ]: g6 y6 R! F# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算* S" b% w% F( K7 u

    8 c9 }$ B) `+ g1 P: Ccriterion = nn.NLLLoss()2 b8 H3 i# F  @/ j% s" H/ U3 {
    1+ z9 m$ D" k' F+ x* d+ Z
    2
    8 z; S. ]+ a" |1 x3+ T  d; E2 o7 M8 a' n0 M
    4
    0 ~4 ~8 U/ F. ~- w" o5
    9 \  ]& \7 ]4 P6
    # l  U! j7 d. a  U) i7  I& U: s3 D, H/ J  ~0 D" u; r2 E" S
    8
    # V$ e/ l# ]. V9 a0 l1 f# i# 定义训练函数
    5 ]( `2 q4 |* ]! A#is_inception:要不要用其他的网络
    0 i4 z6 h1 B  G+ V8 ndef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):# a5 J9 A* Y  j& c: d3 L+ l& A
        since = time.time()
    ! i8 U  ]2 s' J0 p/ B- j    #保存最好的准确率
    4 P, U/ i7 X: P8 C. s    best_acc = 0
    1 F+ I2 i" h. `4 v    """
    : q- `% Q$ J3 K2 m9 C    checkpoint = torch.load(filename)
    7 m0 O, f) s) k3 \) `+ O1 f    best_acc = checkpoint['best_acc']2 z+ A# b( |( }/ `9 Z5 H  j* L
        model.load_state_dict(checkpoint['state_dict'])0 X, a2 s0 S. N* I
        optimizer.load_state_dict(checkpoint['optimizer'])8 s0 {+ @4 c: K* [" K: b* A
        model.class_to_idx = checkpoint['mapping']
    : o0 J4 D2 L3 Q. v. ^" O* x- F    """
    ; E: B, G+ b2 o6 o1 @3 P4 f    #指定用GPU还是CPU2 {( v" O' Q9 {
        model.to(device)4 l9 N) b; W" X& V& |
        #下面是为展示做的( ^) t/ O3 P2 n6 v: J
        val_acc_history = []5 J, o& [1 u, y& L; z/ w9 r
        train_acc_history = []. o# r' e  Y! k
        train_losses = []5 B3 ^' H+ z4 i& n, i
        valid_losses = []
    . `; n1 p7 S. A& u6 y+ u    LRs = [optimizer.param_groups[0]['lr']]* G8 d% a* n0 G4 C0 M
        #最好的一次存下来
    7 {) |: C$ z7 j8 U    best_model_wts = copy.deepcopy(model.state_dict())
    " P" U6 }0 d. L) Y6 Z3 d. [# U3 J! n7 @8 p
        for epoch in range(num_epochs):
    ! i! K/ h" x1 `& m2 ~) u5 {% Y5 Q        print('Epoch {}/{}'.format(epoch, num_epochs - 1)). L% |5 S! W1 Y" ~9 I
            print('-' * 10)
    9 I! I( f  y1 ?+ Q0 ~) }  P6 ~! F  H9 n. \7 M$ ^/ b
            # 训练和验证, O0 T4 k- ]8 Y" ~3 {
            for phase in ['train', 'valid']:, G7 u8 R9 n  U
                if phase == 'train':
    . ?* {$ Y' ~9 e                model.train()  # 训练# O: y* a  S1 }' B
                else:  N& {0 v6 m+ b9 R4 m# z2 {
                    model.eval()   # 验证3 j- l! i5 H/ m! h; k" ]2 I+ ]
    5 H! ?( [  ~( _4 ]* X
                running_loss = 0.0
    2 d5 P& }, ]' j) m- n7 d            running_corrects = 0* H* M* c+ T, z  b' \# r+ h; W
    # d, u! z6 @7 z& K: }: J
                # 把数据都取个遍
    9 B$ E0 C4 J5 t) ^( r            for inputs, labels in dataloaders[phase]:  X8 q3 A+ x0 i) x/ z
                    #下面是将inputs,labels传到GPU
    ) N& X4 Q: {& f' D                inputs = inputs.to(device)
    ' [# j2 x, k1 f1 q                labels = labels.to(device)% c$ E) i9 C7 {
    # A& s+ p- a% M( g* t1 n
                    # 清零! D% a+ h3 z- U: |
                    optimizer.zero_grad()
    : @$ f6 z' j. U8 \0 o2 k+ i5 z+ n                # 只有训练的时候计算和更新梯度! G" l" l7 Z3 c: a2 D
                    with torch.set_grad_enabled(phase == 'train'):5 x- ]5 O* L! A6 ^/ c
                        #if这面不需要计算,可忽略3 v" S; l8 B6 o
                        if is_inception and phase == 'train':- y7 f; l+ p9 g5 D0 H
                            outputs, aux_outputs = model(inputs)
    9 Y8 j, i' i" n' ?. s% A) i                        loss1 = criterion(outputs, labels)+ T0 |" a, B$ q, Q/ L# {
                            loss2 = criterion(aux_outputs, labels)% y, l& w4 h5 P7 \
                            loss = loss1 + 0.4*loss24 B0 A3 h  L! w4 @# C# m( u
                        else:#resnet执行的是这里7 R3 u: u2 d! p8 [
                            outputs = model(inputs)
    1 L0 N3 w! F( p( v" E* Z8 V$ H4 ^                        loss = criterion(outputs, labels)& V/ U' _* U% v/ I/ \+ D
    5 p/ J8 ], v) ?
                            #概率最大的返回preds
    + g/ B+ P* n% b* m" v" U% l                    _, preds = torch.max(outputs, 1)1 V0 U' {+ a& R. U# h+ d, l$ s

    5 E5 E% V3 _- y: v' d                    # 训练阶段更新权重  o4 L; N# x1 v
                        if phase == 'train':
    # t& H4 L: R( t: `' _4 N( _. ?                        loss.backward()
    ) {2 W' @4 q( b4 h6 O2 B9 i                        optimizer.step()# q. H/ w4 P* l" Q9 C" ~- A

    7 ?( ~$ |% [! m) c* e4 v                # 计算损失
    + C" ^$ f, r% C                running_loss += loss.item() * inputs.size(0)
    ! g* M* R! p- v! _% z3 L8 x                running_corrects += torch.sum(preds == labels.data)! k# [! f- S6 V, |
    * p; v1 k' y8 W3 y2 @" `# |$ j
                #打印操作
    / l* c' b) d- O% r& @0 v, [# ^5 [% K            epoch_loss = running_loss / len(dataloaders[phase].dataset)5 p6 @. e( F2 U: ?) d+ o5 {' F; u
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset), J' s  V( K; J9 J

    . [  p/ y9 i" ^% @  U# D( u; Z# R2 C9 {% U
                time_elapsed = time.time() - since
    : s5 c  _( o* D; P. ]% P            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    , P# |/ Y8 c1 D            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))" e5 V: x( }! Z" W5 d- R
    - m/ N2 x9 h7 M1 g# `8 E. i

    + Y4 x4 X+ k" m5 `            # 得到最好那次的模型$ l, d6 r4 s, l
                if phase == 'valid' and epoch_acc > best_acc:, A! P  Q6 e) f/ ]$ C7 m! V: T
                    best_acc = epoch_acc  l2 w% Y! e- X
                    #模型保存
    2 c, _9 b& F4 H# S. W                best_model_wts = copy.deepcopy(model.state_dict())1 N, H% _2 d# h* J" n( }
                    state = {) H0 a' |2 K: O) ~* {7 F+ o: J
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    ; g0 f  u+ p- s                  'state_dict': model.state_dict(),
    ; d) @" U1 P6 i7 l7 T) H                  'best_acc': best_acc,
    ) A- M* ~4 I* \. h7 A                  'optimizer' : optimizer.state_dict(),, a8 G, N% c4 k, G0 D
                    }  c: R- L" M2 X2 X  `/ S
                    torch.save(state, filename); p! ?6 D! C% p0 u# q' [
                if phase == 'valid':5 D: J. o7 {/ H6 G: z- m
                    val_acc_history.append(epoch_acc)* q# H5 f& d* C* F) x
                    valid_losses.append(epoch_loss)1 K: \* z% ~* Q
                    scheduler.step(epoch_loss)
    & ^  A; B& j7 ]1 R* w% h* @% G            if phase == 'train':
    - l2 Q8 y6 Z) C6 ~                train_acc_history.append(epoch_acc)
    0 I/ B4 r0 Q2 a+ ^: z                train_losses.append(epoch_loss)4 T3 \( Y( ^, x
    " \) P2 Z& r* o$ i1 N/ W% X- A
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    / p( M. _; `# w$ A; k        LRs.append(optimizer.param_groups[0]['lr'])$ `/ J0 u3 ^; D5 J2 z
            print()
    ; @7 C6 q1 k; d7 A0 X0 r- b2 ]) c2 N1 l  p* S/ L) j, f
        time_elapsed = time.time() - since+ o9 c2 n7 T& n# c
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    & X0 a; I: H1 L& V6 K  k) J    print('Best val Acc: {:4f}'.format(best_acc))
    5 F" L" ~# ~6 w) ]. N% O9 ^! Q
    , b4 u1 n/ ?0 R! O* @5 d6 z    # 保存训练完后用最好的一次当做模型最终的结果
    % r& ]9 S: y; B8 h) H    model.load_state_dict(best_model_wts)- s  [8 t/ m+ y1 W$ l
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ; m8 g# K8 O9 g+ O
    ( O' m; l- r+ E4 x3 g6 s
    : _$ d8 L# f* Z8 G! [7 q$ x! d
    14 p0 R0 ?8 k, H$ q" r( m# N1 [
    2& p) _) I/ h: v3 u5 J+ o  L& V" y% [
    3
    8 a' `9 \9 T/ X$ D# S4
    9 }( `1 ]9 ?6 O4 r58 e1 p' C% M# b( p' }
    6; V* ?! q' h$ {6 x4 T9 j* ~4 P' |" ^
    7
    , g& c, w3 I  M" a82 R! S9 c: f& Q) O* O8 T
    95 C$ ^2 S* f8 Q' [
    10
    # T) o, c/ `" g% \4 Y; I0 E+ N11- ?" _* @1 ~/ B2 c& Z
    12
    3 g( H) v# ?% Q* Z5 [& C, k13
    5 V! u4 W( J9 ~0 D141 Z; q8 l, A1 f  q) K
    15
    7 K, ?+ r  \8 E4 m+ A  z1 _16+ J, P4 M; X3 y! A
    17
    4 x/ a3 E5 b9 ~* }5 D: Z1 v18
    2 G- v, W) u% ?$ Z19
    9 _9 k- t) |' G20
    : X/ E) G; q& r2 X: i& a5 m/ p21$ G7 @% w: P: C' E. Q& u  C
    22
    : g2 Y1 X! V/ N1 }1 b23
    ' F$ Y- x4 L6 l. ]# @8 `. U24, L/ J4 L5 l& U# e4 m7 Y
    25
    ; g( a7 x3 ]$ B6 a26
    , ^2 g' R9 H$ y% b& T27. I: c, o1 {; Z2 A, c
    28
    ; _8 k2 R0 o) ?; t& {2 t/ Y29  F) p8 H2 F( a5 E2 _
    30& I! `+ A8 \* U1 s* [
    31
    ( ?. a0 h. X: E( z  u- i: R32
    " A0 v/ d# j! n! S) E* C33
    : C7 V, _. ?; w9 U34) X) v4 m( p' o; T. [9 L
    35+ L. J5 n# O7 O, M& A) Q' y
    36+ h, z( A$ T) b  t$ ~' g7 S5 |5 J
    37( i5 T- L, h6 ~/ a
    38
    8 Q! v2 r+ d& r1 {& G1 L391 K! u' z+ n6 c  \3 v
    40
    ; c' E  D" P9 o2 _41
    % ?8 ?; ?& j3 T4 d6 @# G: G# o42
    4 C3 }: r: B: G6 A; }43
    * j, v6 R4 ?( @- `44
    ( @9 O1 {; h" a2 f* w, Z7 ]; |45
    ( L; c% ?! J5 W# G( N6 N7 u! k46! Y; M( }5 a' a/ R
    47
    ' W$ ]! U# ?7 x/ W: C; w48
    . Q1 T$ g" R5 }% F4 _/ b49
    / @5 p* G3 [3 ]* y: B2 h50
    0 g+ K, H) r* Z# O0 W% C3 G; b' @51- R  z- y7 \/ r7 Z6 G# J& \- V
    528 P0 B( A; z( o9 t) W
    53
    # ^( F% W: w/ @$ C54
    : W0 }1 G: X! j: g7 L+ E6 d" y( L" w55
    8 [3 J7 [% \3 N/ A" F568 A6 A7 r: _! w$ q3 n
    572 A; P9 W5 I7 Z  V
    58
    ( `: _( b5 ]- X0 A& |2 c2 @59
    * a1 R8 a, p/ q' K0 o7 I+ V( b60
    ; Z, J$ o  \. I1 R. s/ a, ~61) |/ z9 ~! R4 A0 F- s) P& R
    625 Z6 P* l# A% o  ~2 A" X
    637 m2 w! M4 l0 \( c0 y& W
    64
    / `# p8 N0 P9 D* v+ d3 q65
    ' |$ s1 u" V' \: b9 k3 Z66& J# R1 H5 {# v
    676 X3 w- U2 \2 |: Y" h* K4 p/ {
    68
    0 E$ z$ ]2 I+ f$ k" A69
    3 H8 l. [3 N, `0 ^0 S, N70: R* B5 M7 q5 j
    71
    " J6 v% h% O( _& `7 [72' g$ X: n' s! w- c
    73
    0 p! o0 G/ e" [1 s+ b- j% G74  o  `0 J' ^- n# q
    75
    + a$ u2 T2 o9 _/ ?7 a% c' \( Y76. `# [3 v& k5 l4 o! |, w
    77
    . m7 o$ G- i) M& h8 P# P2 f, O78
    , D9 _9 b5 w( D0 D# X% M79+ e# r& O; a4 l1 o
    80% o" J! g1 {9 w  w7 H0 b
    818 i. |2 J/ v% M& {* L' G# I8 h2 g
    82
    * `# r& {  p0 \9 [' A+ V83! y6 J% J. |0 L) d
    84
    ' ]8 v: H' [5 l+ c2 O) _$ x" A# P85
    3 z& {" l3 S4 B: j& x- p86' `- k7 E4 C, g1 I5 U& ]
    872 H1 I1 {) P& R8 P) _9 t0 |1 q4 s
    880 }* _, M4 ~! R7 r4 j4 u5 F
    899 o7 b* f2 w3 D* X7 N
    90
    4 [: E/ @8 G. ], L! C, M91. ^4 a  w3 D* B7 D7 o& q
    928 ~6 J: s3 }* n- i- }) c# s! C2 B
    93* {+ N1 i$ z0 i+ y/ ~' H
    94
    & L: g% y+ e5 X2 W! G+ U/ g95" B0 N% U8 M( M' P  c
    96
    - Q; v, W7 H! j2 |  g6 C! O; v6 ?976 b6 ^* i- t- e' O( T' g, w+ d& l
    980 j0 d# e' {0 p2 s% A  s+ a$ D
    999 G6 E# @7 F' h+ X& L9 r5 C+ R
    100
    # E& R/ g+ p  J. d101+ j. }! u& l4 j% h. F; @: V
    102) b. S) v: I6 k* W0 X. C7 L
    103
    $ e: j, Y" D9 r" W8 M8 @104. a% C7 |1 f. P$ A9 n; w
    105
    % S4 W4 w! H  M- A( {8 k1062 I5 c: J& M* N) }; d9 z6 c
    107
    8 a$ w- X7 C+ G1 p' b& Q8 C108
    % a/ i- A/ r1 X  o7 U1090 S7 j  C# q  q/ n
    110
    8 X* s! `( s' p& `  ^# O/ M111
    . ]% T2 u: m9 R1126 Q/ O$ h# o. k
    7.2 开始训练模型3 [) }) @1 `0 n  C+ X
    我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    & h# q+ t. {, e  f, ~+ ~
    ' F9 h. y; {) o1 {: r, C  @#若太慢,把epoch调低,迭代50次可能好些
    4 T4 C' G, C! c$ [#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
      Z, i- R8 d9 Z1 Rmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception")). h! b& Q0 [8 T( H

    8 u% l; V* ~' X% f1
    ) y& [' p( y9 l) V. r2, U' A% n: u- X  D$ {
    3
    6 I/ J7 D. S9 X5 Q4 _& U! l' h4, B" D0 p& ?8 s; ?
    Epoch 0/4* g6 ~, z! b1 S1 v
    ----------
    2 b! M; E* u' q" RTime elapsed 29m 41s1 u4 ]; H( R- x/ V1 n' l. i) e
    train Loss: 10.4774 Acc: 0.3147  @: w  Z- d1 g  M% c
    Time elapsed 32m 54s+ B6 y, K6 t/ [. h- Y7 ]
    valid Loss: 8.2902 Acc: 0.4719
    4 n  w9 v3 v: J- ]! yOptimizer learning rate : 0.0010000
    # B, z1 w; U$ ~; }% I& _! h# u% T; h8 p
    Epoch 1/4* W+ W  z, `6 L/ b# q% }
    ----------
    3 T& a5 ?! I" B" mTime elapsed 60m 11s
    / @' ^( u& d! u3 z8 f2 itrain Loss: 2.3126 Acc: 0.70536 r8 t+ P4 A8 Y8 y* L
    Time elapsed 63m 16s
    & r9 k6 @0 d5 y) lvalid Loss: 3.2325 Acc: 0.6626
    9 }! @4 ]9 A5 U7 ZOptimizer learning rate : 0.0100000; W0 e$ V- t# f9 e3 c/ {

    4 B0 |) M- K/ P/ W8 p" vEpoch 2/40 G  p& i; d5 s
    ----------! a" a" Z8 X! ^, e5 u/ g; T3 H" |- O
    Time elapsed 90m 58s
    & u3 }% X9 j* Vtrain Loss: 9.9720 Acc: 0.4734
    & S9 y' {% e3 XTime elapsed 94m 4s5 H/ O) p6 w, N% s* J
    valid Loss: 14.0426 Acc: 0.4413
    3 w: K% M* D& v  \; K% Y* f" eOptimizer learning rate : 0.0001000
    4 D: M% j8 V- x6 L' x9 E/ g% u( g. T4 r- r' I8 `9 [+ `5 R/ J! R
    Epoch 3/42 L8 h& A9 C& U. o/ W
    ----------( M) o: `/ O. v
    Time elapsed 132m 49s
    # ~5 R6 \9 O8 M/ J& _& t  Ctrain Loss: 5.4290 Acc: 0.6548: ?/ U/ d5 K. F8 I+ v$ q: m
    Time elapsed 138m 49s
    # l2 K4 y& i% w! T7 Yvalid Loss: 6.4208 Acc: 0.6027
    - w. e3 o3 A2 V' U, R6 w2 N# dOptimizer learning rate : 0.01000009 {9 A/ V  c  e4 P
    9 H8 B/ C: K1 w  a7 k- s: ?/ B
    Epoch 4/4
    ( M) O* ~* S5 V1 F9 ], h) ^----------
    3 g1 Z0 _# A) b! ~Time elapsed 195m 56s" X- t2 D/ ]( [; n- L0 F6 P% F2 |
    train Loss: 8.8911 Acc: 0.5519
    " w" x' h; w+ V9 CTime elapsed 199m 16s/ q& ?1 u. X, S- E
    valid Loss: 13.2221 Acc: 0.4914: Q) B( I0 t, K* ~
    Optimizer learning rate : 0.0010000
    4 F( u; ], H0 \, _
    + g$ c$ ^4 e- WTraining complete in 199m 16s# {# p1 S+ b" t7 O
    Best val Acc: 0.6625924 h8 A4 v& X5 |4 O1 u1 U6 T+ G
    2 Z; P' g3 v' _* l8 H7 y+ P0 t
    1
    # m, _# i1 G3 h7 e* V1 V; Y# z8 o2
    ! V8 }4 T: y8 R2 e3
    8 ?3 j1 k' h  N4& a- }7 }  Y: s4 b2 E: l7 i1 H( F
    5
    0 _5 R3 O. U* D! r6
    ( W2 C. E8 z0 \( @7
    $ n' f3 d: E% O( j. O" [83 H$ ^1 ?# s  Z) Q: Y) {" n
    9
    , _; B/ q4 z% g7 u0 h! u  x0 F10
    + K6 V& n1 O5 _1 S11
    : Q/ C+ ?' M* q$ d0 X$ U0 `0 t12
    " Q4 R, t' @/ _' a  C13
    & b# n  Y. x+ ~14
    # c5 c4 Y3 w8 a( Y' ]15
    - F7 a+ w# T- e% V" h16* J# d6 p& Z; s3 o9 B3 ~
    17! c7 A% Y' P( m  N! }/ o
    18" G7 g# k( K( D9 h% a" r
    19+ L, k6 A9 K: L( I. U! X' t5 L# U
    20
    $ A  y4 K. v9 V, r0 V" r218 w! |8 ~, s( V& b$ t
    22- M0 P) _% C  ?7 o6 w" M  D1 Z
    238 a$ j3 I' o$ u4 T& j) U& W
    24
    . y+ P6 f# q/ R  q256 C5 _  x6 i& \4 S4 Y* v5 f
    26" R, f' |; F3 c3 u
    27) m0 s8 {+ o2 ]3 ~$ U
    28
    , v' h! \  m6 k% {292 _+ o6 \, h: N4 G! I/ X/ @
    30- \  W5 F( O- i& `, h* M/ y. @
    31
    8 W: ], ?; J3 ]& B2 t/ y32! U8 f& }! L3 Y8 C$ b4 i
    33
    / q- D' b) r0 p  e1 G: `344 g0 H/ p- Z' c' H5 l* ], m
    35
    3 l5 l7 M% I8 X6 c# D  n36
    9 k6 W* H; d  ^4 d1 S37; C  }$ I! |, K( u  B2 b2 F( ?
    38
    ) O. U: _( R9 O4 ]- b! n39, a$ G; K# K( N
    40
    ( z) b3 q' V5 T4 O6 u/ A41  n, e4 j6 g! T
    42
    " w5 w, U& r& y; _; U" Z' r7.3 训练所有层3 y% f( g% B* x
    # 将全部网络解锁进行训练  L' g- m: m! H+ Z  X
    for param in model_ft.parameters():) d* L( p) G. |! m
        param.requires_grad = True8 m' f  @1 Q$ T
    # M* ~0 ~: J' S5 z0 Q
    # 再继续训练所有的参数,学习率调小一点\
    3 a: L$ a1 K: s% j5 n9 ~optimizer = optim.Adam(params_to_update, lr = 1e-4)
    ( q; \& u3 ]% B  ~8 Cscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    1 ?' `6 O( q2 I+ w3 k1 K
    4 i# z- l- r$ |" b7 x# 损失函数
      f  J6 n! Q& L+ ~. w6 [criterion = nn.NLLLoss()
    ' g6 r/ S' _! f9 Y, H5 f6 I; f, W14 K$ E$ @- ^4 w8 Q/ S
    2
    ; N% c% g  s% }4 Z3
    2 f  C8 J) I; n# o0 u9 z& L/ ~0 z- c4
    & f0 g. q4 e% Q+ B/ Y) Q5- @, W) P2 ?3 y, r5 W
    6+ b8 [# o5 g2 T9 u
    7# E! l: F1 x+ }2 i- a1 ?: ?
    87 |- b. ]' `! B, a7 d, w: `
    9
    # W$ M- r4 K  c, C$ @& F  q( b10$ O' n7 m' ]$ o. D7 n9 v* b
    # 加载保存的参数; C. Z" p& _* D- G$ Q! z6 t
    # 并在原有的模型基础上继续训练9 V& T& I" B+ `7 @- A, D$ v" t
    # 下面保存的是刚刚训练效果较好的路径
    # t) x1 l: U- |: Jcheckpoint = torch.load(filename)
    ; p0 k: ?3 L) o; u# z' f5 w# Ybest_acc = checkpoint['best_acc']
    . p9 P5 U4 _2 u7 z" f) i. r5 D7 ~model_ft.load_state_dict(checkpoint['state_dict'])7 @* H& r# e" J% P& a
    optimizer.load_state_dict(checkpoint['optimizer'])! C# Y# J# G. x/ Y  L+ q
    1
    , j2 x5 b% M& q2
    : p0 D: ~- e6 K. |( k3/ [$ z7 Z8 C. i& Q7 ?% [
    4
    6 z! D2 T6 o2 i% ]58 x, R6 P4 @7 r" e
    6$ x7 j; M/ K5 z# g5 e
    7  d8 ?$ p$ S% @2 ~3 Z) X
    开始训练7 e% c9 T$ s( V4 G& z( Y, N; ^
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考  d: r# d& e; ~) Y7 v& N& A( t

    ; }5 V, D$ ~9 P% K2 nmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    / h3 r% E& m8 {1
    . s# W( H( E4 oEpoch 0/1
    ( J& V# g7 q4 A----------0 }9 u: c. u& m" a! I8 @' W
    Time elapsed 35m 22s
    & y! c! V4 V( r. \train Loss: 1.7636 Acc: 0.7346* Y2 m2 {% c8 q+ {
    Time elapsed 38m 42s2 m1 [2 Z% e3 t& N8 Y/ M; ], Y% s
    valid Loss: 3.6377 Acc: 0.64550 b/ w( U7 G7 x4 I3 i+ H
    Optimizer learning rate : 0.00100004 _* \: t- N; g7 p/ u1 l3 C2 t0 \

    . L' s! R: H1 |1 n6 {- EEpoch 1/1
    9 n( E( i5 L. a----------
    * z6 m! _. u9 H1 |Time elapsed 82m 59s
    " d! S0 K  {# i2 u* A1 i+ l$ ?train Loss: 1.7543 Acc: 0.73404 g# ]  M+ y- b% o7 s* B, \
    Time elapsed 86m 11s) V! Z! f7 Y$ J2 d2 {+ M+ `
    valid Loss: 3.8275 Acc: 0.6137% f% e8 \4 j  W
    Optimizer learning rate : 0.0010000
    9 G: x$ i+ v* K$ p; f+ `
    + k; b# i$ Y  }; \7 P) hTraining complete in 86m 11s' m% J: U: S. q- c0 t
    Best val Acc: 0.645477  O" ~5 m1 L( F4 I. d7 C4 w1 X/ c
    " }% `! W$ p! W( q' Y: s; _
    19 h. A3 ]' M$ D- c+ g0 B
    21 D$ Y  A, g7 x
    3
    , O6 k9 a# m" l1 j9 Q4
    $ q9 {( Y1 M8 l3 @( W& z0 o8 {51 t0 _( W  V0 s6 j" r' V
    6, Q' y& [: F! C+ q- B8 I% J6 t
    7
    $ `' s; R" k. X0 k/ [8
    ; z  g( v; R5 W9
    : m$ @6 b$ n' c. z) u; r106 b% o- V4 W( J
    116 l+ A- A9 G$ C: j
    12
    ( @6 Q0 K! Q2 I, ~2 k6 U) h131 w+ D' t$ a5 i  Q2 `: ?) m
    148 J6 M4 Q8 u" {: I8 i& t% i
    15
    ) Q# o. C. f8 B! w; u, j7 ?16; b( M+ B8 j! N
    17/ m: H9 k/ \+ R( c
    18
    ' o# I! }* H( a- f* J; g  P4 ?+ w1 x: p8. 加载已经训练的模型6 i/ \& J0 j  q- Q8 k
    相当于做一次简单的前向传播(逻辑推理),不用更新参数
    6 `1 @: P* L' B- s0 d, Q& c+ G; b+ y0 A
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)7 v0 w9 G) L4 D  D( N: I3 y7 _; ?

    # r9 r) W3 U* t8 f4 m- y8 q& `# GPU 模式9 v) T1 Y& a  o4 t0 L* @
    model_ft = model_ft.to(device) # 扔到GPU中
    0 Q0 P" ^7 S  m+ [5 b2 r0 P9 {4 X7 X5 U7 V% ?' _6 T6 o
    # 保存文件的名字
    7 }. J' f* B/ v: b! W/ f) F8 U; Qfilename='checkpoint.pth'# Q1 F3 q! M% H/ Y' Y2 p# W0 L

      \: ^. J5 W5 F7 f- O! x# 加载模型- v, E' G3 `* G0 Y4 C
    checkpoint = torch.load(filename)
    ! O8 R# x9 N' v% V, U, I- ~best_acc = checkpoint['best_acc']+ h/ k1 V3 c+ C9 J; }
    model_ft.load_state_dict(checkpoint['state_dict'])3 D! D* z  ?2 B
    13 O( c( M; [1 q- _, K& u
    2
    2 g6 J' x* c  V( [* p( e3
    # K6 z& t% b3 |- u. V6 y0 d4$ m8 T% |3 L2 p( I& f
    5; {; H9 ~) _' j" V9 J3 ^
    6
    ' Y& {/ Q4 O7 e$ @& q, w( Y7
    ; B7 r; K$ o6 u8
    % ^& f9 a( ~' f% I9 f0 [9
    " k5 x! X5 G3 H! H, q  D10  l. y4 y: ~  ~, f9 X, w% H
    111 t: P# u4 Q% ^5 Q9 d
    12
    - o8 f) c8 L  A% D<All keys matched successfully>9 W/ h" N6 d$ u8 b
    1$ q! x) d/ O2 a. I2 W
    def process_image(image_path):
    " I  J7 T+ C$ X7 M2 ]    # 读取测试集数据; z2 {' [: D* k0 @8 I
        img = Image.open(image_path)
    6 p# d, A' D5 Q    # Resize, thumbnail方法只能进行比例缩小,所以进行判断* R* g* m9 x' e( O4 F
        # 与Resize不同
    ! ]$ R# v) l; y$ o( b  C& `    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    5 k$ E' v* K3 ]  D3 ^    # 而且对象调用方法会直接改变其大小,返回None
    9 @4 Z5 Y/ ~" o3 ?9 B, u! ~/ z8 |8 s    if img.size[0] > img.size[1]:$ o* u) j, `& E; Y" s) T# m
            img.thumbnail((10000, 256))
    5 W4 ~8 m: Z& l. t    else:
    9 B3 a3 ~: v4 ~+ T) P        img.thumbnail((256, 10000))& b( N% _) |+ W# `: V! s# ]
    $ x1 J5 n& d7 E' J5 J
        # crop操作, 将图像再次裁剪为 224 * 224; W0 H# k. u: \& p7 b; Z
        left_margin = (img.width - 224) / 2 # 取中间的部分
    5 ^  b( e1 ~) c    bottom_margin = (img.height - 224) / 2
    + ~/ E3 ]6 J" X( n: v4 Q( G    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度" F; b* r. c7 f# D$ o
        top_margin = bottom_margin + 224
    ' V! j! O, D# {% h; f8 S& C5 p' g8 d+ f# Z. w2 M
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))* a. E- \- t# v6 d5 D# R+ L
    & y# L& S- ^6 \: i2 `; f( o$ f
        # 相同预处理的方法+ R- O! X2 y6 U! ^) n' q6 o
        # 归一化" {2 Q8 Z: X+ v3 x& y
        img = np.array(img) / 255
    + p/ |( V: E! I+ i2 h6 v% h    mean = np.array([0.485, 0.456, 0.406])* B' z( r( S1 q) L' A
        std = np.array([0.229, 0.224, 0.225])
    + k2 f+ C$ }  }4 e9 \" E    img = (img - mean) / std
    5 H1 P) D( [( Y9 `/ s) g9 o1 y( X6 G" m+ ^  H" z' h' o/ F4 }0 s* p% R
        # 注意颜色通道和位置
    0 g% _' u4 e. N# W    img = img.transpose((2, 0, 1))
    6 X( I. J7 t& ]. j
    - d) M& Z6 g9 R+ q3 q7 f    return img
    4 h( j. Q3 ^9 z8 v7 ?4 x% v/ _" B$ C% C" Q8 O) z
    def imshow(image, ax = None, title = None):, M1 w( ^' v/ S* O0 y3 X
        """展示数据"""
      p+ h: q+ x2 {9 a2 j    if ax is None:
      B$ }4 z; Z1 M! a        fig, ax = plt.subplots(): l4 L; X+ v2 y9 T
    0 y# A" D, @+ C
        # 颜色通道进行还原
    : T9 f2 G; u- Y    image = np.array(image).transpose((1, 2, 0))" r! A+ z' p; H; p" `6 ?
    6 [( l. ]1 M+ \+ Z+ ^
        # 预处理还原7 C% v+ Q% t2 P$ N
        mean = np.array([0.485, 0.456, 0.406])
    , N9 Q) B% [/ u1 p    std = np.array([0.229, 0.224, 0.225])
    4 f) v* p  U" d% E, @' Z    image = std * image + mean
    ; Q) z1 B1 X- z8 j: e    image = np.clip(image, 0, 1)
    0 Y* c3 A5 d/ d, g. u/ W/ M5 a
    & z$ B, k/ w! q  g4 @. ^9 h    ax.imshow(image)
    1 n5 Y$ Q. ]; Q    ax.set_title(title)4 F! X) Q& p( K3 D& |  Q
    + G0 f# ]$ }& a( c
        return ax9 ]* r  S, S2 F; ?. i# m- z

    : ^% f9 ]- B$ pimage_path = r'./flower_data/valid/3/image_06621.jpg'
    0 Y% B* u) ?& t/ Y) q: |img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    # e  _$ R, F" U+ G# V6 H" K7 [imshow(img)
    * c# M( |7 j( j: Z7 L2 z: A2 V  [$ X- X
    1, h" k; ~! K% K) C: _7 R5 o
    2/ e8 t0 _* l/ |. o2 n6 X
    3
    * I: ?0 K% H) L8 Y5 {& r7 Y4
    2 Q8 K; l: p  \+ S5% H7 O# \+ k8 @- I% T/ b" P
    6
    ) Z/ V# g) i  o0 V4 t) ^0 o, \2 X7. j. H! k0 ^4 `( P$ Y3 _4 h1 _5 j
    8" O! i3 F& L( U
    9
    & Z# e% t0 }. C10
    ! W6 V/ W0 R* |, G! P' G5 B11
    * H. o; `3 p5 |8 e% x( ~12. h7 K0 f* ]4 `& d0 t( V' q
    13+ @% a% ^9 d. r" x( O. Z0 |) r
    14% V* z, L: W! ]' D0 t* g
    15
    3 d& D  Q+ I) T/ {! I7 C% r. c16% x/ h3 {5 W& x* S- c) r3 U
    17
    . B+ ~+ A+ J7 A: q18
    " q% d- s7 L' M9 G( I3 w19( `9 z  @2 x% e2 r7 p3 t# `+ d. o
    20
    2 p# A. h* k" K3 o' S& }  I21! C" N; U/ j' L" i
    22
    , B0 A8 S8 r4 F3 J232 c3 h  ?: `1 l! l4 y' x
    249 P9 J7 v5 K+ n; L! G
    25
    & ]' N  C: z9 h8 Z0 ^26$ s' u1 n! g8 u; d; q8 [
    274 ~4 R3 f( g. k9 e
    28
    " m) N8 C2 E# D% J+ m29
      e; A0 O/ `; |; n1 ~30: o+ Z5 M  a( d; P' G
    31
    : n, u4 L+ Q4 \32
    4 `0 W2 u8 [; D33# W. k5 t. S$ n) d9 _
    34
    / ^. q  G# t9 j8 g0 v) f35
    # K6 }3 N- J8 R% _$ F2 {36  X) @7 U9 _8 z! z6 a
    37
    ! l! J' U9 d/ l7 ~* F/ s% }8 o38
    ! k8 p2 p, w; J7 V7 u39
    % a1 v9 l/ P! a; L7 k- u% T; Y: f40, d! K! f9 N% f
    41% Y% O0 a; K) q+ E9 b8 Q% w/ Q
    42
    . d: @$ f8 A2 B' ~' a43' d1 s! Q' \4 t! q  G. p
    44
    # V2 R' ~, Y* ?1 y$ x" [454 |0 o  P4 g& u, \  C# _+ |
    462 ]9 z$ F; H) S8 `' o- K
    47
    , Y. k+ \, Z' d2 |7 z* {* I48
    3 `! a' N- h" J# {( e( g+ E49
    4 f& S6 ^( Q# f: v) A# p50
    5 I: O% T- W* ?) q" \3 M: R51
    ( a1 m7 O8 e. E) {- m) @52
    9 j) D9 w' b* ?6 T5 p53# N! i4 r4 Q5 }, M( A5 l9 M: \! y
    54) G1 U, ~( W+ j( b7 x" p  s
    <AxesSubplot:>3 J4 Q, f, L9 U6 x/ S
    1. s. y" C( H  t4 [

    , m1 Z3 k2 d3 Z- T2 d上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
    $ T9 l7 S' _& B! I$ |
    # a- W; J" q3 n7 n  ~6 Yimg.shape
    , _; e$ ~; {  H2 w; y* z15 H: w/ C$ Y9 k/ M$ d
    (3, 224, 224)( z. W6 g4 g# m  R& C# ~
    1, B# }- @1 M6 Z+ Q0 N" U( L
    证明了通道提前了,而且大小没改变
    6 v& w9 L/ z$ ~  X5 Z  _7 Y0 }' C" r
    9. 推理
    ; l0 |" ~' v1 I6 x* l! jimg.shape
    2 [0 X& w7 q" q7 M0 k% @( C: |0 x, U0 s  t/ I" U. {4 C) Q. H
    # 得到一个batch的测试数据
    * {* t5 |5 @$ Q7 d3 jdataiter = iter(dataloaders['valid'])
    * q- s" s+ b! f. |& I% H# `4 R4 ~images, labels = dataiter.next()+ @& |  Q% h5 b: M  k- i2 y) h
    , D" r. U3 G1 D
    model_ft.eval(): D' H; E5 v5 m- y+ E
    * `' L( x& o8 W: i4 f9 h
    if train_on_gpu:
    ; D- }$ O; a) @6 N' X! Y    # 前向传播跑一次会得到output
    ; C6 |' `, H+ k% v5 C1 e- d- P  r: `    output = model_ft(images.cuda())- y$ {$ M& G  X/ w/ s. ]9 s
    else:- o( k2 P& y" f1 V6 j
        output = model_ft(images)9 D7 C2 g4 \$ `( @

    5 e: d$ x& v) o$ \# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
    ! Z% _- l+ k5 Z( K6 T5 \output.shape
    % i  K+ v" q% @5 X+ O$ U  w, ]1 S. m1 m" v
    1
    # H2 e$ ]  J5 F: {: [- x0 }2
    - F/ Y  r/ O4 V8 Y) E/ y3
    5 T/ \: m% i( l- B/ [. f4& Y: @- ~. Q2 e# i& }# l
    5) e& Z% @& H7 y. U" l7 ^7 i6 W
    6: I5 Y- h. B! @+ h
    7
    . \- _% ?9 h6 A5 k8( X' W1 G. C' M7 p2 b
    9- w6 z# p" a. R4 }; ^: z7 B
    10
    . t4 Z. F" b8 v/ I4 [/ n11# ?$ S. E9 p: l; c* K; e9 ^* W; [5 ^
    12& D$ [" v) U. @
    13; b. k$ Y' o# }+ {  T
    14
      }% e3 T# {' O! V) p" L15
    8 {. w5 d9 B: }( ]% F6 _16. q, k1 n! ~2 _* }
    torch.Size([8, 102])
    2 b; S1 m8 y: V; y2 E* ?1
    & `0 W/ V4 B& `" X, \" W1 ?9.1 计算得到最大概率
    * y! Z3 f1 T  [1 G! x_, preds_tensor = torch.max(output, 1)
    * m, ?7 t: k% G: g
    , m6 G9 f) f. Q' }- `+ T. Epreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    0 {* q3 x, p+ u1; o: v4 Z0 t5 t7 W6 g3 `
    2
    ' @! x, a; r1 a6 J$ q% {% D* I3. C- k1 B9 H7 @4 W$ Z, c
    9.2 展示预测结果
    . ^+ I! `8 D/ t" `% C# ffig = plt.figure(figsize = (20, 20))
    * J$ {7 _  A* S# D! Q3 mcolumns = 4; [2 S7 e' V& N# z# ^0 W
    rows = 2. I9 H% U1 {2 f2 p1 ]+ y; l9 }+ m9 x4 |/ o

    ! G) ~: i6 O" Qfor idx in range(columns * rows):' z5 }5 X+ H) R0 Q  O+ R4 a2 H
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    , ^1 O% S9 h: ?) L& Q    plt.imshow(im_convert(images[idx]))
    3 I: K  R6 Y" E  Y    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), . N) \+ |4 `, l; X5 P' `2 c. f0 i
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    ) Z1 V6 j" t0 R: ~2 x, e4 Y* fplt.show()
    , z6 A8 u2 Z  o* t, H, T# 绿色的表示预测是对的,红色表示预测错了3 D0 G2 b1 L. s; `
    1+ Q8 H4 A3 G$ i: ~1 g# K
    27 r/ B0 z! v, H) G% |6 i( p
    3' U% ^; k) r, X! f9 N
    4$ r4 b2 B$ q/ r. ^( W# D
    5
    3 j" D6 x7 Y0 n* K! ?) {6
    ' y$ g+ o" a0 M; O, y7, ^+ R4 D4 O& u" q8 i% B
    8
      _& Z& S5 q! w1 r9
    . f+ a: q$ _4 {, a$ b10
    $ q/ j' T+ `0 @( N11- L) S  `' s$ j/ H8 m

    8 `+ b: d% [: X' W1 s2 r
    , X* w% m1 m4 @4 w4 L  _: z  P1 q1 M$ a& X$ g9 S7 w* D' S
    ————————————————. ]: h' y$ M- Z6 ?1 v
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。4 r0 I& s  k2 V! V- y
    原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    - J' ~5 Y! M7 W8 j
    . E8 s' d  m4 E$ A- n) [. N
    8 q( s* b# H( f0 |; b0 }! F
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-6-14 14:42 , Processed in 0.477729 second(s), 51 queries .

    回顶部