QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2720|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例; r7 i. P1 z4 _% W
    4 Q/ d7 B- p' A4 c  C1 `
    文章目录
    4 X5 Z, T9 w+ A' n- }卷积网络实战 对花进行分类: B6 w  e0 k1 @& p0 M3 j
    数据预处理部分+ I6 \( O5 j' b  g% T! B
    网络模块设置/ |6 N* I; {( V1 _4 X5 F. s7 ^: r
    网络模型的保存与测试
    4 F) s1 [6 y2 G- R" \数据下载:9 K* V3 O# N8 h) F# Z. t
    1. 导入工具包
    6 p# ]& f1 o; T3 }6 J) p3 o2. 数据预处理与操作* P# O* C6 @: t( N
    3. 制作好数据源7 @0 n1 G) a" o* r' o0 _! r$ |
    读取标签对应的实际名字
    ! e+ V! d2 W% `* n  |, {4.展示一下数据: ^) H7 N- s4 L) ?+ ?; S9 P
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数" C2 R6 u! z1 F  K8 X
    6.初始化模型架构8 F: W8 }, d% k
    7. 设置需要训练的参数& l$ i% N; V+ m% ~
    7. 训练与预测; @! C: i- _" f1 N, p' l! w
    7.1 优化器设置' V1 F- \( ]$ ^! ?' D, ~3 Y
    7.2 开始训练模型- `+ U) P, p; l8 N4 F( X
    7.3 训练所有层
    , F3 |7 p2 T0 T. ^+ q开始训练, f$ h, h, U& u) m9 w. p
    8. 加载已经训练的模型3 q. {5 \0 G- C8 Y: x- C
    9. 推理- q$ \% i4 e( E* O  M1 C
    9.1 计算得到最大概率
    % p' G& e6 ?5 j! M1 Q  L: f9.2 展示预测结果
    ) M8 V4 ~  }  j% k写在最后9 I9 q4 H) h/ d' L- G; ?- o
    卷积网络实战 对花进行分类
    : v" g" O# k7 p# e/ f- S' [$ G3 _本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    + Y: s- O! I/ F2 i  X* u' O3 W  `- A
    9 ?/ r4 b* e) x$ e/ ?1 U在文件夹中有102种花,我们主要要对这些花进行分类任务
    8 Y/ t+ q5 \- ?' l# {文件夹结构! B1 \  t4 m# x7 l/ }" k
    + x4 R- F/ X4 {0 A2 l3 c; o
    flower_data: Z" p3 w' f( C7 e9 f" S+ ^' p8 u
    % p" s2 i( i( n+ k1 L
    train
    9 s5 ]3 j5 w! c1 |* ]6 c! V0 J5 s- g& f) ]+ \5 l0 Q' Y; G3 l
    1(类别)
    . c9 @; c; X% r$ H2 @$ ^% ?& ^) g2
    * M. ~  |2 }0 r2 S7 e1 ]xxx.png / xxx.jpg4 y. K3 D' G) r# ^% X& w
    valid
    $ ~, i6 {  z5 x. R' U
    0 w# V+ W7 T% ^9 i1 @! }  j主要分为以下几个大模块3 w  L% z; V  p' y) z' J7 ^

    " \8 m' p/ x6 L+ l+ n数据预处理部分
    9 i0 W8 |% ~* }数据增强
    ; T! Y, y6 P5 Y( T. P/ a0 ]数据预处理/ Q! @- I9 H4 V
    网络模块设置3 ?1 F. C( o8 W! t. {3 [
    加载预训练模型,直接调用torchVision的经典网络架构2 p6 k$ L& a* d# G0 v
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务* i5 ~4 E( A. ~$ T9 o
    网络模型的保存与测试* d4 u& ~! {  m2 b9 ?3 @: C1 q
    模型保存可以带有选择性
    8 o. y8 }- ]8 Q" q6 ?数据下载:
    " M8 [# @% |% T5 r- Ehttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
    5 p5 N' @! A: x2 x+ ]
    : O" ^; p( b3 _9 h8 t  ?$ s改一下文件名,然后将它放到同一根目录就可以了/ A: N. Q0 d. v! e4 B2 F  M5 K1 |
    4 N5 I. L3 C+ d% m/ ?2 `' P3 ~3 N
    下面是我的数据根目录7 l( e3 E: e9 M6 C9 t3 n" ]/ @7 @

    ; m9 H) K) Y% E. j  C) @* I# D1 x4 |  d9 s3 B* {( f, ~+ u
    1. 导入工具包, Y; w' N; o0 D: J- i
    import os
    0 g9 I3 H+ v  [0 n" Ximport matplotlib.pyplot as plt" x2 t% @( `8 V: C
    # 内嵌入绘图简去show的句柄  G" m- I  G: l  f/ x
    %matplotlib inline
    / W" `. j' [2 d+ Gimport numpy as np2 ~4 z9 |* Q2 n& C
    import torch
    ! z2 e$ d, e8 E! A5 jfrom torch import nn
    ; W; C0 {/ h' Z- c' \2 T& ?  x
    ! q7 U! {/ o& n* R( k" r3 }import torch.optim as optim
    ( m) ?* G, Z: P6 \9 uimport torchvision$ y+ ~8 Z3 |% }
    from torchvision import transforms, models, datasets3 W7 ~+ s6 n8 ?& @4 m( s
    ; p* o+ q8 m$ V4 t/ ?
    import imageio; J4 N3 {/ a& F* x$ V: w2 l) D
    import time; c( B" k: x/ \+ E- e6 Z4 l
    import warnings5 L, k4 Y, n4 A  N4 q) _3 f8 E1 p
    import random5 y3 g! p- p  g7 `2 J
    import sys
    5 {$ d3 |& F/ \5 jimport copy! l, S: h* k6 h7 S& D; B
    import json" ]4 A3 w; _! p/ w1 f/ U  P
    from PIL import Image; g" o$ j* K2 ^! \5 J

    8 E% P7 \) y( i2 ^$ R1 z; }% y
    0 M, [: x, |: O, M! v& G8 n$ `- ]1
    8 H1 u# d# c9 d- S7 l0 ?" Y. |2
    / w+ i9 m9 N$ a5 ]9 r1 @; \3* k# A. ?- ?4 }$ i
    40 F, ]7 `4 g- \. V
    5
    9 J7 F  F5 R  S! j% S2 }2 `6
    6 p: Y7 c$ g$ _9 E/ S" q- t' ], }76 B# ]2 ]  a5 m1 `
    89 ^# |6 S& w7 d. E$ Y4 r9 I- U
    9& `# C. E8 c6 o
    10
    . B: M7 Z. Y* T/ O; M3 T. v114 B+ t% ~+ V9 s! \2 |$ E  @
    12
    ) m6 J& {9 F; I/ |- `9 [, K13
    + M% R' D4 A2 t: Y1 p1 q# L9 R14
    & y: N7 C. f, o" A3 O8 O- g9 Q6 k# z15
    / F* Z4 I" Y# R8 j1 o* u5 I16
    6 r: y& g6 i5 |" d( F: ?17
    9 }8 t. {  v$ t* t9 q18
    9 u) t" D; |! ?* x8 D9 n19
    ; {$ [( i3 z5 V9 M7 |20
    3 ]6 Q" h) t$ a( q) V. A- r21
    ( \* |" ?+ S; l( ^5 ?8 X2. 数据预处理与操作
    ( P$ I# z% E' S$ F- ^#路径设置
    7 l& z: w0 R/ g+ C2 W  zdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录) p, d& H9 p7 B& j. m
    train_dir = data_dir + '/train'
    * h$ Z9 x0 J$ {, w: Yvalid_dir = data_dir + '/valid'
    ( r( w. n% {* r: R5 k1 A1( r+ c# {0 ]9 d2 b! T! t0 D% e' y
    2
    4 h. ]* L( M$ I+ N3
    ; @, w3 F& I  Y% {4
    ' W2 y/ c9 N' ^python目录点杠的组合与区别
    % e' B% D7 a& N4 _3 Q注: 里面注明了点杠和斜杠的操作1 M) Y# `. {. B, ^* ]! _" h5 P5 J
    , j# R: d* g+ _5 I8 N6 P
    3. 制作好数据源
    ) ^  u" k8 ^; V* E& ldata_transforms中制定了所有图像预处理的操作2 Q0 m$ U; E% o
    ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    % T# K- e" `" z% Z: [* A1 jdata_transforms = {
    7 n6 H' g+ c; k5 `    # 分成两部分,一部分是训练, S8 \2 v: @8 [# O* _/ c
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间# H( B% g! O0 U- o. w9 B
                                     transforms.CenterCrop(224), # 从中心处开始裁剪; I  a+ g4 V( ~4 q2 j+ E- ]' L7 [
                                     # 以某个随机的概率决定是否翻转 55开) a2 P( t2 x" P' {% O% t
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转7 z3 ~, i4 _1 _6 w
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    / Q5 P( F  s6 T$ u" I7 I                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    % @6 s- t( p0 c3 S: Y0 D+ ~' g                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),) G: g: e) o' h5 F: ?
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    & |+ l, i: L# I2 u; F6 r                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    , ?/ C, m$ i: u% x- ^0 ]: v0 U                                 transforms.ToTensor(),
    7 X1 a- B' V2 M. c                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    5 {9 y4 r1 B! I( e/ t                                ]),$ G$ R( b. a% p2 N  B  n9 a
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    $ h6 l5 F9 m; w+ |4 F0 i    'valid': transforms.Compose([transforms.Resize(256),* r! |5 t% |1 Q: \: Q+ z
                                     transforms.CenterCrop(224),
    . m* R4 j+ b% x/ k                                 transforms.ToTensor(),
    + S, O4 _8 {+ j2 Z' u# r+ ]. b5 y                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同: N. P7 R( @" N' |6 J
                                    ]),8 \6 k1 M+ s3 P+ ?* l
    }* {' \  |" z, \' L. i8 F1 x
    8 w7 {* v( n$ `, F# ]
    1& l$ Z/ }$ K0 S* H$ B7 U
    2
    - ?. j% J, P/ b$ W! z3
    6 S8 y/ f- p! r; r" F4# M! s; q6 L' k
    5
    : U  o: g& ]7 m, \$ U6
    + L3 T: ~8 c0 J. F! I5 c7
    4 r6 j' |( V$ p( p8! U& k8 f$ ^: [' N7 D
    9
    2 m, ?- t' Z* D2 T) [1 f6 j10- n( r: D5 T  _9 _) S
    11
    4 P. g; y5 I5 o( A12' f# d; z% C" b% [! F: ~% m
    13
    $ T% m# p) u1 l3 t- n$ q9 |  B14% T$ m# M: r0 D  v" O! Y- d
    15. Q( K' t( }% b4 m  q
    16
    $ e/ G% c9 n9 D  P% r17
    7 X$ W3 Y  q! R183 X8 A. x, K; O" B) P
    19
    & Y5 y0 d, y" `5 y5 U20
    7 t7 }) D% {  X21
    , Q" {7 Q  W: p' N0 M. Hbatch_size = 8
    $ [! J; p+ W- `. zimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}/ c6 A! M8 b0 b7 M9 u. }
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    7 l* j, @4 H3 e2 r* v3 X# tdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} 0 d; |. @  x+ X3 z7 H" t7 Q
    class_names = image_datasets['train'].classes. i1 M7 c4 C! X/ Q" d; D& t

    " W" N! x0 `9 ?# Y4 P#查看数据集合# \6 k# a4 [4 J1 L
    image_datasets
    0 J1 o  [: A$ Q5 s% G0 T
    5 F0 I& z2 D8 k5 k9 |1; Y. e2 B0 v0 K7 e# i0 q: l
    2
    % ?5 q4 F. H0 I+ @3 n  P0 m  S3/ l. @$ P8 ~# @0 {! |3 x6 C4 p) d. q
    4
    & Y8 [" m7 `+ T# n/ f- V& _5
    ! Q! i3 w: h' ]3 J9 H5 t7 K6
    1 ?; Q* q+ ~& F# [% y2 r7 T7& x& L5 o3 l, H8 q+ m" h
    8
    * x. F6 n1 \' f8 S! q7 e$ ?% g) \96 d' @" ~, a7 e/ Y& k/ L- O
    {'train': Dataset ImageFolder
    0 K7 p* C9 J: C0 ^1 [     Number of datapoints: 6552& U6 c2 h5 l) d# S' G8 j$ u
         Root location: ./flower_data/train4 u5 a  S5 X" \, t  [* K
         StandardTransform! c- x  P( {; Y9 G5 v- ~
    Transform: Compose(
    8 a& l7 z6 p( W9 L2 f                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0); t$ T  G; y& e2 O
                    CenterCrop(size=(224, 224))  c' h+ \+ ~: J6 Y
                    RandomHorizontalFlip(p=0.5)
    " B- C: s6 @, T! d9 e                RandomVerticalFlip(p=0.5)* m# ^; G0 s8 g
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])- }  F4 N: _  u3 w# q  m* Q
                    RandomGrayscale(p=0.025)9 z5 Y0 K/ R8 X% l4 d0 `. A4 J7 n
                    ToTensor()
    1 X! b: m+ B8 Z2 U6 l2 ]3 g                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), I' g* V5 a$ ]( D( }+ ]4 m* y
                ),
    9 D- h7 ]. G6 R8 W5 j+ u: q4 E$ f' a8 j 'valid': Dataset ImageFolder1 x. B: n' G* V) x5 }/ K. @
         Number of datapoints: 818
    $ s% J) L" @( r* o) h- ]' w3 r% x+ Y2 U     Root location: ./flower_data/valid. X! _3 T+ _& k8 W% S! Q7 q5 ]3 ~6 P
         StandardTransform
    - L. D, T" l7 U; p0 `+ q  Z Transform: Compose(. p$ r8 s' Y8 R1 S) R/ {
                    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)% m& Z1 S5 E0 h& p; W( P: D! O- [+ w
                    CenterCrop(size=(224, 224))
      F3 L0 {/ `. {- x9 j0 R7 V: _- M6 n                ToTensor()2 V% c4 j$ U* B- a9 _) g
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    % I; E* v0 y0 Z9 l+ n7 i            )}
    ( O) v! Y* \. L# Q, X3 j- j( {+ R0 R# r6 h
    1
    ; f' X. J! }; A  X; _2
    2 t5 r  S/ u7 a39 s# ^# ]4 Z9 I' x  Q' J3 q0 `
    44 J# U5 [6 l2 Z
    57 y+ b" h' p- P6 z8 a9 |! n" x0 ~
    6- p( D; Z- V) z
    75 \/ H+ q' ^, _8 _3 Y0 w, F0 A! ^/ b
    8
    & m; {, E. K) C1 \3 S4 o9 g7 S96 l4 T0 ~/ R6 ^! u) R$ R" s7 e
    10
    + R2 o' q% ?% M0 q3 I) V11! }' b3 K" t! U! C1 \1 g3 P" h
    12% {2 n8 Q! U; ~( {
    13; k2 K2 R. ]1 j
    142 w" v- E  \+ ^+ D
    15
    " h- p( k" M* J! C+ L: o5 Y16
    3 ]& ?  N% W$ L, v17
    : {* A" N% \) Y, M+ F5 k18" H% a$ t6 i/ ?" W0 P
    196 V; e: L& Q0 A- N0 K* e3 @
    20- ~4 R  h* U1 N- X$ Z+ ~; c0 ~
    21* f4 o0 Z. \6 M1 g$ V4 J! j
    220 @' n* Q3 M6 A  S1 t  }: U# _
    236 D( l7 l; S9 e  y: D4 f8 V
    24
    " a- I9 H0 @+ F8 n$ N6 Q# 验证一下数据是否已经被处理完毕4 d# b9 X: T  C8 ^, ?
    dataloaders
    5 F4 r/ N0 r: G/ B0 }5 I# p1
    8 l' d, K, N  z6 K2
    1 q3 Y2 H' k( f+ O! C{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,) r( D7 R* C- c. R5 T" X4 O
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    + v& K% o) t  p+ A; e18 c: Y/ b# R. ?+ i* Y+ x  A8 A
    2
    ( ?( k6 \6 z+ ]4 ]" L" S! J- r# }' A6 Q0 Edataset_sizes8 }; u; |7 f2 `+ ~6 n/ R
    12 Z1 Z! G5 p/ c& I1 N: x& R6 G9 Y( O7 J
    {'train': 6552, 'valid': 818}
    * R0 _; ]& W. q4 U0 N1
    0 _# B! c. O8 _( X" N: E0 y5 o读取标签对应的实际名字
    ! [8 t- q) w! V5 b6 v. k) g; T使用同一目录下的json文件,反向映射出花对应的名字
    * U" A4 m" V4 y+ Q& {
    $ [9 u* k/ N" x) G6 X. V7 Vwith open('./flower_data/cat_to_name.json', 'r') as f:' j7 C' f8 D% G5 p4 ^% p
        cat_to_name = json.load(f)
    , `$ M  y) y# K( f: e' o# G11 H4 A1 V. R7 A/ S! i3 @
    2$ J" Y3 n! t! ]0 D" D
    cat_to_name
    9 r6 d( _+ J+ n! y( g" N. T1
      x( W" K+ x. ?/ Q{'21': 'fire lily',
    ) R6 D) U, G+ l* P '3': 'canterbury bells',3 ?; f& E/ Y3 D
    '45': 'bolero deep blue',# }6 W' H6 P9 g  U( O: n
    '1': 'pink primrose',3 S  L9 O( g! J: ?# y
    '34': 'mexican aster',& |2 D) G: U% m& D
    '27': 'prince of wales feathers',
    , p; e: A, y2 C '7': 'moon orchid',1 m0 y1 F* V/ ?8 s& k6 p
    '16': 'globe-flower',
    0 [! ~1 w3 A; \7 k4 U4 O$ O. q& ] '25': 'grape hyacinth',
    7 j0 X. r' p* Q2 w$ _8 u2 z& h& | '26': 'corn poppy',
    8 }. W+ ]0 @2 e+ i5 a/ D( [$ [9 ^2 s '79': 'toad lily',7 C( o! [' `% r! }+ ~7 S8 q
    '39': 'siam tulip',
    6 z/ f2 p2 N# Q '24': 'red ginger',# p1 h5 d0 P8 G  `
    '67': 'spring crocus',
    : ~! M. R* I# r '35': 'alpine sea holly',
    ' ?+ [0 F' E' ]! e" w/ ~1 \5 f '32': 'garden phlox',' M3 e& i( K) S/ h; E6 y3 A4 u" G
    '10': 'globe thistle'," p/ J  n1 R8 D! Y6 k/ N. X3 E
    '6': 'tiger lily',2 j; C  \# Y( Z8 ^# L% L
    '93': 'ball moss',9 {: |, `! _# N7 i8 O
    '33': 'love in the mist',. j* C! s. Q- A* L
    '9': 'monkshood',
    $ z  |- O+ z4 B '102': 'blackberry lily',
    ' j! f  M6 W; d/ X% \. J( A '14': 'spear thistle',
    $ ~% u! ]7 Q  z$ F4 I+ Z '19': 'balloon flower',8 ]! @% b7 K# O  v6 O2 E! F. D( g! t
    '100': 'blanket flower',
    8 O, s0 d3 m. u+ F3 O/ O3 d '13': 'king protea',: D' ?9 I3 a  f. K4 f
    '49': 'oxeye daisy',
    * k" g4 q& ~4 F  _/ K2 J( Q9 J '15': 'yellow iris',8 z1 \$ B! h3 T0 c7 B4 m6 X
    '61': 'cautleya spicata',
    ) Y  U( k: U1 A8 G( V  }- t, A '31': 'carnation',. Q& ?, |2 {$ Q/ F5 Z% B
    '64': 'silverbush',0 U& n* A( d7 R" j8 b
    '68': 'bearded iris',0 |6 J9 h( P* i- x  M! P" G
    '63': 'black-eyed susan',5 G2 _! p2 W5 ]4 }
    '69': 'windflower',
    % q( @+ \6 u" q '62': 'japanese anemone',0 Q6 K+ r, q/ ?9 X& U+ _; B! ~
    '20': 'giant white arum lily',0 A$ v5 O5 d% ^5 q5 q' h1 J5 J
    '38': 'great masterwort',6 S3 m! b$ P0 {0 |/ e
    '4': 'sweet pea',5 a. x% k% c+ F! d! b2 Y
    '86': 'tree mallow',
    % R$ h& L4 y% U/ x; x9 o '101': 'trumpet creeper',4 Q; Z/ T8 S) @1 z: W2 k8 G, d6 o
    '42': 'daffodil',/ p. W8 H% f5 Z( E0 J
    '22': 'pincushion flower',
    2 L0 `" ^8 F) U( a/ Y/ {+ D '2': 'hard-leaved pocket orchid',
    * O/ n; [' }! j$ `# w3 O7 n '54': 'sunflower',
    * W0 K  I8 Y$ b$ B# n '66': 'osteospermum',( M. }3 @; V* s8 d
    '70': 'tree poppy',; L$ [( h& k) ]. ?  g! V- W( P& M
    '85': 'desert-rose',* z+ O- T% o4 i( A$ [
    '99': 'bromelia',
    9 w7 e7 a' A' E+ L6 F9 i '87': 'magnolia',
    # t8 v  s1 o& `& i8 ?0 O4 }: a: H, u '5': 'english marigold',
    0 M+ b3 a- r1 u" `6 \9 ` '92': 'bee balm',
    ( P3 I( [' z% T* E0 b$ d; D" ? '28': 'stemless gentian',' |6 p8 M1 a/ l) ?, O
    '97': 'mallow',8 [" ^& s. r9 R- k: i3 m( e& Q; y
    '57': 'gaura',; E4 f+ l  V) @) j; U0 ]7 f+ |
    '40': 'lenten rose',7 b' F2 v: O, h. D! {7 E1 h3 U3 E: g
    '47': 'marigold',
    - F( r- R" `0 S, y) p0 S" w4 G '59': 'orange dahlia',
    1 |( `$ ]; g! M4 C '48': 'buttercup',1 [1 w! W$ J; o
    '55': 'pelargonium',. l  \$ m) r, |2 e
    '36': 'ruby-lipped cattleya',5 G6 r1 r' n/ M/ p! \' y; h
    '91': 'hippeastrum',; G8 t7 D+ |5 ?. j0 a% x0 O7 i
    '29': 'artichoke',$ x+ `6 P1 Q4 `' f1 j( {3 E
    '71': 'gazania',
    5 m& z& M* T' A9 ]) g9 C9 T4 O '90': 'canna lily',
    + Z) C/ m" o- N2 A( [4 H$ U '18': 'peruvian lily',, X7 n3 G; \# e1 E  ~
    '98': 'mexican petunia',
    + o3 E$ }* M0 @( s- A0 Z) I '8': 'bird of paradise',0 U/ y0 Z, W7 ~( J
    '30': 'sweet william',
    ; D& U  k& D) K% M3 d  M '17': 'purple coneflower',
    . D: x4 W- s7 P+ N) K '52': 'wild pansy',
    ! c4 d- e0 \1 \1 z '84': 'columbine',
    , ]: o% y4 \; C( ]: q5 L3 q '12': "colt's foot",' ^) G; @9 B/ J8 F+ n
    '11': 'snapdragon',
    : G1 e8 Q; h: Y/ Q5 ` '96': 'camellia',% u1 e) x& U1 s  y+ g2 f; Y" B' W
    '23': 'fritillary',
    9 p6 H5 }! P, y, \8 _' K, w8 y '50': 'common dandelion',
    4 ]( N& I/ E; H" J6 L '44': 'poinsettia',
    ; g$ X' Y! S8 d& f( J, C '53': 'primula',( \2 n+ Y2 |5 g! F
    '72': 'azalea',2 a5 }  W# u. L& e; Q+ Q; z( }! e
    '65': 'californian poppy',
    3 ?9 X2 Z) _# w! h' X0 t8 g. A5 ] '80': 'anthurium',
      t' h: i9 ?. ~: i+ F; x  r0 G '76': 'morning glory',5 ~2 k$ ~4 C8 V! K. T9 U& k
    '37': 'cape flower',3 I6 c% g6 r8 d/ m3 N
    '56': 'bishop of llandaff',! x2 u! _/ O- q6 o' l' }, w
    '60': 'pink-yellow dahlia',# Y7 }+ E0 j* a
    '82': 'clematis',
    & @  ^/ ^/ R1 Y- Z, r '58': 'geranium',3 D  I, u5 H0 D. B/ o* {
    '75': 'thorn apple',
    # R( K% V& _6 O '41': 'barbeton daisy',6 Q4 ]. A( p% m8 v
    '95': 'bougainvillea',* c6 v( R& i0 c! I) }: C7 B
    '43': 'sword lily',
    . R3 b9 ^- s9 x& E '83': 'hibiscus',9 c7 U7 ^1 Q6 ^0 O
    '78': 'lotus lotus',) a9 Z  q* D1 i; D/ S' R
    '88': 'cyclamen',9 {' g1 w$ `! M% V
    '94': 'foxglove',/ h. X# W" `/ j
    '81': 'frangipani',' R( h/ [2 K) @7 B2 m
    '74': 'rose',4 D; k8 b+ r" N- {- F" d% t
    '89': 'watercress',
    : F* r: X9 \$ E' f; q '73': 'water lily',
    * w# w! c& P( I+ p- b- i '46': 'wallflower',
    ; Q5 _( a' r! z( {  D% y- ]2 H '77': 'passion flower',8 u5 W4 Q+ h1 C1 y1 Y- d! G
    '51': 'petunia'}
    3 x" Y  x' \  k2 p
    # x2 S6 U" R9 b14 s; s" _' X9 ^# e
    2
    4 K# T2 F( Q8 ]. r# R+ [7 Q3
    ) m$ _8 j! \; b1 u4
    . z# Q% a2 J( `5
    ( x' D& n0 H+ x/ y# E: N5 N1 n5 T6" s6 i) R9 \( |5 q
    7
    6 [: h! {2 R6 E9 q88 x3 U& h; o% Y8 \- A: s3 |
    9
    3 o3 F/ c& \1 L# y9 Q9 d10, z3 z. z4 Q; s2 H5 v/ m
    11- [. U% S; C# [7 c% H" p# h: o- X
    12
    % V* Z$ r; D7 v& N! v0 x132 ?5 [) T6 `9 T2 D2 Y9 p2 Y: h; }
    14& ~. M4 N2 L. B; n5 i
    15
    ' N+ R6 P$ Y) M9 |6 x9 d16  a  h* w% ?. L7 W3 T3 x4 X3 V
    17
    ( m" H9 b& ^  ]& P, p0 ~18
    # {: M. J- O) W- ^3 z  ?5 M5 M$ V19
    1 X0 m! n. j5 j5 ^, Z/ |5 h& n20: N& b1 m6 N4 l% n# q2 r
    21
    . P+ n2 ?) q/ t3 W9 b# U# k226 ?0 Z( j/ s* y* a
    233 Z0 }; F) ], [2 y
    240 j8 M* U/ R4 M9 {' x3 x
    25
    0 a* G0 v2 W: d3 X& G: d1 @267 M1 Y/ o8 W+ p7 o4 x
    27
      c" M' f/ C8 X, o! S! ?% [# m28
    4 Q; s8 i! d8 M' Q( Z7 m295 z! i: k5 p  a; w: U6 w$ j; o
    30
    & @6 R3 ^' f7 @31
    # p$ v2 U5 d" a# b8 n32. d- }) N7 G; w) G4 d) N% N; k
    33
    / Z# O3 S/ E5 ?+ [" a348 W) s8 t' U4 i  k. E" v! k: S9 I- b
    35& x7 ]- ]0 R: }& ~6 O- X
    36
    : ?% `6 {. K  m/ R% a8 t5 d0 L37
    $ N- M( N( |/ |/ [4 v  T5 Z38
    2 J+ V+ U- r8 e0 m6 q- [9 J39
    # R0 J( F# l. O40
    ; J* [4 Q! ]) z( x413 v! p8 q- i1 I* M& q$ h
    427 Z9 W& L- H6 ~( i' \
    43
    $ Z" B; l5 X3 F. b* m44
    - l  L" @# S/ O2 V" I45
    6 v! K& c9 C% f$ R- H, F. T  C46
    9 u# E0 @: v' g+ y! ]) h47
    " H6 b! Z$ y$ O1 d, f. O7 f48. ], f5 E/ g: N5 _3 s# w
    49
    & O& k" r* \  {9 c1 V* ]0 D) Z50
    8 n" X- T" ~$ |3 L' d' n51" Y) k3 T/ \3 s& [2 W5 g' f& u
    52. R" K$ p( j* J$ |3 [6 z% ?
    53+ J2 ?& s* |: X; t8 J
    54
    ' E: M% ?9 J& a: E55  j5 b; N! }( k9 U2 g8 z9 P2 N, O, @
    56
    3 _8 z) M  I) W3 X1 D$ w: A& D57
    * ~8 c' Z/ f( {8 i58: k, n; p1 A) z: L1 ^
    59
    5 @4 M! j" I# f' m* o60) j1 d- S% K" {) U
    61
      Z) A. H1 U) ^: ~$ I0 h62) q: f2 M2 o. z
    63
    * A. Y0 \% T! t0 P6 k64
    1 t- k* j4 ]( \6 y" B* H: R1 `65% j; ]: b0 J. _1 m, S
    66) l# f( o' q* q* j' _
    67
    7 C7 }, k* C6 T. a# P680 Q8 n- b& \  i% d1 v
    69
    ( t- i* ~- o, t  D! x70; y4 n( t, @6 \/ k6 T% G; d
    71
    0 |& b- |6 D' C72
    6 z2 K6 w# c0 x5 d73
    ( S4 t* K6 M) N  V1 I74- P6 i1 B5 u( S/ [
    75
    + P. q/ b% u, f! X6 ~76- x. e- H) d7 i; \' f8 I. w
    774 c: g# n0 f- H2 B) a0 p
    78# u4 x; F+ L) q5 _. j) ]8 L
    79
    / C* Y6 ?# E6 D7 w% ~4 z80
    , O( ~1 g( H$ \! K) f7 `* @81
    ! }/ |& ^/ D" B6 P# J1 k1 S8 c82
      W- }( d8 C5 l- a- a$ q8 H83
    $ D/ f) O! s* l# r6 q; H849 R! n- w2 S- K2 q
    85
      r  ?) x4 I8 S8 w86* b7 f/ F6 k, a$ t7 [8 A
    87
    ' `0 m, g3 F4 {  o# K88$ T$ a/ @$ U# l' p9 O! ^
    89
    # z: J/ `# V' }: c( r* _' R90
    # u1 d5 x, ~2 z6 B$ l' T/ {91! s- i/ h5 F3 X1 L
    92( V# e, R! J; v6 A  m. w
    93! p" A  S' y, ]; W' Y# ^
    94
    3 C; z5 n# z( V' I% n" s( a% U( J95
    5 j2 p8 b6 \1 S5 Y969 k3 G. j# g* o3 c* n+ V4 @9 l4 I
    97
    + F! q$ F' R, b  x+ e( @986 B' l- r$ q: X* p: O, b
    99
    1 E; m8 V+ O& H! a" _1 {100
    " G7 _) r! B! P101# K5 e3 f6 t& v9 p7 q! Q
    102
    0 q  Y7 e- Z: ?) {3 e+ T$ u1 U4.展示一下数据
    ' C+ ]2 ]" [2 X; @+ Sdef im_convert(tensor):5 f7 G3 c( _* ^& n% h  R: p
        """数据展示"""
    8 w" B$ G1 F( [( d" r    image = tensor.to("cpu").clone().detach(): U! P4 v! @. C* @7 ]" E
        image = image.numpy().squeeze()
    : T5 _# ]. \- l- r: S0 [8 |    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图7 I: V8 S1 D) _3 Z# Y- R5 Z
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    1 B2 E& K$ {( k6 f9 d0 B, d3 E    image = image.transpose(1, 2, 0)
    + ^, c/ `! h6 }% U' {2 O+ b    # 反正则化(反标准化); {3 G( u1 A9 y" ^
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)). v: L* a& q7 o* z8 R
      k8 y( K/ i( b1 F. S& `/ C/ m6 M
        # 将图像中小于0 的都换成0,大于的都变成1
    ) c6 u- Q5 c0 x3 y' ]( M    image = image.clip(0, 1)
    ; _) }, v& t9 m# m5 N" T/ k2 D2 @
    ; }6 e% d% }( z; p    return image
    3 g1 \. H+ O$ a. B8 l1
    $ e, Y. x6 X* O. R! a. t2
    ' E) M0 h( P3 z' E2 A2 ~+ D3
    , l" s5 ]8 n& R9 |48 o- O% P) r5 H
    5  r) @# Z) A& G# z+ F5 W! x
    6
    7 `  O5 h' Z% p# J7 s7 S7 D7
    # H% |: g# R+ B; @! @8
    0 g/ N1 |7 ^3 L/ l6 o( r1 B3 T+ X" N3 L90 G1 f8 T$ w5 `, X! w# d
    10
      s- e6 G5 J' H9 D: v6 ^3 j$ ~111 \1 Z+ S5 D" N0 u
    125 ]* N+ Q$ V4 W
    13: C' U8 M; o6 i. s6 S% X
    14
    7 ]2 x3 y2 K$ o: T# 使用上面定义好的类进行画图6 Z) `  D8 I: x5 k8 j
    fig = plt.figure(figsize = (20, 12))
    . R4 S8 R4 d% z3 ?7 F# gcolumns = 45 q5 I2 e/ S5 l7 O- Q8 o
    rows = 2
    $ |' a% W6 x& L! ~1 ?+ Q, {: v3 \0 F& {& C
    # iter迭代器5 e! x  u' j5 t5 e. K0 h/ T4 I
    # 随便找一个Batch数据进行展示
    9 W& [4 c& i+ z! S; }7 c6 Jdataiter = iter(dataloaders['valid'])
    4 U& n/ q3 w( E, K# pinputs, classes = dataiter.next()' c  G( P1 ]. q0 a

    8 ?/ M, b: z' z  ~) [# }0 Vfor idx in range(columns * rows):4 D; \- ^, n6 ~& d7 x
        ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = []): j2 e  E! U; Q( e5 q/ g) b% ~/ I( E
        # 利用json文件将其对应花的类型打印在图片中( Y* b7 Y8 v0 f6 I$ q; k* {# G
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])" N* h6 g& o( c& V& J
        plt.imshow(im_convert(inputs[idx]))' `# L& @, K: V  w
    plt.show()* n9 @7 q0 O8 Z$ _- G: s

    . B0 ~: x& L7 o1 Q# q& K2 g( P1
    % K% J& F5 Q) _3 P9 L; i25 U' M& R& C% z/ p3 ^8 @+ V
    32 l& v. L( ?7 Q, N- f
    4  a0 k3 P7 L* M3 X5 g
    5
    / r. q! U( z0 g/ ?! l5 ?6
    , w) t" J. P" Y4 G* Z- V7$ v7 h1 M9 Q, Q
    8
    2 `8 t. T9 N( j. Y* P9# [! X7 E3 u+ |- Y3 N" @
    10
    8 A# |6 `9 ?; w" {% w% E11! m8 }" ]* R8 e1 t5 b
    12
    * W0 Y0 }' ~) _7 W! V* l13; y2 i% X# }5 T/ V7 O
    14
    2 `3 c2 s2 c. h2 ^8 U15
    ! d' d3 j5 L. i& D) A/ o160 |2 |( j  A2 r4 q- {0 [% ?
    9 _( I5 W) W( \  E" E4 D# n

    ! Y' n& n/ i2 j+ G& k# R5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    & X. r3 [% ]! w4 t; J3 b' J3 Lmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']' q! O& u9 D) F$ ^: z. E0 @9 B
    # 主要的图像识别用resnet来做) V- u  b) @# F& n% P5 }& T/ ~, y6 i
    # 是否用人家训练好的特征
    # a5 i5 B: \5 V& e5 c- Vfeature_extract = True
    - g0 r! v# p7 m1' m, y; ]: k! U% D- i% N5 W
    2
    9 D! U# ~* R* t" p4 C3
    2 Q( ?( u( A$ w4/ L: g. |9 Y7 b, w9 H4 a
    # 是否用GPU进行训练
    8 |: ~, Y& s  k% f+ y7 @: i5 Ktrain_on_gpu = torch.cuda.is_available()# @! b5 x+ ^; w+ }

    ) n9 z/ F+ Z; c! U; ]5 }" nif not train_on_gpu:6 e. f0 U- P" F6 _
        print('CUDA is not available.   Training on CPU ...')( d& c3 e9 O# w
    else:5 S5 o6 l2 m7 F9 f1 ^6 J3 O5 {
        print('CUDA is available! Training on GPU ...')  D7 k9 W/ O8 o2 t: ~5 t

    & L  J1 y; t' w. b+ z+ J' @8 hdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')% t$ }7 L4 Q" }! |% }! ?
    1
    + U* t8 g: R! U' m  `* B( P2
    1 ]: X, Q5 B3 s8 q5 }- C' s- O' Y3( D2 }+ ~3 w+ k1 X* Q
    4$ H% Y& I/ z, I
    54 s0 f' q% q/ M% d1 B9 D
    6
    % A. X& P5 `% w. X7
    ; N0 v6 D9 e- X1 r8. u  a: j; g, w' h) I) O
    9% D3 U! a2 b1 W+ ^
    CUDA is not available.   Training on CPU ...' B1 \5 p6 s; ?6 y5 w; o' h
    12 w/ F4 T4 ~5 s
    # 将一些层定义为false,使其不自动更新
    " L5 e5 Q) v* h4 m4 idef set_parameter_requires_grad(model, feature_extracting):5 h3 S, i  I- b/ h2 r* d" A4 b& M: M
        if feature_extracting:
    4 A) }! Y/ z. V- j# Z( H! w# `        for param in model.parameters():: r7 Y& t4 F9 v: [" P
                param.requires_grad = False
    + k; v# a. k% t. |* t/ j" l1
    - k( F" z; a$ n9 s' c. V6 A2; p! Y# R* M4 G1 a1 P
    3
    + V$ _& O9 `+ {3 A) e. E4
    " h% Q# l3 A7 ]* n5
    & j8 `$ ~+ m0 H' `2 R# 打印模型架构告知是怎么一步一步去完成的- U% T; v' K& _3 X% v5 x, N0 k& n
    # 主要是为我们提取特征的* I' T( L, m$ s4 Z" O* w" }

    + g! H( |! w" d0 k  x7 y2 ~* ?; N# [& fmodel_ft = models.resnet152()  {, U# |7 ?) \
    model_ft
    - U* o0 X! V/ d. P# C9 ^1 }# H+ I1
    , ?* T1 g  X6 d4 v! I7 A2 G2& f/ m8 i. g0 g# L, n& F2 d8 E
    3
    % L; B7 Y4 i3 ?% V* `5 z40 y* y. ~/ N4 S7 M. @' C9 e3 u
    5
    . x  B. ^, t( t$ t) `ResNet(
    ! K- P; `) Q* T( \7 G  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)! v  Z7 \+ ]- E# C8 l
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    " `! h. e( K& Y& t; }+ N: B, k  (relu): ReLU(inplace=True)
    " U8 @8 J) J2 |" F* M& u; M7 M: t  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    5 b2 F7 L8 ?9 W+ G  Z- z  (layer1): Sequential(
    " v$ I/ G- ]& W" h- N& P3 b    (0): Bottleneck(! t6 ]/ v3 o  j& h  Z! K1 O
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)+ s9 A- N" ^" C( t" M( R
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ( q1 G; |6 a- h1 L3 V8 L' g0 c. J      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)$ b+ A' @9 I2 M% p; c
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    5 k$ f) Q! a7 I0 K) y1 K6 Y/ f3 H      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)2 S; v7 U, C+ E: G( d$ H% A  J
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    - ^3 R: u- b1 l3 j      (relu): ReLU(inplace=True)
    , p9 `* o' t  _: B4 I      (downsample): Sequential(1 P* A, l4 K- k2 ?
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)) k) v% {1 d' `/ u4 J8 j
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    $ R! I4 i  C9 \" x) E      )
    ( ^7 @% V, e" C: h; N9 t; I1 }    )0 @3 q4 V/ z2 {4 w' D0 {
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。# B- x/ o) i. E- C
        (2): Bottleneck(% h6 f5 }" F4 H
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    / R' v* R5 B1 Y( [      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    2 M0 o7 Q1 T: }8 {  E      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)  A5 D3 m& g4 u8 x: G0 s  v
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)5 h# `& i. a/ n$ O- u- @6 f$ r
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    , J7 U9 F1 Z9 l$ P" w. Y4 j      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    # O, L9 S7 l. u7 ]      (relu): ReLU(inplace=True)
    ; k; S7 q+ C# w    )6 Y7 w0 a9 }& h8 T5 `- J
      )" W3 T1 h( u: y9 e" w
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))0 Y2 y) O( s3 @. Z% |/ \' s7 s
      (fc): Linear(in_features=2048, out_features=1000, bias=True)) \  F+ C6 }& P
    )
    - Q+ V) I7 \% j% s! L
    8 c- A) H, @$ y9 z1 u: G& O/ y16 l& P. z5 h- t  Y( c6 a
    2* m# q9 L  _% H* S+ u6 k) E( Q/ S
    3* u6 N- O! r4 B6 S2 e9 d% T
    4
    1 M) [3 o9 W' d1 n5# V; P' Y* g. f9 ^! A/ i; ]7 j
    6' ~0 J! {1 y. ]
    7
    & i% [, d3 \4 |8
    ' E9 K6 X" ?5 _4 h2 [98 X% H( x9 C7 ?$ N3 Y" m
    10$ [6 [$ g& l, [' a  x
    117 o1 S/ d" q" C% T  _
    12
    $ Q- n# v  M( R13. S6 R6 Q# L7 V; l, _& \
    14
    ( D$ r+ X9 r  C8 Y3 L; F15: ?% V/ Q# d/ v8 B5 g
    16
    5 T* u# D5 P; x2 Z17
    / s8 D( ]1 f  W% m* i2 \18
    ' H& n0 L3 U6 a$ w19* W' |. K  f# }2 U/ Y5 e7 I
    20
    + a7 A2 D8 O/ L$ t7 h4 [; h' d21  J5 v: t$ J* y5 p, l3 p; [7 ^" p
    22) A) d; R) m8 S& b+ U1 j
    23
    3 k5 W! W1 p4 c+ k0 `( w+ Z2 J241 y- b0 e9 Q: [$ r. I. m
    25
    5 O, v2 q+ l: ?) S& z  d26
    : F, e$ c& T6 O3 D9 r* s- `# D6 H27
    7 p8 L$ n) l2 V8 {( j# `28, O/ O, M* b2 q* v$ {# L5 i/ O9 E
    29
    ) V4 C, m! N6 J4 r30
    & N. F% o" W6 }, S) K# C5 `1 V31
    , Q) O* |7 H6 T: C# |32
    # G% C* ]* o. L" I6 S33
    2 _( m2 f5 n: \) T  Q最后是1000分类,2048输入,分为1000个分类2 ]7 |# l6 y9 |! k8 [9 {
    而我们需要将我们的任务进行调整,将1000分类改为102输出6 x8 q3 F8 @( f! s8 S& U/ D
    " C9 P% u3 v0 d' R
    6.初始化模型架构
    4 P+ {% N4 I  z: r步骤如下:
    8 ?+ ?$ p- L/ k' F! ?' m) g$ A$ x& h" B" F4 h& W
    将训练好的模型拿过来,并pre_train = True 得到他人的权重参数6 I; ^* {3 V( a
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)0 ^2 V% v+ x. i4 Q
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数4 A: R6 D. T& [2 g' l8 r  j. q
    官方文档链接
    - t. N: H# w  A. O2 {* }https://pytorch.org/vision/stable/models.html
    1 u" b4 l3 I6 b: U, M4 Q! C/ a6 ~, i$ T1 [5 Q3 h+ q. |
    # 将他人的模型加载进来
    # s: a3 I- m- X2 S/ _7 \def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):5 J. I8 V) F8 m6 a# D) N
        # 选择适合的模型,不同的模型初始化参数不同
    1 H7 k+ c6 D! p- t8 M    model_ft = None' \" x5 f, m2 n( X6 y
        input_size = 0
    ' b. `; t' z7 x* T4 ]& C
      G& A& W& j/ m0 t0 Y2 ?    if model_name == "resnet":
    " n1 x3 u) _$ Z- l# y6 G        """4 p6 O* \" o) `- X+ Y
            Resnet152
    0 E, I2 j( q% [: T        """. x' h* h% e% h. Y

    " G) F' K# B% @        # 1. 加载与训练网络
    ! y* l$ a9 B( c2 d/ ]* a        model_ft = models.resnet152(pretrained = use_pretrained)3 n. }4 O3 T' L/ w
            # 2. 是否将提取特征的模块冻住,只训练FC层
    # U; d( u. K; n! F# J" L/ I        set_parameter_requires_grad(model_ft, feature_extract)8 p: f- U' Q1 q5 h
            # 3. 获得全连接层输入特征1 }& \% ?; g7 {+ K
            num_frts = model_ft.fc.in_features; [) n% J% t5 s" y2 |
            # 4. 重新加载全连接层,设置输出102/ E$ D6 ]2 _+ b0 u& r( p
            model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
      |* R3 e  k( b1 c' M                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1- {; X' J4 T0 ~4 t; K1 a2 a
            input_size = 224& F& F- o+ [; K* O' a- J$ ^# U0 s2 c$ ?" {

    : q+ c0 r4 D& X    elif model_name == "alexnet":0 ^9 e$ \; s* i; M1 Z
            """
    6 f7 {0 G; B9 V' w4 D7 }2 A8 Z        Alexnet" M% x( T) c  J, ?7 u. s, |
            """
    6 B: B; j' }) x: q2 T        model_ft = models.alexnet(pretrained = use_pretrained)+ @' }, D1 G% K: O: @, c% C
            set_parameter_requires_grad(model_ft, feature_extract)
    2 p5 i1 P: }" N( Y
    ) y& q7 S# ~$ q0 z        # 将最后一个特征输出替换 序号为【6】的分类器  o9 u9 f( w- s& g: D% U
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入+ ^& A& L8 L) @; P
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)9 t& Y: B6 U# z5 @
            input_size = 224% h3 [! {" r0 T/ J7 G
    * O  u8 p9 L$ @9 H0 t
        elif model_name == "vgg":/ W+ a. S$ m$ [0 {' c; p: }
            """. h3 A9 }' d* K0 M8 O: v& k# Q
            VGG11_bn
    ; W: N' ~( ?5 y3 n! o' B        """5 G6 M6 {" m& K9 r& H' h1 k
            model_ft = models.vgg16(pretrained = use_pretrained)
    ; P( u. [0 [' E& s5 I9 n        set_parameter_requires_grad(model_ft, feature_extract)
    7 Y& S& ?7 E* A% S% [  h        num_frts = model_ft.classifier[6].in_features
    8 V* Z6 g* F! A5 D3 {8 U5 C; S* }        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    ) [; I0 _# g. |; Z* l* @        input_size = 224
    % J& Q+ U6 M) ^* `' X3 t* r" w$ Q: ^
        elif model_name == "squeezenet":
    & l. I# _$ i2 p# c$ e        """
    ; H$ g  m9 p& \: V1 a5 N2 A3 F        Squeezenet
    / r  w/ C9 t; V6 O3 C8 Z        """
    * t+ H7 m% F8 X        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
    - M5 k! t! n' \% }+ N        set_parameter_requires_grad(model_ft, feature_extract)
    6 J: Q8 W+ t/ z$ w) `1 e+ x2 S: \        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))6 F. [: s5 j% Z7 c
            model_ft.num_classes = num_classes
    5 n% j9 C# p. g5 Y' S- A& b0 o- n        input_size = 2246 j3 j. x  i, @! D6 |( v0 @0 T% p
    . r; I% t  W% q! m' l# X  S
        elif model_name == "densenet":
    * H$ a$ k' I+ I. r5 V: o        """
    : C6 y- D+ [) W& a        Densenet- _; K$ M6 z2 w, n' z, |$ G1 m
            """
    3 p! G0 w( p& L0 l1 R0 b) b        model_ft = models.desenet121(pretrained = use_pretrained)" y' E( A6 Y6 m, K; u: J
            set_parameter_requires_grad(model_ft, feature_extract)( _  `+ }6 y6 F7 {
            num_frts = model_ft.classifier.in_features/ f- f2 l6 A6 z0 L1 k: E4 {
            model_ft.classifier = nn.Linear(num_frts, num_classes)* J; D- ?7 Y+ e- @/ c' _
            input_size = 224
    - {% H0 C+ A( I& v6 u1 v/ d8 y* J, Z9 @8 B2 Y% L
        elif model_name == "inception":/ s2 l# K; f2 v7 B/ H
            """' b4 m/ o8 g- M
            Inception V3  C5 \. R& {& `+ p; {/ N4 b' y, s
            """
    ' X8 u! T& p4 m+ p: s. d& s1 G9 G( j        model_ft = models.inception_V(pretrained = use_pretrained)
    # b) E( Y  L* I# k        set_parameter_requires_grad(model_ft, feature_extract)8 E9 ]8 |" \' {! G
    5 Y, K3 B# D2 U) }! a
            num_frts = model_ft.AuxLogits.fc.in_features
    * e) \9 H8 c0 g+ h( T) F3 g! A        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes); b5 `6 q4 Z1 @+ A5 X
    : \+ c( V" b0 N" m# q
            num_frts = model_ft.fc.in_features9 o5 J# y- T; c6 t% `- i$ Q0 t& B
            model_ft.fc = nn.Linear(num_frts, num_classes)+ I% v: r4 z; {/ a
            input_size = 299% p- g% c% W8 p5 g& d( q- t7 S
    1 w8 T4 D7 r2 z. r3 K
        else:9 r3 L" V( H+ J! o. j
            print("Invalid model name, exiting...")
    3 o) W9 V, L; `8 a  c/ a1 F: |7 K: i        exit()& q! \- {/ E" V. a
    7 W% X) \- h9 l" i
        return model_ft, input_size
    & r5 ^: a' K6 z4 q; T) P1 l6 c# L2 ]( L- T1 l
    19 ^$ h& @6 D" h! S
    2
    4 ]) z% {0 j4 ~1 \3# Z" T; D5 @; Z  y$ X+ `, @
    4
    : }: T3 x$ {, I& U0 @) Q53 l3 D5 m$ Q! ?, N" K6 G
    6
    % N3 ^1 O3 u8 ]0 V9 X+ x7# S( O, R3 H* m) o1 o8 F
    8! P0 p  D; c/ @8 X# t/ w6 h$ O0 ^4 m
    9
    & `6 {: k1 o' u107 \# K, n. d( b& a0 Z6 _9 n
    11" V% C# |- B" g2 x4 W0 w- W! m1 d$ I
    120 [  m, y" B1 _( Q  [
    13% ~7 E6 o0 a+ b3 O
    14: b" ^3 o$ b% n3 c; p2 e2 ?8 y
    15
    : z6 P2 l' I) O2 p" {. X16
    % D& p% w. E  x$ t17$ g5 K7 c( I. q( ^
    18
    * Q* @( c7 i7 k$ G; A7 [. G5 ]: \195 @0 g6 a* l0 n
    20
    * o5 c" J! S9 n213 `: ~& G$ A+ v4 }! \; t  q1 k
    22
    7 B; {$ T  C0 z* m- g6 Q23' O2 o! S1 M, R! M0 Y
    24
    ! S+ Y& A' F; Q% c25
    3 p1 V- N6 w/ K. }# U5 Q26* ~: K3 F% d% Y+ m' O  ^
    27
    - V1 H# V- X( P( G9 m; J28
    ! t$ g. I: m4 x& M. u; |% g' M29. Y3 a, G! S; f. ^  u. ~- F
    30
    * O% I. s- f3 N, O, l8 }( [% B313 o( N) M( x! h* ]  k2 k
    32
    ; [9 M3 I. I  _3 c, R33' e; G) v6 i$ M- `* ]# p
    349 I: c7 P* |9 b6 p
    35
    / H- N+ W# B1 H$ v, p+ E, q36
    7 X# E; q6 Y$ H0 X37
    ! E9 J: P# I0 r# ]* q1 _38
    2 E! v% ~5 n+ N0 L0 P39$ n6 N$ q2 Z' l
    40" Z& `" b8 L8 h) o* L
    411 L3 }! y/ k0 @. ~" j- v0 L' R
    424 [+ z& t0 }, R% E8 z/ S
    435 z0 X& z+ U% `- `  [. W
    444 J5 m' p: R; z& f4 W
    45
    # N" G( _2 H7 l% B' ]2 T2 }46. I$ S* {3 g3 S
    478 x; W4 R6 p, e+ V% ]
    48
    2 n" p1 t3 M2 ~5 S49
    . W6 o1 q: L- {" s7 r- {505 w. D* y) q* J: U' V: @
    51
    8 Z) r( e$ B4 F" a. l. H. I52
    1 R6 r- {7 q: F% A4 T7 s/ X, ~. E53, s* g- e, G- }. X6 L
    54
    , ~; Z& s" S8 Z7 h55
    3 Z- {" X# Q* P1 }0 j56
    , i' |: Q7 W/ U57' r& n4 A  f' A* H; u/ Y# u$ [" I( K
    587 R+ l9 x* m0 }# S& H1 n
    59
    ( j% I5 e! f2 _; b60
    ; w! A" D; A5 b/ A61
    : U' U1 y* S  b$ r+ j4 |62
    ( \* V+ y& T  K- N& c9 |63
    7 G9 ~  `" ^; w. _64
    # u6 j  `: c" B5 t: V8 I6 r65
    $ t( s, m( E/ c& ]3 k66
    ' q! J# b" U( V- h2 n2 U67# |! l' B# U0 Q8 P2 o3 k
    68" ?6 P3 l" T5 R8 h* e) Y
    69
      N! G' R% k8 L70
    0 u$ F  \! g9 [- K# l; C1 O( s71# B$ L0 o& ^' W4 X0 y3 ?6 T1 ?" j4 W+ D
    72' A) L0 {# }: V# D4 p  h
    73+ C% }. ?* l: X. P+ z6 `" o; e+ b% @
    74' `6 C6 J6 K& E5 E7 n5 F* ]
    75% C) a7 Z3 p9 @/ U1 x
    76+ q( _3 Q" p  C5 M: X3 e6 d2 @
    77
    2 H1 D- Q( t9 d! E782 v, W8 b: P* t* r, h
    79
    ' J9 c3 C# L1 ?80' m( [$ E9 K; L* K
    816 I3 Y$ |+ ^8 ?8 y
    82
    5 w- k; t- R- f- [6 P1 u' `6 b. D83
    0 I* X+ g# m; k. c- }7. 设置需要训练的参数
    3 X4 t6 _: ~7 U% B- r6 N6 S, \, V# 设置模型名字、输出分类数6 n6 v- R" X7 f, [: ^- g3 V
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)6 S/ Y7 [1 ^# o, J% p2 c7 H
    ( c$ U6 x: W: I3 R/ \' u" o
    # GPU 计算1 c1 Q: B& l3 \
    model_ft = model_ft.to(device)6 w) m, x2 M$ [! k' a
    3 v8 T- Q1 J+ E# U  U
    # 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    5 [; L+ G* @: O. I! n) Y8 r3 jfilename = 'checkpoint.pth') b# D' ^5 ~: M# ~. X. p% m0 b
    , e. i- W( C; P% D* _4 g
    # 是否训练所有层7 O2 l: z! h  W8 j5 Q# l
    params_to_update = model_ft.parameters()
    6 ^8 S) M& f; Q+ U. D# 打印出需要训练的层& Z( `& H' I# v5 b
    print("Params to learn:")7 P2 I. S2 s! B( k& D$ ~1 I# p
    if feature_extract:
    & ^7 W; m1 G, u6 K0 Y/ L    params_to_update = []- J1 `4 x, `) N% y$ x
        for name, param in model_ft.named_parameters():
    , q9 o9 }+ C% x9 d        if param.requires_grad == True:9 S  J' z4 R3 i: g
                params_to_update.append(param): _9 g4 Y, S  i
                print("\t", name)* p" G  n7 J$ b( B% R' }
    else:+ a5 O/ X9 |- C4 y
        for name, param in model_ft.named_parameters():
    * [' d5 M1 T* V( H0 h        if param.requires_grad ==True:
    ( F. b0 k6 s$ t, ^8 Y, S2 q- {            print("\t", name)/ p5 R6 o$ \2 }, t3 H4 {

    ( J7 ]3 o; k: ~' N/ S1 a- f1. w3 @- l0 h  Z
    2! L6 D  C: g5 m+ i' l1 B3 r
    3( ~) B& t7 x% m
    43 s* [7 R, U5 d. V+ \# J
    50 p: S. ]& |/ z- Q' b  @1 Z. r- t
    6
    0 u# p. Q( T' {( t6 f5 ~7* K- I$ I' _% @5 n5 v2 i# `
    8
    : G5 P2 X; U0 g' p- k, v9) ]" `3 k" P2 g, j6 R
    10
    2 z1 h9 A  {. F* t11
    $ e' a+ Q7 `/ {, x) U. d12/ Z% b4 p/ L, ]4 n! |9 d# H
    13
    3 s5 a7 m5 S) Z3 J: h14. {5 I4 A% Y& Y5 ~7 O
    15. _6 J6 e- Q& O. w# e( l, T: K4 T
    16
    3 h3 g' g0 k. l$ j17
    8 l- K' Z- I! _5 o18
    ; }6 w' Z2 K" T5 S7 z9 w0 T2 a" r9 `19
    # |8 @" |+ Y2 q: k20
    & X+ _1 `9 X6 Q$ m& V  E7 G7 \21
    & R8 J: `8 F4 W/ \, {223 n( I: E8 q- k( A1 o
    23
    3 K  V( r# ~( j, M' nParams to learn:6 m. o. t, v) @* P: S+ D
             fc.0.weight
    % c, _) \# ]. q         fc.0.bias
    7 q7 n3 @2 g/ M  v1  m& @; Y5 s* E' l! z& H
    2' _6 ?7 _- T; W3 i! w* }& q, @
    31 y1 [% h" Z1 v( h1 W
    7. 训练与预测
    - y9 q+ v& i2 P. n; G7.1 优化器设置
    0 `" a- ^# u3 _) X# 优化器设置* t( U/ C7 c9 q+ d! L8 C4 r
    optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2); P2 |6 f5 x& H- P# R
    # 学习率衰减策略6 A. a4 ?3 x; f' }
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)3 M  b3 J1 u2 E8 N9 m& u' y4 X3 u0 b
    # 学习率每7个epoch衰减为原来的1/10
    ) T/ f0 P5 h9 ]" [8 l3 R0 {# @# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算  O* ^0 {. ~% B5 R. h
    $ j; |$ V" T% ]6 A* Y1 b3 _
    criterion = nn.NLLLoss()
    ' |& @- V7 S2 M# ^( ^1: ~' O0 H7 P, P! I+ n" I
    2
    ) W8 K' H% I& N1 t2 _3 {8 G3
    + ^- U. @8 n% }44 E( b/ R' Q" T0 H
    5% O6 y% d/ x2 r7 x* ^
    6
    / h' J* P5 x9 d) p; b  g72 K  ?! [% o% B# p+ O
    8
    2 @* q3 |, A/ G# 定义训练函数
    7 s9 ]% s/ j+ W4 u3 G- f7 V$ B' k#is_inception:要不要用其他的网络/ h1 m. t/ {  M) ~+ J( \
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):# i0 h3 W( L) c
        since = time.time()
    & h% W  ]' m4 j, R) \. f2 q    #保存最好的准确率
    0 _+ B1 ^4 |2 M8 g    best_acc = 0* ?! z6 w- r' g% G8 J9 e
        """* v% u3 d6 I+ ]+ e
        checkpoint = torch.load(filename)
    . m) q. i( [3 Q    best_acc = checkpoint['best_acc']- p( L5 u# u& g7 r# ~% u; k; C2 @
        model.load_state_dict(checkpoint['state_dict'])  _0 r/ y, d- o' r0 i% W
        optimizer.load_state_dict(checkpoint['optimizer'])% T( y- V- y1 {
        model.class_to_idx = checkpoint['mapping']
    . K$ k) e3 a& l( [( C6 @& D  I    """. W) A; F& s( X. {
        #指定用GPU还是CPU3 k; ~. R) o0 P4 u
        model.to(device)$ R. A) n: U: A3 Y
        #下面是为展示做的; W0 C$ I4 D. n) B
        val_acc_history = []
    % v# Z" V2 t+ i4 l+ D8 D    train_acc_history = []. O2 Q/ I$ `- C8 }
        train_losses = []
    ) l1 C% A, z, Y! h8 ]    valid_losses = []$ u( t5 S& z. i6 w
        LRs = [optimizer.param_groups[0]['lr']]% i1 t8 D0 [1 z9 J- M2 [
        #最好的一次存下来" [/ m5 g7 M8 w+ z( R3 |  ^: X
        best_model_wts = copy.deepcopy(model.state_dict())
    1 ^& s: S! l5 g5 {0 E4 @8 C+ |. a5 r- L* _- S2 B; h! x: d, q0 T
        for epoch in range(num_epochs):; ?' G1 C: ^# p- e9 v% [/ i
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    . y1 X: v! S' T1 h& G( a4 t% I# x$ N        print('-' * 10)
    / k$ Q3 K( r/ l5 Y* L! N7 G8 |( B" x6 a! H8 J3 h' `
            # 训练和验证
    ! p+ Q% W+ ?4 R# i3 X: }7 V        for phase in ['train', 'valid']:
    3 y6 Y; D# X8 W  o) h2 v+ S            if phase == 'train':+ ^- C1 j2 g3 E: t& G# ]
                    model.train()  # 训练4 t' b; \- T/ E' C5 D
                else:
    1 ?7 n; `0 C: b4 N2 F1 Q: `  t5 \                model.eval()   # 验证
    ' x* W9 \8 b( Y5 v4 s4 S* i9 d! b) ]0 T6 k" @! k: P
                running_loss = 0.0/ m+ {5 J) S, T) J  q- w
                running_corrects = 0/ [2 J: K- r) l
    ( X# i+ f& e' i  q
                # 把数据都取个遍* f- c- M- u$ I* P3 l1 h
                for inputs, labels in dataloaders[phase]:6 X* N. ?0 i- G! E  n6 x9 x
                    #下面是将inputs,labels传到GPU, C5 g% W1 w- r5 W+ e! }
                    inputs = inputs.to(device): h4 R) Z8 m% B' ~- V3 f) @5 Z: t
                    labels = labels.to(device)
    : v  Z( ^5 X6 _: O  y6 ?" b; _+ _" R! M/ c2 A, j/ L/ h
                    # 清零5 W2 V4 B5 Y+ q0 P% W7 E( ?# b6 ]2 C8 N
                    optimizer.zero_grad()! A/ M+ _. l+ R9 d4 Y9 a0 i9 O
                    # 只有训练的时候计算和更新梯度. B9 k6 t/ ~  h: }7 I& X
                    with torch.set_grad_enabled(phase == 'train'):' z' n- q- b5 q) s
                        #if这面不需要计算,可忽略* j% Y" S. |6 I8 g/ ~* l& |
                        if is_inception and phase == 'train':8 j! C3 [. a0 Y4 f. r: j4 S
                            outputs, aux_outputs = model(inputs)
    ) F/ j7 z# n' ?' X8 j; }4 D3 e; g                        loss1 = criterion(outputs, labels)
    * b& i' r7 e0 x                        loss2 = criterion(aux_outputs, labels)
    . b# _4 K3 b" O: F9 M, |                        loss = loss1 + 0.4*loss2- ~; B% E1 |6 h1 U6 w
                        else:#resnet执行的是这里
    ; I& ]$ z- `- ^: {0 [, x  q4 v                        outputs = model(inputs)* X$ m% N  u6 [1 }, O
                            loss = criterion(outputs, labels)
    + f6 F9 k& j- t" n& S
    6 n' f: h; v3 {' r2 y  T                        #概率最大的返回preds
    : N/ Q- l* n+ e- x3 ]                    _, preds = torch.max(outputs, 1)! h3 `; e- u8 C) @' N9 b

    1 p2 B6 J+ m- Q7 O                    # 训练阶段更新权重
    9 \6 F% n3 V9 x8 |5 F, n- F                    if phase == 'train':( B4 I/ f: {5 h& X& e3 A/ x
                            loss.backward()
    4 W, n, R) j" v! D& z) o, M/ |+ r' V$ x                        optimizer.step()9 S; v8 ?! [/ _2 t. R
    7 ~/ ]/ g5 Y/ V6 h0 r% N# b
                    # 计算损失
    ' j& q2 y+ H6 R3 R! f                running_loss += loss.item() * inputs.size(0)3 }) p# u8 @! E8 {4 B
                    running_corrects += torch.sum(preds == labels.data)
    5 t4 a; v- P1 r, I. g, h7 x- C
    8 S4 w+ o. s! p8 ?            #打印操作
    ! y+ J% t$ V5 {+ H7 q) G" [& `" }, [            epoch_loss = running_loss / len(dataloaders[phase].dataset)) w% U7 I! J# c1 n$ F
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    ! v5 v$ m& p/ W8 l% Y
    0 h& K) D7 Q2 d+ P6 ?, f0 {" c9 f
                time_elapsed = time.time() - since
    9 Y5 \, M6 S  a# s1 f            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    ' P1 o  D0 E. d4 H8 s            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))( [: ?1 k. Q$ l' s

    ( L2 }* B6 @# ~; [4 ?
      g9 C, f- J/ n8 b            # 得到最好那次的模型
    ' J5 a9 L* r$ @9 a3 ], k) Z            if phase == 'valid' and epoch_acc > best_acc:" k% ]1 @' t+ L- w4 L; [- S: Z9 d
                    best_acc = epoch_acc/ s( ^: R/ g% }" }/ A
                    #模型保存! B4 r  ]- b8 x) z- Z: p4 n
                    best_model_wts = copy.deepcopy(model.state_dict())
    & ^6 g4 h: `% h9 `( k+ S                state = {0 k7 _; @+ x* J1 x5 L& I1 \$ Z1 E
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    ) L. C2 e/ j/ P: O3 M2 g                  'state_dict': model.state_dict(),
    : \6 z1 M& i; R5 h7 _0 E7 ]                  'best_acc': best_acc,1 _* p5 n* y' O, {7 Y
                      'optimizer' : optimizer.state_dict(),2 r3 y  ]: y4 p$ T' R' Y5 q
                    }4 U3 _( k  E$ [4 u/ t, c. k
                    torch.save(state, filename)
      d0 B0 _# c1 x+ G  I: K9 k            if phase == 'valid':3 c0 Y3 k# a7 [2 P1 D2 L
                    val_acc_history.append(epoch_acc), N7 ~( _+ f0 E
                    valid_losses.append(epoch_loss)4 f/ _( E9 V0 R
                    scheduler.step(epoch_loss)
    ) n7 k8 {, s6 @, G) j% C/ B            if phase == 'train':
    9 p5 u" m# N( }7 X, n                train_acc_history.append(epoch_acc)$ w5 r" \" k. D
                    train_losses.append(epoch_loss)
      }+ H- `: u, o) g6 h6 U
    ! X7 P0 B) h/ o, P5 ?9 c# R        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))( H4 S* x$ G, o4 W3 ?' i
            LRs.append(optimizer.param_groups[0]['lr'])
    " Q0 W( ?* V" D! B1 X6 ?1 C3 p- E- G        print(). ^2 j7 ]5 G3 h9 ]
    , y" G4 J5 R- t/ Y1 t% R* p
        time_elapsed = time.time() - since
    % n' E" W; p, x! e/ W$ r    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    6 O+ ]' b4 P3 b( k) r8 x  e" J: I& A    print('Best val Acc: {:4f}'.format(best_acc))" T* |. A( X; v( b) G2 z
    : {0 N6 H- p& F
        # 保存训练完后用最好的一次当做模型最终的结果, F0 O5 |3 u) r
        model.load_state_dict(best_model_wts)" b9 m9 F+ @; B# c6 ]4 p
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs / v4 O0 Z* D  h; H8 @+ E8 l

    / C% C* H7 u. \, \5 }
    * u% T' P# S; x) T1
    3 \/ e, V' d  X5 r2
    3 S* C8 X# V% x# b' O) Z3
    5 Y5 n% l' Z, X2 Z8 c) |  K43 p7 C1 Q* N2 i: q
    56 X, C+ y0 L) C* n+ q- `; s
    6
    . ~6 S$ e( r* X7  o; G5 E7 W! k9 _' N5 o
    8
    . a1 Y, G5 x  [. t& ?. b9 _5 ~9
    4 S8 [" i# x- l! Y10; U4 C' B" |; r$ S* }. g
    11. Y- U: p) ]" p7 P. B5 Y7 O4 J3 {
    127 B7 F( m" ]) h( J7 A# x: g
    13
    5 |: q( t. A0 D% v% F14
    3 x: q( v9 t5 x0 l) j152 ]% q% R. p4 n  K2 [. j/ l
    163 H+ ?! I" G4 L
    17
      i2 W# G2 Y2 _4 A18: h/ z0 h, k! I2 ?9 d% H
    19
    7 H) N3 S5 z; ]5 S! B# K  R0 H20- A. h8 z- {  l8 H; A
    21. F1 F6 e, o" o+ }2 V6 L$ n- i1 y
    22
    . n/ R3 a! ~+ _7 U4 i- P23# ]* i. s* j4 n
    24* k: G) L; \. @) s6 c
    25
      i' {* Y5 [0 D, r0 {26
    9 ?% i: ?5 T' }27
    $ M2 U* B0 D  R28- J  W4 K: s* h7 S3 G1 o
    29+ j; j7 {5 k% ?7 _
    30
    , l7 z- T5 o9 Y31
    7 e" S7 o: k# {4 h32( B8 o! p6 L# O% t
    33# m2 k  O8 l0 m- k3 F- W! c6 ?
    34, q: v/ _4 ^+ m% Y* {8 I$ F4 }
    35; J: _+ r* h, M. Q2 \" a0 I% D1 L. d
    361 E& o( P, |) ~
    37/ g$ Z5 |. V7 p
    385 o3 W8 C5 r0 ?: j7 K
    39$ n6 o- ~1 l& f! I3 t7 \3 @7 `
    40
    ( L) F! q/ m8 n; g% O411 L7 r) G0 @' q" m
    429 p& J& g. W& S9 @; Z2 u) t' [3 [+ E
    43
    ! ?7 O9 C/ C+ H) a) X44( x0 R7 P$ A/ K5 G2 ~  a6 q
    45
    # j4 e$ X) @; q+ @" k46
    2 }; Y4 a4 l. ~( c9 n471 C; l7 j/ v) U& E) e5 o) z
    48
    9 U0 @7 g# W, N49
      @) E# o5 @+ a& @" [+ \50  w% u6 G6 R7 f
    51' z  d- e# S" d4 Y* n
    52
    8 o1 y' F& ~) Y- `- y( ~53
    1 ^: u" q7 q% _1 J! ]54
    . r+ k7 N/ t5 n) g: |- W5 A. c55
    4 X+ \2 ]- A* d3 q( t7 v56! Z3 W* U. x. k6 W1 x7 N
    57
    9 ?4 s& d4 m/ A6 W58
    % v/ N; U! z$ _* m59, Q9 G6 A  I, U* n, B% Y
    60; D- E/ J! ~7 x1 m
    61$ r9 v9 P- _" ?6 L: h
    62% F1 d9 A4 s1 R7 U# Z! j* u
    63, B9 x6 q$ O& X, |! W2 L# h
    64: G- a; a6 B$ N' j
    654 \9 s) k& [( D; n1 Y2 x7 t' A
    66
    - J- I6 N* b( M3 d67
    1 p, Y2 P# b' ~! ]/ p% L68( v5 Z% C' N' A  `* W
    69* Z2 @8 _+ x; U, l; p- y
    70
    : m; ~+ T! w& k8 i0 M71
    & J6 G6 l5 Z, ~, h* s4 j; N72
      e# A( Z! z% K1 h5 \, t7 X73! s4 m, B% a/ q0 a$ |, Z
    74
    . X. e* Q4 T3 [0 k75
    2 j6 e" g2 {6 E; W) a76
    & J% U' N; ^1 j77% C6 g- }- z( v+ m  n
    78
    , g" O+ \9 c( u$ |9 X79
    ! a% Q$ k9 q4 x1 M80
    # Y. e, T( M0 U) J$ O6 b817 Y% m. h3 M( I( ?& X& b
    82
    * g6 X9 p7 Q+ t83/ V, {: i. u) c# x% v1 n' ~
    84
    2 t1 G$ E, f6 ^85
    9 t# {" c& r1 m6 S, L3 M% p. G5 f: [869 A' N2 x$ s' i9 D0 J
    87
    / ]7 p- n2 k6 H- Y* e880 y7 Z  K9 A+ f0 R
    89
    4 y& J2 v. ]. W/ u7 P! F90( x5 J2 Q. `5 d1 b
    91
    6 E: ]$ ?6 [) E$ Q" e4 \1 |4 v92) z0 w3 ~" F5 D7 X. ?6 t: w( W
    935 D2 X; P: L$ w* X$ R
    94
    7 L7 G+ f' o8 o# c- C951 l  Y5 \4 [. f& x; m' b8 b) r9 i
    96
    $ T( b! L$ D3 Y) D/ H/ s% E97
    ! _0 ?1 ~! g$ F& K9 {. S2 Y988 }( ^6 d9 Z4 q% I" W9 i' A
    99$ W0 O# p, Z* k! Y; N
    100
    ) {4 d3 I3 F6 m, M7 n) v* }( _. \101) C+ c  `1 k  b  t- Y" \- n
    1020 `; E+ \1 {/ n) V4 Z$ w2 o
    103
    + {! ~- j& b3 s2 {0 f% R- Y1041 i5 H% h/ ^  n, x! O
    105
    1 X# M5 _/ F6 O% B1 [% s106
    . h# j8 x4 O6 s, J$ m107  Z# ?  A/ |4 h: E
    1088 ]" S8 {7 q% p" Y
    109* e/ D- C) `0 e( ~2 D
    110' Z  m6 F# g  `
    1113 k! ^$ t. O7 `; [( t$ |& h
    112
    1 p! e, G. f, n6 ^" ~9 {- B! N7.2 开始训练模型
    : j" H$ @# u% k& B6 E我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次- ~! d* f) J8 N3 C

      s! Y- G4 I! U0 r3 T& {0 \) H0 u#若太慢,把epoch调低,迭代50次可能好些9 j  @, B5 e5 i5 P) n
    #训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    % K1 t. ^: m0 [0 xmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))3 Z2 {: U; J: B8 K) I& w+ v

    4 G1 T" w( l3 \' |: l+ M1
    ( T: b/ A, k" |' j6 W# B( T+ v2
    ! K' j5 r/ c( Y, B& b3" U6 f$ }' B8 B: a. q# ?
    4
    3 j' d3 _3 ?. V8 c' `Epoch 0/4( n' }2 v2 }4 z0 e% b6 D1 e
    ----------5 m$ e9 b9 M+ ]1 _" P  g2 V( ]; H
    Time elapsed 29m 41s
    / \) ^9 f/ s$ Rtrain Loss: 10.4774 Acc: 0.3147
    + O& G; n% {" ~3 n0 ETime elapsed 32m 54s
    , B% c: ]$ F6 Z; {valid Loss: 8.2902 Acc: 0.4719& V* F) C8 I* a: F! T
    Optimizer learning rate : 0.0010000" K8 m" Z5 c" Z% P& E7 \
    6 V+ W, @% Z' E6 h
    Epoch 1/4
    8 o6 ]- T. v* Q0 h----------% Y. R' A2 }& j( H2 Z6 \* e
    Time elapsed 60m 11s4 U. R( C6 d) j! b% b
    train Loss: 2.3126 Acc: 0.70539 n( L- d& W5 C1 r
    Time elapsed 63m 16s) l" O! c# d9 h1 v
    valid Loss: 3.2325 Acc: 0.6626
      x0 Y* t% N3 `" y1 |. Z( [; h8 Z7 kOptimizer learning rate : 0.0100000  {2 N9 N7 |% P9 A) o( s

    8 H+ {: a5 @$ ]- o0 B* K9 ]; i6 g  [Epoch 2/47 Y- z8 V: U, v- h: b7 ~
    ----------' a' O  N& A; E" A
    Time elapsed 90m 58s) r: L5 l3 y. Q* @/ I
    train Loss: 9.9720 Acc: 0.4734
    ' @6 ~4 y8 u+ v$ u  X) \Time elapsed 94m 4s
    2 x& e# p9 a! a: ~2 s: }valid Loss: 14.0426 Acc: 0.44135 t, x7 T$ ~9 [  l2 H3 ?; D
    Optimizer learning rate : 0.0001000
    ' o1 k$ u3 x: F1 @3 r' W2 ~$ L1 |; R; R, m2 T* l
    Epoch 3/4
    & A6 I) p! l. k$ {----------
    1 V6 E2 ]! o1 w, F* fTime elapsed 132m 49s
    ( {1 Z5 v( E! ktrain Loss: 5.4290 Acc: 0.6548
    # X, `0 ^7 @5 j3 z0 |Time elapsed 138m 49s
    2 \* N. O6 c7 W$ o# {7 r9 jvalid Loss: 6.4208 Acc: 0.60275 s( K% T' v( t  y
    Optimizer learning rate : 0.0100000
    7 |6 H8 C6 V7 F3 @; Q
    8 @) Z& x, O- ^" {% S/ kEpoch 4/4
    * E+ e  w) y) [----------
    8 `7 M  ^+ h; B- B; TTime elapsed 195m 56s
    * e; n2 u' v1 h) r0 {7 d! \train Loss: 8.8911 Acc: 0.5519! u1 M% Q/ u! M' m/ B
    Time elapsed 199m 16s. \- |' R, U  U
    valid Loss: 13.2221 Acc: 0.4914) t9 S. D" y/ S9 P7 H2 Y
    Optimizer learning rate : 0.00100004 [) d4 }8 T2 k+ X8 R2 Z
    , }" u" P$ b: t2 c7 R1 {0 N
    Training complete in 199m 16s
    7 G9 B( w% r: wBest val Acc: 0.662592+ W" n5 z5 z- F; ^

    # T! F7 j0 L4 `5 Y, y7 o+ ?& z1
    - }, f9 j% m  w8 h2% |+ [1 }, t! `& k
    3
    2 j$ n# i+ W) L) B% l/ S4
    ' g1 h" {* i0 p/ P5$ Q/ _8 V6 [& X6 w$ Q  n
    6
    ' h) P5 W; x4 P9 P. X7
    8 t. s8 r; M. u; S* M: F; {8
    3 A2 h7 H% m2 t4 g) v- \' b1 k9- ~8 G. E- J- M# f/ k1 K4 o/ g/ m
    10* n5 W/ u) k! e0 A1 Z
    11# H, h4 f3 i. Z- s, n# R+ r9 V" w
    12# l% D& C3 r: I( V/ A. g
    137 ?% c, f7 l+ J. o+ u2 {8 |
    14
    + a5 Q2 k9 n$ b6 E% H15
    * _. x. v$ I6 s: y& h8 |2 ^1 g169 s5 g$ H% a# ~9 O6 g4 `: g4 s
    17
    8 d- ~/ [5 {  i) ]  ^1 d18* n) Q( M$ E' i2 B) ~
    19; v# u% ?% _$ {- n; S
    207 j: X& z1 ^$ Y) f& H4 \2 W
    217 G' z9 \8 u1 ?5 s
    22
    + c. M$ p/ l' b8 R5 ^. r23# E$ J8 W, S/ r; I1 C$ ?* y
    249 S" L( O! {3 z( j3 ]
    25
      Z) P7 ~4 k. P# S0 S26
    ) }; g. {7 i3 M. {1 f0 e+ P# W2 t27" F# m- m6 A3 Y/ Q8 k
    28
    3 Q, \$ M8 q4 r4 `6 T) M3 R29
    ; y( F* Y( A1 u7 \! N) n30) y: A# H' s$ B" U3 t, ?
    31: A4 s$ x( N; x8 @% Z+ f( i
    32' Q( P( h: |% C! z  ]
    334 n" T9 }% ^( r1 V0 \+ b6 ~
    34
    / {5 A, n. N3 w; s; U35
    + k* R8 e" E7 ^1 W36
    9 S5 Z( d% c, V6 C37
    # e4 M4 `6 o- U+ x9 m; |38
    ; u) I4 a- a4 F7 f39" I3 w" ]. n' U. `) n
    40) x; S9 d; K; w4 F9 N) {
    41
    ' H) S% ?2 Z: N' v" q. M% |42% \- z' K. q! Q+ F
    7.3 训练所有层
    # _% v0 ]+ J& E5 U  H# E2 G# 将全部网络解锁进行训练
    . k; [" ]' C0 C6 u( }7 ~5 l: R! N) Yfor param in model_ft.parameters():
    ( Q& F' I3 W; n1 U7 G- c# w6 J, D    param.requires_grad = True
    ) Y7 r- Q& S# S& `2 s: ?% O5 ^2 D: d4 e' l
    # 再继续训练所有的参数,学习率调小一点\
    ! `/ f" U9 ^( B, a# k; loptimizer = optim.Adam(params_to_update, lr = 1e-4). h3 |3 c- F+ v+ a' |, G
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)1 O6 w3 ?' R+ D
    7 a, h8 B- x. Z: H4 ^( b
    # 损失函数
    / \' _, e/ O: Q! Y  e1 M+ b. Dcriterion = nn.NLLLoss()
    & h$ m6 m' F4 h1
    + Z" _0 k0 ?" e; K; d$ @" U2 F  j+ k2( P. i, t9 c- V  E7 n8 u
    3
    3 b# L7 E5 N& r  B# \8 k" x9 d4+ }2 s# V' L, a0 K
    5) a7 `2 I& |- b) ]1 j9 r
    6$ {* l0 A% E, {2 J" Z$ R* v  d
    7
    8 {1 ?( B4 _0 I. c( A; R# i8 I; t8
    9 y- ?5 @- R& J, T" W  Y9
    6 n6 y* _  q( K2 h- `104 p7 o5 v) O4 n' c2 {  t
    # 加载保存的参数$ O; `  w2 r% Y: b( F
    # 并在原有的模型基础上继续训练" U& H- i6 V  p7 ^* y* G: W+ M9 ?
    # 下面保存的是刚刚训练效果较好的路径2 z0 X* i: Y0 a' \; G6 J+ p
    checkpoint = torch.load(filename)% e# c1 N2 p6 {8 X$ r! y( V* B( B
    best_acc = checkpoint['best_acc']
    # Z  h7 W5 q* fmodel_ft.load_state_dict(checkpoint['state_dict'])
    4 m; A3 c5 K5 ]7 O' |9 M# Aoptimizer.load_state_dict(checkpoint['optimizer']); U( S; |3 ^. e' J2 c0 q
    1
    * B3 @* Y% y6 T1 s/ o, s2! H4 }7 A3 a1 i( ?3 h
    34 W. L: i. D3 F) Z& c+ L6 @+ W0 b
    4
    : T  d( n2 i1 a3 K50 i+ g0 r0 S! p( b) O0 C
    6, O6 I. ^3 P/ \0 h7 F8 a
    7% E5 f" E% P# A
    开始训练) M/ U( V5 ~7 i* j9 |
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考% k  S6 [" d; y' L% I
    * |3 Q# C, L0 }
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    0 p+ l  D1 o& Y0 [% n" j( B1 Q( ^1/ a9 Q# X) b& @! q
    Epoch 0/1
    ; e$ @) o; C6 |9 [9 o----------4 f* u1 F# N! K, n' M' ^, X4 w
    Time elapsed 35m 22s
    6 {0 j6 ?5 g  ^* Vtrain Loss: 1.7636 Acc: 0.73461 ]3 O/ L( c. t; o/ I1 N7 Y/ w  p# Z
    Time elapsed 38m 42s
    $ y; L! B* P' k5 D/ Ivalid Loss: 3.6377 Acc: 0.64552 [2 U1 Z$ m$ v2 M. x! S  v3 _& ?7 X
    Optimizer learning rate : 0.0010000
    4 D0 v5 O* W0 E8 E  f* S
    4 n$ p' r8 t: e- p8 |% eEpoch 1/13 w% D7 w: ]2 u5 K
    ----------
    ! E' p( U9 x, f, b5 |8 @. [5 eTime elapsed 82m 59s
    ) E2 w7 I( B' w6 E8 x% b0 D% V4 [, C2 }train Loss: 1.7543 Acc: 0.73408 r  a, q* u/ I3 ]& O0 ?  x) U' s8 i+ v
    Time elapsed 86m 11s/ R- n* v5 B5 L! h. b
    valid Loss: 3.8275 Acc: 0.6137) c( U+ L& _4 W" Q
    Optimizer learning rate : 0.0010000, J/ k; Q; U6 L6 o$ e' ~8 ]3 d

    ) ~( I: P- j6 n: y$ ^/ `Training complete in 86m 11s1 W1 A! `' E2 ?4 S+ _4 R
    Best val Acc: 0.645477
    ) l+ @9 F. t' w
    7 R* e. w' d# @- H' `# L13 |/ u0 \' ~4 G7 F$ ^( G) f: _
    2
    * n0 g1 y! h% F' Y7 W31 i! {1 R$ i: g3 `
    4; k/ w' C% p$ f3 o2 H3 G
    5' f5 J1 C1 y; j- ^! I
    6
    0 g: P) `5 s7 H( \: k79 }% i% G' \2 |0 Q5 v" T. b
    8
    , \) D+ z* S2 g  n+ ]+ |+ v9 q# z. t1 f91 ?5 J) m# y& p# ]0 T4 C1 I
    10. v# K3 r9 J; k( \( v' `2 E% u- c) `
    11
    9 `2 j4 e( v3 [  ^9 O124 @; E& ^# g4 ?& H' K
    13
    1 r4 w6 W. j9 n7 b# R14
    : P+ m9 X1 h  K1 R% g; G15
    - a  d( _' z% T* ~16
    7 ~) y' Z1 D2 \% m' y171 a0 {& r% o; n; t
    185 R* H  p; C& k* n6 k3 {5 T
    8. 加载已经训练的模型" y) C# l0 f! ]; G
    相当于做一次简单的前向传播(逻辑推理),不用更新参数
    3 m! G" w# Z9 u2 C
    . _+ U" G  ]- Y3 smodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)7 F& }' u/ T; D* d6 _9 q3 |

    ) H: G& O8 S* Q( o% |# GPU 模式  Q% C9 u* a9 Q3 W; E
    model_ft = model_ft.to(device) # 扔到GPU中
    7 v5 e& E- F/ j# r( I# b
    2 ?: ]3 I7 R4 M' |" a# 保存文件的名字# O5 _! ^) s+ X1 t# s; b
    filename='checkpoint.pth'. y+ I: N2 ~% h* ]" u( D

    5 U3 d* Z5 K5 V0 j% z9 |$ U- r# 加载模型
    ' q" K0 \$ j8 zcheckpoint = torch.load(filename), }/ P9 I7 s& O. O- a9 F
    best_acc = checkpoint['best_acc']
    5 B$ u  l6 }( s6 ^model_ft.load_state_dict(checkpoint['state_dict'])
    ( C% r5 V) x) R# t1
    ) A$ u- Y- g/ y, Y2) N1 D4 N- u0 C! D$ Y9 p! s8 g& l
    3
      d- s2 g. D$ l/ C$ w! U- A* ^43 i: [/ g& F+ N) `/ o3 F3 X2 L0 d
    5
    9 I, S  b5 D6 U; b" g- k! d6
    $ g; U, H. V' u' ?) O1 e7
    / I! {3 W5 k: r# k, b( `" i9 a5 P8* N/ F: @1 ]4 Q) }/ c& i4 O6 z
    9
    ; l1 v1 F  H5 R/ o) _10
    : V8 X. `, _' Y9 N. e$ `: |1 P1 q# J11: j0 r1 T! G+ t  j5 G
    125 Q6 o- i4 J( E' l5 T8 v  m
    <All keys matched successfully>2 |% U" C1 j" x% n7 f6 i' D
    10 n4 e1 o9 h$ g$ [% n4 v; M
    def process_image(image_path):, ~; u3 Q: N% i2 a  C1 h
        # 读取测试集数据4 `2 ^8 [8 G8 c1 L3 `8 Q9 S+ A: J" b( j
        img = Image.open(image_path)
    5 M" s) s0 g3 G7 y8 r    # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    ! }5 {0 a, ?, }: n, e8 ^    # 与Resize不同8 K4 C/ a. T* g9 B! v8 p
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
      J' P: N- O# ~$ x! X9 x& @5 I    # 而且对象调用方法会直接改变其大小,返回None+ U: t2 ^3 s3 `& ^- }
        if img.size[0] > img.size[1]:
    6 _1 r) C( x* t7 }6 g4 y        img.thumbnail((10000, 256))
    & T9 q  U3 G0 Q- G% x( ^    else:; ~, M3 f3 y4 `- p
            img.thumbnail((256, 10000))' }1 N8 ?' p! Y; O

    # W& @9 J" Z8 O5 z2 O0 j2 x1 b1 @    # crop操作, 将图像再次裁剪为 224 * 2240 i5 ]$ \3 n% z* n4 a* R
        left_margin = (img.width - 224) / 2 # 取中间的部分
    . x' ~; c9 }, t2 C1 E    bottom_margin = (img.height - 224) / 2
    ( y2 @4 o$ v9 n& v0 a& d& e* q    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    / n* P3 u7 H" l1 P) ~+ T    top_margin = bottom_margin + 224* N/ ^3 a# B8 v+ l9 c
    7 e0 F- }( v; u  g
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
    ! B* k# |; v; Q- T8 R2 K4 d9 y1 O3 D$ g6 Q3 \. o
        # 相同预处理的方法
    " }# d) H3 I: ~$ `& h    # 归一化& E/ G8 C8 B  V
        img = np.array(img) / 2555 `+ D, W, W! b& k2 K
        mean = np.array([0.485, 0.456, 0.406])* F" e  x4 C$ F, x! k" o
        std = np.array([0.229, 0.224, 0.225])
    4 U. ~9 X& I( h3 C9 @3 m0 F    img = (img - mean) / std% G/ N! R. ~$ E. {* j7 i- Z
    4 m' Z2 Z7 f  X- s  M: H; \; ^
        # 注意颜色通道和位置
    " U, c$ H! v/ ?" x4 {    img = img.transpose((2, 0, 1))
    5 ~1 }- Z' ?, q& d& Z* a( X5 T# `: F! Q& R9 j+ E
        return img/ a& b% o! f; t6 x. U% y' f$ D
    6 y$ g; F, e  M6 f
    def imshow(image, ax = None, title = None):! C; z1 p8 @- D, i
        """展示数据"""5 }3 Y" S2 H; G( S( j
        if ax is None:
    0 \. a3 f9 C' P+ g8 O5 O        fig, ax = plt.subplots()* d7 _3 n7 o. K! q8 D. \, o

    3 w9 S4 J0 V" h1 S4 c" O    # 颜色通道进行还原
    " Y; V' e" G; T5 C; o# C    image = np.array(image).transpose((1, 2, 0))
    / I2 ]- C/ c4 f% Q% t$ D
    # j( ^" C+ A# S/ E    # 预处理还原9 L7 X( U( p' \$ @
        mean = np.array([0.485, 0.456, 0.406])
    * u) A( g7 z: r8 m    std = np.array([0.229, 0.224, 0.225])
    / G2 \* p% N8 `& K8 L( U9 k    image = std * image + mean6 Y0 e2 E" P" f6 k" v
        image = np.clip(image, 0, 1)
    $ R6 U# ~( q. X' @/ D1 w4 ?
    5 x# ?3 e6 y  N: {/ U2 }7 v( B# K    ax.imshow(image)
    2 Q" t: s  t! q6 p  R1 i    ax.set_title(title)
    9 p+ v, |1 g5 E/ |# Y
    & r* n3 O; @" C) i: n8 ?1 z4 T    return ax
    9 r5 O+ D6 h3 \$ E
    ' `+ r3 o; d5 [* G' o/ Himage_path = r'./flower_data/valid/3/image_06621.jpg'! V* J/ I* u/ }/ X  T! @( v- e
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    5 p% H+ s" @2 E: @& u# eimshow(img)
    & v) w) u# p4 N7 R0 r# ^* f. W" h
    * r5 D8 b  Q! C* z16 P# z6 x: @: w
    23 P7 h" h8 L  g: E6 h. y: N1 Q
    3! Z$ g! H6 W. ]9 n  s( F
    4( j: Q! _$ Z8 }  T2 U
    5
    , Q9 d, ]* w# N% _2 s6( O# t1 [1 C, q# S  ?
    7
    ) F2 `1 j3 {! x& Z0 V! }8 B% i! `4 Q8# [) f7 h4 g2 g
    98 |+ ~. s# N. b1 S# i& B0 w. U
    10
    4 H& S/ c+ ~) h1 W. R$ L/ x11; M3 a9 e: g* Q3 u7 D
    12
    , z  j" H: o" x& _  ]- G* w13
    ' Z- h  ~' v' v0 c* B8 r! Q147 s' q. h* D& Y1 l
    15
    " g: @) R1 u7 W; d+ U1 Q/ i16
    . I. K( X* F, B5 _2 x2 @17
    2 m' f7 o/ Y5 h18( O$ U2 U. n0 R# S2 ^' h7 E% u+ g
    19  f: L3 G& a3 J7 t
    20! ~! V" p; r+ i
    21
    ) F' n6 Q% ?4 z. u22( V0 S& m/ [# e2 o1 q* t0 O. q
    23
    ! n4 K) J# e. R/ D. D. Z6 J246 E2 L% Y! N' I( G
    25- e- ~/ }3 [$ X0 T2 |' U
    26" r- b/ B- D( i& B! J; M6 w" S
    27
    5 c  ^% V7 f* a4 U! r28
    1 D7 \+ G$ t! a- H29
    4 G% d9 {4 ~! A, a' [/ |3 |30, |" B5 t! S5 c  a( T  _! O
    31
    + I  p  D/ ^' i# y* f' o$ i7 [3 A32
    8 P, s9 X  x: u0 N* @8 a* o' o8 ^; Z33/ s- |8 u0 ^2 L! I! u
    34
    2 i, a' O+ G( Q+ h; k358 t, P  ^2 K0 y* ~* u. m
    369 p* C$ Y$ w" F/ e
    37
    - Q8 }0 M6 C/ {# E9 S) a  ]38
    1 F+ ^5 O! |" \8 g7 s( m5 [9 j39
    ) k' D, O0 ?$ M6 {6 p* H. Y' e! q+ q: r8 C40
    ' i4 ?# w: ~1 H; N5 p419 J+ W' Z8 w( f3 k6 E) u6 l
    42' x7 {5 p! i' K5 l4 i  S# \2 w6 {  d
    434 r6 J8 q! ?" a1 N4 D* Z1 P
    44
    : ^2 ^3 |3 ^7 Q4 x/ K. v" m& k+ l456 y: E& `, D; t: ^: v
    462 \5 y; y, B3 r. b) U
    47% {4 V3 g+ J. e- j% ]( I
    48
    . p9 h7 K3 M! @" Y499 g- W/ [/ p2 W5 c2 ?# O9 ~7 U
    500 b. T# s, J0 d0 Q( p# ]
    51
    7 ]! d$ s( f: V7 q5 C. O52
    8 ]# J7 f& H8 @; \$ l  I( Z53
    7 u# k( S' q* A549 J0 S1 h9 j. a/ _2 ?4 t
    <AxesSubplot:>4 s) o5 P( m( J( ]- Q
    13 D% w4 n: d' h4 T# P0 o. _8 B3 B

    # @1 k0 l2 |6 j上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确* r/ ^# Y" {. }8 ?+ n& H

    6 C# a& A/ i1 [* Timg.shape7 p# l! x1 b" Z3 \9 d) q
    1
    / w* U* T  M) U+ y5 t* h(3, 224, 224)
    ( x$ s1 Z5 Z) c9 u, M11 i( _1 C" l- f% a3 h) Z0 Y$ _
    证明了通道提前了,而且大小没改变
    2 h! _; Z& l% L& O( B& {
    8 t" T8 A  K6 {# H' m' d% q. h9. 推理- `! v  x  k9 ^2 l4 C6 x' }( r
    img.shape
    6 x) b% d+ G. ?4 j* I/ F5 x/ r  |- y# X
    # 得到一个batch的测试数据
    , {- E" ^/ s( q- g/ W  Fdataiter = iter(dataloaders['valid'])
    3 z; w6 J  f: I4 g$ q6 ximages, labels = dataiter.next()
    " V$ I( s& v# y' \1 d. f1 M1 R/ M  f3 ~/ a! a" v: g
    model_ft.eval()2 s1 H( ?+ W: w- Y6 V& B
    ( ^6 ?5 s, T2 m* ?9 Y6 @6 a1 ?
    if train_on_gpu:8 f9 _1 t3 D) U2 u; R' ?8 B
        # 前向传播跑一次会得到output
    + r6 S4 F4 w9 q  o    output = model_ft(images.cuda())! A6 X* v# \! g+ n7 o5 D
    else:
      v: F1 v# C7 U/ e1 D    output = model_ft(images)+ `$ p8 Z  k3 p5 Y0 m/ k0 Y
    % S( b! G, J& h2 [$ t( {
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值5 T% d+ O& P( K3 v; ?' f3 d4 X
    output.shape
    % L) ^  A3 H8 E7 g* L. o7 z0 i1 _: K
    1" h& A; g7 ]' T  b
    2
    - E$ E& k4 F+ T  {3; X- F7 x  u9 U% o
    4: c8 ~. `* W. z! \# b3 d
    5
    1 s4 U5 U+ z# v# g; @: G5 M/ [6
    ) S+ a6 k  [  V% Y& w* S. V7
    1 k3 X6 @+ L  K. z9 ~8
    * F, F5 r$ X) l/ y7 w. S94 |9 o9 T" A. j4 h6 J1 Z1 z
    10
    $ l- u+ X0 `/ S# F11: v4 M. M& v$ z
    125 B0 \. E, u' b1 X
    13, N- D! O# c# n9 m! d$ ~% x9 o
    14
    # C$ F$ B/ M1 o3 ]( ?2 E15, ^! D' I, a  w* i" [* Q' ^4 u! r
    168 b8 i4 q$ X4 q
    torch.Size([8, 102])
    : p8 g/ {2 L6 |, v1
    7 }% U% h! E8 a( y; z  W( C9.1 计算得到最大概率9 S9 c; ^6 O9 D
    _, preds_tensor = torch.max(output, 1)
    - @& ]! c) C4 p
    0 k# p0 Z, U/ H" Gpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量) ~  u; y( ~5 z1 A- W
    1
    - Y% B8 o# t0 O- m( H2 b2; x) n  R5 }" E0 g$ {
    36 F. V/ W' ]7 A
    9.2 展示预测结果: q( ]- A; L! F" P' L% V! P
    fig = plt.figure(figsize = (20, 20))% s0 A1 g$ j) O' _) D& X
    columns = 44 f/ }5 k4 V4 E) I' C" u7 t9 V
    rows = 2
    * C& g: s7 A/ [. x: Y" Q2 j3 T" S& n: W2 P9 }3 L/ S
    for idx in range(columns * rows):
    ! J9 F& v7 o6 ?9 K8 A    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    - b+ c# y6 n* \" ?    plt.imshow(im_convert(images[idx]))  [1 L! b# L) p9 e# n: Y
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), . ^  L, M3 Q/ S4 ~. n0 q& j
                    color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    3 j& P  S" |: J) {plt.show()
    6 e8 _# K  r8 X! M! x. z: o* E# \/ k! B# 绿色的表示预测是对的,红色表示预测错了
    / i2 \" d7 _3 t# ]1
    / M- H* R7 E  p# u1 N6 c4 M; h' z* \1 Y25 b3 `5 y" r5 `8 `1 U3 U" {
    3
    7 @. @0 C6 }9 F2 E1 Y& H% Y44 n0 p5 W! }( G
    5
    7 Z# X' J$ L+ R5 a& c6 o4 @3 e* J8 W6
    1 a0 E/ c/ }8 l7 q7
    ! r3 c" Z3 ~" Y6 W. R; ~% E8+ A$ [7 X+ x+ x1 x) G
    90 P3 c/ w0 d1 R$ h0 n& ]
    10; Q1 m; J% u9 T6 p/ D7 R0 ]
    11
    * Z" O1 p* b% D. ^3 c6 [1 O' a+ I1 f2 K- l
    - C& b( y, j/ I& n0 D# t- i" V) X

    ' U! G  |  e, ?5 ~% |7 g+ g————————————————. R& v" s8 R8 F: r5 Q1 z- e
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    6 G" ?( r: F- G8 U; w原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    0 K( [2 n  q2 i; Y
    4 l" [( q9 f) w: g
    " O+ G4 K  g4 Q! h  A) \/ `
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-16 08:15 , Processed in 0.491553 second(s), 51 queries .

    回顶部