QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2752|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |正序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例- a) o, ^, h2 c& i2 U3 {4 c. e
    ) z+ Y' B4 {/ Q# j1 q: u! \
    文章目录, s( u0 D8 c! A, ^2 I0 [
    卷积网络实战 对花进行分类
    ) P' E, ~8 b/ [# ]# l  d数据预处理部分
    5 `9 H# S0 ?! \2 E8 |网络模块设置
    ; |( U, q1 `: e& X* o# Z, o* e6 U网络模型的保存与测试
    7 ]  k! D& D9 P: V% }+ U3 n; `% x数据下载:" h5 _1 t" D% j; s( ?' H& G* C
    1. 导入工具包
    7 U( @7 u. R7 s. W2. 数据预处理与操作5 E& z2 K' z; W% Q
    3. 制作好数据源  v9 p2 U& v/ h" |1 I
    读取标签对应的实际名字: c3 ~7 f5 \* W9 m
    4.展示一下数据
    + f3 s) g% t- T6 r$ J8 J+ u1 T+ |5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    5 ?$ A, g/ k( ^6.初始化模型架构3 c! B7 p9 c8 E0 ?. B
    7. 设置需要训练的参数
    1 H* y% h' P+ Z# N" ]) K: s7. 训练与预测& ^8 r! @$ B# v. [" y; R6 y
    7.1 优化器设置1 V9 l, K/ ^. w1 C0 f/ `
    7.2 开始训练模型5 x+ B0 k0 o4 L
    7.3 训练所有层
    ) |5 `( ~' ?6 A8 T- M9 Y开始训练3 N1 t) ~5 `" Y8 e  J* ~! h3 K+ d
    8. 加载已经训练的模型
    3 S- z  x3 |7 O; P9. 推理
    2 d" V& \' D8 r7 N5 C2 {+ H6 ^9.1 计算得到最大概率
    $ R4 ^8 p8 a& v$ b9.2 展示预测结果' T5 z2 a/ G: _9 N
    写在最后% V) P' a/ Q$ j
    卷积网络实战 对花进行分类
    % j2 Q, Y1 D: z/ Z; c  ~( E本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能+ Z3 q  u: O3 P. O) d
    . g1 L: l* i1 ?: p$ P# I, E
    在文件夹中有102种花,我们主要要对这些花进行分类任务
    ; {) ^/ G  A& u; f$ ^文件夹结构
    6 G+ P/ e5 F0 g2 Y6 _: @/ v. G
    2 S! B" U5 h  C  y7 Oflower_data
    + N7 Q5 ~8 z2 X% |7 p
    . h1 {& ~0 @& N; ?train% ~& L2 I0 o9 w. k1 W; M+ X
    . i& i) u- X" ?/ ~$ M$ w
    1(类别)
    2 x2 ^  U" _: c2
    5 {( {+ k) a& P( P/ mxxx.png / xxx.jpg
    ( O& G# t; L0 w. p2 ?4 ]. Wvalid0 \6 G1 e6 o: U  w* N
    7 x1 P" K2 t/ c; v. ^
    主要分为以下几个大模块
    ' h% Z3 t; ?! ~2 r0 C
      |8 ?& b6 }0 G9 x数据预处理部分6 R2 w+ N% E: r. c& x1 m
    数据增强4 `2 A; u& V3 Z  I+ D& q
    数据预处理
    - m) L: a  g' }+ k3 I) V网络模块设置: W' D- L' X4 b+ p8 I
    加载预训练模型,直接调用torchVision的经典网络架构
    1 i4 n9 i4 O- A' \( u% A因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务: P( a7 G" C- E5 Y+ S: S! }8 y' h( u
    网络模型的保存与测试
      @$ b# u: I! K' z! S3 f模型保存可以带有选择性" i! r' k/ T" _% d% r
    数据下载:& G+ q/ u- z% |1 |$ y3 {
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset9 K1 w( K; k/ V4 }
    ' Z: T5 V( p% I" a; N; o
    改一下文件名,然后将它放到同一根目录就可以了* p7 s  Y  v# m3 P

    ) ?+ E, n# t! x# ]; k, }下面是我的数据根目录
    : D9 P; ~2 l2 ]7 q0 ~" U3 @+ M7 U& y1 O/ h5 Z& O7 u
    $ x9 w; z# o. {7 L
    1. 导入工具包- P- d8 {; E% D& R/ ^/ e
    import os# ]) u: \8 N: p  A4 v0 l
    import matplotlib.pyplot as plt0 m, r6 T' Q) f' Z: V; B4 U
    # 内嵌入绘图简去show的句柄
    5 @' |6 H, T* ]' x  S4 \  Y8 [" K%matplotlib inline
    4 u& c0 h; H5 M4 F# X/ limport numpy as np
      S% T5 M% o* [4 D1 `* zimport torch
    1 i; n% M0 Q( n2 a% m/ y! `from torch import nn
    1 w4 i6 {$ s/ q& u; E  W9 C) t
    ) I, y  ~4 A% Yimport torch.optim as optim+ F) x) V  F, f' y
    import torchvision+ c# S0 `& Q' n9 W' @  x0 o
    from torchvision import transforms, models, datasets
    $ v! r9 n+ K  P1 W
    2 O" q. k9 w  `/ F1 m) ]import imageio( O% G" |$ ^; ?" S% T& `
    import time
    ! a, h/ y. J" \" @& M4 z$ Vimport warnings+ k4 f" T4 X# P% G& ~0 z
    import random
    , T3 c. \6 @+ Z: v+ h9 {import sys
    . l' K8 o& s! Z$ g# t+ h9 cimport copy: b8 Z7 @7 O3 n3 T3 y; I/ p: J
    import json6 _2 K4 K# W" V7 \  k0 E$ y
    from PIL import Image! h( S' K: H9 E+ n+ f0 x  A& Z
    ! T# [! I; y" F+ E/ j
    ; @, P2 E1 a6 H$ h1 b2 o0 J( }
    1+ j% h- b8 r/ Q) w1 y
    2
    : M' d! s: j& x1 m/ ^3
    2 e* T) Q; Y3 q* x& C. M4) F; h+ }' \$ `$ y
    5) g0 n- P0 v& s. c
    6) }( `  h1 s# a0 h9 u! I  d% R
    72 E& t% U( b5 i# p2 m, n
    8
    8 X7 c+ _! j  `$ ~; Z9 O9* D$ a3 ^# l5 C3 O
    106 A" _) W. X0 \; S  n
    11" s/ }/ e7 k( R7 l" G
    122 U( g* O6 a% I1 ]! q; X7 r' i
    133 r4 t- Y0 ~/ n; h3 k
    14& i1 _) a; X0 h+ A& S$ H( }# N6 y; u
    15) h' a' w# w: s9 b
    16$ O1 z9 B- t! X! J. X( i6 j5 F
    178 g  O% u/ U0 e+ t/ v: y- l
    181 @9 R; C2 F1 V  i$ I) z$ v! V
    19. E$ P8 j, }, v2 G" ?  Q
    20+ ~8 E( Z- `$ p6 B0 b, C
    21& I/ f. N: Q" D- k6 Q- C# A
    2. 数据预处理与操作
    4 D3 U+ J5 g  a#路径设置. i" W$ ~# @1 o8 G3 e: i1 p
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    / [8 |6 |4 m1 C# K3 wtrain_dir = data_dir + '/train'* M: Q6 |$ u8 s$ G3 e4 S7 O) o
    valid_dir = data_dir + '/valid', i9 D; H/ Q) _
    1$ x, l) ?8 `' X9 z
    2
    ( Q  S4 u+ L, H: {( B5 n3
    ; J3 Q. L% R0 T: Z% [' M49 g6 R5 K7 c; w; q5 E* j
    python目录点杠的组合与区别
    7 v+ b# ?; r. s' D3 C注: 里面注明了点杠和斜杠的操作
    % z- [- Y, T7 i: o* B8 V
    % K5 b) P2 C/ l8 a: C3. 制作好数据源) {3 ~0 u, v2 v: d2 H" U4 l3 C
    data_transforms中制定了所有图像预处理的操作
    3 ?, }% l# v. XImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    . e1 B- V5 K8 J" Edata_transforms = {: k+ t! z6 n( W( [0 I' q8 X
        # 分成两部分,一部分是训练$ I1 F5 k! ~, R- x1 u3 y
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    $ S; t0 N# _$ x5 |9 P6 T$ v2 j                                 transforms.CenterCrop(224), # 从中心处开始裁剪
    + H0 U. [. V+ X2 e                                 # 以某个随机的概率决定是否翻转 55开9 K" B/ l" w" {: |
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转; }- Y. F1 ~  X1 A7 f$ D, h
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    1 M+ Y6 X7 [/ d! `3 B! G                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    / W( ?/ p' m# \3 W4 d  ~4 P: i# G                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),9 V/ y4 E/ H$ j& k
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB) P9 k! }" I2 r
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的* D6 i; ^( E; }/ I: `6 Q
                                     transforms.ToTensor(),
    1 k. e8 Q, ~$ |9 y+ ~! R3 d                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差+ t% N. f+ k$ R+ t* E7 t
                                    ]),
    2 @) S- g- n1 H% |7 C    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    ! @. d5 g3 Y5 m$ k2 ]5 u    'valid': transforms.Compose([transforms.Resize(256),
    ) K3 L. G+ D6 ^; l9 x/ S0 Q                                 transforms.CenterCrop(224),/ U/ z4 C+ o2 a) |. c) u" e
                                     transforms.ToTensor(),
    3 y. B  F& i/ S; Z0 h/ h6 o                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    # g# v: v& B. T  p                                ]),
    5 e2 u- w2 H+ Q}; O  X% P! j& j/ M( H( s

    ' X9 X  d2 s* r' a1& n6 l7 W( O( {! d
    2* N+ W7 t) z3 Z2 c: v0 }# h8 W: u
    33 y( w% A6 }' Y, l0 }
    45 _$ N; a3 S, K! K/ {* I( z* B6 `. C
    50 M& \0 ^+ G3 M
    65 {& q: q' |1 S) E' O9 l/ l, x
    7
    9 q9 y% w' ?0 S! ?/ x8- k# T+ P4 d' I, ~2 Y) K! V" J& B
    9. {3 V% W( N! W- s8 |; `8 V2 C' u
    100 T! f1 F& R2 Z) S+ C
    11
    4 E- [* f/ i; H: d7 G; N& ~+ E) A  K12. Y+ a- {- ?. ~- |, }
    13
    , a5 N+ y5 A" N7 V. h7 a+ _" E5 A14# k$ m- y# a) W; |* ~: Z  `
    15
    3 f" U  s+ F( ]2 i, X0 r4 F2 H16# J; B8 |$ M% s% F' ?
    17
      q; v" C8 E8 ~; F+ \18
    5 Z- Z8 V" x' c9 Y2 D19& u" t7 ~. H3 n) G: c7 J) @+ V
    20) e! _3 Z# N0 `1 d; N
    21" T. R1 m# ]% F" [1 h& D
    batch_size = 8
      Z' N4 e& O* S% S6 E: f% [image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    7 e/ l& ?# i* c% b% m5 V+ @& wdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    3 q( c- S9 g5 N. W+ C$ k2 f! i6 X0 idataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} ' p" `; r8 l# i& H5 B; a/ c
    class_names = image_datasets['train'].classes3 ?( b# }! a4 |, k, C3 \2 q, ?
    & ^  p- W. d3 x4 G/ R4 d
    #查看数据集合
    5 p( N( h; M+ i3 o+ Q* j: yimage_datasets
    # u' S9 ^* V2 M8 Y% T+ Z* r
    : J9 {! d4 L- h* b4 ^4 t15 K' x" v! G) m! D; N. g1 u# l
    2, g* Y, }) S2 J
    3
    " T0 M. @- `. i$ N: U  A5 Y, o4
    6 ^2 h4 i& K- A7 \& e. m0 u54 R9 J: Z7 I- B3 E
    6
    # s: k' a- }8 }4 T4 y% ]9 S7& v6 s* r& ^  ]: s: S% x+ h
    8$ N- g+ m, A! Q& @3 ?- a3 ~/ A! q( }0 [
    90 U0 |. o- W/ J$ E
    {'train': Dataset ImageFolder) s+ j9 R# e# i& Q6 S
         Number of datapoints: 6552+ I) x, C7 J2 x4 h9 l% G" \
         Root location: ./flower_data/train) y4 Q) y  |- u3 N( l
         StandardTransform
    0 T1 D$ ^. a  a4 k" O+ O0 d Transform: Compose(
    # y% f, i* p: z+ t6 m, B                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
    5 J" s7 [; s, R( T# P0 |! X; [: g                CenterCrop(size=(224, 224))- Z1 X! Y5 Q" ~% r3 }
                    RandomHorizontalFlip(p=0.5)) h* l1 S- z! l; ~# Q' V
                    RandomVerticalFlip(p=0.5)/ d0 j/ ^5 c3 ?
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])$ y2 V, q5 [- ]3 C7 p3 T/ a5 G- L( {
                    RandomGrayscale(p=0.025)
    ! X6 w% B$ v" e2 _. P' x+ e                ToTensor()$ |/ i% _' h( [- J3 _$ {8 P
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    9 E: R) a& e" K6 a% ]2 [  w# s/ i            ),
    & a) B$ f2 h* g( T! w. Z- o! ] 'valid': Dataset ImageFolder
    ) x0 m. H* R: P! A& l     Number of datapoints: 8182 ~, x/ K: |  H
         Root location: ./flower_data/valid
    ; S5 E& J/ D& |$ h9 U* Z     StandardTransform
    ! @: t+ R' E& Y' [ Transform: Compose(
    ' i: T: N* [8 v( I: G                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
    " A, v7 R, r, a- z5 x                CenterCrop(size=(224, 224))
    $ v+ U7 V( p. ~3 C0 L, D- I) \, U* U3 E                ToTensor()* ~% \' g! I( x
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])0 T4 j, F' R* S0 u
                )}
    7 \% m. [1 i% j! x! |
    : E, e' u  U' t* h3 E" z1" q0 v& N' S1 w$ `3 Q9 U& n
    29 d4 Z, x3 [" F: H# ~* e5 t
    3+ |# |2 k; W" Z! g9 [2 w
    4$ v0 A; _/ c+ T2 M3 }: O7 V
    5/ G1 ]+ H0 k- V4 p: y8 T. R
    60 c: m: U- M2 G! C8 J5 ~( Q
    7' k3 k6 m5 g) B7 Q# K9 G
    81 G, B6 ]- G0 u
    9
    ; E% T5 k( W' @. Z  Y10
    4 R' M+ U; l+ h6 Q9 p/ x0 @: D11
    - Q: f& B7 C; s' L3 T8 B4 [12; i( a" N) |  Z! E$ t
    13/ x4 x  M* v4 l
    147 Z6 q9 K4 ]& F2 i# \  H
    15* [2 G* ~3 h; g! f& e* ^! F
    16
    0 d$ r! S# `9 x7 Q8 T4 q171 a8 H% x3 f5 B3 x1 R
    183 x- E/ g$ ?- {0 O& P. I
    199 \+ G& i1 }: u& i
    20
    0 u* G8 H1 Y# p21: H9 U* T8 g" {6 E, K6 @
    22& q5 i: P) _/ s# B/ I
    23: e4 d& R6 {5 q' v8 t' r
    24+ l8 }2 ]8 l" K- O) ?  M
    # 验证一下数据是否已经被处理完毕
    5 E: F  U" B7 j; Z# A6 d. Q; d+ Ldataloaders% |3 V* y! r% J* r
    1
    # P% K$ f* i" W: |9 F5 z$ T2
    ) `2 L  \7 }. I2 N( ?{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,% Y) F6 S7 N3 N5 i8 ]; I0 I
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}& W5 k5 ^6 g# M
    1
    # c& p/ m% M# p5 W1 I9 ]2' B4 J: d8 F& r- b
    dataset_sizes
    # G, [. X5 Q; ^/ \  v* X1$ U/ \4 R" K6 s1 u% g  A
    {'train': 6552, 'valid': 818}
    / y! m3 J) i6 U18 f9 N- v7 a3 p3 K, z
    读取标签对应的实际名字# E6 ?/ a0 U7 ]2 W3 Q6 e& ~
    使用同一目录下的json文件,反向映射出花对应的名字
    8 k+ j0 R6 ~+ a- f5 `$ s9 e8 d. M5 U5 m$ q6 O
    with open('./flower_data/cat_to_name.json', 'r') as f:4 y+ t2 d% {1 b. H( T$ \
        cat_to_name = json.load(f)4 Q- [& y1 o( C6 b
    1/ n7 w% e9 S( H1 Z( T4 M( ^' W9 ^
    2
    ( i) L. J, q& Fcat_to_name- x( d: F: u0 K$ {0 B  v+ w+ ~
    1, E4 |  }' p* r+ B) y) d( J
    {'21': 'fire lily',/ h" t! a7 t9 D5 m
    '3': 'canterbury bells',
      E" t8 E2 O# B4 z) O. H' z' E '45': 'bolero deep blue',
    3 M2 b9 r5 J2 Q' |2 h7 d '1': 'pink primrose',1 h: G! }  Z( T
    '34': 'mexican aster',  W# J" k4 I5 n; F2 g9 e
    '27': 'prince of wales feathers',
    ) l, U/ ^* `8 S2 L '7': 'moon orchid',
    2 r& s6 c% H- y9 G '16': 'globe-flower',
    ' r& v  R& G, \9 S: u1 @ '25': 'grape hyacinth',6 f. a' i+ @# q# a: l- h) m
    '26': 'corn poppy',
    * z) K" n1 G+ ^5 F8 i '79': 'toad lily',( ]1 }3 T; W7 j8 i
    '39': 'siam tulip',5 K( N  P+ G, M( @+ v1 f) |3 l6 P
    '24': 'red ginger',
    ( Q  Y5 K) \# l$ o4 c/ X) ]# m '67': 'spring crocus',
    % c- v7 J" H5 F: f/ h. K '35': 'alpine sea holly',
    - @( l0 U" r: T8 [( q '32': 'garden phlox',' e& @/ w2 ]* P% }5 Y& h
    '10': 'globe thistle',6 Q# J: D: y6 e( z
    '6': 'tiger lily',+ c9 I% |" N% i6 M7 p# x: [8 s
    '93': 'ball moss',
    ! C6 H  O' K: W: w '33': 'love in the mist',
      P- l# y6 p0 q6 P- \/ a$ O '9': 'monkshood',
    7 |2 l* E4 n2 A0 s+ x7 C '102': 'blackberry lily',
    7 h* m: R- N/ A. X '14': 'spear thistle',
    - z% S% N6 R$ q* d- \; y% e '19': 'balloon flower',. I% \4 U( @$ Q' W1 e
    '100': 'blanket flower',: g9 r) M9 F) p% e8 ]
    '13': 'king protea',' T  A, V! B; k+ G6 {
    '49': 'oxeye daisy',
    9 q' M: g5 w6 [- C '15': 'yellow iris'," x9 l( H! t& A! G1 I; r
    '61': 'cautleya spicata',. I4 v+ y: q: Y
    '31': 'carnation',! Q4 [: ?; e2 N- D6 t! \: H
    '64': 'silverbush',6 \- A4 I6 ?, V) l1 d; L
    '68': 'bearded iris',
    6 i2 @: V! h6 N# {4 T '63': 'black-eyed susan',9 @" y! @' L8 I$ m
    '69': 'windflower',- Y0 d* C9 N: i( G* g* ~: Y
    '62': 'japanese anemone',* J3 o6 x1 W+ K
    '20': 'giant white arum lily',
    ) K0 d/ }$ _& V4 u  ? '38': 'great masterwort',
    , P6 }6 c# J$ j4 l! s '4': 'sweet pea',
    & O% `; b( F* V8 ~ '86': 'tree mallow',
    ; ?1 e7 l! C& m( M0 g" M '101': 'trumpet creeper',
    # r% B8 v! m: Y; V$ `8 r! J '42': 'daffodil',/ D; D7 a+ k. b' F' g$ L
    '22': 'pincushion flower',
    : f; H" e& J! y9 ?$ G '2': 'hard-leaved pocket orchid',7 x0 {" ^5 Y5 K
    '54': 'sunflower',
    0 ]6 B8 j- R7 ^- e; C '66': 'osteospermum',0 d2 c7 `' O8 U' n: c
    '70': 'tree poppy',4 E0 [: \) Y/ L5 E6 Y
    '85': 'desert-rose',
    / b+ l. [  T" V2 c2 m( @ '99': 'bromelia',; {$ {) J5 ^& [* ], y6 O
    '87': 'magnolia',
    2 |+ r0 R; q# p0 _5 f) h4 s '5': 'english marigold',1 U- k9 S$ w3 c8 x/ l( y0 X9 i# Y
    '92': 'bee balm',
    : R; s0 W  Z0 t3 k+ A '28': 'stemless gentian',
    ( f0 y  Y0 l: t+ q( I2 v '97': 'mallow',
    * g- g9 z$ [9 r" N( f '57': 'gaura',2 e) u) f" L7 t2 w& q1 H" Y( ]
    '40': 'lenten rose',
    % s" E4 e' B+ l '47': 'marigold',
    ' L9 f/ p9 I5 Y3 s9 L '59': 'orange dahlia',
    ' `1 J6 F4 P8 |8 S7 @  l7 k2 j  x) W '48': 'buttercup',% w9 J$ a: P+ U: l9 O' @# P
    '55': 'pelargonium',
    & J+ F6 W; z/ J0 E7 }* ]8 a7 q. y" M" n '36': 'ruby-lipped cattleya',
    4 V- [  {! C& B3 L, ^ '91': 'hippeastrum',
    2 H7 p; R' O+ {, w+ G '29': 'artichoke',
    * B. y0 I% p7 @4 z( I( C  Y3 j '71': 'gazania',, Q. G' v; a) O4 a
    '90': 'canna lily',
    9 n0 ~% c/ O3 d* S; l; { '18': 'peruvian lily',
    ! ~( m. {4 {& S1 [ '98': 'mexican petunia',1 y, w0 O  {" k1 n$ O& Y
    '8': 'bird of paradise',( `: m% t- {! m: n  S
    '30': 'sweet william',* `, t5 U0 ^0 ]. j
    '17': 'purple coneflower',4 F7 y) o, J! q% n( y4 F" Z. B
    '52': 'wild pansy'," T1 r/ Z+ R  H
    '84': 'columbine',
    7 E% Z; s: J4 ~ '12': "colt's foot",
    * \- \% J" s5 Z5 J/ T '11': 'snapdragon',
    2 h9 U5 k6 \+ s '96': 'camellia',
    ! U$ z/ V( l# j4 X5 E '23': 'fritillary',7 I1 Y5 U1 A1 q0 y0 t
    '50': 'common dandelion',
    , \- D5 S: t7 l$ y '44': 'poinsettia',
    $ i% e' J# D$ a. v! b '53': 'primula',; b5 h+ c7 C" _1 ?+ p' C
    '72': 'azalea',
    + y$ Q0 Y* a* x3 f5 ^  v/ U '65': 'californian poppy',
    & L9 T1 v0 z- Q, o4 h% H, l1 { '80': 'anthurium',
    . z/ m9 |6 ~' n# T- P '76': 'morning glory',/ T7 P; G% C- U: @( B: _
    '37': 'cape flower',
    7 c; y  N) Z. v. o. X( M '56': 'bishop of llandaff',
    + ?7 j& K: m) t% l '60': 'pink-yellow dahlia',
    % X$ {  u$ Q& ^$ a6 }! c0 w '82': 'clematis',6 @* o5 C1 b# U8 a  Q, {
    '58': 'geranium',
    - I% N8 i8 Y* S! c: K '75': 'thorn apple',
    & w) i' x" P, p0 @ '41': 'barbeton daisy',3 o% K. _  ^5 W: A+ f
    '95': 'bougainvillea',+ h- l1 Q5 Y( o" g# x$ B7 m% f+ E
    '43': 'sword lily',
    , Q( W: w/ E  d' b2 o( H0 U+ g '83': 'hibiscus',4 R/ _. E  [1 g2 ?5 w2 ~
    '78': 'lotus lotus',; e3 C8 s2 l4 s1 [
    '88': 'cyclamen',
    " z- @$ y- H" K1 D3 Q  v9 g '94': 'foxglove',
    2 T' u* G& g) G& [9 ~3 e7 P '81': 'frangipani',5 P# z3 c8 _  D: q) }2 z1 K
    '74': 'rose',
    2 ~9 z! @1 _. L( j: q '89': 'watercress',$ @- t5 i3 G# r5 z9 j  n1 f0 V
    '73': 'water lily',
    8 m+ [* q/ U6 d- | '46': 'wallflower',/ Y3 p+ w# c4 p
    '77': 'passion flower',
    1 A7 t. T+ i) B5 u '51': 'petunia'}
    . m2 \: v2 `# {4 x! h$ W( v! X; a- X! Y0 O8 }) B& @0 ], b0 h* j
    1
    ! ^( i& M' M& f- }! c2
    ' O9 q8 C! t3 U, _7 |0 B3
    + p1 O. h! V& r7 Z7 `4
    3 n9 Z0 _3 s4 _$ F, d6 h3 g5
    : ~- D) C3 r+ O( Y65 L* a9 F1 W, h0 u) e( {) T$ K0 {
    7
    5 A9 B; m! H3 }$ |7 [4 C8, V: j2 o  ]  V2 ~
    9
    ! n  n9 P" p5 ~! b! i10
    8 o) I# F. J! ]5 @# s0 e' q11
    & y+ o! L2 c: S. W' U12
    4 `9 w! m6 |1 H+ p+ o5 v3 O13, b& f4 r, u+ O, t- N5 w
    14
    ) O7 h: u/ n/ X1 g+ B15
    8 e9 a, s+ x0 q$ w6 O4 K16
      k5 x5 J0 R) X* X, {17' c0 t% l9 ^" [" t
    18" x8 e) _! a5 n3 ]& S: E5 o7 v
    19; e3 R. N/ ~+ E
    20
    ' a# N# e$ ]* D) s1 ?217 _; U& z. A5 Q: L; _
    22& T/ E4 U, Y7 |* Y
    23
    $ A1 }$ c! N" y$ r24
    5 h' ~$ H2 e& I% b4 _$ V/ X253 V; L" J0 L1 i- E! X3 I
    26) e" w0 c7 P' e* S' A
    27
    - P% V; D6 s. @8 f( }285 ]6 [7 w- Q; T  M4 i' A4 R
    29
    , ^- i  [8 t' q) d& ^30
    / r( K6 ]# f3 d# R# n" R7 u5 X: m  P31
    1 ?" s: [0 J5 o% [* E32: ?! U" Q/ h2 q! @: h% j
    33
    2 S) x1 ]  j# s. x34
    . Y+ C& A/ E$ ]4 M# d35" G+ ]- H* Z5 t: \- Z
    369 @1 S! Z: I; R: L
    37
    " s! H2 k! ~) w5 Z% }384 M! d! e) J% ?7 d. N8 y
    39: o/ Q* W9 R0 \8 L0 c) u# z" A
    40  J3 w/ U9 q# P) Q5 m
    41* ]0 Z2 E+ u7 {" I
    42& X  q0 v: a. L3 L1 H8 T
    437 k; Y9 D4 B' l/ V
    44
    - _" \& C' t- D5 A) `2 j% P45
    ' S; |7 ]! D: `6 e0 I46
    " L* F8 I! h. [: I* H) {47
    ; ~6 k& @7 R3 ^1 E* N; J9 [8 v+ S2 f486 k+ f9 h5 e% p" n( E0 n; n
    49: j* n+ o& h! t( Q3 _" n4 Q
    50& ~% q+ D. [& p7 Q+ x) t8 }* o
    51% d( e) G) n% L- C- z5 K
    52; e+ f9 o2 U. }
    53* h' v% u) q% ^0 H% L
    54
    6 E; ]" y, Q1 r2 E, f55
    2 a/ p: Y2 W6 [* U566 f6 H6 y) D1 e0 b7 e! I
    577 J  j% N/ Q4 M& C+ h
    58
    $ H) s! }1 y. l59
    0 N, m0 |( l0 M9 ^! H60& o& O$ i2 g( ~
    61; ~" m2 b. Y& [/ M. w( z
    62. `) H$ [$ l8 |8 ~- J. z1 w
    63; L, G3 W! ?. M
    64
    / r7 c# G( t6 c" X4 o- d* K65
    $ o6 t6 g, y" n; }4 N3 m/ |" i66
    # B3 Y& l) i" \" T/ b* @* n! [; i67
    2 q0 B- _3 {: T( x% i68
    ) }% n; r0 g1 Z5 K' W$ e* T7 P69
    0 }) y' G+ x" v* W' R70
    : k) M' R  }; V, u71; O: ?/ Y- m$ X0 {
    72
    $ Q8 g% b5 G! N5 f: X! t( Z9 j73
    1 E" n- A$ v; s+ h- ?  X6 |" F/ r74
    2 {7 R! u$ ^) S- |75
    0 _2 f, p# N" I" g" s6 g76
    8 g. o" [. n1 G' ]+ p77
    . V4 n, E( s* R& g  r& A) T78
    / H4 S8 u% r3 S# B" x: X79
    * c$ z/ H  u; O4 Z( @' s806 h) M' ]- e: }3 l
    813 _' I' {& \# W$ J7 ^1 R
    82
    % `8 n# H) q$ [5 ]* m9 |83  ^% F5 z5 W  S& q: [
    84% B. l2 B# i5 w+ B  u9 K
    855 Z& l* s0 R  l7 d
    86
    % F% v9 E  Y2 U* j. v3 N4 `; L870 x% |5 W+ T* |3 A! S+ y( Y) D6 R3 V
    881 X5 k# {  i% B+ a! ^! [
    89
    ) @7 }& y' o) X! f9 A5 X90
    3 P) P. |' t7 C; E8 N9 z0 `91
    : W& a. U. l* ~' }8 i( E1 `92  V$ t( s( x; f! Q) }/ d3 L9 @
    93
    " \/ t  |" [+ k- K0 h" \; T94) P# j- o8 A& H6 k5 M) Y9 A% ?' o
    950 m. }1 O9 x! `2 ?
    96
    3 H2 S' s* W8 B3 ?97
    8 p' z' N; P( n% o1 a98
    3 ]" J: y5 _7 ]+ l99
    ' I* `; }0 @/ Q, N100
    3 f* B1 b* l5 H; Q6 ^  b: j101
    : H. s9 e3 o8 V* e1 f1027 _6 `. M4 [1 t  G
    4.展示一下数据
    . K; P: J  k$ n1 h5 c$ o" j% W4 Sdef im_convert(tensor):) v5 g( q( V8 i9 W) C
        """数据展示"""
    ' D! h5 v6 n4 o1 t) E2 o6 w    image = tensor.to("cpu").clone().detach()
    & o# l7 y0 W: Q7 a7 q& W0 @    image = image.numpy().squeeze()# g, q4 t' u  A" A' h7 A9 P
        # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    ( n0 f9 ~  {. W. H    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)9 Z6 B- Y. T! M: f
        image = image.transpose(1, 2, 0)
    & ~: k& N6 D6 g1 v    # 反正则化(反标准化)
    4 b6 i3 B, q3 f7 |) ^4 C5 B, o' `; W. R/ q+ }    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)); J, [0 x9 W6 [$ O! w* U

    ' L3 {4 V# U$ [* I    # 将图像中小于0 的都换成0,大于的都变成12 {# O7 `7 O3 k& E1 h- j8 U
        image = image.clip(0, 1)
    ; y3 {) @2 j6 H8 t+ q( K
    $ s& ~/ ~7 l0 ?& y; q    return image; V0 l% C: H0 i/ v: Y1 y
    1- J- E/ N3 F6 m
    2* U8 C  {4 S- W6 ?, J' j
    3. A4 T( E$ ~( m, e  c
    4
    $ b' J* P: e4 G) P) I0 D8 L5- b! `6 e( z7 W8 \6 i' T, Q
    6; e3 c8 ]& x# M- i
    7
    4 \! O$ P. ]) y* V0 z8
    8 z) r  \& r9 a- S2 R8 n9 f+ W96 X  H  f$ \9 b1 U1 E
    10
    " i5 Q- W0 K) O110 X1 v! e6 n! t5 ~; ]
    125 `+ }, v( ]2 i8 s+ B) B
    130 d+ D& g0 w/ b5 l4 C, q8 O
    14. }4 {5 v8 G, c' ~3 D. M
    # 使用上面定义好的类进行画图, |  A8 M. {9 J6 ?, C( _9 `
    fig = plt.figure(figsize = (20, 12))
    7 N' x! J6 z/ n" G' Zcolumns = 4- |' G! [4 X3 ^/ o/ r  V1 d
    rows = 27 V( C6 _+ h2 D
    ! U/ O+ n3 y, v9 c6 s2 D+ y' ^" Q+ k
    # iter迭代器3 h( A6 `6 R7 s7 C( k
    # 随便找一个Batch数据进行展示
    ( b! U2 `5 e* o3 |3 Hdataiter = iter(dataloaders['valid'])
    - W; [; c  z) ^. minputs, classes = dataiter.next()
    & e7 D! R* \( d( _4 @- g; [( i* |1 l
    for idx in range(columns * rows):+ \7 p* z/ Q9 }7 R7 S
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])  r' D1 E! D$ i  @2 o1 v) ^2 e/ y
        # 利用json文件将其对应花的类型打印在图片中& h) P3 H# t1 E: a9 d! u
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    + A/ l# B0 }# @( K( |* k6 J    plt.imshow(im_convert(inputs[idx]))
    - I$ l% X. A8 Tplt.show()+ [2 J7 I+ i) @6 n
    3 L* x. e1 _/ }1 B9 S- ?
    1! I9 B, j5 k4 Q  l) T) C6 l
    2, n$ _$ `8 n. ~. P- o9 z3 j
    3( g' S9 I; s9 y4 M4 k) {
    4- z7 T# s! g" f, H$ x5 E
    5
    ; e, u  r: f) R2 ]. J* ~3 J  K6" m1 N4 c# F% c# D& Q: P
    7
    4 L! L  N7 ?! X3 Z+ B1 ^# |8
    & w1 |* Z9 V9 }' k, p# ?0 d2 }9$ A4 v. d8 t- v: I, M" G* g, t% [
    10
    # u. `7 P1 ?+ x0 t, o7 T112 J/ x! v% I9 Z: @1 |1 `
    12$ F$ ?7 o! z- G3 S
    13
    4 B5 F8 e, z+ O% c) N14
    ! d, @; f* J; X. r9 l7 ~8 h, B15  N) M1 E$ Y1 H4 O. s
    16
    4 u- j9 e8 F& m; e$ T& I
    ( ]( W% x# Z- h. T9 o8 g) J( j$ M+ t0 O
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    ; g" j# Z/ H! j" q% g9 K; [model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    6 I% E8 r' c. D. z& Z( Y# 主要的图像识别用resnet来做
    - N5 n" s$ |: N4 s# 是否用人家训练好的特征  e" |7 \" V3 c- }* m/ d
    feature_extract = True+ o. T/ A6 b# i7 s1 M8 E" i, V
    1
    # |0 e1 U6 y+ j4 `( b2
    ! b* c5 @7 M0 M. z30 _: a8 [. e0 Z+ v# r/ E; E- r! ~
    42 A. V% `: m- V- K& Q2 c
    # 是否用GPU进行训练- K+ @: M9 V: F0 F4 [
    train_on_gpu = torch.cuda.is_available()4 @3 t- g5 C# G8 D: y: n2 B

    5 x  Y9 x6 ?' n$ l. H: [/ kif not train_on_gpu:% U1 l0 J+ `/ O/ D# K- H
        print('CUDA is not available.   Training on CPU ...')9 u$ @& s* z7 r0 p* Y
    else:; i5 [$ M% r( u5 M, O; ~, M
        print('CUDA is available! Training on GPU ...')
    0 z3 Z8 {/ a1 C/ k5 T7 F/ U
    4 g( V/ Q  z7 n. Q7 ]' I; ~device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    / t: J; z3 S+ a! }( w) A. x8 G1
    1 x7 f& ]! w( F8 f% }: y# X$ M# A2& u/ h! p# r: h! |+ X
    3( p6 b, _& g8 z0 v: q) Z1 V
    4
    4 Q$ S8 t% o7 `* k- i9 Q- q5' X2 i; v  G2 \% z' p/ G3 @* F4 F
    6
    / }4 z) h, v  C1 u6 F- e9 Z7 r72 g7 P; Z4 Z/ B. c' _
    8; ^+ o' ]( z% Q) e
    9  c- `6 i! w) R1 B- Y% s2 i
    CUDA is not available.   Training on CPU ...  a6 i! Z# y7 E0 z) ?: R2 E
    1$ l5 V  I9 R1 t( U7 A7 K
    # 将一些层定义为false,使其不自动更新5 K% V- b: M, C, |% D+ f& E# J
    def set_parameter_requires_grad(model, feature_extracting):
    ) ~# O, Y  L8 A. I. r' d    if feature_extracting:' Y  k0 F' o3 e1 D4 v7 m9 v5 ^
            for param in model.parameters():0 C- h3 I6 U1 [+ ?0 r8 k' v
                param.requires_grad = False! T* i+ i# x* s, s; y: a3 w3 K
    1
    & I+ Q6 T% }" T, J6 \8 x2
    3 c$ |* a' v5 h0 }3
    ) W3 Q8 I2 p6 i* N- ?4
    : V( g  _9 b( E, v5& H6 Z6 o6 ~8 f+ T
    # 打印模型架构告知是怎么一步一步去完成的
    4 E; M/ o; Q0 D6 z7 \' t" ^3 T/ j# 主要是为我们提取特征的
    + m& S& g2 H# S& o/ _2 m  k/ z0 Q6 ]# w
    model_ft = models.resnet152()
    $ I" [; e' }" X" U1 B4 x* zmodel_ft2 F* z( ~& L( J1 T5 i
    1
    : B  p% W, x) E( h" ]2
    ! k% e3 p1 L. Z  ^3 |& H$ R3
    ; l* L1 S* D, ?1 g" i8 J, n# a4; @  q# d9 `5 s( P1 U" ]  I
    5
    ( C! S6 i! ^+ \5 vResNet(
    0 A0 n- @) V0 t1 A  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    6 f9 F/ N7 \) [5 k  a5 F( w  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ o$ `( Q& }& X. l$ [) Z
      (relu): ReLU(inplace=True)
    ( S, Z# u$ [  R% ~8 M  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)6 x4 f4 S  C4 G
      (layer1): Sequential(' f/ t2 O5 ?  L- K+ h/ Q
        (0): Bottleneck(
    / d' l. K) D' b0 q2 c& H' ]% K      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    , C" }/ V. S, C% }      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    9 L3 p# x0 R. i' L8 Z. R; t4 k9 U      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    2 S6 j4 `% Y0 Q" j; ~$ t      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      B6 @& E: G8 J, A4 R6 i, Y" c      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)/ i, P4 S  d5 W" R+ i
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    $ Q' _$ D! [- J6 ?7 F1 o      (relu): ReLU(inplace=True)2 R8 ?$ m5 l/ o4 v& d+ {
          (downsample): Sequential(
    9 Z5 Q% Q6 e0 M( j5 p3 _5 s        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    9 c4 A9 r8 J" m2 H; ^        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    9 u* A( J2 H- R2 i5 \      )
    $ ^" H8 O/ e7 f1 x    )$ N& ]2 P9 c/ v. P
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    8 O% ~  F0 ]  t# l; b    (2): Bottleneck(
    8 W/ M( N2 S/ Y: m" Y7 l      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    3 X7 A6 f- T8 m" O      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ( @4 g- e7 k" p: u+ b      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)7 r. q+ g- O) _6 e
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)& ]' Q/ d( t% Z8 t5 [; Y' o3 W
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    7 O5 L7 `" D" z. N& e, Y$ M7 G/ u      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    + p6 w( A* f3 |/ |      (relu): ReLU(inplace=True)
    3 X8 ]/ j$ r' Y  \; ]6 Y3 k% W    )/ @% @/ g: z/ f& m: [" ?% G# n
      ), N. V. }: R! c# ]* y
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    $ y+ f* @2 t" O" q, d& b  (fc): Linear(in_features=2048, out_features=1000, bias=True)
      o8 y  x7 O8 D, S" j" m  O2 I)
    ; y+ s7 D, [7 K4 }! C
    6 O6 c# c& G/ t( L) _( o15 g% T+ b$ L( Z( z8 f; G
    2# e7 {; l) O. \3 J* s4 z9 l) c
    38 v$ m4 M  c; n% N
    4
    + M7 q+ o2 |; m3 K- u2 [& @0 n5  [% h; @9 b$ p
    6% n% R. u; h0 J
    7/ j0 h8 u- ?9 Y5 p/ A
    8
    ! L% _/ l  x. n/ ]6 T, |9
    / D5 j. R  [9 }, @/ Q" i5 @6 ?4 X6 T& t10
    0 [! e; B& Q! n3 q11
    ) l( L: }, q" R2 H% I12
    " \+ c+ a/ r; G" B( U( t) c132 s& L. A6 z) R( V
    145 e* |  @/ |& {  i4 ^9 A' Y8 a
    15) u  W. e3 m9 Z! }% Y
    16
    : J2 [; x- r# S$ F$ [17
    3 m+ f5 r/ K2 S, {" Q18! D! R! E7 Z3 A$ ]* G# _
    190 l; P' R" K9 o: q9 I7 z
    20
    5 s' m8 C) Y  O4 t21# I8 o  t# W" z
    22
    $ ?9 [# O2 G, D8 k% C5 `23
    + G6 n# c; G2 i24
    $ ~7 |1 r' s# O. ?0 s25
    . F# P/ f. ?+ S260 s* X8 d# a8 A: F% E
    27
    4 ?. _  y) ]$ ?/ H* x* @' [8 e) b28
    ! v5 J% }- q7 x( `3 X4 y+ v2 {298 f' r- ~" I4 u$ s# p& J: o
    30
    % }% d5 f' B4 i* d- `: F  O31+ Q/ P. w: k# f) C2 H: I- U
    32; A6 |5 u5 ]7 ^
    33  B  g5 e/ I7 @2 _
    最后是1000分类,2048输入,分为1000个分类
    ' j8 |8 n; F! u0 G+ E而我们需要将我们的任务进行调整,将1000分类改为102输出
    ! B: W2 M: c' S( ?8 M& c, H  @( D
    6.初始化模型架构
    3 }- B  [7 C1 Y& c8 ^1 k步骤如下:
    $ q! T& v5 j+ X: J+ i3 m0 s5 g+ e8 j; b) B) q
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
    4 e8 J! B3 S2 T0 [+ A6 [可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)# J; R+ Z+ K. x: T2 z
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数- Z5 C- e7 [$ u/ Z+ c
    官方文档链接; v- I. F9 y' m' z
    https://pytorch.org/vision/stable/models.html
    7 e/ r* o- I2 Z
    4 w* n! u; I, U0 U( L# 将他人的模型加载进来
    9 y8 ^  q) V, l& I/ e$ ^def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    # _0 U, E  V) c, A0 g1 ?    # 选择适合的模型,不同的模型初始化参数不同8 w+ y/ p5 ^  h7 s4 ]
        model_ft = None  k2 q9 c( ~! m" y8 K3 M
        input_size = 0- ]1 x( x! g0 F
    : S% S1 m! D$ G- \0 P
        if model_name == "resnet":; m; B& q/ G0 C* X' q9 B- ]
            """
    9 R+ B2 i- Y& w. p% N; d0 P; `        Resnet152) u9 K- ]/ R* @6 X" P3 c
            """" A1 f* A9 _. `
    ( X- T7 ?% C$ c' X2 c1 X
            # 1. 加载与训练网络8 }, z  K# `3 Z* X
            model_ft = models.resnet152(pretrained = use_pretrained)0 J9 C  E3 X6 Z% g% T) U* z
            # 2. 是否将提取特征的模块冻住,只训练FC层8 c5 K; U# X! n: j9 b
            set_parameter_requires_grad(model_ft, feature_extract)
    : B0 g0 V* w! e, i1 h+ g        # 3. 获得全连接层输入特征
    6 r  C9 \8 Y/ `+ Z4 |        num_frts = model_ft.fc.in_features$ _; O+ Q- x5 \; S. ~7 `9 O9 n
            # 4. 重新加载全连接层,设置输出102* L& q6 K0 h8 f( c4 X- {
            model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),4 t/ K0 t1 W* V8 P3 c6 i- ?9 Y
                                       nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1  R$ C5 }! X, |
            input_size = 224
    9 H9 p/ s" Z( n( d+ M" W& ]$ D2 C) `; M5 w
        elif model_name == "alexnet":
    6 \/ i  W2 Z( {0 G2 v        """: b. x  n& {1 H1 Q# Z) R5 k
            Alexnet2 Y& _2 S5 v8 Q
            """
    * ~# i( q' A* r5 m' i" S9 W        model_ft = models.alexnet(pretrained = use_pretrained)! r/ r/ n& b- g  ^
            set_parameter_requires_grad(model_ft, feature_extract)( Z; _& _! d( @1 n
    ; ~" T" d0 f7 N
            # 将最后一个特征输出替换 序号为【6】的分类器8 C: B8 V  U" B3 |
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入
    4 u# F, ~! H) y: T        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    + T' r' q" @1 \1 C7 H+ q        input_size = 224. F; L8 D% r2 {' A6 x

    & j, J7 F6 g% c" L1 I4 Z    elif model_name == "vgg":
    0 H/ L! J' W9 c" T        """) w; _: B2 y6 h+ ^
            VGG11_bn
    ' J- I7 o1 ?  [9 F' I' M1 V; V; }" m        """# B$ F& X+ D# ?( q
            model_ft = models.vgg16(pretrained = use_pretrained)
    9 h* L! W- l* N& V( W5 \        set_parameter_requires_grad(model_ft, feature_extract)9 l/ x# W$ c5 {
            num_frts = model_ft.classifier[6].in_features+ e- k. |+ {  `& z/ X( ~" N5 X
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    5 L; v! R3 c8 J- K# W: c' A2 A        input_size = 224
    + U' G( G( h: E" ]/ H2 [8 o" S# |: G" E9 o; \& z  L0 M1 P
        elif model_name == "squeezenet":- X7 f5 y9 C1 Y3 ?2 d' ]3 v& y5 S2 L
            """) ?, U/ S/ E& B# a7 f% l6 Q
            Squeezenet% u5 u$ G& H1 C' X, h' A6 c
            """
    ! V* L/ G9 O9 i& ?        model_ft = models.squeezenet1_0(pretrained = use_pretrained)3 U- H+ G$ a3 }/ c& w: H+ a# i) A5 W
            set_parameter_requires_grad(model_ft, feature_extract); t' z: }0 I) S
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1)): U0 v" t( U# E4 b/ S/ z, W
            model_ft.num_classes = num_classes* ?8 r$ C2 E0 h( n  b7 n. S
            input_size = 224
    - x) b& C) X$ |3 T& |1 n0 L7 n- J+ l7 ]* m( F6 I. F
        elif model_name == "densenet":  A0 ~6 W! \% k4 u! |
            """
    8 C: e, j0 _# @. i. p        Densenet) h( J/ y6 X% n4 u- [
            """0 r. N$ ~# o& I# e6 p" z9 Z9 f
            model_ft = models.desenet121(pretrained = use_pretrained)
    2 }) @& k; M* T. G  |. V4 d        set_parameter_requires_grad(model_ft, feature_extract)
    % g/ x" _$ p* m- R9 g        num_frts = model_ft.classifier.in_features
    ' `& r* f) g) V% R' m: |$ L' L        model_ft.classifier = nn.Linear(num_frts, num_classes)
    4 x6 R  k) {& h0 P2 ^) v        input_size = 224) J- |4 p0 r1 ?- O8 K" @" j3 [
    ( ~6 n4 P/ T4 Y5 x6 o! ]# j( U# j* d
        elif model_name == "inception":
    " P, K9 C( O; M$ U# T: t8 b% q        """
    3 H, B) E0 L+ `& y        Inception V31 i$ m# r1 ~6 x) V3 n" @8 S* Z& f
            """8 X- u9 U( U, K
            model_ft = models.inception_V(pretrained = use_pretrained)
    $ v- ^) W$ m8 M0 f! \        set_parameter_requires_grad(model_ft, feature_extract)
    : c9 u; P% H  {& k/ U  U& }% e0 h
            num_frts = model_ft.AuxLogits.fc.in_features
    1 R1 p7 y# q3 I( Q  i4 b3 M        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)( n" ?4 K1 U6 O4 C
    * O* ]" z& f, o
            num_frts = model_ft.fc.in_features1 `9 o9 |. ^: T1 ?' V) H
            model_ft.fc = nn.Linear(num_frts, num_classes)' [" |2 r* L5 {6 e( j
            input_size = 299
    2 U! r3 v* E, j. U, A1 A: V* ^1 A% J! \5 y
        else:
    # ?2 {- d# l8 r* O        print("Invalid model name, exiting...")" |  d4 G- v" d4 C# B+ x
            exit()
    8 L0 |0 K) d# n4 E1 ]' T" j( D
        return model_ft, input_size6 O  T% J7 x$ |) ]/ g7 U: N
    5 D2 Z1 I* f9 w" U' Z/ j( J
    1
    * U/ R/ Y& B7 X+ }; w* W2
    6 ]& h( p1 i! V8 i3
    & b/ W# |- y; g/ m8 [% H# d2 w4  v. |; r1 U& u' s$ R
    5
    4 P; Z- y# }$ u" L! p6
    % X; b' m' r, n) |2 e7# `. `  u8 a' s" _7 k$ A/ S1 s, i9 t
    8
    ; R& r: K5 x7 c  d6 d$ T" Z9
    . m3 e" t( {6 s, k  F8 n# h10
    : ^4 y  W1 @7 {113 W" u( ?2 j# ]0 @3 c% X
    12
    0 C# n. u" L7 m) }: v6 U8 `7 x3 d' O" ~134 B# h6 R- _# n* D' X  v
    14
    5 W8 p, V; l$ a1 Q2 B& _4 ]1 K% }15! n! R. _# t0 N4 y( p- D; z+ E" E# K
    16
    0 L4 Y+ d9 J* r# G/ |17
    , g  R& W& t% s( P18
    $ k. t! {) p% X" x! U6 z19
    " q; W( \1 a1 S9 _/ Q4 H20
    . P0 G( d+ t0 l" u3 u1 J21: R4 \! r  Y/ J
    22- L3 X" s. I. y  I! }( E
    23
    ( y; t! G8 ^" ]/ L9 k& K. A24
    ; l" e+ }9 _5 c25
    8 p7 t& S6 w6 w3 |26  g4 h5 w/ t' v* s& z
    27; T; _) ]6 ~  E, m' \6 M' y$ ?
    28
    & [7 h( J3 x$ d2 U/ f" ^29
      w$ e0 q% Z  O308 w/ _. w1 L# L$ Z. U
    31
    ' @' T* N& x: u32- M# w) w! |7 C! h2 L
    33% w# K: Z0 `, x' K& B( E2 i* J
    34" G4 M3 R( K# }/ v
    35
    - [8 e& m9 g, V: f7 p4 Y3 {, m! v' C36% R8 r' t; f) v  X6 o2 T! f
    37
    6 N5 L% q% `$ p7 `. t! b38
    " d8 n/ y( i5 i% Q- J$ C* j0 C39: D' N/ d* }- G$ g0 ]+ ^$ h% `3 o5 g
    406 p. s( s  n$ _- A9 B
    41: d9 c" v6 G5 M) E4 Y# K
    42
    + \/ i/ w; U& }43' a1 }" A# A4 {7 ~4 g0 {2 t! _6 Y
    44' G. f! |: m5 U9 R) f4 G4 _
    45
    - `( _) W, l9 W8 T46
    # K; J2 d5 v# d: Y2 k47
    - T( X- s) p3 P* M, M0 N48
    & [6 Z0 X0 g0 u9 g; y0 y6 j49; M9 P4 A7 N# X9 w, E  J" F" h. C$ z
    50
    4 U& G2 k3 \1 p" ^3 t511 _2 y7 H/ E9 e  g
    525 {% R; f" F* _% D  {
    53* S4 z# i2 P& s" D1 _2 r( K
    54; Y+ \" t, E$ r& N) m' H
    55
    & t0 m) ?9 ]( U2 h! d. `& t56
    % y1 @/ q! t1 y0 I$ D57) ~" W& ?; C: S$ z7 G) J/ F$ D
    58
    5 [. g. q9 \4 X( F3 i59
    * J) V; X: X4 ^) \60
    / C: Q+ C5 f, @) N' G6 n* _61
    2 C; e6 J1 T" Z; l! T62( K! F% [8 g4 C! Z) e, q
    63
    4 n6 ~/ e2 p4 J64+ E8 T6 i7 ]$ l$ z
    65
    6 W) P- c3 {" e1 \/ Q# D3 X" u66& u6 b1 D; |$ s1 W7 @0 p
    67% J5 W# h6 i4 J9 V
    684 @: T4 ^3 C: x  E. H
    69
      P  a( R8 z6 A5 }( u70' l6 a" ^* V; r: f6 Y
    71; l' a6 H. v) O; X0 \) L1 e) h8 W
    72
    ! k' q/ d9 w- ~3 L- l73: p$ a' {' |' m- H/ n: @
    746 J: X/ z% t& H( ^1 z
    75
    - Z" h3 r" f5 |6 _1 H76! i4 ~1 s9 A( ]) \# }1 U3 k; C
    778 {3 r0 Y7 s( F* E+ m
    78+ K2 P# I0 F( h
    790 m9 F0 W# {! b4 z  n( h' ]" P0 A: a
    80
    % H- c. m# G- x6 b0 [81# u' F9 x8 A; p+ o" K  i
    82
    ; g4 P* S4 v7 u6 ~83; g4 [7 `  F: t" q3 [2 a$ e
    7. 设置需要训练的参数
    , k" c. o8 P& [! c( ^# 设置模型名字、输出分类数" E. X' w9 A, f
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)) G8 ^9 j2 U  u6 ]3 w* A! w: Y

    2 ], F7 b1 t& e2 G6 y7 w& `7 J# GPU 计算
    ' H6 W+ T; @4 W; w+ C' T) Amodel_ft = model_ft.to(device)+ M+ G( j, B! p4 w3 V) @0 m
    & [0 k/ P9 W# [5 j
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    : y, ]  D2 _: {9 W# W3 k1 V% a/ Dfilename = 'checkpoint.pth'& O- J! y8 \8 [2 m, V
    ; u6 A0 b$ d/ x
    # 是否训练所有层. {" l: e2 A, w
    params_to_update = model_ft.parameters()
    / e& h  s% d% d# 打印出需要训练的层$ d5 B0 F) c9 A+ f0 ^- _4 u
    print("Params to learn:")
    ) \; H5 q: u6 Y$ b/ Q5 Vif feature_extract:) ]" _; k+ @/ e  k% i+ h
        params_to_update = []( q4 T" S1 ?3 m2 R' I1 m7 q: p- }
        for name, param in model_ft.named_parameters():4 n, g  v+ r! @! F
            if param.requires_grad == True:3 s, C1 E$ m" P+ C
                params_to_update.append(param)
    % Z6 n  U6 z8 J            print("\t", name)
    ) `, i8 x: {/ f; A0 }* ?else:/ d+ o; F5 ?% h& _
        for name, param in model_ft.named_parameters():
    / n, k4 l$ L# q# U4 d& Q- w* O        if param.requires_grad ==True:
    . V3 U5 j. F  G* w. p7 `# e            print("\t", name)9 r% O& t, X( T& s+ p9 _* i  M
    / M% c( x$ ~6 t% b) Z
    1
    # I+ z7 |) c8 f9 |2
    / \: x; t2 O1 d$ G# g3
    7 w. M7 I$ P, ^5 M/ u0 o43 R+ a4 c" W2 {
    59 G: J$ I/ W! N% g8 j
    6
    & {5 H- e3 D" E+ [0 n7
    , ?. a6 I  U8 U9 w# R' B0 s; m. `" y86 L- @$ t9 N, N6 E! f) S
    9* S" e9 O& N4 c, i4 Z, ~
    10
    , `' Q, m4 W1 z; D( ^11- m. l+ |5 V2 \3 A* S6 ?! B# G
    12
    , \9 F9 ^, F5 S8 F: u13) }$ q: f9 l& z! K( m' \
    14
    : I( T7 R2 z5 L: a) `6 a156 a# G  e1 m0 C& L, Q0 [4 ~" O  e
    16
    & M9 ^* m( r5 p5 v17
    4 e; x; c1 E$ a4 ^* o& }18
    ' O5 x$ _$ |$ M! ?; r6 {- W19
    7 h0 \& {: c4 C: r- \" E( @  ]" U20* a6 W( D# i2 ?
    21
    7 l% ^5 [7 V7 y% X5 ^22
    7 [% l# a1 C+ @: L23- ?# h8 m! o- o) o9 x) Q
    Params to learn:
    3 W# ~' p" g' y+ S5 z% ~; D# {* A         fc.0.weight
    . K" E- `+ i+ F4 n$ j         fc.0.bias2 r: B: w: H- r1 {* d
    1
    3 c0 m: V9 P$ _4 H/ c: O2
    2 l" }; H, j5 J9 c% w3
    3 Y+ D, A: G* \1 m: i7. 训练与预测
    % b+ @& v' z) p8 C' b$ i7.1 优化器设置
    . i, q- u: I0 N  B- G7 S# 优化器设置
    ! X4 {0 o. H8 @6 N7 a, N) u. Voptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2); F9 A' [5 U/ v2 p1 d
    # 学习率衰减策略
    - f. e2 x3 S' v# n7 Y3 k: g' X0 N8 X# [scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)% u2 c" Q9 f, s3 E1 ^0 ^, H
    # 学习率每7个epoch衰减为原来的1/10' c0 G- r' F: {; c
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算3 [' I: I+ x/ C: P

    : n( R4 W9 h( E2 H& Tcriterion = nn.NLLLoss()
    & i6 W* b# w8 b! ?) v1 U1
    ( ]" F- L- Q6 L( g2 l/ ]! P) Q2( D5 {: M3 ]/ F) q
    3
    8 t1 i" L% x* j4- x/ r" Q5 y1 Z# }+ }
    5
    0 a4 i: ^  X) m$ F: Q2 r6
    # O8 D0 }# ^, O6 p* G/ A; T; h- ~7- F; {: g3 @9 w7 y8 R
    8
    # J' p) I6 w* ?2 @# 定义训练函数3 \! _9 w$ o. s; x, [/ j% r
    #is_inception:要不要用其他的网络: v/ c; M7 p  Z  ?* J5 P
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    0 ]4 C* U. c' o: k$ m    since = time.time()
    & x8 t& R5 v* I( g    #保存最好的准确率3 F$ b! ^% \6 F' z; m
        best_acc = 05 n/ y* D; h2 p- w
        """4 w% S6 ^1 ]3 T8 t3 Y- U" ?
        checkpoint = torch.load(filename)
    , Q5 @, m' [) E- p    best_acc = checkpoint['best_acc']
    6 y+ [" B5 |5 Q7 v/ D  P. X    model.load_state_dict(checkpoint['state_dict'])) W2 U# q" ~6 {$ a6 x8 g* \3 E# n# U
        optimizer.load_state_dict(checkpoint['optimizer'])
    " n& R& B% D& j* J! x9 N& h, x    model.class_to_idx = checkpoint['mapping']
    9 o- R4 S6 O8 p    """+ }9 t+ \9 F$ Q3 Y4 Y8 j6 K& B
        #指定用GPU还是CPU/ A* _/ n1 m* C! q; G5 t; z
        model.to(device)1 X# x! t, U5 ~' p  Y1 J
        #下面是为展示做的
    - D' C- C9 \/ L0 x# ~3 q    val_acc_history = []
    . ]/ g- f- Q3 O* h3 T    train_acc_history = []: ]' z4 a. x$ A
        train_losses = []* V' u( ]' Q, c8 r% E3 l  I, j
        valid_losses = []
    0 N6 d6 O. A* Z/ s6 M    LRs = [optimizer.param_groups[0]['lr']]
    - h$ r; g4 N0 G6 p* B7 \    #最好的一次存下来
    3 H9 _/ U% b! L' B3 B% k3 Y, X% }- m1 R    best_model_wts = copy.deepcopy(model.state_dict()): F2 r! G5 g& m8 j5 s& u! j

    6 c" P/ s1 t1 a+ |; @    for epoch in range(num_epochs):
    2 H! u4 r9 P3 p! a        print('Epoch {}/{}'.format(epoch, num_epochs - 1))# \9 ~! W0 G' r. [9 B) p8 A
            print('-' * 10)4 U2 F9 h2 X8 @$ H" S

    7 O/ i5 h. a) b6 u9 @2 i        # 训练和验证
    / Y8 G5 n5 N" K  ~" X        for phase in ['train', 'valid']:
    + _1 F& j  O( z& F' J' D# y1 z8 g            if phase == 'train':
    1 Y- p4 u* h5 ^7 a                model.train()  # 训练
    , h' W  n( E$ E1 Q) M; z; k            else:
    / f7 v1 z* f' b                model.eval()   # 验证
    " U  L* z5 X- ~) |4 W8 I5 Y: E' a: |3 X( |/ }, y
                running_loss = 0.0
      l; m6 F) ]4 l1 q: {            running_corrects = 0& a$ e+ x# |' {. n% J

    : V6 h' e" c& M+ t8 N% o2 R            # 把数据都取个遍/ ?& P. E9 i0 `4 w6 ]5 f$ T
                for inputs, labels in dataloaders[phase]:
    - |$ x  a' R$ i0 g, U8 S                #下面是将inputs,labels传到GPU
    . ~8 @% e# R8 z* b7 H/ V0 F                inputs = inputs.to(device)6 S3 Y# O/ k4 M2 b6 S8 t- I% I
                    labels = labels.to(device)
    8 t: y* p9 @$ z
    9 h8 n) y. h- c5 E: T% I                # 清零
    + |) p5 E+ E, o                optimizer.zero_grad(); X/ c0 w5 }- Y
                    # 只有训练的时候计算和更新梯度
    ( h: P$ q+ u- z0 s% y# m                with torch.set_grad_enabled(phase == 'train'):
    " ?& e  p! P% [/ \                    #if这面不需要计算,可忽略
    1 J; F. t$ f1 [+ b1 [# m7 c                    if is_inception and phase == 'train':
    $ G; d! L8 E7 v  `& @7 z/ d0 ~1 l# F                        outputs, aux_outputs = model(inputs)
    8 e# A( j' g! C( e                        loss1 = criterion(outputs, labels)6 _) Y2 l+ W; X
                            loss2 = criterion(aux_outputs, labels); r" s# m" M: }9 ?5 }4 F- S! P, t
                            loss = loss1 + 0.4*loss2
    , o5 s3 z+ B) @' c4 d                    else:#resnet执行的是这里
    5 K3 p9 @. ?% e                        outputs = model(inputs)6 i& T+ r7 U2 Z7 \( z* t, \
                            loss = criterion(outputs, labels)' X+ s9 ]' J1 e3 j1 w
    6 I. `" A4 N$ u2 {7 f7 g
                            #概率最大的返回preds
    : H/ `% v  k; K) o) H! k                    _, preds = torch.max(outputs, 1)1 U0 r5 }3 P$ Q- m4 n5 c" f
    # _7 _0 U$ _* `/ Y$ @. W5 `' g* c  f
                        # 训练阶段更新权重! D+ O, Q& k8 R: I- u
                        if phase == 'train':
    " m3 O/ M& A  Z" l  z3 O8 y! r% D* u                        loss.backward()
    / g3 z6 c/ X( r' S6 K4 M$ H                        optimizer.step()
    ' a) W) k/ Y8 g# a
    2 o) D+ C* B, ]' Q" @0 \                # 计算损失
    3 P5 z2 a0 w' l- m                running_loss += loss.item() * inputs.size(0)& v2 s0 A: b- h8 i- H+ n6 e
                    running_corrects += torch.sum(preds == labels.data)
    7 Y6 o8 r' O+ x) ]$ r" U; B$ O4 l5 N- P9 x. z5 T7 b! g
                #打印操作" o' v8 z/ W7 }0 b: _
                epoch_loss = running_loss / len(dataloaders[phase].dataset), Y- s+ U7 t& s! \* o* i
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)3 F% ~& F$ D8 l( X, c& V# S! u
    ' C% Z* t2 @3 u9 }. |
    0 T' K1 W4 K! Z8 M% R
                time_elapsed = time.time() - since
    ) n) U( w% i* z( C            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))2 Z: t' S7 y( v$ l
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# {) }- J* q. k
    ; b9 Y( K. d* \% q( H: w8 z
    5 N1 m4 `8 e! j
                # 得到最好那次的模型
    ' @5 U9 U6 z% k- q: A5 ~4 f- f) ^% o            if phase == 'valid' and epoch_acc > best_acc:
    : S. @. U( d2 a4 w3 W- _- h! Q( V                best_acc = epoch_acc
    & t6 G3 Y3 X& a* }                #模型保存
    5 h5 P( }% ]1 O; E' f/ I2 d2 U+ a                best_model_wts = copy.deepcopy(model.state_dict()), q4 E% r* P6 Y! r& Y4 G+ |
                    state = {
    3 e, [3 F3 l% x# d3 y4 }3 L                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数2 F; y9 q" I/ a6 r. H3 R
                      'state_dict': model.state_dict(),4 z9 L2 t. u$ Y/ e
                      'best_acc': best_acc,
    5 H7 R+ X0 A- V2 }& h5 X% @8 ]                  'optimizer' : optimizer.state_dict(),
    / S, x+ x- ?0 H1 `/ G" u. _                }7 G( o0 y' W. h! x; N
                    torch.save(state, filename)
    # p+ Z% r% T, k( }            if phase == 'valid':9 Q3 v9 K! j/ s8 x' e
                    val_acc_history.append(epoch_acc)  R9 S8 ^: ?- v- I# E
                    valid_losses.append(epoch_loss)
    ! c; J/ S' L; A: |' P                scheduler.step(epoch_loss)( L* Q- U; z* l" |9 }
                if phase == 'train':2 s4 U' R8 x8 f6 C
                    train_acc_history.append(epoch_acc)( i! ~! V2 i2 [2 d& {0 I. d  ^, E
                    train_losses.append(epoch_loss)# H  w) t2 p  [

    0 W; B) F% A; X        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))4 U5 {+ `4 z# G' y
            LRs.append(optimizer.param_groups[0]['lr'])# @) ]- j' v7 E3 l/ D5 k
            print()( W* }0 C/ C3 ]

    8 w  G' v$ D/ N* v6 F    time_elapsed = time.time() - since
    9 e7 A: O' @8 B  q  ]5 r: c    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    ' n# L: g! h' b) N& T' B    print('Best val Acc: {:4f}'.format(best_acc))4 |# T: h& }  g! ~6 R
    8 i# ^. M9 p3 Z& d
        # 保存训练完后用最好的一次当做模型最终的结果6 f% ]" a0 T3 L3 b- R  J( r; M9 L
        model.load_state_dict(best_model_wts)) H8 p  h5 ?2 ]/ k
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ; i$ x2 `. w2 k/ s% l+ M
    1 o5 c8 A) E0 r" \3 b4 h

    0 T- l( t1 Q4 P5 Q* M1( T* k& _$ V  f4 u0 a9 P5 c
    2' G! C2 B/ b5 i3 t$ L; j* m
    32 Z3 x( W" \' ^1 S
    4
    + g  n+ |* G" H8 [* R0 N+ K5
    2 y; j- T* t1 u, |( y4 S: p( I6
    9 Q- \$ J) _' n8 o2 X& g* ~! \9 r7) j$ \* R( F% Y' i2 q! T/ l& ?2 u
    8
    ! Y* e3 l! @  J* R* P  g/ K/ m9
    " p4 k% P( M$ q0 R" b# n* T8 p10
    , s" P4 V( w* R7 f9 U11
    " q6 b% I! `  L3 G12
    8 g  c* T6 r, X8 f, e3 Q6 G13
    / v* Y. a" \0 G5 T5 U149 F/ F/ {! r5 i4 M, S6 @4 l. M
    15
    8 A7 {9 l, Q4 l" t" H16
    / n: g" Y: M0 D5 q+ k& G170 |( s) ^9 f+ `/ o7 A
    181 {- c! o8 [- |$ \0 p
    19
    3 r. J, k/ D$ l/ v, d20
    # [& F/ I, Y4 M( E0 d+ I21
    ) w! d( h- F* L: b% ?22
    6 c- Y/ s8 c$ V5 b5 o4 b23
    / ~# R5 m  F& _; q" g245 o6 L+ f' h4 g7 }
    25
    4 ~/ }: n$ s# y! C' t; _8 r26+ t7 J, N8 u6 h0 ~
    275 o+ q3 I: ^' x& ?) D5 T) x
    28
    . }( u, j& \6 D: E6 M  _5 c29! \; J4 {+ D( I: p" k
    30
    * R+ S4 n6 _, b8 j7 ]' z31
    " n4 {) L2 e0 I) u32
    2 Q8 E' m* }: k. ]: U33
    $ z/ C% i3 E( @& k0 L4 ^. t344 J9 \% s/ _% [7 H7 k3 }
    355 K2 u' g1 q; n5 s5 e
    36
    / A6 H$ q# l: K1 x) X7 x375 T: t& b! z8 f
    38
    ) ^4 }3 V- N2 P$ O/ w( F" ~3 L39; z; R" z( J# W/ |3 [" `/ q
    402 v% g( [/ [( C. g: u
    418 h9 ~/ U3 |3 {8 L8 M1 x
    42
    5 e$ \& Z. G9 J: ?43
    3 s+ e( I* b6 C8 V. d5 Q! n. {* [44: @& K7 ^- m+ [! M2 x
    45& {; j2 y% q' D/ g
    462 k- c" ^' B: E
    47
    1 l: _, o$ v1 M& V3 l483 T" ~4 e: h! E5 Q" ^! P) B0 Y
    499 [' T7 }/ j0 i! j# O$ f& ]9 E
    50
    2 i3 j# Q3 W. \, t51
    1 Q& L! m$ A" A( r  Y52
    : q* i/ }- ~% c- \$ ?+ z' x53; c* g1 R6 [8 A) f! G  C% y3 c1 C
    54
    & j& i8 {0 v7 f- T9 U55) p- j' h/ b" d# T6 e
    56
    ; q" Q/ ~+ t/ [- o- J0 ~0 G/ k57
    6 I$ Z/ X' B2 z! I; ]58
    ) T' t9 @% ~. E- J( \& j59" K! R! X$ n, q5 |, }0 ]: \5 n. L
    602 D! O% {2 C# S7 y
    61( p$ z! K! [  Z9 i! n2 S7 G
    624 o$ G0 Q) M$ {. m& e
    63* z" I0 ?/ N# v: K
    64' w! U( k( \/ l( [$ U
    65: K3 C, x& w: ~) l/ E3 X2 c, S
    66
      ~3 P" V8 b& ^/ x( h/ ?67
    , ?1 M1 J3 E0 R" c9 q686 P' |5 D; X+ Q: s9 B# y3 @  o* }1 a2 g
    69
    ' a: a' m/ x/ S/ |. r: f) h70
    & `- L$ i' L. |# I& r71
    2 \/ Z$ h3 @6 C1 A& i" S72
    0 o' S$ n3 d$ m73
    3 ]8 f( S  W) D3 ?  x74* U# P6 u8 ?; C; k
    759 r' J4 ^1 M8 a. `
    76
    % b2 m: x6 |3 ~5 b4 D2 H- J77
    ) e4 e  |' y8 ~- c( p783 }6 `: D5 b0 f8 q5 ~  y" ]) I
    79
    ' {, T$ h, o" @8 u! {80. f( e% z' z7 y/ H- W( V" [' l! ^
    81
    3 Q( a+ J' I# S" H2 q; q/ U, |82
    4 S: s2 |/ \" U/ _& \83& I0 B* M2 w/ r6 G: ?9 T) C3 t( z- m
    847 w+ a: [5 N' Q' Y: F6 w! w
    85
    ) h+ p, R' d) G' v4 l8 @86: G7 j* d0 s3 v: ]7 [
    875 ^6 E" `8 c  P9 Y1 d' a+ X: Z
    88" m: w& l0 @. g- {
    89
    * n( P4 Y+ I2 E$ Y9 `$ e90
    : Z' o. }4 E, T# ?) \5 p. I91* b/ S& Q3 L6 q/ t5 l  J+ p( l
    92
    # h! j+ W( P8 F1 Q0 l7 N93
    ( o( B0 y5 D$ V! m3 N94
    1 Q. D, n0 Y# n: L+ N7 [; Z95
    ) z3 q  d& E6 O) T4 [" r96
    & Q& n# Q+ H# l2 u  n& J97+ D7 o3 J) v- C- r! Q
    986 b- o8 Z6 n5 p$ Z: B
    99$ c9 t4 X5 v8 R
    100" z- p. A" B7 v8 x+ @
    101
      d4 Z8 }# M: D1 J1021 G5 e  X6 {6 X( i
    103: N$ Y: e7 S( T% }  x, r
    104
    , Z& v; s5 k& n1059 I1 _( O- U9 `' J
    106
    ' @4 @" ^$ h- L: p: \107
    2 [5 I5 g% o' M* B3 Y108
    2 a/ L) b$ ^6 W109
    7 _7 A$ C# H" G- ~110
    ' ]) M, I5 V, |9 Z111# c# H2 e& m  E" a' x+ \8 r
    112# b9 h- [& y8 _6 _3 _
    7.2 开始训练模型
    ; b# h: P/ Y/ J$ O1 z我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    / I- v2 H# Y: a3 O
    " r7 M3 v/ B( H/ d0 E+ Q5 o; b#若太慢,把epoch调低,迭代50次可能好些
    8 {; K$ n) m  \, g# _! |#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合. h) r9 C0 ^+ N' G1 N
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))9 X5 c, e- q( E  H
    # J2 t$ q5 R' j: [, ?2 c, I& ^3 H
    19 ~6 i; X/ \! B0 L6 H
    2( ]3 A8 F4 K6 a
    3. F0 k1 s7 R3 Q9 t# @3 R
    4& o- _; R/ F. B) a7 M. m
    Epoch 0/4/ _! H4 p3 k7 b2 [3 ]* J
    ----------
    ! F& L3 D; {: Z" j7 hTime elapsed 29m 41s
    4 C% e- S. l. N. X& e# A' X5 O4 [train Loss: 10.4774 Acc: 0.3147; G5 I& a! A. ^7 z7 G* J" E
    Time elapsed 32m 54s' }" ~; n" h% e, K
    valid Loss: 8.2902 Acc: 0.4719& W# w3 Y. L" h7 H$ t/ r5 Q" o/ h
    Optimizer learning rate : 0.0010000
    $ _' g$ ^4 M8 ?
      x% V0 \" U2 C# T. r+ \' iEpoch 1/4
    6 P" G. D$ c0 N& ?( x3 T----------" ]/ y! X* {; u
    Time elapsed 60m 11s
    ! N# F9 p/ L2 _8 o  d+ `% h3 v# ]train Loss: 2.3126 Acc: 0.7053
    ' i! i1 d+ e5 T0 MTime elapsed 63m 16s  D- [0 {+ ?0 `2 B  M  _/ M
    valid Loss: 3.2325 Acc: 0.66266 C% P6 m3 t3 u
    Optimizer learning rate : 0.01000001 |5 x, [% v8 }$ M% ?! V4 o# ]
    % s( [% l) P+ ^" d5 k' q$ o
    Epoch 2/4
    ! Q; \: Z3 o# X. R9 N----------
    2 S# B% G: T8 ?9 c2 R; jTime elapsed 90m 58s
      u7 l( }: t+ E! u. Strain Loss: 9.9720 Acc: 0.4734
    2 h, K$ j8 ]8 ^$ G% e* i  `% GTime elapsed 94m 4s
    7 F$ A# F. u! O! W- @valid Loss: 14.0426 Acc: 0.4413
    : |  D9 p+ r' h+ B0 a& ROptimizer learning rate : 0.00010006 W. z' h+ G5 l" m' Q& d0 ]
    $ N8 F7 b/ s% _+ ]9 B! r
    Epoch 3/4
    6 k$ q+ B. g6 g1 w) e$ F0 D; U----------
    8 R2 p1 D3 }5 R4 L. b8 jTime elapsed 132m 49s% q9 W; }4 ?+ M( I$ s7 Z/ }
    train Loss: 5.4290 Acc: 0.6548
    ; x% ~6 H: y% K: `Time elapsed 138m 49s
    % x# w6 ?/ d1 b" avalid Loss: 6.4208 Acc: 0.6027
    4 j4 e2 T( A1 A& M  E* iOptimizer learning rate : 0.0100000' b# z5 s: z' y) T: q; h
    9 {6 Q: f6 p# q! F' K
    Epoch 4/4
    + s  H/ q$ ^9 f5 q# O5 }3 r, j----------* }: G( m- m5 m
    Time elapsed 195m 56s
    # d- M0 m: m9 b. n7 htrain Loss: 8.8911 Acc: 0.5519
    # E, @: V7 g- @* W5 qTime elapsed 199m 16s
    : }* H% u1 q1 I. o) Y2 h" S, `valid Loss: 13.2221 Acc: 0.4914( K( M* j) R4 j& m0 N1 `
    Optimizer learning rate : 0.0010000/ q' }  K9 w9 j: Z! N

    $ [  ^5 z2 F# x& }) R3 N8 w, OTraining complete in 199m 16s
    5 |9 R( [9 p1 S$ Z1 j, YBest val Acc: 0.662592
    ; D! ?$ d& S5 N. J( p: L" T  [$ Y3 A/ j3 d6 H; M
    1
    ( i9 K$ k3 O+ K" ^3 e; n; Q2' s' {3 F- {2 r2 |9 S
    31 w/ p" y) n; U
    4: F3 c- |8 L. i* C3 J% ^
    5
    ) s1 K' N; M, O6 i4 S* C6) P3 M( e9 O/ g6 T
    7
    & V% B: n- o- Y  i& V$ ]82 u0 O9 |% O! X' f0 a
    9
    0 Q+ I) Z  o( b& T5 w! Q8 ~# m10" X$ I: H2 w+ Z. a; s: G/ Y
    11
    ( b3 {( b/ F, H) G4 w12
    5 x& B0 d6 x6 D# P9 N# e13
    , [) V, y5 e3 R14& F5 c) S0 `  y( |% A
    15( ^2 p- ?. g$ u2 R
    16
    / N. @+ z+ \+ W. O8 O4 P17
    2 e: a+ M; U; L8 r181 S; L7 B; N8 z9 J
    19
    0 c+ D/ L4 t5 i/ h+ L, v20
    : h  S# n# n) j/ b21" [; _6 b. X) N8 b3 n/ G' z! _( j
    22$ W+ S2 R7 L% ~) R3 J
    23
    2 ~7 J7 r0 C1 c5 v24
    - l0 D( o2 p; j7 {9 x7 n4 `7 s1 J& p25
    4 q3 |( M- ^( x: \& h2 m& c3 u9 i26
    4 n, C4 z1 j9 w- V27
    " y3 t" C% G4 Y28
    6 |- _( w2 h/ n, y2 I5 z4 j29% r7 d; C9 A0 U! M- m
    304 m) A" d% E( a; \2 G& V, l
    310 i* [: L( w4 s
    32
    ! a- _. ~" a' M! w4 w  a33
    1 {1 \9 L; B7 W; H# J% {7 e34# h0 ?$ p# H1 w3 ]* @' F6 d
    35
    ! _# m6 p& }- Q366 Z, D2 w4 I" m3 p+ @' I/ D
    37
    * Z& q  g$ I/ b0 |3 {8 z389 r: l' k: g' x0 F% ~, @
    39+ I0 ?# `4 `1 p# g5 e* _) n
    40
    + U9 o7 X6 A# y5 c! I  T410 \: ?7 h8 Q0 G* O
    42
    2 ?5 e: J. e; N! S: @1 a* W8 S7.3 训练所有层7 \! S; F2 y# I3 `  q/ ?% |
    # 将全部网络解锁进行训练
    & j3 K; \* l$ m4 }; M1 i; Wfor param in model_ft.parameters():' ~# J3 E1 l/ T9 l" h8 H
        param.requires_grad = True- Q6 B7 o# d6 a* |  H
    5 r2 m: S. [) Z+ C$ H6 Z& ^4 Q. n
    # 再继续训练所有的参数,学习率调小一点\
    ; E2 E: z7 m8 o: c2 K2 G& eoptimizer = optim.Adam(params_to_update, lr = 1e-4)) h4 k( D, D9 J6 J
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
    & m' V( M( C% e1 Z/ V+ X
    7 ]  D' [' o8 x$ b0 s# 损失函数
    ( F5 p  D+ H! wcriterion = nn.NLLLoss()
    2 V+ X' u- i* i/ S1
    : Y0 ]  g0 }6 x8 T  ?% t, k6 N2
    9 L. O$ C/ o) F- m# K& r2 z2 j3
    & O/ a! y5 k- F2 I2 ^. v% G4( [0 X6 r6 |+ \
    5
    ' ~+ ^/ ?# M1 }/ W$ i% y6; E9 E1 I- J, {: u3 J7 f
    7, m* k& X7 Z" |7 L5 `+ r
    8
    ! ~9 f2 E( `5 r3 i# q8 ^9
    6 Z0 ~2 z* o* p107 ?8 I9 q( v$ h9 p
    # 加载保存的参数9 g$ ?  l: z# o3 ?3 m7 }3 K
    # 并在原有的模型基础上继续训练
    4 f2 j2 W$ `% w* l# 下面保存的是刚刚训练效果较好的路径
    % t, b& d+ I. B$ W" ?, gcheckpoint = torch.load(filename)) ~3 X. A; f1 Q4 O: |
    best_acc = checkpoint['best_acc']
    , {6 a1 v9 T; ]" q  O/ [4 @5 D2 emodel_ft.load_state_dict(checkpoint['state_dict'])' {0 Y6 ~) Y4 ^
    optimizer.load_state_dict(checkpoint['optimizer'])& n* ]. d8 {; z2 t- v
    1
    : R0 s& }) ]) {+ S. i' S+ A# M2, x* i! H( L" O4 W
    3
    0 h( g% ~6 L% z' G. y9 W4
    1 u- c! ?% B7 q5 `! {5 S' R8 h6 b5
    ; h4 m$ p& G  u  h  t- Y0 N6
    / D/ f& x& b* D& q! a, ]72 o9 [$ D: B' ]0 U& e3 f
    开始训练
    2 s, D3 |( T- l注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    - u0 K& X8 V- t9 Z+ \2 ?9 ]' K) q, T; @: Y
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    0 `  C8 Z3 Y; v) s17 G& F/ m: e; j
    Epoch 0/1
    2 ]7 U" r2 o& o9 d4 m0 D+ j----------+ t3 r8 ~4 K, w1 `+ P( J: b# r
    Time elapsed 35m 22s
    . F0 f; c* q. K+ K- Vtrain Loss: 1.7636 Acc: 0.7346
    1 z1 I% |6 M; r0 ~Time elapsed 38m 42s
    # p- E, C9 h. ^" Vvalid Loss: 3.6377 Acc: 0.6455
    2 ~2 L# J4 @3 Q! J4 U  c7 eOptimizer learning rate : 0.0010000) K2 {/ s" m& M) ?% Y
    ; N! U  V6 S2 C
    Epoch 1/12 @% B7 [6 `$ f: F
    ----------
    # g3 P) `& Q" X1 F, k! FTime elapsed 82m 59s& o" n! j5 G" e8 _) o% h' @" x
    train Loss: 1.7543 Acc: 0.73407 m3 l1 f2 C$ R. l, }( [: \
    Time elapsed 86m 11s
    5 N% q9 |) W3 a: `) Avalid Loss: 3.8275 Acc: 0.61370 F$ N, j) V; v1 i1 s
    Optimizer learning rate : 0.0010000
    * J+ k' m! e& T' G+ F
    5 r0 t# t9 h8 G8 v8 o( kTraining complete in 86m 11s) G. k4 k3 m! V9 |5 Q
    Best val Acc: 0.6454779 @7 `# H. [. ^% I* f1 h$ F- |* R
      {9 Z8 v4 k# ?( l
    1
    ' ?* \+ ^" F3 g) j2# h: {8 t( _7 T: E
    3- s/ C5 B9 [+ l: A; b' B
    4
    1 b- L4 Z1 |8 g; T52 t' S; C8 @4 r, K" t* A# r8 Y
    6
    ! g+ E( v# G7 W' @7
    0 Q( r8 {9 d2 {9 N* y3 [) b0 A( G8
    4 b0 T$ B$ i' F6 L0 ^% R) {( j5 d3 {9
    : e% W! Q8 m( ]6 {7 D$ ?1 g! o10
    " V$ l- c7 E1 F; l1 ]2 p' p11( }- s9 h( x8 t
    12+ N/ i$ Y3 `& V4 y! G( O+ t
    13
    # u; l+ z  k$ |14
    , j; e& `+ b% `9 ^3 U8 l0 R15
    0 F) e$ I7 W4 R3 n16! {3 v9 M& g+ M  B
    17
    , n0 z: \6 a( O/ a, r18
    - }+ B( }% Z. o4 T  d) y8. 加载已经训练的模型
    5 D( J$ Y' R2 d相当于做一次简单的前向传播(逻辑推理),不用更新参数
    : v0 Q+ d2 v# Q
    6 o; v  `1 N$ A- L" Y/ b: Pmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
    ; o7 i# N& ?0 ^8 w# ]3 J+ X# o" T4 d8 {% M8 z
    # GPU 模式
      y) y0 ]7 t8 x9 m" b, Smodel_ft = model_ft.to(device) # 扔到GPU中
    4 j8 b( ~( q, A6 _. q0 d# E
    * H8 M  R0 z. W5 `# 保存文件的名字
    3 z2 E2 X+ H$ b  A: ^* jfilename='checkpoint.pth') ~0 e1 T; l# z+ _$ x- I9 c' |4 ~; {
    4 @" D0 F( P/ Q% }; b% {
    # 加载模型( [; z1 ]! l1 P- S
    checkpoint = torch.load(filename)
    3 g+ U) {! x7 a; O. B0 ]. E+ [best_acc = checkpoint['best_acc']
    8 k' Z4 |; E5 o) p! u2 Gmodel_ft.load_state_dict(checkpoint['state_dict'])8 c# s0 u" |" K
    1
    ) Q) K$ Z  X% o% o3 p" ]( d: W2
    1 c- {8 @( A" B5 I8 n* G3
    7 O- \) g9 }* j0 H& a% @4, A6 M. l' ?& P" x5 i) b$ L
    5$ n& l$ u6 k2 C4 r
    6
    0 V. O' g( Z1 B5 q7
    4 U* M4 Y0 ^  B9 {% W# {0 W' P8
    + u% e9 l! ~! |( i9
      @  q/ \# b2 h* H, S/ D10
    + e4 B3 }& [6 k* s  l11
    * g* V3 ~- G5 T12
    8 M- t8 b' u$ @; S& D: Q+ F8 J3 j<All keys matched successfully>" }( o  ^7 f# D) y
    1# T. }* [( z. y/ B8 v, h" Z& g
    def process_image(image_path):. d4 C5 x" @8 e5 a
        # 读取测试集数据# O( N  B8 k! k1 c  O! `
        img = Image.open(image_path)- f0 }- }) _* e: V! Z7 y$ J/ `
        # Resize, thumbnail方法只能进行比例缩小,所以进行判断2 G7 x1 G; I" _  E0 _* v
        # 与Resize不同. M: F% K. D  b8 l7 e
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
    ! s9 {$ x% I; J3 g0 y9 j    # 而且对象调用方法会直接改变其大小,返回None
    / `( c1 g( v6 R2 I9 X3 h4 Z    if img.size[0] > img.size[1]:' V0 p$ d( t+ u# R* Q( r
            img.thumbnail((10000, 256))1 b5 q$ {+ S, w
        else:
    1 u* a9 Q3 Z( W        img.thumbnail((256, 10000))
    9 H1 g  G% l" F: ]! e6 y- i5 {1 s" g7 P
        # crop操作, 将图像再次裁剪为 224 * 224
    & c" J* e$ x* \    left_margin = (img.width - 224) / 2 # 取中间的部分# Y" i7 L8 I) z5 Q2 T: V- j, ]
        bottom_margin = (img.height - 224) / 2 9 D+ J, F" w4 P: \9 L
        right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度  F9 @7 |) e, L- r2 p- B( r* @0 D
        top_margin = bottom_margin + 224+ C6 D: R, o/ t- ~5 g( J
    7 Z, P: J( `8 C4 h% d7 Z4 S$ |* u& k) H
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    2 X+ M  q0 U9 ?
    2 U% w3 g; F! S4 R6 O/ [2 ?5 U    # 相同预处理的方法
    + C8 I; L4 G1 K2 y    # 归一化# t' p" r+ L3 b
        img = np.array(img) / 255
    # g; r# Y; X" G; _    mean = np.array([0.485, 0.456, 0.406])
    9 b$ X; L4 k9 @    std = np.array([0.229, 0.224, 0.225])0 ^2 f, f1 L' Y, c; t# |- Z; h
        img = (img - mean) / std
    9 v2 p$ n4 C- i0 @5 F( s" e% B# B) @/ x9 M# v, p3 `1 G4 |. F
        # 注意颜色通道和位置6 ?5 d8 _1 l% q% c# u/ _3 K
        img = img.transpose((2, 0, 1))1 r3 e! O) O6 \% H

    , R3 H5 \8 M, N. f    return img
    + q1 y$ m$ R# q4 ^2 e% L( T+ g/ Y+ y0 x/ A
    def imshow(image, ax = None, title = None):6 m* m4 s& {5 G: p. u1 D8 G! H
        """展示数据"""
    / d0 C3 Q0 R6 m) l% A. v# s3 u! W    if ax is None:
    0 x6 ^/ b' @. I: F        fig, ax = plt.subplots(), W! k9 L& E; B6 h1 I4 G

    $ N/ t9 U$ u. k3 B    # 颜色通道进行还原
    0 ~2 Q) P6 _0 Y3 |; e' g    image = np.array(image).transpose((1, 2, 0))
    , D: j7 }2 ]# O2 ?1 b; N$ {" Z$ P: \" V
        # 预处理还原
    " M8 n4 M* l9 ^" M    mean = np.array([0.485, 0.456, 0.406])  n. B& J; j. g6 o0 i( O# U
        std = np.array([0.229, 0.224, 0.225])( b9 D& B" f! o9 T
        image = std * image + mean( v0 v' G2 ~& w& G2 T
        image = np.clip(image, 0, 1)
    ) X  X* ~8 e, ?% U' |. ?% E$ q$ P  V* z# t2 L' {6 A1 y* ]) H
        ax.imshow(image)
    8 {2 [  g0 ^$ O+ Q    ax.set_title(title)
    2 b% Q5 {, s4 D  }$ r) r5 u
    ) l) x' h0 ~6 ?$ U4 }. {) I    return ax6 f( L! Z; Y4 E) e5 |8 t2 r& r
    9 f5 c* t7 O7 v$ e
    image_path = r'./flower_data/valid/3/image_06621.jpg'
    # d7 b8 T. J7 |- [img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. `" h& Y* `" x9 o) c
    imshow(img)4 z' }# D' `4 B% X7 T3 ^4 x

    ! A" ?+ z1 h8 `  I) F1
    0 y/ W" u9 c; m& E8 H, @  e5 V- x  W28 a8 `% o3 o; H
    3
    0 B- @8 y2 c1 e. O6 a" N) g46 z0 O& D4 P5 E  X+ K
    54 n( d) b& d3 P8 _) F
    6
    8 _( ~2 @- X8 z8 Z7
    4 M: g  R3 K$ [* W85 k' S  T/ x4 T5 c
    9
    9 u3 @* ^2 \# l% Q1 s10: J* `6 h8 c1 O7 P
    110 e4 w6 O2 k! a8 f- @% }
    120 e9 J' i# G6 z, r5 e' {
    136 R  _. ^% s' g5 B9 I  L3 i
    14! Y) y+ m2 G3 f3 O! P0 I. ?5 {
    15
    ( K  r( h7 e6 D: E: l, z8 j/ L165 e' ~8 @, B# v7 {  X# N  B* r
    17& m: P4 J) s& f+ X9 m% R3 Z
    185 b3 E1 `! }' k
    19
    % L5 d5 u1 i! \; x: [+ ?: n6 n20
    2 o2 }$ M0 Z7 j7 X& L21
    ! H0 ~! d5 \- P, Q+ Y+ A; F22
    + G+ D3 N/ D. X, T/ R& }23
    ; p% K0 W* E# _8 ?# k$ y/ u" w24  Y( J5 x9 o# X* U8 G; Y
    25
    / }) b; g0 \( Q5 I% k5 I26
    - }0 _+ d6 G7 M( ?2 ~27
    ' m3 k1 k2 a% [; R2 ]( e28- @: ~5 C' s. ?% K2 x% D9 U  I  P
    29
    7 I* R1 H/ w$ Y! ~# U' `# s  [30
    6 N* D9 a1 s" J# }% P31$ C- r/ W, j0 |1 z. w0 G
    32
    0 K: S9 D. w1 C1 F. a338 W$ H& U+ W& L% n" a- o
    341 m2 d7 ?3 e- z2 C2 O
    35/ t' ?- C9 q6 F
    366 u$ Q6 M5 t$ ~# K" A! j) W
    370 H& \/ ?+ i# Y8 @% [
    38# H  y2 E- w4 }9 U2 B$ d' W
    39' |9 H% P! N: c& [7 ^5 P  S
    40
    8 C4 B* Q0 w) T, x  i' Q41
    ( X0 m& B6 e- F0 O428 S. L- l0 G' w8 E+ X; g1 N& i9 D+ V
    43
    % A& X, A6 m% I% e44
    5 \. `( K+ F2 s" s8 T8 H6 {45" A' k# v. d. U
    464 j  L' [2 @+ b" y+ C
    47: l/ ~; p( L/ B/ I, s7 Y4 }" ?
    480 d2 U/ ?% }3 {, h7 ~. P* V
    49
    9 f- N1 ~  A5 p7 I5 n9 ]0 p50
    8 M, r& [) Y- B" k# h  ^51
    9 i9 |* x! M) w" F$ W: k4 f52
    + a5 u- B% c+ Q3 e8 P53
    * ~' L0 i. U' {- o" }6 ^3 |/ J3 Y54
    ' j9 C) Y2 v6 v* I$ \<AxesSubplot:>: ^7 }5 ?: q4 E2 D- a
    1
    9 q2 q. z; {# w% _
    ) A' M: i& X7 _/ {% G上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确( {! j, z. G7 o' e* m/ k) U
    . a; C. b) ]4 ^$ X9 C: |5 u3 s
    img.shape
    - A$ _$ e) Y+ ~' o! v6 O2 e1
    - @4 ?6 [4 g: W* x% n- ](3, 224, 224)
    % n0 S* |' l, o( r2 d6 H8 R1
      A; E3 E7 \! b, E+ e# k2 v& H证明了通道提前了,而且大小没改变4 s- Z1 W2 |) K; d3 q* E( P, q0 d* B
    & |  y% x. p$ q
    9. 推理
    . ~" U1 q# ^. }; himg.shape
    $ Q0 o, t; j( p3 b) ]
    " b# |" c! Q3 E' p) i, G3 M# 得到一个batch的测试数据# z9 _2 {8 Z5 B7 \6 ^4 n' l
    dataiter = iter(dataloaders['valid'])7 C; ~7 T2 V* a$ K
    images, labels = dataiter.next()
    1 c% Y6 O+ D/ f# ]9 g6 B/ Y; `4 g8 @
    model_ft.eval()0 |6 c+ M7 j' V, @! Y

    . d$ F3 i: ~9 u9 Jif train_on_gpu:
    $ J' x& y# F+ E/ l% a3 K    # 前向传播跑一次会得到output
    . W2 F$ U- e  {* v    output = model_ft(images.cuda())
    # X" E6 S+ S) A8 helse:
    * v, U) b2 T/ P1 ^% d$ F' p    output = model_ft(images)
    6 f( n3 k2 r. R5 N6 `/ @" i) Y1 l7 j- V
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值5 T# u3 O5 F0 L! ~! k+ j$ `2 j
    output.shape
    # j) O  i0 N# A, N5 o+ n
    8 S# p" {6 C9 D0 b# |/ K1
    7 B; b4 W& O3 x& T2
    + F2 i( R$ A( A3
    2 i& q9 Z) W( o! f1 S/ L: Q( ?4" k3 \" V+ s8 S& M
    5
    ' f: @! B: _  S9 D3 N" I! d: K6. G) ]3 ~; {" W! @2 J0 n% R
    7: ^: {. j6 V  k
    8
    4 r; ~+ f9 Y; q6 N7 O99 x7 P4 S/ A, a3 \4 `
    105 z8 t- q, q2 `. T& |  I+ D( Z
    11; E: t% o8 O3 w8 ]% K; ~
    12# L( q% d& l9 e( C1 i- r) z  f5 M
    133 I% y# X, _( \. l
    14$ |+ r; X8 K0 e0 u/ t' p
    15
    & p3 A; ^. A" D" b( e* F5 K, |16
    2 W% r* P) X4 dtorch.Size([8, 102])
    7 f* T0 w' b1 q$ m6 z/ h1
    - [9 o4 u7 D) q$ g1 U* P; C, J9.1 计算得到最大概率
    * e; S7 w! {2 c1 __, preds_tensor = torch.max(output, 1)4 n0 q. M! x8 M. {8 k( A
    - k& e3 D7 l" K. U
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量3 V) }( i- B+ S8 ^" c& C
    1
    : ~$ ^! Q7 s% ^9 C) _! t( W4 Y2# a% D+ Q5 i! d
    3
    3 b1 r: a+ T: E, \2 @# L) V9.2 展示预测结果0 G9 ?2 @5 E  x2 D
    fig = plt.figure(figsize = (20, 20))2 G0 ~& v/ }1 L% Z, _/ p1 J
    columns = 4$ V4 W; `/ L2 l2 i- ]
    rows = 2
    # g( J( M; s8 O
    ) d$ ~) U8 p6 j9 _7 C) u# k- a! lfor idx in range(columns * rows):" Q5 ]6 D/ C& F' Q# O" `
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])4 e: }- E+ K' S) I, I6 |5 z
        plt.imshow(im_convert(images[idx]))
    ! G& [4 g6 X# C    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), ) v% H0 P2 l2 q8 a# F
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))% v7 B; W, b4 X: k' S
    plt.show(); c9 h( k9 Z7 F3 [
    # 绿色的表示预测是对的,红色表示预测错了' E* x- D+ j5 V2 ^, J, o  R* v
    1
    4 a, m: j& Y9 g3 j2
    2 i5 z, W! ]5 U& {# t! h/ I3
    / f% g1 V+ m$ x+ R' l0 \# }4
    * o5 B, a4 {# m9 y52 @4 K) h, Y* [" H
    64 o3 L& K" U: G/ ?& [) Q
    7% T% A5 ?9 x3 d" }
    8/ T0 \' m5 x1 O1 o7 ?2 }/ j
    9
    # F% q" h& u$ X4 @! Y$ Z5 A& e10
    / Z+ J% a' }% }0 G6 \8 j119 S" ~1 ]' F4 _, D! _: |' Y
    $ R: e9 h! T4 K

    4 J* j6 m; i' ~7 }; [) U* }/ W, X  b/ ^9 {  x& ]- s
    ————————————————+ g7 W- K1 s0 j! S4 x& P3 O# Y
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    , Q# u% v% |, t原文链接:https://blog.csdn.net/LeungSr/article/details/126747940$ y  ]- z6 D7 h

    + w' ^: O$ S% _, _0 d6 Y  w  L. @  Y7 W. m% o5 E. u* ]
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-6-14 10:21 , Processed in 0.485583 second(s), 52 queries .

    回顶部