数学建模社区-数学中国

标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例 [打印本页]

作者: 杨利霞    时间: 2022-9-8 10:41
标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例4 M1 e7 M: ?0 l8 M) Q
6 B/ i- P9 F8 ?
文章目录
) [8 d& W; a" G% K卷积网络实战 对花进行分类5 a0 Z( K; F6 w5 c
数据预处理部分
. i5 A. M3 b6 p* w网络模块设置
1 N, c9 x- p, h& X5 i7 M网络模型的保存与测试: |  S; a5 V$ ?0 S8 h8 `& q7 H
数据下载:
: N/ G4 R7 r+ h6 [( d* z1. 导入工具包
' S5 k4 d3 K! ~) J+ E! ]# B2. 数据预处理与操作
1 L- z8 L9 X: ?5 s2 w) P5 \3. 制作好数据源
# M6 l7 @0 E2 t4 l* F2 U1 N  _) G读取标签对应的实际名字; u, ]( d! j- p' }# J5 t
4.展示一下数据. v, W! G# X: \% @- ~  e
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
, O2 d4 o* n5 w# [+ {% ]6.初始化模型架构
' N( W) A: ]5 L8 \7. 设置需要训练的参数
0 w  d7 p& V' |" {* s7. 训练与预测: w% n+ v9 y/ L; q
7.1 优化器设置
! @4 c1 f4 W  y- a! {8 p7.2 开始训练模型4 b' ]! u& W, s. z7 u6 [: u$ e
7.3 训练所有层# w/ t9 s7 K0 M7 O& u9 t
开始训练
% s! s& g$ Y; B$ u8. 加载已经训练的模型
% E: w0 S- Q8 b4 U9. 推理  P  e( ^- `& f- g( Y3 D! ?
9.1 计算得到最大概率
/ L, J) C( T3 k9.2 展示预测结果- K, l) _+ a1 ^. G1 x
写在最后
+ [5 r3 g7 U2 o# }* u% U2 a卷积网络实战 对花进行分类
& z1 t; Z, N, J6 G  Y% s+ `0 o本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
, q* B" D& E; `# F/ N0 z$ M- E! l/ O8 B. Q1 m9 u4 y# H, b
在文件夹中有102种花,我们主要要对这些花进行分类任务
* M; o1 o$ q* D7 u0 q0 `文件夹结构
! I8 y$ P! s, N8 U$ }- ]- e0 ~/ ]# `9 C  D& G3 o; V
flower_data5 ~. g6 e. H4 Y! i4 f1 M2 `

. r2 I- H6 k+ U6 Z1 Gtrain. \& j* B" S: X* Y
+ m1 m5 l+ M& m5 Z- X$ E" g; @
1(类别)8 _0 j% z0 F! }8 c& l% v/ `* V8 q
29 f/ w. e: _# v" d
xxx.png / xxx.jpg- ^# M: K' w. F+ ~6 ^0 W; l
valid# ]% \' H6 z2 S* i! o3 _
. J- o* }1 G$ j& P9 @% z0 p
主要分为以下几个大模块& K3 y& ~8 q5 y( B& g& d- |
+ A$ j+ D! X% |0 ?7 s: a# S
数据预处理部分1 }" F$ R$ O) R, d9 O. l
数据增强3 l0 b' l7 F+ p2 C) f- {
数据预处理
( J8 s% S8 @* `2 `网络模块设置6 m6 y, M* [2 X
加载预训练模型,直接调用torchVision的经典网络架构
2 x: O& R) n& ]7 q* R, E因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务  J; ]5 k" t4 F+ ~) C. J
网络模型的保存与测试
0 ]) S& S; X( ]% |% \/ q模型保存可以带有选择性
$ ^* `# y4 r. a3 l2 `# G数据下载:( q' A3 Q! q7 S+ f' g
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset$ u  q# R2 N' Y8 n# k9 f3 \

