QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2756|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
    ! m3 _( g$ u3 |; z1 h
    % i+ J9 m- t  k. @' X9 p文章目录
    5 Y# ~" E1 R7 ~# u2 y8 k0 b+ ]卷积网络实战 对花进行分类
    0 M# G# H2 }+ B) t, A* h( p( i8 K8 x4 e数据预处理部分/ I9 r, ?) H7 n) O7 k$ q
    网络模块设置
    7 u8 ~$ R+ K- M7 W7 M8 N网络模型的保存与测试
    ! a2 o6 R0 Y) j  J( H数据下载:- d# `4 l8 |( `, k
    1. 导入工具包
      M$ F" b- A8 w; I! N+ `3 t( i" j+ d2. 数据预处理与操作
    # u) x& U7 K* `! f: b# c3. 制作好数据源" V$ y5 f/ C3 s# n& u8 \* _
    读取标签对应的实际名字0 d4 p" R; E. h! Y3 z; k
    4.展示一下数据
    8 j& {5 m- k/ ~7 p( u2 F% Y  x5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    1 N# E4 `8 U2 E6.初始化模型架构
    0 g$ ~. l' X" r  T7. 设置需要训练的参数# s% R$ c+ [  L* T, v9 _
    7. 训练与预测
    6 i# ]; M1 g8 z1 R; g7.1 优化器设置
    8 v; c) Z1 e0 o4 [7.2 开始训练模型
    2 M2 W3 k+ @$ k3 u8 ~7.3 训练所有层# S( M5 W" h. X8 i7 Y
    开始训练( v* i% X% D( u
    8. 加载已经训练的模型+ ~! @3 Z9 L; C/ A: O5 \
    9. 推理
    % s6 u' v6 E) t. Y9.1 计算得到最大概率5 ~7 Y) `! E7 v3 T/ E4 C
    9.2 展示预测结果/ R+ e+ E) H! S* n6 S% ^9 J3 Y3 q- L
    写在最后6 F6 J: W' S- [0 D
    卷积网络实战 对花进行分类3 B- Y7 R0 w1 y3 r' H$ A7 B1 f
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能3 J, c! f0 @: _' O( j2 j8 D

    3 H$ j3 {  w( _1 }4 w在文件夹中有102种花,我们主要要对这些花进行分类任务
    2 c; Q- T9 O+ |# W2 u  v/ l1 T文件夹结构0 [/ o$ t4 h/ J4 e# {/ n) s

    ) d4 m, E, e% ^7 s* Fflower_data& Z/ ~7 N7 F; Y% U0 t/ z& B
    8 L, k" A, G' Y1 X' X
    train. C# M' k3 T( O3 E- x9 |3 h+ [

    * O" P- G' {" |0 s& _. W% H5 @* Z: p1(类别)0 `: O. ]! N) q3 ]. O
    2( ~: O  @. t& k8 h
    xxx.png / xxx.jpg4 Q5 H+ Y: J1 H# l" a9 h
    valid
    , R" l: ?9 ?5 R  Y( B
    1 T8 M5 q4 b- R+ F# `主要分为以下几个大模块$ S1 Y  u% D4 h

    8 q2 E! x6 s/ U) R1 m数据预处理部分
    8 e; D0 A  H) Q5 b- m5 j数据增强
    4 o( s  s6 I' o. ], ^数据预处理
    8 a3 s; e( u8 A# A网络模块设置
      D0 u) |2 Z; o! t; q加载预训练模型,直接调用torchVision的经典网络架构% K' }$ |4 ~2 A1 S/ h
    因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务) Y# t) r, B. g& A  |( }
    网络模型的保存与测试. l9 \4 d1 Y* y5 E5 Q" t  ?
    模型保存可以带有选择性
    * o2 y* k  M6 X; x- X: e$ H7 v- E: p  {数据下载:* T9 {* f; q3 K+ B; E
    https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset  m' t/ d& V, Q" r" L

    # U1 O) E9 B! t' T7 ^3 \: D! N- L+ t$ c改一下文件名,然后将它放到同一根目录就可以了
    ' E2 \4 l6 g* A, t; |: E0 \. k, W4 s' W
    下面是我的数据根目录  w6 x7 p; a' {5 [
    " r8 L0 e" W5 f# l
    6 ]+ ^0 W, F4 O5 \* A) v8 d7 p) {
    1. 导入工具包  h9 F9 U7 W  H" ?3 O. I" H2 e
    import os2 C0 {, a0 k# _2 y# f+ g
    import matplotlib.pyplot as plt" W7 A" Q  \; s5 D+ N  k; O
    # 内嵌入绘图简去show的句柄
    : Z! ?3 J& \5 [+ d1 G( q  G6 \%matplotlib inline 0 s4 O+ N+ t* f( }
    import numpy as np" Y- z3 O% V8 W+ A* T
    import torch
    - V: Q6 }- W& u* Ifrom torch import nn
    " p2 F2 v1 B, w" i7 B
    0 ~9 n4 o* u* Iimport torch.optim as optim( ~$ M# j! ?7 q" k; q4 T# l
    import torchvision6 t+ d2 ~2 ]6 x3 i1 H- k2 n
    from torchvision import transforms, models, datasets
    - U: v9 M4 F; O; R! `7 N2 ?
    # m; x: S! o. V$ k6 Y# x8 \  Oimport imageio
    8 k4 I: B& {% M  Mimport time
    4 Q1 [. y9 |4 v* s: Gimport warnings
    # ]$ B  j  |: G# T- {4 Yimport random$ a6 S( N: J5 V" v9 i7 T
    import sys
    . c5 g7 \& h' N8 Rimport copy8 ~5 A) }; q- M, l2 F: ]
    import json
    & X5 q4 O7 t6 H4 }' k; a( d" t; Xfrom PIL import Image. X) Y) _6 H+ N$ r
      R, I: l2 w: p4 z  b( e
    & m( w( J* l4 k: A
    1% f/ M$ [# d" s  j, t3 A
    2
    5 k% M( r( Z. J3
    0 B- a4 X: b& O& j/ |4
    6 ~2 K' E- \* S* e) ]5) g* ?" P: i; T* }# _0 s7 f/ j
    60 q, W( F) @# `7 y
    75 }1 }0 p5 b/ e8 T& J0 L( n
    8" Y/ d4 A- e# j; e
    9
    & p( d* N8 J' T! ^105 K+ _+ {6 Z; D8 f$ B3 q
    11. L, ?/ O0 O" D; L
    12
    ) ^; V9 T; o$ o6 R& {# r, q# ^13
      `7 H: ?8 F" d7 Z# a14
    + y7 x0 S. l7 I- s3 A0 o, @7 m" r! a15
      H2 F) w( G. ~2 F0 w16
    $ H( C$ [" W+ O9 ]17& ?5 k3 T# R) a! Y: r( C7 q( M
    18
    2 M4 z7 @, ^+ t. T19
    : z  {  N3 N8 \: k! W9 l0 _20% {/ T" {: ?- l5 \
    21
    + O) n( v" `! N; K) ?( Y  g( }2. 数据预处理与操作% P0 u& a9 C" W% ~
    #路径设置9 z$ E. T! @2 ~2 n7 w3 ]
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
    5 W. h, ?* H% N8 j" k9 ?/ A/ n  ~train_dir = data_dir + '/train'- s* r3 P, j/ ]/ J* c) G
    valid_dir = data_dir + '/valid'
    & b* K$ p, l6 X& b# y5 r, H0 X6 R1
    6 P3 {) M( h/ g21 K$ O) j& F3 B5 m! M' F
    3
    7 |1 x, B$ o. R9 j/ w6 H4! B& T% o) w/ u: q
    python目录点杠的组合与区别( w4 m: n2 X: `2 t% q
    注: 里面注明了点杠和斜杠的操作
    ) ~+ ?* L7 S' O' q* Z- m+ O( u1 i& F
    3. 制作好数据源
    " n) }  d7 p3 sdata_transforms中制定了所有图像预处理的操作
    0 H, E4 L5 N: d/ n  T0 k' |( xImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片" Y4 i! j4 {% y
    data_transforms = {
    7 j% T! I' t' k8 O1 ?5 }    # 分成两部分,一部分是训练; l6 b& V4 w5 U% w
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间+ h$ s0 C, K) ]. a: U' b
                                     transforms.CenterCrop(224), # 从中心处开始裁剪
    4 r6 [  O, W' ^2 ?/ c0 g                                 # 以某个随机的概率决定是否翻转 55开0 r! z) K8 m1 b# e' z
                                     transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
    - Y' u- e9 ?! ?. k                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    ' O6 s7 R: I% ^$ L. y                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相* U* o+ s' G" p% T1 Y* U
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),, m% q' {. @; |9 V- Q7 T
                                     transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB2 e* N0 d, A! U9 j1 x3 U" x7 z! W
                                     # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    : u# @! m& h) T' h/ A                                 transforms.ToTensor(),1 {4 `( k2 q# m+ Y) ?
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差/ ]0 j  ^* H. N" a) Q
                                    ]),# Y' L6 m) p/ f" K6 g# j
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    8 X. T3 Q$ {! {  b    'valid': transforms.Compose([transforms.Resize(256),& F  c$ `1 A8 P' |* |) x! g  W
                                     transforms.CenterCrop(224),
    # Z7 J* s& [1 U/ S                                 transforms.ToTensor(),
    ! e2 w; T& c% Q) p( A4 g                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
    : Y* z+ y! ?% |$ @- P7 F& t9 \                                ]),5 n" J8 ~' C( g& {" }0 k0 `6 v
    }
    . X) Z, b- }5 a% @4 `4 \- ]$ m& K) }& L% {! q" B/ \* p: X! l6 U6 M* L
    1+ `2 H4 m# |, c# `6 L' R' {1 H
    20 \9 y" G. r9 P5 p
    3
    - x4 r. l, x$ p9 J+ F* a4: t+ w& O& I' q, m& ]3 ~9 u" I
    5
    6 v  F, y) W- E5 c5 W9 }; P2 `6
    6 u' i4 R3 v( D5 g# o- W7 h7& G- c# w; ~: h/ P
    8
    / P4 K& C" A* v91 y0 F% B- B1 ^7 o
    10
    * o' [4 U+ K4 U1 M5 j( ?  c& U! Q11
    $ H8 U; |" m) u: D7 C122 }6 ?1 N7 V4 ?' q
    13) o9 y5 g$ H0 Y% [% a
    14
    4 r* B8 x( u6 T; k1 q* M+ b15
    ) P: Z: b- r# c7 u16
    $ C& W+ b0 M0 Y17
    " V7 P8 k. m: D18% S8 j/ Z8 W; C6 D+ i
    19
    , W3 @4 Z0 i9 ^8 a% U20
    7 k; t. Y% z0 `4 {21
    8 N$ n3 i3 W) Z1 t! a4 h) W) A" lbatch_size = 8; _7 h: m- B3 p9 B% I5 b
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}" i, }1 f4 ~6 P* i( z+ Y8 u. W
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    , B) s" m3 r; y* Zdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    + v/ s( {  b8 a; {" _$ l. E" ~class_names = image_datasets['train'].classes
    0 S3 j3 q; ?* q% c8 E& N
    - r" W3 M# m% L4 E- B9 v- N" T#查看数据集合2 X- Z" C6 A! u. y; b! D
    image_datasets2 _7 D  \# J1 k0 z' r) ^3 q) h( W
    # T# T# P; B! @2 k
    1
    " U5 @# H- h$ P" V. b& X2
    3 u  |. g& [" a7 u& j$ X8 @" R3* x  \0 M  U0 e/ X5 U
    4
    / l9 Z  G5 x. |# W6 v5 Q) n! [5" a8 ^: F' W  {7 f0 ?4 U6 T7 [
    6
    . d. @: w3 b3 e7
    0 i, ]" y% J1 u4 ~80 V- `' T) U" H3 J! g5 Q+ a; Q
    9
    . {5 D: g& y5 [6 z4 o& k{'train': Dataset ImageFolder' Z' F% I! h: J$ F* y) Z
         Number of datapoints: 6552/ `5 A" v- P  _8 r" `* c# |
         Root location: ./flower_data/train
    . i# ]8 Q6 m% u. J6 N     StandardTransform( e  u, [4 b$ y) O$ Z1 d! x
    Transform: Compose($ I! z9 j. C- s9 s/ W
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)7 C* J4 U8 X7 T: Y
                    CenterCrop(size=(224, 224))
      E, T3 t0 k0 c5 O+ l                RandomHorizontalFlip(p=0.5)6 }2 B: C2 D* D3 q) e2 m* i! W& M! R
                    RandomVerticalFlip(p=0.5)
    ) r+ G0 ]  r* t5 Q/ o                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])2 B5 u8 ], @6 r, S
                    RandomGrayscale(p=0.025)
    / N. y; i3 x7 \. d' o# F  _: _                ToTensor()
    + S8 j% X4 j6 U3 B                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    6 j0 [( D& t: N: e' E+ t" u( T            ),% g7 ?9 C" t! ?  g, r
    'valid': Dataset ImageFolder
    . T* h% B7 k7 B0 _1 K# S, E, A     Number of datapoints: 8189 W0 J7 n  Z- J+ r  P- `2 S! _
         Root location: ./flower_data/valid
    / I, n- x( h. g! _7 l; X* ~* c+ o, C     StandardTransform( h9 G1 ~! \, M8 W. I. g, I( J) Q- y( N
    Transform: Compose(7 h+ I/ u& S7 @, ?* H
                    Resize(size=256, interpolation=bilinear, max_size=None, antialias=None). \1 S# T9 |1 N8 u' K2 `+ j( P9 e
                    CenterCrop(size=(224, 224))$ \6 W- b( u1 T' `
                    ToTensor()  |+ z6 `( X) a
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    1 C. M0 u& J# v            )}+ j/ [* k8 o' }9 c9 r
    + Y3 S' _- ~$ ~8 ]1 y- a
    1
    0 u2 U: G! o+ R. z- E29 _  F  I# x$ S0 e8 @' }! J: }
    3, V& T- r3 k$ t, g. J3 J, o+ b- ^) d
    4
    $ m& D7 W4 [/ g9 c5+ i( {5 q2 r5 \) e) o
    6
    " f' e+ b. f! y+ v8 v! U+ H; E7! ~5 g/ m6 T3 M; A5 C
    8& i, d6 a0 s' Z
    9
    $ a6 a# E3 ^9 f  W10
    1 Q0 a0 d/ I* \. {- Z/ M, `5 t11
    4 G9 z5 i! T5 h+ V3 j+ M1 ]124 q# W! p, ?0 [9 Y5 g, Z* ~
    13  |! z8 _5 F' d7 l! e: |3 i9 u
    14
      x: V; W3 s2 X6 W% x* [$ X+ V. b15
    - o; E% [) r! Q  m, z8 z% s16: k3 c& s# U  q/ m7 ]3 G
    17
    - e, r: P; R+ o! {' H185 A. e# n' o" \3 t
    19
    $ V2 v. L- w6 q1 e. T209 V, D3 h4 h0 x8 n. T1 D! {
    21& @2 t$ L5 i/ h! _
    22
    & `0 A0 q& n* A" Q9 G23" _& x/ d8 ?4 U
    24, `8 `! n% j9 Y8 ?  D- f# J) R" [
    # 验证一下数据是否已经被处理完毕, |0 |, d  f* t
    dataloaders) O$ r5 H9 B2 c& @
    18 _0 J2 D7 P/ C4 z- R0 _' Q3 \
    2
    ; {+ i  m6 p! n" _{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,  G8 d* [, O# F' t. O% A( S! k9 `. H' T
    'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
      g6 V( k0 l, G3 n3 l: L1
    % O- f  w6 @3 o4 o. {5 u2
    7 H5 k; K9 P$ W. ?dataset_sizes
    7 ?9 r; k% S+ w" u2 P6 j10 K, v) j9 m$ T5 M7 z
    {'train': 6552, 'valid': 818}
    2 R3 W; O5 g0 z/ |; O- X$ X17 i% _  i6 F+ V4 \, M
    读取标签对应的实际名字
    , {7 N( Y& r% o) N) g使用同一目录下的json文件,反向映射出花对应的名字5 a9 x: @' M  T" N4 }

    0 ^( Z- B0 H7 }2 G" M- O, ~with open('./flower_data/cat_to_name.json', 'r') as f:
    0 g- S" N2 T7 z4 O- b. P1 O    cat_to_name = json.load(f)/ G+ V: Y6 J7 t& F9 d
    1
    8 @% Y% N% i* L3 K. l2" F( c8 X2 V) ?; k
    cat_to_name2 B* O9 q% S7 A
    18 U* j6 j8 N+ i( U, p* L
    {'21': 'fire lily',7 j2 o, `4 C- H; r+ j4 U
    '3': 'canterbury bells',  u$ z( V, G1 @
    '45': 'bolero deep blue',
    * p* v& }0 u  _) u" L1 ~- o '1': 'pink primrose',2 {9 E/ i" K, v. {, f
    '34': 'mexican aster',
    $ M0 `6 M6 \: c! y7 ~ '27': 'prince of wales feathers',
    8 p/ r$ P# V. V. n% U- @9 a '7': 'moon orchid',3 ]) y) H8 Q1 y3 C- N/ m- o! k
    '16': 'globe-flower',
    5 r% H$ ^& T  c4 q, ]' ?/ K '25': 'grape hyacinth',$ P2 W( K2 ?) d9 M9 J
    '26': 'corn poppy',* j5 u9 P3 Z/ @( f* a5 J$ }( }  Y
    '79': 'toad lily',
    # t/ @+ @# O4 Z '39': 'siam tulip',
    . G& m: r$ O* `- T& z '24': 'red ginger',& Y3 c2 f  y3 z1 H" z
    '67': 'spring crocus',2 J+ n. G# v0 }% F! h% b) L
    '35': 'alpine sea holly',
    8 G1 ]4 R& t/ Y8 w# f '32': 'garden phlox',
    1 M2 k3 P7 N  _8 h '10': 'globe thistle',$ u; u; U0 W4 O$ B
    '6': 'tiger lily',
    0 N) w+ W0 v7 s  p" J, [" _ '93': 'ball moss',, y" U- G. f& K# c( K
    '33': 'love in the mist',
    ) X- `) C: v0 P9 i7 b '9': 'monkshood',
    , I0 S$ k; i0 b+ U$ B# I '102': 'blackberry lily',
    9 @9 s6 y$ \- c- c% n '14': 'spear thistle',
    ; ~+ ?9 h" ~! q! ]# T# l/ [ '19': 'balloon flower',
    , X) p( O2 M6 _, ~. V '100': 'blanket flower',
    ) j) p3 {7 r" e* D9 r '13': 'king protea',/ G7 x% q8 P' M' P% _# ~
    '49': 'oxeye daisy',/ u6 Q( H1 H3 \! j  ^) h( P
    '15': 'yellow iris',& r! \  {4 P4 f$ R
    '61': 'cautleya spicata',9 |5 W9 e+ F1 f1 r# S
    '31': 'carnation',- z8 D& A( L' ^* G5 {" y
    '64': 'silverbush'," [/ _$ \) L9 F" U& o  K
    '68': 'bearded iris',
    1 o' V. _  d" }$ L5 A! C2 Q9 U '63': 'black-eyed susan',* t* [0 b. p# L8 L0 p
    '69': 'windflower',; l% w- `5 |' k6 ~8 U
    '62': 'japanese anemone',
    - j2 p- ^$ a6 G( Q  G/ r& S3 j; H '20': 'giant white arum lily',
    " k0 C; Z% `/ Q8 c+ ] '38': 'great masterwort',
    8 H: s6 U7 n7 f6 i2 i '4': 'sweet pea',6 }8 u6 z7 V4 r
    '86': 'tree mallow',6 W5 s- U+ d$ D9 b+ N
    '101': 'trumpet creeper',- d8 b+ `' N) K* t. G
    '42': 'daffodil',9 x( Z. n- u# }' Q6 x
    '22': 'pincushion flower',4 I$ Q) c) P" `9 m* {6 N+ x
    '2': 'hard-leaved pocket orchid',1 f  v- U2 _5 Y
    '54': 'sunflower',
    ! A3 h; g0 U& M/ g7 C/ d8 k# K0 o '66': 'osteospermum',
    6 ~5 Z4 e& f! w  \# N0 O& c9 P' E '70': 'tree poppy',
    4 c; j" @/ h/ H% Z '85': 'desert-rose',& P: r! ^' \: G6 e" d
    '99': 'bromelia',
    * x- b8 m& n& {' i" e% H '87': 'magnolia',# N9 m6 |7 \  W
    '5': 'english marigold',  y' i/ \9 v' M3 T
    '92': 'bee balm',9 Q0 w  D( M( k
    '28': 'stemless gentian',$ ?" N. F) m- \. ]( y
    '97': 'mallow',
    5 }7 u& ?: E+ G4 ^ '57': 'gaura',7 z1 [  o9 o8 v0 \% F
    '40': 'lenten rose',& l/ Z: M/ F; @& V
    '47': 'marigold',
    , n- T3 f! U* M '59': 'orange dahlia',& q1 {6 j4 |# T: ]) c9 z
    '48': 'buttercup',
    7 B6 U& G+ E( y. j% K. y& u  g '55': 'pelargonium',
    ( ?3 y6 _: h. H4 d% U- Y '36': 'ruby-lipped cattleya',9 q9 {/ f1 {- A2 `: e0 E4 O; N
    '91': 'hippeastrum',
    * M0 y/ Z8 q0 u/ N '29': 'artichoke',- F! z4 e7 s+ D4 k2 i& l$ x
    '71': 'gazania',
    $ a& c! H+ D6 |) }& R4 t% D0 Z7 w '90': 'canna lily',% k) m! H! o' a8 x/ k& K& N% ^
    '18': 'peruvian lily',1 i7 m2 Z5 ?6 w
    '98': 'mexican petunia',* P4 T1 D3 E8 Q, U2 R
    '8': 'bird of paradise',. T; ]% m- U6 H! I# i5 n
    '30': 'sweet william',. [9 @* K4 z7 }9 Y
    '17': 'purple coneflower',
    % ~" B& a. q1 V2 t '52': 'wild pansy',
    + K9 m  n% B" U6 } '84': 'columbine',% K! F& r: f! z4 L$ i- O4 c
    '12': "colt's foot",. E' g" h+ k: p* d, V9 ]( z
    '11': 'snapdragon',
    " W( r4 \) O" V) E1 x4 ]* T '96': 'camellia',8 g4 l# e# R' M$ j, {  M. O1 x
    '23': 'fritillary',9 O: q- s3 Z' J' F9 |, V
    '50': 'common dandelion',8 k  R4 |  C9 i
    '44': 'poinsettia',) b% O: @7 c6 r7 S% N2 |+ N2 G
    '53': 'primula',/ e6 `$ Z( e1 {$ `( `3 a6 l7 ]3 @- [6 v
    '72': 'azalea',
      q7 E/ p& ~/ M '65': 'californian poppy',
    3 _8 }% c; _1 F '80': 'anthurium',
    , w* H/ W; M% t( D/ d9 x '76': 'morning glory',3 \: Q  ^) B5 V! q5 l
    '37': 'cape flower',3 N, k4 [8 p+ h/ p
    '56': 'bishop of llandaff',& d' p4 [- N, \  P8 C
    '60': 'pink-yellow dahlia',# d, B' G) ?. w, {' S0 O
    '82': 'clematis',. |) M  w2 p4 Y8 W+ o
    '58': 'geranium',
    # [6 w+ h* t: t3 ^- D% ?1 P8 ^ '75': 'thorn apple',+ e1 Z& c! y$ |$ K* Z( {
    '41': 'barbeton daisy',
    , }( a+ F2 Q0 a4 Y/ g; H; z '95': 'bougainvillea',' z2 h& j& h2 ^3 g5 b
    '43': 'sword lily',, @  E- m* N8 a8 @4 l1 F! k
    '83': 'hibiscus',! h$ v/ y  q+ h& S% H! _5 [
    '78': 'lotus lotus',* I4 e( @7 c! e4 o
    '88': 'cyclamen',/ m' s( m+ ^( e2 j' M5 V4 f$ c+ d
    '94': 'foxglove',/ g' j8 G! h& ^1 Y; `0 w
    '81': 'frangipani',/ H6 }2 t9 v) d2 h+ P2 m7 p
    '74': 'rose',
    9 ^; p$ f% U8 G4 q '89': 'watercress',
    . _9 t; d- r# |+ K8 N1 ~6 ^+ z '73': 'water lily',
    , f: P9 b  Y4 h' \. L' U( D9 A! G+ d '46': 'wallflower',
    6 D) c  G0 ]/ ^8 ~5 r '77': 'passion flower'," R( F# a4 P8 P% {  ], j
    '51': 'petunia'}
    0 z/ W$ m) {! ^! Y; m6 }# _/ j& i8 W' d9 G0 H' ?0 ^
    1% K, J! d& L  H8 c5 u0 {
    2; i/ B/ P7 K7 a& h- ~6 Q* a( E9 d, n
    3
    ! k) h" T5 G) G4
    ! p! o' Q9 J) f+ |, h4 c5* I" f6 t' O+ j7 i1 l4 V
    6
    0 u5 F+ D) Q" i, S) }, V* A76 b* ]% R& |0 J. q1 L( Q. P- B3 S
    8+ L7 \$ N8 Q7 g: s8 }
    9
    + Q5 @- o+ t- y10
    $ ]# s- R$ U9 x* ^, V11
    1 W0 |4 c, S7 Z, k1 [& j5 j$ D12! x% I" `  l7 Z7 [% k# Y
    13
    0 o; a$ h: }+ o9 |2 ]: I% F2 Z144 [3 n" A. u3 @& E* I( w
    15
      l9 s. s  Z" Y1 c16
    ) U/ Q, ^: V* K- }: I5 N; v0 C17% M% ]$ c9 L0 w" w4 n
    18
    4 m4 P5 G7 H7 b) @19
    # |9 `! `1 y* ]0 W) I20
    ! d" \% x. I" t/ u& V21
    3 ^# N" {, U( f7 d9 z% F6 ^22# v  ~3 @8 Z2 P9 q
    23
    2 v4 a7 e9 f4 G' v" {24: `' @: [% [( _6 m) W8 {
    25
    ' o* }2 b5 A5 a, a. H9 C* n; V" F26
    " ?& n! {4 X$ ]" Y6 x$ j7 b270 [$ k$ ^2 M* q  p
    288 u& e+ g' [: }+ |% y7 ^
    29
    ! e/ Q6 f6 A& G$ e: D, V7 w30
    4 K2 d* Y. l/ |9 l* C! [31
    ! M- h- c% C1 C32* j9 n4 e1 g- \/ B  Q9 L/ U. y  P0 ]
    330 b1 b/ V+ P" [8 [+ N+ f+ l
    340 E5 `- x7 ?8 n) F! c5 x4 x
    35
    & l1 R8 K% q1 T  l6 r8 m36' g" S* f/ i( F9 w1 ^4 h
    37
    9 J  N1 o0 v: k  F* F2 B5 I# K* k) C38
    . X7 z6 K# p8 o) t. l; m* E39+ ?5 K4 g0 H0 n) ?5 A
    40
    # a! l, u2 Q- k41
    ) i# i, |9 O) P; n0 v42
      U& z1 B$ y- Z* j43
    # G8 G) m5 ^, G( x44% x2 @" S: ?0 l4 Y& {6 d# d
    45
    7 A/ K. y5 s7 `* n& `2 T467 M7 X% A$ w' ~- U  [" X5 J
    47
    : t( V% w0 R: V4 D  I  ~481 e7 S# E1 J; h
    492 F8 F1 ^# d8 v6 N: [- y) n1 c
    50
    - j! S, q# s; S519 q4 h: T0 j) K3 ^& K
    520 s2 R, W  Q8 E6 B
    53& R, K% W# H+ ?7 m$ g
    54
    % P, V( z4 p' ?551 p  W/ K  {$ k7 @
    56+ I; [( S8 R# [/ L% p7 t; s3 E! P
    57, y% r2 {! L; a" G. `
    58# t. C  O* ~- Y: ~; s  i0 `; X2 V
    59
    . F! F* z) I. T% M- U! v% k9 }60
    ( U0 D& n: O9 |3 k61
    ! R, E4 p& F: o6 W' H1 s  z/ y62, a; ]' {, \, Y* W! [3 v0 s
    63
    . N# Y. s6 F- ~* `0 m% U' a7 T64. b# _* {  z' X* y" {
    65* i4 e+ P5 h4 V6 b4 H
    666 C3 @& Q; R6 l* X4 j
    67
    6 ]) L$ `0 z6 h# k6 B68& l2 X/ G( B: S- l
    692 n* Q+ w9 ?6 ~) a* T3 B' R4 a, F( i
    70
    9 M% A; q% W4 i3 W715 y/ l  J) K! m5 h
    72
    # _! _8 Q( C  \% s  m73
    " `' k+ N; M- L; B: a: ?& S74/ b( @. A3 _9 }+ i- y7 H" U
    75+ s0 A( K+ h" z3 v5 @& j2 S0 l4 _
    76" F5 s9 Y  n2 n" ?/ y( s# S
    77; m6 g2 w+ q4 P: O
    78# D/ G" k! \. L3 b$ j3 K  j
    79
    4 q5 l( f- C5 v6 C- S% w80- M) u# t  f: a, b
    81
    3 h% v! f# c" I( M' q5 M4 Q) Q824 j* r; W" }; O& G$ w0 c9 A7 p
    83
    1 Q9 r$ l" l6 c; P8 U4 t" d84- S& g" H/ d- u* T
    850 a; T4 A1 X9 \2 v$ ]
    86! `, w. o0 b4 w% r- p/ ~
    87
    7 |( V0 k- X5 T' a887 ?# o- W* A6 e( ]+ T: C4 X. J
    89& [) M" B( g5 o3 y6 ~
    90
    # Q+ \% w+ N4 Y- H! F7 Z/ W- f919 G) B' H9 t2 y9 Z/ P
    92
    5 k, J& `. I8 {2 ?2 Q93
    " Z( r% ?! ^7 p* \, y$ \! @: m94& C* n: P* v. u" f& a4 j' j6 I
    95
    : }) `5 `6 B( E. C8 {96
    8 J" Z; Q" ?+ S7 w97% y- z+ q4 v; A# E% c
    987 J  P1 q* w+ Q2 y" f' A# l
    996 f/ y( e$ r# u# S9 J6 g
    1003 a$ ]- x, {' |$ p# R' {6 U
    101- K- K( Z6 p- q& K7 N3 h+ B* D' |
    102
    1 C. A8 X2 z9 t; G, _9 e: `0 F. n+ i4.展示一下数据7 b4 F6 `- L1 D' D4 B. r/ D
    def im_convert(tensor):
    ' D' {9 M5 w+ [( ?! b2 A    """数据展示"""
    ) n1 p( H8 P6 M- E6 C' q/ d    image = tensor.to("cpu").clone().detach(), y9 m) L0 V% C/ s, I
        image = image.numpy().squeeze()
    : Z# J9 G2 T7 a1 O9 ~    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图. o$ W9 E" y0 x( j8 S6 T) o2 c  Q
        # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
    5 P5 j3 B$ z  v  ~, `8 Y9 Z    image = image.transpose(1, 2, 0). a, z3 f" d* t& z- w. N
        # 反正则化(反标准化)# d9 O& }# Z& ^& ^
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    $ D% F/ u* W* f  X+ I: o5 W' D
        # 将图像中小于0 的都换成0,大于的都变成1* b; F3 J1 G4 s* H$ M
        image = image.clip(0, 1)
    ) v- K. u) s5 o4 g  A4 C! }
    " F3 c# \5 L% v7 G6 Z    return image
    9 f6 d# D& {) `8 I3 `( V- i1) E5 \" _8 [/ [& ~3 ^
    2
    8 Y1 K" {0 Y3 b) i+ @" X3# [0 a  g; w( b+ \: O. E
    4
    3 f. v. X0 `; I& X55 ^$ X) ]* I/ h
    68 z- D/ N6 i) i8 a9 ^, z
    7
    9 B! }4 I, a! P: s. G8
    ( I5 e* Y1 d5 b+ y9 `6 s9
    9 A  Y& Z5 [3 ~' z# k104 a# x1 r. m' P- v( `# _
    115 J0 C- n1 W% t1 l7 F8 i1 g
    12; ^; x1 C. V  P
    13
    & X( F0 M: W' H: l14
    $ s5 a, R8 x  w' e9 D5 \% r# 使用上面定义好的类进行画图: B7 [, [. m1 n$ M0 |6 T8 ]! m! L
    fig = plt.figure(figsize = (20, 12))* J* ]3 L% {( y, u. E
    columns = 40 [) r4 _; \6 k- c- y7 X6 M9 ^8 g
    rows = 2
    ; n7 z2 W& c6 k2 v- Q1 D! K  v" e" m% @& x1 I
    # iter迭代器
    ( \) P: e7 s" y$ j- W- Y3 N# 随便找一个Batch数据进行展示
    * o- N: e+ ~# d/ _, kdataiter = iter(dataloaders['valid']); B. u7 v& r5 P$ Y& Q5 E
    inputs, classes = dataiter.next()
    . v8 u. R7 Y$ ]- k. O$ P1 G& ?! ?; `
    5 P3 d% s; G2 u) u- `' K2 m0 ?for idx in range(columns * rows):
    . S8 Q5 \5 S. ?* ?& j3 l    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])0 ^8 K3 E' B5 c" S% j% K
        # 利用json文件将其对应花的类型打印在图片中
    4 t$ K( O, R* p4 t    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])# _2 e6 P2 x) i/ @$ \6 B
        plt.imshow(im_convert(inputs[idx]))1 O- [1 J, S7 B. z  F
    plt.show()
    3 w8 J( s/ _- j- b: x
    3 i( j0 W# w1 Y; l3 [1. W- T4 z3 N0 g: \5 o% }
    2
    3 y( I& u7 i3 N3 R4 \/ m38 i9 ^  ?. I5 @! L/ A
    4
    6 l* h2 _, ~: i, t# `6 G5
    * c: u# K" B4 G2 B( }# j6
    8 }! A% Y8 o; }7 t2 r7
    " G/ ]& \& `' i# E83 U: V0 i2 P/ m6 D# T
    9
    5 {0 D$ D3 m: |10
    . F; W0 c2 `) n11
    8 }; ~3 I* {7 V- }' L& X12
    , B  B% `7 [3 c+ @- f13/ f' g5 ?7 Y" m5 b
    14
    & e1 b. U% S6 K8 Y! g; d: G$ k15! x; D! }6 L1 @3 E
    16
    * W8 e6 s# a8 q" [5 ^' ^
    ! R3 k5 g- ?  t6 x3 d8 r
    # S# b2 U8 z# q# z- r' f5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    7 d5 G8 h% P+ i7 }$ w8 e8 pmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
    + f3 w* f5 P8 ?5 T' {( I1 M# 主要的图像识别用resnet来做; h" N) @% }5 r4 c8 V8 F
    # 是否用人家训练好的特征
    ( c" n" e/ g; P8 I" f1 yfeature_extract = True
    1 @$ O. t2 Z$ |% v1 M1
    6 f; Q! L8 M/ e, O1 {9 y) U/ |27 B9 h) X) P5 {$ r( n7 K0 N+ f
    3
    . d* p/ M- _8 M+ P8 T4* `$ S; Z( E+ D: Z9 C5 ^
    # 是否用GPU进行训练# A( [. d4 r0 A  f' g. u4 C
    train_on_gpu = torch.cuda.is_available()/ ^' J+ T' t6 ?* D+ _% i% R* H

    - ]" H( |8 D/ B1 r/ ]3 n$ B4 ]: \if not train_on_gpu:
    0 q$ L$ C' \' e5 k/ ^9 g2 f- p    print('CUDA is not available.   Training on CPU ...')
    ) m* H! G# O4 r1 Zelse:% p3 Y' `* Y& R+ g, E+ H
        print('CUDA is available! Training on GPU ...')* I1 z, |$ H. i% @5 v9 C9 E
    - B5 k+ [, @& A" l" G
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')$ o. t# A. ?8 b; v8 E
    1( r0 j6 [# e1 i" h6 w+ T9 w4 I
    26 e, v: ~- `: ?: E* p& z
    3
    5 W+ k3 @3 x$ q% A4
    9 G0 z7 ?( ~3 r, ?' \7 p/ m, g& I5
    5 t& f4 L; ]. \4 _( s68 p- t5 r. @2 Z  W: ^: h# {- n. H
    76 u2 E5 Q! X% z0 W% d0 [9 O3 x
    8  a( G0 A( l% G$ z2 Z. @
    9
    1 T9 a/ X7 c3 ]  \CUDA is not available.   Training on CPU ...$ V8 x$ A: `7 H! K  Q5 @- S
    1
    6 S" O4 y& }$ }  Q# m# 将一些层定义为false,使其不自动更新
    , A) Q- H. x5 V8 ]; c: sdef set_parameter_requires_grad(model, feature_extracting):/ ^9 U$ e, Q$ B4 X& e8 \7 U
        if feature_extracting:
    : U7 d% M: R" A2 ], ^, z        for param in model.parameters():
    - D5 _6 o4 N. i/ }; ~3 e& }- m            param.requires_grad = False
    ) E' R4 r: x2 F7 L/ _- ?# [6 I1
    9 t6 V( d$ c3 M) b% u* f3 V2, T' B. n+ w( b" b: j* m$ [
    31 K2 y: Z% o* d2 z) y
    4( E, z, `) ?2 k9 s; j
    5
    4 ]  m/ T2 N( [% z# L  |3 S& l# 打印模型架构告知是怎么一步一步去完成的
    3 ?0 v% U2 ^: I( s# 主要是为我们提取特征的
    - Y; f* c0 Z1 i8 x' p5 Z2 x# @" d- z( ^! L( b0 Y
    model_ft = models.resnet152()
    : [) n- K+ {% J; @0 D- nmodel_ft
    3 y9 B( ?' r+ H* J- q  H2 ]1* n+ C' t3 H8 |6 {9 ^& `# Z' p/ }( k
    2
    " ^- ^( r/ A* h+ [0 u: `3
    2 B# [- ^2 L0 g" w. |  X, A4
    5 C4 X$ Q& p, _2 a0 }' u: n3 U9 Q5
    . f- I/ V- o' L$ }7 [: q2 OResNet(, P% T+ W8 m" `7 C: p/ a& y9 {
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)' O. w/ h( O6 Q1 K6 T+ ?
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) I+ E, z& y. \) w
      (relu): ReLU(inplace=True)" k" K/ @& y+ s9 g
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)) E6 \' b) E3 b4 `4 h, [3 v
      (layer1): Sequential(
    ' r. g$ Y3 ?/ ^2 Z$ L. z. R" B7 x    (0): Bottleneck() ]; z4 N0 q. N$ Q; Y" B
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)4 X! m2 ^0 T8 |$ n" W7 h
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    1 B5 v# l1 @% G' h4 X  [8 P      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    + ]2 g/ ^/ S( k8 z& G; ^      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ' R# v% X; U# b7 J0 |      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    ' Z* w% @1 y# u( H      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)' D8 T; A. `( [. n
          (relu): ReLU(inplace=True)/ \/ ~( C# {4 E( U- L
          (downsample): Sequential(
      }* e! Z5 _$ b        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False). n& `& k; B  H$ }( W% }/ g) h
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    + A+ d/ u2 N4 k, N7 o$ `) Q      ); D" D( I. m+ m. v' O  _! h
        )
    , z; x1 F# T6 {$ G3 R! x7 p/ q. A中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
    # B3 \* g! }, `: y6 Z! R    (2): Bottleneck(- w% i# O. c! i/ |
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)' E, x. X( \! @3 x
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)( K7 C0 ]6 r$ {, z# C: \$ }8 c( e
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)  U- F0 }+ ?" h4 _
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    % ~  U& T  G! z( i: K7 r6 q      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)( Z/ j- P# i/ n% W$ |
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)4 x) v8 ]" E7 {
          (relu): ReLU(inplace=True)8 H. n: L4 t& x, M' ^6 C
        ), W0 y" {  Z$ _
      )
    ) p* x6 q" r$ F, p# N/ p6 S  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    - M) q+ v/ N3 d) X/ }5 `  (fc): Linear(in_features=2048, out_features=1000, bias=True)3 u/ S* k5 E: i, o+ ^$ P3 W
    )& o! f2 E/ _, F

    5 k6 L% _/ S6 y& }& e( H/ N11 P1 @6 V6 W: K" G/ ~0 j
    22 Q) g" U  T7 t" j
    3
    7 D! G6 e/ F1 a; `4; r; w; E0 c' w- f( D$ B
    5# W+ {( o2 H# T$ l: ~+ U/ p. x
    6
    ( E# t' K9 Q% \! r7 A71 X5 G: n6 H8 U7 v- x/ g
    8) Y4 F2 N/ E4 [
    97 N9 p" L) t! Z3 }3 z2 t
    10
    : \' R/ \4 R( N( c# Y6 J11: d& u4 _8 A' _) ^2 {" M
    120 l9 [8 k# _/ N* M! I# [
    13* O0 g2 f& S2 k% ~9 e/ F! T; d
    14
    $ Z* w/ }" m  n6 ^9 B3 T15% }6 k: S8 n, o  B+ c
    16# _/ Y" ?, f3 ~2 J' Q6 V; p
    17
    ! {+ I, o1 I$ H6 B) y5 g185 T3 C4 y8 ]9 C4 n: g8 H3 c  J
    19
    9 }7 v: D( t( c- {1 Z  |206 C7 P* T' M( q/ c3 }8 B
    215 q  @7 T9 U" O, m1 O. w
    22
    5 L6 V2 v( ~6 a23
    2 e3 b4 R- M7 j$ i; I8 I240 ?0 A5 t! }& X+ }# A: W. e
    25
    ; s! S$ I  D' E# T1 n3 o% `7 I26
    ; j# b  ~2 j" N; J4 D5 ]27
    4 W9 a. r/ |, u) h, S0 Y' E28" F' w. Q8 ~: }+ p/ a
    29
    $ O, h  k7 Q. N4 N30
    2 k* r9 A9 ?6 w$ g9 w4 N31* y: P& b+ x7 s! E8 c' F- y
    32! q! n- n- q" b9 z" _( ^
    33
    / {' t4 Y: }) V  I- B& v, J3 e9 D! M1 E最后是1000分类,2048输入,分为1000个分类6 ^5 l0 J! W9 _% C! w, Z! ~
    而我们需要将我们的任务进行调整,将1000分类改为102输出9 x1 d+ `% {; p, O
    / V4 H: K9 x' l3 h0 @
    6.初始化模型架构
    : p- l. n8 S. \8 Z4 f' N  Q步骤如下:
    % i6 T3 @! s; V0 \) Z
    ; v1 h% Z' U9 s6 o将训练好的模型拿过来,并pre_train = True 得到他人的权重参数3 g( k& t( s" u: i" d) O
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)' Q' X6 X% {6 Q
    无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
    & Y# f8 B) ^& P. h官方文档链接
    % H! q, ^: S+ Ihttps://pytorch.org/vision/stable/models.html
    3 Z5 S& H6 X/ z
    & Z  R' `  Y& @# d! c# 将他人的模型加载进来1 D5 H. t4 K) s% K* y
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    - v3 R4 r( @4 a  M; @) Z8 [    # 选择适合的模型,不同的模型初始化参数不同" ~6 n  O  n% n4 ?
        model_ft = None$ z4 D- O0 u# A' ]4 B8 ?
        input_size = 04 ~. ^* E8 U: U# q, B

    # ?. w% O# _6 @  ^( _6 @$ i3 {    if model_name == "resnet":) B# m+ ^. u' U
            """0 M) T5 r) Z) Y: n; Q) C* E7 k" _! i  W
            Resnet152
    4 @% A7 a# m' n8 i* i2 D        """
    * h/ }2 M% G" k! [$ W7 n* S( _
    $ o$ A8 d; U8 k" [        # 1. 加载与训练网络
    ) ^! ], `; |8 G: ^        model_ft = models.resnet152(pretrained = use_pretrained)) f2 }/ N# V& Z0 E! B1 h
            # 2. 是否将提取特征的模块冻住,只训练FC层$ ]( b" s7 a3 _  Y
            set_parameter_requires_grad(model_ft, feature_extract)
    . ^& P: C' g) c3 Z6 a' ^0 |( D        # 3. 获得全连接层输入特征3 Z" r0 G& G  l: s- }
            num_frts = model_ft.fc.in_features
    # S. k" J/ e2 y  _1 E* H$ C3 I8 H4 Z3 W        # 4. 重新加载全连接层,设置输出102
    4 y- k( s/ V) L2 q7 M0 g        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
    & n7 R8 }2 ^3 r" C$ i3 z. V# g                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
    % t0 g9 ?4 O" j2 A; p        input_size = 224
    * g2 `2 t: S& [/ n# i+ N2 x% y2 J/ y
        elif model_name == "alexnet":
    ' [! b8 o- t2 d        """# Q# S0 r  \- n- F
            Alexnet0 w' H! C7 _- X0 w  P2 G0 n" J
            """
    ! t" G. m1 Y  J2 H        model_ft = models.alexnet(pretrained = use_pretrained)! a, R2 `+ Z6 Z* I: k7 V% S6 K; x
            set_parameter_requires_grad(model_ft, feature_extract)* Q9 s/ f4 A; Z/ S
    $ h0 u1 t+ R; f' i5 K' i$ c) v
            # 将最后一个特征输出替换 序号为【6】的分类器) Z$ L6 D. S; m3 e, |  L
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入
    8 W# {7 W) ?" Q4 \- _7 F) v+ ~' y: ?        model_ft.classifier[6] = nn.Linear(num_frts, num_classes), g! w$ X5 _7 B4 c$ n! o. n) V: B
            input_size = 224; w. b' |3 Q8 D( v% s; r  W9 G+ X

    3 K% W! a: b0 S3 R& w) r    elif model_name == "vgg":
    5 T9 ]  `( J& }' Z: q/ \9 h% z1 O        """/ S* e) V+ @# f- U  q6 c( j
            VGG11_bn
    4 c! J" S$ z$ B* d* V' p" X        """" r; G0 U1 c+ y
            model_ft = models.vgg16(pretrained = use_pretrained)" z* h* N/ w1 X
            set_parameter_requires_grad(model_ft, feature_extract)1 j  P9 S2 L( Q5 Y! W
            num_frts = model_ft.classifier[6].in_features
    & o" f  Q% h0 o5 Y        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)' M+ n+ i, G0 x. [/ m
            input_size = 224
    " i9 v2 F  z$ H  A; i- }( y* q4 x- c/ \! i
        elif model_name == "squeezenet":
    ! w8 }4 U" G" T1 {! |: @        """
    - r! w4 \$ [4 w  r2 w- c" x  Y: J        Squeezenet: G1 V- ?' F2 U
            """- _0 e& d, b: F1 y
            model_ft = models.squeezenet1_0(pretrained = use_pretrained)
    2 M4 `) d* ]: @. G6 ]" w        set_parameter_requires_grad(model_ft, feature_extract)% {* y& M3 ]9 @
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))3 c$ w) A5 {; q6 D/ j! F# R0 o
            model_ft.num_classes = num_classes% s; L! g, {3 b% E+ u2 i& B
            input_size = 224# _+ q6 ~: |8 W/ P6 d

    6 P  }  n7 ~! p1 x3 U6 H    elif model_name == "densenet":
    6 u" ^# Q, p9 }. k) m        """; \5 n; t0 f: K0 G, y
            Densenet6 w$ [2 p! N* V% P
            """! G# D/ y5 r: e( e' [
            model_ft = models.desenet121(pretrained = use_pretrained)' Z  s' A* S* E' F$ O
            set_parameter_requires_grad(model_ft, feature_extract)( o1 d( |& |! M6 L4 r
            num_frts = model_ft.classifier.in_features
    - e, E) ^! }- V1 {& i        model_ft.classifier = nn.Linear(num_frts, num_classes)5 `, b+ t% ]- b2 n% g
            input_size = 224
    9 `- L4 N' y5 w# k7 t& t+ R, ?# v0 Y2 ~7 r8 M' S3 k9 H  V
        elif model_name == "inception":
    ; J; u+ t8 H- r4 o# A' e        """
    2 |$ U5 j- f: G- P. z2 F" [        Inception V3( T1 V& U+ F: _" y  R
            """
    * O0 p5 C& n/ L# o' C        model_ft = models.inception_V(pretrained = use_pretrained)- L. l6 ^+ l/ @  P4 o2 ^
            set_parameter_requires_grad(model_ft, feature_extract)
    . o! a/ k" p! J( Z0 Y/ D- Q' t8 Z3 p1 P" s3 C  {
            num_frts = model_ft.AuxLogits.fc.in_features
    2 G% K" J  r7 P! ~        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)3 ?2 T+ D; f% {

    ) X5 i* k3 V! H+ t4 }        num_frts = model_ft.fc.in_features
    1 L  U4 G" v* N0 o6 S! i; D* P        model_ft.fc = nn.Linear(num_frts, num_classes)9 X% `0 }! a+ D$ s. l" c
            input_size = 299/ \6 T; ^3 u: q, H. J

    7 F; U7 r* k; b9 U    else:; ?! ?3 W5 q8 X; S1 f
            print("Invalid model name, exiting...")
    4 A" G5 h: V% ]2 s! \        exit()
    ) V$ `0 H2 T3 m: A
    1 a3 f5 i0 G( J    return model_ft, input_size% c+ \0 Y+ H4 R7 Q3 t% d' e
    / H8 }/ L+ G8 a
    15 _+ k9 }' [! X0 C2 _9 g
    2$ x/ G7 y. \; i* O
    3
      N5 g$ r$ e$ R. K4 I6 `7 D% c8 @4. a% z2 y3 C: B* f3 i
    5
    4 O/ O# l& r0 X$ D5 T6
    9 V& G; W% S' O. Q. r. X7
    2 u1 L2 G) q3 Q& k88 U- k' f9 S' i
    9
      T% N) U/ C+ n/ Q- {& A10
    & {* D1 ?6 g! `+ e9 F0 r5 [- f11* ?! a& K) ^- t: K3 D7 _- M
    12
    5 F# {  v2 j0 k. O( i0 i! E136 [" p* C3 [1 X4 O% i
    14
    ! E8 ]3 T1 _5 \! K: B6 H15
    % b  Q9 n* S, o( _16- u% ^5 B$ f9 Z# D! o( I
    17& M: ~$ ?0 r* x; O
    18
    8 B/ i% |! C! S! j/ c$ V2 Z/ ?9 \" ?  ]19
    1 v+ o9 z) ~7 ?/ v0 N! h8 p* B200 H" ]( ^7 p5 L* d  ~" Z6 `
    21
    1 G3 e/ v  i( L2 S" E- @' V& a224 @0 r" `5 v$ b! ~" X. {
    23' P: e0 W; d, V9 E% M* v
    24
    ! s) R* ?3 s& a$ B; k7 b25! Q  [; X- D/ n4 a
    26
    $ h  u; m" c# Z* }- U27
    4 B- ]. G3 a- N5 l3 E" Y28: @+ G" z) h, q
    29  j. q* D6 A/ {+ W
    30
    ( }5 q8 M) O8 O  e# e6 X( B. [2 {) O: \317 F% h% d8 M. D; K
    32
    4 b& V0 ?% K3 J8 |, U' f- Y33
    $ Z7 L) O  T& A% W( q" T  P34
    # s" d( w9 o! B, F35
    $ w' f: \# r  Y7 g36! m9 W# }! {% ?: U/ i. e1 K2 r
    37
    . v  y& l5 ^' e5 B38* W/ ?8 x1 u5 L- a) g
    39
    5 h3 P. d3 _" E2 k! V% u40
    " ]0 ~# F3 q3 @9 n  v& |2 ?3 Z/ F2 R415 L# C, M- w$ P2 E
    42
    8 n: t) r* }6 T: t43
    2 p$ @0 I# C: F8 U( [$ K& @44
    + ~" u  X& B/ ^$ i( _45
    : N& f% S+ ~7 y9 B4 E$ m46
    $ i; u- W* D2 G477 h* Z8 E% U2 H: m
    48+ z' x. Z! |4 T. m
    49" V* f/ g' _7 p. c. X  f
    50
    . m& [+ v- w" Q( p- }2 ]51& i6 Y0 t7 X( l# S' y% _
    52
    & x5 Y% f' M. J( V2 ?53" s5 h& \) J1 x5 T* N
    54
    + D. s  @! F& x55
    5 H: D* j3 J4 q. w, a; u$ {1 v6 w56
    1 J8 n% }3 R- [) d% \57
    - J: E+ _, k# H9 c- L2 B! u584 a& D+ U. Y+ T# y9 L+ o5 {. c
    59
    + K! G: a/ H, @60
    % v$ m9 }) n" {1 F) j1 P: {% @61
    / d0 o* P2 }1 a( O626 _) U- ^" _6 L
    63& ^6 u; j# f4 L) m
    64: T: B  m- O8 R! C# O* `
    65, g. v) Y+ F- h5 S2 y( a" i) l
    66
    7 M2 \5 v' L0 u% S- b, \4 ~8 U- V672 e$ D6 ]/ a8 ?: d  z: g
    68
    9 Y" [3 N: {) _69
      X7 P. \! _' X! ?4 ?8 O, n70
    7 `3 S3 ?% ]: G* L# T5 k+ M71" N' E- `3 d; n3 ?: K  e8 B1 h
    72% b5 O' t4 O! S5 P* o
    73
    / p1 E6 w/ e+ B& O0 c74' k, K. t; d$ s: E+ y% B3 P6 w% t4 E
    758 _- I+ c0 ?7 F# W9 L1 M$ w) Y( r
    76
    . X3 d+ N$ j, q" Z4 {+ H77
    ) [* \0 h+ J( {% l/ o) Y6 y78" I$ z; l5 T! C* Z/ N8 ]( v3 W
    79
    6 k2 s7 N- W" O. c4 `+ h807 ~$ q2 Q3 \4 \
    81
    % B1 M, H/ Q) K2 R- d$ h# d0 j82
    : F/ X- C9 G1 i* d; S6 `3 i" k# a83! s5 B( m7 H) w; |; V! M
    7. 设置需要训练的参数
    2 K! I6 x9 M! F+ T* ]# 设置模型名字、输出分类数' w+ G0 X9 \8 t# k
    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True). e" |! v. [! ^
    , n& o; _, g9 _, y; y' |& Y
    # GPU 计算/ i& f% t8 z/ h
    model_ft = model_ft.to(device)$ i5 y! I- Q  w$ P5 j6 a" r

    ) z, Y8 D( x9 F- s# s/ ?; H9 I# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取. G3 r) L2 n% A- `, u; V+ f; \$ u9 `. @' w
    filename = 'checkpoint.pth'
    7 a: S* Z: N/ M! b7 ]4 _- I5 J5 h' }! U( B, N
    # 是否训练所有层3 s! l/ @1 o8 n4 y
    params_to_update = model_ft.parameters()4 Q0 J& |0 ?' V9 C
    # 打印出需要训练的层
    + K8 ^6 `% @$ a8 i2 \print("Params to learn:")' Y1 G* c+ O7 ?
    if feature_extract:/ W1 p( I" v- X( f  V# K) ^7 f
        params_to_update = []
    5 ?. k* U% o# X2 Q4 c0 j    for name, param in model_ft.named_parameters():
    , z6 c- G* B3 @8 B- Z9 Q        if param.requires_grad == True:
    + u5 A' k/ {0 [            params_to_update.append(param)
    1 V- H2 p* o7 E! a% R            print("\t", name)/ d9 V3 l% ~4 ]  ]; m0 V
    else:
    * G6 F" y+ _) z9 k! N3 p    for name, param in model_ft.named_parameters():! J% _; x8 k8 x  _) R) D$ E. E+ R' J
            if param.requires_grad ==True:
    ; x/ c& V- {+ M2 C            print("\t", name)1 i' z0 I2 M% e* S- L  |4 r. X$ g
    ' I! N5 `$ A. [* B: j% j5 M+ f
    1
    2 Z& A7 S. `6 i# \& o4 A( L! M3 D2
    # b) K  i- v" U7 V* }, U: V' f3
    4 o3 O; x8 K4 C' m' @* U" {7 Y" N8 m4
    5 `4 @( U, F& h2 N3 L/ {6 k5( \6 X- l/ Z. M8 P
    6/ l- n- c/ |1 @( Y' P9 A: X- I
    76 Z# p; |2 M  P: v- X
    80 z, U6 W5 [% c( Q3 K
    9% j# v: K& J* O1 ]0 s
    10$ r  r" j( j. }' S2 N
    11
    4 A, ^; g9 f* V5 H124 ^0 t; F4 i; W
    13, g! \) G+ O  _* e2 g
    14
    % I6 j# S. x' C& g/ f& K/ v; P1 E  b15- ]( S3 M  T4 ~
    16+ S) M( K2 U4 ]4 V; t0 s
    17/ B$ [* R2 I1 B" u( g
    18
      u9 U* Q& j* x9 u19
    / n6 s7 W4 }, F0 @" b) ^3 L- }20
    ' @! q4 N, a2 C' S21
    8 F) y* p1 ?1 h" Y: c22. Z- L# i# H$ N0 Y4 W9 z
    23
    % R( r, T; O# C( G) k( u/ K% @. ~Params to learn:3 S, v; ~! {7 R% N% r' U) Y+ c6 U
             fc.0.weight6 G* C- f% O  N9 W! \% v
             fc.0.bias2 i  R. s9 V$ z5 J7 O: @
    1. N8 `# K1 G; Y" K1 x
    2
    3 E0 O! b7 s% V8 _1 R1 I5 @8 s( c3
    0 o: Z, u( t2 [7. 训练与预测
    - k1 y+ \( h; M  n! t! V. `7.1 优化器设置
    , j- B" Y9 S; Z( `- ^0 @6 X: x+ ]7 I# 优化器设置
    % ^! T" s7 X# V% m" V9 toptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    # }  t, R& f4 Z' n/ @) [8 r# 学习率衰减策略
    1 e+ @/ T( \- U0 J1 T% ?scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)" d# n1 L: D+ Y
    # 学习率每7个epoch衰减为原来的1/10; i2 `( p3 `: _: ]* O4 @
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    ; t0 J7 n$ L, O' F  k$ g) h
    : E. V. f1 \8 |4 n8 rcriterion = nn.NLLLoss()+ D- b& K5 Q! L& i; d  q/ d
    12 j4 R3 u, s; f5 ^+ [, |
    2) R3 N6 b& d# J# g
    3
    $ v0 ~* w) a% g/ I4
    9 R  t# Q) e# J3 S5. `$ e3 v7 X2 x6 r" L
    6
    " P! N; _! z- V0 ]# T0 V7
    4 g# g: B' y5 b80 K$ |: W# C; j7 Q
    # 定义训练函数
    0 K, J  [' P$ P5 X) L#is_inception:要不要用其他的网络
    * E( r/ a3 D) [2 t- V; s  o+ ~def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):4 q1 v% ], [" g+ @- n* V
        since = time.time()1 y  ^4 c4 ?5 `, [0 s; k, r9 d; ]
        #保存最好的准确率
    % W6 q: G6 D# j" ]* ~" c    best_acc = 0
      r* C- Q4 {/ y" `# b  P8 d; [: x    """; }  I" l. j8 i
        checkpoint = torch.load(filename)6 \. w6 ]( O0 t3 [
        best_acc = checkpoint['best_acc']
    ' j8 L1 r* |5 i- B1 d" L$ g" G    model.load_state_dict(checkpoint['state_dict'])+ K7 c, O/ G; ^  j: g/ F5 d2 B
        optimizer.load_state_dict(checkpoint['optimizer'])
    + T5 P7 R) y! S; N8 R) I    model.class_to_idx = checkpoint['mapping']
    3 V6 d. ~, S3 _2 T    """' R. A$ ^, @3 l- }5 ?
        #指定用GPU还是CPU. N" g( l; S0 }3 N# Y& Q
        model.to(device)) p/ P$ [$ ~0 m* `7 S+ H
        #下面是为展示做的* V* w7 g) Y; Y3 q% Y  Q" e
        val_acc_history = []# R" R+ A* t3 B9 Q- c1 u* ~. _( p0 {
        train_acc_history = []
    + R0 {" g$ ]' U6 j. I" P    train_losses = []4 f. b$ z' y9 F
        valid_losses = []' `6 r1 ?6 y) }3 R" z" \% |! y, A5 J+ ]
        LRs = [optimizer.param_groups[0]['lr']]
    - M: J3 y& k* h# c% Y    #最好的一次存下来8 u% _5 M$ p$ E: N; @- t
        best_model_wts = copy.deepcopy(model.state_dict())0 H6 |- H7 N# j& @. [
    2 v$ A& C! r7 [  J' {
        for epoch in range(num_epochs):
    6 f8 n: v( ?! H0 [. t, b        print('Epoch {}/{}'.format(epoch, num_epochs - 1))" t- l4 G! Q+ U) p6 B. o9 ?  |
            print('-' * 10)
    3 Y! F( h: d. e( y+ i( {3 [& {5 @1 M
            # 训练和验证$ J% F! d$ T! ~9 j; ?& H
            for phase in ['train', 'valid']:
    + m) ^& }6 G6 \1 e, u% i& C            if phase == 'train':
      U, o6 u6 E! T6 l1 K                model.train()  # 训练. @0 Q" A( D7 |( W% M( N( ]
                else:, X9 C1 v" S% w3 P& H
                    model.eval()   # 验证* n( g9 |# b; e1 X0 @$ J* n

    1 l/ K3 f( i$ c$ A3 C  j! j            running_loss = 0.08 n* b; A) R6 X; F
                running_corrects = 0% n6 G+ h% `' F2 u8 d" O: h% Y
    : H+ v9 Z+ o, E7 ]2 V6 x+ B
                # 把数据都取个遍! j' j9 U# _5 W9 ?2 f9 S1 o
                for inputs, labels in dataloaders[phase]:
    9 S9 }( s5 I5 `2 V, g9 u" x                #下面是将inputs,labels传到GPU" ~0 ?- K( G4 T& ^+ d/ s
                    inputs = inputs.to(device)2 r- z3 |  y  @# ~  H8 N
                    labels = labels.to(device): y" v) o/ R! Y5 g# }( M. `8 m

    1 p. B+ `0 z! f4 N1 ?$ N! H1 L                # 清零, }) G3 r" Z% [1 h, Y
                    optimizer.zero_grad()
    6 @0 @, ]: `0 Y2 F" ]/ h                # 只有训练的时候计算和更新梯度
    4 K' K; K! V) b7 X+ u% O! c6 p1 b                with torch.set_grad_enabled(phase == 'train'):- t" P! `5 |+ w8 [( h
                        #if这面不需要计算,可忽略
    ' G/ A! n& ?3 O' |                    if is_inception and phase == 'train':; G2 }: E3 L( y5 P( B6 c& K
                            outputs, aux_outputs = model(inputs)/ b$ {# S4 ?3 j! W" @/ G
                            loss1 = criterion(outputs, labels)* j- U' K5 t2 j
                            loss2 = criterion(aux_outputs, labels)
    2 n/ D7 J) a% ?9 r* l7 P  n                        loss = loss1 + 0.4*loss2
    " V% o" ]8 B5 j- o7 h) z                    else:#resnet执行的是这里' p2 s2 ~! |" y8 d
                            outputs = model(inputs)' ^2 z3 U9 i3 B6 ~
                            loss = criterion(outputs, labels)
    : e$ s  l% V  Q& U! g8 y) v7 E: @" X! q# X* E) e
                            #概率最大的返回preds, k9 D- k) W, k
                        _, preds = torch.max(outputs, 1)' K2 ]6 X! ]7 q* U1 c. O$ _

    + k4 z" G2 b: K                    # 训练阶段更新权重6 C* Y6 N! I+ `- V- ]
                        if phase == 'train':# l+ `- N8 k8 N! ?& G9 W! R
                            loss.backward()
    ( G& b3 O5 r  {: _  Z6 S                        optimizer.step()
    / y3 t5 ~0 u: O8 H" `. A+ a' {! d6 {+ Y; c2 s, R5 x
                    # 计算损失0 {7 f4 @: G  U5 n* F) n* i
                    running_loss += loss.item() * inputs.size(0)" n7 j, ?- u# M5 r( m3 L1 Q+ v
                    running_corrects += torch.sum(preds == labels.data)6 E1 R0 p1 v" b: r& o

    : m) S, |$ N3 h* u9 j# ]. X            #打印操作
    / L/ B% R7 ~, {: L. m9 A            epoch_loss = running_loss / len(dataloaders[phase].dataset)9 j2 C" {4 y8 q0 k7 ]
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
    : r7 `: b+ C+ a: N6 J) O" e- G" [* F- c& H
    8 s6 I! P6 c( Y* l2 Q; n) M8 v% b
                time_elapsed = time.time() - since
    5 c+ k2 d" L' e* z2 D. D: P$ o5 m            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)), P; x4 Q, _( j) q( r8 s( B- f
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
    6 T$ l7 o# Z) p% q6 v5 z
    8 X. q4 P# u. N( O  g# {" ?; V! e$ ], i& a$ l1 U  i
                # 得到最好那次的模型
    . u! g, \% N& P- u' m% ~            if phase == 'valid' and epoch_acc > best_acc:
    $ E' ?4 g3 U6 {2 y% x7 r; {' O  p                best_acc = epoch_acc8 l! e5 c8 {0 `/ a# C" D
                    #模型保存9 A- L1 {! J* Y
                    best_model_wts = copy.deepcopy(model.state_dict())% |( f# g/ t* W; H, C
                    state = {& l1 m9 }: H- N/ n+ l' j
                        #tate_dict变量存放训练过程中需要学习的权重和偏执系数
    ) h/ s+ {/ B$ z0 f. A                  'state_dict': model.state_dict(),
    7 W3 h. E4 m! q4 K                  'best_acc': best_acc,2 @7 w( `2 V" D7 c
                      'optimizer' : optimizer.state_dict(),
    5 c: \- c. D/ Z+ e  }6 E                }
    . |9 k2 e% l/ c- W1 \9 N: A, p. X& ?                torch.save(state, filename)5 y7 Q+ c4 i# h% A. I; X. k: I
                if phase == 'valid':/ j' z; V4 Z7 R
                    val_acc_history.append(epoch_acc)/ w7 n, |6 `" K: r7 }! T
                    valid_losses.append(epoch_loss)
    2 D+ E1 p% z7 H7 x, q' v/ |                scheduler.step(epoch_loss)
      i3 l+ U+ A  k$ N# i' E+ p            if phase == 'train':
    # R- m! Z; P) y' V6 O) _  c                train_acc_history.append(epoch_acc)
    % Q& ]& F  x; s9 K* f# @                train_losses.append(epoch_loss)
    0 {* L! J. B1 x1 C8 c: H1 N+ d. j$ W5 e/ q' ?. ^2 ^& m
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))8 Y6 ~( v5 w; O5 l8 c; |/ W
            LRs.append(optimizer.param_groups[0]['lr'])7 i" A/ z" H) H3 @# Z! d: p! J
            print()
    $ x) D5 _5 l6 s8 F
    * Q0 o3 c9 c$ x# k7 W% e    time_elapsed = time.time() - since9 o" D5 N! Z6 K/ ]
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))8 N" H1 B9 u( U" n" [4 e9 H
        print('Best val Acc: {:4f}'.format(best_acc))) @; \; k3 Y4 ]" w, z. v7 _

    9 Z8 ^/ ]0 V# z' Q" _7 a    # 保存训练完后用最好的一次当做模型最终的结果7 V0 }! s0 Q# u1 Q/ m2 W
        model.load_state_dict(best_model_wts)% x% p! r9 K: ^' J! ?
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 6 }+ O- T; B. d3 f, L7 \
    , [' Q* K% g" }; V

    3 U( o- W) G1 b% `# ~1
    ' }- d' Z! c) E4 V+ z' l. ^; d) X2
    9 n4 B) g: Z. X% O' T5 R3
    / B' m- u* _/ p5 v! r! f4
    * F2 D- k9 T: @8 c% R$ D7 Q9 a! o5
    1 g6 ?2 z% |' ]; J/ {69 J3 |/ Q3 z( q
    7
    : c1 d0 m1 F+ f# _) E0 S1 a8) U0 U5 j) ?  O  r/ G3 N2 G9 V
    9
    6 F& C- v" U' U8 ~& a, w4 X10/ }1 |# z" V5 R$ q  ]  c1 J3 q" p
    11
    8 W1 B4 x2 q) {0 i0 z( z12
    0 ^1 o9 {6 i2 x( C! r) @13+ e) d' k- y; K& z% _& R7 o( F
    148 O9 v) n0 Y+ ?6 H! }$ W  q1 d
    15* Y) k. J, g5 X5 S
    16
    / l! M" F$ O  P7 S17
    ' M: ^% E: D2 s3 X* e" N! |7 u18
      L/ I% m& R5 o- T19
    0 A- O+ e. l" f, j) n208 T4 L+ S4 U0 @- Z8 d
    21
    5 F" k4 F. {, Z3 Y22/ q8 j2 Z5 |0 w/ S: }* @4 J4 l  C
    23
    # ?6 W" Z: d  x  |' n24
    # o! Y  b; Y. F) ~9 i' Z25( ]9 w* {2 _5 B8 S, [' Q" j( w+ f/ p
    26/ z* q3 M* A  L' U3 \. z; v
    27& T9 F% H; M$ ~0 i7 ]$ G6 Q8 n
    28" J! X9 Q  t, R
    293 z8 p' q, G! A) K2 y1 c9 d
    30' o1 g( Z* ]" {, o0 {, W1 l3 c
    31
    . z3 W: ^+ n, Q! |/ q323 X8 X, O% V* ^8 D: R( O8 w
    33$ u, ?: Y- }- K5 V' c
    34
    2 v3 D# @+ Y9 P7 [; {35, }  @; Z1 N8 c' z" g4 e
    36
    & P. y6 i* s! _, A! k% @37- j; C6 @6 G3 I
    38
    5 P. o$ H) K6 A4 J) s; H% s39  @6 E' l9 ]7 i! m: O4 G
    40' d- m& B- e# x* V
    41
    7 f, d) t4 H% q% _0 R: p42
    $ {0 t- ]3 T1 d: a9 T+ [43
    3 R4 N) w& q/ t, w8 E+ }( T! ^44
    - B$ a. X  V+ m/ C45
    6 `/ Y' O0 O( i# j. Y8 [. P46& L& u( G" h+ H
    47
    ) h- B4 O- R$ l% w$ K8 n48
    ( i+ ^; F3 X3 d( N. j49! g: N- j( y' j6 d7 L7 u
    50
    : W/ I! u+ X- i' @- K* e515 Y: U; Z! C- A: ^. X1 y+ z: R
    52
    7 Y& [5 q% U7 ?) z! I# Q: S53, ^+ [1 u& I) o8 |! i1 c+ u
    54# M3 O( M6 W; P/ ^5 ]! X. r
    55$ J. \  U+ [; g1 }
    56
    4 G1 d2 q4 G8 M* n57
    - N% }" R* L; V58
    5 o7 P2 X# R( x59
    9 U( d4 U3 \" r4 L60
    1 q5 |$ Z3 P( U61- X: w, c- y# j; U
    62* ^% L( g: A0 h6 ~. K
    63
    $ i% ]; [7 ~3 y% a64
    ! p! Q' v/ X. R. H. {. L6 J' m) F65, c. t, d2 L$ M: O+ A, P" w0 g
    66  n( J* L0 Z5 v# T9 z8 X# p  O
    67
    : w- c6 t3 Y8 ]# d* s5 k68
    # b+ B8 U6 x( y2 g2 {69
    " o. z) G2 l9 i, Q706 e7 i  b) p. N* x
    71
    # h/ B* B; E, S0 S4 ]726 r' l' ^7 _8 q& _' j7 M
    73
    4 e8 ^0 ^. K: ]: v* K5 O744 ~. A8 h9 {  `- }
    75
    . d% s8 B% N" y, O( w3 G' H76  s7 z% r: W& c- K: F. F4 b
    77
    2 T% l* w1 Y% a! t& Y* g+ |3 a) X78+ q$ i" R9 Q& w( a$ W5 U
    79
    , r' `) _. N8 c, {80
    5 @4 e" [/ ^1 T1 Y' g2 w81
    , w5 l( I, V8 ~8 R3 \, E82/ c2 J+ ]2 l! l/ T0 M
    83; e) x) P* k+ ]  D9 Y# e& ^1 \
    84
    2 G' w4 f7 w  n+ m: v1 M851 P" B9 k" j1 ?# ~
    86; p( G9 T" A0 m# a
    876 q' C$ c4 m' s% M/ K
    88; O% S  C& I/ h+ ~, `7 ?4 t
    89
    9 l9 _& v4 }5 H- k. w1 }9 Z# a90& S! H# V. D* z3 S) r! a4 p- j* y
    914 Q/ i/ [+ k  U. ~6 R. m
    92  z) r  [9 w0 |% J: Z
    936 S6 w* L2 r- i4 [; u
    94
    3 |% W$ e! V% Z- f1 \& b95
    ) S& \- L$ z$ E. z5 [5 p9 F960 K# F' A- W5 a- f5 A
    979 J5 B  V7 |, V1 g( f# y( j
    982 b; n& o  `3 J1 m# [
    990 Q- K! R- W0 N0 d! n+ S; M
    1004 {; H! v: Z( Z% X/ E9 Q8 d
    1018 y2 d0 a/ o8 }5 \" h# _
    102# Q1 B% [: D* M; _$ b; G$ F
    103
    8 c# f# g% n! ~. u/ A2 V104
    6 }) c4 o% \, P105
    ' d! W; J& ^+ E) `106
    * T0 [/ i: a& M1 ~107
    3 e! y  q# V1 w# x. k108
    ) y3 j3 l- z) k. b! \1 R3 J& n# \' U109% p3 ^. g7 M9 p: t% c4 [  b/ K. D" x
    110) t; X% M! ]! n" `* f
    111
    : q- l: e/ v8 ?' P1122 V9 z4 ]6 V7 Y8 {9 k
    7.2 开始训练模型
    9 b- j+ {) e+ Y' d% S! P我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    , q) f, v9 y3 K2 k1 a! Z: s3 {( p2 g* ^" y, _1 W6 a, T; Z8 X; `
    #若太慢,把epoch调低,迭代50次可能好些
    ; p: l$ g, p0 K* Z( G#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    , N% Y% w3 u0 Amodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    + `- Z% }/ \9 }8 q, N' {- T. b. {+ e9 R" f4 t, A. K/ c
    1
    6 Q$ p/ ?  O/ m/ L+ b2: C1 A, _" I, M3 n) p) W* r, E
    3. T) A& ~! Z1 a% Z# j" `: L% C
    4: I" V* g# n% M1 @8 S- r! H
    Epoch 0/49 r* V- v: R7 H7 e' I' A5 k
    ----------
    ; A" w! l% {! i, Z- ZTime elapsed 29m 41s. l4 B( @( S6 T
    train Loss: 10.4774 Acc: 0.3147
    4 K' n- q4 [% u  L2 M3 PTime elapsed 32m 54s1 ^' K+ Y4 h6 I* ?
    valid Loss: 8.2902 Acc: 0.4719
    9 s  M+ }2 X! g( R: \5 JOptimizer learning rate : 0.0010000
    & V/ m9 v3 S% ^8 H; {  Z/ T
    * |* C2 ^  b' g0 j! fEpoch 1/4
    8 H; ^; [2 v0 @% L1 W& e----------
    ( j8 Y+ I; g8 }( @1 E+ n6 ZTime elapsed 60m 11s8 e% e- q7 w$ V5 T$ H& `; E
    train Loss: 2.3126 Acc: 0.7053
    9 |" E# u% m) ~" u7 W' ZTime elapsed 63m 16s; |: k$ W* `' f( W6 k
    valid Loss: 3.2325 Acc: 0.6626$ b  X8 i! n% w, U2 g3 x, x
    Optimizer learning rate : 0.0100000% t; Y9 I: w* q# t
    9 ]8 R/ [  i+ [
    Epoch 2/4
      m+ B. o9 S4 B/ q8 C/ J# P  A----------
    2 H/ }2 N0 z2 f  yTime elapsed 90m 58s$ Y/ ?% b- |7 U/ ^, t- K! a# w* O
    train Loss: 9.9720 Acc: 0.4734
    4 P1 w+ r1 Q. g8 ^6 iTime elapsed 94m 4s
    6 U( O+ v8 e7 b& ~valid Loss: 14.0426 Acc: 0.4413; m& W( R( E& ?0 b) U, S
    Optimizer learning rate : 0.0001000; C9 F+ I! k1 O6 \: {( {& V3 ]
    % W4 d5 W- y) N/ Q0 y4 l% e
    Epoch 3/49 T9 c* T% A& U/ w5 Y  O
    ----------
    % W( u0 G4 S: g2 y! T6 zTime elapsed 132m 49s! k9 O, I. M: d
    train Loss: 5.4290 Acc: 0.6548* ~) z3 x' C  B7 O" u8 T
    Time elapsed 138m 49s6 F* d; @$ I6 u7 }
    valid Loss: 6.4208 Acc: 0.6027
    - o+ U# M; o2 M& j4 \5 l1 f, wOptimizer learning rate : 0.0100000# G3 L/ L1 R, w$ Z8 F2 I3 e+ W
    4 N0 w; Y; d3 Q! Q6 B+ w4 j/ j
    Epoch 4/4
    # }' C3 `0 V) x* z( k' @* t3 e----------
    & ]- n7 ~' z8 k/ cTime elapsed 195m 56s
    + A* N9 j7 x& ntrain Loss: 8.8911 Acc: 0.5519
    $ n) Y; j$ E# d  h1 gTime elapsed 199m 16s
    2 w& Q* T3 c" Z9 \. Nvalid Loss: 13.2221 Acc: 0.4914
    9 O( [8 c! I% p) a3 l% x& r: \Optimizer learning rate : 0.0010000
    $ Z0 g* Y8 |8 _7 P2 h0 o: f0 L' l* h
    6 W, i& i/ H  n6 h, f% aTraining complete in 199m 16s
    / e! L- L$ D2 M) d1 Z# xBest val Acc: 0.6625921 X# C% k2 v2 E

    3 E8 z0 |0 O* g" @2 w; ]& x1
    1 t% b  ^+ P, Y- a" v2
    9 a5 ^3 ]2 w. J3
    ; \0 X4 n$ d) c4 v4: i5 [6 o* T- x
    5
    5 w! A) W! ?4 y# z1 v0 M6
    $ R& E! {0 O; Y: |. h70 c  A4 T2 {0 a3 T$ @& t
    8
    1 W+ m8 P+ K- D9
    , O' I4 J; l% S10
    ' p/ @: d) H! Q5 l11
    + m6 E/ Q! f$ F, j* o( e1 v12
    ; F" a! A( B( {. l137 K- H, O$ R: T9 R/ z+ h4 ]
    14
    9 X. ?. _5 u2 p15# L. c* K% f/ @5 P3 J
    16$ ~3 z8 F. r- \5 ~- u, ]% t2 m
    170 `# O' H; X; l8 U
    18
    + y9 `, i: a$ w% V  W( V8 b' V199 Y3 q& V* Q+ N& Z
    20
    / I) `0 `; C/ w8 J0 ?21
    ; D* ^2 ]% K/ W6 _9 K  n! ^22
    5 U, J0 M/ V0 j( O- W23
    3 R8 Z( D+ `8 z, |: a5 V24
    , q) Q9 ~: P+ W25
    ( o+ |% Z4 \7 Q& C$ f268 `% |& ~  @" P7 z: F! y( e+ v
    27
    / l- m: S5 P( R28
    & K2 P6 e# N6 Q+ U. D* n29  I7 H! C& ]! {/ M
    30
    : O) ?, m: [2 H, Q1 I31& O7 E' u8 g2 V) b
    32
    $ t, z+ K. B- V$ b/ i' c& p335 A; L7 ~6 n5 F- w, w6 @2 f' _
    34; l, B$ f9 j/ w8 |% r, W
    35# _. t9 |- R; P( V5 \+ e; }* n
    36: ~4 H# g8 u. l% n
    378 H$ S- [$ \- {2 i# _
    38
    / J7 r7 b0 I2 F2 U3 Y39
    1 z- v9 G* G  O402 C; M) e' [, K; c& |9 y
    41
    ; f) r4 t# A2 x5 G3 ~# M6 z42
    8 h) e2 ?- F: f7 _" E7.3 训练所有层
    $ H. g( j% _& S6 E* N# j# 将全部网络解锁进行训练
    2 |3 ?( [0 O# `+ c* {& Hfor param in model_ft.parameters():4 T- _! Y- K6 c
        param.requires_grad = True' x" e3 |% W& n5 k% u( o) B8 [9 S3 o7 e
    9 j# K" P* a8 W: t7 b$ E
    # 再继续训练所有的参数,学习率调小一点\" w( Z0 C9 _4 @1 D% H( ~7 e9 h: B
    optimizer = optim.Adam(params_to_update, lr = 1e-4)' x- x+ D0 o, B
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)+ @, v- m/ \5 k

    1 p8 ?5 E5 y" `" _# 损失函数
    5 R1 c) s, G0 ]7 Hcriterion = nn.NLLLoss()
      C' f' C0 m5 _4 f: Z" ]3 y0 E3 w1
    4 |' F3 j& f# S, \& F+ ^# L+ d2
    8 ~2 j. [, b( W2 N* N# ]3
    * @) F! S/ Q1 j1 k: _! N! b+ m4
    6 I( b- c* H; y% O$ Q5% v. M8 S1 V2 A2 f7 ?
    6; t6 P) ?7 z  U7 F7 v& T8 ^
    7
    3 S# J" O) r" E) g8
      ^2 ^. P0 k1 G* o; e9
    9 E$ {# T$ s0 r( P10
    , k  ~( k  G/ b8 ?2 R, v0 c. h# 加载保存的参数4 z# E" P- N0 b# f6 O8 W5 ~
    # 并在原有的模型基础上继续训练: q: s. R  T2 _- E' V
    # 下面保存的是刚刚训练效果较好的路径0 c; ^6 |  x& V7 {
    checkpoint = torch.load(filename)
    ; L2 h6 \5 c0 d9 abest_acc = checkpoint['best_acc']
    ' `6 |+ \" }8 l$ gmodel_ft.load_state_dict(checkpoint['state_dict'])1 P% D) D/ U6 D! f4 x6 d
    optimizer.load_state_dict(checkpoint['optimizer'])
    + l9 `; F  Z2 A% _1
    % g  Z: S% W% I7 X5 ^% _' Q) a2* H0 e7 D9 D5 W8 _9 w, f  ?
    3
    2 u& i2 D$ F2 H4
    7 E" u. T6 p" A: G5
    & T+ Q" N- i# S8 w6
    7 f3 K+ b: `! d2 I9 L; o74 y7 b) o' U: k
    开始训练, n" Y1 c9 F% D# H  N
    注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
    - _$ Y& x4 c" T7 C/ b% K  Y' G; r  q$ |7 Q8 h7 W  b1 a
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    ( \5 i% S* P3 p7 U. T6 Y1  \; H3 i% D8 v% L1 L' ~
    Epoch 0/16 }0 m% j- }' E
    ----------
    ( `* W4 T' Y3 S5 Y7 {4 g; |Time elapsed 35m 22s
    % }. R, Y) _1 Z/ t5 n9 Dtrain Loss: 1.7636 Acc: 0.7346
    " ~0 J' X% c* ?3 |6 ?$ _Time elapsed 38m 42s( k, ~$ J+ p8 v) Y
    valid Loss: 3.6377 Acc: 0.6455
    # ~6 l7 z( F# f; f3 D' H( O* `Optimizer learning rate : 0.0010000
    : ?0 ~$ h$ n5 t6 T3 d7 o, ]' D" \* w" M3 d
    Epoch 1/1
    6 y( e7 G, t, D+ {----------; R+ [, J" [4 d
    Time elapsed 82m 59s# Y5 m4 P2 H) y. Z4 x  X: l
    train Loss: 1.7543 Acc: 0.7340: a9 C7 E' J% J% i0 S2 `  B
    Time elapsed 86m 11s
    ( [6 r# a5 b2 P; N- }. gvalid Loss: 3.8275 Acc: 0.6137
    " c0 s) O* N, b/ V2 z: t8 N: dOptimizer learning rate : 0.0010000; T9 B% F1 S1 `( |7 u+ \
    3 W1 Q; ?3 S3 }  ~/ L+ }9 c
    Training complete in 86m 11s
    ) ]/ e" N  q7 v8 X4 u; N% y1 o" D: ZBest val Acc: 0.645477  V! ?2 R3 B4 ^, ^- E  T
    ' I8 r/ O. ?9 c3 |* u# N
    19 S9 r/ T4 |  g( M
    2" w% h/ e# D! K' n) `
    32 w( J0 R7 x/ E3 f1 S( m0 P
    49 r, v2 y& \3 w: T
    5$ _" k0 i, p6 G( s* \4 R5 D5 G
    6
    * k, i9 P2 P$ V7
    $ @9 J7 Z2 ]$ J) F3 c89 D( f1 Y+ X! q! N% V
    9
    , }; x* K6 G& M/ Z108 R! Q9 S: I: T* g- t9 H
    11& k7 l0 w# F8 }8 f& {& a( Q
    12
    ' A' L1 p/ E3 ?% X13
    . n9 \4 k0 i4 |# d14
    8 A! ^# C" u6 j* ~2 Y! [" D! h159 |! [1 j( W. C* `  Y
    16
    + G. Y- m$ O( l( l9 n17
    & f( h, `. S5 ^' t183 v: f- r+ m; v
    8. 加载已经训练的模型
    3 t$ B4 I- ^3 J) I  g相当于做一次简单的前向传播(逻辑推理),不用更新参数: [3 @) z7 `6 y

    - f) e3 k" c8 M2 F2 Bmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)) c; z  ]% x3 R/ Z2 r# y1 l5 Z

    , v8 z' S, ]8 J5 @* P# GPU 模式
    % }, ^2 C% q7 c5 [- w3 omodel_ft = model_ft.to(device) # 扔到GPU中
    5 w& e0 ?2 B1 @3 \
    9 Y7 Z( i$ a5 v  u+ t& b: e- B$ O# 保存文件的名字
    * z) M1 c9 l3 z. @7 @, yfilename='checkpoint.pth'
    4 \3 b: q1 `, R( @/ P/ p; f
    & R2 T; s! J' E1 g4 F# 加载模型9 r9 Z. P* E# L* K& A; `: W
    checkpoint = torch.load(filename)" l9 a( j* K7 ^
    best_acc = checkpoint['best_acc']6 M2 R# N4 H5 y. M% }: i
    model_ft.load_state_dict(checkpoint['state_dict'])
    ) D6 f3 j# i- j$ \- e% e! K15 A) G2 k  y. L
    2
      t& x+ e& M, v& Z3% v1 E3 g" [& }; t3 T7 X+ d- f
    4
    1 }# Y4 `% F. ^& \3 M3 }- d5, V7 K" C& W0 @% j0 k# u
    62 E! j+ }* w3 u5 M0 w) c. i
    7
    6 G2 C6 @7 |* W; L& O; X+ N8
    8 U  v: J  q6 {; Z& \! _4 J$ `  d9) G2 |& ?  \& e1 x  `( H* v
    10
    " O$ q  z; R/ }, q11
    - g( z( h# p* O3 I' I4 S4 p7 c12/ B$ S# ~$ D. `3 X
    <All keys matched successfully>
      J6 @" |+ f2 z( x: l0 v/ N. D15 \! t7 }8 w. x) C) @  b
    def process_image(image_path):
    * H" @% M( L/ k9 F, V* N2 Q. I& E    # 读取测试集数据
    6 Z2 \2 h; }. t# [7 z' h    img = Image.open(image_path)
    # Y+ i& O: q+ G1 {$ O5 S; S1 l    # Resize, thumbnail方法只能进行比例缩小,所以进行判断
    0 r* x3 y" |: ^) M5 B* ~    # 与Resize不同' u7 z; C4 X9 G
        # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小; R8 |5 m: U1 @8 M" W: }# h( {
        # 而且对象调用方法会直接改变其大小,返回None
    6 l2 J) A0 \* H: Y1 l# O    if img.size[0] > img.size[1]:
    / r. L( C# G1 A$ o        img.thumbnail((10000, 256))
    : v4 u) O( s" R2 k. K* F  g    else:. G8 @/ `& W  c! d% \- O
            img.thumbnail((256, 10000))
    $ s! o- U& M; i, k
    1 X& c; T$ i8 Q: K. d    # crop操作, 将图像再次裁剪为 224 * 224
    5 U! e! V. i$ u# q    left_margin = (img.width - 224) / 2 # 取中间的部分
    # F+ [( w' D7 X  b( q  f5 |    bottom_margin = (img.height - 224) / 2
    0 p/ L0 G0 ?# X    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度+ J* @& }  T. u% F6 X0 P# A& H% K
        top_margin = bottom_margin + 2242 Y7 O/ d6 D/ \. s$ c) s9 Y
    + Y* w: w6 b  V: w  p- t
        img = img.crop((left_margin, bottom_margin, right_margin, top_margin))# Z: W5 x$ a* `
    6 R- S  m, a% h. P
        # 相同预处理的方法& g9 R0 P9 s5 Z+ f2 T) L# I) e
        # 归一化# u2 x4 J/ {! \% s! ]
        img = np.array(img) / 255: z0 ~- V+ t, y6 z
        mean = np.array([0.485, 0.456, 0.406])
    : t2 N) O, z* q0 x9 G4 J0 J    std = np.array([0.229, 0.224, 0.225])% j7 h  e$ w% h9 e! y) u: C
        img = (img - mean) / std! C8 U8 Y& J  R
    " d0 |" X4 v$ v) p5 N
        # 注意颜色通道和位置; v. c& C6 }9 O; {* o8 T
        img = img.transpose((2, 0, 1))
    ! [5 d" x/ ~/ H7 s$ B/ ~# O0 a, u& ^( S3 H/ [0 e+ n% X/ c
        return img" [+ d0 p8 q0 H! W

    & v/ l% q4 b( M* |% H' Ydef imshow(image, ax = None, title = None):1 F( m. X; a( u+ K( Y0 U
        """展示数据"""+ i, G8 J' ^, N# A) s6 g8 P- |
        if ax is None:) ?& {- ?: `8 [5 ^: c+ ~1 M% o, K
            fig, ax = plt.subplots()
    5 a' E: P6 M0 D! o! h
    ! @" }$ S3 d! N; G3 H4 L    # 颜色通道进行还原5 k( c8 @, h6 W0 a& [. L
        image = np.array(image).transpose((1, 2, 0))
    ) Y0 `. y1 T# E" E7 B
    , z: `& F4 A6 a% F) _    # 预处理还原3 b% ~9 x, t4 q6 T# X- |4 Z& R
        mean = np.array([0.485, 0.456, 0.406])& ~3 d% W+ N  L% _) S( n: @
        std = np.array([0.229, 0.224, 0.225])/ s' ?4 e7 Z7 }) O1 S# S
        image = std * image + mean
    2 I- V* L1 `- [* a2 W; |( q. ]% C    image = np.clip(image, 0, 1)! V) \' J: q8 I' t. D% H' `

      R0 j$ d9 ]5 b    ax.imshow(image)
    : |0 k) U7 p/ h  T2 n5 `# Y* s    ax.set_title(title)) j" ?7 U' u$ _0 n0 Q
    % I$ z5 _/ r6 j/ k& S
        return ax4 O) }1 F6 B5 W3 M% Y( `2 l

    6 H  B3 Z& Z* a; Himage_path = r'./flower_data/valid/3/image_06621.jpg'" |/ U% Y' r; `0 B. y* q3 m: S: t$ k7 }
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理' B8 D+ J0 x' D5 O# J# }
    imshow(img)2 N  r7 q5 U& Z: _

    ; f# {5 K9 M9 s1
    ! f3 p. ^" b; K8 h8 D0 L# j6 e26 G' D/ t+ k& P5 G* r0 J' W
    3
    ! h% ]; {, ?% |6 _6 N4
    ) |6 K: X% L# l; E* |* q52 R* W# v3 ^) i* S7 v& \  O& f
    6
    1 q1 T; Q! f" t8 L) l  ~. t7 U! ]76 @6 q8 R% n& x8 K" C5 y4 u6 p
    8
    * Y  s, t4 Y1 r6 {9
    2 m: N' ^1 i* U100 `3 n7 k3 Y$ G. `3 V( k" ?
    111 ^/ b  v# a: w& Z: ^. a8 y0 F; L
    12
    ' @, P" u+ x: Y13' f" q' H: U* c: q+ B. r" X  H+ V
    14
    5 ~# D6 w# p: l$ c* x+ _( @154 Y9 R) X# R9 b8 G6 q7 \% S
    16) O, X# B# E. Q4 L& m
    17
    $ N; P# W7 o  \$ d; |18
    ; P2 v1 b  x: g7 j7 c: I19  b! Z: n' V: {/ e. @
    20, ]( P6 z$ ]: [# {
    21
    ( a# k$ ~: O7 W. j22
    % ^& S  k  T8 M$ R3 L23
    $ y' J% V' p# @) C$ O) N: f5 l- A241 D! u1 X$ A; d5 l1 J! a
    253 ]; |- q/ j5 h% X& d
    26
    $ C0 J9 S. P+ n0 B, q27
    9 i1 x1 z" D0 w' r( L+ B28
    ' e# ?/ q  |+ }29
    + I# _4 T+ J+ m! U0 N30
    # C/ D& c- I, v, B9 N( F31/ @( c: @6 j0 z$ K2 L
    32
    1 G( \" W/ j3 |33# H1 I2 j* B  R# n+ j
    346 Z! u# T3 b" O# w& k1 C
    35
    & e- N+ W7 O) l" s% d& ~- f$ |1 W36- h' M9 }5 C/ c/ `
    373 I8 @  F5 D4 ^6 D
    38+ |# m0 ]6 x9 F% t) `
    39
    6 k, T) i, O; u  G40
    $ v3 y3 h7 v! L$ s: J7 S2 `41+ V+ X, A5 o& j9 S
    42
    8 X: O8 t' ?! H* ~- U" ]: l; s3 W43
    : y, i% i+ f1 ~, d" G44/ z- B7 V6 l  Q
    450 V8 J! z& h- U4 k
    463 z3 P" N- G, h; _2 q/ I
    47, [3 d  Z1 M3 F2 S( e6 G3 `! n7 v
    48. [# W0 w! l! U1 U* t# ?+ D
    49  n  t  M- @/ G) r
    50
    & o# u! D' K2 ~& A% `51; l! V0 w5 k9 U% E, `% w
    52  [* x. G  B# u1 C/ n
    53
    * ^5 h+ n; l+ q+ E2 Q8 P/ ]54. x6 ^; L+ I! \7 B* @
    <AxesSubplot:>+ I1 V- c! }, T- S  d* [1 N$ U
    1
    ) }% x6 @8 L& n% p$ Z! I
    ; K! o, l: ?" B; ~3 Q* J. P* y上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确# w. I4 I9 I4 F% h, M
    + e" O. _+ L, y4 R
    img.shape
    - E( L* ~- Q; A+ G% i1 x" ~1, Y7 S+ x3 A9 }
    (3, 224, 224)
    * c) S4 v3 w' q8 q5 S; e0 L8 ?1' D- L% ^1 J) o  `* J9 ~
    证明了通道提前了,而且大小没改变7 G$ M( _' a: h$ ]# M4 M  e8 z

    + n6 r+ x3 @- z3 f1 Z9. 推理. ?, D" E2 C; L: _3 F3 h: b4 f
    img.shape
    : h/ \8 Q" F: K1 E6 [9 Y! r5 B
    " R' Y7 I0 \) B# n# m- A( [+ r# 得到一个batch的测试数据
    9 ?( F# E. H5 @4 V& s# Tdataiter = iter(dataloaders['valid'])5 ^0 l1 ]# l. x2 ~- M/ }* W5 s1 E0 `
    images, labels = dataiter.next()
    / R4 Q9 H. j, w9 w
    , p0 ~: v. g; X4 kmodel_ft.eval()
    ! g+ l- c, @4 A/ R
    , ?3 I3 M  Y. |% q1 Cif train_on_gpu:9 }, Z: c; C5 U
        # 前向传播跑一次会得到output/ c0 f' G, V  p6 u% v1 v; G0 l+ ~
        output = model_ft(images.cuda())  w! h. A/ e. h
    else:% e3 b' a( {  w$ Z% o
        output = model_ft(images)$ R0 _- \5 `; @* d

    - X; H3 q" C1 e9 Q! `# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值  P- y( v2 o2 Z% F$ ?
    output.shape) W9 N7 v* R, Z$ [8 N6 D9 ~
    7 N( }7 z5 _! h1 \
    14 K. q7 {( _5 a
    2
    ( A7 H5 @: Y  P4 ^7 Q9 H( A3
    ) }6 m' {$ @2 u  \0 _/ j$ X) s3 k4
    8 v" ~" x3 i) C7 h! E+ _# D. S5! M8 O& J- p  N0 Q
    69 h/ X7 Y7 h: y7 g, X
    74 W8 }# s  {# Z. R3 @8 T: t
    8
    0 B3 |) U  G( d9 w: i7 Z/ U- O9! E2 b7 X3 K) x# A$ G# H
    10$ a+ w0 O( v" @' O" B0 Q
    11) S. c  p* i. O
    12: a* Q4 o3 z# ]# W; x
    135 G# A+ r' F5 P3 C
    14
    6 G/ v( t: b! n$ s+ e8 b6 g15
    1 t8 J/ {# b5 e' N. ?16
    % q! ^% \2 y6 ztorch.Size([8, 102])5 o7 V5 A2 d6 \4 s
    13 n6 Z4 c: q- s7 }7 Y/ R: `
    9.1 计算得到最大概率5 a) `4 J( O7 @! v( |# U- X" l2 Y! F2 Y
    _, preds_tensor = torch.max(output, 1)% Z: y5 b5 c0 R' U/ N( K  k8 H

    - M( l7 y+ ?" f5 f/ \preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    6 m; `$ b9 o5 f& C: i; A  `9 m6 ~; ]1
    + s" {+ W5 t  x/ x, \3 |& ^5 e2
    9 Y) ^$ N& _5 u2 v( W! z3% I" W3 z9 {; m6 L
    9.2 展示预测结果
    3 U' @! _4 w1 h- Gfig = plt.figure(figsize = (20, 20))
    ; b& K2 J2 k  j* Q" c; Ocolumns = 4' H7 U; O- j+ J% z; G$ v
    rows = 26 \: T$ @, a! Q" V9 X6 c

    * ]' U0 c+ _8 O! N: Q6 S+ s+ afor idx in range(columns * rows):* w1 k5 m% J( f5 J  |8 F; i" S4 o% X
        ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
    " a# K" Z9 O' L8 N    plt.imshow(im_convert(images[idx])), D- d" v, [) l: w& {
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
    ; r) }( b' \+ q1 p! d                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    2 J# v9 y) H8 Z. c' W. cplt.show()4 w5 L" j3 w3 u: Q5 {
    # 绿色的表示预测是对的,红色表示预测错了
    + [/ j% _8 p6 Y8 R5 S+ u' u# l1& H2 ?8 s8 l4 v
    23 j/ U  R: {) j/ J3 `
    3
    & `5 L2 r" o  Y+ Q4/ @, T# f* ^* P6 T) t
    5
    0 f# v& N. A' E0 b8 U2 G# ]6
    + `: c2 J$ L$ m: s4 `$ P7# J" i( `' @* U( d% F
    8
    6 P/ p' I" L# O% b! h% E! q9
    ( n8 o/ t7 i2 f" w# W10
    2 B$ i. P$ N) M+ a+ r0 g: J11, v5 _$ g: f9 W3 ?7 z1 ~3 K  t) l
    ( `) f! ?8 O  z& U& q+ a% s

    3 S: [& p% V7 h6 d9 {- G6 i* n, C
    * Z* X# J' ^) d6 w$ s! u————————————————
    3 O7 u0 A5 }; y# O2 ~4 v! s版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    9 ]* \. R2 x4 U/ V原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    $ W8 b- B- b4 ?: D" W8 ]9 [  N5 b2 q. Y$ s4 }2 }; l& {  X
    + O, g; k6 M( j
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-6-14 22:46 , Processed in 0.432550 second(s), 50 queries .

    回顶部