QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2751|回复: 0
打印 上一主题 下一主题

【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2022-9-8 10:41 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例5 X& P& R+ g, h: y7 v/ r" y
    ! e1 j% p9 N2 I
    文章目录3 s6 A5 P9 ]; H
    卷积网络实战 对花进行分类2 O1 {( _9 }- L& \
    数据预处理部分2 h  i' N8 W8 B+ z& R8 i" ?# B  ^
    网络模块设置" _2 K* ~- z! C. P0 }' F8 ?; Y, V  T
    网络模型的保存与测试+ T+ V4 |7 c$ G0 M
    数据下载:
    1 O) Y2 z2 H$ _1. 导入工具包
    2 H+ n* o, ^- f+ u/ r9 x( N2. 数据预处理与操作8 n8 v' H9 k( F& c+ w2 e' e
    3. 制作好数据源0 a1 L( [) s2 V& y  T
    读取标签对应的实际名字
    / c0 ^5 x; k; [( e% Y4.展示一下数据6 l5 J; n( O9 ~1 `
    5. 加载models提供的模型,并直接用训练好的权重做初始化参数/ n( ]: q$ I6 h4 Q( y0 @8 M
    6.初始化模型架构
    , ?# i0 k# L: N$ m5 b# a7. 设置需要训练的参数+ G5 w2 w' f; M# F4 W% E! |3 E6 e
    7. 训练与预测
    : H( _) q1 i7 v( c2 h" u+ P7.1 优化器设置4 P4 L5 b$ r8 `( f1 Y$ v
    7.2 开始训练模型
    7 o0 \) T7 i1 x; i( i' C; D1 P7.3 训练所有层
    ( R( `& _* \/ C8 H) N开始训练( `+ ?  U2 C6 P( b  y
    8. 加载已经训练的模型) i9 p" a( e, P7 S. }$ n: r: L3 i
    9. 推理- m* l! m! O2 X" D& {
    9.1 计算得到最大概率1 g) y$ t! x* @5 }$ f
    9.2 展示预测结果! u% h* f  u% u! H
    写在最后
    8 {3 A& l  T+ \0 A卷积网络实战 对花进行分类/ E9 j$ L8 k0 {* ~( r
    本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
    3 ^) Z% _" O" I- O& X) `3 U' a
    6 ^. e+ ^  U0 V7 V# a% W9 @在文件夹中有102种花,我们主要要对这些花进行分类任务
    5 Y( _/ T  ~9 z# Y" r$ [文件夹结构
    0 m4 E# r: G- Q
    . R! y% s( `2 n1 ^: l: dflower_data
    ! v: N8 q! [) ]* o: G& I& I) t8 s) _0 S
    train
    % Q- K: F1 ~4 M' {2 U# s! n3 ^+ b& ^0 b
    1(类别)
    / S. M1 R9 j- d! ~+ l+ m5 _2
    , T. V) W# T1 S# |xxx.png / xxx.jpg
    - J1 G0 f( w! B8 t# kvalid$ o5 ]4 E0 e& P% c3 H
    ! Z$ \& W8 z+ K
    主要分为以下几个大模块
    & ~' I; s0 u. {
    + s/ T7 ?/ R( s6 n- B数据预处理部分
    * w; f5 ^  L% P: `% F数据增强
    $ P: q7 d6 {4 n* X5 A7 V: n( J数据预处理
    & L+ t' q7 d, d/ l2 n网络模块设置' X- `# E  j" J& F% v
    加载预训练模型,直接调用torchVision的经典网络架构
    - G6 s5 }" m; X) V2 s% W4 H9 C因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
    ( A. Q* i* c# @4 Z, t3 m* E网络模型的保存与测试4 u# S& p7 L6 c2 r8 `1 V; M5 Q3 U
    模型保存可以带有选择性
    " {0 @; i9 ]" d, U; S数据下载:
    ; x& Z5 H& L9 u5 o) e& U6 I* I, x! ]8 ohttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset! ]; R6 u  [; s7 y
    9 z6 I. `" i5 x# N) J" a, B
    改一下文件名,然后将它放到同一根目录就可以了5 U$ Y3 u( A" x' P

    $ [* Q( e" |- a7 ]% j/ l下面是我的数据根目录: M7 t7 S, ]1 M+ l  k% D
    2 q1 H5 I( R7 M
    : H- X* c  \0 M- p7 g5 d
    1. 导入工具包
    9 \8 H8 c4 U, h" I6 Q: ?0 w! ?import os1 J( n3 G" b" ?
    import matplotlib.pyplot as plt
      S- Y. M. T$ ~+ K# 内嵌入绘图简去show的句柄& D# a; d* ]2 t" C
    %matplotlib inline " q% Q% G1 ?" b2 z2 E
    import numpy as np
    * }! E/ D" C- j; F. ~1 u& G: Eimport torch
    . I1 B7 @7 I6 H' J% b7 Nfrom torch import nn5 J4 h; `/ [. a* `& n
      L5 f" y# S3 M% Z
    import torch.optim as optim* [: u$ x6 _; r; R& M5 Q
    import torchvision  v& v) c2 O, n: Z/ l3 V
    from torchvision import transforms, models, datasets6 q$ }" D- Q! S3 a5 _& V
    ( {& |9 x% e3 w3 ^& D
    import imageio
    ( K* h" x: H. q4 d! ^import time
    # R0 L* T2 P' S- V9 Zimport warnings* q# K0 h, l! o" D$ \  N0 h
    import random1 L' a! S6 o: g- z
    import sys# H# y1 o- R" r3 R! `
    import copy! h8 x& K2 T3 m3 x, y* d" M6 h
    import json4 M% O7 n7 s- z  P: h
    from PIL import Image
    ( d1 p8 A. S; ^7 P& p6 r% a' t' |1 h' H- I9 N( @0 D

    * e" @+ a! V! p0 m) N- D% q13 @' S  A; s% X3 s! H
    29 q& h" Q0 |- C3 {
    3
    $ M7 J. R7 Q2 l: T: R' n! O4
    2 Z* H7 [6 t! W  P+ L1 Y  G5( {" ~; o) W; ^" ?+ _5 L4 A  l0 Q
    6
    ( L# Y0 t8 A: c76 M' w" C6 d1 R0 o" ^8 V2 T
    8
    # j3 U. b7 j5 H' x. B9
    $ K3 V  a# \! `; K: \10: |+ a6 U& _/ d1 B1 E+ {3 N
    11
    - W9 T, Z# ?; O, F+ D" V, M# Q4 s12
    3 `/ F0 v" s" v4 a13: K. f3 j* F7 g3 a- T
    142 t  _5 W! Z1 ~8 {
    15
    * o5 h7 S( ?- f2 x# j- S16  ]  }: B- S) A. C3 d
    17
    & L4 x: D# K% C18* U; I0 c0 H! x% X% n" G2 Q2 r0 ~& S
    19
    ; N$ u' R/ H, ]( ]! J3 b& c202 i' r3 r; ~& ~# w( V7 f9 ?
    21
    4 d- ^9 A/ c5 q2. 数据预处理与操作
    5 J7 _) {2 u  q9 x2 q#路径设置4 A1 i: |6 C% \* S2 N1 L2 m
    data_dir = './flower_data/' # 当前文件夹下的flowerdata目录$ r, c8 |7 o* B1 n) w
    train_dir = data_dir + '/train'
    5 [2 e+ k4 f6 i6 m4 bvalid_dir = data_dir + '/valid'
    7 x: Q, |5 V; k9 H2 G; l1
      h5 y- y! ]2 ]; r2' c1 |( e0 B9 N* ~4 g7 d0 k
    38 c9 c) v) G. W7 b9 a- y
    4
    6 T/ R# }; D$ n. Y  ?python目录点杠的组合与区别
    5 o0 d9 h: J; C" }* S) _, N: D/ }注: 里面注明了点杠和斜杠的操作
    " B$ U/ ]: N: ~" k3 s* [8 B9 ~- b. f0 ~1 B  \8 s( g" S
    3. 制作好数据源
    * i; s* O4 @! y0 Q+ `3 v& _; E2 m- @data_transforms中制定了所有图像预处理的操作
    0 w8 X. S9 Z$ X+ ^( {% ^ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
    ! l# h8 C0 {) a, P1 G/ Z& ^data_transforms = {
    7 E2 i! C3 s& E( p, U    # 分成两部分,一部分是训练: y1 M/ k+ [" \, J4 s1 ~
        'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
    / J# T. O7 K* ]( Z& e( \# Z% I                                 transforms.CenterCrop(224), # 从中心处开始裁剪+ P6 H' W) ^& w7 R
                                     # 以某个随机的概率决定是否翻转 55开
    6 ?1 I- ~' G: H' c  d                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转7 u+ u. N- ?% \6 p/ V3 |
                                     transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
    4 [9 a7 K# C7 D, _6 M4 P                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相& `0 C/ e8 S9 D7 r
                                     transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
    9 Z1 C8 m# o5 i# ~                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
    7 P+ [6 _+ O/ L& q4 d' Y6 _" [                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
    / F' K$ {" F& {                                 transforms.ToTensor(),
      W0 n* z& ~; _8 j$ |* \$ g. f                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
    / Y4 p) N/ I. D! J; \6 p                                ]),2 K2 t3 X% B5 L. W
        # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
    ; X) @) t: u2 z/ X! Z7 Q2 o) E, I    'valid': transforms.Compose([transforms.Resize(256),
    : d' n0 _! `. c4 E, f0 O                                 transforms.CenterCrop(224),/ v6 Q2 l0 X4 e0 s
                                     transforms.ToTensor(),
    ' p% s. z% D) b, s2 l/ Y                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同4 H, t  f% {/ O! D: I
                                    ]),$ |, P9 R8 B9 D& _; `
    }+ Y8 K3 i3 `; l4 |' p$ i4 ~

    1 [; k9 ~) U/ K8 [  f* L6 Z1% p9 w  P; a' q) Z
    2( r5 B6 ]0 H" h, G# ^" Z1 p- K& `
    31 c" T" l7 t7 I
    4
    6 R$ K+ p) f( W$ W5 w9 K57 r( J& D6 ?& P' S0 ]( u
    64 x5 C# ?7 \7 i2 e2 y/ ^
    75 V- t4 S1 @1 u: N: }
    8; |" R" O" ?9 d* C7 p& ^' M  q
    9: {* x' O! s8 a! f5 [
    101 T  V7 ?6 N& d" Z4 E0 ^
    11
    " }" I" W( ?7 P12
    8 z9 x. ~: j9 a- G) |* {# ?8 Y133 a8 W% G# F$ r* d* U
    14
    ) k9 A  A- ^' Y1 O8 V15% W3 C! J# r9 V. v* M4 z$ v
    16# I6 Y4 j& ~" ?! ]
    17# a  C; h# \8 K1 w: j. N
    18. O% |1 l9 y. F5 h2 ^4 K! L
    199 E2 b0 Z9 ]. f* x; b  J# h
    20
    1 f3 U& E+ `$ {  c$ ~$ a, j; H6 i21* H0 ?4 P( I' f" R
    batch_size = 8
    3 X2 m3 G9 a# Q! Q; A; [4 Nimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
    : Z. i% Z) j  z# pdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}. y. N% D* C- X4 W2 r- u: @1 `. [
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} # y* n; b! R% N1 L
    class_names = image_datasets['train'].classes* A* ~2 T' e; V9 ~# O: D6 ^' I
    " J- l6 h  T- P! S0 d' o
    #查看数据集合2 J1 L; \% D' {: `5 F
    image_datasets7 F: V& `- d! c2 n- g$ _
    / e* e% G3 o- D4 a8 U
    1
    * _0 K' R* F/ A& G2
    ' N( B% p  ^2 e/ O! k34 N8 X; z) M8 O* A
    4
    ! a9 Z' A* e' A; r2 q& c5
    6 N* g. v$ T# z9 q  {' x6+ ^7 T8 o" P; y) v7 E  G& V
    7
    ; b; K% C1 _% q& o" u85 d5 u7 z) ?# P1 [( D0 ^9 ]; t
    9+ Y% J1 q; q9 w1 \) x
    {'train': Dataset ImageFolder
    " Y* k. Q- G  @6 G  O% y% `     Number of datapoints: 65526 X+ y5 m; u1 D( {  I  ]3 g' ?3 M
         Root location: ./flower_data/train& ]; H( i8 w( K$ e+ K
         StandardTransform
    % }7 Q; B6 p% K' B" Y. Y Transform: Compose(
      F( t5 T- @$ o- b3 d                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)) M& c, c! v6 f' F% b
                    CenterCrop(size=(224, 224))
    - ~# U$ q, ?0 t3 w( u" m5 U                RandomHorizontalFlip(p=0.5)
    " a1 M1 }$ \. _* o/ p                RandomVerticalFlip(p=0.5)
    8 q4 H8 c2 `! s9 L( S( `, _                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
    9 A  \3 `& W# X; T' l- _                RandomGrayscale(p=0.025)& ^! [0 W9 e; Y% j: c0 Z2 }
                    ToTensor()4 a& M* v. y8 W( [: G) ?
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])9 Z- J# ]2 K" D. Z3 M7 i
                )," D3 S  M4 G+ A5 K; H3 P6 G9 K  d9 v
    'valid': Dataset ImageFolder, T$ p2 R$ s  m" Y) T
         Number of datapoints: 818) T  P  j1 Q4 |- q0 R
         Root location: ./flower_data/valid: ^" `0 [1 H# u& p
         StandardTransform9 t4 }' O5 [  o
    Transform: Compose(
    ; `- y  N1 z' T0 H3 A6 P9 Y                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)( t8 V" }* l0 i/ g7 K
                    CenterCrop(size=(224, 224)); p; s5 [% E! `0 L! i0 i9 @4 H. W( B
                    ToTensor()
    1 v& F& r* Q5 T# b9 S' D3 I. W1 Z                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])% U% u1 Y# u/ @. `' }
                )}
    ; q' k+ r, d7 L$ A5 `# d. Z7 r$ X9 V/ [* L! O5 M$ d" |0 ~
    1
    ( @8 N% K: i, F4 r6 y2( d* C( s2 r9 t# {
    3
    ' d9 G5 y1 |9 i" N0 t/ v4 U/ L! o9 R4
      G7 N! v% ]; x( b8 a$ z9 x5
    9 j( F  ]7 i- a; ?7 o1 b9 [. M6
    . f# ]; v! O. U7/ D$ m/ M3 l2 i3 k( u" `- u
    8
    : \8 W3 [; U6 b4 s9
    2 ^6 x# Q) `" H8 l10- h/ d- {" u/ t# n! R
    110 R; w/ ]+ \) w4 U3 L" ?* Z, V% C
    12" [& ^. M- B; K( D* `% ~# z
    132 R! i3 n5 [. b
    14
    - h! T( W3 j! D/ r- D9 T) D/ j, D15
    5 _: Y0 j; ~( O5 m5 T4 g16
    7 `7 d- d2 y. S17
    5 l) u# W+ l  A; [18
    0 U) g7 o) |( @: @( O* d19
    ! G5 {1 _, i8 E1 H6 Y# W203 G+ w; p, [& a7 n
    21
    0 }: n9 j! m% I22
    8 D" s5 x% v% i8 K23) D# H3 I$ o3 K0 y1 X6 E4 S
    24
    ) G" z+ V: L) B1 ~( \( A# 验证一下数据是否已经被处理完毕
    1 t# k6 M4 W( wdataloaders
    - o9 T; t$ u: g1
    7 C/ V4 t2 I! R+ l! _) Z4 `( _7 r, o2& x. W% ?2 d' @1 l/ h4 n
    {'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
    5 x% Y! W8 O/ \4 Q, { 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
    4 d4 R: ~" J( g( i6 |5 x# s1
    5 g1 e% E. ]& n) a- U$ A7 U2" J& a( ^, |9 v( x
    dataset_sizes& ~  i  C! |; w* E! M
    14 D, r9 Z9 m6 N4 h: ]  E( i
    {'train': 6552, 'valid': 818}
    $ I( q. u: n2 d4 B  ~" ^1
    9 P( r) h' ?+ {6 b8 ~读取标签对应的实际名字! t# S7 Y6 |! W  y& y6 T
    使用同一目录下的json文件,反向映射出花对应的名字
    : ]- n" L  a/ o- l2 R) q4 Q; m0 b) g5 C' c) D1 k# {) f/ G
    with open('./flower_data/cat_to_name.json', 'r') as f:
    ! }5 U" Q; `" X9 _5 v) ]4 e' W    cat_to_name = json.load(f)+ D$ F, u6 h6 X% ?) r
    1+ }6 f1 a( S8 D( \5 Q0 b
    2
    * D" h9 P* n9 R9 ]8 Pcat_to_name
    + E% U9 h7 K/ F: k1. F4 P9 _  m2 X. a7 d
    {'21': 'fire lily',
    ! n3 p  W5 e! @' ~' g( `) b '3': 'canterbury bells',
    % `9 H: a- r* N2 f '45': 'bolero deep blue',
    & {! F0 O5 m. M: X$ S& w '1': 'pink primrose',: ]' v0 C0 s/ r/ r
    '34': 'mexican aster',3 `/ y& e. i  `; L& p% v# ^
    '27': 'prince of wales feathers',
    ) p+ k4 a0 d0 [ '7': 'moon orchid',# i& Q; |3 q$ O; g
    '16': 'globe-flower',
    2 Y7 R9 {# G( k+ _- C; L$ A '25': 'grape hyacinth',% X- q8 P, P3 c+ F8 e3 S
    '26': 'corn poppy',- O; c4 c+ {9 b& d
    '79': 'toad lily',
    : Q( j& _+ B3 X0 x: c '39': 'siam tulip',2 z  `5 Y3 t# t3 p
    '24': 'red ginger',
      }- u( {; @- [; U; n '67': 'spring crocus',
    ' S6 G! Y8 c, j5 X3 D( o- a '35': 'alpine sea holly',
    7 j0 [# p& e' l6 b8 R' N '32': 'garden phlox',7 _7 m1 @$ y/ N' w
    '10': 'globe thistle',
    $ x4 ~" K: c; C0 ? '6': 'tiger lily',
    7 t6 h, w. X& d" ~ '93': 'ball moss',9 W/ |+ C4 P3 p; Q
    '33': 'love in the mist',5 U) L# m2 Y7 o9 w4 o" Z: C6 }8 v
    '9': 'monkshood',  i0 M3 C2 Q; ]( e. }" q
    '102': 'blackberry lily',
    % e) |( ~6 V) w9 z '14': 'spear thistle',
    5 X0 q2 E& x" h! A '19': 'balloon flower',7 d2 o3 H- x5 y0 y5 S( d+ y+ S
    '100': 'blanket flower',1 ^0 B2 M7 V# ^+ m
    '13': 'king protea',
    / h5 j3 i" ]: A% Q& M! K( x '49': 'oxeye daisy',7 x+ Z. q) M' H& `
    '15': 'yellow iris',) u+ j& Z# I$ g  f
    '61': 'cautleya spicata',* ^- o' u0 U7 n
    '31': 'carnation',+ k& E- X5 m  ]& t8 U/ m* Y, `7 q
    '64': 'silverbush',/ q( w4 N2 I. m, g8 W& ^
    '68': 'bearded iris',) F5 |7 S! ?. h$ x4 J7 y1 L
    '63': 'black-eyed susan',
    1 g2 K, w4 ?' }* { '69': 'windflower',
    % J; ?% h6 O& U. E& x' N '62': 'japanese anemone',; g+ l  |' Y4 y
    '20': 'giant white arum lily',' \  t, x9 K. \; v( U( u
    '38': 'great masterwort',
    9 |, }! u6 U* [9 h% X. Y& l1 e! ~. v/ l7 f '4': 'sweet pea',* Y- G3 X8 j0 k- N. K2 j& K
    '86': 'tree mallow',
    2 w5 {' m8 _4 }# b( C '101': 'trumpet creeper',1 l7 z9 f' @4 \& g$ h
    '42': 'daffodil',
    & d7 U3 t/ W# z6 A '22': 'pincushion flower',+ ]- V) j4 c" k- R$ N( ~
    '2': 'hard-leaved pocket orchid',. ?( s, t* A  |: A- Q& R* `, X; ^' o
    '54': 'sunflower',
    7 x' q% T7 R, {/ f '66': 'osteospermum',
    . y( O. H0 J( Y2 q4 E '70': 'tree poppy',
    % ~8 Z! S$ l+ x- c, b$ X '85': 'desert-rose',7 B+ U1 {) h9 ?& A
    '99': 'bromelia',5 ]% b; B, _. p) d: o, P
    '87': 'magnolia',
    - n- `" g& i5 V4 ~0 m7 B '5': 'english marigold',5 C% \/ t( x2 I& N0 v7 V
    '92': 'bee balm',
    % Q) J. o1 a$ j. i9 C/ Q. f! W! u+ K( ^ '28': 'stemless gentian'," j: f# g" d3 a
    '97': 'mallow',; W; F2 K) W4 d5 Z% k3 e6 b# u
    '57': 'gaura',2 E8 e/ q' q) y' l- g; C6 Z
    '40': 'lenten rose',
    1 }0 ~9 f/ V6 [ '47': 'marigold',3 x6 f1 @, a3 y4 V: Y* Q2 N" Y
    '59': 'orange dahlia',8 n& [* H$ I# c2 z( F5 }3 O/ k) d
    '48': 'buttercup',9 s3 d" |5 x- w8 |( k, f$ g, u
    '55': 'pelargonium',- q8 O. E& J9 _7 ~; M
    '36': 'ruby-lipped cattleya',( y* |3 a; \+ t! M1 Q
    '91': 'hippeastrum',
    * f: i4 g+ k' y '29': 'artichoke',0 O! @- d& U5 V
    '71': 'gazania',6 @2 u- y: o' z2 }) X- t* c: q
    '90': 'canna lily',
    , Z4 D: S9 k2 U, I5 I '18': 'peruvian lily',) O6 W: O1 T7 M6 ]! _$ `
    '98': 'mexican petunia',( T4 q7 w0 ~; l6 f
    '8': 'bird of paradise',1 o9 Q$ ~- j* o9 x6 \  B$ @
    '30': 'sweet william',1 Z- |: d- Y8 U) k( X
    '17': 'purple coneflower',# b+ i# R6 S' F8 O( l
    '52': 'wild pansy',& W2 A4 j4 e  m
    '84': 'columbine',
    5 r3 x" Z* H1 J! o; n/ l' `, z' ] '12': "colt's foot",# v+ t" T' L1 w) P  |& c! @5 I5 I& ~
    '11': 'snapdragon',
    ! w. B+ k: Z* g% I0 C0 i8 J '96': 'camellia',
    5 F# g' z  m$ Q+ K# K& N3 P '23': 'fritillary',
    " I/ t# _0 J6 o8 W '50': 'common dandelion',
    3 }5 l6 A9 a, D; S  o$ h '44': 'poinsettia',! F$ ^. e  m/ {2 K
    '53': 'primula',
    5 e; y# ~; {2 v9 a" P '72': 'azalea',
    5 L( u, j! X; q8 x '65': 'californian poppy',
    - A' M9 m- z  |, B* b '80': 'anthurium',8 ^7 d. F" F2 |7 S
    '76': 'morning glory',
    6 G: v# B$ [" Y  f '37': 'cape flower',2 v+ B- K+ Y! Z6 T
    '56': 'bishop of llandaff',
    2 k1 a1 J; L: K9 I" f0 c '60': 'pink-yellow dahlia',) z) [( u3 [% I$ Z. l: J1 M, u
    '82': 'clematis',& v9 h# t- r# Y/ }* x1 m
    '58': 'geranium',6 ^; G0 U  p: g/ I7 d$ M: t
    '75': 'thorn apple',' G: k( i% D; n3 ^& `
    '41': 'barbeton daisy',
    * c$ ^9 Z7 L' g- j9 ?. i) W* S '95': 'bougainvillea',
    : w7 @3 b$ F: M- \8 j& |+ Q '43': 'sword lily',
    6 T: S/ w# f. x8 n- L0 ? '83': 'hibiscus',/ L5 V+ x% |. r1 K8 @, i1 P
    '78': 'lotus lotus',/ v) r1 E$ m& l$ w
    '88': 'cyclamen',3 A9 H' M4 E2 a3 f$ b) Q
    '94': 'foxglove',
    - }) p5 H9 W$ c: [ '81': 'frangipani',
    2 ~& x) M* t5 I) M9 u '74': 'rose',
    . o2 m/ A9 o7 Q1 u  \  Q+ ? '89': 'watercress',
    * n# N% k5 h( U0 G5 K3 P/ E+ f7 q '73': 'water lily',
    2 Z4 G2 d6 h/ `$ `! O: h# a '46': 'wallflower',
    + S( w6 G/ d6 c '77': 'passion flower',
    . i* n0 }6 H* _$ l" h '51': 'petunia'}5 S% D/ {* r2 @+ I/ M5 E

    2 l2 s8 d& |; Y1
    . W. x/ {% p' r% m$ @& u2" `2 O' g  U$ L" |; Z
    3
      T3 u# `3 C6 ~6 N- U, y5 }4
    0 M# ~8 d5 t+ ~& y! z7 N, T, n5
    # g2 C  y, L6 T! ]9 P/ M7 G6 f6, W; u5 p* L( Q+ D
    7
    7 b! T' Q+ v) B- a  v' J4 B$ F) s: R8
    ( U9 {4 _' O% ^, r9+ K$ J& O: M- x& K
    10
    & p5 Y) e. }1 x% `% D# \11& h; e! `. e% Z# p0 y- n% Y
    12
    . P" Y' ~# d4 L' A, l! [9 S( q138 G( d, l$ e' }+ \. @; H+ W
    14
    3 w' ~" q- G2 n# G, M2 w15; `9 s- p( b/ _" j' u: D: e
    16
    5 a, L2 `3 j. l7 P9 m, m5 r8 [% @17
    7 ?, c; B( O: @18
    7 T: X6 U. s1 w3 a19
    3 t% F- G! ~/ \- W$ b205 p5 h# F# l; H  o  e
    21; v9 o( I7 ~$ H
    22& _: M* A; b& ^$ _& b
    23
    ! w8 m7 ~- i5 {' ~/ U24
    3 E& U. _6 e, ~3 s! j/ ~) ~3 ~25
    , h1 N, j' N* |7 B26
    1 P' r0 @; n. M' R7 q( s& H27: v7 c+ m6 R; f, E! s% \0 G# r
    28
    / r" J# ], w% `! A3 w29- \3 n; d- L7 E; g/ R
    30
    & {7 F. H- `( k5 r. }, \6 z. z8 n31
    4 ]  B' `) v" ^32
    - _0 k7 c9 W  B5 V& R33$ s) L: x* B8 t4 I+ w8 D9 u( X! O
    34$ b+ X* _! ^6 i: C+ c# U
    35
    ) M9 V$ v; _4 R; i: o# `- h36% X3 X. E; |# I. W
    37% |7 V! L+ v  {# [$ [( f) F
    38
    * \' b* {2 y, L39- r5 l; T# o# y9 a
    40& n  o1 d0 p! y3 V$ B. Q2 J
    41* K% y; F4 b! h2 o, \2 _( T
    42
    # B% M# Z% A. `9 ?! ]* W1 {9 y+ z; ~43
    ) T; K" E3 o) k0 P( F44
    . [" P. ?" p7 e6 k$ w: k4 X45# w9 w( x) {  n" U1 F/ M% {, \/ y
    46+ d) Y  }6 U6 u; N
    47
    ! Y* M4 r  o, `( b8 z48
    # T" C! ~4 B  Y; u: D49) U( J/ B1 {; V% [8 K, t- n
    50$ Z0 W8 ~$ v; L+ W
    51
    ( ]' G) P& ^0 r3 e5 b: {  `+ p52. Q, z& y( A3 J8 ?
    53
    9 V  W& o' e% x! y54
    ( z$ \9 B7 _- b/ x0 t7 x# T3 z: x55
    , d7 d6 X0 ~0 ^' t% c56& G9 j% I) Q/ C/ j$ T4 h) N
    57( q. J. U( T7 d
    58
    0 x% V6 n) A' m1 e) R59
    / p7 v. k: M& x5 k60
    ! q, O' L& ~7 ~% V* w61
    9 @6 r+ {' w! V: U$ h1 z2 A62, D6 H' P) f( c8 s, U4 P
    63& g2 g3 S& z! l# D
    64
    * }2 ^) Q& X6 M65
    ) J% q% i  Z3 N+ ^- n6 g9 ?66. h2 G+ z( v  \+ ?: m, u
    67
    7 R  a  d7 I& |# t# g68: L# \1 Q% s8 T  X3 G
    69$ Y2 m) T7 a4 I- w/ p5 Z$ u
    70* V. X' l% g1 f+ T" |# s
    71
    ( O& M! t4 V2 p- u/ ~5 ?, g0 C/ |72
    . J! ]' Z& r; t( i, R730 {. H! k, q, A* W( V
    74. ~: c8 `& N& F. J" w3 D- y* @( d  o
    75
    3 B7 ^/ x/ A8 ~6 n76
    , v/ N* w, R' f5 x( s& A77
    $ p+ Q# J1 }& M% a& R1 o2 `78
    / q: u4 z  \* y7 x' h6 a79
    * d, P1 H+ v6 [, S7 }80
    , y6 Y( B5 q+ Y0 Z817 N) x- h: G" [
    82
    ( @& a- F! D0 J$ |* h5 z83
    8 @* ^  O4 J/ {0 F4 s847 N/ P, }& d) v1 J
    85
    4 M9 R& S8 i- d8 t& K6 }' K3 [869 c8 w+ K# w$ _0 c6 ]$ y  w% v7 U
    87
    $ r* [7 T9 i; ~  _7 ~88' |/ g# M; z5 s" B1 c+ S  {
    89
    * T& r( p2 O7 L* r8 K( t, g* y90* [5 h: W' B0 H$ s
    91
    8 q: V2 X# l  B% l6 X92$ h9 Q7 f: l- G. v3 `
    93
    : {3 q' [6 _+ v94( M- u' m% G$ v( w
    95
    # F" j# _* k$ {6 d; P. L96
    + X3 q" \% m# C5 K% J; m) G97
    ( ?( f" L5 W+ i, I2 n' D$ p% f( t6 V98; {7 o, y! G8 x" a% Q
    99
    " {, b0 Y% C7 X% Y  R* g( u1003 a) }  w: Q: x' Z6 V- Q
    1015 J9 a! C, w+ G- |2 q
    1028 a/ `4 r' F# F+ ]; Q( S6 N# g
    4.展示一下数据  u; R( m" m" F8 x: t! {
    def im_convert(tensor):. {) e( Y' ]1 ?, V7 V6 Z1 S; a
        """数据展示"""5 x8 ?5 G! r! ~8 |9 i# @5 _
        image = tensor.to("cpu").clone().detach()7 @6 H( f8 u  ]+ W; w! b( U& A1 C! x
        image = image.numpy().squeeze()
    , N2 C4 q4 A0 T5 I. H  Y    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
    8 a  O3 P- ]$ Q6 q0 g$ p    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)1 [  o" v0 d+ _
        image = image.transpose(1, 2, 0)
    : Q  y& j' G( [3 R2 Q3 z) T    # 反正则化(反标准化): A: L  }  A1 b5 [" X
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    ; j) b) F; R  C% N+ L; ^) b' |9 o  g% N
        # 将图像中小于0 的都换成0,大于的都变成1, C  b$ u- o, J& U9 `+ J8 m. l
        image = image.clip(0, 1)
    : J7 A5 e3 d8 N- D, a" Q! K# V$ n+ S" u
        return image, P8 s2 S. w0 E( w0 F4 B% l# i( o6 m
    1
    8 c2 V4 ~' W2 a3 L4 F+ U27 {8 [+ f7 k, n& |5 O
    3$ i0 E: c+ o; E
    4' o! U1 {! D* X0 n* I$ Y: O- M
    5
    / e: O/ d( L, b6
    & Z9 b* M# E! A$ ]/ [7* k& Q+ T# c% h9 S. K
    8
    7 c2 l% V, p4 [9 Y' a. K8 h0 p: H9
    9 e+ c$ q1 \9 g' h$ M4 v* b10
    ! t( r7 b8 u% b( p: G8 j% L11
    # c1 v2 X9 f4 \12
    ; I/ o. o, v* |; ]4 A' W' k13
    ; t+ K( k# P! E/ c+ Z; r5 j  ]14
    % i3 b' _" z1 g# 使用上面定义好的类进行画图
    - \, m3 B% R8 q, Efig = plt.figure(figsize = (20, 12))3 W3 B) q# i% X* p  U
    columns = 4
    ( o% a$ {# P1 a# y$ z4 Crows = 2. u9 C: n1 C2 N: m9 }1 X
    % A0 [' u: a% Y- n; K
    # iter迭代器
    3 w* W& k! t6 d' o5 k3 z# 随便找一个Batch数据进行展示& ?4 u; ^  _' ]5 z! ]# A4 C
    dataiter = iter(dataloaders['valid'])5 t% z+ B. l# [+ s0 c$ u
    inputs, classes = dataiter.next()
    % S$ c9 j  T& m4 C" B# x, Q
    ' Z, n$ @0 @) {: q' F# [for idx in range(columns * rows):
    3 ~3 F+ A7 B' T) M    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
    2 I. t, Q# k! J4 g    # 利用json文件将其对应花的类型打印在图片中
    + h: P7 }1 D9 N" N2 X9 m8 V    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])* N& o9 v- A/ X$ F& }4 L3 n
        plt.imshow(im_convert(inputs[idx])). L9 k8 U) T# l5 u
    plt.show()# C: D0 Z( L$ v" A

    / P. p# `( l# O# ~( y5 Y13 \/ S9 o  x2 V1 o; S6 `/ f$ M! Q
    2
    ! Y0 |/ M+ K! Q7 R3
    0 j# D$ d8 T7 n4
    / G# I6 ?4 f$ L( K5
    3 l2 v- c7 A% W, ?1 G! @8 \6
    8 R! z+ Q2 k& n0 e7/ @9 I: y2 M2 Y9 r* z) L6 ]* p
    83 H* h2 H/ x  [+ o, I
    95 @  P- p; _# r3 S
    10
    - q9 s% f1 |, X- z" r9 {11
    " [. u* M0 w/ ^7 G! H. |12" D" S4 e# F6 R* [* x
    13
    : T# b. c& j) K# r& g140 I) i' P. z: _' c8 k
    159 u' c, o' S% ^0 K) c+ w
    16
    9 h" o  k( V& U; h1 w  Y' {3 b8 k0 S9 `; N5 S( ?0 A

    3 @8 B% x* R( c! R) Q5. 加载models提供的模型,并直接用训练好的权重做初始化参数
    7 |/ Q4 A  i- [- J% pmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']% T' J2 I+ q; y7 [# |
    # 主要的图像识别用resnet来做
    # n! o- j" V8 O& A7 o5 O9 \# 是否用人家训练好的特征: q! D! q, P3 R+ `! p, z  t8 [
    feature_extract = True- C4 w! D1 r, U' J6 H6 F( o
    1
    . K3 m" ]4 R  G/ o  J2" R7 ], ]2 A7 ~6 y" k5 j' B
    3
    , U: O+ Z  X- t) ]0 D$ x6 ?4) Z: Z( ^2 T' x: J/ X
    # 是否用GPU进行训练1 B! k7 x% J& \4 R+ H) b
    train_on_gpu = torch.cuda.is_available()
    . j9 E0 m& F! n; k2 ]0 d
    ; V8 [- }% n. V1 J  |& P( a( }if not train_on_gpu:4 z* V- l/ Q$ }) B* O( N4 F
        print('CUDA is not available.   Training on CPU ...')0 h9 N. ?6 g- O$ j& ^3 N
    else:
    ) N% B& R5 t9 N    print('CUDA is available! Training on GPU ...')
    7 E5 _# Z9 F& ]
    # p5 E% ]" z  y3 idevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    9 \+ w3 t! h6 |* @1& I: J$ ?0 L* P4 Y# e
    2
    ! B& [6 `1 s" c+ _/ N4 ?- k3
    6 \! k" [5 z2 Q5 C4, w; \% }. T! m# {- N' I5 l0 C6 [4 O
    5
    % @5 y% `' M! c3 F% E6( R2 w* q9 |" `! m  J* b
    7
    % o- u# Q, t1 ^; B7 M: Z8
    # e8 A) P1 j: C6 }. @0 O91 Q7 t: f9 v0 k- T, x
    CUDA is not available.   Training on CPU ...
    6 a3 e3 r9 f" I* l2 j& _6 h1
    # [7 j3 {; }* ~" x2 ~( C# 将一些层定义为false,使其不自动更新
    3 Z* d8 c; U( t/ _- ?9 E3 t( d5 ~def set_parameter_requires_grad(model, feature_extracting):
    - g6 g. d0 F8 D7 K7 |, Z    if feature_extracting:' E& V( a& L4 {: F2 o' `; W
            for param in model.parameters():
    + A9 T6 I9 C, r$ @$ g3 l2 T+ e            param.requires_grad = False+ @1 f# ]) \2 g2 d  b
    1
    & v6 V  L: y; U27 d" Q0 l9 F# W! M# }% h
    39 V+ O- A# t# @* _5 y
    4
    5 r1 I* J( W' v* V4 R4 F5, n' a- J3 F" w' q3 t; H5 C
    # 打印模型架构告知是怎么一步一步去完成的2 }/ z. x0 [& \6 u: D( J7 U. l
    # 主要是为我们提取特征的& L. |% ~3 r8 @

    & k# \, K' G6 J% Cmodel_ft = models.resnet152()
    1 @0 v  T# F8 [; l( A. H/ a9 Jmodel_ft% B% o- \, `/ R9 }
    1
    2 z0 V& _) H7 s6 D" Z2
    3 }$ Y0 B  Y  ^  S. ^8 X3
    ! P, a# X, F6 E1 S, O6 y& \! M* ^4
    * K  R. |3 m4 m4 z1 ^4 U$ C1 ?50 h' R3 z0 q5 d$ a
    ResNet(
    / K6 R' Y+ S$ `  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    3 H5 o  [0 F% Z$ F  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- ~: ?% u0 z$ K4 y  a) S5 D8 c* {- x
      (relu): ReLU(inplace=True)
    4 h: i+ S6 \3 `9 H/ z# s  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    3 n/ O$ r% P( F5 y. `4 t' l. ?  (layer1): Sequential(8 u) ]: n5 c; Q0 L! z# M
        (0): Bottleneck(7 q. ~3 M4 O1 j/ A% \
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    $ [& _1 p" I9 v3 g- B      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)% W% x) _) k4 e. i  c  r$ y
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False); ?5 u( m% ~, B0 L, l% I
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- m5 u: ?, h+ X* N
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    9 l* [/ p$ {$ G2 W/ Q$ F      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): s4 i" n1 t$ X
          (relu): ReLU(inplace=True)
    / o5 |+ E  [7 r: B5 K      (downsample): Sequential(
    6 U: J6 V0 R! s        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)  H9 j' L, ]) |: }8 H
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ( m- i, `8 y0 a      )4 |3 U! _4 l& a; Q0 E5 a
        )( N3 x1 Q. h3 q
    中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。0 n! k: U* e( x+ N3 i
        (2): Bottleneck(
    2 K9 e* _+ c& ^* i# \  H" u# G      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)8 e- \' }! ]  |: d1 B
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 m# a0 L( M' G) E! ?' j9 v
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)' z( \/ K9 F# f. \( p0 k% y9 m
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)5 r% |" x) v6 Y3 N
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)- V" [; a7 S0 \# s- y* d
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    . Y, b- u6 Y- P3 e/ J      (relu): ReLU(inplace=True)
    ) q1 a8 E& L1 E1 g1 J6 N% t    )( M+ H/ M2 Q8 A/ k
      )
    , q" y+ }9 u8 |" t. C  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)), r: i: X1 [" p1 L6 U" t
      (fc): Linear(in_features=2048, out_features=1000, bias=True)2 h& ?* X8 a) a
    )% P" n6 E' S- P/ {" X) [' f

    + t+ L  f# r8 v3 ]. E9 W) ?5 y1. b6 ~1 F7 q3 R: R+ z% `" ]
    2" J+ e. t9 V: A+ [9 C& j
    3
    4 y2 ?( I6 Q" N1 j7 p( u- @4
    9 c0 V; y: Y; ^* U9 g" I/ h; {  d  Q5
    $ _4 R! @9 s8 t6 b. J6 a  E( k69 j- i0 n2 o* E# E# k: d5 D
    7& \% g( D/ n+ Z" v2 l9 m3 E
    8
    9 d3 N+ i( F6 \/ a9
    ) t7 [" s' C1 @6 J# j' }10
    7 Q( o  `: f$ ~" W111 H3 r- C3 h; e3 c: d$ h8 v! p
    12
    ! p8 O% g  ^$ }/ v- \  u131 K' w) W1 l9 u" l. g% ~, R
    14
    / }) E3 N7 e& K+ X* [" ]15
    % A# C6 E; b- n16
    8 R  w2 s4 P  z1 }* ?171 z/ q0 R+ y3 t& t
    18: r4 f/ a$ b1 L3 ^/ F
    19! s' S  ?0 S% }. T& S' n0 G
    20+ W; m' C+ h) Z4 F" r
    21
    " S: w, |& B+ S! q8 H6 a5 A223 n9 Z- \- j( K0 K6 G3 m
    23
    9 t! C0 v! m3 M* g6 p24% ?& `! n, D( N+ g7 Y% u+ G
    258 p: G6 l) `* y3 C2 y, T
    265 y3 P. R4 d. c& c: I, d
    27! d! E! l! a8 E+ d% E
    28, R3 F# Y/ ~; a: e  H% W- p
    29! i( b( Z) ?; |4 ~
    30
    8 K4 w* `3 I0 Y$ G! z31
    : ?, G  ^. n% u8 i7 a32* E8 I3 v0 d9 ^, Z& J' d% G
    33
    6 G: y) X  W& T& X* d, p最后是1000分类,2048输入,分为1000个分类
    + ]6 {) f$ U* s+ ^, o: i; L而我们需要将我们的任务进行调整,将1000分类改为102输出- D$ d9 C1 G8 {) W! O

    ' Q# H- J8 C+ K9 S( V/ P9 S# m7 d6.初始化模型架构
    4 }2 s- ?' I' ?步骤如下:! \* M0 K  j  O& C( Z, E" j& \- R

    $ F* l! T$ Y% C8 q' i& K将训练好的模型拿过来,并pre_train = True 得到他人的权重参数6 m3 L! M1 b. N
    可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
    6 }8 U! G. O/ [, r无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
    4 G1 G% @8 H6 O' O4 ~官方文档链接4 y& F, U8 X$ }
    https://pytorch.org/vision/stable/models.html( T- w- e2 H8 y1 h# S& q
    2 }3 V1 n6 P6 P, _+ i1 i
    # 将他人的模型加载进来& U, _" n) V/ T4 Z9 P/ @
    def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
    4 m8 @" U' M# z9 h    # 选择适合的模型,不同的模型初始化参数不同8 {, [  A1 m- K4 `" `8 K
        model_ft = None
    6 B% P4 R( O. r" [/ F    input_size = 0
    * M9 h. v; c* K- Y  {, c% m, n' }) k3 v4 M  L; C' `6 h+ b
        if model_name == "resnet":
    # k: Z3 w* C+ r" ^( I6 w        """
    ) o) C3 J5 c) M        Resnet152( G9 X8 u+ q+ U/ B5 c
            """* I; c# c% N+ R' [( `. R
    3 @- R% F, e+ J, a+ o* y4 J
            # 1. 加载与训练网络* v$ K8 A, y8 {6 W. J
            model_ft = models.resnet152(pretrained = use_pretrained)$ q0 p  S! T& J9 |! }- X
            # 2. 是否将提取特征的模块冻住,只训练FC层
    7 b3 C( i1 Q* N: A9 Q+ F# _! R        set_parameter_requires_grad(model_ft, feature_extract)8 I! x% s7 d+ [( T2 a4 l
            # 3. 获得全连接层输入特征: u  _2 z2 |( @: J2 x
            num_frts = model_ft.fc.in_features. q; y& c5 b# ]$ E
            # 4. 重新加载全连接层,设置输出102
    . ^( V! j- {& c$ ^5 g0 P8 y        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
    4 Y( U% z6 R" j" M, U                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1* ~  X* }- Y) j2 v7 z; h
            input_size = 224
    7 Z$ `; [7 m+ E0 `- X' b# @  j$ \' F2 L7 |* Q
        elif model_name == "alexnet":2 l: [- D' a) C$ _9 S
            """% i4 }$ c7 n7 @. A; l8 w% G
            Alexnet! E5 E- z4 K7 }
            """
    " H2 ^3 h& s9 Z9 I        model_ft = models.alexnet(pretrained = use_pretrained)' I1 D3 W$ f1 R# [  m5 F" E
            set_parameter_requires_grad(model_ft, feature_extract)+ w6 k2 w$ B% g$ Y' O6 X3 R, I

    . i9 B" m  F+ h- V( v( k        # 将最后一个特征输出替换 序号为【6】的分类器+ I! W" J7 r% q+ J" Q0 d
            num_frts = model_ft.classifier[6].in_features # 获得FC层输入- n4 j4 P# T7 ^- R3 O
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    , m; O( M; H8 H" v; m' m3 f9 {7 S) [        input_size = 224
      m8 q3 B# J- P% h+ e9 K' Q9 Z2 S4 v7 u8 d+ f7 q7 L% G+ z
        elif model_name == "vgg":
    8 U( v/ a4 x8 ?5 `' e0 Y" T) x        """
    " W; y7 c. w* p$ `0 x; ^& G7 H% z        VGG11_bn
    8 R; B' J0 r% V3 ^  |& g        """- x+ ]2 s  P+ [
            model_ft = models.vgg16(pretrained = use_pretrained)  K' I- ]; a: ^$ B0 F8 r
            set_parameter_requires_grad(model_ft, feature_extract); Q4 i6 z! I1 G9 z9 i
            num_frts = model_ft.classifier[6].in_features2 W  W1 i" U+ m* E
            model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
    ' M5 q  ~3 v. ]( B: y5 v$ B        input_size = 224' }! {% V0 n6 V2 ^

    - w2 U  Z* k6 C4 H    elif model_name == "squeezenet":
    / ~% z3 Y( a: M4 S9 A        """4 e7 n2 O* m' U# c1 @  U
            Squeezenet! M/ Y% X& ~5 j* S: p
            """
      W" X3 w/ B. c4 Q% K/ C( @- E        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
    7 q4 t% p; W! p, k        set_parameter_requires_grad(model_ft, feature_extract)" W8 b% h- Y& q! @
            model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
    ) z4 y+ s: j% o2 q5 j  o        model_ft.num_classes = num_classes: c$ e/ Y; v* b& n0 }
            input_size = 224
    + J  ~: y; G" N2 K* k6 t0 r) Y1 z+ B; r0 U/ `/ O; T; T7 f, i) y1 n
        elif model_name == "densenet":
    8 o/ N$ R8 v$ M        """
    ; n. b. G( q/ W* S3 _- i1 |$ `        Densenet
    * V" ]! r5 D% B  `$ t& s, y        """1 J! U  O4 {8 r2 b  n3 T  y3 \
            model_ft = models.desenet121(pretrained = use_pretrained)
    1 y, }) n9 D# ~        set_parameter_requires_grad(model_ft, feature_extract)
    $ C( {! O; q5 H        num_frts = model_ft.classifier.in_features
    * g. _  d& @& I        model_ft.classifier = nn.Linear(num_frts, num_classes)
    ) F# d& L: f, i4 x        input_size = 224
    % g9 r/ Q( @4 a0 v. \) ?+ o  [+ a
        elif model_name == "inception":% r# p& F; i$ X) P
            """
    1 f8 |0 w9 j# |/ S        Inception V3+ N  A+ ?5 w/ o, j% H
            """
    2 _' c# h: G  F7 E4 ^( ?        model_ft = models.inception_V(pretrained = use_pretrained)8 g  U- @0 n/ r  y, [! m4 A/ w
            set_parameter_requires_grad(model_ft, feature_extract)
      @! m' x& F& R* ]- i' _: {& r* c* c7 z  z; e
            num_frts = model_ft.AuxLogits.fc.in_features0 @+ B; o& L% H9 v- i/ @
            model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)' V; {: @$ j2 M# t& s* Q

    9 X3 ?7 {# O1 I% Q        num_frts = model_ft.fc.in_features, }) c; O. }/ z" U9 O/ x
            model_ft.fc = nn.Linear(num_frts, num_classes)
    5 P8 J- h! u) ]3 g% o  i        input_size = 299
    2 `/ s5 l+ m/ C  r8 m& d" D/ q1 F4 N4 m5 |& g
        else:
    3 L# ^2 I' {4 @& L7 a9 q+ ^        print("Invalid model name, exiting...")
    7 u* j: Z. g4 q6 U. I# ^9 i        exit()
    ( P( b7 ]4 S" T% v0 `) _
      o1 L- D8 c/ N0 x    return model_ft, input_size
    ! p) h# C* c9 t7 ~1 F1 ^
    4 I, \+ K+ k/ I1 C6 ^0 |9 s: h1
    6 ?5 M$ p/ b5 w1 m2 n# _2
    ' d+ t8 r6 T. [: [; C32 W1 ~& e8 E+ |' k' j
    4
    ; Q- l1 ^7 C$ U6 @1 I$ W8 H5
    * Z+ O- W& D& t7 m+ E) G7 f9 W6
    7 h  x0 N7 D" o8 w- H3 o/ P7
    / {) x4 E# e2 V2 {8' X2 |% U! n# ]! }& q4 R  W' Q
    9
    ! t6 l. Z- ?6 Q. I) U9 K* ?10
    8 ^! e8 t* S! K; g11  z' q7 [* V& o/ v$ O
    12
      w6 `- h% l  h% m/ \' D13
    6 B/ u( K+ X& F" U, ]% B' a14
    : M4 |0 f! z: D15% x' M9 T- g( l
    16
    ! F8 h+ B% B5 L9 E/ U0 b. u4 k17
    % T, D8 j  s: r: L) ^+ a- W$ z18
    + e* B" |- S7 z4 I) b8 b1 n19
    6 ?9 {9 o: Q/ u" V' M9 |20
    / d3 C+ ?% S) h) b! i% ?) }21
    4 f. I5 M$ |  x, k4 c22  G1 ~) o5 b& p9 V
    238 E" j  c" w# h, X8 ^2 a" x' a" ^
    24$ J( k% P$ k9 |) t  y5 ]: Q0 ]. N
    25, ]9 t+ }" ^; d: }! r+ L9 E
    26) r% J  W, \7 y+ T6 K
    27
    : E3 n' ?# M3 B; f0 c* K28
    - i2 h7 L1 L& E, P! Z/ Q29- ^6 h3 ]0 z$ Y5 \* j# I
    304 }9 J8 Q2 ]: B- r* n
    31+ B' C4 f* ~4 b5 S% @
    32
    . H/ W7 _: S  x/ _+ a33
    9 ~; y6 q% ^2 J, x1 C  O* r34+ ~- |$ Y6 d" M- Z+ ^. L2 ^) G
    35
    # d+ p+ Z% P+ `; R369 M# |6 U9 A6 Y
    37
    % O* y- A( T- E4 M38) R. i. l5 a; h
    399 Q+ ]  w2 q) j
    40, d/ N4 Z' c6 n$ p" l7 a
    41: a6 Z: X0 d6 ^; h/ _5 ]
    42
    ' m& A" O+ p/ [+ I43
    0 `/ |' \# J3 k* K44
    . T! P2 ?+ A6 [7 \" Z' s45
    & d1 q1 {& {) D3 Y46" N" V: b: x! g9 u
    47
    0 L6 a# g6 s' @- C' |/ [48
    5 e5 T: a% t( e7 x2 w, r49
    / F5 ?7 C9 h: f+ d502 N+ |% A+ u+ |; p
    51
    8 k  s3 S" v' g2 V" K+ V52$ ^6 X1 ?+ Q% u9 M( ]9 t' Q, A+ s
    53
    % i- B, _5 m3 l" `& ]  C54
    8 r6 h1 _$ c+ i554 x, A$ b' F7 U. b* l' A
    56
    4 f' l1 Q9 s! S3 q5 K  _: c57
    2 I4 l' M. L0 e4 e" d& s5 K. _& d1 O+ o58
    + {/ @# M0 T: n3 ^: R0 e59& }2 o: n: B8 z% X8 j$ i% g; Z
    60
    ! a% `" v) \( C* {+ j/ V: f2 Q  b& b61( a% ^2 a$ v4 s; B+ j. q
    62* ^. X, m. s0 O7 q. D& v0 z
    63
    % x( I% y' M6 ~8 H4 a; M4 u2 `64
    ; l% n- s* t" w. o/ ?# k# r; E65
    + J( y; I: w+ W* D. _66
      Z7 k1 d" m$ F8 ?; N67& D: D& t& u' d5 ]) N
    68  F- `0 `4 y( \6 S* F3 O
    69* |! j5 _+ E( C8 w) M' v
    70: ^8 _) _7 c: D* h4 ^0 o8 c+ }! q
    71
    9 L. X4 Z  T( }) Z) _$ ?/ \72
      ]/ \7 _6 X" A% |6 T# \73
    % F+ l& r3 X& \% ~; W74- p* M& p5 j3 S* r! }- f
    75
    2 ?3 l1 [! j5 j% C& g- `76
    , ^3 E; ?$ l' N/ [& P) |77
    8 C) o- c' I7 v78
      A; u; W) A1 u2 y79
    2 T- k2 M  L7 F* D8 D80* p' y7 S8 P% I- d
    81( M  k; r! j% w6 A
    829 h8 D" l' c& E4 S: P  u
    83
    7 z. n) r1 \6 C' P7 z7. 设置需要训练的参数
    / a+ O) o! h9 p# 设置模型名字、输出分类数
    ; r  R; z( s7 b; {$ S: G1 ~model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)( S, a* g9 y" k, V8 a
    & _- W0 j8 D$ E# v& h: c
    # GPU 计算- j+ t, c1 [1 G1 Z; z# v
    model_ft = model_ft.to(device)
    / I$ T' W; a: g: z# i) N
    4 `$ }4 ?8 r5 b4 f! h- p# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
    0 ~9 I' p8 a) n: A7 ]$ a, Lfilename = 'checkpoint.pth'
    ( V9 X2 u  J& \6 V! _# R0 s
    $ R3 {. l  v6 }$ U9 u# Z! c0 c. i# 是否训练所有层# V- s+ w/ A( o' X
    params_to_update = model_ft.parameters()5 K+ A/ S4 p8 n
    # 打印出需要训练的层: b+ N( l' ^& `: s) L& J
    print("Params to learn:")$ u3 M0 e: h1 ?, E. Q
    if feature_extract:# W/ [' _+ v- F5 z8 ]! a
        params_to_update = []
    & J" }, e% t  m4 E- c    for name, param in model_ft.named_parameters():
    " H/ T) I' K, p) Z& ?1 h3 \# Z        if param.requires_grad == True:# U' q& G9 }0 p9 W- _. s
                params_to_update.append(param)
    / x) ?6 T  A: L' @# w* ~, K            print("\t", name)
    4 z. p2 C# q1 c$ J2 P2 H4 R: Ielse:2 n9 o8 l! N0 F8 o1 H
        for name, param in model_ft.named_parameters():! c: c5 ?8 X4 {0 B2 Z9 z, ]
            if param.requires_grad ==True:
    . ~$ T2 F( h( e  C& V. M' |            print("\t", name)
    / B- r! L3 Q- {3 T  V# v8 u
    % E( ?! r' p( Q3 Q. n. |5 e- |11 Z, j; I3 l$ i" h8 V& T# e
    26 ?" J, n; V: i  }, M
    3
      ~" T1 c, N. A. i4
    1 Q- g( O% O* q2 q- r/ k5
    & R8 E5 j/ q. K/ c  f5 f. U4 \+ N6; O9 }3 }9 }, l# X
    7/ A. D; m2 i/ l
    83 Y% z6 w4 F! u" O0 I6 b- a
    9* ]2 c& n& [" O) h2 S# n
    10
    0 @4 J' f2 n) a6 w7 E11
    ! q9 i% K& q7 u& W$ ~+ J12
    5 Q- M0 e2 }* R0 V6 @& j6 L& c  s13; }- q! S5 N5 K; M
    14
    3 d5 l2 O" H, a15" B; o: ]3 Z6 H4 ?, O2 d- A
    16- V6 _$ _5 i; Y* a
    17
    ( V5 X! H8 v. l7 d, S18
    9 N( g, o1 L: u7 G5 F+ X: _4 R19
    6 I, h( K! C, p1 S9 j: Z8 q20% i& ]6 I! h- H7 t
    21+ m' s& b# D( `) h
    22
    . J$ F& I  s/ v1 R! V' C23
    4 j  H7 x2 G! d; ?8 ?8 O  WParams to learn:
    9 ?, W& Z! f5 y5 b" Q4 L1 ]4 i' |         fc.0.weight# t" t7 `# I7 L  K; B# S
             fc.0.bias- x* ?) V& ^. E3 O
    1
    - O# z; u  Z. N" U7 [3 I4 f1 ~2
    ' k& j# r& Y8 b7 o0 d3
    $ p5 H! ]3 [  u! a, M7. 训练与预测* h) _9 ?& T% r
    7.1 优化器设置6 k6 v. q  T. m7 X
    # 优化器设置
    ' G, q& r4 z8 L1 m" T. Boptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
    8 ]7 d  _. x( f. ^6 O+ E& G: A# 学习率衰减策略
    1 u6 k" e# P: }. o2 w7 jscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    * m( G! \+ H$ j* l5 D) w7 z7 Y# 学习率每7个epoch衰减为原来的1/10$ o  e. F. |2 X, D! N/ l: D
    # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
    ) k6 T7 ]4 l# I/ A+ X. b+ |% [4 N4 K! \2 c
    criterion = nn.NLLLoss()* u( }* n4 L0 P0 v* \
    1
    % X  H! o& @, b2 g  ]! i4 \% f2
    , u! J# ^; o6 c7 ^4 L8 F3
    : f) D9 |, A& V5 y$ \9 E1 i. e$ b4
    7 y( o: E3 z' F; e+ E" ~51 B1 ^5 ^% T9 T$ X6 ?# K6 J
    6
    " I3 R/ M$ I% V9 K+ c7- V- x9 V1 I4 @9 ^# X; T! i
    8
    - b! ^/ Y0 u# d+ u7 d# 定义训练函数
    . ?1 r5 t3 w% d9 t& a#is_inception:要不要用其他的网络* E  Y' g+ K' `
    def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
    - C/ R) Q( q, U0 E3 T- N: m2 H, w" j    since = time.time()
    6 [# K2 k# n6 A$ n  Z. b    #保存最好的准确率
    7 @) d3 G" v9 }5 e    best_acc = 0) ^7 B$ ]3 k5 J) y5 Y  D7 a9 k
        """
    ( A8 v* u) k2 j7 F    checkpoint = torch.load(filename)
    9 q/ |: p. G/ w5 A    best_acc = checkpoint['best_acc']# i6 j8 R+ C1 r: W5 j
        model.load_state_dict(checkpoint['state_dict'])
    : u1 ?5 _- _0 Z4 m& r    optimizer.load_state_dict(checkpoint['optimizer'])3 b2 k$ ^5 r& u
        model.class_to_idx = checkpoint['mapping']) m0 u" x2 e' ?# }; e  y
        """! z) b: r9 j6 G- L
        #指定用GPU还是CPU; y' v% k' L4 v
        model.to(device)
    3 F/ H4 O* B' w    #下面是为展示做的
    + ~( @% l3 A0 x/ T! X9 w& R    val_acc_history = []
    + x# S8 G% S' z( N! v+ e    train_acc_history = []8 i) b* G! @+ d$ j
        train_losses = []5 k+ T0 C+ ~1 x! M: O/ ^
        valid_losses = []
    % K; f" e9 T& k5 U1 K9 t# W: i    LRs = [optimizer.param_groups[0]['lr']]+ q7 l# f/ y' u2 H. ]7 W
        #最好的一次存下来8 `& R3 Q6 X, ?2 F9 Q* l! a8 e
        best_model_wts = copy.deepcopy(model.state_dict())
    . a) l3 o! V( m. \
    5 s  d% {4 ^! Q( Z) A+ b% B' E    for epoch in range(num_epochs):
    ) N, `: ]7 @6 s        print('Epoch {}/{}'.format(epoch, num_epochs - 1)), ?8 C# |3 h! ?: z6 T. ~+ V2 v
            print('-' * 10)
    * p! P  [3 C2 g3 k, ^0 T5 ?2 U0 [+ M# g
            # 训练和验证
    & b& U  h" A% R        for phase in ['train', 'valid']:# i& v4 e( N; S( A! @3 F0 N
                if phase == 'train':
    5 V8 l" a$ `- f1 Z, u* ~- J                model.train()  # 训练
    " q) l; j$ q+ U            else:
    " v8 f& d2 y# B8 w# z. t2 f                model.eval()   # 验证
    0 m; n' m9 t1 J5 x  l* y  Q
    ) l" S* ^4 {) M" w" M            running_loss = 0.01 W* \( @4 n* Q0 Y
                running_corrects = 0
    $ F+ h: N. x9 w
    6 B& c1 `6 R$ d& X1 ~$ z  y            # 把数据都取个遍$ t' T3 r7 e) F5 q
                for inputs, labels in dataloaders[phase]:
    * `; @! ]) l0 j, F                #下面是将inputs,labels传到GPU
    3 P0 Q6 j! K9 g4 @                inputs = inputs.to(device)& q6 ~1 S, ?( d% R" x: z9 }
                    labels = labels.to(device)
    9 F. m. w7 b2 L: M' ?0 r
    : h* b( @+ S! M  F1 z                # 清零: j$ T+ q4 x$ m8 H' v5 U; S3 q
                    optimizer.zero_grad()
    3 u  i1 D6 p& _& @) U                # 只有训练的时候计算和更新梯度- X& }+ O  a6 y% R/ H2 r
                    with torch.set_grad_enabled(phase == 'train'):
    6 `5 b- ~" a. C% m% X                    #if这面不需要计算,可忽略+ b# u  J) ~9 A+ U  \+ b
                        if is_inception and phase == 'train':# s3 w$ L% o/ s) N: ^- o  s( E
                            outputs, aux_outputs = model(inputs)
    ( Z7 j# V- t) t3 ]# f                        loss1 = criterion(outputs, labels)# [* |1 v1 X% A& u% Z% a) m; C8 W
                            loss2 = criterion(aux_outputs, labels)
    " E1 c' E. }9 W: Q# t' f, [                        loss = loss1 + 0.4*loss2! Q( O6 C0 p+ a/ T  A/ }
                        else:#resnet执行的是这里1 I) a9 b7 D$ U0 M
                            outputs = model(inputs)- A) `7 u6 ]: z+ Q- ]/ v% G# r
                            loss = criterion(outputs, labels)0 ~5 C6 ?8 E' m0 v0 s5 y; ^$ w2 u& L

    3 [, Z3 H6 G' R! P, k. F( B1 t                        #概率最大的返回preds
    ( E; K& X8 f5 C7 a, q- L                    _, preds = torch.max(outputs, 1)$ F6 t% G# _# x5 e

    2 R, a, @' g! d' U- M% i                    # 训练阶段更新权重
    - q) t  D1 J( f' o& c                    if phase == 'train':
    6 W& v) k* O9 S6 [/ i: S' P                        loss.backward()% A' b! P' F, Z1 z& R( F
                            optimizer.step()* q6 x7 a; F2 Z0 t0 [

    ; D1 J) b9 c* F2 i) d- X                # 计算损失2 V' ]* m5 m- S6 N0 G0 Q" U
                    running_loss += loss.item() * inputs.size(0)1 s9 }9 W7 {* t4 |( [3 d4 i
                    running_corrects += torch.sum(preds == labels.data)4 V: H0 c% u4 N2 X# a- Y% x
    1 ?5 `( H+ r8 j3 Z0 W
                #打印操作# ~5 Z1 v% I! y) Y. ~- p) h6 b8 R: t1 z* t
                epoch_loss = running_loss / len(dataloaders[phase].dataset)" g: y% j' Q* {- t( U. f
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)/ f& _; M3 I8 c1 A

    & w( E  ]9 J% W5 U2 V8 i" c& p6 v5 J- F$ \8 F& r3 [3 S
                time_elapsed = time.time() - since7 U% F9 ~! G/ I! Q0 X. x
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))/ q8 v7 s. A7 U+ {: ^( h
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)): u% \$ `6 Y* A9 c3 C% U
    ' s1 w$ ]$ U& X

    9 v9 h  }$ i  O            # 得到最好那次的模型9 H7 z5 `/ s+ Z6 B1 b4 K% c
                if phase == 'valid' and epoch_acc > best_acc:6 S* v% p; v/ d) d1 b1 R# m# W: O9 h
                    best_acc = epoch_acc  A1 Q1 R5 ?4 X" ]- Q. a: D7 f' ?
                    #模型保存( F  y, o+ a  L5 W
                    best_model_wts = copy.deepcopy(model.state_dict())
    7 O! o7 E0 [1 F; P% M                state = {
    # a% ^* @3 H* ]8 ~& k                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数) o# w; |; R3 u
                      'state_dict': model.state_dict(),( y8 P" V6 g8 S' c4 a4 G
                      'best_acc': best_acc,0 X8 m, r1 j# s" B) U  C# l" j
                      'optimizer' : optimizer.state_dict(),
    ' k5 Z7 A' `4 K- N8 r/ B                }* h. m$ w3 S2 L* o$ C7 Z$ t
                    torch.save(state, filename)
    ) _6 G. E5 K5 I7 ^+ l) A& v- y6 i% G            if phase == 'valid':
    5 G5 X7 ~2 ~8 R" d: N                val_acc_history.append(epoch_acc)) ~, `, G" X/ }& T
                    valid_losses.append(epoch_loss)9 f& }( b; V6 A2 f
                    scheduler.step(epoch_loss)3 v; V3 m# O  o
                if phase == 'train':
    2 }4 E6 }: a* w  E' S" \                train_acc_history.append(epoch_acc)
    5 g$ ?) F( k( z( \' x' v- ^                train_losses.append(epoch_loss)
    4 i( a! @/ `) L# t' b' u* B/ C( V& n9 j. }  c2 r. F
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
    3 w# g' v7 T6 p1 t) h        LRs.append(optimizer.param_groups[0]['lr']), l" V) k  n" W8 e2 U: X
            print()- q* b: U9 }, a; k
    7 @8 h4 t+ ^  ?: p2 p1 Q
        time_elapsed = time.time() - since! S2 G6 x& ?1 R6 Y5 ^" K1 r7 V
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))- v5 `9 t2 F* j5 ]# a  w5 _, }
        print('Best val Acc: {:4f}'.format(best_acc))5 Q) N5 Y' g0 X% u2 e1 n  N
    6 h# e# T" o& m% Z4 |: K
        # 保存训练完后用最好的一次当做模型最终的结果3 Z+ }& W; n& L" F, `5 C4 r
        model.load_state_dict(best_model_wts)
    / G9 g( ]  M6 @/ c; k0 f* k    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
    : U9 ^  j! X5 ~; C8 a. L. L; r: o% p" W! i+ W- W4 [$ k$ N2 \

    % r- `. [" u. C1, E+ z5 w8 _* ~# b% M; s
    2& [2 u+ q5 |! @" E& E7 B
    3* m( f+ O5 A% Q! @7 f
    4
    ( Z' u& F& y$ e  ~5
    , ^: S! j, m' t% K9 Q# g, s6# Y$ h- W! j) P1 v( u
    7
    2 ]3 C! y1 L5 `* t1 E3 r. Y$ t3 R8
    # U& J8 `, m5 ^6 T& ~! Y9
    0 A0 b+ u: _' l$ ]8 d. c106 p7 i& X  ^0 W
    114 e8 t: v4 X" h5 @9 ]
    12# b3 T! r. K: a
    13, F3 x' D7 r4 C9 U4 |
    149 N- Z$ e/ J) V( o8 X
    15% y5 m7 _, V+ P' M+ V( _# E7 H
    163 w& J* S" |8 p2 F4 n
    17
    ; M2 W2 r2 \) M2 Y# y; y  `7 k18
    . V. k+ Y# W- H9 d: e7 J19
    4 Z. j( r3 J6 X+ \2 t4 G3 q20
    # H; @5 m7 F6 P$ |3 Y210 q0 v' u+ n/ D2 ^6 B2 L& y' @
    227 F* v/ }% S8 }9 @: D+ Q, F) u
    23( [4 q! }- n( ^! h6 h7 I
    24( K5 e. x- N, @; k$ B% C8 p
    25; Y0 T2 N. F+ n" P
    26+ V( r/ F+ K" V
    27
    $ y  i) _' }( O- V3 x; b28+ c$ @- r! B3 e9 t. e
    29
    4 V! \# w/ O* p9 ~4 Q% a& G, @30
    ; T$ J, j) o8 x  s0 z) e31
    ; F" k% a, f, P; Q& O! n6 I32
    . s  }* M5 N7 ~7 L+ i33) q; v, C, F8 h' v2 ^/ F& l, }
    34
    ; Z4 l! w6 S+ w8 @35) ?# J, T, q; [
    366 }4 ]) q& }7 h0 P8 {) L
    37
    - }6 J- V9 m, S( ~38
    + N; N5 |6 h( M. l! C7 P39
    ! t' b& u& H3 N( R* `+ V: m40
    / I9 b( }5 f; V& C, D' G! g41
    ; N* z) L$ g2 X" Z42
    0 q1 I' S! C% _1 Y5 T43
    ( l1 D& m/ x& ]6 w6 v# E' _0 W% j440 H7 Z- |9 n0 m* `
    45: e  a: X" P/ ~. E! K
    46
    ; l$ D  L! y2 o" m6 k* `9 Q47. v# G0 y" f$ J, A4 I
    48
    : ~2 R; E- ]1 W1 ^9 @* U491 @4 R. C, \  @6 R+ _7 W
    50
    ' b8 R, h( t* Y) a+ \4 M4 \: z51
    - I  l) \3 A: v8 P* N1 r521 `0 s# F; ]4 N$ U8 n; }
    53
    3 P* R5 ~- A- ~' p7 r; q; _! h54( p: x  W$ [1 u, w# k
    55
    6 ~" }" o- Z0 _9 G0 Y7 ]& a% E565 @1 p; v" \, q8 L- z$ d
    57
    % j- \3 Q3 m3 l6 b- k4 L" q58
    6 d! z6 x* g1 S6 v8 O6 O& S59# w* C+ P* e; d; l
    60
    : g- G& |# R4 h# E8 A% A5 d/ t) c61+ t( V3 f" N7 U. q0 _+ R. Y
    62
    $ r8 E3 T8 B' |' g63
    ) M1 }9 ]" p. g+ ^8 p: u64
    * G8 }; N8 N- w+ f65
    ' U, o) k" }- l, B+ o: x: o0 y66
    % s' r# V1 M; [7 E+ v- Q67. f7 o4 z% o1 i4 M' y+ g
    684 }2 @# p; y7 W  K, `
    69
    . _, H( M* n* l6 O/ {# v4 ^' o70
    1 h" B' ^/ v( u: C( U; W715 q: J8 D: {( {! i8 R; }
    724 S- I: I! J1 }
    73
    * s6 f5 b" \7 f! U7 O; O/ m  ^74
    : G# i( R$ }: u& n4 Y8 v5 `75
    7 `7 [- k2 n' ~9 g76- ]/ Z/ }2 k2 Q  C' N2 K
    77
    5 k/ K* o6 r( t6 D4 g2 I7 p& w3 |787 Q1 g" T$ h8 q( a1 A" E0 X% _
    79/ X) b0 g* {6 I+ Z. e1 f
    80/ {  f$ u8 r: Q- p7 ^1 G
    819 A/ r# I6 @/ ~7 @
    829 v# T8 j* ?* \5 b, H, o! C9 R
    835 u; D% d; u& a/ a4 x8 p9 V$ n
    84
    , E  o  D9 ~$ R851 |8 B! D& j# N  S9 o% t3 G
    867 }- h9 w# p8 K* b
    874 h) D8 D1 G& j3 ~' [
    88
    1 E  D4 f& q" R" {9 R89
    + U8 \; N+ P, m, h) y90
    ; V* o( \7 ]' h* Q: Z910 `  F8 |( B. X+ i, p2 v& d
    92
    0 E. q- d; x! c; l93
    . G+ r3 ]) O. N; e! i- z- d* k8 s944 X8 |4 m% _5 w/ M$ A1 l$ J
    95  b' p3 K- {; q  h. s7 r0 I* w3 r
    96
    % u/ o, ]4 h$ T97/ [  q4 M1 [/ g) K6 n2 l2 ]
    98" @4 Q% ?$ I. ~* A& h0 ^
    99
    ) I  [. p! Q* h' y0 u" ?100  h4 g  }  A% g6 w  Z
    101) [0 S: c' O! F& X, ^% Z6 G+ Q
    102. ^2 [9 {8 E* x: U# u4 }" ^! y1 M
    103
    9 g$ A7 l, O& |# n6 K. R! w" n2 J104
    " }7 f- C) v" H( d105
    9 R1 X6 Z$ l0 Y: c% l( Y106( X) q) h' R+ _4 W' |8 M1 D/ o+ _
    107+ D2 m1 m! ^9 d  ]1 M4 y
    1087 s/ J4 M2 K. }, K3 l' y, V
    109& B6 A2 X' y( |# }: Y7 W* l8 }
    1102 T9 h, G% ~6 {1 h. V) W8 ^3 P
    111! W7 X0 l3 s5 \' s
    112
    * }! h9 e5 `5 p+ H7.2 开始训练模型
    ) ?% y! R( e, u( K我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
    0 W3 z' f8 ?5 t% {, I& m8 k1 l# b- J! n# r
    #若太慢,把epoch调低,迭代50次可能好些
    0 e8 l+ c, |8 L6 V#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
    . ?0 Z8 b) {* D7 H0 U/ R7 emodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
    ' R8 _' Q+ c! S& O1 {8 P
      z8 C, A6 E) ?7 i1
    & r5 ^3 r3 v) T; y$ X& @% n+ i2
    * J% Y, _0 M5 u. _& i3
    ! {% b' s, ~) w# P/ Z/ U7 |4% k9 B& I0 i5 v7 [# e0 c7 Q. v+ V5 s
    Epoch 0/4
    & I3 L  s! ]9 f* E----------
    - s2 @2 G' j) `0 g; kTime elapsed 29m 41s
    . W% Y& P& _* r" g4 U3 ltrain Loss: 10.4774 Acc: 0.31471 A! T8 |" _; N+ w& k
    Time elapsed 32m 54s
    + d  H3 ], G: j) g( n2 m4 M9 f" Hvalid Loss: 8.2902 Acc: 0.4719# l/ W& |! e9 b3 \2 ?) ?7 x9 I
    Optimizer learning rate : 0.0010000
    4 W. j  C% T: E* d9 y. D1 t2 B( g" C! y  X) i
    Epoch 1/45 Z% a. A, Z( ~  h8 z
    ----------
    7 |; e  }  m, STime elapsed 60m 11s
    - r6 o* V" h% v) ztrain Loss: 2.3126 Acc: 0.70539 K- u1 S/ _" z' |6 n; M5 s* j) y
    Time elapsed 63m 16s6 V+ l) B+ m. c: @
    valid Loss: 3.2325 Acc: 0.6626& ?' t" V9 \) M6 S$ \+ U: j
    Optimizer learning rate : 0.0100000% ?$ O9 A, l) y& I
    ) d# k0 C  Q7 Z, V( m
    Epoch 2/4
    3 I9 Y/ [' N1 x----------% Y; p. I% O" ]/ B0 M. `' Y
    Time elapsed 90m 58s. _, A! K$ K" h' l
    train Loss: 9.9720 Acc: 0.4734: T) Q8 c5 Z5 f6 N. x
    Time elapsed 94m 4s
    # P4 ?$ u; s3 Q0 r) w: avalid Loss: 14.0426 Acc: 0.4413
    ' @/ b1 V+ Z" z9 e8 I+ {Optimizer learning rate : 0.0001000
    . n; A, Z  t9 V- o) ~. U4 B* U: t  G5 Q9 A8 W5 h9 ], E1 j
    Epoch 3/4  b* U) h, x4 N% l$ m: S
    ----------! l$ F: h  T0 d- Y& x% T" `% x+ U
    Time elapsed 132m 49s
    0 o" ^3 D# C2 L/ k5 E: l; ~8 ltrain Loss: 5.4290 Acc: 0.6548/ u) W+ F0 D3 u/ H  U
    Time elapsed 138m 49s, F6 n0 _8 ~2 i0 I( y! n
    valid Loss: 6.4208 Acc: 0.6027
    & k" }+ Z3 r$ @* E: x+ `Optimizer learning rate : 0.01000006 v# B# X6 L7 i) }9 z; P4 G1 ~* S
    8 Z' i1 J8 @# \4 n8 I# K1 [( C* V; z
    Epoch 4/4" K2 i" p4 i& ~4 I
    ----------
    6 x: ]; O. [' r  Q8 \( l5 m0 T7 VTime elapsed 195m 56s* n# v9 Y6 J4 F( y( `, s% @/ c
    train Loss: 8.8911 Acc: 0.5519
    2 O% w0 m3 l2 c  e- BTime elapsed 199m 16s( r% H, ?6 p1 e
    valid Loss: 13.2221 Acc: 0.4914
    8 E' `7 B6 I, f% d$ Q. KOptimizer learning rate : 0.00100000 Q! |$ z% t' F. w& m
    7 {" j0 \: r9 R: V5 Z$ s- e/ ]
    Training complete in 199m 16s
    ; z- @9 F& _9 i) o" JBest val Acc: 0.662592
    / ?! E! n1 a6 q% N) \6 A' |$ Q8 |: x) M  G( i- ]0 A
    1
    3 n  Z. P+ f* L* J2 N5 K! G2( a  ~8 A* E; D; _
    3
    3 c1 X3 [+ R& e  B6 u$ V3 B$ I4* D- l7 u7 T" T
    5
    2 \* q) {6 ~9 n4 r3 K% t3 `68 `3 S, ?* R% a8 m) i
    7# i3 U/ P  v5 g  L- w( p( m
    80 ]8 ~/ u1 Y$ @* |8 H
    9
    2 w5 U0 K% N3 r3 t# \$ j7 T10
    8 m0 Q/ J: t4 x11
    6 Z$ z7 N8 m$ a1 J12
    - d( d; W& K0 Z# ~' J13
    ) F6 r/ i8 C+ L14+ `! i2 J$ }* _- Y
    15
    0 Y+ d7 u+ f& A9 O, X: j16
    8 u5 v+ O" Y5 e3 E: L17
    8 V% p4 ]$ ^( _; o& a# J- _' f18
    3 e& W( J4 Z* F% G19
    0 m7 ]; t5 o6 t3 _/ w20
    ; ]9 i$ k9 B- Z& N21# j, J: v& |3 ?
    22, {. J- K5 K7 d9 m
    23, G- E. T3 Z: Y* y4 H; }+ K) I
    248 O9 k# d8 b) r( e7 w
    25) r# I* V+ t# e4 H& e
    26
    ' k- `. O3 ~: c/ J1 A/ s27
    - Z' ]. w# N+ Y' D28' k7 t# y+ h6 _- [# H$ Q" X
    29" z$ f) Z, F7 g
    30' k& ^6 V) b. n9 \8 a" R& O
    31; G1 y+ c6 U* s7 K2 C9 i0 n
    32
    3 y( |; M7 z3 S# T- H' P# {# P33$ u! ?2 `3 G6 }* G
    34
    ' g3 o' p- m, {! Q35% O$ f( m( X1 d. h6 g9 `
    36; z& p& X  t& u! g+ g- @$ b
    373 J/ g+ M4 c' |1 |
    38" t- e! |  H' R5 A
    39
    . {7 D; L, j# T1 U40+ O+ \( Z8 o0 |% q% R) H3 y1 ?$ e
    41
    $ Q$ j  \. [- A( q& B- u$ ?5 @42. B; d. t% f. S. ~3 d( c$ V
    7.3 训练所有层
    $ a" N9 t. C+ M" }9 L# [# 将全部网络解锁进行训练
    . W/ Z3 r( l4 @* A* W* j' m( R( dfor param in model_ft.parameters():6 b1 I9 u7 w! v3 J+ j
        param.requires_grad = True
    # J* U  A! H2 a4 Z4 w6 |) C2 H2 S
    * u6 X" P+ `! L2 i# X* c. c# 再继续训练所有的参数,学习率调小一点\
    - W: W4 ]: h- t1 v/ ?# Toptimizer = optim.Adam(params_to_update, lr = 1e-4)! _9 i  q# I5 ~: V& f6 \1 G
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)% f: A/ f; H- Y  @) c- I, u
    8 B5 r1 f! `7 B: T% [
    # 损失函数
    $ S- F  V3 B. k' G( m# }criterion = nn.NLLLoss()
    1 q+ T6 w) p5 S+ p1. q4 t: f+ p" S
    29 `+ Z' X# @" R
    30 |/ @! t2 c% d% E: ~
    4
    0 j1 ^2 N5 v: ~# ^& S9 z5
    * h! m& U0 w8 V( l6) k- D2 F% B+ y6 _  B% `
    7$ c: H7 X, a9 Q8 D% Q0 E
    8
    . {) _# x7 T% n$ p7 ^5 h' k" U94 H2 ?/ s) {6 j- W
    100 {- t" `: M( l; s, m& \4 i
    # 加载保存的参数
    0 L# |9 O# O/ J3 C/ c; h: A& p# 并在原有的模型基础上继续训练. j' o0 J8 L; z6 }
    # 下面保存的是刚刚训练效果较好的路径
    $ M( U6 F" _3 X7 y' {* o5 Echeckpoint = torch.load(filename)" a1 s- j# l) k7 P8 K# u/ R
    best_acc = checkpoint['best_acc']
    9 Q9 w7 E! o; @- z$ s2 x) ]model_ft.load_state_dict(checkpoint['state_dict'])
      i( Q5 P, r* `* Loptimizer.load_state_dict(checkpoint['optimizer'])
    # d$ v1 w- z" k9 \" ^! G1
    : _, C0 b7 ?7 |( t, g/ |8 Y2
    5 n7 a9 W) M0 J, ^! ?31 Q% M. g: T5 n( D- ~) L( l
    4
    ' V2 O0 b1 m* G. @9 R5
    : s( l" [: s2 T1 S6: w' W1 R$ m! z
    7/ y' ?5 {0 r6 ^, n) M% @* y, y0 J
    开始训练
    ! p  ?& ~5 Q0 r9 @$ O; [# O1 y& t注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考" t$ D" A! b, ~( l
    9 `2 g5 T* f, _* \7 I0 `
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
    . Y# M, _. H" W3 h8 a" e) R) G  ?# c1
    2 x) ], M- F3 }0 z: R2 \3 h" W$ U) kEpoch 0/1  w& c7 r" K9 ]0 D4 W* m
    ----------' s2 N* X! p1 j6 `
    Time elapsed 35m 22s9 G7 l' d/ D3 _0 o" t/ C& E, ]; ?9 N
    train Loss: 1.7636 Acc: 0.7346
    # `" a# v/ Z& e8 ?/ gTime elapsed 38m 42s
    & V2 a7 R! N; I6 v* J6 U% avalid Loss: 3.6377 Acc: 0.6455/ o7 \' F( [: e3 F: G1 g$ s
    Optimizer learning rate : 0.0010000) i* S9 A# V2 K6 I0 @

      }1 r7 G( w) Y5 t% REpoch 1/1
    / I: e7 E( y" `----------
    / {+ T2 B3 V" JTime elapsed 82m 59s
    ' E* h+ f8 K" M4 w3 Jtrain Loss: 1.7543 Acc: 0.73400 o+ w1 D6 N5 y0 Y+ w* r* N
    Time elapsed 86m 11s2 |: b1 `. W3 ]! a" t; u
    valid Loss: 3.8275 Acc: 0.61377 W! O* k9 e# s  }+ U( V( P
    Optimizer learning rate : 0.0010000. l$ P6 k/ ]: j2 `: d9 L" R
    $ |3 b1 ~& }; T
    Training complete in 86m 11s
    - O/ U* D. E. k- F/ v/ L' r7 |Best val Acc: 0.645477
    , z" n6 K( N: }4 _
    ( F  L( H5 g% r1, g7 w9 [4 P2 |) b2 i& ]  k
    21 v7 X3 j3 ~9 S9 h& K3 L( R! b7 U
    3% x5 `& i; f* \/ |2 c+ v
    4, k" K) e. v0 R1 f( {$ r
    5
    4 z, w2 k9 p: }" ^' D/ i3 N1 n6 R6/ ]7 u9 w7 M) }  l. A) d
    7
    / C% G* ~" }& H* o4 x5 W. n* d# j& h8* @0 L0 V! y( Q: C+ [
    9
    6 o2 O+ o  ?; ?! `+ f  t10
    . c4 L2 p0 u, R# E115 M; J& G" ^8 j
    12$ T  V+ E3 l7 |7 B$ N6 r/ v
    13+ E$ j/ _$ i3 M9 p' _/ Q6 l
    14
    0 q" x  y: @$ Z15' `: |8 z0 }) r' ^  F& h& o
    16! ^5 ?6 ~6 W) T! E4 x3 I/ |6 q
    17- W4 ?0 w3 a5 y) u, I) ?3 L' p
    18: d" S# U" Y" ~) n9 @  h
    8. 加载已经训练的模型1 l) @7 Z) W. k2 S+ G
    相当于做一次简单的前向传播(逻辑推理),不用更新参数7 R+ f% a$ D. {( D* X+ E

    ; {3 k3 G2 i2 P6 f/ ]3 Ymodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)* ]- I4 v$ x6 ^# X& X1 z  w

    ' T6 W% w' u' v# o6 T; g# GPU 模式
    9 a; ]8 C& C: O6 k$ _' h* o4 bmodel_ft = model_ft.to(device) # 扔到GPU中
    / i* m* M5 n( U/ ?" t
    ( q! D- W/ {8 c9 Y8 p) Z6 q: A9 g# 保存文件的名字' f; ^% y3 Q$ i8 j% H# s5 y
    filename='checkpoint.pth'
    1 Z/ G5 ~! O, M! x  A, [  @1 J2 ~: B, t! M5 {
    # 加载模型
    2 V) C" @: ~7 n) i' Ucheckpoint = torch.load(filename)
    6 Y6 _+ J  x  z, L0 I9 s& k$ Jbest_acc = checkpoint['best_acc']
    - t+ U5 K9 N: Y4 U: r: x5 Rmodel_ft.load_state_dict(checkpoint['state_dict']); J% v' g; r5 U* T0 e" ?
    1
    4 k5 |# R; N: `; \' o2
    ' j- @, N9 u) V! v, G) V6 [; t3
    ' p9 Q& ~$ @; G6 [) o! y43 F3 b; y* [! X4 _+ ~
    5
    , V% @) R1 {- A& P4 f+ w, O6
    ) o0 h5 f1 `. q- N3 ]* ?2 y  A- s7: n' l, F7 \: V4 `/ O2 J
    8
    , x( u6 B' o5 T& m) ^0 `+ s7 z9
    1 B7 X9 j1 X+ {1 j9 s10; y7 s4 j) }: ~6 [) o& S. N3 P
    119 [1 {" Y9 B$ k+ q& r
    126 B$ J+ K8 F4 q
    <All keys matched successfully>
    2 A/ g4 ^& J3 l2 {1 ^% g3 C2 }, z! P19 e; X7 ]) m! O/ m
    def process_image(image_path):
    2 ~+ o+ ^: W4 Y* U- J% }6 J    # 读取测试集数据
    2 q$ i, e% K0 f  _. y# y    img = Image.open(image_path)
      c2 @6 Z( J; f% _4 G    # Resize, thumbnail方法只能进行比例缩小,所以进行判断* ]( f" `. d% h- ?! h' g$ ^. {/ `# G/ z
        # 与Resize不同
    " z3 Z& h! y1 Q! q0 p" ~) j6 b    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小9 B6 n9 X7 \( M. ^2 D
        # 而且对象调用方法会直接改变其大小,返回None
    $ R: z0 K" I& B) h6 I+ w    if img.size[0] > img.size[1]:; S) r& R. n; G  p5 Y
            img.thumbnail((10000, 256))
    ! }- @4 ?( j( T+ ]    else:( y: }. D# N/ i% Q3 @; Z4 ~  ~5 z
            img.thumbnail((256, 10000))' o. Q; H# U  A

    8 Y0 B% ^7 D. U% ~; W    # crop操作, 将图像再次裁剪为 224 * 224
    6 C: A4 C1 u1 B" A    left_margin = (img.width - 224) / 2 # 取中间的部分: ]& z# U. l' Q% F# b4 J! F" D
        bottom_margin = (img.height - 224) / 2 - d/ U$ O8 d& d1 @
        right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
    . X) R7 U8 R: C! J8 }+ f    top_margin = bottom_margin + 224/ j& i( r  Z  O1 A: z/ ~

    . e6 M+ m) C6 V# n    img = img.crop((left_margin, bottom_margin, right_margin, top_margin)), F, M- M6 n  r8 |1 ^5 w& w" Y
    / M+ Y# s9 m+ F, u! D( z, G
        # 相同预处理的方法& b: u8 V  U. m! F  _
        # 归一化& l! u$ x* C6 Q- _
        img = np.array(img) / 255
    - K& ^5 l: m1 o5 G% ~1 Q    mean = np.array([0.485, 0.456, 0.406])
    4 ^3 V  Z- O% v# I& D; z) k+ w    std = np.array([0.229, 0.224, 0.225])# }. Z1 n4 x- h" g2 Y
        img = (img - mean) / std
    5 P* M2 \( H% o, G2 r" f. O" I" q9 a' F
        # 注意颜色通道和位置
    5 O! U  I" R$ U# e* M6 X+ ^. c    img = img.transpose((2, 0, 1))
    $ o! f- N4 Z3 X0 L* _
    " L  i( F- w! w0 R7 v4 N( T    return img- S; {# A) K  T1 e& P$ d% h

    # B$ w% F! x0 {# `0 N* mdef imshow(image, ax = None, title = None):
    ( R* R$ |. S. b3 F/ u0 ~0 d4 N$ c6 y    """展示数据"""
    8 b% z7 E0 ?" F0 s( Q( f    if ax is None:  J" ]5 Q3 c6 U) k7 R' b' I! Z
            fig, ax = plt.subplots()8 t% X, M* S$ E

    4 y. h. T& Q; x* P4 a: K/ |9 p    # 颜色通道进行还原1 F0 t0 ~: p2 G( u* d3 _5 o4 b% y
        image = np.array(image).transpose((1, 2, 0))
    / Y" V2 ^/ t5 y( X% F/ x6 ^* Q' C& g
    ( W9 J" i5 _) _0 n6 {    # 预处理还原) [1 @/ T5 l. @! z4 `# K) m- z8 x
        mean = np.array([0.485, 0.456, 0.406])
    ' u1 F6 t  p& t" `# X    std = np.array([0.229, 0.224, 0.225])
    , t! I1 p; S2 `* K    image = std * image + mean
    ) I3 D& R0 e6 L+ p" E* [0 B" x    image = np.clip(image, 0, 1)9 b+ Y0 c* M4 v+ |; C
    9 {$ a% j' E1 d, a# E6 b
        ax.imshow(image)2 Z+ W% E- K- ]
        ax.set_title(title)+ M& k* _9 g+ A/ e7 R
    # J  o1 w( B4 o  Y2 M1 F
        return ax# ~8 i/ v0 ~: a# T! ~
    & D- G" \- Q( C" t
    image_path = r'./flower_data/valid/3/image_06621.jpg'/ j, x" D' I  B3 G
    img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
    2 z: e" m, a! \; \& t! q: rimshow(img)7 _! C4 }2 ]7 p6 g0 f
    ' o( R: P( Z1 P/ r4 _
    1/ _6 O6 C  Z! j# {1 V2 ]
    2" I* x* [) @9 {. G( q- W( {
    31 M$ P5 o+ P: H8 K  s* x* b' n% L
    4" l- B1 F7 o0 |- X# |0 z( l) `0 n  @
    58 l( M2 r6 }: y
    6
    . S  j$ R8 M" Z* a7! Y. r* ]) {: W6 r7 g! t6 t  |2 F
    8
    2 R6 j" X0 @! W8 H. a5 q9
    2 b; G' f: c" m, x: ?; a10
    , a8 o+ A6 c4 j/ x  L. g11
    0 d3 y0 ~9 B7 z7 R123 G  @8 L. e" j5 b
    13
    , @& t) v' [$ ]/ h) i) ]2 B14. @9 l4 ?  m  h) Z
    15
    - S; S$ s+ r' I% F' r160 _! A4 G/ u* G8 R7 x; E3 x
    171 E, Q* t7 |$ c: [2 h( R
    18
    5 K/ P" h) K3 E! v& H% X19
      n* \7 y% P2 B20
    4 n3 r5 w+ r: l2 O! `* e21+ ]+ c: N9 l1 \/ m& ?' S) U
    22
    + @2 e4 X* ]+ W5 E23
    8 Z0 g. `7 R: {+ |& H8 @2 u, V24
    2 X* H% J% }6 @$ u3 N0 ]" s25- ^, l, V# P" Y; G: j
    26  r" g) |* N3 [
    27
    " W% _2 \. b0 k, o( s5 f4 J28
    $ O3 Z. V4 u  t/ P) @$ t29
    + ~% k9 J7 a+ v; g$ b30" `; m, b, E* n, a0 R
    31
    / y: C0 q: O; Z' k32/ E! y. S& |0 ?+ a
    33& h) q( v+ K; r& R
    34
    5 a( _  o4 C1 U& ]( g35! y1 g- x  o3 u
    36
      B  y$ U, Z/ w5 y% c3 M37
    # y* {! w0 e; s1 R38" k: t! H; I+ p+ G& i7 w1 _8 `0 ~
    391 T! e/ g$ U4 j) M( j$ D. g: Z
    40
    . }1 {9 T0 l5 Y; p7 x. ~412 j% S) n0 u# T6 Q7 p& r
    42
    1 S: w+ l9 ]. ?* J3 C43" r( N2 n" T1 }( Y
    44
    ) c+ B5 A8 \# S9 R+ d0 M" n1 c45+ w5 P5 q* i2 q0 F) Y( @! C# L0 h
    46
    , f* W, Z" W, z47! d8 I! \! d# h! T1 x7 Z! d8 Q, X1 S
    48
    0 R3 K: M5 Y0 m! _2 @1 ]4 |4 Y496 ?7 _) [% h. I8 j& {: s
    50
    + {) l! _; ~/ w; l- S& ^  H5 N# J513 R' K& S  g8 p: `) h3 i6 I
    52' m9 T* e  j9 Y2 a- v9 A6 m
    53
    ( _# ~. q5 E# F, x% b0 W: }  l' E54
    , z/ |* M- e# {- b  ?6 i/ O" B# y<AxesSubplot:>
    / _/ I4 j  a  z/ a6 Y, s1
    6 z; b) f: G# Q7 l% a$ n2 \' u: b
    7 x0 |9 i5 N8 o上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确. ^  S- N+ ]  D
    ' Z5 D4 N9 G9 y/ y) f3 y
    img.shape
    $ h' x8 g4 X1 m# N; V17 E6 e9 L# P3 D. P, e/ }. n  E; c
    (3, 224, 224)) @/ K# E0 J# ]* z' o
    15 s$ g, O$ L4 s$ f3 J& Q
    证明了通道提前了,而且大小没改变
    : l1 d2 l% I( N  u6 n$ A
    ' |3 V6 i7 e6 D2 ~- f1 O! O9. 推理0 z' U3 T! O! s8 k' |( e' k
    img.shape
    * Q: V& p! C/ ^/ p7 {) A; O2 X" i3 H( V  g% x  n8 K. a* [* x8 T  m2 u
    # 得到一个batch的测试数据
    9 G$ h7 }0 `& Y4 h( J8 u, w% ldataiter = iter(dataloaders['valid']), J, Z9 M* l8 l( @2 l: E
    images, labels = dataiter.next(), K2 j/ d8 z! [3 e1 F2 I

    ) V  P' j; s4 W/ n* F) Lmodel_ft.eval()
    2 h: u1 r- P$ ~  _4 F2 R+ {9 e! x# i) J2 T  A  f  ]7 z
    if train_on_gpu:* R2 M1 d3 _; M4 a1 ]2 y( W
        # 前向传播跑一次会得到output/ e6 G3 X2 f1 Y+ i9 {* ]" [
        output = model_ft(images.cuda())$ l" O; A, a; ~) {5 {3 k+ t2 L
    else:
    4 {  D2 w2 t' b$ i- L2 t$ ]    output = model_ft(images)
    / ?9 J: ?- o9 W* k/ y2 S+ I9 I# z
    # batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
    6 e1 c4 L+ V% C6 r. y0 U2 }* ^0 doutput.shape
    ! c1 d, b' H  v/ f) N7 W+ W4 r
    " D' [( G+ O/ L: Z  c1
    8 f! F! E4 E2 E4 j9 s7 x2: q( w5 b" o: G0 F  o
    3
    8 H: }& M+ i' _& I6 H' c46 D. l9 f! y% {/ |$ a
    5( L4 j! I' q8 D2 h. O
    6
    / b, M; c& ^+ ^' x: Q9 [2 Z7
    0 e9 g' E, _$ {. W8 B8
    % a0 p6 _; w& o9
    9 D3 s1 {) L% t# l, ?! s/ {106 K: p9 T( W2 [+ A# i7 p
    11
      D, R4 q2 \1 d2 d5 k12
    5 `7 Q( q* b* ~' B/ p1 e8 ~13& p8 F8 h4 |2 t6 S3 Z
    14
    . U# l& U4 x+ g. t15
    ) K& |: {/ T! |; ^16, K7 y  k) h! [
    torch.Size([8, 102])
    1 c: h* ~# V+ J0 X* ]3 A/ Z1& q0 f* i* Y2 p) `- p
    9.1 计算得到最大概率; n: h- ^, a) X8 A1 ?: e- ?
    _, preds_tensor = torch.max(output, 1)
    ! q. c, f8 u5 U" e: i; a( h0 |9 ^2 b! A% Z9 _3 x
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
    8 q+ S; _+ Y7 ^, V1 E" I1
    + E- [$ j% _0 b7 h7 N2
    ; u, l8 \) `9 ~6 q3
    3 [! t' `, e9 D9.2 展示预测结果! a3 [, w9 p- M& O  f+ x' a
    fig = plt.figure(figsize = (20, 20))+ e& G* y) h* H
    columns = 44 |6 U# m* k; B  R! X8 H- c+ S
    rows = 2
    . z- c$ D/ ]0 y  x* r" U2 q
    + ]- r: b! T" B, Afor idx in range(columns * rows):
    . v, @; S6 D# P; x# A; L. W6 m    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])# u9 `6 t' G: N/ b* r! a
        plt.imshow(im_convert(images[idx]))' v+ s4 m1 z/ e& `" h  y) m1 D. K
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
    . l- F0 c& B8 p' [6 A                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))6 l; b+ L" o2 T4 G8 F; _( |4 K
    plt.show()
      C. \' N, h+ p6 O# 绿色的表示预测是对的,红色表示预测错了: _  p3 E+ i$ D8 b3 w, u
    1
    1 i+ ?) ?1 X5 O7 }- L  V3 N" P21 |3 S5 H+ \: `  C) c8 Q
    33 K; J5 I4 ]/ O! \* j, |3 f) f
    4
    8 Q8 @; Q: Z/ s: @5
    6 P. o; A; Z; o! u6
    3 G2 f, V) F5 u- l7
    % ^0 r, v9 ^0 z: F3 p5 d8' S5 W# C! J: P" C* ~! w5 x" P4 i
    9
    8 P0 F5 w# V1 ~( L102 d' f) h  u/ }2 o  A7 k( P7 f
    11$ X8 x3 k9 c; V% @# s% u% y  @

    7 @* }+ y6 N! f5 u- V5 J" B  K
    + X# f* P; B* C- C$ t" Q$ x; Y  M
    ————————————————" j4 M: \+ t; y: G' |6 ~2 a5 E
    版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    5 W1 o/ d- ~8 i0 |0 v/ o' d" n原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
    ; C! _: D7 M) {) r. D* T. p+ y: `' \2 @: M! t6 E* ^- N: q4 k

    5 |" L: o8 W2 p  e8 H1 F
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-6-14 08:44 , Processed in 0.502476 second(s), 50 queries .

    回顶部