' b; i- y- V' \6 d/ W  b改一下文件名,然后将它放到同一根目录就可以了* i$ {$ [8 b; o& F: E

& Q! U' C. |# W. ~/ H下面是我的数据根目录
+ k& Y/ ^% j, }' c6 s4 T( }& v2 {7 Y1 M; n8 f  J  s0 ^

6 M+ a4 @# p3 b' X9 v( _1. 导入工具包; C& [' D: h! x
import os4 d+ D, K4 }' c. `
import matplotlib.pyplot as plt
3 J, k- O. l6 t9 s% G. _" u8 A* T8 ]# 内嵌入绘图简去show的句柄
  N. |9 X  z3 e( O4 R% A%matplotlib inline
' R) B0 H; }" `: R! J0 Y3 [9 e5 Wimport numpy as np
) r& r' N5 F. \0 eimport torch8 O, x1 m6 }1 m, U& t" I/ ?
from torch import nn
) d9 U8 s* ]. @( z' e! P, E  m
2 C9 H  ]" b( j- h6 iimport torch.optim as optim) X& p0 v- W  b4 L" e" A
import torchvision
) h& E3 z2 d  U! C7 K7 sfrom torchvision import transforms, models, datasets; o" {- l( ]7 J9 i  c' B0 D
" T6 }4 O' B- J" c
import imageio3 j9 z, \. r( @
import time9 R' x8 ?! _3 z" q3 k1 V
import warnings
  ^( z9 y1 V" ?+ x$ |: f6 Bimport random
3 ~2 G5 D( C% j/ m; u9 w3 Himport sys
; f6 ~* Y4 T7 @+ H6 Kimport copy' }$ l0 @' E+ _/ h9 Z
import json
  i" R8 W1 F6 F1 O9 u$ k& S  D3 kfrom PIL import Image
6 H$ u9 [9 P6 z% h, M! F
' b9 z2 o! C! Z& A
  c1 e0 H; m1 ~8 E. ~0 {1$ u/ _0 b/ D- _5 P4 ]- x+ a0 |) H, L
28 l3 F; l0 i& j3 T
3' \1 f, N! R2 e& y' ]
4* n* j: c, I% [2 M! Q8 Z
5) L$ B9 Y% D3 l) R2 y6 D% \" C
6
) I% |9 i3 Y7 Q0 s& F+ ]' U7. u/ d+ u' B: T
8, v3 _7 W, B8 `$ q0 c0 P5 _
92 ?6 I1 O' E% w0 ~
10
" v7 L, S; {- d  w; H2 `11
7 Z) J. U/ n: \12
0 T5 A, {9 p' E9 ]13
5 L6 t- N' Y6 e! f143 u! }; Y. w: Y: Q7 n. q
15
5 Z& J# E( Q/ T16
8 P! c* I, L4 C) Z/ _179 J4 b% J& M: k0 _* ^  {
18! e3 Z1 V& W* v
19
: G/ u- p# _3 [1 k20
0 M' F# ]- |9 O2 u# T  g21
* |: ]3 I+ ?* N: F* h2. 数据预处理与操作/ s- }' s) N1 M" s, C) q
#路径设置
$ A  w- o/ P/ o1 N) m. |data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
/ g+ ^7 i" m7 q3 Z  [7 U4 C) Z4 strain_dir = data_dir + '/train': X4 `7 z! c/ b$ c% t2 F
valid_dir = data_dir + '/valid'
& G; ~! E; w9 N" C, j1
+ W* U6 S$ F) z. q& y21 ]! @$ D5 H( w3 a# W3 k% o+ Q
3
8 l( R+ D+ w) c4 l4
+ {$ @/ @7 k; N- P. Rpython目录点杠的组合与区别
; \  j7 J, H8 W: {& i注: 里面注明了点杠和斜杠的操作
$ G( N, N7 Y) M1 h( ~4 v9 k& E* ~9 ]' s( {# g
3. 制作好数据源
9 O. U9 z4 F( B" b% Qdata_transforms中制定了所有图像预处理的操作' b) ?2 A/ E- X
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
" ?9 J' o2 @1 w  T9 N. q. Idata_transforms = {
4 ~$ v; G4 @- x! A( C9 V    # 分成两部分,一部分是训练
% @0 ~; ~. S5 {3 `; X    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间0 Y# _* t: b/ `
                                 transforms.CenterCrop(224), # 从中心处开始裁剪
, T! b5 E4 C$ ]4 p9 T                                 # 以某个随机的概率决定是否翻转 55开
$ B% Z# B- V1 r, y+ f5 C: N/ \  q                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转& }8 d+ \( n  u2 r. w
                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转. ]) Q' {7 I% O- g1 g2 h
                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
2 y( M& K" E) w5 V6 z                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),8 Y9 \& p% A" z3 t# M
                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
2 P% a5 ~. t3 b1 @7 f                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
, ?- E0 A. Z6 y% \$ W5 E                                 transforms.ToTensor(),/ T5 T; ]+ r- O, m( Q
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差% n+ {/ e$ I& q7 R, I* L
                                ]),
2 W& M7 ~6 p$ D/ [! W8 W* @4 `0 T    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
( i2 m$ L6 P) ^0 }8 R    'valid': transforms.Compose([transforms.Resize(256),# T! B; A, }+ D
                                 transforms.CenterCrop(224),( D$ Z7 E5 m# s: d; A4 G/ t8 n8 O
                                 transforms.ToTensor(),
+ R/ f, y3 i0 v+ j- T4 e+ ?                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同9 r+ Z% _: b' i  o/ E( K
                                ]),+ Y0 s% O& u) w! _
}" d! c+ k3 P- ]/ D3 I4 E
7 _% C5 H; O5 k1 v: S. z- T# ~. [& Y
19 X& `. F8 k" h# F8 E4 @: d
2
9 I  {. U3 s, Z/ g7 R9 A3- A3 C' j, q" h: t9 [  v
4
0 Z* o( L( i) p5 R  ~51 g, m* K9 a- n% ^* _" i5 K/ t0 V
67 c" K/ \; K1 c) x& m
79 D) Z5 ~/ g8 w
8
7 p9 t# M0 |1 X* m( F' H" }6 k, x97 v+ f5 \6 {4 w$ g
10
  P* n6 A1 D6 t# Y  R5 I/ G0 S11
7 n+ R. L+ w2 f" t12
8 l, N8 a9 v- }13
" {8 B1 j' M1 K4 X6 z7 W* @14- L0 ~1 R+ I. \9 t. \- p$ j9 P
15% l( ?- M9 C- K: i  I9 ?( J# u
166 q$ s3 S0 s- A9 `% u# a; f( T
172 C( s, n" o( A# H8 E
181 ~& N. `' `6 Y+ @" ~0 h6 X2 S
19
) t7 W( w( Y6 A- M  z, t3 s. O202 j0 [; W# ?7 H8 W
21
5 s/ S9 S. v$ r& R4 Mbatch_size = 82 g! l" x- w; D" O3 S' n4 e; h# W( k
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}( T# `& y; E* S
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
- Z% c; n& Z# h. b3 i! adataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} . r: Y  o' A4 D& W
class_names = image_datasets['train'].classes
  l3 F- u& S! o* ~2 I3 Q6 [5 h& g* x& b+ q
#查看数据集合
4 X+ _/ t- N3 e! {6 ~4 f* O: B1 I0 himage_datasets
- _3 A( S& s9 T# k' e- X
$ m5 e+ @2 S& a  z7 P& x' F. f17 J5 A2 K( W( G; {$ L1 _
21 V" \) ]2 J& l) T
35 f0 ~! \; r: c3 y1 u5 F- d. ?$ [
4
/ w4 i# e, ]0 K; b0 U+ k& R53 o+ P; _+ v. U" G: u5 _
6! j9 B  k1 O- f1 O, ~* R
7
! ?0 I- r3 m' x" |& z- |% a0 ^8; E& t7 h' S8 y0 E+ @. h0 l" e
9% o- Q# O% s0 z
{'train': Dataset ImageFolder
, `9 Z) b$ y) E' M5 F9 H     Number of datapoints: 6552) B! ^+ Q. T- p) H! V0 J
     Root location: ./flower_data/train0 b4 F- r" {' M& {, S
     StandardTransform' Y2 Y0 `- y8 g" c8 g0 x& E
Transform: Compose(
' f6 i6 ?" N. l4 t5 O, t1 ]0 z7 b4 v# J                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0). O8 M$ c% v5 z8 H8 K# [% `
                CenterCrop(size=(224, 224))( i. A5 x8 t2 a' Y
                RandomHorizontalFlip(p=0.5)
2 t6 r; n' @3 Y                RandomVerticalFlip(p=0.5)
1 N$ S( y$ ^' h4 z6 I! T) [' X+ ^. ]                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])/ n6 p! o% A$ l9 S
                RandomGrayscale(p=0.025)
& E/ A) F% o2 z( {9 W                ToTensor()% @% _* a! U# ?
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])9 m' f4 S6 H6 J* P
            ),0 X! u) r7 N& s: ]0 b* o( }- i
'valid': Dataset ImageFolder6 z% N1 e% q2 R: ]2 i
     Number of datapoints: 818- s( b, e& ?6 r3 w( u. Y: e
     Root location: ./flower_data/valid
5 A8 K. F  H6 r" U0 k, t. V     StandardTransform
  w& `* N9 ]* I) p+ x Transform: Compose(
- `4 s- G- l; d                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
+ e% F- c; i; w8 D& a- B                CenterCrop(size=(224, 224))
0 F0 K" e' \; B5 a0 |* D) Z% l                ToTensor()
) }7 U4 y8 s! ]                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])6 O5 u9 D7 e6 n/ I
            )}) m4 v( A& F2 ~
9 q. d( I4 K. y, R) G9 R5 L. j
1
1 h  z5 g( e+ i' y: x2% M$ P  S8 c3 P1 M! s$ I
3
% T% s7 t0 O! y6 X4
7 y% ^, ^. d# |0 T% f! b8 i& a! d  H9 s5
/ R; H/ A# ]3 o' b* L" k+ B' A" d6 r6
) o. }4 q9 {1 `" A& L7
1 U7 j" o7 o* L* K# N/ h) }87 Y& n$ W7 b+ ~" t* c" O5 x) c+ u
9
! f5 h* C0 J4 p! J, q! ]10: d1 ~. D* W4 S1 D& [4 P
117 M8 }7 o# D0 ?0 I  B6 X
12
. p5 ?! X( y" S. u13- A4 T! w: W" Y1 F
148 k' M/ [- G" W
15
, E) t# k3 V, V& G' r: H5 q16/ Z4 W+ }9 h$ r+ A2 [
17
; c: A' I/ S/ W/ }( U; p: j$ D18$ H8 T& b/ s% |" X
19
2 J9 r9 D  s* K9 A) G) V20
/ @9 I! T0 b4 N# l$ g21& a5 Q) h: O6 F& {
22) D# u: c; U. _1 g% N+ z
236 |/ n. w! H3 ]
24
/ t5 r. \8 [& U% V5 ~6 }# 验证一下数据是否已经被处理完毕5 h! O+ `: z* @+ a: m* A9 e. G; k
dataloaders2 M4 H& A7 a7 L
10 G6 O" ]' f$ T
2  k0 L( n! {, i" i; P  P9 j
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,: L7 L$ w6 H9 T% u- k3 ^2 z
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}8 ~# C$ N: z* N# _6 q4 T# U
1* S* B$ A% m  I& S* i# ^) h4 L
2
' @- B0 ^' ]& G+ ?+ C$ S. X* ~. sdataset_sizes
' A9 _, @( N3 |' i* v: j- j5 P1
9 P5 |7 |4 W( j' Z0 w{'train': 6552, 'valid': 818}2 `: K9 F) {  h; S
1
, k+ A% I7 x, E+ N! q读取标签对应的实际名字) C7 R  {- @2 ^6 ?1 y+ y
使用同一目录下的json文件,反向映射出花对应的名字9 f4 b+ Z( T: Z6 B

4 G+ K4 Q7 e7 h+ M% ?. p2 twith open('./flower_data/cat_to_name.json', 'r') as f:" P+ ]* L- n8 b
    cat_to_name = json.load(f)/ _% |- H; E1 w$ g
16 X: _* ^, P# z$ O
2
( y$ N1 [9 p( U- Q. w* K5 jcat_to_name
9 G5 }) }5 a* t, r: p1! Q& ~* M+ B# ]4 b
{'21': 'fire lily',
- k! r5 o' q8 t '3': 'canterbury bells',
% M4 h2 o/ l4 r5 P0 s$ s& W9 k1 O7 X/ n '45': 'bolero deep blue',' a$ \, p. b# r
'1': 'pink primrose',/ e5 t$ C8 k' G
'34': 'mexican aster',
5 ~5 ^, P( @9 `' S3 l7 n '27': 'prince of wales feathers',
" h- E& }& w, i# z' Z '7': 'moon orchid',
. a/ @9 J8 o. {6 s6 T" f: B '16': 'globe-flower',
, h1 F+ m! z, R' x& Y4 @ '25': 'grape hyacinth',  T3 P% Y% H$ t4 Y
'26': 'corn poppy',
7 ^6 A3 z7 m. M, M- N '79': 'toad lily',
6 l8 f8 B- h  g '39': 'siam tulip',* |2 m+ k; r5 b: o2 s1 o
'24': 'red ginger',' ]5 \% [; T1 n3 e& U
'67': 'spring crocus',
3 O' |4 B7 x* H; H '35': 'alpine sea holly',
9 c4 A8 n. y& k* D! c& A2 D '32': 'garden phlox',2 ~9 f* z+ K, H) R* T
'10': 'globe thistle',( C& L/ I0 Q* t2 E/ @( |/ {
'6': 'tiger lily',
' C) f/ a  {; [9 b  l& d. ]& y '93': 'ball moss',
0 X% I. ~( S4 F2 x" @* ? '33': 'love in the mist',
9 @. H# X6 L( s7 Y '9': 'monkshood',' Q% r/ \4 O+ W" P, N
'102': 'blackberry lily',. I) e) I0 ^$ A1 s
'14': 'spear thistle',
+ W8 U0 l8 c! h" { '19': 'balloon flower',
" F- \: N: |" `8 [ '100': 'blanket flower',
! A. s- p+ X3 E '13': 'king protea',
; i8 Y& v% e# D9 Q" O0 f9 r '49': 'oxeye daisy',1 z% |1 S& x5 G$ S
'15': 'yellow iris',) J4 H: i! l3 I* Y/ q* j
'61': 'cautleya spicata',! \% v' @+ O! P4 w# Z
'31': 'carnation',
6 X& R2 E1 S5 M: E '64': 'silverbush'," [# k- p4 }7 Z3 H
'68': 'bearded iris',
0 d7 e% {4 Y' g '63': 'black-eyed susan',
/ ]1 u; `# A' m '69': 'windflower',* g, _4 N% N/ P( v
'62': 'japanese anemone',
# R8 T8 s" `' p- C6 X4 x, x: |0 D2 X '20': 'giant white arum lily',. `+ y# Y9 {& ]; j$ T! G8 s
'38': 'great masterwort',
' I# A$ a+ _$ ]. @% O '4': 'sweet pea',
; J4 p# d0 h3 k% [ '86': 'tree mallow',
( ?3 A; U5 ?9 Q3 i '101': 'trumpet creeper',
& J* P' ], B1 }1 D5 T4 W '42': 'daffodil',8 \! B$ u! X5 r5 y% G9 f. A5 L
'22': 'pincushion flower',1 j7 p: j! ]% r9 o
'2': 'hard-leaved pocket orchid',
) Q- \" K( B1 F% L. r% Z  s '54': 'sunflower',5 q8 g: n5 J7 A. l) U
'66': 'osteospermum',' @, f3 M3 |( c6 P
'70': 'tree poppy',
5 |7 ?0 R' ]8 h; ?$ A: C% Y '85': 'desert-rose',
7 b* c( p' m. s '99': 'bromelia',& S. K& M- i4 U5 A
'87': 'magnolia',1 Z( Y9 W: W6 E% x  m- I
'5': 'english marigold',
, x' v! z: b! | '92': 'bee balm',
, P: Q; m6 _5 k: R3 _ '28': 'stemless gentian',
" a) x: `/ x8 z, L '97': 'mallow',
: H$ O5 I8 X3 P" v" D$ m '57': 'gaura',
6 `4 a' \& v6 u  L4 t+ ], s '40': 'lenten rose',
8 Y0 u* J+ b. s( k8 O: `- ^1 I! V; n4 t '47': 'marigold',- J% v; ]" ~" g
'59': 'orange dahlia',
7 T5 w& `/ S4 ~/ x  X6 K$ I '48': 'buttercup',' t" ^% x3 C7 Z% w+ R  O
'55': 'pelargonium',
5 G6 H  R, r0 j# C% G  r '36': 'ruby-lipped cattleya',
- N0 X' `2 r  d6 H" k* O# b0 b' } '91': 'hippeastrum',; I$ ^6 k* O( k* m$ K$ c/ A7 D1 k5 W
'29': 'artichoke',& D5 \5 f" E. Z7 c! _0 F: J
'71': 'gazania',( i7 Z8 g7 V% L/ H' S& @5 \
'90': 'canna lily',# M% W2 z# H, G* d
'18': 'peruvian lily',# m0 P" c+ `8 ~/ }  o
'98': 'mexican petunia',
. G* M/ @8 Q( W1 a5 p4 E1 |- s '8': 'bird of paradise',: n* Y- g) D2 _1 b+ v
'30': 'sweet william',
( l/ m: U3 P, X '17': 'purple coneflower',0 y( n( N, d7 N# a8 B' N: r
'52': 'wild pansy',
2 B) o* `: a9 H! R' }+ b '84': 'columbine',
. G+ z8 z; Z  y  r0 m: B) u6 c '12': "colt's foot",* B( d( d0 _! W
'11': 'snapdragon',1 H* K4 m3 ^) L  z. n5 r' I) S, R
'96': 'camellia',3 s; Z$ `2 R' o
'23': 'fritillary',
/ \1 d7 ~' F6 d8 t/ `+ d '50': 'common dandelion',  m0 m8 p* e! s2 \8 A
'44': 'poinsettia',- V$ I4 k/ Y7 d+ T: `
'53': 'primula',
  g+ m' c2 y* R& p6 i) |, o '72': 'azalea',4 d+ q7 K5 S2 E* E3 Y; R
'65': 'californian poppy',- k9 O: ]3 V: w% g1 e. s
'80': 'anthurium',' i8 {9 l! o# R6 ~/ F  V0 R
'76': 'morning glory',
' X5 V0 D1 X0 V4 \ '37': 'cape flower',
5 N6 Q' }& P) ?  _, B; T/ F '56': 'bishop of llandaff',: L. T  J$ Z' @6 I* v
'60': 'pink-yellow dahlia',9 s0 i9 t- R( L; W7 w6 x, p
'82': 'clematis',6 ]0 S8 w' h& A& p, i
'58': 'geranium',
, M, F) {* `+ H% o '75': 'thorn apple',
5 Y! E, a- u' Z0 ]8 A- a. \ '41': 'barbeton daisy',
; H8 t/ g# g0 z0 W" H) ? '95': 'bougainvillea',+ ?4 ~3 R+ t, |, c
'43': 'sword lily',% Q' |3 Y0 O! W* a8 N( Q& l
'83': 'hibiscus',
/ H6 A/ A( o8 T; r" i( f '78': 'lotus lotus',7 m+ X4 t# B( f9 Y- S8 D
'88': 'cyclamen',/ p, q: k8 P1 U8 c3 o; l
'94': 'foxglove',
0 e9 x9 K" i$ d: k) p1 H '81': 'frangipani',! {8 D, d# V) v9 s, x
'74': 'rose',
8 f  B- r$ }' Y/ K '89': 'watercress',
  a* e$ D  u% M '73': 'water lily',- z, ], c6 V3 V' r" _0 Y" x2 L' B
'46': 'wallflower',
+ \" u6 J+ @% N' w '77': 'passion flower',
& T5 N8 u7 V& y '51': 'petunia'}/ q. f, J2 |/ l( [
( D/ O9 `6 f6 y/ u! u
1
7 c) y8 l% _- r' V0 d* Y* V9 I2
4 {: s9 Q8 H' q2 ^6 [4 O8 B" u/ p) |30 |) `5 w/ t( n. {  \9 L$ S
4* E4 n$ p+ i2 \3 s  @
5! |" l% k! g. a3 H9 A
6. ^) }8 ^, ]$ e$ s6 X  n5 ^- _4 ^
7! x/ E8 n2 ~; X* s: u4 x+ u
8
: ^4 h6 U$ o' D$ Q, _9$ n) \" w1 Q+ _  D
10
+ w% K: ]; d$ |/ y) {# y11
8 l8 O1 S+ E6 K6 `. v5 W7 W1 Q) n129 y! E8 V5 W( L3 o; \) l
13
' {2 Q1 w( A5 ?% S/ A; }14
: i9 N. Y; y; t, p4 O15
( {) K: v+ P% A7 X$ W, k, B16
$ {4 U, o4 `1 G& J5 i) D17$ J$ G. ^1 Z$ T; a+ t& ~, ~
182 c* X8 I; w9 O4 a- J0 p4 ^
19/ C# w8 D: m3 a: Z: b3 ]+ `
20$ M9 n0 p- Z) @- c
21
7 n2 |. e9 _, D, K1 w# U228 F: }/ F. Z1 Y. }3 m% ]9 a
23' m- u; s$ L; u2 x
24
2 |# c" M$ x/ T0 ?; {9 w: z. C3 T25
+ M3 E) y" w: c: @7 {& I% C26
3 j' D, c8 m/ U4 d* S3 m5 x27
1 o2 t- x1 K5 `6 Y28" E! ^( L# o* ]2 h! w
29
4 R* Q6 \- x3 d* G9 }. n8 ^: a30( L0 Y6 T. G' N- j" y% K
31; i7 i( O% q1 V
32
6 j; Y0 o/ i) m# t' ~+ `! U33
( q1 ]; J9 s6 k- B* Y! V0 w5 ^' S34
; F2 M- O* \2 X7 T6 ~+ h( F0 u35- |; h. z) j* Z; o7 E
36
& x) S; J' R; a) s37, c2 z3 Y4 S2 t2 v6 F" S
38
# \/ Q3 ~& ?8 @1 Q0 S) Q39
# e/ f+ `/ W) E( H. N! Z, y40
# O8 T6 }% ^% t6 i# F; X$ C416 N" y' b2 c# p* h4 U/ G
42  J0 C% H9 S& H5 |  c) n: {
43
( [7 }: Y7 f1 U: V) G- I6 |. B44
, L1 V$ n) I. n45
- [) k% s  ~8 n, d8 ^: J46
% R' T5 W/ ~7 _# T0 k476 I8 }% a9 c  V' Z+ |; U2 |6 n
48
6 u0 J# R6 X- Z  q! k49
0 X- l6 }7 M9 Y% Y* y503 X7 E( j& w! b
51
+ ~' v( a. V9 ]4 q, }4 R+ o52
' W# l. e7 E7 ~2 a" s' r" @53$ S% x6 s1 c" t+ x8 C6 W2 g
54, r, x  w7 E% {7 `$ @! I
556 D  L5 N( i' o
56
7 ^; X+ h& d3 x5 h0 @- n57/ \5 y  V0 D# d* L  h
58* O. g  Z, b' h: O: z" x. t: S
59! b8 d1 f3 R$ C
604 F' a3 V0 A$ C
61/ R1 o: C# H+ o6 Y- A$ W0 ^  d
62! r& T- a0 Q3 Y3 b' a- F
63
9 p# d8 W. ?7 P  P6 x64
4 K% V4 Z& S8 ]7 L9 y& ?653 F& F' i$ M. N' K
66
: E7 h2 [  [& d1 f, Y675 D8 r/ [' }8 f; ~- A
68
" o5 H( w' O% d0 a$ X69
7 z8 ?( S3 z3 J# Q- z9 e70
" D3 f- i! \) ^3 a2 T1 [) G' u71
6 t+ }* x) w  ^) F: H72
5 R3 s, Y2 `) O2 ]- ^73
3 n9 ^. U0 B0 @) X/ J, R) f74+ r( T0 P; a  L2 z  v5 Q
759 v8 P) B: \- G1 b  x2 e8 z" w
76
4 l8 z; \8 s+ x7 }& k- s77
1 c$ U! x* v* g782 M! a) N  [$ g
79
  P+ Z! _$ z. [, M! M80
+ U/ [) f5 L; J, B- }$ N  I81# R8 H% h$ S  V, S" D# t6 v- Q
82
! j: n  v$ A4 n9 Y/ _& J5 `# t6 R83. w' V0 ^& s# I, D
84! P: u# f! i7 ~* f( M' C: p
85/ j. b$ f6 l: \; t+ R1 j) T$ q4 h8 r
86
9 B8 P8 L: i0 j2 B- `% h6 ?87
2 U; r) V7 ]/ b9 k! j. m3 V88
9 J& _$ g1 p0 L7 S8 k89" i; I6 {( L4 Y1 g' t: R2 b' k
90( x) Z3 b2 s: @: S5 R, U8 }
918 d- I+ P2 t# {" T" f
928 }: k/ z% W" K! p- C
931 C% M6 d, y  D7 I* s) q
94- ~4 X. t2 K, x  T/ S) b
95
; z1 x1 V, D( `" ]96
6 h/ d! N% D1 ?! R9 N) l# l/ B97
2 R1 _) B, M. b98
) C$ j) A' ?9 p; ^; B: `. o99+ p, U3 x  _3 E5 q) v7 t! g- o4 k
100
+ M5 u  E! G8 r* I101
: k: P1 o. B9 r3 Z" l102
4 P0 f# ~! ]2 x: a' s9 I4.展示一下数据: q) D$ g8 B; {7 R" M
def im_convert(tensor):" ~  H/ H; [; x4 W
    """数据展示"""
