数学建模社区-数学中国

标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例 [打印本页]

作者: 杨利霞    时间: 2022-9-8 10:41
标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例1 ?% A) t6 L6 W5 b4 ~& F

+ W( r6 {9 O5 }3 k* U  [: P文章目录! u& L3 T: V4 D) R1 N( O5 O; F
卷积网络实战 对花进行分类1 Z& |% R; j9 Z
数据预处理部分
+ Q3 c* g2 W+ K8 G" s8 ~) @网络模块设置0 P+ ^/ q% j& w* H+ L- P
网络模型的保存与测试5 Q3 w( x" y/ K5 T/ o
数据下载:
: Q4 i; [6 }0 U; U0 o5 b' g1. 导入工具包
- m+ S9 Y! I% c! a2. 数据预处理与操作' U, y: S( i( }. m) K
3. 制作好数据源
: U0 S( r  @% h7 ^5 j3 m读取标签对应的实际名字' z8 e+ T. r( J, |5 @0 T" ^
4.展示一下数据
9 f1 n9 t: i, O5. 加载models提供的模型,并直接用训练好的权重做初始化参数
$ A6 x0 `- L; D8 m! Q. Q3 D6.初始化模型架构
( N4 d' Y, @, s+ n7. 设置需要训练的参数
6 ]$ ]" n- m6 H0 K4 b1 v7. 训练与预测
( [, x3 T* U# f2 U) W( |7.1 优化器设置' g! R) b3 s' Z4 z# V( ~- B
7.2 开始训练模型, R9 V$ g6 ^: v
7.3 训练所有层) [8 i- g1 S1 x% O! r
开始训练
$ }: D- k3 n% I2 Y! [4 \- R8. 加载已经训练的模型9 J5 R& N" C- h4 U/ Q2 g
9. 推理
. W" T* x, j  |8 E) G0 r8 V9.1 计算得到最大概率
  g2 I" t6 g) g9.2 展示预测结果. d1 Z) i6 W5 {% c: h$ p7 C
写在最后. h9 `; J) v# \9 _
卷积网络实战 对花进行分类
) g* I% x( V" ]本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能& ^4 v5 y9 H% s
4 O; m- C9 U7 ?' l4 ?& T3 {: Z. p
在文件夹中有102种花,我们主要要对这些花进行分类任务' Q) A' b2 O' P% j! n3 ?6 H
文件夹结构! e, j$ W3 l6 z! e: j

) E! ]1 i+ c. n  p" Vflower_data
8 S  O$ S$ @- l1 \4 [: |* e- d: k7 O$ w1 B
train# L( L+ H6 ]! L( F( I
' }( J# L9 H" j1 X0 `, J0 `  J' w
1(类别)
: Q. {3 I# Z5 e  r, [( y2
3 |+ u" s/ B  O" U8 S8 X4 [- nxxx.png / xxx.jpg
1 ^, l" L5 c% [3 @valid
) f% R, h9 f* {; [
: i  G3 N! b4 x3 f, ?; y( N" I主要分为以下几个大模块  Y; C/ a) ~) o+ k- ]5 E
5 Y' c6 _/ H+ @, i- K
数据预处理部分- J' |: g2 ~7 r
数据增强
$ b- h; }# I. U0 k2 a: u数据预处理
" P: O0 i8 _3 R- n网络模块设置  p0 N6 u: ?# h% G% Y  R: {
加载预训练模型,直接调用torchVision的经典网络架构
, f4 \' L- Q6 ]* [/ i* Y2 S4 r因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务: n1 J+ M& z" d/ ]: G6 |
网络模型的保存与测试
2 t* \* X. Z$ i& z! T模型保存可以带有选择性1 H' D" R" l( D" ?- s. @7 b
数据下载:' W* [& D7 Y% }& z3 T
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset) n) }- L; _, }$ ~
7 P8 R0 E& s$ p3 l) ]+ N
改一下文件名,然后将它放到同一根目录就可以了2 {" u2 j/ H. }. e

  k$ \0 p7 |7 y! u' [下面是我的数据根目录
. D* m5 |4 a+ [& k( P
# s" C7 n: J( R$ d0 e5 F& {! }
$ Q* v6 T7 ?9 X" R1. 导入工具包, g5 `- p; f. j" d: O' d
import os
& i! X" S6 c2 {$ eimport matplotlib.pyplot as plt3 G0 {2 G$ N" P
# 内嵌入绘图简去show的句柄
/ m& s* {8 _& m- j- Q%matplotlib inline
8 ^! D' R- U. p; R7 e  h6 iimport numpy as np
/ b8 `2 [. C8 P# e$ d  @% q; Wimport torch
; f4 O, u8 G$ Vfrom torch import nn
7 z% K# h8 G4 W; [) x
/ y+ u% g5 S, H' ?+ _: Vimport torch.optim as optim
1 B) d2 {" ^- `2 _" @( c: zimport torchvision; I/ _/ s/ l0 Y  I* p5 B
from torchvision import transforms, models, datasets/ x7 b$ b+ B( |$ y# G  O4 m! o9 B" T; v2 H

& u# U- V8 [# x5 n: @' ?( B: K* Yimport imageio1 ~1 ]( t6 B+ T/ \9 O- M" H
import time, L: q3 S/ ]+ h+ y+ p% P
import warnings! {' J& s# S4 j8 [. a' t3 x
import random
3 _* J8 L/ \7 I6 p, s8 ?& m" @import sys
' s9 Q7 G0 p* l9 K5 c) m. pimport copy
; F+ D" W* _* \) d! F% Z/ i+ Vimport json; t  F* ~2 i% y( N  j9 Z
from PIL import Image
3 J+ h/ s7 Q  Y- P* q- w- w: x
2 F+ H: H' B! r* l$ p6 \4 D6 T0 T$ @& W7 j$ [) E* v
1, t& p! J- G  m* U9 O* J
2& w# U1 }: t( Z! P) j: d" P& U
3
5 h5 Z  I1 P( b3 k" B4 e) }! p4; f$ U+ V; G6 r0 m3 P
52 K$ T4 N8 h0 p2 f
6
( f* x( n  L6 P7
8 H0 B. i4 }4 C  ]8- Q- ?! K) o: O4 N! k
9
3 K6 b7 ]. i8 Q) Y0 `102 A+ a2 j. L% N  t2 Q0 E; \
11
2 S. h1 P! d6 B3 X9 g12# t, {5 {. @9 T$ }
13
7 D- c9 U8 L& s0 }* g5 P6 c5 o14
4 }* R, I+ g( A4 M0 I1 T15
6 K' Z) Q1 k5 {/ q4 J3 ?16
! D5 {- A' m0 }- r9 n9 K, z8 C' j17. V) s/ C8 v' \' \6 W. u, {" f
18
$ K3 j0 z% r( v7 |4 n1 k" p+ K/ X19( I: @$ p# \9 D' S/ Z$ ~4 x# U
207 z0 E) `2 F7 G0 ~0 V! t; w! I
21) N) s/ z5 ~& @, v  H5 i$ R5 \
2. 数据预处理与操作3 d8 u6 Y& s' T9 [/ d9 y
#路径设置* m! S7 T9 A, P2 ?" _
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录; P* }! s, L" C
train_dir = data_dir + '/train'/ t- Z$ B/ ^' e* \3 W% @9 L, ?
valid_dir = data_dir + '/valid'
& F  j  }7 l% f7 F1 X& L5 I1  N$ }8 \9 g% Z) q5 F5 }
2
# ~. O* o# W* F: t0 V( j3# M5 m( |9 L2 B4 L5 J% `
4+ |# F) D% r3 F0 g6 ~: J3 D! h
python目录点杠的组合与区别
( X" g# G+ |! q# p  b* p, f7 J8 f注: 里面注明了点杠和斜杠的操作8 t5 a1 {' f+ o- X5 B/ g  H/ w
8 ]4 h4 H( q0 B0 ~$ I1 }
3. 制作好数据源# ?$ o. j' L, G2 T9 w
data_transforms中制定了所有图像预处理的操作) i5 o) r& A# `
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
1 q5 X' G9 W; O: K9 Z% _" fdata_transforms = {9 Q. @2 l, W5 C4 W
    # 分成两部分,一部分是训练
5 j, Z$ s! n) g/ z4 z    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间: P& n, b1 K& G2 @; u/ d+ c
                                 transforms.CenterCrop(224), # 从中心处开始裁剪