7 O. J! L  u: L    image = tensor.to("cpu").clone().detach()
# m+ p: ]; ^8 U$ Y    image = image.numpy().squeeze()
. n' H0 l5 t$ x7 [" L    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图! V- v+ K7 r' |8 w
    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)8 @2 S% p' I2 K3 `0 J
    image = image.transpose(1, 2, 0)! K8 z2 S, R4 Q! `
    # 反正则化(反标准化)
# L# ^. [  ]' b! V# Z) i    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)): `; Q/ Y5 g4 c; _+ V: Q: A" x# N2 Q1 f8 c
) u' V* d+ p- A! ~
    # 将图像中小于0 的都换成0,大于的都变成1
* T+ U4 b* M& I, P6 ]# b4 [5 `    image = image.clip(0, 1)
# j4 B3 e# z0 c, A, ?* C: j0 m" O; E4 g& @' F6 b
    return image# l( u4 b9 O5 Q- i$ x
1
. X; b1 K/ O( c8 X+ K2 @* ]3 Y/ t; ]25 O3 e8 H* G" F+ p, d
3
+ p2 \$ }% q+ R4. }. ~' Y2 R& \) N
5
& r4 ^, n; i) R% s7 {5 L6
4 o' X+ O1 M9 R6 ?8 {8 ?7
) O4 P! [# m5 X9 B8! o8 j. X, V8 V8 M7 f( ]
9) r7 Y! @  Z, l8 M7 r0 N4 @
109 s, y; N; d1 Z2 v. i* e! h1 H
11& {  s( @$ r, W! _
12+ x4 K& A+ L$ G+ U
137 \  A3 t" p  l; Z1 V7 W
145 @; w. k/ w- @1 g
# 使用上面定义好的类进行画图
) ^4 ]/ K6 b, k0 x6 i  d6 ofig = plt.figure(figsize = (20, 12))" u; d1 E7 M/ o5 `
columns = 4
4 A2 ?6 e( S+ Q  r! M2 `; a$ Qrows = 2; {' G. k! X0 e9 R' D% U8 @
; N1 I  V8 f5 G/ J$ U" P
# iter迭代器( g9 i: ?( X- S
# 随便找一个Batch数据进行展示
6 [) I0 a, a+ B; M7 P5 Gdataiter = iter(dataloaders['valid'])( R2 E& o* ~- w+ ]8 q( H/ D
inputs, classes = dataiter.next()& M; d& h. w, e3 H0 a: C

5 _$ z1 O  d( o7 R) u" cfor idx in range(columns * rows):
/ C: Y: {. J, \% e$ H4 j" z    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
* z# b8 G. E" M/ ]" \    # 利用json文件将其对应花的类型打印在图片中
, V6 m( T- }- R) h    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
3 T; y8 A# O' f0 T) Q# L! o    plt.imshow(im_convert(inputs[idx]))
2 i# G/ |" [; S! W3 x# Mplt.show()
$ ]' l  U2 \  k/ e
1 C  M! Q: o0 @1
4 E, ]7 x1 A$ g" V. T5 h2  M  {" P8 ~$ k3 p
3. {8 z, `. m, v) \1 ?
4  T2 e' ]5 p0 q4 ]1 m
5
& w7 X. u8 k* r+ g9 y- p6
3 n& n+ _; r+ ~- U, i' M$ I, g7: d" z3 k- L5 }; G
8
  Z+ O) W# G6 ~7 l+ T$ q" F0 V9
( J# O% Z$ I: P' _! P7 I10
2 V# x' O2 l# Y113 @1 M. D/ b( H- O
12* R9 X5 O8 S+ M
13
. v# L$ m9 n1 X$ n7 ^) ^* M14
0 j9 X2 B% M- D5 b+ y1 E8 F* S157 s3 u2 K# l" M
16- h& f2 \. N  E$ l: m- O* Y: f

3 D( P) w6 W! `/ Q, X6 {: S. O) r/ \5 p. E& y
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
/ l" T, U4 D2 R5 h! ~' z0 W. Hmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
6 ^8 C! ~1 o0 u5 Z+ r# 主要的图像识别用resnet来做
; ^" A/ ~3 B1 m7 s9 c6 p( J+ _$ u# 是否用人家训练好的特征1 M1 |. h! S. D& k, Q& H' A" i8 p
feature_extract = True$ a: E, m7 Z8 P2 G8 K' y. z
1# g2 [3 k, C4 o
2- h" w0 l* l" s7 h
38 A- N3 G6 g6 F1 E
47 C5 K( q# e) v7 x* }
# 是否用GPU进行训练. }" w9 R* i  ]% d& `. }
train_on_gpu = torch.cuda.is_available()
: l+ S$ c" ]) p1 s; N7 o, V3 c' ~
  f4 S: X- @/ n2 R+ \* `9 @, Xif not train_on_gpu:
4 [, ^% T0 l+ d5 A6 `    print('CUDA is not available.   Training on CPU ...')
0 s2 S5 d. q$ M8 e( J& {else:
. |. X2 w6 `" Q2 C+ q1 k    print('CUDA is available! Training on GPU ...')
8 b: I; a6 W" t$ i) i: ^- F9 Y! b) \
9 M! W( O& l9 o: Ydevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
# b7 n; ]+ V( P1
: ?" w7 K# Q$ z, F0 l2
( g9 V+ y0 v- w2 @6 W3
* J; E- r; n  V; \- N! S4
" J1 e, Q7 E: a6 u: ^/ T5
7 W& \8 \( q* d. ?5 V* e% Z6/ {5 C; x$ O: q3 k' z
7* m: P* [7 l8 |7 q
8
! w& o( T( I0 _8 Y5 O' u  I' a97 D' p4 _  B4 s2 l2 d
CUDA is not available.   Training on CPU ...* t/ `( \. F* S+ `4 O5 h/ i
1
+ b4 G1 n9 o5 {: j$ {/ D  o# 将一些层定义为false,使其不自动更新
9 S2 ^2 a2 T' D; adef set_parameter_requires_grad(model, feature_extracting):
2 Q' |% h9 }( J3 K    if feature_extracting:
7 s  n+ a, v# z! \        for param in model.parameters():
7 m& R' Y/ t* \5 G            param.requires_grad = False
$ q& ^' x! }7 [) ]+ k1
# T. L1 g- G1 i2
. Y5 t, r: X6 G/ v* l3" L. u9 m2 X+ O( x1 v% `0 T
4
% ?, U/ q. ~6 t8 t: H# o56 U+ ]" f9 p. g' e: |
# 打印模型架构告知是怎么一步一步去完成的
7 n, K( y+ }" b& C& v# 主要是为我们提取特征的
. m1 T/ d9 E3 Z! T8 h1 b$ d7 B0 U$ x" N9 j
model_ft = models.resnet152()5 k/ K8 y. Y- s5 ^
model_ft
2 M/ {# a* ?. k% Y1
$ q: r2 ]% l7 [( S2
& k& A1 D% |! X# \# R5 P" w; d& B$ y3
# i! a+ a! a7 Y! f) x7 b4: }* Q) d. E& _! T& g! P6 b' r
52 z! H' B- s/ M# K% {
ResNet(% u: n$ l2 x4 P0 ^! {
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
9 r1 }6 u) v) j" P  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
" e. g5 [# f+ J& o  (relu): ReLU(inplace=True)2 Y- E! L" _6 C$ S1 [9 }# h2 r7 G
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)% M. E: u% y" s+ V3 ~: ?
  (layer1): Sequential(! e4 t( ~' P1 B. K4 V" p
    (0): Bottleneck(
8 f& [, {+ X$ f8 R      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False); u$ F( |8 K" H! J. d- g3 W! L+ \
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): Y$ o0 q9 P9 g1 \1 p! B
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)$ i( ?# o; ]4 d  o: Y
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
* n+ S* ~3 k+ w7 M1 s, }      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)3 q1 q9 ^9 x- Q/ F
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
; F+ A/ |5 f4 o; ]' P      (relu): ReLU(inplace=True)
+ e/ ]) x2 s$ y, [5 r* U; q      (downsample): Sequential($ R$ C6 H* u, |% ~3 X$ u* K
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)0 R6 f3 o  Y9 p1 \
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
% j) @+ P7 h5 E( Z. X! t      )
! Y$ f7 @) N" p/ y# Y! _4 O    )
& A( N4 o" Z3 ]中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。0 `& z$ j7 u6 f2 t' I; ?- d
    (2): Bottleneck(, ?$ R: ^; q9 w
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)7 m" L- r2 Q& l
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  g. F" h. J9 }: q& l0 u      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)" I3 [" V8 d/ ~1 i8 F5 f) ^$ }# k& l
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
2 j; Q4 s- b- k: d: ?      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False); O1 ^$ u4 z( L$ Y& }& l
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 B  n$ y/ p: H4 }( b, v5 ]- m: p      (relu): ReLU(inplace=True)
) x. E- O* U' {% Z/ O+ a    )
8 c5 n3 }& G5 v- B. a9 |0 ^5 k  )
$ K; g( r0 W) m  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))7 N. p0 H8 Y/ Y9 b
  (fc): Linear(in_features=2048, out_features=1000, bias=True)  B2 g/ h8 B8 x2 `8 e1 {7 y* J
)
9 _7 c  N# @7 }3 Q  g2 |3 l5 W4 J* C4 S2 T2 z1 b6 I
1
3 J0 n, y' w, \+ D2! K1 I: Y/ a# @& h
3+ c( \& N7 n( u3 |1 W" M- ]; J/ z
4
5 \& I4 X9 W5 ^6 }/ v8 r5. x5 o( j: c& y3 m1 X0 E4 w
6
+ Q( F1 U% T5 c5 ]7 s' ?4 p- J7; K) h' F+ ]+ h5 s2 Y# f! x
8) h7 R" z% Y+ W' [  m+ y
9
* d. O+ m' O( l8 V) i/ [1 a; o107 r" o5 |( o  q. E/ D: d' T! ^: U! M
11
9 c5 ]9 f. i2 E& J: g+ K1 t12
' a& R7 K6 U& h5 x2 j% o13
+ ~! H4 l4 T6 b1 g; u9 {14/ K4 `1 }7 B5 p) u
15
" a, |9 O; Z( R" f164 D1 w# Q  X2 E
17/ d# e9 K0 Q% w" k9 A, D
18
9 _8 |7 h! A6 O9 K3 @4 Y19! f- \+ ], H- U5 M! l2 ]  U* w
20
' y% S. w2 o8 U6 t( I& N* a4 c' X21
. H# q6 g% N( C4 z  M7 E  y/ h22
8 B7 U8 a. ~' V( H3 t) S& M" H23* u$ O! I- p3 L6 h2 Z4 m
24- r, R& a! b2 g# ]1 |, Y! }  [* }
25
3 D: E* Z  q; p" Y  ?) Y26
: a6 t5 E' ~( O. {9 y27
% w( [9 W- f- q8 _28
+ k' i; T1 x8 U, i. y, c294 X; O( J+ y6 ?- @; o% @* a
301 r+ b: S- w: z7 q+ ^" l6 D
31
" ~+ P0 P% V+ h, s32# L& [3 `! Q+ D, u6 a! f( \! C
33
9 ]* Z* W1 Q5 z9 J. Y$ j: ]最后是1000分类,2048输入,分为1000个分类
9 h7 V" ]0 [% s6 g7 v而我们需要将我们的任务进行调整,将1000分类改为102输出
9 q8 ~& T# F$ i' i. l
- a/ D2 q# f. V$ T6.初始化模型架构0 A; [. W/ K+ |' T
步骤如下:
! B5 J. t' D  j+ t/ Q) }
2 e- ~  k- F6 f- Y3 Z$ Z5 F- ^将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
4 k& ^) a2 p8 |. W3 T/ F; v+ g) B可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
/ C* }; {( v! i$ N# N2 h无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数+ m: h) \' w5 k) }+ A
官方文档链接
. |" d% b- U4 o9 v% k$ _https://pytorch.org/vision/stable/models.html
; A0 {# w+ l5 R* {( k( V8 j# W- J. D/ O8 X' a
# 将他人的模型加载进来
9 \" |( _: f) p9 o, f- Ddef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):$ r* p9 p8 L  P% C
    # 选择适合的模型,不同的模型初始化参数不同
7 t. _* O! a$ d    model_ft = None
1 n/ n( q; T! B6 Q: ?! {4 I    input_size = 0- o/ G: g: b# w- w
' `/ V' y; a; i0 C& B$ [+ Z* ?
    if model_name == "resnet":
5 y; V8 U8 n2 Y" K  C$ E        """
: i* H) Y4 t; }) s& A        Resnet1529 B/ m* A+ |: _- y5 H" R) H
        """" \- D2 L# H/ g) T3 u- h2 b
- f) ^# w3 V, c: T) u
        # 1. 加载与训练网络
7 f* Y, Y* H; K        model_ft = models.resnet152(pretrained = use_pretrained)+ N5 S5 w. R% O7 k
        # 2. 是否将提取特征的模块冻住,只训练FC层) D! r+ ]/ q1 A# e0 Q
        set_parameter_requires_grad(model_ft, feature_extract), M5 t6 }8 H. @
        # 3. 获得全连接层输入特征/ y% O/ J! r) l1 f0 o* u
        num_frts = model_ft.fc.in_features4 V3 @, ?0 l& b
        # 4. 重新加载全连接层,设置输出102( r1 Q4 H, O1 f! v4 J9 Z
        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
% K: `# b: D9 }8 W, w7 u3 I7 m                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
7 Z0 P! ]6 p! n' A6 H# t        input_size = 224
( ~5 R, ?) K$ r1 w! X
" c8 D8 D/ f" h) d2 B1 l. }    elif model_name == "alexnet":
( H& Y) E- G! U. w        """/ N' I+ ]6 A1 P0 c- _$ m
        Alexnet
1 E- h/ u. U+ ?' L. \7 O. u        """
: H- N$ E) m! ^# D( Z( f. c        model_ft = models.alexnet(pretrained = use_pretrained)
* |! g- ]; h  q        set_parameter_requires_grad(model_ft, feature_extract)+ j0 [( H3 @; q1 r9 r% T

% n8 [  E; D$ p( r  J; H        # 将最后一个特征输出替换 序号为【6】的分类器
: h5 @. {' O8 Z/ ^& {2 B: T8 B$ h& a        num_frts = model_ft.classifier[6].in_features # 获得FC层输入0 ]2 s4 Z9 t  }" f% b" @
        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)- i8 R- l6 X  W
        input_size = 2240 _/ K9 E2 b6 C3 R* R6 u% A2 P
& Z2 P/ L8 c7 j8 U% g+ o
    elif model_name == "vgg":; ?. s* ^( ~1 m* s' N8 T3 L: J
        """; K6 E7 x$ V2 t" M" ^9 q
        VGG11_bn
7 F# D) q" J3 Y) [5 W' N3 Z' [& H+ z        """
! X) t! F0 v- Z* O$ L" G/ X        model_ft = models.vgg16(pretrained = use_pretrained)2 p( {3 S: C  j8 t% P
        set_parameter_requires_grad(model_ft, feature_extract)
# U& @: o/ j4 x8 j$ f% _        num_frts = model_ft.classifier[6].in_features
7 n- s0 M4 F( P        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)' b, H+ n/ r( H' S- n2 L% K0 E# g) P! }
        input_size = 224
) q+ ]! m% E, k& h: o/ F/ {' X2 i1 s$ e, ^4 B* S
    elif model_name == "squeezenet":+ U, w1 S" v! K# S: w
        """1 l; E& k+ J8 L) w- ?0 P
        Squeezenet
1 t) O4 ]/ b8 ^) B' E& Z# \        """8 b2 r6 G1 [9 v; H8 S
        model_ft = models.squeezenet1_0(pretrained = use_pretrained)4 e4 ?' `% E7 P8 n
        set_parameter_requires_grad(model_ft, feature_extract)
' E2 M7 y, F; ]+ C) _: E        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
. h6 W- v( R+ t        model_ft.num_classes = num_classes
) n5 E, T8 n! z* K        input_size = 224
4 @& Y6 A2 X; ?- D' @/ u9 d
, k. v0 {4 H9 Q& q; X- ]    elif model_name == "densenet":+ [! H$ [/ Y/ H0 R
        """7 M" }$ D" v  W) m6 T. Q
        Densenet3 E- b( Y7 b# C% c6 z, a
        """
8 N2 {/ q) l: Q3 Y6 d8 K( X9 e1 V, f4 F4 q        model_ft = models.desenet121(pretrained = use_pretrained)
, n' u" F0 U$ P        set_parameter_requires_grad(model_ft, feature_extract)+ R; q# Z+ x9 j3 D, h
        num_frts = model_ft.classifier.in_features
; D6 v8 L2 \1 [" z        model_ft.classifier = nn.Linear(num_frts, num_classes)( [$ r& i% L( o* _
        input_size = 2240 `9 q3 H; ~2 |. J6 |

- @8 L* p" H0 o$ v) J    elif model_name == "inception":% H6 V- D: ]2 t0 H! d4 |
        """# M" M  d5 P/ ?0 {* `9 h+ U) ?
        Inception V3: T3 A0 u* z+ p2 `& p, S% h
        """
1 v0 w# z, H0 p5 G$ f6 I: J+ k        model_ft = models.inception_V(pretrained = use_pretrained)
! B; R* u, f0 U        set_parameter_requires_grad(model_ft, feature_extract)' n5 S2 ^& h& D( Y3 n8 r: f
4 o1 k0 C4 C1 p4 q$ y
        num_frts = model_ft.AuxLogits.fc.in_features+ y8 d, Z* @7 y
        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)! x% K6 D- R+ W
1 {6 d4 h6 D9 N: _
        num_frts = model_ft.fc.in_features9 `( G: ]' K% W+ A
        model_ft.fc = nn.Linear(num_frts, num_classes)
* P$ i$ e+ U/ v' J, \$ Y( u  ^: y        input_size = 299( n8 w$ b2 h0 m( D
+ a2 S- N+ w/ k# C! @9 M- ]
    else:: p9 x3 b; v9 b1 Y
        print("Invalid model name, exiting...")
; J! R# H4 Q. l        exit()5 O2 C4 U5 z! i9 X: {. L* f

: a+ }% b7 l; {' V$ m/ V    return model_ft, input_size6 F. o. u0 C3 T" C7 |
8 p! a8 y% o& G; L$ f, y# y2 T8 K
15 |* U, L# p7 n. Y
2, i& M& d" w# _; C0 Y5 z: h
3: s. ~& i% |% u  |
44 I# p( Y: c9 E6 Q
5
0 n0 S, f0 |  Z$ ^/ W6
7 P) r: B2 O/ {1 Y8 v7
+ ~5 s9 W$ i0 u" X8# {) t$ c/ B+ \' J6 R- ]
9
5 j9 X- ~" W9 q4 z: J% n10
/ k+ R# j! S6 g/ D$ [0 o$ Y11
; |# Y* n9 r: S0 f! }12+ m5 z) g, ?4 n$ h
13! F' m: o$ f: @. S0 f
142 b& P: l5 W: c1 j2 A
15" d5 t9 i; u6 X, L/ R0 A. D
16( o$ Y% r! L& }* a
17
/ [# k* w5 e6 G; H# B' n: E2 ?18
& X, R  r/ ]8 D0 {19
+ K5 O7 ~8 {5 K* W7 K20/ ]: q+ O$ B; w, D
21* z' z' S; C" i
228 S) [3 t& J4 x4 _
231 k( x5 J% g; V7 `$ a% p
24
) t5 q; \  ?) `25- u3 I6 Z3 s& J: }" \. a  V
26; S1 H' {$ M4 c& A% f1 R  _
27
& D4 g4 k: {2 P4 C: I, u28! J) A; C/ D) i) E
296 ^, v( j% @- p! M% j! J
30/ `4 W! }5 M* m1 |( m
31
- ]  |1 L+ [1 b3 i7 P32
7 E1 U& |! u( o% z( q33( J" Z  I7 J) P( }- P
34
; q7 d) Q, S7 l* [' L35
! A  i! k7 V) ?7 O; J' q+ U: j0 Y36
0 [. B$ G7 G) c% p  }+ y# u372 }* z: R( q: _$ o, V8 Q0 |
38
- {% j: b* Z9 |; f) P39
6 ?! p# x' E% i, P7 G409 B0 x: X, r5 u$ n7 k% p4 }1 e
41
5 b& O6 m- R: w8 b42  N: i% Z; \8 H3 ~
43/ y/ g3 f  @0 d2 s
44
) R% L, Z! p* s1 u. p45# x  M6 K# {1 T3 w
46
- ~+ O' O1 [" z7 u$ {" X47" V# m7 B) v8 h
480 E: o+ i; \8 Y9 w! G. v) I
49( O! x0 @5 J3 }# s% w# X  r
50
1 t+ `$ F: g" _; {* |. _& F51
: x: b1 _/ n  D3 r' B52% p6 g* U3 q3 Q4 {0 d- K* z
53
8 g+ l/ F1 H2 X7 H54
) H4 ]3 j  U* S1 K* v! Y9 E8 Q. A8 v55
# U0 K2 j; E2 M56
4 n( r# l& T: s, p578 u" f9 l. @0 q4 @, K" x* J! b
58
  m6 {2 n0 F' [! A7 T596 M6 ~: c" q# B$ m$ t( V8 X. S
60
  l. D* P$ T# h3 t: o. n# n* }61$ p2 L- T0 P" O; h2 @
629 `8 @5 E+ \7 B- I% J
63
) Y1 E( `! u2 X/ Z% N- j9 w% V649 R! c  a$ N' |  p3 j3 z
65  A5 T# q- |' h0 k* ^# P. z
66
. m, S$ C5 }: m67
* v- y* v, A- t" I7 S688 _% K! p( X6 f( ~. T
69
& D) u4 w+ W" ~! h+ M) z70! @( q7 c- i8 @) ]
71! `( a5 p) {1 X, H  F2 A  r2 J
72, M4 h  ^: w, q) Q9 A9 v
73
' w- B$ e& ~" V: q  m3 p74- |% K3 r# A0 p3 m4 H
75
, j7 I/ o  ?. O2 w: F- h: b76
3 E+ @* H+ O6 j% s+ ^% U2 R77
- m9 V& s3 m& ]" }! y78
  ~* u# Q, i* N6 k& s: j1 k% s+ ^79
/ D8 `# N( ?: w; w6 g  p$ v+ a' M80
7 }/ ?3 Q6 G! s3 s5 j! m81! Y6 L5 k2 ]; M9 u
828 ^- ]" c/ n/ q$ j/ h) N8 R
83
9 X5 W2 x2 Q9 }7. 设置需要训练的参数
0 k: \0 p" c7 S+ B: E" X. n# 设置模型名字、输出分类数
6 E* V, r* J& Wmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
! N8 }4 M# w# E7 U6 N7 q: z* I. C7 |0 a( z1 z% ?
# GPU 计算
; ~* A( w- S4 X1 }' umodel_ft = model_ft.to(device)) s/ p! x4 Q8 T8 e' u' V

0 Q6 Z) _; L7 b& T/ r# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取6 m- Z/ f' e+ Q; |
filename = 'checkpoint.pth'0 w. `2 C4 y# u) V2 f3 b
% ^. Z" l& D/ T4 _6 M7 D
# 是否训练所有层
  g( L8 v% Q& j0 Rparams_to_update = model_ft.parameters()
' a4 \5 o8 b8 }# 打印出需要训练的层
* T3 E$ [! I; f* Cprint("Params to learn:")
* c0 D- J2 i& ]if feature_extract:1 W" i$ p& J; X
    params_to_update = []/ R  g0 _9 h5 z( L
    for name, param in model_ft.named_parameters():
% _' Z+ w- @- n8 S' y' K/ N7 j        if param.requires_grad == True:
$ Q1 U, E4 j/ c( ]7 n, t) e0 i            params_to_update.append(param)6 \1 p0 O; u# d) E. Y1 \
            print("\t", name)5 D4 }; C5 z+ ]1 Y" P  F3 }2 D% f( p
else:
/ @1 Z$ V. {% S! t9 i    for name, param in model_ft.named_parameters():
& Y' k! C) d* u2 V) x) \- K        if param.requires_grad ==True:* ?6 C8 R* @. Y% A
            print("\t", name)
9 @  {7 m3 p6 t6 n7 C7 w* g9 N4 w: u% \7 f" P* X% o0 N
12 ?- m4 O5 j& d
24 N  B6 [! ?( \6 Q7 \( C& Q5 j
3
9 N$ U* L  J0 U+ O8 i" _( h4# M. Q1 w8 r6 }9 @& ]  @
5  }. J* }- w: v3 [0 i( J- Y* h& J
6
( v0 @3 e: ]& v) W4 A: ~7; N2 K7 M9 f0 s$ ]( y8 P* h
8
% W3 A9 \2 y, K7 o/ o% x' f92 u& o3 K7 }+ J! l/ E  t
108 ^" t( |/ k/ N2 e9 t
11+ B. n  @$ p8 b
12
, e1 P8 c, M& K) G13
" R! Y0 Q0 y' a: l1 t5 b% |14) I- Q- c/ m5 D* n6 v/ j( ]% U
15
- y8 h+ x; y5 s; Z4 o$ C3 a! \16
. n/ p3 a6 A$ k# g: ]17
- W# H. K0 ^+ \+ H; ?18" d7 \" z2 u' x* H, I; ]3 I
19
3 w0 ]$ F8 R: J, C# v20- @/ z+ x5 m( x& K4 c* g
219 b+ }8 r7 z/ L4 u
22* i  o# g& b) c. c! j
23
* ]7 ~$ p/ v/ gParams to learn:0 E* I5 N  x5 j5 \" }+ R
         fc.0.weight
# C7 b: s: h/ M# ?* s& R) l         fc.0.bias
" X# b" g, z2 i3 t1
7 Z' {) M1 N/ t+ h8 c( T. u8 x2
  ?6 M4 f1 L" H# \2 v, I' o3
" i$ X3 N. S& Z- Y! _+ s1 L' J7. 训练与预测# ]. L5 R% u; B* p1 g
7.1 优化器设置
% s) f# ]' k% l# 优化器设置0 w; C& X3 o% ]1 K) P3 n
optimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
! Z+ u/ C+ A0 }2 q# h# 学习率衰减策略" H' y, x* P  \8 ]- T" T- M+ x& @  Y* T
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)( ?' r+ ]: @2 D* q- }0 k0 o$ x4 h
# 学习率每7个epoch衰减为原来的1/10
6 j& ^# E7 ^" v6 o; p6 L4 q# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算. i4 y; p. \1 u! Q* x: D& W7 Y
" j8 u+ g1 n2 g" d& ^3 q
criterion = nn.NLLLoss()4 S6 G# Y) x/ g/ z; F- k
16 B; z; W. F7 n$ H) d' G
2
! _3 t/ a% Y  F7 O' _37 P1 R5 H" J  Q3 m6 e7 @4 W! _4 S
4% O# g, W, l/ M0 E
5+ @1 S( E: |* v. U- x% T
6* Y4 g( V8 f1 o5 [% S
7
8 N2 X! o; k- @8 m9 h8
# }) e' ?0 ~. W' j( v* ^* L* A# 定义训练函数, ~+ p. o" T) j
#is_inception:要不要用其他的网络  q. _% P" v7 b* U+ W: ~
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
% u, u. g% K& O' z0 B0 c* r    since = time.time()# N, {8 [8 Y" p6 Q& |
    #保存最好的准确率" b/ a+ l& I0 r! E0 I1 O
    best_acc = 0
5 v/ a) J% I- v+ D; u: g) ^    """
0 n" S- i" J# H, k. h% T    checkpoint = torch.load(filename)
: e$ B- ^4 R) o; D$ o    best_acc = checkpoint['best_acc']* a* b$ k- [6 `8 }7 {/ B
    model.load_state_dict(checkpoint['state_dict'])3 g3 o% _) N% U. H: {4 C0 f" T
    optimizer.load_state_dict(checkpoint['optimizer'])7 ~: s" p' A) q4 ^8 D# j  T
    model.class_to_idx = checkpoint['mapping']! Y3 ]5 p5 p% y1 o
    """# S: H8 x/ A+ V% s% Y) z& o- ~
    #指定用GPU还是CPU
3 L% f# A  U3 \; R    model.to(device)# G# G6 p- E. A: H* \/ Y- [! w
    #下面是为展示做的9 r) _* z0 C+ A: d' Z3 t
    val_acc_history = []
! F& i" y* t4 s    train_acc_history = []' d  d% u  A: ^+ c
    train_losses = []6 c: L7 L7 T3 E
    valid_losses = []( ^* x1 F! ~: r
    LRs = [optimizer.param_groups[0]['lr']]+ q4 Q! F: x+ P; d! Y; k
    #最好的一次存下来
& o6 W) ^9 \- f5 v2 M$ E    best_model_wts = copy.deepcopy(model.state_dict())
7 v+ l7 Y* [# H# `) z% n( r4 S8 J
" _: ~3 Q5 F4 J    for epoch in range(num_epochs):( W' u9 j! C/ {6 m' c( |  J
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))/ Z4 Z/ A  ~  q: |: ]' I: B
        print('-' * 10)) ]$ H+ i5 u' T: f

9 F  C* `' ]' j* l, P7 E        # 训练和验证
5 l. [1 C  N/ A' c        for phase in ['train', 'valid']:
/ o1 `6 [- o. I" {8 A            if phase == 'train':" f0 B+ B0 R2 t2 D, J
                model.train()  # 训练
& H) o( p/ b/ j. k; B            else:
0 C* V+ Y' n; B: ^7 T                model.eval()   # 验证0 N; ]8 }7 O, }  p6 w

" Q5 x% w3 }& w( A8 Y' p            running_loss = 0.00 _* J8 _% w3 `2 k  ?9 |
            running_corrects = 0  w0 ^' }5 @9 z2 M" q3 \( f& l9 a

  O" d4 J; _# ^3 X            # 把数据都取个遍
! J9 U3 k0 j) G$ K8 n. M            for inputs, labels in dataloaders[phase]:
* _1 }3 K, L" N) i0 `+ I                #下面是将inputs,labels传到GPU
3 ]& r# u& U+ y1 V                inputs = inputs.to(device)
) Y$ R6 l% q: P) I5 m                labels = labels.to(device)7 v; ?# W9 T$ `3 Q$ @
) I; _/ N. l0 y( {
                # 清零+ u4 ^; ]# L1 I6 G* a/ a2 x
                optimizer.zero_grad()$ Y7 m6 n3 E5 Y. l& `& s
                # 只有训练的时候计算和更新梯度( G: L$ r0 X& a8 F
                with torch.set_grad_enabled(phase == 'train'):0 |# \% C! D# j" Q
                    #if这面不需要计算,可忽略
. h2 J( e$ B6 E+ w1 M3 ]+ l* b& `                    if is_inception and phase == 'train':* b" E- c$ L# p
                        outputs, aux_outputs = model(inputs): o) Y, d* N& ~0 @: Q( ?
                        loss1 = criterion(outputs, labels)( N2 r/ I: m0 S3 |9 a
                        loss2 = criterion(aux_outputs, labels)' O1 C$ N" Y4 X9 }5 i5 B5 @3 z/ Z
                        loss = loss1 + 0.4*loss2. d5 v& p- J! |
                    else:#resnet执行的是这里% p, w. a+ ^# N. A  l
                        outputs = model(inputs)
5 l$ @# H, K9 ]3 D  ^                        loss = criterion(outputs, labels), K) r* y3 o% H6 M+ v3 W9 E9 X* g
% _# M3 v8 ^4 a* b$ _9 a
                        #概率最大的返回preds
8 g- C# p1 Q# m# W5 N0 p. }7 l                    _, preds = torch.max(outputs, 1)3 ]; Q6 I  u0 \, {' }3 n' ?9 k2 W" B

5 ]: v2 f2 P( }& @9 \                    # 训练阶段更新权重* _7 W% X( p# c  g$ j$ V1 Y4 C1 H
                    if phase == 'train':  ~$ _1 e: `$ J# U4 [
                        loss.backward()9 _  m3 O+ W; B$ a
                        optimizer.step()! ?  [2 N" h! h1 M5 F

1 T! M# Z5 |6 i8 X: G& S                # 计算损失" m: J0 p2 P' ]1 T, j7 B1 k- x4 r
                running_loss += loss.item() * inputs.size(0)
' y" ~8 M6 S7 k                running_corrects += torch.sum(preds == labels.data)
, l8 _* @  f6 h8 {: O- {: b
" k* `' ~  J' o# g4 D7 b" W" Z            #打印操作
0 N; i7 n9 b- E            epoch_loss = running_loss / len(dataloaders[phase].dataset)
. Y0 i4 M) Z; }/ ?. p            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
* r& L- ^1 ~' L% A' R- h& l
/ ]7 V  N! a7 a' g* u* Q0 k
8 w. Z6 }# \8 J% [8 j            time_elapsed = time.time() - since
/ q: Q: c5 d% P5 H/ Q- x1 F: {            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))* `0 G3 V/ U# Z+ I% g
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))  Q* e  v3 F& X
8 K. g' n0 c: d2 C
. W& c3 P1 ]: q! N
            # 得到最好那次的模型4 E6 }; C9 X7 w% O
            if phase == 'valid' and epoch_acc > best_acc:
6 y/ [2 r2 i' c. e& u/ M                best_acc = epoch_acc: i# G: i% {' [5 X
                #模型保存
5 f' |" S& m  k0 I; B) n5 [                best_model_wts = copy.deepcopy(model.state_dict())# ?- g3 i, G# ?$ k1 [& b! L+ N
                state = {
5 b. @3 {6 d8 r4 S( r                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数
2 _( F4 ^- R+ O6 q* N3 }                  'state_dict': model.state_dict(),9 o1 E' w# G& ^* d( h) I
                  'best_acc': best_acc,; L. |$ g- F& w" v
                  'optimizer' : optimizer.state_dict(),7 e# R* N5 h7 M( f7 H) C
                }
  |2 |# y" T4 x: G& C                torch.save(state, filename)+ Z! I) ?  z2 o5 ~
            if phase == 'valid':
( u6 `" n8 H) N# U, {                val_acc_history.append(epoch_acc)- c, \1 G+ N4 c6 f4 J  k
                valid_losses.append(epoch_loss)
$ p; q# ?) g: S7 F' r# s                scheduler.step(epoch_loss)' f) e( @8 s" R  _0 s2 y
            if phase == 'train':3 [6 a! |- h( [" u
                train_acc_history.append(epoch_acc)
$ y( F9 r! f# v                train_losses.append(epoch_loss)3 A% s4 f) q8 q! V  `: n
6 r5 m/ X; ?+ d; U, i
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
1 Q; `% x  P- I6 U        LRs.append(optimizer.param_groups[0]['lr'])
- h; s9 B3 w6 M  \: ^' [3 E        print()9 T! f0 e) B! I4 [" s- X( {

7 U9 O& t7 {  v* F9 d    time_elapsed = time.time() - since
0 ~& g. `  g) E3 v2 h8 z7 R. O    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
+ @, ?- G$ M8 m5 j+ M    print('Best val Acc: {:4f}'.format(best_acc))% |( A. I3 o. m  Z' E9 j
2 ^% c! g( R) y* s
    # 保存训练完后用最好的一次当做模型最终的结果6 ~# w2 N* b; X# D; ?( q
    model.load_state_dict(best_model_wts)3 g7 Y9 m( d8 H- [2 P! R
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ' w4 a- G! `5 F, m  N$ a8 g  s
& x( N2 w5 n; u2 v
$ I* f7 m; W; B! g8 [  V
1
/ _5 \- t+ ]3 p4 ~2  Y. ^* x5 y1 K' E0 L  W5 m/ F
3+ h- J1 E4 P+ ]/ G. o; R" H0 {7 P" D
4& C: E% [4 J5 c) \( q0 g% A
5
+ C( i7 y$ i$ s' l60 u9 D4 p; `! k5 Y, A: y
7
/ {+ I3 t2 q: e8 _: {6 \/ K0 T. k8) i* Y! i' m3 y2 ]
96 x" v8 x! G9 }/ p- b
108 x. D; h7 {9 u3 [
11
9 `5 o& f5 n1 x& X9 r7 o12; e, U  d2 ~* e8 T9 ?
13
6 `; i9 N+ J2 T7 g" T) e9 Q! V% c) p14
) O8 G) F1 X6 k3 B! N2 B3 F151 L7 G5 y# _# f) x
16
. V( e- T5 a" ^4 [6 R17
# D; i3 d& b7 B* w1 [% K: I18
  C3 a- J3 @; J# S19$ L) r$ o; u0 u
20
+ a# `% Z' ^( v2 _' _3 J, n- Y* \21& w) E7 B7 }. n
22
. {  n0 H) `' H6 g/ a( u: x, A23
  m" s3 V4 W: E$ {" }6 z& S2 U5 h245 G6 V' E* \) ?: C9 b) w
25/ d6 z- D% z( f
26! d* ~: t# A  X
275 e% H) a0 [# L
28; [- J0 I4 Y+ K! ]4 P
29+ P( f& n) u5 P; h- @
30
1 k9 n' ^+ `3 c3 n" b315 A3 R% ^! E; E! x" E2 N
324 x5 e* ]4 m/ M# L. B" X# X5 N' o
33
& ]) `6 L4 J, Q' C34
$ }# ?0 p2 K) r35# a' G* `2 C8 _8 B- v' q
36
/ d; M( u; f4 F3 F7 s/ A$ |37
/ ]$ g- V0 L3 [- b  K: Z  p38
+ m/ w3 ^/ I" z+ `* A2 p9 m/ O3 v39
$ O/ U! D' X5 k, w/ }. V3 j40- }8 `8 u' G+ w& o# ]$ |
418 j- l: T- g% ~  y
42
' O* Z3 P' D+ X6 a) O: r1 W! s& e434 f$ j6 Z7 e, v+ |, t& ]" g. I
44
- C: l/ J0 r. x! L45
8 B% o$ N7 |* T465 x7 u' X' G. V$ r. ^9 G, H' R" [
47
+ U7 _  n: `& P5 P/ f48
* d# m/ [& c5 B1 f" u49$ \) D- i1 i% d8 \$ j9 Z6 _7 F
50
' J. b" Q$ N6 P5 {514 [5 X' |5 t0 v# w/ n  @1 g. l
52
# D, S9 Y4 J6 E1 m53
1 X" X" P+ a* N' x54
/ p8 j' H& ?, q- R2 g55
! ]0 F  G8 j% x" m56; y+ Z( Z6 R4 P# \( L5 K8 \
57
+ b1 G) ~; q: D: P3 L/ l5 a58( L; m8 d/ H, a5 z
592 {# F( E/ k- r& |. o6 k( z
60
- Q" F& e- o% g+ i61+ A9 a" j  d2 x7 l4 Q9 I
62
% L1 h  B! G! q9 e63
, U( E, h# [1 y. r  O649 `6 g0 t& x5 l7 v! K0 s
65
# o& V& ~0 ?% `  _  q66- o. h% L; C' S
67
' R' K0 [# {. X2 M, b9 V1 i68
& Y) l- z( V; l69% r) q% a4 ?* n* l
703 z" c  |0 `& X7 e' R
71
4 \/ r* ^/ f! T" v- ]: b8 k720 Q: ?' ^1 W: w: y7 L
73
4 s6 B- c" y9 h6 ]$ S741 H  R6 B( F( n3 p$ v$ ^6 Z, g( H
75- N# M! u( h, h# x5 q& k
76' e% M) h" @9 b( c- z$ Y- S
776 A9 L# n) `# b8 Q3 f( q5 a$ B
78
: K8 g& Z- [. R' @5 v79
' |) _' e8 _- `$ Q! g80
7 H* d+ \, J/ U$ \7 ~. d81
( T- T  w- Z" _  F) p! q& f82: H, A" Q4 g5 I; D
83' L. h7 D4 N& u1 Q9 q
84
, H( d0 }- B+ \85
8 Z& \# }. U% w5 s% [# V861 C/ x2 I$ d/ d% j  }5 U$ c
87
0 J, V4 T" K5 J. B88
7 C8 W( B7 Z1 N89
2 z4 P! a2 m. w, T- r7 S90
1 T7 K# H7 |& F( D7 g91
* S) b$ u+ X3 j92. l: d/ L, [5 S- j: I* w$ N
93
/ `, U; H. n: [1 {9 U94
0 H7 ]. W3 [( l8 p/ E9 ~( I95
7 n" {4 O$ {1 G& s4 X6 g* G( i96
. e9 N. p) Y; {97
& J. L2 y0 {" k$ M5 n* A98
; Z8 O( E" N9 j" w; w99
8 @7 `" y0 O& c100
) ]! a1 D7 Q$ F6 }  D3 @" _5 `101- H; Z$ ^, I3 U: B1 _
102
6 |# ^- r* e0 Q% i103& y2 g- l6 m- q$ z5 D
104
6 w8 O! [6 ]. b" K  m7 m  v$ s% Q105
+ W2 P& Q0 I1 J5 Q: \106
/ e1 \+ q6 a( e# @, T, s% N5 d8 a8 G1076 I8 G. M* L7 k$ F  G
108
( ^5 Q- {; E( s! ]0 `( @% i; w4 D0 `109
, f5 [8 ^8 n2 z0 ^; s* j' T8 h110: o% l  v) o) q! R. S+ J
111
; z; Z. ~6 t9 H' a112
8 t% y: j  v% \, C4 x- X7.2 开始训练模型) s" q' G  a* H( q! n1 z
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次" `2 d5 M6 ?8 D  O

+ s% S0 [2 R1 `, ]#若太慢,把epoch调低,迭代50次可能好些/ S* h. `9 i. Y0 y; P1 c) V
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合' d7 A; ?& {5 |" n. N: T0 c* V0 k5 V
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
: G" E9 C# r# I* R6 O7 O0 u0 J# `7 V' ^5 }! |: y; B1 ?
1
& ~. g' i6 w; q& f; x5 y2
; f+ \/ s: o& o, f. _3
* a# M- V# g5 o8 q' V9 c* M2 N4, K! q  `, M7 |1 F8 h
Epoch 0/4
* o6 I: q" `0 _9 p% r5 E$ c) |----------
! t/ X3 E6 s  R! h2 J' \0 qTime elapsed 29m 41s; l( G0 p) L% Q% x  a( ]1 u3 ?" o
train Loss: 10.4774 Acc: 0.3147
2 |, M  D2 G2 HTime elapsed 32m 54s
7 e' G, S7 [8 y1 K5 u( Gvalid Loss: 8.2902 Acc: 0.4719/ [9 B2 i* o6 r4 x; {0 q  C& S
Optimizer learning rate : 0.0010000
! x3 y# x" z) f/ T
5 m2 w) u  K. b2 H5 O6 L$ u+ y! QEpoch 1/4
8 p, w" z0 q2 c: z4 U6 z----------7 |  F6 b- S2 c! M0 Z
Time elapsed 60m 11s* ?$ [( H8 e* U3 C5 D
train Loss: 2.3126 Acc: 0.70539 \9 O1 r4 |7 ?4 T
Time elapsed 63m 16s2 M$ e7 ]) G0 z+ M; [" L; H
valid Loss: 3.2325 Acc: 0.6626
5 M1 i& Z& k0 h4 L: TOptimizer learning rate : 0.0100000
/ }+ @: P+ p( N9 T. j% k, `: F$ {
Epoch 2/42 k2 D+ D$ W- p2 \1 l2 C( Q
----------  @' Y" _! \+ |, L: w% m7 O% }+ K
Time elapsed 90m 58s7 c4 i  Q0 s$ _4 N+ s- t0 J
train Loss: 9.9720 Acc: 0.47344 k( J0 w/ `1 b9 w5 r$ {8 [
Time elapsed 94m 4s) H& c- a! p5 f8 o/ I
valid Loss: 14.0426 Acc: 0.4413
% O% c5 j- s; B; O. bOptimizer learning rate : 0.00010006 ~3 \& ~8 P; f$ B. o

5 B" [+ j+ L4 ~, ZEpoch 3/48 [0 i* o7 {9 G5 `/ }0 Z  I
----------
% P" a7 W6 H' M, R3 {/ I' n: _# `Time elapsed 132m 49s
  p( a: m- y7 A* ^& v) j# I4 z7 J. Vtrain Loss: 5.4290 Acc: 0.6548# w3 a4 ~; l, o% d7 O' R
Time elapsed 138m 49s' _4 n6 t0 t4 |, T8 h( |
valid Loss: 6.4208 Acc: 0.60274 W. E, b" x  f2 L+ c
Optimizer learning rate : 0.0100000
- ~" R& S/ k3 G  H
& w8 r+ y* R- G+ Y) k, R! C; r& kEpoch 4/4
& N! T. j2 s3 @2 @) `( @----------$ p( A3 M# e1 P
Time elapsed 195m 56s: I  P0 Y7 {# @. j/ U: J/ t
train Loss: 8.8911 Acc: 0.5519: Z# q' n2 J4 S7 J0 w4 e3 G- \
Time elapsed 199m 16s
; P. y- V; W% N/ K! Q5 [valid Loss: 13.2221 Acc: 0.4914
( v9 b1 ^! ~# q1 w1 ]6 [4 m+ v+ qOptimizer learning rate : 0.0010000
9 Z" ?; _8 P8 J' K* Y; H- j6 _9 B* r; y# i1 I& I& C! f5 F* a/ _& [) w9 F5 g
Training complete in 199m 16s
9 m3 ~* n6 @: C" `( U+ iBest val Acc: 0.662592
% j" q$ W7 E% Z7 A7 `! l
# H% X% x0 w+ l( F" z( m  X3 T$ A1. p3 _* V; `* R9 _( o' c8 H
2
) a; m1 y, Z. G+ j7 T3; N3 r6 B3 k9 u7 p: h
4* `& u- N* p4 w" ^' k( E
5
- t8 c( l7 y7 i* e- ^' z9 M6; k+ Q( R+ h6 k7 t$ Z2 n
7: X* i! {5 o" \# R
86 u" n5 m* @+ k6 _
9
5 s& b. p/ H0 Q+ N5 g6 I0 ~3 e106 Q" E+ s  u- [# q* U" u& Z8 [
11/ P9 Z( _% R0 Y- x4 v1 U& @
12
2 Y" \& g; o: e. p0 c. l9 c8 R/ v5 O13  O/ K7 b+ {9 \* c2 v
14
$ N# ~4 s' @: \' Q15: Q% ^5 j3 D8 s
160 J) R9 m& `% o; ^7 @7 p
17( {* h; T2 `3 c- k% c$ @& B
181 n2 Z1 d+ i, S# p1 V
19& e( Z" p" M6 s/ ^; D% I
20
4 F- L0 l& s# s9 W' t/ l215 s* v  ?& p- \. @8 ~) V. H
22
: r! ]. P5 k7 Z7 O4 w0 u" o1 N23
) V6 X+ X$ N2 C* Z# C" T% W24
4 \" P! l7 z7 ^) a6 u1 v7 d4 x  z/ b+ R25- L; g; T; H9 u1 J- G. I' j
26
$ H. F+ a1 d; @  m' ~  X  }9 u27/ s. X' \. J( A0 [
286 @( A/ M/ T7 n) |9 i' s1 a
29/ F6 I( F  l, h/ k. I
30
  j7 c+ D" A8 f  `/ m' L  `31
: N& R2 f' L- W, V/ Q32
1 E9 r# Q$ I) v% |4 P& v4 e& X) l33
7 ~' [# S8 c- G34
! ^7 L$ s5 W+ }$ Z2 ^" T: e$ v354 g. s* v4 P" I% I
36
* U9 y* {+ S: s3 A4 `* U% {$ R37( t% C- K3 Y& F. u" ~- v
388 T' E# t, j8 N, E0 @& j% b
39/ i) ?9 {- Y. O" g
40
3 `% }0 ^. c) |* m41! k' {/ g) S3 ~" {  U6 P+ H* X
42' N" _1 }4 c3 Z8 q- f/ |- \9 ~0 U
7.3 训练所有层
8 t4 Q1 b+ N7 i6 F6 I7 q# o. h# 将全部网络解锁进行训练
7 S9 I# M8 F1 n) C6 Qfor param in model_ft.parameters():; s+ \, ^5 L, I$ g: N# p! |
    param.requires_grad = True; w2 n0 t- D) J, T; o( F7 ^
) O9 S6 m% W* ]
# 再继续训练所有的参数,学习率调小一点\
# E( H* i+ `1 F% }  b5 g; f. O4 loptimizer = optim.Adam(params_to_update, lr = 1e-4)
  x! y& o) N% uscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
9 @, _! ~8 n: F, B4 H4 m5 [. w+ {. L( p2 [' X8 N2 |; M
# 损失函数
, z. o3 @: L' V0 o1 V) ]2 z& ?criterion = nn.NLLLoss()
! D$ Y$ r+ x+ J1 @9 X1: @% K9 S/ X3 {* K1 }  t8 z* V3 [
2: w& s' W7 \% S3 d& _5 X/ ]% ?
3
' l. s, j2 J3 s0 k3 Y4  C) p7 s. y- ]8 G3 B1 i' W1 @' N
5+ u5 s) ^+ F6 K- s' m3 W
67 b6 }6 r# y3 D/ h9 B! L
7% |* {# K! N+ B0 n5 }/ X3 t
8- f0 y, ~( h$ G* A; z: R  u/ C
9$ p8 p4 j* y. g  ?. x2 m# r) @; I
107 H3 x; P2 x) k, L' P# Y' J! V
# 加载保存的参数
7 A- E5 @' p0 x2 B# E8 @# 并在原有的模型基础上继续训练9 c6 _$ {3 X& y& `
# 下面保存的是刚刚训练效果较好的路径* S" H2 |5 j$ K+ Z2 b3 U" R
checkpoint = torch.load(filename)
1 _9 [. K# Y; V9 r2 Obest_acc = checkpoint['best_acc']) L5 c7 N2 p5 _; ~* p' [. L
model_ft.load_state_dict(checkpoint['state_dict'])
# Y/ D) K' _/ I6 voptimizer.load_state_dict(checkpoint['optimizer'])4 n! @7 ^' s: V% G7 h
13 \5 {7 f+ s$ M6 H- e/ L
29 j- w2 ]* F& m+ u& c
3
: K6 b5 U% x. G/ t4
% Z. J3 q- x- c4 l8 r" W% Y5' {: \7 k, `, I4 Z& y
6
: z, A! \# J2 }# G9 Q7
" e) }+ |/ A  E开始训练
: G; g2 l: O% k4 Y7 I+ j) L注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
4 Z% W! u& X( \4 y/ b5 Z5 }1 G: S% T. N; W: D
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))+ Q- g5 @2 `3 b: [2 p; }
1
5 w3 j: Q. \: I/ {/ B* ~" SEpoch 0/11 i9 g2 v- K' v+ W  p# |# a6 G
----------2 d$ W3 W9 y5 C
Time elapsed 35m 22s
3 d, l- T5 ~  b" q7 G# }train Loss: 1.7636 Acc: 0.7346
4 f& a2 W- F3 PTime elapsed 38m 42s
2 E8 Q( B4 g3 {9 k. \- Qvalid Loss: 3.6377 Acc: 0.6455
0 K" d, f0 _! B( B( Q# MOptimizer learning rate : 0.0010000
% e: j" L* N: S: U; w; i. c+ B/ w" ^- s) v/ H: t
Epoch 1/1
+ H* t1 M' T6 A2 M1 T: u) X2 t----------4 a3 A* a2 X0 j+ V% M6 V# s4 h9 t# `8 e) h
Time elapsed 82m 59s
6 P3 ^0 ^. |. m% ztrain Loss: 1.7543 Acc: 0.7340
# c  G, i0 ~+ T! P) v; N) [Time elapsed 86m 11s( b( j0 A  b) y* ^; R
valid Loss: 3.8275 Acc: 0.6137
0 P) p: S5 L! M; N" d5 b5 `Optimizer learning rate : 0.0010000# R  w8 z- F6 D
- i: f* I( N7 R- T0 z; O: d4 h
Training complete in 86m 11s
9 ^9 `4 ^8 U; g. v- V! k1 yBest val Acc: 0.645477
* V+ e& @" B( Q7 s* U! \7 f
5 T4 V' q* g& t; o- {4 F. J+ g1  [1 v3 u! T6 b  d5 N3 O5 T) O
2: I  t2 ?' M/ }" o  n. N
3
7 r/ p- R/ D7 x  {2 L4/ _! h. R& |- X' p$ f
5
1 A1 b: z+ t  Z& Q6  ~# a- Q9 }) m7 |
7
2 m3 D& [. h0 R9 Q# w2 L9 Q8 V8; ~$ o# @8 }9 d+ y+ {; T
9
' V' V! O2 Y/ @/ l- a10( f2 |) A7 ^$ R6 Y) ~
11
! Q  e* N$ ^: L: M  C1 y& y121 n. P& X1 Z$ G' d- I) v
13$ O8 k, C# q2 g
14
/ q' K- @9 m5 p# I1 z150 x: u- \- E! F; \: E8 ?3 V) c+ g
16% J9 t; |& S, a, q) x' A
175 C* H! [( y& |
18
7 ~, ^9 E% ?! B9 j# h3 F8. 加载已经训练的模型% E6 X5 B8 c) T( n9 B2 p9 w, ]) ?
相当于做一次简单的前向传播(逻辑推理),不用更新参数6 ~/ M5 H6 x# l3 |2 X
  c" S+ B/ |# Q5 y- k7 Y7 h
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)" K' p2 `6 f. c
6 M; r9 i1 ^$ C
# GPU 模式
. F, j4 s2 W$ f+ a$ J' amodel_ft = model_ft.to(device) # 扔到GPU中
% W) f8 o3 s$ B' y8 ]0 q
3 f$ {3 H# ~" L# 保存文件的名字) B8 O+ k4 c- l8 g
filename='checkpoint.pth'& p& I* G& m; d9 Y
+ u% N, S- L! e  ?, ^$ A
# 加载模型/ Z8 c( y: |  I$ }0 @
checkpoint = torch.load(filename)* d6 h. l& t+ Z& g
best_acc = checkpoint['best_acc']
# n- j# u7 O# w) _( o/ Kmodel_ft.load_state_dict(checkpoint['state_dict'])
4 c7 o6 C8 P- c, y1- b( k0 Y- Q: n0 s0 R. e, a" N* N8 {
2
  M8 @$ g0 F& W' C: e' ~8 m3
5 q! @8 O9 W6 y& ~* h$ r4
6 }6 \! r; F) f9 ~* i5/ C$ H0 Z+ a1 G& S% M4 ]
6+ S( E9 D+ K+ a3 F
79 ^! [: l5 F" l! O, `- v" W7 b
8
, [; K, f; {% f3 b+ X8 l+ Q. T9
, x( v7 T+ ?' z# f1 K) S) Q) t10
8 |( s9 ^6 F1 T6 n+ g11
/ F4 v! U0 m; e. ~12, P; p0 i9 m) ]) G
<All keys matched successfully>4 Y4 a8 c; e" V- h' f$ Y
1
; ]) b( K( [0 K4 b, Wdef process_image(image_path):& p1 V9 K  s& Q; Q8 ^* G1 u
    # 读取测试集数据
! m" J( U- l; K8 @    img = Image.open(image_path)
' J# Q6 Y, w6 V/ T2 m    # Resize, thumbnail方法只能进行比例缩小,所以进行判断8 B4 `6 c) A4 ~' O; F4 `9 W
    # 与Resize不同: m3 }0 i! {+ Z& ^% h) }, ~
    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小% R& v3 R, a1 `9 S0 ~6 w
    # 而且对象调用方法会直接改变其大小,返回None/ o2 |4 V6 ?2 G, Y
    if img.size[0] > img.size[1]:
3 b! a) ]' q5 d+ Q) x& w        img.thumbnail((10000, 256))
$ }4 ~* B! h' ]    else:
, t" Q: l) [' i8 l) I        img.thumbnail((256, 10000))% {3 t  j2 A8 }8 m3 G- k
$ v3 [  z$ h4 D; N3 Z
    # crop操作, 将图像再次裁剪为 224 * 224
2 E- |7 n. w/ P3 l' v    left_margin = (img.width - 224) / 2 # 取中间的部分# `+ }0 r6 V; u
    bottom_margin = (img.height - 224) / 2 ! l( j. A& H% e
    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度4 {  G, ^1 I4 H% S2 Z) Y
    top_margin = bottom_margin + 224
( N& M9 r: p! _2 l
; S4 }7 t. M8 g: E. q3 }2 D    img = img.crop((left_margin, bottom_margin, right_margin, top_margin)). E, n! b( H* l" R$ O& H
9 b9 @) c& ?6 u* m& L. h# h
    # 相同预处理的方法; }6 d2 _7 n# y3 t8 I
    # 归一化, _4 |) g% C! l5 s( r# s
    img = np.array(img) / 255
1 Q- _* o* E: ^$ ^    mean = np.array([0.485, 0.456, 0.406])
$ V' |6 k. [3 P    std = np.array([0.229, 0.224, 0.225]): ]! j3 w4 P5 [7 ]
    img = (img - mean) / std. w* P% ^4 E. T% k! f

2 r$ @& i& _3 N5 D7 i5 V    # 注意颜色通道和位置8 w# R# M! Y' Q" g: o0 Q3 c. @
    img = img.transpose((2, 0, 1))
) |2 ~  f$ u& ?, E; L4 B& ], ?0 D* o1 J: A  N
    return img! y1 d6 Y$ r5 v/ X
. R. H  }( P5 x; b
def imshow(image, ax = None, title = None):
8 t8 M3 e5 w. e& I; Y    """展示数据"""
: ^4 Z& n7 A8 N+ [2 a$ U    if ax is None:2 E/ c. w0 J0 |& U' |
        fig, ax = plt.subplots()
# }3 r0 X5 [% d: }4 P( S0 ^# z+ I6 D9 W' {
    # 颜色通道进行还原
. n2 M/ s/ X" T- R0 l0 P. p    image = np.array(image).transpose((1, 2, 0))3 f' O# w: n, O( }0 x# v' G" f5 F

7 ^2 [1 k9 ]. h& I$ G9 V    # 预处理还原
/ }( z& }/ S3 r" R    mean = np.array([0.485, 0.456, 0.406])
5 s/ W. O. v' {; k0 u0 S: W5 T    std = np.array([0.229, 0.224, 0.225])
+ a6 `" u0 F9 v$ l& Z    image = std * image + mean
& Q, z0 ^" @7 J    image = np.clip(image, 0, 1)" |4 K% w. y# E# D6 |) U1 J3 I

$ N! A5 Z; i. y3 k9 l) R8 @( g    ax.imshow(image)
0 l. U* w& S/ L9 Q# Y, H0 f    ax.set_title(title)/ a* y" q. c; H  B. l+ Y. u

) n6 ^! s0 b2 R9 z$ J    return ax
/ g+ e& Y- ]* y' Z5 b
- s7 o* K7 H& h7 J2 ]/ ]) fimage_path = r'./flower_data/valid/3/image_06621.jpg'
5 w  f; M/ ~0 T6 w) z, wimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理6 D/ B; T1 s# k2 J
imshow(img)% L# }5 E) Z' S& ]. n
/ m6 i4 ]! I( [1 [" E
1- w. [9 m4 y. D
2' `" W( v8 y. V7 V) g
32 s. z: y; M6 `: P! x
4
, ?/ C1 b8 ?# N. L/ W5
" x* F7 l, H' }6
( ]$ Y* z9 W! z% m  h7% _2 o- s) {1 ~/ z. g& w5 @3 K, c
8( L9 q" @' @6 i9 |. u$ W* U+ [
9
9 ]6 o( e" G0 l- v( i; }8 [8 h102 {" ?" n: F" R) ~9 e2 y5 v
11
) `# Q; c  ?! S/ k1 P/ E- q8 K3 l3 ~6 H12
* A1 w& O" Y6 ]: I8 e1 I133 ?5 K% Q  H+ U) v( U, q9 e
14# o2 g6 {7 w0 h. q7 l8 m* F" n2 P
15
5 P$ d/ g# S/ P" ~: a16: l% Z+ O" m3 @9 u
17
1 I; q: y; j* P& G6 p18+ N% J5 J( O' P, q3 H
19
5 `  G; b. `/ h  f206 Q/ ^' R! M2 u$ l3 i* N  @# \* c
21% m5 {9 \6 {8 s" S' D6 b/ H
22$ V! L. [" J. g
231 _5 {" v8 R  N1 f: n
24
4 n6 C3 Y, L* G  ]" d0 s0 i25
0 s3 g* b% ]3 H: q* j4 u. e26
, a- D2 R0 l, f1 }27
& M3 I% g1 _5 I7 O+ H8 @288 t- i; R- `3 T: _* [! A
29! d  v% N) A4 `& c. \) k
30( |& k5 p3 `8 u: D
31- f. D8 O* w4 z: t4 J8 r
322 ~- a3 `! a  g) I
33
' ~. H% U/ L# O/ a1 I34
" X, u3 \# s, R! j  H0 b1 m8 j35' a- v' \+ U) Q
36
# P6 n! `" G% t) B37
. ^1 z5 K/ v) F2 S, d7 l38
# z7 j. [% S$ F4 D# A0 G9 D" g7 t39
. r5 o) ^, @5 J1 l+ N* _402 S0 Q8 A" b% Y. s* V
410 Q6 n' C/ |" ^( m* y# z% I
42
" f& j5 m, o2 N* ]43, f- i1 q, a; J5 s* g6 H
44% H* T; m+ V, i! \; P% |
455 r. J/ x5 c3 q
46: B$ r! J9 N' V+ m& |% ?
47
" h  `" \. s/ `4 E3 m0 P' [48
$ R: F5 R& d+ C0 ?49, R* j/ o* F$ p3 q, R; `9 o. p! f
50
: B3 S2 y% f1 ?- m. E& q! @513 @& ~" [% _$ \+ L8 K# U7 z
52
" J' ~5 v3 w* N' ~# ?( D/ B53
+ k1 L5 p2 `( i54$ k$ `7 z: x: R7 u# Y9 O) F+ V
<AxesSubplot:>  `$ G' i' N# x+ g1 ^
1. M8 g' O. I+ O+ b3 s$ G
: C# l! `) t8 |$ L; X: Y7 }
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
) w0 U$ A9 X1 H! Y2 g2 {! S+ Q5 F0 d7 A  v
img.shape/ L' c, {- _9 W" b/ T. t: w
1* T7 s7 Q+ k3 B& B
(3, 224, 224)8 B, B9 A: ?, ?, x$ b, M6 |2 n
1
* a# S, {' U5 K7 K8 x证明了通道提前了,而且大小没改变
$ j3 y! Z3 B9 i" c
- l% C1 c. g! t1 e. Q9. 推理
  X( X6 ^$ B( c, o+ vimg.shape
, n; R- c3 F  T" K( Q2 A
9 n" d* U5 o9 G, b# 得到一个batch的测试数据* {# o! B' N) \# y2 ^' w7 u
dataiter = iter(dataloaders['valid']): s0 D; e/ W3 r" x
images, labels = dataiter.next()$ k5 t" O# G& Y6 m9 }$ H* \

# w9 {+ j1 m# [' }) c/ Cmodel_ft.eval()
1 [+ f- E/ F! z6 t$ o* B
# X* [+ p5 I3 p$ |: ?% }0 K& P1 Yif train_on_gpu:* F' ~. u: G2 l  g- N
    # 前向传播跑一次会得到output
0 w% K+ P6 t) b# R    output = model_ft(images.cuda())$ l5 b, y( s* u% l% r# F
else:/ l1 F+ |: @; w$ R* ~- {
    output = model_ft(images)
0 I) }0 x9 j2 ]* {
7 O* D/ `5 W8 z9 _$ h" _+ Y# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
& V# s! y2 q# h4 R8 X5 E$ Moutput.shape  e# _8 }+ d$ Z+ `8 N  k; n
1 ^8 j) _8 Z: H! X
1
1 `( m* U8 l: V' O2- Y. S& ]  Z  L1 X4 \
3! B( x8 P, z4 D- _5 M' r* |; J
40 `- p) d/ s+ n4 X  i" ?
5& j. b; y: b( n/ M5 h: |! j0 G
6: Z$ z6 Z/ h7 c
7
$ e1 Q) e( S% D& `" g) L- o8; \( [$ r8 t4 [# P4 C
9! D) o9 x2 m1 f! \+ y6 A+ `
10
' U2 T) `7 l$ E; D11
7 c5 A9 ^& ]* _( D12
! L7 p( v, ?3 @6 E" \132 \" z- {' g- V0 c
14
; R6 `# D; x3 S, |* J# P0 d7 o15
% x; ~+ ]9 ?+ Y6 m5 J: f16
; N+ N& G( b/ z& a+ qtorch.Size([8, 102])
+ n; h' \! @6 D/ v1 v4 W1; X( m. b4 _( S: m# S$ ]
9.1 计算得到最大概率
) Q1 Z/ W# G/ Y4 F_, preds_tensor = torch.max(output, 1)7 J+ |- {/ S, F) ?/ K

4 t% o$ ^" j" u1 F# p2 kpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量) b# ~7 g( ^, l  J$ i
1' O, g, t" b( P( S. O
29 q0 z, B1 y+ M! ~$ C
3+ q5 i) |! C2 }* \- G( _; B
9.2 展示预测结果9 x, w5 o, G9 M
fig = plt.figure(figsize = (20, 20)), y  c; t# ]  u* G$ L1 M: L
columns = 4
. T3 G3 k% a8 k% u- N+ [rows = 2
6 w" G. }5 a* J6 G" f
, S5 F2 ]1 V3 B5 o/ g4 y2 z8 U2 Cfor idx in range(columns * rows):; M0 T$ X  ~7 n4 b, ~8 a0 x
    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
& u, Z2 \* N2 o- B+ C4 Y    plt.imshow(im_convert(images[idx]))) L6 W4 S3 `& U9 Q
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
- J9 }) c7 E; o0 ?2 U$ e; l                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
- U  N) U) b3 M# h" cplt.show()6 K. [4 B5 i2 Y+ F6 `
# 绿色的表示预测是对的,红色表示预测错了5 ~5 _0 O" e) U/ S1 S( ]2 ]4 P
1
) F, b" y6 J3 _! d$ L8 Q  T2
: C9 [: P1 A3 s# `3
: A& x- Y- U  N# w1 f4 y, t4
! H: G/ K1 ]' G& `4 q1 t5# }1 Q+ m# m& q+ y- B" k
6
( v. B% z# q0 `7
# d6 i. ~: L8 U; \82 x  E% G8 x$ A- V, h% i$ A; n5 O
9
: ?" I+ m7 `4 C: P( }10
$ @$ g) F; \& W$ }5 q1 g8 g11
) j3 i; N) x) X' t9 ?9 D: ^7 D& R9 w5 {; e
  ?$ M7 Z2 I1 u5 S2 Q$ b

9 u' H, y6 O; b1 o1 q8 Y- }————————————————# |; B7 U2 n# }/ Y: ~& {) Q) F  V
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。* V1 g( R' x9 i; p
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
2 A. ^% J5 G- _" W
. P6 U# Y$ w/ |$ s9 |( m4 z6 ^; `' i, O# v$ ?; Y7 S





欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5