9 [, Y9 Z; |" d$ ~+ ]                                 # 以某个随机的概率决定是否翻转 55开
. G9 e6 `5 G% c9 p/ G5 d, _/ E                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转( a! {; X+ l# ~* {- C- B/ ~
                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转" H! P& U5 @" C) F( ]# s
                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相2 |: O6 x' y9 a5 c( Q; Q6 |" P5 e5 N
                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
( E- H" a5 C8 Q; y                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB- C$ N, ]  a+ G9 \
                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
! u: d9 C7 Z. p! A0 `4 H3 l1 ~; s                                 transforms.ToTensor(),
9 C+ P. r+ ^2 ~, v! f; X4 z: E. W                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
* S& @% _; _. x& B7 T) V                                ]),
9 A) @% k: a8 @8 p! @    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
/ `# w; s: T8 e: H3 @8 t8 c# ~; P    'valid': transforms.Compose([transforms.Resize(256),# _, O) {% c. Y1 r* m8 Z
                                 transforms.CenterCrop(224),
$ P) }: @3 S* w6 c  @. b                                 transforms.ToTensor(),+ g. Z7 X" F4 u: a( A
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
+ ?0 D! D) }2 V, w4 D, g                                ]),
5 |! q& J6 d- H. ~2 D+ \}
5 C# L0 v0 i4 e( b- L
4 k- y4 Z, O: v1% E; @  ]3 P/ A8 m2 E; }# Y2 \
2
4 S2 N  }  L* x7 _( k3
) H  }& |* ]1 I0 v5 O4
: M2 \' n) D9 g5 W) v; \' x$ ]; h5
; H2 ]& n$ x5 m& R7 b* o. `* m# q6, C% W5 r/ Q+ i+ \+ v
7
7 N# E* t, Q+ H/ r' r8
0 G( x7 _" t% ~2 n3 L; \6 }/ m9
- O( R% _$ Y$ d% ?9 Y8 V10
# V! h* l4 d& S  K  C0 B  y% {11! [& @. z5 H$ V- o1 g
12
- m; J( _: R2 p4 Z7 i13# ~, m, @6 `- S! G% r' @- W
14
+ S! _/ x; J5 i- R! ~15
' }0 Y1 T3 h  c2 T9 t; Q16$ ?% V4 v' ]& K6 V  Y1 ^
17' w' a0 [4 z) u. Q9 ?1 b
18% h' x3 D& I# \: e$ V3 W
19) {! J: S8 Z9 {8 Y; W  d- u: ?
20
# k' {( E$ p0 C21. y+ Z- @% \5 ?$ k2 t9 R
batch_size = 8  |% j, }3 R- A
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}# W9 j1 o& P6 j3 `* C) p2 J  h
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
1 s& d; ~7 a- J; j! e, U0 Bdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
) I, [$ W: {* Z* ?6 [9 h2 ^: o- X+ Oclass_names = image_datasets['train'].classes
' Q1 ^# X+ h! d' N4 H) K3 z5 W; ?8 A# Z3 W! f8 G" n* e
#查看数据集合- f( k" z# q4 p0 y" Z6 }4 l( u3 E
image_datasets, Z  |( }, q& Z3 Q
. g' F% k  v, X+ k
1
3 ~: [) e5 X* {8 E0 Q2
1 C8 p) Q+ V3 E! T0 k, ~; y3
1 b/ q/ [+ i& J. T9 v4
  Y3 r  }& i& x$ A! p8 q5
1 M! j" L3 j; w) z' m* i, d6: B# w. K* W% g1 w. ?2 }
7  J0 u! L* y* K: F9 H
8
9 Q" E8 i  {8 o8 Y& ~+ Z* l9
4 \. L+ [& y" y5 s) W{'train': Dataset ImageFolder
: N7 I. m  T% C3 z' ^4 F1 \8 \/ W     Number of datapoints: 6552
0 J; ]/ w' \) Z6 N/ D1 P     Root location: ./flower_data/train. e0 a- \! I8 J8 \, x0 a
     StandardTransform
8 H- \! `& o" t. i) ]1 O7 S. R Transform: Compose(% \3 D7 \0 Q* H) [) b
                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
- T; T+ H5 n+ f+ M" ?                CenterCrop(size=(224, 224))0 q1 ~& W6 M: A4 B7 q. y; M
                RandomHorizontalFlip(p=0.5)! N6 r- a2 C0 M/ g! r& ]3 J' u
                RandomVerticalFlip(p=0.5)
& ?& Y9 O3 x6 V2 w. A, ]( H                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
1 e/ s7 i$ @1 A4 v- V% _$ B                RandomGrayscale(p=0.025)$ ?0 \( A7 f; j1 j5 I& [& K2 O
                ToTensor()3 g& @' T* h1 `& W( j8 u- {( Z
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])+ c1 x* `7 H, x
            ),, ?$ g# D# [2 ]  R, K
'valid': Dataset ImageFolder2 N; e( h+ X% R1 r9 i7 W0 q) r
     Number of datapoints: 8189 c7 Q: P; j6 H
     Root location: ./flower_data/valid
" P) w& Q& F5 Q4 y& u% ^! q     StandardTransform# e. m- V* i& |& r$ ]7 }5 A  _" g
Transform: Compose(" w! C/ z5 e+ f& n( y! V
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)# P3 y. z! b3 g6 @( C: _
                CenterCrop(size=(224, 224))6 S3 Z+ d4 n& A7 Q+ D
                ToTensor()
" Y% i. P* t$ t                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1 j) y( d" I  w            )}& d, T) t% r8 F8 ^
8 M: t' E: o- Q
1& q2 ^+ r9 E2 Y
2: ~$ }4 y! M: `8 a- V
3) ?# t% x$ G& V5 D9 e
4
8 c2 D: A" m! W. m/ x5
# {7 t. \5 O: v/ W  N# }6
; z3 s: L% a, C5 S2 T/ e4 x* U7
0 j% s) ?2 z9 \; w7 R* c8
$ @8 j6 G0 @0 h" R2 s  h" L# B9
' n$ x5 [" M; R# y8 Z& l! F10
' j* T" f2 t8 ^: W. S116 o' e* H3 ]; e
12
7 X3 Q, O  i+ U13
( n! K7 ?3 \+ Y14
- p0 d0 O; n' `! C# p* t15
7 F& y- Y/ @8 u" v16
: l, z, j0 @, {6 ]% J( _0 r1 b17
& K, W4 q2 ^3 l" v: n4 h' T18' B; c4 {0 d9 Q
19' J# @$ G2 I8 H
20" a7 B- L6 x4 T  v
21
. p! R! I! y5 o1 s# a$ V# C226 |7 Y# K4 y" i: s. @* Q
23$ J. P, A) p- ~8 ^7 V5 {& ]; K# j
243 }& \& G# Q: f) D9 S' a% {( D% z
# 验证一下数据是否已经被处理完毕0 x+ v: S. }  u# W% [
dataloaders, S) W* c0 B# h0 c
1+ x1 c- T* A' y9 E  Q6 e" c
2* J7 P1 R3 C* J: [1 T) `1 g+ r
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
( k# x* w9 x- ~. ?6 S! P 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
4 r- p9 u5 ^- B1 ~1( Z' i! g  O( S' p+ J/ c% @
2$ @9 i+ A2 L: w+ v4 M3 J1 Q
dataset_sizes
2 o: D; n# g( A, _! i+ M1
) b* M! ~' s0 H% M' a{'train': 6552, 'valid': 818}
: J9 L; w8 e; L0 C1 b, T1
$ L3 k! R, [2 H# d# f' l9 E读取标签对应的实际名字
8 s/ ^. U% h; p) C, x5 _4 n4 d使用同一目录下的json文件,反向映射出花对应的名字
# j( o: T# T. E/ U. e9 q& p( N0 Y* |9 \+ o) ?
with open('./flower_data/cat_to_name.json', 'r') as f:1 z# [+ I" x2 Y* L) [# n0 O! e
    cat_to_name = json.load(f)
1 ?* Y( R7 [8 Y1
+ _$ Z) d/ B  t2
2 ]/ X- ]* m4 v9 }  R4 j/ Fcat_to_name
: J& @1 s. T1 d0 u. @, k1
4 A3 U0 F' J+ r5 r7 x{'21': 'fire lily',7 g- K1 V4 l2 I* F/ {
'3': 'canterbury bells',* Z) W* k& [2 Q- ?, X) E2 k
'45': 'bolero deep blue',
# F9 S, B7 U+ [. R/ l* ^ '1': 'pink primrose',
# T2 J2 N; {. k* e% T' A '34': 'mexican aster',* x) K: K, ?; _
'27': 'prince of wales feathers',
" [/ w: ]( g+ j$ k0 t) U( W$ U '7': 'moon orchid',2 [( Q8 h6 y% V/ }& [4 A( d# k
'16': 'globe-flower',7 Q" S. c4 N4 p
'25': 'grape hyacinth',
3 }: D+ [7 s( M" W& b* f '26': 'corn poppy',/ O- ~( C0 o- s8 O5 V. ]) k
'79': 'toad lily',
" B1 i( Z. \+ }9 u( d$ X '39': 'siam tulip',3 S" U( [# |3 K# L. p% E* H
'24': 'red ginger',5 Y8 |+ P, W: W3 _; W/ W8 ~: j
'67': 'spring crocus',
1 N7 E9 }8 x: J, `) P' r. b3 r. ^ '35': 'alpine sea holly',7 a8 X' C8 [7 _$ ?$ [
'32': 'garden phlox',
" x/ R  j6 r# H3 d '10': 'globe thistle',% t& x, M- s" |( N1 T, q
'6': 'tiger lily',1 a1 A3 v' K  a) K
'93': 'ball moss',+ w; Z) E3 v- T* E2 K6 R) Y
'33': 'love in the mist',
" J* ?+ c8 L6 [7 C '9': 'monkshood',- g" p6 e- Y9 F; a! u) V
'102': 'blackberry lily',
; ]; @! j" e7 i! C2 r '14': 'spear thistle',
* c" |0 l1 P& Z& |* k8 _2 H8 t+ g* { '19': 'balloon flower',
: p' e: Z# M1 A" M; b* V8 m9 f  b '100': 'blanket flower',6 ^7 f( F+ s( A! K( _* @' Q( l  t
'13': 'king protea',8 L* I! x/ G0 o5 Z
'49': 'oxeye daisy',
# b; n& ]7 L: j% X4 x) M# [ '15': 'yellow iris',' f8 g; ?* h/ C- `% U
'61': 'cautleya spicata',
5 [  G1 w4 B5 f* b '31': 'carnation',
) ]1 \$ r( S" l9 G. R: H& m$ @ '64': 'silverbush',# e% |4 ]" b2 M5 y5 Y
'68': 'bearded iris',
, ^+ P2 R' k4 [- H5 i '63': 'black-eyed susan',9 N7 b- [& c1 F% Z7 P. F
'69': 'windflower',
; m0 p; T6 B1 \ '62': 'japanese anemone',
9 N- s4 V; ]$ r4 x8 C2 f3 S0 s, @ '20': 'giant white arum lily',
& }) C: T, a! @# }. @ '38': 'great masterwort',
1 R( @" v- H3 Z5 i& x '4': 'sweet pea',6 A2 k6 v/ ^" @& G8 i7 P8 Q
'86': 'tree mallow',  r* u  _7 C+ s- e9 a& c3 b
'101': 'trumpet creeper',  H  Z+ H8 Q4 {3 u; O1 f5 |- q# d
'42': 'daffodil',5 \8 X0 G$ G$ `! U# W7 t
'22': 'pincushion flower',
: a( V4 [& w! D) ~! y '2': 'hard-leaved pocket orchid',; \4 W6 n) f9 |7 y
'54': 'sunflower',
) L4 ^1 I$ X. i$ c '66': 'osteospermum',) G6 {7 i  @. C  a, B8 N1 G
'70': 'tree poppy',
7 d2 s! Y5 {" S '85': 'desert-rose',
# `0 d9 G* p" @& I '99': 'bromelia',! N" i  i. T! p2 K, @4 `4 F
'87': 'magnolia'," q  e* k& z: P
'5': 'english marigold',
- G& ^% R/ i( B# K- D+ U7 [; i '92': 'bee balm',
! c) F  `0 w1 O  \) }2 c '28': 'stemless gentian',* ]3 k2 `  W3 ]( j/ r
'97': 'mallow'," [5 l8 o2 l6 X0 y" E
'57': 'gaura',2 `  s7 o% x# f7 J
'40': 'lenten rose',  w1 c- `* j8 V$ n! T. g
'47': 'marigold',
7 Z- n; M0 d' a& E3 G2 |7 K. m '59': 'orange dahlia',
( ^1 M1 ?1 y; s6 h8 z3 ~1 V '48': 'buttercup',
7 O  v9 d# A+ J2 f6 S '55': 'pelargonium',% z" ^& f! K' a5 v
'36': 'ruby-lipped cattleya',# H/ k1 ~9 R; r# E6 J5 F2 i8 o
'91': 'hippeastrum'," e# i4 O- O2 X& Z8 y* a
'29': 'artichoke',. c2 R. \( ]! h
'71': 'gazania',! u3 ~9 W& r# r( I. G, _/ D/ d: v8 ~
'90': 'canna lily',9 \1 j. E: }5 v% i- C
'18': 'peruvian lily',
4 N( B6 h. x0 r, S' X  x '98': 'mexican petunia',9 y  M  Y% d' z! w) u
'8': 'bird of paradise',+ v0 \: a5 X6 k: K7 E& \
'30': 'sweet william',
6 s% z* ?( l) R; _1 F0 a '17': 'purple coneflower',
& M$ Y% _0 }6 Z, O; [4 F '52': 'wild pansy',$ T. Z, ]4 {" j* a& H. d8 ?
'84': 'columbine',
) c  E8 n9 S$ K7 k '12': "colt's foot",
, J: P* [+ e+ a, l* q% K9 @ '11': 'snapdragon',5 m. j6 ~# A" ]* |% x
'96': 'camellia',) P( ^; _- M! C* K# }. j1 @
'23': 'fritillary',. a' m! A# c: Y& G
'50': 'common dandelion',
8 D* {" n  t7 u! E '44': 'poinsettia',. r& m$ ~" z5 j
'53': 'primula',
! }* q5 J  B1 ]2 j' d( C" g. p! U '72': 'azalea',
9 C9 [+ u4 q7 I% i$ \8 `0 {5 [ '65': 'californian poppy',) E1 _7 [& O% }6 F5 h& d& h
'80': 'anthurium',+ U8 l  G. j5 v! |+ S
'76': 'morning glory',
$ G+ s+ ?1 [& S0 n: I4 E9 J '37': 'cape flower',
: x# i: h+ f* k- j& W '56': 'bishop of llandaff',2 L" O8 ?5 @, U0 P* h5 I$ \4 l# a" ]9 _
'60': 'pink-yellow dahlia',; t0 M, c, u) C, _1 U3 a
'82': 'clematis',6 l6 n6 F  _$ Q3 ?, p
'58': 'geranium',
+ h- R7 f, W# k% r- D  M  D '75': 'thorn apple',
  a( ?# G: S6 G  p6 m2 m& u; J' Z '41': 'barbeton daisy',7 x# e; F! a5 l2 l# Y7 @: w/ I
'95': 'bougainvillea',
% q; H3 ?; ^' S. O2 c$ d '43': 'sword lily',
  N! [0 `6 Z' X. K8 B6 O '83': 'hibiscus',
9 ^) H* ]1 J1 s3 C( v '78': 'lotus lotus',# D! |5 j! ]) L9 w' ?+ Z
'88': 'cyclamen',0 M/ P' D& a3 K
'94': 'foxglove',7 Y( s- ^# @/ ]& S* l' N& D& g
'81': 'frangipani',* ~, n. L+ K' U6 {& W6 |: S, V( n0 Z
'74': 'rose',. P' l9 l6 z' @( ^: m! Y
'89': 'watercress',
  K; s4 g; H& f( Q '73': 'water lily',8 y1 g+ ?  _. I: W  I
'46': 'wallflower',
* H6 f6 o3 B/ I '77': 'passion flower',7 G# U" S3 m$ W; Q2 B
'51': 'petunia'}5 Y% b9 x: `) Q4 a0 D

6 E$ o7 K* _% t9 J1
* E  q+ n) e% l2
0 S7 Y6 V; Y& o6 j, |7 q0 F3
) z6 n: ^2 a  h( |) J. }3 {4
$ M0 H; Z8 n) f( C/ s( A5/ B8 u* W7 v& {8 j% }3 _! W6 n
6/ C# Z3 ?0 n8 k0 x0 w& S: m9 O$ p
76 m2 }. K+ Q2 a! I6 S
8
, O8 d& x" t" U  n& w5 ?* r# N  O9
0 j& M0 P* k' [( ?( _10
; q  j, e9 N# i& k5 G" A11$ R* l" w8 l0 p% m9 R. o. X9 M
124 f& \3 `/ a1 B% k/ `, a8 Y7 a* p) A
135 m- V2 ~" X7 d; Z8 I( J5 e3 n" R
14
. C" f6 ^# v: k* n; M8 z15
) F2 P: X; D9 K/ J16+ S# ~. y# b' D2 a! f& y5 _
17
8 x$ a* k* K; E+ c' T" B) v; ], q18
& o( d; l( q# g3 d/ Q9 H; J19) n0 v: K) P0 E; n
20
9 m+ Q$ V2 Z) d/ h  B. S: o21# Y/ }' @  l4 l
22
# z9 [- v/ d1 E/ f* h5 q! Q23
7 h$ a) ?' N+ k! p$ ]# f24
4 I4 N2 Y: J! q# S3 q3 I25
2 o% Z; u+ S; c" y, G# p* w1 F26/ v0 R$ X6 Q: }8 A' ~* G
27
/ ]) p( W, f$ ^7 \6 e, X28
0 x5 _- e* J" u29( h  _. p2 Z# u
30
' V+ U1 |4 m5 q# u+ m7 x31
% S8 [: c5 S. ^7 W! `32
6 S/ K: N* @: y$ a5 }/ V" G/ B33
( B6 M& b7 p; d, J6 N/ y* h& J! G34
: }: R2 I, ?& ~% o5 T1 @35) q0 s: w+ V& Y
36, [& z+ c8 t3 a) Z
37
& Y" X; A) v, T  l& _3 j38
+ {6 D; U6 T) ]6 C0 ?39
* l2 x% C/ F& U. A8 V8 ^/ g( }* L9 o. ]40
. R& ^/ e7 k' w4 N; Y0 Z417 L/ x- P  Z- J! S( n
42
8 @$ x0 D0 Y! p432 K& l/ D/ b, h: n: j+ V" m; H
44: L5 f& Q+ i# o: ^. Y+ @
45
  M7 ]1 C/ K) `) ^1 ?6 O" v2 R46  V- }# Q" x1 q* h$ j2 C1 c
474 I! V. ?! T- V5 m3 ^
48
. E2 W5 z# \9 b4 T! Z498 F' ~5 G" p% S9 L
50
- ~( R# L, m  ~: T1 w5 N51
6 h+ U. \+ U: t& `' x$ Z! m  F52  y1 G2 Z' G. j, w( w, L
53
8 t4 K& _+ ?1 s7 E- A54
, B6 F3 H/ h* J55$ ?7 }6 T# S( x# ^5 o3 ?, h
56
5 _# u' H7 _6 ]$ M) e& Y. |57
* k$ d( I  Q3 ?0 ~7 D1 k* L589 b- c# s& a7 w+ I2 u: l4 b8 G# R" w
59- g- X* x) Z$ ^# ?: z+ G1 [4 ]
60
9 S: O) d/ m2 }61
6 p# N0 F# E4 O( @3 r4 j62
6 r! Q! {5 E: R: S/ l3 B63" C4 g1 n# A6 k; w4 @" K
64. i1 x6 o. I) v8 h# N
65
( U4 o0 H7 O- g2 o, x+ E/ ~5 T663 I0 z4 `- ?3 d, A, s
675 g; z. ]8 n! P8 ?9 G  j1 t  D
68
0 }( c4 C" H6 [: r% s0 c69
: m' x" j/ D7 Z0 g+ C; W70
/ g7 f+ b& \+ Q- ?# w1 x5 B3 d- |71
7 p  R9 o1 c. {72- w) P1 U& m. x8 s
73
* X, R  R$ [9 f. ?0 ]7 M74. O, R$ a( m, T" M$ {
75
& s3 S" |9 m- [1 d$ u76
" N- }4 x/ f) ?# O, ]77% q: \; X: i; D* P
78
% G) o' @2 ]. C" c7 Q0 _, ]79" W) p; H, ], @  n. M5 n# ^
80& {+ T( t8 q' i' G. K4 C4 y
81
: p* g( e% `7 y) f3 M! A82
& z0 }' k6 W" \+ S83
* v7 @1 J4 H' m! u( b5 k! v  K8 x" Q) ~84/ e7 K# j, f" X" A( E- x( L3 d6 n& f
85. h/ _9 E0 M) ]% M4 @
86
8 g1 j- I" g0 Q87
3 G! o" B' _3 |88
8 d0 ^& @4 M2 E  _: P89  W' W6 ]' o& j- y  D
90
& I( T3 V6 o6 H91
. ~; t2 i* d% v6 z; t92
/ h% E/ L6 H- m7 Y0 p93) G! s: x% {5 o6 O5 N! F
945 V5 R6 T$ J* v/ Y6 v! L2 E
95
+ X! e7 |/ C4 L96( q' e  O$ H8 _& ?3 R
97
5 C5 t) M) [# _5 w- _98
5 M- H5 ~' b, o# Y' X99
3 X6 q0 ~# y( j% t6 Z, G- f( E* ~100: T# w' v1 ~0 v7 b- d9 p2 p8 n" }
1019 S1 a7 Z# z" j9 `6 b% p
102
9 N: ^) [8 ~1 A" G4 m) S5 J# ^3 T4.展示一下数据
; I1 N- @; ?& I# [: S  z# O0 xdef im_convert(tensor):
2 m" k  r9 b5 e( `; R6 S% m    """数据展示"""
1 K2 K. B5 r. q9 I    image = tensor.to("cpu").clone().detach()7 R5 z+ d! I: M' v+ d' `
    image = image.numpy().squeeze()
& D9 r4 A( o& r6 X7 `  y2 u    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
! @' D  S/ Q% g4 ~: T, k' F- j" x    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)8 j; c& k, l$ O1 @0 N; `8 A
    image = image.transpose(1, 2, 0)
" e: t7 S1 C6 D# b; V    # 反正则化(反标准化)
$ ~% T& v" Q  C8 {) N+ g: ~    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))+ h* G: M' R9 r. n7 {0 d6 Y( E

, y7 }' N. P5 m" @7 R% S    # 将图像中小于0 的都换成0,大于的都变成1
  f$ H; N: n* e! x0 I    image = image.clip(0, 1)! |% j( l$ w: G$ o3 g7 j

- f. Z( |3 _3 M# L. z  P. B- N" [    return image
. {4 `" }; c# _6 g1
: a: P1 T/ |1 [2 v- z4 j: S2
! A$ \3 m' V! J: P3. x6 U! X  a' K# [+ L  d
4
$ K& _" q0 @) r/ w9 x1 N% J7 s5
" n5 @  `8 y4 c3 @# U) B( S6
' c8 ~# [% a6 O) z9 G7
" i+ b' |/ P. q. _! w86 d2 h" `- s: d
9) p% l7 x2 j. S' r6 z
10& F7 g$ A' u( |) S( S# ]
11
; \1 J& m) ?* l12
& W; z% n) f7 Q; V( ~( q, d13
7 u& P, ?+ Y2 X* i& a14  A; J2 T' W% H. O
# 使用上面定义好的类进行画图
7 g. X- g% r1 @2 Ofig = plt.figure(figsize = (20, 12))
  n; k9 ]( k4 acolumns = 4' M5 |. I  e! _2 w3 {  e$ a5 x5 _
rows = 2
& |* N. u: O* @  C, {9 _
3 s1 A! R: M. u9 g6 Y# iter迭代器, i) }5 o8 y7 _+ J4 [! p! {7 g
# 随便找一个Batch数据进行展示
. M, z  s! N' \9 Ddataiter = iter(dataloaders['valid'])/ S5 v$ v& k. C5 H1 A0 ~
inputs, classes = dataiter.next()7 J1 l9 `8 B$ @0 z! q1 `

' A) H/ S3 i- |2 tfor idx in range(columns * rows):
* a1 O; R% A  N0 W; a2 _    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
. h! s: A& `" V/ T: R% Q    # 利用json文件将其对应花的类型打印在图片中/ R) R! N, g* R& w* b; K
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])2 i! Q- Q1 q& g% ~5 v
    plt.imshow(im_convert(inputs[idx]))" c, G% d4 u/ B% I7 a2 \( R
plt.show()( `  M% U/ M& X- c) E7 D) s
, l% w8 T% h. ~4 P0 @5 W
1' y6 K! ?" D' v  c/ F' M) ?' q5 ^
2
6 ?$ a! Q1 \$ x# y; x3 S4 I$ ]9 w3/ L; O8 [5 s2 Q; j
4
! X3 v/ F! V2 ^: D% z7 I5
6 J" q- [$ D. }# E  d- A6
6 ?/ ?0 F" h" I/ N3 b; c7) ]  W# A5 Z+ w
8
5 l& M, j5 G8 Z) K9
7 B. b* \9 p* d$ F' [( x& K10
% @$ J2 k% d5 t( C# m115 |0 n. W& }0 C' Z2 L, `8 S
12
$ _3 @$ J. ^+ i+ {- b' f2 ^$ w13
/ ^6 d7 Q2 z1 z: l14
; l: `. t8 ^( X& j- ?; x0 x& q1 j15( O( ]" x/ x9 O3 l( ^+ D( |! i
16/ y' I. }) B1 U% d% ?7 @

" @5 w" m. X  N9 v( r8 i5 |
1 t: ~% s; f7 U4 z4 R8 }5. 加载models提供的模型,并直接用训练好的权重做初始化参数' t4 B+ U9 F% i& f) X( H
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
4 T: X) h5 ~. o5 B. |# 主要的图像识别用resnet来做
- h" O7 G) S8 c1 E& M# 是否用人家训练好的特征& p/ \# ~4 f1 J
feature_extract = True3 z  r6 B& Y% m; b/ k9 u! ]) H
1
! _8 E3 E9 l; i) w- L9 i! H9 Q2
7 x; K9 G- l% M8 l3
, d6 t$ Z/ m' j# m4
( u; x9 w5 W' P7 L# 是否用GPU进行训练
- i  c9 L/ ]9 w  S2 M  w0 q# ftrain_on_gpu = torch.cuda.is_available()" _4 E: v2 {" ?9 b
* I6 ]$ U& m, ?. C
if not train_on_gpu:
4 k4 S1 G* B" S    print('CUDA is not available.   Training on CPU ...')2 @; Q, G6 _, s' c0 X  `, g4 e
else:; f) S; m6 w+ k" L
    print('CUDA is available! Training on GPU ...'), H/ v3 h& ]) C- @
# p1 R* H9 @& }  r! i3 T6 B$ P3 W1 ^
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
% \! |% m7 R1 d: V1
, H/ t! G8 X% Z9 A/ c2
: u1 d, {7 O  J- }/ F+ W+ j3
  J# s  D% j4 U, `8 J) @2 c. U# B3 S4
( Q4 s9 k4 n& h( T. N5
$ X  y1 U) Y' X6
/ v8 p. Y' P! ^1 ]& }' t# _7  X* _1 ?& V8 T$ \. x
82 Q% h. R2 w1 k4 V; p1 d  ~
9+ \2 Y! b/ d8 U
CUDA is not available.   Training on CPU ...
0 V; G. x+ C& P, o5 o  L& h1
' z7 B: ?0 Q6 j- I/ |! x& b# 将一些层定义为false,使其不自动更新/ Z. T7 @1 v3 X1 d
def set_parameter_requires_grad(model, feature_extracting):1 E& `& O, B+ {& j) f+ D2 }# G8 w
    if feature_extracting:% {# ?: r! |4 v/ i4 j  V
        for param in model.parameters():
2 |, D: j, C6 }8 Y% d$ l            param.requires_grad = False
3 y+ @# m: b  c8 I) z' l3 R1
- P4 [7 b1 ]3 W2 c5 b2 E2
( F  [4 j3 f2 p/ _5 Z: V: c3  N, @0 \' k% |$ P. H
4
: t; C( Y; z4 g- k5
$ {1 i- h0 o  E, ~: z0 _) m# 打印模型架构告知是怎么一步一步去完成的2 b9 p5 Q- ?, x: \5 S) x0 _
# 主要是为我们提取特征的1 F( V, L- U  m
# G1 R3 p; P0 Z! p" g
model_ft = models.resnet152()
5 }/ _" j  A0 z3 w* pmodel_ft  J" y0 ^5 O7 h6 h7 o* X6 W) O5 b
1) r- W, F0 f9 s. e5 R) D- q
2, v: P+ I. g: Q
3) v2 p- N7 b. X0 w
43 m0 @: u1 ~8 _( P$ a
5
+ o# }" {% O" _! WResNet(
' [: }4 w  c/ p* t  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)( E) c  w4 p; F9 g! G
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
) a& e; V+ A" F  (relu): ReLU(inplace=True)
/ G/ S  T! f. Q/ R( j0 Y: K$ [  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)* X4 @/ R( E5 I6 S+ A2 F& N
  (layer1): Sequential(
/ d' ~  B: ^5 @. m0 v- h; y    (0): Bottleneck(
; Y# C9 v! p+ n" G/ C& P* z4 |      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)# L0 O4 i: M- ]9 Z
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)( a& d- A5 }" s( Q$ T
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)+ u% W. L; W0 N$ E- C2 t1 r" C* Q
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ z; e1 W+ z8 x' y. c7 [
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)5 H. a5 L( d8 w$ J  p; }3 A" K
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- |" A  ~; V: @# _* ?+ C* S
      (relu): ReLU(inplace=True)
) D3 K9 u$ I& _: v# g% M# }) V: ^      (downsample): Sequential(" f7 L8 A; X$ [6 W( r7 f. @
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)4 V: A; K7 m( n+ o, u- G3 @
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)& X( Z# x( ]+ U- ?+ E
      )
' U4 U. f! c2 G8 Y$ |    )
# G& r* |& C6 U; V中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。6 B# B' ]( r' t) I$ T6 I2 G3 _
    (2): Bottleneck(: Q; |+ s9 V4 x8 x
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  L6 m4 K; |; i9 I# L      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# V. V; |! ?9 ^      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)- H) q) t6 D9 H+ _) |- F
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 I: b0 b  W. Z
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
" F+ D* T% `1 c# n, \2 U! s. y' f      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
/ M- z( O; T4 m, G  A3 m  Z; Z      (relu): ReLU(inplace=True)( o5 D- [$ a' S: s6 _
    )! y: d) z. g/ Y
  )( M, S; |; k6 O& k
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
% n0 A# o9 }8 N  X1 A  (fc): Linear(in_features=2048, out_features=1000, bias=True): H6 w! ]3 |: x9 I: }
)
% D3 ]& j' k4 {0 t" J' i1 x* I
, B: Q0 Z7 f& ^1
8 Q1 H1 K2 S8 ]/ J4 E- @% N22 u; o; b% m1 b
38 {( p2 f4 [9 w
47 d4 a& U! ^( Z
5
. e) r+ g- a+ c7 b6
! |- e( f- n' u) o2 {( Y7  G" ^7 h/ z2 ~5 O1 }# q
8
9 R# z+ v* w" m6 O5 M# [( H9
& B" Y* b2 U- K$ C2 l10
0 Z+ N7 F1 O) Z118 E3 k5 c- A- C
12- ^/ F1 q$ v& L
13
' j. _$ S# b0 d/ ~, W1 ^4 d14
0 t& V, D1 P8 n1 I4 |  `( {15
  ~' j6 w" J0 X+ h1 n! I# H16, v& b8 i1 J* q
17& E: n" p' Z: x: @% X4 ?
186 [+ Q) [: i1 b" A6 i+ S
19
2 v( C, t( H; s) R4 G8 a200 b6 q! `; j, u
21
0 {- b( x% \& j! R22
0 y. \9 Q9 M9 r6 H/ r23
8 z8 {9 Z6 W6 q24
; T7 N5 Q: J  z2 l25
7 c) }0 Z3 W# J, k/ Z8 E' P26! I2 H9 S/ s4 i7 d7 Y* V
27
4 y% F8 `; h* [5 {. h" K, G& ~- c/ ?28
. r/ s. g/ B% T) z  X( V29* j. Y3 f7 E6 e2 y8 ~
30
$ O2 r- P* t9 a: z/ V/ a31, A# B% d6 P  E
32
% h& s# V- V- s( [$ `! Z33% d% T6 \5 a$ R4 R- s5 p
最后是1000分类,2048输入,分为1000个分类
  I9 l. ^8 D/ ]! J% g而我们需要将我们的任务进行调整,将1000分类改为102输出
7 }/ l( L5 B7 |0 M3 k2 K
" C4 {5 j% [4 F% ]: K- [6 {' |6.初始化模型架构0 r* O. v4 g' k2 J1 b( `: |& _8 V
步骤如下:
: U4 k' T6 `5 [0 Y3 h7 \; w
8 Y2 E5 R- w" f8 d; N& M' @" k4 V将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
$ h5 R5 U$ N2 n- |可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
3 ]& Z: R9 @& p无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
1 B" z7 E+ d7 w* D3 D官方文档链接
. t; A# x' A& bhttps://pytorch.org/vision/stable/models.html
0 c" q' G! O2 \' C) h: c6 f0 [9 G9 h8 c5 j4 T# ~) p
# 将他人的模型加载进来, G8 h" X. ?& M$ S4 Z& r$ B% T  g5 ?
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):2 N; y2 ^$ T1 W
    # 选择适合的模型,不同的模型初始化参数不同
' v9 c8 S, \2 z$ l5 m+ b    model_ft = None
, I* n( M5 N# ?7 H    input_size = 0
, y# W1 Y9 G) |: k$ C) @. w  |
* G% t5 L$ h. ^$ T    if model_name == "resnet":+ w: _8 V. n3 w  Y# L
        """2 |8 i/ U, B* J2 E# ^
        Resnet1528 d5 }! |3 Q/ V3 H" z, |9 V
        """8 c  _# p& [3 D( Z$ S% K
! ?3 Q( b6 X! b1 ]5 C
        # 1. 加载与训练网络4 |+ }8 \; a1 i( q
        model_ft = models.resnet152(pretrained = use_pretrained)
; u+ z1 z1 D- _# G8 Z2 E        # 2. 是否将提取特征的模块冻住,只训练FC层. x( r5 [2 X* H$ D) y
        set_parameter_requires_grad(model_ft, feature_extract)6 Q$ Z  l  ~- C2 a" k
        # 3. 获得全连接层输入特征
8 c" k/ E, d& \1 m! R$ Z) c8 \' H        num_frts = model_ft.fc.in_features6 H$ }  V! p! ?" p/ C" l& }
        # 4. 重新加载全连接层,设置输出102
( E* Y9 r% k* g        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
# B& ]6 e. M4 \. H                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
$ B+ r8 z. c# y- {% d5 X        input_size = 2244 W2 L2 Q( O% O) t- A9 _+ }

; _* k: j1 Z+ l3 R! [! @. c8 H' y    elif model_name == "alexnet":
9 |% p- s! F0 _* s! w. D% {9 V        """
) w! |3 D% H  O$ ~9 o        Alexnet% r  B. c" \  P( M* S
        """4 c: Z, ~0 `+ }5 K% a0 r
        model_ft = models.alexnet(pretrained = use_pretrained)
# h  \) q" a/ B6 |* n        set_parameter_requires_grad(model_ft, feature_extract)  V% X5 r. C5 }; y

' C. y, O8 t" o$ z8 U6 O        # 将最后一个特征输出替换 序号为【6】的分类器
) I& {+ y9 H7 S2 j  t6 Q+ v        num_frts = model_ft.classifier[6].in_features # 获得FC层输入
. F8 m. h, q4 c1 q( R" l/ H        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)0 j# b. s$ R( ~" [) r
        input_size = 2243 `1 U% d0 ?! `
( `' A$ }9 Y. l. ~& @/ c& c
    elif model_name == "vgg":
4 T; h9 w2 A) C6 N9 w7 ^( r3 M+ m( P        """
  B& M7 x* N/ [4 B8 s+ F        VGG11_bn2 v& a: {4 w- u1 d2 _2 k
        """2 R. t* ]# R7 X8 M! E/ p2 q  W0 _
        model_ft = models.vgg16(pretrained = use_pretrained)  u1 D$ i+ a8 F
        set_parameter_requires_grad(model_ft, feature_extract)
0 j" B' K* R. W9 l& F+ {        num_frts = model_ft.classifier[6].in_features9 D6 e) |7 _1 f% n- \2 ^( \
        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
/ h8 I/ ~) Q+ j+ c% h1 O! x: k        input_size = 224
& x- Q1 R0 e1 }6 H+ r/ }) ?4 Y1 H3 |9 i: S) L! Q% H1 c( ]
    elif model_name == "squeezenet":
. Q& n! L) P1 c7 ?8 t/ R        """
# B+ ^7 q; h$ e  d5 R" w4 [        Squeezenet( U" I$ V, l) X9 B3 d$ t2 f. b: |
        """
' @3 m1 n  w  I) X8 t. S        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
& r3 i# q4 v& Z; f. |; T/ o        set_parameter_requires_grad(model_ft, feature_extract)
; T9 M: j& ^4 o; V$ t        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))5 g9 K$ h9 u0 _6 A9 n+ S/ E4 l
        model_ft.num_classes = num_classes
* v9 J: G4 U( o3 \+ ~; N9 F1 @! p! g        input_size = 2247 K; u% M; |) I

0 w" j; j$ |; S+ d; M    elif model_name == "densenet":
  V: M6 N# z; J: B' ]' @' L! o: _( ?" K        """: w4 _& X7 V3 B. g- W* U4 \1 A9 f
        Densenet1 o7 `. Z# k( `( o+ ]/ r+ }
        """
6 y, _$ w' Q# h5 N        model_ft = models.desenet121(pretrained = use_pretrained)
1 X+ H% L2 g& C) R* S; _        set_parameter_requires_grad(model_ft, feature_extract)' F6 Q2 P. O1 C& F# N0 R
        num_frts = model_ft.classifier.in_features+ Q' Q- M+ e1 P& ~7 |- w( ?- v
        model_ft.classifier = nn.Linear(num_frts, num_classes)7 K# r. _2 u4 X) p1 B2 P
        input_size = 224& b) s/ ^+ W8 V9 A& H! }( E

. M1 T) V. W' |2 U; n1 \6 R+ `    elif model_name == "inception":0 i' H: s/ c  H; c  D) W
        """
/ R7 A! E$ H4 b% C) w        Inception V3% `" @8 u2 I' P1 c, K. x
        """
/ J2 }/ E+ }9 S- N  ]+ W3 Q6 A. |1 Q        model_ft = models.inception_V(pretrained = use_pretrained)
  |8 |) E5 o  k5 J- I        set_parameter_requires_grad(model_ft, feature_extract)
+ D  P/ t7 T6 o* o0 T
3 ]- y" j) o& l9 s$ Q6 @' Y        num_frts = model_ft.AuxLogits.fc.in_features! v$ l  z: B/ R- q$ J
        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
. u9 ?6 C6 A! p) v; n+ @0 N% U( s
2 ^. Z% l! B! J  M1 |        num_frts = model_ft.fc.in_features  P9 K4 a0 r; p$ k7 b
        model_ft.fc = nn.Linear(num_frts, num_classes)
) z- R, {" t( f% b) w* ^0 T: [        input_size = 299
3 G3 O, f' [: z( Q# _
% ]) W2 x4 K4 ^* Y3 t    else:! ?1 K& T5 M3 y2 |4 y
        print("Invalid model name, exiting...")
+ y' O0 n' ^( f7 R9 d# v* B        exit()
" i5 p3 I' o+ Y- {7 L# X5 A$ l7 F$ Z  j* X. i2 k
    return model_ft, input_size
8 M* ?( y, o, {$ N; k% d4 `$ p
- \5 J9 I( W; L1# h6 g8 g# h8 }+ ]2 Y/ S* s! A
29 l; U( D, y9 u# h3 s
3
% X' w8 i9 h; x* F$ \( x0 p4
: q/ E0 Z5 l' g5
4 F+ K  _9 u0 q3 H6( J4 k) I7 p$ \$ A+ c
7
: P2 N' p+ e5 y% P& o: `* _2 j8" z$ C5 @* M' c' ~
9$ _3 y# |9 E% X" R; h5 Z. S
108 @& q, X8 ?5 b; [5 w! ?7 j! f
11
, x6 D7 T4 ?& k7 t% H# @7 i12
" c  {/ M" ]  E  P134 e- i* p8 e  N
14$ L' l* \2 n& Y* h0 h6 R! `+ h
15* I* i) F$ f1 ~
16  ?0 `8 q; {% Y  R7 v; F: U
17
6 A6 ^" j+ y  _- s" ~- m7 G9 B- Z( M& ~18, F' [7 D* w0 t  p3 J5 {7 @, J+ ?
19
* V) ^+ e/ E7 y1 Y4 ?20
- h* P2 S8 R' f" l% ~. Y5 s21" H0 J4 k8 |& _3 J3 X: f& M
22
# _! J4 O4 _. I/ T* K& H231 ~) V" y$ d. B! U1 h
24- H' j8 L% h) p; \; O6 Y) t' u
25) k& h0 u" @2 a7 Z* W4 k
26
" i9 T8 z0 L+ o27
. w1 S, t" h: S, s' G% n4 E+ L9 t28
4 g) W* n' S! ^29
4 t8 _$ ~3 G3 W* ?5 ~: [5 ]309 o1 e6 a. ?2 Z+ x
31
9 ]/ w. ^/ r! t3 `  @) ?1 i; f5 I327 H+ k; c9 U+ L& Q' [
33
7 H8 j6 W" r% C4 z. `( W& {  _; E34
' c2 _5 R( Y0 e7 _35" R' s2 ~/ C/ d0 S3 J
36  p: X6 U) j3 L9 H
37  w! Q, F& n& n
387 i$ ?$ O6 ?* W$ ]3 r0 l
39
' {  T" t  ^9 B* x+ p* n9 h! p40
1 \+ B2 b* o6 y* c. f3 q% `: n41  K& W6 B% P+ a$ b' |
420 q2 @! |+ ]9 o* Q" m
43. t+ q, p' u" a4 }, x
44
% a& E, H" Y( V- s9 X1 o4 H456 \( U0 H1 l6 N/ G  F) G; d1 b: d7 I
462 w& Y, X, d. M: @$ `& _
47& q/ a1 @: z. J1 Y! A% u% p
489 O  O" p" {% b" Z
49
3 F& A. c1 [9 a2 b50
. q% c# ?; J2 P515 c- }1 f' Y1 b
52
1 W5 Z4 v  Z3 J2 n531 E/ r+ [5 u: N0 p3 C5 k
54
" F4 j! E7 W5 y5 f$ P/ O55& S( A4 o1 j, U' ^
56
3 M3 B+ Y2 m% v, K: _9 L9 f" B57
9 O( }* \$ P5 l58
6 x( C. r) Y) z59
# k8 o7 R% L; T. E, s' Z/ y609 L; h$ e% u; w3 B  V
61
8 w; [; c7 G9 v+ k/ Y62
6 i2 L' i( W) P; [+ L8 F" Q63
: B! \* s1 y0 [9 s  P64( G4 h* B5 Q& c5 q
65
, I4 f5 ?$ @' f. E66
- Q" ~9 E) K& S( n677 `2 A' U  z3 q" {0 r6 A
68& ?- {0 f  D! z( X
69
. d, L3 O: V, |/ m3 \7 U  a, @70
0 [1 _0 e/ J; `9 X% k71$ t( \# c! T0 R2 x! G, W0 R
72
2 }) o; I, u$ s/ T) T+ i7 b4 y0 m6 U73% Y* \9 N+ M$ O2 S  ~$ O+ \4 O
74+ ]) @0 H$ v/ |- ]+ }3 S
75
2 r4 K: o& z$ x76
! @. ?# d1 b% C* B+ F, X5 {) A/ m9 B3 z77" R, j0 [4 L+ I( g0 S+ ?. N1 f
78
+ @; D7 M% J/ m3 M0 I, J2 `" c2 h7 {794 m; ~+ \# S) E7 q& o9 M
80
5 M# ^+ _$ C+ n  N% V81
' u  X( S8 E1 `# l. a  A8 @82
3 o) H% i3 e( d2 p* q4 P( }833 c. f2 y' A0 w; I- k2 }
7. 设置需要训练的参数8 h1 j" D. w; R7 h  c) y. \
# 设置模型名字、输出分类数) d3 J7 T) ~6 ~1 l$ N
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
& m0 [- V- _5 c, J% t) J. c. P7 P5 R6 Q  v  C
# GPU 计算
5 P/ v0 U9 i8 ~2 N" t* _, N& p( imodel_ft = model_ft.to(device)
1 E: {% ?1 s' T+ i
8 {. B5 S; T; ?; q$ I0 ^" B$ f4 F# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
! H+ J. z* Q* P: `filename = 'checkpoint.pth'
- M! G. [' q' @9 Q/ w4 r2 m7 @
5 d. A; a: v# I- N2 U' S# 是否训练所有层/ r% `  L& Y% |0 g/ J
params_to_update = model_ft.parameters()7 ^0 {* ~: T. [, Q
# 打印出需要训练的层
, @7 F; O: ~3 [6 V* `print("Params to learn:")& _! n9 p$ [- m" z  v- ?
if feature_extract:
; J& ~! u7 p9 j    params_to_update = []' n2 m) C5 |6 K, Y& {) i
    for name, param in model_ft.named_parameters():
+ C$ b5 j) o5 x$ I1 C! u7 M8 w        if param.requires_grad == True:
) b. W+ s$ M1 n" a) y$ |5 C+ `/ q            params_to_update.append(param)
# a/ F3 u: E3 p/ `1 }2 M            print("\t", name)
3 w+ B7 ]! V5 W5 m6 b% {else:( x& n) i% P1 [4 m4 W7 o+ Y
    for name, param in model_ft.named_parameters():
$ _1 A/ v& ?, n        if param.requires_grad ==True:
+ I/ I. P, t4 @+ \1 J; ~            print("\t", name)$ t3 `. r/ {5 ~1 {  n" w$ ~
; F5 n% h& x( b, @
1( `9 c( {  X7 E$ ~
2+ p- ]  [8 e; p* F: y  Q
3
1 u/ j5 e# q! U47 ?$ _3 }% S7 M0 ?- ~. N/ ^' K, j
5
# C- Q4 ?$ U2 t7 s! `0 @1 `" a6' s+ w0 w* l  N! h" J3 g! w$ K- Q$ u% F
7
3 g. I* x; u6 }. C7 Q; |8; s+ y( a3 c- P0 Y6 R" P3 F7 u+ M) d
9% A( k8 H/ k1 F
101 O3 r1 d% \( V. S  e# z' m/ V
11
+ l8 `; B* A3 x0 q/ Z: s: r: m12
- ]: M2 m! b8 Y" M6 t13
5 D+ ?0 G+ x" }140 t* U' x" e; N% e  z8 e
15
% z$ R  D2 h8 s16
7 A% E/ W. ]4 s; v% F' f17
3 D/ D) n3 |$ k( ~6 n3 Q1 m0 S188 p  N3 V9 _2 H/ L& `& O8 T
19+ P7 m, x* `) c, e7 g/ m5 G# S4 w3 `: r
20
' I. l: [3 u; X9 R# P$ A, S21
! y( N$ w! `/ n# ^' Y3 E8 \  c" q226 M" l+ l( W# e
23
8 X. l( w( L0 o% f( W9 ZParams to learn:
4 w! D) g. j: `+ {9 q         fc.0.weight4 D# {" C9 T1 j. H! x
         fc.0.bias  r! K# d2 i% B6 p5 c9 F9 x
1
: u* k& J6 R3 v2
. i6 W+ h( \: M0 P" V3; v8 x3 F" `2 Z' k8 q4 x
7. 训练与预测
( h: e9 ~# W! ]. x7.1 优化器设置
7 ~: J5 w$ O! V' @9 V# 优化器设置0 j# V9 K- g. u$ l
optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)) B7 F5 |. s, N2 s+ K7 B
# 学习率衰减策略
7 D2 V0 ]9 K- Xscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)) w2 Q  f; I0 y6 h- P- u
# 学习率每7个epoch衰减为原来的1/109 u* m& q/ F6 L; W: a  k
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算) f$ f! G/ E" m. R
2 n, {' j8 X, O* U" P& w
criterion = nn.NLLLoss()2 n" C# t6 O4 t1 t5 ?2 F. f% L
1
$ j! {7 [0 m/ U) a* S2
" y) |% B3 \: s9 a8 p3
  t( i0 }9 |; f5 N- p6 T45 G/ }5 n% G: v+ I) g
5
& g" w9 U4 \5 z6 d7 v& I! |! H6, G9 s4 d* d2 i8 q5 l: s
7
5 _* s- Y2 J% C0 ^# p8 U/ k8
2 A4 m2 D+ s8 O, b, X# 定义训练函数% |( {2 ]+ Y) j( b: D( [& f
#is_inception:要不要用其他的网络
* ]: Q5 l% N9 I3 W. E. ddef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):$ O* c2 z% r3 f
    since = time.time()1 B: Y$ X$ ?4 c* w
    #保存最好的准确率
% t* L4 A. U+ C! w. }    best_acc = 0
4 [! h: R/ \3 G" y0 E) F, G( n0 E    """
' l) h; K( I/ _: y3 A9 G) y    checkpoint = torch.load(filename)( i1 Z6 z6 d7 w; N2 r' \
    best_acc = checkpoint['best_acc']
( I6 k# t7 E/ a# S( r" s) x    model.load_state_dict(checkpoint['state_dict'])2 t0 o" d* D# O) o) [
    optimizer.load_state_dict(checkpoint['optimizer'])
% ~: [/ {/ ]2 u; [: z8 }/ j+ T8 X# ^    model.class_to_idx = checkpoint['mapping']4 N) _8 z5 n7 X) m/ d/ S7 `
    """! K. C' ^2 j! W/ ?
    #指定用GPU还是CPU: E& g) {  Y+ J5 I  o; E' }
    model.to(device)9 D8 S. S6 u* |" c
    #下面是为展示做的
5 g2 s, F: l! T* o2 d' v    val_acc_history = []
* u$ B  r+ ]$ E7 m( V1 a) `6 S    train_acc_history = []6 P1 I& k) P, L7 ?9 v# s/ Y4 z5 Y
    train_losses = []
. p- o9 `. v) b" p0 y$ g    valid_losses = []0 @4 t' g. }/ M' N/ C8 c4 [2 V
    LRs = [optimizer.param_groups[0]['lr']]" Y, X/ ?4 `0 K0 j
    #最好的一次存下来
7 v- B9 _" K# [    best_model_wts = copy.deepcopy(model.state_dict())
5 @0 V' {2 k$ K# ]6 t0 y8 K; k
, g9 S; q1 T9 s5 ~; m' @4 h, O  A, h    for epoch in range(num_epochs):
% G6 _" O) Y" J4 m! U6 i        print('Epoch {}/{}'.format(epoch, num_epochs - 1))3 b% R8 |' h2 U5 S3 Y
        print('-' * 10)" e+ J# O" R# m* X) p# {" Y- Q

7 u/ f" j/ N: o" K/ m$ e        # 训练和验证. ]. H* K: S2 n) ]
        for phase in ['train', 'valid']:4 N" q0 n" \3 x0 q# b( r
            if phase == 'train':5 U) K" {" e) U0 x. U
                model.train()  # 训练
/ i% n. I& ^% _, D# D3 D            else:2 s' F' P/ J/ O$ Z
                model.eval()   # 验证$ a! p' @2 I) ?' G9 {. Z
. k7 ^8 Q2 ~3 g; y, R
            running_loss = 0.0
  \5 v) `$ ]( @            running_corrects = 0
0 Z6 ?: S8 g7 E" ^5 T1 ^) E  C9 d
            # 把数据都取个遍! w; Q6 [0 `2 F% L
            for inputs, labels in dataloaders[phase]:  I6 [  P1 u- L3 F* E5 Y. n1 M5 _
                #下面是将inputs,labels传到GPU
8 M$ ]4 j& P4 H" v/ c2 d7 G3 n                inputs = inputs.to(device)) U/ j7 {" h* v* J' [8 @) A5 N
                labels = labels.to(device)% t% g, B: o1 g0 ]# |" e) I
0 S5 ~4 u7 u& x7 r6 z: s3 {
                # 清零  G/ t+ U! d" {9 Z7 j
                optimizer.zero_grad()& Q1 d3 p7 b2 L+ G# Y" V: l
                # 只有训练的时候计算和更新梯度
. h8 R7 A8 W- z4 n9 ^                with torch.set_grad_enabled(phase == 'train'):! A- v1 j$ ^8 `  S: Y  P9 K
                    #if这面不需要计算,可忽略
0 q4 e( E( R/ x; C  y" f6 K  U3 b' ~                    if is_inception and phase == 'train':
; a; O4 i8 [9 i2 m! N( ]3 i/ n                        outputs, aux_outputs = model(inputs)5 \+ @+ c/ k" O/ E! X) w: L
                        loss1 = criterion(outputs, labels)0 J  i2 j% r0 `6 k' i/ h
                        loss2 = criterion(aux_outputs, labels)
, J8 R/ T1 T: J; ?% A                        loss = loss1 + 0.4*loss2& `* l8 [$ I' ~2 ^- H- @* \# Y
                    else:#resnet执行的是这里
" @! M+ C, C, ]" u                        outputs = model(inputs)
# u$ S) f" C! {, I/ N9 B' z3 `                        loss = criterion(outputs, labels)7 e% A/ ]  q5 i( x8 @, |- C( G

$ z% |5 V0 f/ `7 T                        #概率最大的返回preds
% |1 v+ E+ n) L( P5 }                    _, preds = torch.max(outputs, 1)1 I( O& p" [/ G0 l5 \5 U& U

0 ~$ w2 s3 W/ c# i1 p                    # 训练阶段更新权重" h5 P% E, G2 r& v7 J
                    if phase == 'train':
! b5 j/ o2 e9 k* d, B                        loss.backward()6 D1 b. ~) D: _: ~8 b
                        optimizer.step()6 `- y0 v7 k& b" p. d1 p& ~

9 M* i$ A4 N- s  {" V" q                # 计算损失0 a# s8 N. h) i6 h1 V$ t
                running_loss += loss.item() * inputs.size(0)
. {% d4 {, F: {5 B8 d/ Z                running_corrects += torch.sum(preds == labels.data)$ \# e- Z9 ^7 ~* a

: K9 o. |9 E% ?( C# C9 u9 y. E            #打印操作
* b! }; w, g' S& i& E            epoch_loss = running_loss / len(dataloaders[phase].dataset)
; \7 C3 h$ `: ^' G3 K: N5 y            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
$ w& T8 c2 T7 _5 Y. b* ?& K1 q- e3 @! X$ m

2 f8 ]3 L/ E/ j' f% x% A: U" |+ b            time_elapsed = time.time() - since+ ?) u+ Z2 G/ D( `3 H
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
! L1 w$ n1 c5 O4 h9 D            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
% @3 F3 Y, \$ R
* [, m! Q5 g' Z4 v% f& K# k& Q% |3 N  a; \" j3 b; x- H
            # 得到最好那次的模型1 t/ M0 n, J' C! d. W1 G
            if phase == 'valid' and epoch_acc > best_acc:$ H1 F2 p" D# s+ g/ W
                best_acc = epoch_acc7 H0 E+ W9 L$ |$ a4 {* X. T
                #模型保存, i- o( a& U/ O; J
                best_model_wts = copy.deepcopy(model.state_dict()), [& @/ X9 A2 |7 D# j0 |
                state = {# x, C  n1 \) g7 C) ?( [5 Z% `
                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数
/ l. ]% x7 l8 F$ D6 ?1 ]                  'state_dict': model.state_dict(),
: t  w3 {; g  v0 z4 U5 I' {                  'best_acc': best_acc,3 u& n( Y- K, n" V" r/ a$ f, x% E
                  'optimizer' : optimizer.state_dict(),. k& Y' v2 M, f4 f; X) J
                }$ o/ C& C0 K9 p- y' M% }: ~7 ]* V$ a+ a
                torch.save(state, filename)
- f: D/ h2 T0 N& I1 a7 {            if phase == 'valid':0 K' `3 V& Q+ A6 ^
                val_acc_history.append(epoch_acc)
0 B% ]: v& Q/ R+ P2 f6 h0 D6 g                valid_losses.append(epoch_loss)
8 u5 J7 u' h7 r# P- r1 [! b+ \                scheduler.step(epoch_loss)
; e2 d: j6 R+ g9 X9 K+ w  d* K5 j            if phase == 'train':- B: W9 e/ L. B" J) ^
                train_acc_history.append(epoch_acc)0 a9 D" |5 y2 Y1 F* ?
                train_losses.append(epoch_loss)0 [7 \9 g- l! U4 [9 T
' ~8 H- t; }; X- R# k. d" Y1 p' t9 @
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
; l0 X4 I. z' F- ^1 {        LRs.append(optimizer.param_groups[0]['lr'])  Q# X& t! J9 h* P! {( I
        print()  Y& o4 \& {  B  Q! Q# B
' i! ~, u: p9 b( \
    time_elapsed = time.time() - since' i7 b' c) ?4 G  s0 F, r
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))) q! j9 o' Y) x
    print('Best val Acc: {:4f}'.format(best_acc))4 f/ p4 G9 o9 l& Z1 d8 @  e& @  |
7 b& L/ c2 W/ G6 R- g
    # 保存训练完后用最好的一次当做模型最终的结果
3 E$ R  `5 H' H* T& g    model.load_state_dict(best_model_wts)
& G. q# |$ L% h0 ]    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ( \' M- k  j. g6 Z6 {6 n8 ]
6 `( j0 O/ \8 @5 l

, b7 A- j- N! H: L4 h) K1( N& F7 k& w* g1 Y. O1 r* ]; G+ ?
2
* R; `. i0 |+ v6 x+ a3
7 y: e% \& D. s7 z3 S/ t4 K. `42 o& x4 O4 s7 J! H
5
0 m. v7 f2 o/ d" z0 T0 t3 E6, P" C- T) y  L9 d2 l+ `
7
- \+ ]# R1 @; P7 g8
$ r: K$ J! o2 O8 v% u9
& `% v; J. H; g- \7 @10
5 ?0 T; V! j/ w4 P7 Q11, Z2 l1 z: p" O
12; H$ t- r4 E, C9 ^
13
: B3 r4 D* n& p% H( _14; d: z( p6 o" z5 X7 A) R6 h2 h0 X1 I) I
15
  e8 w) ?9 C7 l& I( \! z# F16( ]/ R5 |9 [% j( G& `1 Z& \
17. L, z# w; F# K3 }7 _" j& O/ m$ k
18! ]6 }7 @; G! o1 O6 v
19
8 x. c, s+ D9 D# b' h- R20
# Z, d5 S" j4 t6 ]* P7 Z- S21
3 f1 K" B( u' |# X9 l7 L3 N& c22
# Y$ O8 p0 w9 l+ h' v; O/ P6 I6 z23
' v5 D/ B1 l/ `  i$ }2 M( m24
' Q# ~4 b2 `4 }2 h$ y253 c* _2 _8 d" K+ r5 i6 o
26, ]2 E/ v2 f! f4 K; j
27. P) j4 R7 I. |$ o9 X/ c
28& X/ U, B2 m+ Z; I
29
0 _) r6 I6 Q% U' C! p4 ]2 Q, ]30
1 o7 r% H. X4 J: Y" |! s315 i3 U, i' \5 R- O$ m! T) O8 ]# I
32
' l9 b9 @- H$ P+ ^33
) |' g$ x& p$ V34( G! n: V" h  Z  v
35
, _# u4 ?& @6 Q36; T4 ]) v7 W" m3 y$ Z
37
7 m' F) W4 b. j# M* c: W) w- h; B386 L* l7 J$ ~! g3 H" C
39
1 X# M! s- B+ c& z/ ^' U, k40
7 f3 {+ A# G1 c# J41
6 {4 c2 B5 o9 i  }4 J42
# s  c/ p2 H* `* E1 ^$ S43
( Q# G7 ?0 R7 i. P% z7 b44
+ J9 o7 o& h4 [; n: {45, P( D5 J& t3 O
46, B) g: ~) @: r: }/ {+ C! s
47
0 U9 E% D3 K) b: n/ b4 L! f( F481 ^( A  I8 R, o0 Z+ _) V- e
49
$ l( {1 p3 ]8 g/ |' B# {50" ?1 Z& W3 ~% O; h
51
. Y$ u  G" \$ c0 x3 R! B4 D528 [  J" Y6 S( g, g( ]" K7 f& w
53. m7 {/ x; s9 g" H1 S
54  ~" P9 `* }# R: R0 ^: M
55- l6 M* h( `" e; k; A+ ?& a, X
56
/ W! @+ @- E8 a7 P# W572 K2 s$ h/ d" x2 R
58
. K9 f0 P6 G. h59; U' Y5 |% f/ e( c# A( j
600 b( @, I) X; y9 Q2 S+ o
61
0 I. a2 x/ j5 ?5 D62
, T  r* r0 d1 i; P! M) @- \63
2 K4 Z% O1 q4 ?  V5 j64
) x9 V1 u5 c0 T% }65
9 f4 f! p8 t( ?5 F- `66
" B+ G' b% V; b% I+ _67
4 Q' F# z4 d. ?( J/ K/ Q# d68- w2 f2 X) i3 Y: R( @# ]6 K6 ~& \
69  D' I; |9 S* [' X2 C
707 h) F. n' j9 i7 c- O
71' W; l0 ]4 l* p* |4 f4 A% l
72! j' C( u3 |, x
73
. F& I  {8 o2 q/ V  k: r744 y; O5 D2 W% j/ C3 t
75: L4 D. b" t9 W$ t% Y( i  `
767 c( W2 d5 \: D8 f$ w+ T: m
77
2 l$ ?8 G& ]5 a+ D2 |78
  \. K; U5 ?. [4 Q- w. V# v79- Z! f) w$ d  i8 y1 W3 g' U
80- I6 y1 U  A( R0 S2 ]0 y) E( o( K$ _
81
% I( P8 y' j( S+ M. H82
! j; h' b0 n2 M6 T) p5 e" s839 |8 Q- M  r! E8 |4 d
84* W- ?, U& c. R$ y' f: U, ?
85
3 w' o1 D8 d/ H$ x3 W! O86
6 A1 i( L( B' Z87+ `+ ~* }7 a0 K
88
5 a0 n( [( [' q+ |8 |# \8 G89
, ?) f# ]% V( s% x3 t# }90( S3 {6 ^! \  K  q
91
/ m5 ^; [, w8 e2 D; f& k* e92) M1 y# E0 X+ _& M. j
93
- p- Q: S3 f* P* E94
6 `3 Y0 p: s1 ]2 d: G! k5 A9 h1 Y95) ^& l- P0 Z" |1 ]! ]9 b
96
$ a7 U0 I: L2 C8 H97
$ N9 T3 a# n* b0 t" R98% E- L* L, ]4 X( V9 T, |, E5 T& O
99/ m# U# h6 m' Q( F9 ^6 t# D
100
2 l9 w  [; p- ?7 u& A6 Y. Q1015 g7 p! s$ `: Q1 N3 k! b
102
' s( G2 W. Z* e1033 k/ b! t% X# }, Z
104
5 o6 ^) h1 N7 j3 @' v! b7 G% b- P* z1 E105. t  B5 \( Y4 |, f8 v+ v
106
0 F8 F* }( l' H) R, O7 h1071 r7 F: `& E! u: Y4 {
1088 n( H. B3 _  C- A
109
5 V- j, S+ c+ F+ p110
" o) j' ]6 l$ O6 m) l111
! Q( F  e. ~9 H) i7 f3 ~; N1121 h  i" u- N) n+ _$ \
7.2 开始训练模型
, Y1 n' k2 Z! Y" {: f我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次/ Q& j! y9 D% D( m9 a

( M( w" F" Z9 E7 j. {#若太慢,把epoch调低,迭代50次可能好些+ v( r& i( `4 ~  |5 |& e* W
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合1 Y5 }! K7 X  C0 q5 x$ g: c
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
2 C* x4 a9 I& f8 ^) X& _  F& P  d; p. F% x
1
- X8 F' Q/ C) G: X' d1 o2" X# n- t  h* I( Z$ j
3
4 x5 ?; C8 e# M1 d9 s$ a* O4" v0 |' _& f( A9 n* i. }$ A, Z) p
Epoch 0/4; O3 ?9 {; \3 G
----------
4 j( t# u/ r9 `3 CTime elapsed 29m 41s0 Q4 v; J$ B4 y, H  a" m! Z$ G
train Loss: 10.4774 Acc: 0.3147
" `% a) h9 v4 V& c* Y. m, ?Time elapsed 32m 54s! d; T9 J) K% ~& d
valid Loss: 8.2902 Acc: 0.4719
3 }) ]$ B7 I) f- `% N- i" FOptimizer learning rate : 0.0010000
1 {) l# Y- \2 q) S$ \: A
% J0 K! Q3 E- \Epoch 1/4
1 {! |6 v. v9 P, L) @/ ~4 Z----------2 ?+ ?4 F7 K# C
Time elapsed 60m 11s; w, d4 ?( N  @% T: P6 P
train Loss: 2.3126 Acc: 0.70539 z1 ]  o7 m) d
Time elapsed 63m 16s: ]7 A1 a2 v3 L! G1 O/ G
valid Loss: 3.2325 Acc: 0.6626
* G- v  d# j. G! j& t# p% W5 ~Optimizer learning rate : 0.01000009 K3 O4 X/ \) i! Q6 q- [" q4 x
6 ^) \2 q2 R1 T) d) ~
Epoch 2/4# U& q* s$ b% _1 }5 ]: a3 M7 ^
----------$ p' O$ r- L% a$ s) m3 R
Time elapsed 90m 58s; h7 ^  x" V! F5 [# u$ i
train Loss: 9.9720 Acc: 0.4734
/ G$ F, }6 H7 t% ?, J, ZTime elapsed 94m 4s6 Q; P8 Z8 y8 p) p% g
valid Loss: 14.0426 Acc: 0.44138 V' x/ [$ {9 @
Optimizer learning rate : 0.0001000
  I8 N, O- a+ n: V- `4 g6 @2 f: Q  k, J3 u9 W# W
Epoch 3/41 x0 d, i& O' j) |# o7 f  x3 j
----------: l) }3 ]6 I& }3 W: K" e( h
Time elapsed 132m 49s
' y9 g6 P, Q+ D" H/ ~0 ?: K: J% _train Loss: 5.4290 Acc: 0.6548
3 y8 G1 M7 n( k6 \9 e, N) R1 ?Time elapsed 138m 49s* ^4 o* h; x' L8 a
valid Loss: 6.4208 Acc: 0.6027# k% ]' p  K$ U' K) x
Optimizer learning rate : 0.01000002 \4 u7 O- a6 o3 h* M; Z
7 {- c" e/ w! q. g; j
Epoch 4/49 u+ l2 S4 V7 i5 E0 u
----------5 F8 W$ }) g5 B! ~& I/ {
Time elapsed 195m 56s
; X6 {$ _5 \2 u4 Q5 Ttrain Loss: 8.8911 Acc: 0.55195 g; f  F  F* R: x) Y8 ]( J+ D
Time elapsed 199m 16s& u, I  R# L/ K6 u
valid Loss: 13.2221 Acc: 0.4914
$ O1 G  \7 d! k8 iOptimizer learning rate : 0.0010000% a  ]& _; g7 x/ E# s4 |
! P5 I0 g: E+ I: j* Y6 _- G+ I
Training complete in 199m 16s
8 h+ X3 X3 `: Q4 P; D3 MBest val Acc: 0.662592
1 V8 s7 g6 Z8 [) ?
2 l2 n! N# L3 j5 q3 ~7 O18 ?  V, p# y$ S  J1 Y4 G
2$ n6 J. s) N$ X+ V
3
  g3 B& K- G6 y  x! ~4
& f% ?1 O0 Y4 t5
+ |# P4 e& D1 w/ M! F9 k& _) ?6, e3 }. I6 s' @; c4 G
7
5 I) r4 b- M, H7 R7 x' z# Y8$ i9 `/ H% X  l% u3 t2 `8 u( r
9  u: j! H$ G" E" _! h
10; K/ \( Z7 G1 r* r
11' I4 v$ \+ A- J
12
9 d; j% E0 Z9 c2 ~13% X: d0 h1 R2 }" [0 W6 R8 O8 M" E
14* r: [: Q2 r4 [1 a2 W
15
  A6 R* I" j+ i! s8 F  E169 p8 S8 f6 G' `# g( v' v
17
. k( `! g& G. D& E18
/ t2 L0 A  t! J) f7 _: d19: X8 x0 s7 u4 a8 u0 j  ~
20! m% ~' w% J( v) h: U6 V0 i
21
' A% K+ m. d; }/ H7 K" O22
/ P/ L4 [+ w  u3 s) w6 p236 U' A1 o- X6 b
247 N6 H3 k  C& ?& I: S" c
25
0 b9 l. t9 n* E# m& n26
9 C5 x0 L( c' }& [9 d* ]27& q+ B2 x' q$ y
288 l3 S  H+ k* y" h' o( S' K2 u! \
29* N+ A9 b; w4 ~8 q: d
30
( ~# F$ p, T' n7 s  R. d9 `310 n) F4 }5 g* c+ b5 N/ e
32
' i, W+ ~$ q8 i& m* T. a33
8 k8 b1 s! q) C  ]% f348 W- s3 \& m! d# i5 V
355 x) A& P& {5 l$ y2 l' J
361 [4 i4 j1 e! r6 T& E/ n& T
37
8 V$ m8 w$ n8 s& x/ P- C( Y38. R; K4 k0 w" c' V# y' p) b
39) d( J3 s- S, b1 g$ ]! p; Q6 d
40
% _  u* \: g4 t, h! d41* T$ K9 `. O3 A' u! P$ b, d
42' n8 d$ C$ [* q# C( |. g
7.3 训练所有层
$ \# x  E( t% h( w0 }1 |1 e# 将全部网络解锁进行训练
8 [6 x" l- X8 `& Xfor param in model_ft.parameters():" k5 k) E. P& A  P+ R
    param.requires_grad = True2 g+ _1 U8 m4 Z0 V0 {! U

7 g9 o) v( l5 l8 e5 t# 再继续训练所有的参数,学习率调小一点\
( W; z- @5 R( loptimizer = optim.Adam(params_to_update, lr = 1e-4)# ^( |; P9 }+ D, q, p- S2 r
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
3 ^/ M3 E/ v$ W% N
! g) v2 C$ [8 t0 U$ u) r* Z# 损失函数
, v! b" W, b+ |$ W1 G3 ~& M5 E' `criterion = nn.NLLLoss()% q1 r1 a$ ~  p+ |1 e, ~
19 }( T% o! K; N
2- D0 L+ \8 r# u9 X  `0 J
3% I' T8 l5 C+ h/ ^* Z/ @
44 Y: g" k9 C% ^
5
! {: v6 O) e5 X, x6  {, u% }; G8 I4 X
7, O5 Z$ {1 d1 b9 K$ K3 n. D8 V1 w
8
' {* G( {. s4 Q. P6 O8 e9% M$ Q  W0 J. r% W3 n4 I
10- F, M( M8 S' T+ Q8 u5 Q
# 加载保存的参数  H8 O  ~5 ~1 e4 U- m  u
# 并在原有的模型基础上继续训练
& m# e9 \# J7 }& E5 G6 U# 下面保存的是刚刚训练效果较好的路径
4 k( m# `; E( A1 r7 ~  [checkpoint = torch.load(filename)
4 X% Z2 c3 a/ j  Y. _# Hbest_acc = checkpoint['best_acc']  w! h8 s, R$ }$ }, s: G0 L
model_ft.load_state_dict(checkpoint['state_dict'])
, x2 x$ d9 d( o0 q) poptimizer.load_state_dict(checkpoint['optimizer'])
+ v" i0 @3 f. [1
0 E. f1 M6 C2 l/ |# z2
; v' t- P3 @& P. J- m9 F, b6 |* c3
+ M" }/ {/ L" O! g# d* J6 t4
5 A' Z* A8 W. j$ W+ c5 A5& @) \7 y5 m- [7 U
6
" d0 o' g4 Y5 K1 x: r7
; o1 c- W0 d- H开始训练
# Z, d1 n1 C' h. J* p7 z8 A' b* K9 B注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考- \4 `3 g" K4 E* L

& X8 G% \; e9 r2 n2 [" nmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
' H. V; X2 F+ ~1
& ^; t: \+ X$ D! ]Epoch 0/1
7 n' p( k4 g2 L2 [----------) T) l# T3 U2 W  O' z! @$ k6 i/ {
Time elapsed 35m 22s' M8 _) T) @! L1 f! p
train Loss: 1.7636 Acc: 0.7346. I( l3 U# `8 j2 v2 Z4 d
Time elapsed 38m 42s6 w  R+ D0 e/ G
valid Loss: 3.6377 Acc: 0.64555 R" a0 Z9 g" z4 N
Optimizer learning rate : 0.0010000
% M! s& @' g5 @+ ]' w9 r' L* A6 j8 V$ Z( w7 u# X) B$ ?0 k' T
Epoch 1/1
7 R* z7 ~& e6 O: D; k' E- Y/ E----------
. z# e7 h6 Q/ u; @3 J  hTime elapsed 82m 59s
2 U) o2 b& W  mtrain Loss: 1.7543 Acc: 0.7340, L% ^5 m6 L4 l  s$ Z! y" ^0 |
Time elapsed 86m 11s
! [* f' w. A( \0 |4 @valid Loss: 3.8275 Acc: 0.61375 ]$ m! Y/ ?( Z9 z; Y1 L
Optimizer learning rate : 0.0010000
1 K8 T( p* I2 w* G6 s* G7 n# E+ A. T# P
Training complete in 86m 11s# r  S. |+ S8 D8 q0 X0 T" [
Best val Acc: 0.645477
+ M$ Y; J0 o% E1 [; A+ N( i" m# ?/ I2 n, H
1) [8 S% N2 q1 _4 T
2) Z' K2 ?2 d4 y! E( m
3
# k) \2 L: P/ r) t0 q4# |' t( G& M- g& @* W' W) d- }
5/ x5 u5 a$ ]$ t) b
6
5 @% z* W* \; [6 y0 k4 ^9 }  r7
1 C" z- {3 j+ Z; |% X0 j8( t# `. ]1 K/ e. Y7 C! a$ y0 @, q
97 t1 R9 T# h6 U2 b+ V/ \
10
1 [) G1 J+ g' l) J; Z. t7 u8 L' I11
( I9 a2 t5 ?+ Q/ D" }: {124 h: C* ]+ H; x) n) E+ L# Z
13: {! x: X! P# u6 |
14
4 X7 K4 E5 M3 b; o151 U$ s' J' }, E! t0 m
16, H2 O7 p6 p/ c; M
17
* j) p" r& E/ P. I1 s, ]18
! ^, M/ R; f( m+ `8. 加载已经训练的模型4 K# h7 ^: P; A3 [+ ?! S* E1 U
相当于做一次简单的前向传播(逻辑推理),不用更新参数
6 D2 X% ~6 S. Y1 f4 N% g( e* M  B) _% U5 \
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)4 X2 O- u; H2 Z, C* u& H* h2 ]! C( }

9 ^$ P0 J' Q% g& Z! O& _# GPU 模式
! I% I9 S6 J! L8 O7 Z% N2 Mmodel_ft = model_ft.to(device) # 扔到GPU中
2 C8 I2 s. E+ p9 B3 V) a% t2 q! q( G5 @/ x. d2 d# g4 C+ F7 i
# 保存文件的名字1 K4 j4 n  c2 a0 O& `
filename='checkpoint.pth'- F  r- }1 \+ N' S9 T- {
" H- p6 l/ G7 A# j; |! x
# 加载模型
6 @. F$ {1 \% E8 N' w. ^checkpoint = torch.load(filename)
- e. x: b! V1 j/ P# Hbest_acc = checkpoint['best_acc']7 g) t" [" c5 n+ ]0 Z
model_ft.load_state_dict(checkpoint['state_dict'])
9 d/ I. K* \' ~& }' _( L+ H+ S* P1
  J6 S& l. a& k+ R8 f, U1 |  |6 n2$ g9 S# \! ]' f* k% L. B) u  H
36 x0 `& {4 B9 P; |' ^' i
4& Q  b& x( t$ ~6 d* v# F
5- k; a2 {9 e/ {+ A$ z' V# ]% I0 J
6
5 J( N* V# j, q* h7
7 o3 n/ z1 W+ Y/ s: U9 g& o8
! W' I( x* D0 c- W9& V. x0 _8 O2 r0 x7 e0 k2 z: Q! F
10
. G; t5 X# |% R- y1 D; B" Z11
: u. G' p0 K* T2 p2 q12! @" H) |% L( u  K
<All keys matched successfully>) o2 n1 z' c4 P7 a
1
7 }. v% r: ]. m* q5 Idef process_image(image_path):
/ q9 X* G  G# C/ }: n    # 读取测试集数据0 u: y. H$ g" L$ l/ u; h
    img = Image.open(image_path)9 B( \/ }% i! Z  Q; m  p9 f
    # Resize, thumbnail方法只能进行比例缩小,所以进行判断
, B5 I: ]  p! m$ w' y( Q0 {, {% K    # 与Resize不同. R: P7 u  T8 @
    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小" A3 U5 n" L1 a$ T
    # 而且对象调用方法会直接改变其大小,返回None
. V9 e' I9 c) g: b' o" U    if img.size[0] > img.size[1]:1 G* q7 r; A6 U! K' [
        img.thumbnail((10000, 256))
1 Y& X! O' A- ]0 J# H/ z    else:  f3 s& N/ V. |
        img.thumbnail((256, 10000)): p2 v0 q! f$ i# x+ r3 g
( j* B( Q; d8 P2 r8 [8 D  J, u9 d, K
    # crop操作, 将图像再次裁剪为 224 * 224
1 k, _0 t! }8 K: l% Q    left_margin = (img.width - 224) / 2 # 取中间的部分
0 l3 u2 m: d3 }: R' a    bottom_margin = (img.height - 224) / 2
/ a! ]/ j* V) w    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度5 s% U% F- z8 V* \! h+ \
    top_margin = bottom_margin + 224( I; K* q4 B; V; b, W! Y
( G+ J8 ?5 f; j) I
    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
+ Y7 |4 T  v( I9 p% V* V+ G
3 K  H- l" P% u7 `9 O  L( }# j0 W    # 相同预处理的方法
) r; K  A- B/ B% C5 x, n( }    # 归一化
% Q. t4 J- Z; v2 v6 x# Y    img = np.array(img) / 255
# v% H6 T; L! n+ Z8 E5 I( t    mean = np.array([0.485, 0.456, 0.406])
- V) [3 u7 f/ j* P& Y% k" _7 \5 p    std = np.array([0.229, 0.224, 0.225])
) W8 e( }7 v, E7 S$ v    img = (img - mean) / std
3 R5 F) B0 s/ @/ F2 ~( Z* B2 K6 y+ r( N
    # 注意颜色通道和位置- @; K1 m  F& @: i
    img = img.transpose((2, 0, 1))! R  \) n$ |. E' I; A

$ y: p8 h) Y) Y+ G. V% h    return img
3 ?! d3 d- _4 r0 Y
! G' N$ J1 r; @! }. T6 x* }- P6 }def imshow(image, ax = None, title = None):. c2 t+ q. I! l* `5 P
    """展示数据"""( ~, e$ z+ t  |6 e" M
    if ax is None:
3 m( a( N# f8 h& q; u* G        fig, ax = plt.subplots()
8 g/ x* ~+ X  j! R# S# A+ l. m9 j
/ @- q5 t7 I7 o/ C    # 颜色通道进行还原
5 X' B) v! d1 {. |& ^9 Q    image = np.array(image).transpose((1, 2, 0))
! j6 w* r3 K+ Y: c! N+ F* [: k
& [$ I4 U/ L5 f9 T' z    # 预处理还原: u5 k$ R; U$ l$ d' H$ \
    mean = np.array([0.485, 0.456, 0.406])
7 `5 z6 V9 l" z# D& q2 [    std = np.array([0.229, 0.224, 0.225])7 e0 ?2 j$ Q. T( t8 O0 j
    image = std * image + mean2 S" @" _7 H& A9 _5 M6 J1 @7 r
    image = np.clip(image, 0, 1)
& O2 }( i% S9 }/ g7 `5 I+ w% \0 T7 h  C8 G, F0 U5 b$ m3 m6 T$ W
    ax.imshow(image): N& @0 X! ~6 n
    ax.set_title(title)/ x' e7 n1 M. H: U, t

0 ^3 d5 p4 Y9 a  h6 l" w    return ax8 ^0 f+ g' e3 A2 V5 ~& i6 J

3 T! t& F9 n( X. Q* q4 l, P  i( F( cimage_path = r'./flower_data/valid/3/image_06621.jpg'2 j4 B  x; g; y# L# `% |) O
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
! t- I/ z+ T; O8 E( Uimshow(img)
! H1 m# z! ?$ {) z  x" E5 G7 i, u0 b: Z; F9 T9 z" {1 y
17 U4 P: m; s1 q; Q$ c3 h+ v
2
6 i" Z5 l, ?- m38 {  x& O9 \& b" ~+ H
4
# t: t. N! J6 E% o. T, b52 E& ?: D9 A0 b, P4 {7 g
6  g+ [4 G* |; h4 m  F& V  D, y6 D- o
7
; Z8 I, P: e1 R  j, \) A( d) o3 Y9 M8
6 `$ G1 {2 h9 G8 T9 P4 G9$ R( |7 K& O/ M. M
10
8 V5 F( k) `# W( ^2 F6 b+ f11
, A, y( ~1 N8 [1 l' h! P& v12
7 k3 W, U! D* q7 U5 K: y+ F13
4 J4 v  `+ W/ P% i6 d0 w- F7 n1 s14- k, ]; `& Y- Y; q
15
9 B  Q- ?  I! Q/ V" A3 z# V3 }16
6 ^7 a6 Q4 w9 s( _* m, A6 }17
- e2 d+ X) E" O8 `/ w, d6 Q18
9 g+ `+ y6 H5 L5 [7 X& S# ?( u  _19& D" |' `3 C6 L) l0 J( Q
20& G: \7 ~7 {( u3 n
21
9 a3 @2 Y; b& x, Q3 ]; ^22
+ C5 w% J/ u( b4 Z# f23* G9 w6 H. P) B8 X4 l
24
# i& l8 m3 V. s( ~25
: B6 I" r4 x& s- f+ Y' |26
3 m* ~( e" c& q# s27: V- O/ Q$ N; p7 {9 l4 {4 b  O6 k
28
) r0 m6 M8 {. d3 Y# |1 r29- m% @% L' Q+ U* s: @2 F& q
307 u* c6 l. W- V. V- s
31
- ^, y. Y9 l% d) i2 T* A  N320 r  D3 w( [9 H9 @8 n; g
330 ^1 }. n! s" h* l
34
3 h. w1 R  v: {( ]6 U- l35. f( R8 t# c8 ^2 X0 y7 l
36" V" `6 H9 p* T
37
: \5 C3 u* }7 F& b" L38
8 \1 @2 Q4 Q& l0 i39. V$ d+ ?7 E$ ?# g' N2 d) y
405 w- G) b  d& y% `7 p6 b
410 @, I% p5 E( z! z
42
. [2 d# ^& @* f9 C' c435 q9 Y5 W: `; @2 R9 O
44
$ y  j3 e' r5 g' z6 z45
" Q" p2 l; Y# [46/ j' D; n( n6 s' e; P
47/ ^! b( }9 D" I! U7 E
48
7 U* O! c0 t* X3 L3 _! x49$ Y, A# u! ~3 n
506 `2 w9 U' K" n9 K5 B; z
51/ S% s- C; {, v: {* I( q
525 }! J) J6 k' q; b: r
53
8 @5 g) z9 J- o3 n5 H- B5 ~548 q$ H/ B* ^; \( m
<AxesSubplot:>
" f1 _$ r% A5 k: y12 g. L$ L1 s1 U" ?, Z
# V+ @8 i' C- a! V1 |8 H
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确# {5 @9 _. }8 i0 K' K, H: _) x

  W6 U& S) j' i5 l/ w$ t1 |$ S3 Kimg.shape3 S! n) @5 V. E
1
9 R" W+ F0 [6 T' ?, m  A, |( }! Q( j$ ~(3, 224, 224)
3 o8 G0 V7 S! Y1
4 B4 `) {; `) H: l' p: I8 Z证明了通道提前了,而且大小没改变
) |5 O) u9 x6 N0 N- G0 \0 v) m7 q# s" I2 @! U
9. 推理: V: z- V: x8 Z8 s8 V& i5 ^; L
img.shape# Y4 G7 B- h: V5 l# b! v( v6 C
& D6 q6 s  t) b9 M$ W  o+ s
# 得到一个batch的测试数据
. i. u: T5 f  Y/ _+ O3 ^" X. udataiter = iter(dataloaders['valid'])
7 B( s( n! l# R2 T. cimages, labels = dataiter.next()" J' {7 t/ T0 |. n$ F6 m' U$ v/ O
2 Z1 k* R4 l# a' s  @
model_ft.eval()1 y2 r( b; \+ u
# O0 w2 U$ A3 x# N" G* ]1 A
if train_on_gpu:
0 `7 h* W; \1 h* m# n% t5 W7 C    # 前向传播跑一次会得到output6 d8 y( V; p& b/ W5 ]/ G/ o
    output = model_ft(images.cuda())0 s$ X) d$ x' Y- h4 w4 E
else:
# D* ?- q# V( [. h6 z6 a    output = model_ft(images)
* \5 k0 E) i' O2 K3 f- G2 ^) B9 K. m
  v' x( B1 e7 B9 @( {4 z" w# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
( t1 G7 Q* q4 `+ toutput.shape$ x& W3 D: \6 B% z; e

1 c8 [$ x/ f5 L% }5 V* y- ^1# @* s# M1 V, A% S' Y) [' o, N
2+ Y8 ~$ @6 A# o
3
0 U1 d; m( g2 U2 Y- D9 c49 @; j, c* n/ s* b7 Z* i
56 Z" n; O4 J8 d3 h- f$ R9 r& y
6( B% D, e' Y; m2 X/ y; B. r- ^* g
7
& ~/ a6 n- f0 b3 Y* ]; n8' j# S/ a' [6 X) A, F
9
8 L+ ]1 I6 e: O; k* w5 P  b3 M10
# ]8 j0 M; P+ u4 N11& ?! P" p2 A" g' Y8 x4 N5 N6 S% n
12
/ e& [  u& I/ D  t- |( m* R6 Q13
& `8 X0 m7 N( i  E$ _1 e5 J14
! w- L- u' j: Z$ Z8 Q# e# o15
' y  w: v6 |! C, M9 @16
9 p; C+ {+ q4 H  P6 w6 gtorch.Size([8, 102])
0 Q- |4 |! E) f9 v  w  h19 B+ v0 i; K% s+ k
9.1 计算得到最大概率
6 N1 L/ J" @0 k2 r/ L1 X# f1 l_, preds_tensor = torch.max(output, 1)
  H% Z4 x/ `' e5 B% c- A
, K, ]1 p( Z% @# j2 ]& M. Upreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
" a/ j; |2 j. W6 S# i5 l12 O3 c  p8 D4 N- \! @/ W. v- p& G
25 Z' Z% |9 R2 b$ ~) b! ^
3" r6 O" d7 x4 \( |+ M$ k6 F  Q0 X
9.2 展示预测结果
0 o9 r2 c0 L- L1 g: q( b* d. rfig = plt.figure(figsize = (20, 20))
- X8 j* }3 f8 n% {. J- C1 R# w  ccolumns = 4
9 L4 a) l& @1 N- yrows = 2( d- l/ }/ @0 q& @

, `; I- d- q4 [+ \for idx in range(columns * rows):
( Q6 J/ F% h8 v0 H, W2 ^% C    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
$ r$ o- N# w" L+ F8 a/ x0 D; U    plt.imshow(im_convert(images[idx]))
$ f2 o  V: }: ]1 N8 e    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), ) l; ]" p6 `1 _- e, K7 V4 P  b
                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))+ R4 @' H/ \& u  F
plt.show()
8 [( \; R& ^. u+ R$ f* l# 绿色的表示预测是对的,红色表示预测错了
" {4 a- h( @! I1% `/ A  ^, R! N+ j+ o5 a7 u* h
2( J$ L+ \6 K5 r8 C0 m2 i5 ^
3
  r. k- I6 o4 e) C  S  \; n4
- m* f: E8 l0 S5 n$ s5$ D: Z1 K& m4 V2 T% I0 R0 w
6
) ^& g+ }% q% R6 B. e/ h7
3 `# ~) @: {& K- k+ y- N8' S1 t# u# n7 K) u# x
9
- x/ N% s6 U& m10
* w: F9 e3 T1 q! `( X7 N- j11+ k5 z$ R+ s" v- u
, b: h+ g, F: n+ F# O2 c

: y+ ~) {5 \% O( x. _  K$ t
5 h: _7 \! H- o( r  l- t7 C- I————————————————6 F5 r0 l  V2 O7 m5 T: k
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。& @5 T  Y1 ], [) J5 {
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
6 f" N9 h" \9 K$ l
2 h  i6 t2 p" A- \" L6 O
( T; B3 q( m8 ^/ p) Y




欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5