数学建模社区-数学中国

标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例 [打印本页]

作者: 杨利霞    时间: 2022-9-8 10:41
标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例6 C2 c9 M; u3 F) y) F6 g

) j4 x2 V: I+ f5 h) k9 i9 m& Z文章目录
! h/ y0 F6 I) {! _卷积网络实战 对花进行分类, h$ ^5 p/ d9 @6 D, b, z
数据预处理部分2 ?% k% X$ v. v. Z
网络模块设置
, y/ |* d2 p" ?) X4 o网络模型的保存与测试
) Z: o2 C2 y7 u数据下载:
! l: v4 C1 M# q4 Z' _6 m1. 导入工具包* M! N, _1 z  @9 Q
2. 数据预处理与操作
' I8 i5 i9 E$ e2 o3 Z3. 制作好数据源+ Q" A" u# O% e9 c* R$ n* W( a
读取标签对应的实际名字
7 B1 A' N5 t! ?7 _5 c- F4.展示一下数据
- o6 N$ {9 U  a; U/ o, j5. 加载models提供的模型,并直接用训练好的权重做初始化参数( }* G. x* W' p
6.初始化模型架构  c$ _" g3 J5 q6 C8 Y
7. 设置需要训练的参数
% Z- p! M* D' w4 l' d" c7. 训练与预测
$ B1 P$ D2 E4 ~3 P+ l  D7.1 优化器设置
2 j2 O( V& }3 G& k! N7.2 开始训练模型; j% [5 g* P. o4 P
7.3 训练所有层, S& _# U- d1 I6 s0 A- z
开始训练
! v6 g; _. d) C. U9 K7 J5 m" g8. 加载已经训练的模型" c3 M' _( T' p9 t
9. 推理
3 z" \  [' \. x( l, s3 E9.1 计算得到最大概率
9 e: N+ g" B2 `% R, E& _5 ?9.2 展示预测结果7 X" ^% O' y. J2 X$ e
写在最后9 ~: O, W# ^1 D2 Q/ f
卷积网络实战 对花进行分类
) ~4 N; t  e* `. D本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
. s3 [' V2 @3 N% `2 \: \/ H: X( T4 o9 ]( ?# N: l
在文件夹中有102种花,我们主要要对这些花进行分类任务# |7 F1 g+ v: V4 O+ m0 O
文件夹结构
& }% U$ b4 R9 {( I' g' T: @! ^0 J* T6 X6 R' x% k* d0 l
flower_data  d) \/ Y: g" w) h0 ~
" d: W. l# ]7 G, g$ d
train
7 T4 U2 y$ _4 l- u: U  I4 [- i: S& N) _& m. B& u9 \% t3 w
1(类别)9 I0 U2 k! e. H! U+ h& B' T2 \. Q0 [
2
: U; Y6 R% g& L) Bxxx.png / xxx.jpg8 X# m, v. H/ B+ V; a$ X
valid
# n: N2 k8 B6 @, E& [! P
/ C/ [- p8 ]. I  S8 p9 F- n! G主要分为以下几个大模块" l2 s0 v; N6 Y0 p" v. \
. Y7 V- K( Q$ V, v1 k4 e4 h
数据预处理部分
. A! E$ O" d, Z5 J0 M: j+ R& T数据增强: m7 k/ ~2 n$ R
数据预处理0 x1 A7 K' T+ ~: g) J) z) a; M
网络模块设置
- A2 J0 x: N% U/ Y  ?4 d加载预训练模型,直接调用torchVision的经典网络架构. Q# |- S$ e5 U( \8 H  e$ L4 c3 H
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
. U/ E2 y: ?; K# x0 ]6 n3 m) j* `  ^网络模型的保存与测试5 L) V, o, X; _' }; ~* N; s: O' ]( `
模型保存可以带有选择性
/ n: L. N# ~5 x) `9 _" }% g# w数据下载:8 Z# E. {$ g: e# g9 n- E; A
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
6 `1 L2 U0 H& }; r5 |, Q* S3 p' }9 W2 g+ ~: e0 E4 |: M
改一下文件名,然后将它放到同一根目录就可以了. V+ k  b, Y/ c4 [1 i& f

3 O% A$ q4 L0 |& J% ^下面是我的数据根目录
) V: J2 m' t) R. a. ^/ |  M. W
* W# ~6 f' l# R! W' r$ w) B" l5 ^" Z% T: Q0 ^; q
1. 导入工具包
! J0 X, I- h$ |import os
! a' |+ N1 Y8 I& k# uimport matplotlib.pyplot as plt& w# V* D9 l* Q' R  i4 g9 h
# 内嵌入绘图简去show的句柄' A6 ?2 t* n7 T' ~) a, g; A* S  X
%matplotlib inline " y2 x  O+ c# J) S1 J8 O! e& o
import numpy as np+ @( p1 ~$ r. B  G  ^
import torch
2 U9 e: u; I& o% C* Yfrom torch import nn6 |, `, V: U  b2 ]5 Q, J& r7 D

- [3 M5 Q* ]8 R) x2 g6 S5 wimport torch.optim as optim
7 u" c* `8 A2 n; g, g  yimport torchvision
9 D0 T7 G/ z& U7 q* B$ f$ x+ j+ dfrom torchvision import transforms, models, datasets
2 B9 D0 s% l: z& A1 p; U9 `/ Q* ?4 n2 S* u0 z9 v$ U+ P0 B
import imageio5 E* U1 o- c8 f: \" o- r; w) |
import time
6 C( \; ^# g3 oimport warnings1 n+ m0 J* G$ E# c6 w  K3 m
import random
" m4 j& Y/ ?, U: mimport sys8 ^. E& o) d" I
import copy! G- m: o' }6 \% ~
import json% \  f: [) d/ q, q% k
from PIL import Image
' W6 @7 a; C5 D8 r" B2 d* G8 @* @9 _/ ?0 a. p$ v; B6 S$ |6 ]
8 h0 t  K. ^9 x0 O
1
% W- @6 v' r5 x7 f2
5 B- X) \" m, V) f0 M37 B( D/ j+ v4 ^: y# z7 @( I
41 r; T8 Y& @2 k
5
# `6 X7 @( b6 }5 h7 f$ w2 g9 r9 c6
- }  D: p! G; o7 d. P7
1 d: K$ \* q% ]4 I8" t/ B9 b  @3 K4 k1 G: Y
98 i. g- O8 U- Q
10: D: @( j1 R0 v! I4 I$ O0 w
11
2 e, u9 h( U7 J  q, q# @12: k9 o( M, P4 C2 [4 ~9 X
13% C) L  u: W1 b" ?
14
7 v; h# L; Z' \; @) v15' T$ l" S" T- z/ ~& Y3 p; s' f$ j
16
; U6 G! S9 n. P( \6 U6 v17; N# `: L0 _  }
18. `; f( ^! H- O0 W
19
7 c% J& d/ j; s- ~20' M; b( Z7 t/ I
218 r% M% z7 V" [( N$ b( `( P5 K
2. 数据预处理与操作3 F0 \$ m: f  K/ g1 ^" D
#路径设置
, i- t- g. o: F# F8 @data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
" y8 W' M8 U5 h6 C" s5 gtrain_dir = data_dir + '/train'+ \# r! N2 h6 Z4 [7 l
valid_dir = data_dir + '/valid'
+ w0 S" c& `1 {& q1$ @0 e+ @2 B/ g( I
2: ?& C% Z5 n7 n" @" Q; u2 A
3
' `8 u( a4 \) ?' J9 T4  J( F" i" z6 B2 O
python目录点杠的组合与区别
' A/ ]$ e% t# b8 a4 L9 |' G注: 里面注明了点杠和斜杠的操作
2 a0 s1 d5 A+ d  [4 ?+ ?+ e& o8 O1 @+ n3 o' p+ w. Y! _7 p& ~* U
3. 制作好数据源; B0 V$ w  `2 R: f- l4 E
data_transforms中制定了所有图像预处理的操作
2 ?4 |( H& j1 p7 ?ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片! s/ }0 P; M( L9 H$ u! s4 N/ T8 ?! i
data_transforms = {. C- y" f6 y& r8 n/ ?% k7 X) X+ d
    # 分成两部分,一部分是训练
/ n) ^8 o7 Q. u; `: C* L9 m    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
( c* t- v5 v3 W1 G, w% q5 \                                 transforms.CenterCrop(224), # 从中心处开始裁剪+ d; e. K" D9 @3 A* t
                                 # 以某个随机的概率决定是否翻转 55开9 C, |7 z; D5 g( Y3 r$ v6 O
                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
# f/ X0 M" G! |5 _2 T( U1 q4 u                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
  _7 \9 g# N7 i* ~7 O( X                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相! U! ^: e# }2 B# F
                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),+ E. E9 U0 e* ^2 @* v9 X) J: {
                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB+ \- w6 E+ b; a+ S. W2 e
                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的8 T! \. ?* F0 d& ]1 j2 U* {
                                 transforms.ToTensor(),- k: X$ T6 s: p1 O2 I
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差  S7 N. R* X' t" G
                                ]),
4 e8 a1 O& [) O- U" K    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
; F5 R: l. v5 R. o5 H$ J3 l    'valid': transforms.Compose([transforms.Resize(256),. Y' x- S( z4 T
                                 transforms.CenterCrop(224),, Z) e6 S% `/ T: A
                                 transforms.ToTensor(),: v; n! k1 F) c/ P1 J$ p. n
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同6 T" A  h  K! A% [8 I. k% \
                                ]),' p+ j" k$ u! X$ ^
}
; M/ N. k4 D( g' t7 b8 s. s2 H  w4 r* ~
1" S% J8 }8 R* H7 z0 F. ?! w  H' t
27 Z, \) H5 o1 s6 A
3
" B3 l( h) v1 {" L9 `+ n+ ?4 V4
6 w/ C; q' S$ G  ?: O5
. l4 H+ ~9 \( V6# ]: I; a; N: h, d# E6 S
7" q( h+ o2 m( l* U9 q
8
1 X$ ~- H3 w, n$ u8 _7 O" G9
  \' f6 d$ G. D7 ^10
6 R2 `! X9 b) I11
$ p3 s! D' W  y1 V+ f8 K12( r! t& B. z3 n" S# l( I
13
, {- {: x+ x/ z1 Z; m, B2 R14# U/ ]" P5 ^. m5 U
15& |! F. m/ O" H& Z- @
16
  l- o  }/ Y2 _) m& I/ e17  R2 L9 k! l$ n6 I' H# S; X1 Q5 w
18
- B1 W: N7 G: ?9 O" t; q$ F; _8 K191 v8 m" G" x3 m2 S% P5 P3 B
20# e7 J; @- E: W! D2 {* B
21
4 e/ y: v4 r9 b' t  sbatch_size = 8
- {* @/ U! }/ Y! @# r& q( F+ Oimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}4 ?5 l5 v2 S  a. v# ^
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
7 M  G1 @% v- j3 c7 X) Y9 bdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
! q" W# F  i1 {  s9 Eclass_names = image_datasets['train'].classes
' }0 U/ g' E# ~3 d% |- g4 q1 ?: Q4 E& V6 _7 E
#查看数据集合  h6 @  r, K  ^2 }# l/ y2 N
image_datasets; v$ M& W3 _0 Z( _

% |2 A0 N& ^( L+ {$ ~3 Z0 s1
. Q3 @% Q$ c; b: U: }+ ~2
! @  k# }8 D8 @3
: [( U$ ~2 {5 g, Z  C1 b4; u. R! K) _* {# X* z/ E" m% R- `
5. E# O3 o4 C8 B% B# ^' H# ]
6
& C8 O0 l( Y+ W0 M7
  k* o8 K. V, E; b1 I5 i8" }' ?% l; N) V  Z  F
9" ~' i9 X2 ]% M+ O4 J( }
{'train': Dataset ImageFolder
5 Q6 L: K' d1 \* ]& V: r     Number of datapoints: 6552- P; c2 L  O2 R( c
     Root location: ./flower_data/train, T0 I3 s% [+ v& g2 N
     StandardTransform
  p! v' S2 C1 Q Transform: Compose(
" ]- n. W0 A+ S4 ^                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
6 F. R+ n: E6 {! i5 z( S                CenterCrop(size=(224, 224))
% S2 y! }' S+ C: p                RandomHorizontalFlip(p=0.5)% q) ^8 w( Q7 N' a! i( d  W
                RandomVerticalFlip(p=0.5)
7 x7 e: r$ C- d* Q+ X, p- u" G                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])* t3 c# B2 u7 C
                RandomGrayscale(p=0.025). `; c. H# [. }/ T' Z  J: n
                ToTensor()$ D" m$ u  |3 |* @
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
* t, ?- I4 _0 \& W5 e6 b+ u            ),( {, ~; ^4 P4 y; Y- c4 ^# C+ u2 |
'valid': Dataset ImageFolder7 ?4 i1 U; i& X! ]
     Number of datapoints: 818" A& l: N5 o6 ?3 S
     Root location: ./flower_data/valid
1 \& K( u  v3 s5 [     StandardTransform1 `: e1 G. r' q3 A& \
Transform: Compose() [' l6 \3 N$ r  b# m. N0 V
                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)  @/ _8 o8 C4 a& G% R" P8 g0 S
                CenterCrop(size=(224, 224))' j$ e$ o6 S6 S, E: y& A! A
                ToTensor(); g' n  p# u' ]
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
. g  W: R- J. ?) H( a            )}3 V% n- I2 }5 I8 S) D' D+ `
2 c* m) Y- R* r/ N; k- w
1
( j* o1 A& P1 j4 W22 [% m" n* Y; w; N$ \- W
35 x6 g* l# ?  N, `: @5 r
4
8 N7 Q& k3 Q+ `& I2 P: H5
; x1 D* o; k  ^5 Z; X* _6( h& i" g# z0 ^: i
7
0 u1 ?1 q4 r5 C0 z. F# k8
, T; ~! j! r$ J& y& l  K: Z6 [- [9. J7 M: E$ ^  O, C
104 }4 X, \7 x3 L, Y
11  ~" \) q2 H+ m
12) ~& q4 b+ r  e- M
136 z8 r1 |5 B7 ]" k  w2 @
145 u/ m" N& n8 Z6 q
155 l, s8 ^! f2 H& m  N  k
16' Y$ Y# P6 P4 `: R+ d3 c! S' G
17
1 o, K4 X) z/ [9 R18/ [/ z9 G1 P9 l0 O7 c7 o
19
% h3 R. E% k9 Q; J4 N20# r: ^6 _8 G! B, t1 U4 T
211 J' B, v9 s- S1 Y
22
! A5 v! V% Q8 i! ^. V( l23! Y8 c0 v" y/ d3 h' j% p
24
4 W% @* X- G# x# 验证一下数据是否已经被处理完毕9 q$ O# y) J9 T2 k2 O0 ?
dataloaders
0 \& ~6 v+ l( N5 D1% n" H3 u3 Y$ ~3 u7 r! w. |; b+ ?
23 n1 @( @# W! g$ ?8 L
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
2 W4 r- ]; m: b7 n/ W$ y 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}# F1 `0 ~5 S; H" @( I" P! [
1
! a6 M* s1 b  g* b& L/ J2
. ^( L0 `, D2 }% n% Udataset_sizes# O- v- _/ u* z) e" K0 a0 N
1
0 u3 S- U. b) ]+ p# Q0 e0 q/ a# ?{'train': 6552, 'valid': 818}5 n" E# c/ _6 R
1% R4 P& U+ G) F1 V. `# j: {
读取标签对应的实际名字* T- Q* I  G# L
使用同一目录下的json文件,反向映射出花对应的名字
+ d) M  x% z$ q2 J. U& }2 w* x. @0 g3 |% d3 H6 z+ `
with open('./flower_data/cat_to_name.json', 'r') as f:7 x, u' T, i" S, h  K) j2 O7 F
    cat_to_name = json.load(f)* }  j2 ]5 {; t* y6 @, `( L
1
9 r  N) l( F4 [5 P9 {* [8 l$ S2
' V  e7 C  @3 B# ?* \/ B- mcat_to_name1 M, s/ l" V) d4 ?  I. r1 ?8 |9 ]
13 T# y5 w3 l: D4 e! @
{'21': 'fire lily',: k, `8 J% h( @0 b- a
'3': 'canterbury bells',5 X) u( C2 k$ W
'45': 'bolero deep blue',
( M$ D" U0 ?- r4 y '1': 'pink primrose',
) X! l2 y! V6 i* L: V '34': 'mexican aster',8 g- C' V2 v5 e5 ]0 @: d: |8 T6 ]
'27': 'prince of wales feathers',( g$ |5 ^  v* `. U- {
'7': 'moon orchid',6 a+ s$ h  ]2 r! J+ T
'16': 'globe-flower',% d0 `- W6 V+ w' w
'25': 'grape hyacinth',
0 w3 _2 j7 i5 a1 x1 | '26': 'corn poppy',
1 d% `3 z; R" s, i; f# o0 ]8 b+ d '79': 'toad lily',3 [5 [  _- P5 J% @9 L
'39': 'siam tulip',/ a/ T) Y0 M7 n: {/ R8 `2 e2 r3 d$ Q
'24': 'red ginger',
  B; f- \/ U% |2 H% f2 r3 _ '67': 'spring crocus',
8 G+ _' ~+ ~8 s9 K7 {' {/ O7 R( _ '35': 'alpine sea holly',. T& T/ H, X& T$ }5 W! T) u
'32': 'garden phlox',
: S, N/ p3 G. q8 q9 I$ Y  I '10': 'globe thistle',
: A# z) a5 e+ v. L4 A, e '6': 'tiger lily',
! k3 `6 V% b, s/ V '93': 'ball moss',
' K1 j9 X3 p, b2 Y; ~ '33': 'love in the mist',! e+ W  z& [8 F/ B, R; t
'9': 'monkshood',& ~# s) t* _8 R2 ]
'102': 'blackberry lily',, e: p( J5 I  h
'14': 'spear thistle',' j' \& @8 s8 U+ w! A
'19': 'balloon flower',
2 j$ P6 M+ i  V$ b6 b '100': 'blanket flower',
6 {# D2 z8 ^  t, @8 x# t '13': 'king protea',' d! ~+ V2 z9 p# i' F
'49': 'oxeye daisy',! U$ N. e' Z" ^  K6 Y, s8 Y2 ?
'15': 'yellow iris',$ Q( C; D; a0 d6 L. U
'61': 'cautleya spicata',
4 r; F& h2 Z4 h '31': 'carnation',2 a9 H/ K! J, x. }2 M% F
'64': 'silverbush',
. h0 V( s( M0 C0 @1 ]) F '68': 'bearded iris',
$ I9 A/ C* v( F6 K  E '63': 'black-eyed susan',
( U7 C" Z3 Y" V: Y '69': 'windflower',
. C- `8 R0 y: r* y' u" v '62': 'japanese anemone',
5 u9 X, h% }0 n& Z% ^ '20': 'giant white arum lily',
  L6 y9 ?2 P: ^8 z! f" b! K '38': 'great masterwort',7 r8 d) p* x% ^2 G' R; q
'4': 'sweet pea',. z. S: x( I. Y( I  v' l0 u; x
'86': 'tree mallow',
! f9 r' B6 I8 V* M+ h '101': 'trumpet creeper',  w" l/ W% u  e, @- _! a8 g
'42': 'daffodil',( \8 `& ]9 w# g* e: L
'22': 'pincushion flower',
5 x1 [$ m! r8 A% R; H '2': 'hard-leaved pocket orchid',- |* ]& V7 w6 Y* A. `5 w0 \  m
'54': 'sunflower',& M. D. h) g4 X4 _
'66': 'osteospermum',( v, q4 J3 |% {9 G
'70': 'tree poppy',- {- A4 Z- h7 p. n: f2 R9 m
'85': 'desert-rose',
7 i9 }9 M8 H! i3 R% e. D( g '99': 'bromelia',
9 c* Z* F6 F0 E0 u/ v' d '87': 'magnolia',
* u6 Z$ z+ Q: a) I5 v9 N! g* S '5': 'english marigold',% r. W4 S7 O5 L1 N& H# t+ q- B  t
'92': 'bee balm'," d) _% J6 x* S  _* k8 s
'28': 'stemless gentian',1 a+ Z1 l, m# B% m! E4 x, z' [" ^& a
'97': 'mallow',  ~- {" o( O& t
'57': 'gaura',# K; |) K; m) N+ R" D* u" j: t! t
'40': 'lenten rose',) Q, p) f( a( A' N# m4 V5 E
'47': 'marigold',
/ N/ b5 q7 E7 _ '59': 'orange dahlia',+ B6 _5 D+ Z/ S' |! S
'48': 'buttercup',  O8 n# z+ n) h( ?; _) i% l( Q
'55': 'pelargonium',$ U( B+ u8 T$ |+ V& ?4 G, _
'36': 'ruby-lipped cattleya',
# Z& b7 Z" ]+ s1 D+ R- _! _ '91': 'hippeastrum',6 h0 i  d2 k2 @0 E4 o4 x0 X/ Q
'29': 'artichoke',
% w4 |0 r; U- S( o3 R& J- K- J '71': 'gazania',
6 R6 X2 V7 ~& U0 N* A- A: N '90': 'canna lily',
" O* _) a) _9 Y. ]/ ]2 }0 k '18': 'peruvian lily',% L3 Z  D3 b  l2 \
'98': 'mexican petunia',
" x" j' F- u0 R5 K/ @' O7 T* u '8': 'bird of paradise',: ]! _" S" [0 m" f: B5 f2 N
'30': 'sweet william',2 N0 J7 q1 Y; a- _1 d
'17': 'purple coneflower',, z  }2 U1 _5 _4 S: `# R' y
'52': 'wild pansy',
. Y+ c% R. q5 }# b1 n! o, P8 j '84': 'columbine',8 s8 S% {' ~  `0 g( j6 x, W
'12': "colt's foot",$ j# ~" n( Z$ s3 K2 N+ D
'11': 'snapdragon',
4 A& b' A" W2 y# \$ [+ I '96': 'camellia',
  `4 V' ~. \" d1 u/ d  E '23': 'fritillary',, Y7 [% V2 F5 i& ~2 y
'50': 'common dandelion',
8 J. C7 u. j: t '44': 'poinsettia',
8 v% a7 X4 A7 h9 f  O: g9 q9 U0 | '53': 'primula',# {. ~) G$ p2 ?1 h
'72': 'azalea',1 o% k! n6 u0 d- L* x
'65': 'californian poppy',
6 N7 q6 Q* s) }7 u0 M( Q '80': 'anthurium',) a6 i# P  W, y1 m9 M  e
'76': 'morning glory',5 ~( C: ]" B0 N: h5 b
'37': 'cape flower',
& w4 k  }' ]7 @ '56': 'bishop of llandaff',
5 t# Q9 F; v% O1 u+ ?' Q '60': 'pink-yellow dahlia',* o1 T8 g4 H' D: h6 _; {! b# ~3 {
'82': 'clematis',
9 |$ t# \/ W% a8 L3 r. L '58': 'geranium',
) Z% z6 c3 n: r5 X& ^ '75': 'thorn apple',
+ C  ~# [- ~$ ]4 S7 G# c '41': 'barbeton daisy',% d9 C2 c0 v+ B" j
'95': 'bougainvillea',
( Q- `6 c( ?8 S7 ?1 o' Z  h1 A '43': 'sword lily',
3 e+ R, _  H6 i5 i; A; |* Q1 R '83': 'hibiscus',. t. q  l. V6 t& ?2 v
'78': 'lotus lotus',
: M' ?8 E: H6 i& s, t5 A '88': 'cyclamen',( X7 w9 ]; D& F0 Y" K3 ]# G) U
'94': 'foxglove',
1 g. x7 ^- i% y1 k) @) M# P '81': 'frangipani',
3 ~. i+ @! R. E2 V6 @+ U '74': 'rose',
6 @" Z* N3 j# o/ G' c, a( H '89': 'watercress',
' S, S' z4 C+ Z5 A( ~7 V '73': 'water lily',
4 |, R" F1 q* e0 p  ]# ^ '46': 'wallflower',* b! b8 X' S& E% z8 G2 H) P
'77': 'passion flower',6 Y" {! Q$ _, {- F* D+ ]5 i
'51': 'petunia'}, T' w1 w# |$ d7 d' D2 Q
. p2 W* E+ U" d. h/ `. \
1& D& |5 v+ h0 }& q% b: N
2
  o# L. l3 x# M5 P31 W: ]. {; u7 [+ f) _
40 x9 v, v8 B: N. N* c1 K* J
5- P+ J& ]2 I7 T7 W6 T
6
. m" e6 y+ x' W: c! a) s3 l7
& d6 u4 r6 E+ x0 D9 M85 J6 w- R! L' b. c/ s0 V' I0 W* E
9
, W) W7 B! J8 V( G- j: b7 a  N10
8 L& M7 ]8 `0 U; L; g117 M: S# ^5 k. T) C4 W/ z8 g
12
" p6 {# a7 a  B1 X  R  V13
& g5 f- n$ \3 h14
' s  N/ d' o' }! H) l15' i/ Z, {( G9 u2 V9 w  g
16
7 v: L0 e* u5 x& m+ ?175 W" h% D) [# y2 Z" i# g
18
! _4 J  ^4 f8 t# U3 q" V: P19
! N0 _4 B( G' F# w# _9 ?, f20( F4 n% b& E9 P
217 S& E+ W" r) X/ D$ z
221 v3 A; @' X! ]
23( t& b! f. ^) h6 x
24
$ F4 F# N1 G9 D25' K: M1 x/ g1 [$ D9 a
260 t  y' T" q% u# ], A' V1 F
27; x8 V: Z! B2 O9 J  |
28
2 g- d$ `- B2 o! I+ }/ }, `! o292 R) j8 [) ^4 A% F# R
30" G* v/ O1 a1 m' w! ?
310 M  d: a5 w6 T8 n. `- T1 ?4 M
327 w! E# z' S. n: h* A9 M* J
334 D$ P+ o0 |  `& ^# W4 ^5 r
34
8 {; i. _  Y0 k' C" m4 Q* j; x358 O0 [& J# U1 I+ t# [, T
36, \' S( l% n, w0 N7 v7 ?
37- E* O$ z% Z1 Z( ~& B
387 G- z/ y* c* \( O  q# C6 T
39
5 M( Q/ i& X, U; `8 R6 f40! I7 H! c6 y5 g! t3 M% l! ]  I
41: x8 P; S" F1 c0 b
42  o+ N/ ?# I6 z6 e( k
43
* d6 a; l1 ]; Q# }7 \$ L44- t: u1 y8 D( f( E+ h# i6 v# ^7 R
45
4 |' {+ I6 g4 W5 r  H: N46) t: i; y8 l( A. t
47; }5 o$ D/ E* }
48
7 O1 i6 C+ m$ Z( h/ k5 X) s492 ?  @. x" d6 E4 X
50. T# \, u- \* \" z* g' D* f
51
3 Q& s: F- v0 e! g52
- |" @8 L7 K# {. z3 D0 @$ ?8 w5 h9 b6 g537 {- @' ^* A0 v0 r/ |" ?! r% `# b
54
" b; n5 f3 s6 Q/ L1 D55; W( @& t# M* M& l
56" p0 N/ w4 N7 ~$ e
57$ Q" a& C2 k$ _+ E; [; h
58
; h8 B) R9 z8 k! o" g59
+ p5 O0 ^: p5 k' j" I8 C' O6 i" Z60; Z9 [7 {# \% s! ?6 l& d1 q
61
3 U9 l  d7 ]4 E  M623 D) S; q, |: M0 z' ?
63
9 i; ?3 K' G! k6 b3 h64
6 F, {: J$ m* j" m% y5 h! s, |65
8 I; K6 Z, }5 g. R66
2 ^1 s( ~3 ^% a. {# T* l( {6 @: m677 N, }; B9 T% s% {- @
68
7 X9 }4 w6 Q: H! \3 ]) [$ l6 N+ C69
- s8 Z! s) d7 d8 |% i0 V9 @70
# a* R, D8 d+ v) Y! g71
, @5 z! U# j3 N* s72
1 L- C- Q+ N6 k! B& A# J6 D73
; b0 j" T( q- [) h3 l74$ s' ?4 D0 t# x' ]
759 S' n* H" W/ ~* t- c, h
76
4 i3 o0 L% @& _/ X9 i+ ^3 b1 g77
3 q3 z2 R  d' G+ X& F- m! V78
$ H  j& P# C* C: J79
* e& N: P2 F" k8 ^80# F" p, d4 F. `6 b; X% |. i9 ^) t
81
/ I2 C1 w5 n5 ?& {" P/ e3 U4 u82
& }( E' E: f( n' V# @4 ^83) G/ G9 V; D: W9 n
84- x$ R1 z2 d' d- `# d, F
85* w0 v# n8 c# b* u* H9 P& N
86, A+ J; C& ^" H: d
87. y. z( h9 Q: N' ^
88
$ }; B& c4 ^; H8 ?/ j" c7 I. M6 m: W89
0 Q6 P, t! X7 f; A90! P2 a( f2 h6 ]; ~- r+ J/ M/ [8 B
91' f' c1 i( Q  i$ p9 \
92
" G. V- _. `( Q* I% A, O93% T6 c/ D$ L2 I3 ~2 m
94
" g" r# L* q" C$ O8 U959 o, K. N( J2 x# `% a8 |
96% u9 ]. e2 f- g; `: L
97
; ^; S/ `" l2 Z$ E4 V4 U98
' H8 Y  A$ y' G7 b. n" L% t99
+ R  n9 V3 |) `  w: |100
& D0 e0 z7 U9 Z0 f, [, N: A1011 }3 k; ~0 f+ _+ y) E
1029 C( z' j! [. V6 \" ?! w
4.展示一下数据1 f+ S6 K7 P$ N, X2 T3 v4 ]
def im_convert(tensor):
, O$ R) S6 K& X    """数据展示"""1 P! \1 c: h8 s9 |6 X! T+ H
    image = tensor.to("cpu").clone().detach()
+ F) h- q- z8 }, u5 t    image = image.numpy().squeeze()2 q6 N+ T: ~, T8 u& g4 z$ j9 ^
    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
: b* q4 y) e$ i8 e% n    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)+ m  U5 x! |7 Q
    image = image.transpose(1, 2, 0)
2 Q9 y" e3 X1 `) {    # 反正则化(反标准化)
  g1 b. N, [9 j$ J9 j3 s8 m    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)); |/ C$ |# e' p* @
" B5 T, u9 C3 F( g& r5 X7 K& ]
    # 将图像中小于0 的都换成0,大于的都变成13 p' o) D: [! t8 e
    image = image.clip(0, 1)
3 ~/ ?4 P5 E. ^2 r% U4 d" ], _
3 l7 |9 a) e8 k1 u. `    return image
: i8 s6 P$ D6 W1
  y. f- c; z9 _6 {1 ?2! w4 ^! J) G6 V; [- A: {
37 e! n) }) m. ]3 e7 w5 R5 C
4) S3 q+ _4 Y, B: a
5
2 p5 B3 i) j0 F$ E- R6
4 M% z# ^. h& a, P7; E* A/ M2 l# _- y& `4 ^
85 W& z( {4 V% s0 J! L' j
9- H" ~- i! v" C+ I
10
3 u3 y; \2 Z$ l6 G: L11
9 B4 G. Q: o. k' z; O* O121 {4 N) [8 H2 W1 X% l' M6 ]
13
  f. E& i# e. N1 e# v2 }/ }% ^14" p2 B) X$ P" G, U& b. ~
# 使用上面定义好的类进行画图
' r- k( P5 ^8 A, w3 G  S. c  tfig = plt.figure(figsize = (20, 12))
  Z8 @9 j+ U9 I1 ?2 Z8 a" qcolumns = 4  V6 J4 o9 `1 I
rows = 29 w/ f: U) D* c$ }$ r: }

9 n' v$ t4 H# ]! y# iter迭代器  ]6 u" r/ M, l, T' Z5 A8 f
# 随便找一个Batch数据进行展示
5 x$ q+ R0 o/ Ndataiter = iter(dataloaders['valid'])
0 u4 X7 _* z3 Y, y2 ~" l; Uinputs, classes = dataiter.next()
0 ?7 Y8 ?1 h. d$ q* d# h9 r+ t9 s' b  k' d, l1 z. \# b3 f. f
for idx in range(columns * rows):+ w) a& k( U1 U( c# E1 t3 ~9 P
    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])4 J" K1 H( a1 A2 G0 T3 T
    # 利用json文件将其对应花的类型打印在图片中8 P6 d$ m0 k7 T: P; u
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])8 R! h' A% P. A7 t2 z- x" w, ^2 @
    plt.imshow(im_convert(inputs[idx]))
& a' g' g$ S2 d# w6 c7 Cplt.show(), l9 \4 A% W) _% R5 R. f, V

% E  o2 g# F& W- Z' ~% _1/ N" [1 H  Y8 h8 j& w3 D  o
2
5 r3 N" z- S8 G4 \33 c/ x0 K' j/ j" e+ s  \% Z
4. }1 s. t+ a9 b
5% r" ]+ U! }' t8 P) D6 H
6
1 ~4 i4 W3 O1 G: D% o7 x8 x8 t74 ~/ [" i$ t% G/ P
80 y& o4 b- w: b: R4 C0 b7 v' ]
9) P$ E5 }8 t% _
10$ I6 U' ~: t3 s. H
11
, ?! s5 ]0 ^( D& f% ~. W6 Q0 _12
4 e: Q, M9 Y- X! [13
. v- ~$ ~# C/ ~; y14
( O( V+ N3 W( |15
- `1 e4 T7 [$ w% e9 ^; w! d, \# _+ T16
9 w, @- M. {/ u+ Y( L# ~/ Y: r9 M( J
; Z! D2 t5 X; @: ~! i6 |/ r* l. U0 ?. u# l' U& C% M
5. 加载models提供的模型,并直接用训练好的权重做初始化参数- M$ w7 f! {; k1 i: h# T& K
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']& Q) R: P  U& f0 U8 O. X. b
# 主要的图像识别用resnet来做7 I" `/ L# Y1 L6 |& P& c3 e; D' v
# 是否用人家训练好的特征' w3 @+ U0 ?/ p* q1 L% p3 R
feature_extract = True- P! G$ z# |' q- S3 h
18 i+ n$ H8 o8 H% @% ^
2
& n; |& ~) R3 D. }* a2 O3
& F" X  Z# i; U& h3 ^4/ X5 S3 W& c! [- w( F4 e( C
# 是否用GPU进行训练$ D3 }) `8 O8 M- o2 t! s8 r; P6 c% K
train_on_gpu = torch.cuda.is_available()
% A& a* V% V" r' Y+ k5 B+ P4 X! W# A0 j9 W. _- _! J+ c: e
if not train_on_gpu:
4 U, o9 ~: i1 w  |    print('CUDA is not available.   Training on CPU ...')  b( j- D* D* }0 j
else:! X0 M. S4 t5 z) O
    print('CUDA is available! Training on GPU ...')
) M* x; f' j6 ]7 c. m2 U7 S0 T3 A' a  }% O" ]+ u' F# n% y
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
, i: g  L" n0 }( K/ V2 J  y$ w8 u1( i+ }0 R& S4 ?
2
# C/ }" V3 z- f  x+ d1 t: `3
5 x4 y4 K7 y$ ?& i2 A4
7 }9 N/ \" d7 k& `5
& l. _4 B$ y# S' E5 U" D69 Z+ M7 m" [& V9 Y8 W
7! M7 E3 W7 @% b9 \, ?
8
1 `$ c* G: `: N9
+ Y8 N% i$ v0 ~( i0 ~/ @CUDA is not available.   Training on CPU ...
6 }( m1 w& `( k1
' J4 P. P3 N  r. P9 m) h$ |4 {, }# 将一些层定义为false,使其不自动更新, |3 H( b( V+ L) ?3 A: R
def set_parameter_requires_grad(model, feature_extracting):; V! u  {2 ^8 g  p
    if feature_extracting:! C0 R( t& }& H  V% W
        for param in model.parameters():; e9 ~3 T: }: w+ f. u4 d
            param.requires_grad = False! O, s" o0 D# V% R, s
1& [7 D7 a8 \1 L- y
27 |( d! r) y2 o9 }8 `
3
. \/ t1 x& O& t8 [/ V1 `4# Z' S) H3 k0 d5 ^+ e5 r( M& b
5) \" r: n' Z) {6 Z" ]
# 打印模型架构告知是怎么一步一步去完成的
" N7 L' A! T3 Q# B- o( a9 z0 K# 主要是为我们提取特征的
* M6 y& e+ L+ x
5 `8 R' w! M* l0 b" fmodel_ft = models.resnet152()
* Y% Z! _2 @& q3 q7 f& g6 Vmodel_ft
- R2 ~0 w  H0 K3 G/ Z/ H& W( w1
" F, d) I& v! N) X; h5 E+ w5 p2
1 T" S  [: C: ]0 _3% x8 I! P8 @- k5 u; \6 i3 \1 g
4
. t7 K  h7 N0 p5
2 w1 P/ d# ]# I5 c+ cResNet(, Z4 E/ `% L9 |  ]' t! n
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  f& L+ D- x( o- g  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), R* ?, B- Q9 b( t! N
  (relu): ReLU(inplace=True)
: _% L1 w5 C! Y- [" A  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)! f$ Y( }* g0 x% p
  (layer1): Sequential(
2 o* Q! M# ~3 G( j2 v    (0): Bottleneck(! I* l' @( J5 U6 s; X, k' [
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)& }- k$ m) z& c% |" n, E
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
% V$ m1 B% ]! Z9 ~      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
9 Y3 _# x1 p6 ?+ [      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
% T' H0 h: T: w' m" ]# @# C      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
% G+ f9 O* N" D: v      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
; r4 J8 i- e' c8 ]+ V2 i% I' n6 S: \      (relu): ReLU(inplace=True)
, C( B! o( [: R! o( C& V      (downsample): Sequential(# \: R  q( |' Q/ n! }7 y: |1 I" r
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)! Y, C* M$ A4 c# c% i# i
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
& t0 w  N* ~+ Y' M1 B3 R      )/ ~" @( x  D& u" f4 k
    )$ Y* w. b) o9 W9 @
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。  z. R3 C$ C8 B
    (2): Bottleneck(
4 M( ~% _0 F, D& o! Z! \      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)+ e, ^: p6 [6 u& F$ f$ R- X
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)7 n; g0 `" ?; z7 Q
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
! l! Y3 q$ X& Z% k      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 ?7 e) H  j  @0 g; Z* z+ c
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)% J8 g# x) T# i+ [5 N  x
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
" G  Q5 Y# V% ^      (relu): ReLU(inplace=True)
0 m  Z6 q% b9 i- W1 ~' a4 d; b    )
0 x3 P5 N* k  j+ o/ |3 X" m9 D" V. L  )8 X: X) p/ ?: S; _) U
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
( x% w) O) H& D( @  (fc): Linear(in_features=2048, out_features=1000, bias=True)
1 K: \( Y1 m8 g- @)6 {/ H. y; f: t! g; A5 b

4 Z- Q, D$ R" ]: S/ ^* b1- \: z2 [& a5 _! C+ d, o3 W$ n, J# l, O0 l
2
6 T% z( ?/ Y4 n- p+ D3
) {$ i$ w6 Z$ E; k; v: x1 u& Y; a4
9 f2 O+ |( M$ ?% a6 J3 g5
) \. s3 j+ d" A' C6
/ C( f3 K3 S; g- u1 [. q6 B) J75 |4 K- N$ Y2 s! Z, ?
8% @! ?0 e, ]6 ~: Z
9
; I$ X' g6 b" r7 S7 b10
% F, z. b6 b( v7 `, r5 P11
% `5 N# D+ ?3 M4 J/ T12( K4 ?0 q  w2 p& [2 I6 _
13
3 v5 g3 R/ N* ?* t0 `8 g147 d( `3 M# b9 t: v
157 k8 T) j. y& a/ x7 }* n1 |- T
16
5 W# W. \: T% E- @177 c  n; L, Y' E1 V# ?, [; K
18/ t- S3 E3 G; S- G
19
: |( \5 k* R6 Q9 X; o, ~3 A20
& H+ Y* T" d/ Q( j! D1 i* ?8 g+ Y. @21. ]! [4 L- g# c
22
  f/ [1 k8 j$ s, h3 P; x- s. O/ N1 R23
0 `0 W% n- w4 d+ L7 ~, O/ w247 [8 I8 {2 Z/ d# e6 s$ K
25
$ x9 w) g& g, H7 Z# e+ I! \: d26
  s9 \6 E+ H& W$ x8 Y27* y% m5 ^6 X; X+ ]1 L0 n
283 y% e1 w* T) H) [% @
29; S1 R5 B& b6 |7 V( u& B, W+ ~1 T
305 f8 n# C) j% v. [, h  i) z0 ~
31, T# p, |2 ]0 l0 q# e* r3 A. |! }
322 s$ Q+ b8 m$ }: G/ Y! T- Z
33! U% x& j  Y* r# ~3 o9 l
最后是1000分类,2048输入,分为1000个分类- y1 T2 B$ t. E) y! `4 \% P
而我们需要将我们的任务进行调整,将1000分类改为102输出" F5 a. K' E2 j8 Q! c+ [! V0 E9 g$ X

4 M! C* p7 A$ R3 c7 f$ p6.初始化模型架构
% ^$ i/ }- ^6 F! k9 Y% f步骤如下:6 Q- ?# R5 \4 {  d
) F6 R, e( T3 K. P
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数" Y2 G* w0 A7 ~9 ?
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)9 A: P2 Q0 ~1 j" O$ a& e# Q, ?
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
: j3 K) a# c4 M官方文档链接
- U$ f  N6 O# ^https://pytorch.org/vision/stable/models.html8 q! Q& F% V2 M9 i' W% _. o' X

: ?0 x* p1 m6 F) s1 \) I# 将他人的模型加载进来& m/ J, D% u0 }5 I
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):& T! n# p/ `# h! Y7 z$ q) L( T
    # 选择适合的模型,不同的模型初始化参数不同% f$ C! X8 z" Y
    model_ft = None- c- ?! b4 g% x% X* \5 J
    input_size = 0
3 V. Z0 _2 s6 A9 E3 s4 T! c# E, K5 }0 e& V6 M  o) ?) u
    if model_name == "resnet":" z' d8 I( `& c( k. u
        """
* @* g- V, r6 N$ E" ^( g        Resnet152
6 `, Y% m- o7 y5 I. `, Q' a$ A" T        """8 @1 c8 Q/ t/ s8 g" w+ ~! d

3 F$ W% ?5 y8 o; o7 ~9 I. s/ T        # 1. 加载与训练网络, h+ W7 b! `$ @3 Q; v. S
        model_ft = models.resnet152(pretrained = use_pretrained)
7 ^% S8 E+ Q' `% D2 i        # 2. 是否将提取特征的模块冻住,只训练FC层1 t6 H$ Z! i; W( y
        set_parameter_requires_grad(model_ft, feature_extract)" e7 l5 F- E6 J+ u9 X' }
        # 3. 获得全连接层输入特征# O: F' F  r3 g+ Y3 l
        num_frts = model_ft.fc.in_features+ M& C6 _$ K* C# `* e$ T- j4 T# f/ P
        # 4. 重新加载全连接层,设置输出102
+ q1 F& t# m7 Q' a        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
6 ?) ^- ]5 P- }2 O: u" W# Y                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
, Y) [, X/ u  a  w) ?1 o9 [        input_size = 224
$ _5 L0 C( q) N! q" \
9 P5 w2 j* G- C5 d0 W7 z    elif model_name == "alexnet":0 D; t" i& c# Q# i
        """* `" t1 C5 N. {+ q6 {
        Alexnet
) G1 m/ z% Q' A0 L; L        """
( {9 t3 ?1 u9 [        model_ft = models.alexnet(pretrained = use_pretrained)
6 l) t# b  e- ^! E3 c, j2 Z/ N  V; {        set_parameter_requires_grad(model_ft, feature_extract)
3 C' V& d: d! G& G( R+ _; Y: h! n: m
        # 将最后一个特征输出替换 序号为【6】的分类器
+ Z7 A! L* g$ O0 s5 Y' j, R        num_frts = model_ft.classifier[6].in_features # 获得FC层输入
$ y$ {# u9 v7 j  i/ n& w. d4 ?        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
- g& z  V$ n& ?6 H, ]' }- B# k) x' G        input_size = 224
+ ]9 f4 o- U  Z2 b3 f& M. d5 o/ r
; y" l' `% j9 k# P* W0 [    elif model_name == "vgg":0 Y: R! `. r9 P  _) o9 k  i
        """
9 G. O: |. J) x6 a$ _        VGG11_bn
! A* Q. R3 s5 R* D        """
3 @: s$ G0 [, ^/ w: m# e* w& |        model_ft = models.vgg16(pretrained = use_pretrained)
- A6 d. X. p4 c8 Y9 j$ @: w+ v! {% C        set_parameter_requires_grad(model_ft, feature_extract). N% [  \. g9 y9 L+ v
        num_frts = model_ft.classifier[6].in_features
- G# u% @. |0 `* h        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
# N% ~) f% F3 ]1 y5 `. |7 Q        input_size = 2245 ^4 S8 E  d2 }+ E9 [3 f: ]# ?

. `3 Y  Y, B$ n, [, X, B5 H6 }) u    elif model_name == "squeezenet":1 B; C7 Y5 ^- y) N5 V: M
        """) q3 q# n( M8 H. o4 Y- \: H
        Squeezenet
) v7 p1 F, P  h) h2 v; [        """
6 K  F) J% M) {, B8 V" _. y% `7 v        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
2 S- M$ l2 p* f2 \$ f' W1 m! l+ s$ T        set_parameter_requires_grad(model_ft, feature_extract)
! e( a2 S  K2 x& {        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))5 h7 J+ s8 [* s1 B, V, x- K
        model_ft.num_classes = num_classes1 j) @7 z1 B* m) W, |
        input_size = 224
' c) S" j1 a4 F; u( W
: P- U: g8 H# N0 _    elif model_name == "densenet":1 A( I9 e! I: c) O
        """: F& J3 T' p5 j4 q
        Densenet
  X3 L: p$ M: a5 Q& o9 R5 K        """1 H- e% i% Q9 E6 N" h1 z; T. a
        model_ft = models.desenet121(pretrained = use_pretrained)  }2 }# y  B. F" K7 s$ ~9 w2 V
        set_parameter_requires_grad(model_ft, feature_extract)5 i" ]5 s# g; a  Y. [
        num_frts = model_ft.classifier.in_features2 a+ R5 }1 E: A- k; C
        model_ft.classifier = nn.Linear(num_frts, num_classes)
- y) u( P4 O7 w; q! ?9 D        input_size = 224
% p( G( ?# ?0 b: T4 I; T& n3 m
8 @9 [: g: `7 A) N& s    elif model_name == "inception":4 L' n  \& Q8 ^, W- O: v, b9 B$ a
        """
# F* {" q8 I+ A        Inception V3
/ u) [5 e, E4 m0 S( f+ S) ^        """
: t& W. E7 Q) |1 R' T        model_ft = models.inception_V(pretrained = use_pretrained)$ V& s; i$ W) [7 ]
        set_parameter_requires_grad(model_ft, feature_extract)
* @8 Z2 N. P5 [* ~0 Z4 G! o' @! E- c0 E; V& ~
        num_frts = model_ft.AuxLogits.fc.in_features: U: C, L' @0 |( m) Z9 r
        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)% h; H1 M6 l2 b2 L  j) t1 X
$ p* a+ }7 B% B8 b; d  F& s
        num_frts = model_ft.fc.in_features0 E4 s' e- E& p- f1 t$ T/ W5 r
        model_ft.fc = nn.Linear(num_frts, num_classes)
9 i! [: T, t. g# u# W7 \! X        input_size = 299# G% s  l" I; o" I" q

6 E! z  }$ i* p; M" P    else:
7 M4 E) c9 K" y3 }        print("Invalid model name, exiting..."); g2 y- t9 e9 X% |
        exit()" s" T0 A- W2 U3 o
& J" d! N5 s2 G1 i' `. d' r
    return model_ft, input_size
' e( K9 E" j& e+ K9 ]7 L3 j8 M
' w: Y. H. X2 A2 j/ H1
1 i: ^" w% A- n9 ~4 M$ D1 y7 ?9 K2
3 m: ^! _0 `1 g* }3 w6 _% o! T; X3
3 F0 S5 d& N) ^6 }4
7 }- F9 [( ~, @8 _) Y2 T1 b8 D1 J) n' e52 ]# m3 ^; C& i4 r4 {$ m) X( g5 X
6
) n/ q- f" m2 Q6 h7 C7
& y8 s2 l, D  g$ R8
( D9 s7 o! ], X; P  D+ v: ~9& t! ~# A  I& y  ]$ x4 V
10
/ T  m7 H$ ?' M11% p" a8 I8 h  c' `6 q& v
125 F  h: S# b+ z1 G" `+ e4 o
13& R, s6 J6 `$ e. l
140 ^# r( U4 Q: k7 f3 m1 ~! t
151 l; P5 h" p% j5 o2 M
160 p. n3 G7 g' M* h
17. K8 K, Z' H: ]8 c
18
6 ]1 c7 N8 l- R; X' u19
" e+ _* f6 Y9 q8 n209 O) L5 h$ e/ ?+ V) ^) \/ ^$ l+ U5 @
21  s. h7 L/ S6 C' F9 Z
22
$ k7 B2 Y! P7 e) E23# h# t2 [! a7 O. K) O8 j& w2 A6 g7 j8 x
24: Y* j+ Z4 P, A
250 B# k/ O; E0 w% P3 K- W
26. N% Q1 L, m+ M  Y& x7 i
27( a, _0 c$ @9 O1 n4 H, n% b
287 `7 z' U0 p7 I  k, N
29. H: M+ Y/ Q$ B& \
309 o2 H3 o$ U& W- E8 U
31
9 ^( m9 s1 E6 }: [9 G8 E+ K32
$ @, v" {+ Q$ J3 T: {33- V' B: P7 o9 u  ~& U. e
34
( T6 m2 C; x" i$ L/ k35& M: o4 p# Z, G  V6 A$ e
36  s" J3 f. l9 ^9 D# U
377 h& R9 H. `2 l* V" S
38
) q3 _( ]& e0 M; S  T39
# }1 T! o2 X/ g8 R403 U- X6 P$ e  f$ I& Z9 J8 ^! x$ u
41
1 Z' C* ^" w/ f6 D42
7 ^  t1 u! o  z# \. W43
( J0 v1 W) e  S! J- p( [44# L7 t6 y; n6 W& t2 v, q
45
& a; @; M" k; K/ E/ I; [& o" Z46
0 D( d# y: ?2 U1 V  n* {/ M47
( W" }) U2 ]9 I0 z+ W( r7 r/ O480 \) H5 u$ ?% y3 S. ?% f
49
9 @8 H9 o. L9 d+ z503 v5 p& L, j) }8 n# Q. x  o9 c# g
51
$ V& {) X/ G: K52: a" z; I1 h! ^; W( O3 r
536 |% A- g+ _) _6 O
54: b( U/ A1 }7 a/ ~- K, u7 \) G
55, ?, s1 }- Q( ^; ?
56
0 f5 }0 b; z* }$ E57
. H* H; ^. k% ~- K- V58- w. z# a) H1 @/ a9 z) e- T6 g
59
4 k8 h# L( X6 Z2 m60
. u' o5 \8 g7 z5 C8 U61
# h  _8 O7 z8 i62$ J3 t+ I$ S7 j' B6 J4 ?
63. e9 t: |+ d+ A* }
64# M9 Z: w1 L" q  d( F2 X
65! j/ q) x6 v( x" H) U2 P
66. b4 y" A) w* d: M" f; Z8 @, d
67* h  m& ]4 E/ d
688 l% g/ E# z- C; q7 c
69& m* F$ E0 O6 Q/ @
704 v6 z- |* D! z
71$ A% K0 k2 K" B! C, a1 u
722 M  |+ l1 r; A( l& n
730 L: d0 ]5 _$ `: e& e/ K
74: f% z* h" w/ ~" V
75
6 @  g2 r; ?8 H  H$ R$ J76
7 [! r; Y5 J& |' y# s% Z! [; v1 V77! [# E( i3 J8 E$ E; k6 o: j3 Q0 V* L
78% c. V8 B; C- R% W
79
) A4 N$ @" R7 b80, d9 `4 X6 D, w9 o4 z9 J* t% A' Q
812 o  n0 E6 y* B8 ?5 h
82
8 F: m# x- N# X0 {# J83# W( w" T; w, G9 W
7. 设置需要训练的参数, L( ]/ ~) a1 w* u: u
# 设置模型名字、输出分类数0 \* m# ^) ^" |5 I9 b' J. U/ _
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True): w! L4 u" B+ o  z

# ^' p' \- c. ]# GPU 计算
% \/ ]; o+ V# _# ~) ^4 h$ Emodel_ft = model_ft.to(device)
" s7 }4 x$ b" P$ b9 B) Z9 G. R6 [+ r4 g& \; P9 K1 W9 t
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取+ y# C% O3 w. j2 [; b' u
filename = 'checkpoint.pth': z9 `0 ]$ K, w

  E& {* R6 B5 C; o, R* x; d- Q# 是否训练所有层
! f& I* r! b$ b- R+ O9 Fparams_to_update = model_ft.parameters()6 g; k0 c; ~9 u) V) W! M
# 打印出需要训练的层% D' Z3 x* o( C; P8 f" B
print("Params to learn:")0 d$ V; L0 B  n
if feature_extract:9 F1 _# w4 ?7 i' R0 e9 H3 x
    params_to_update = []
0 `- j2 I. x  {2 e    for name, param in model_ft.named_parameters():
% [4 t5 J1 }% p6 O        if param.requires_grad == True:- B$ E+ d  U. M5 }
            params_to_update.append(param)6 ?$ H  w8 x4 Q( k9 J
            print("\t", name)
, J- Y1 R0 B( w; ?, p1 ~- i7 pelse:. R5 q( S. S8 e5 |' `+ f: z. h
    for name, param in model_ft.named_parameters():2 [3 I7 t6 ?) V* ?# k& C
        if param.requires_grad ==True:2 U) s+ C' i8 w4 f4 g
            print("\t", name)' H) D. V' c' M- K1 C9 k
. O( f6 z6 W% `1 M# S
14 Z- e& Y1 h5 G
2
% J! ]2 ]& k0 R0 [8 y( q3: Z4 e1 g" w: P8 p& @4 A: w
4
9 X* u7 ?8 J8 w6 H8 t. q, Q5
0 R+ @1 m2 k7 |6
& m" @% M1 [: J2 H, A/ U3 ?; @9 a; C7& a* Z7 B. N! j1 G  M+ d4 ?
8
6 b& Y" V' V* w, ?+ ~9
# |. ~2 I7 M; P+ w$ _: c10/ s. u4 D  A( f, l
11
  \; u1 b: F  l3 I1 V! t# T( @127 c4 _! v0 ]4 C8 N, k# P$ Q
13
) `6 B, h' w) L$ y, [) s9 e14/ ]0 Z, m/ k6 U& r7 o" e
15  U2 {- v' p, l/ q  Z
165 W3 v7 S/ j8 u% Y% T
17
" w5 e6 d6 d5 S9 r& v4 V8 `( c18
) m: T! P/ V- ~: w3 k19  R, g) c0 ], ^2 @7 j  O
20
$ P$ L( b4 X0 ]/ @( A/ ^0 s21
& C1 }) |0 f" q5 k1 K22
8 l  p0 \4 W7 _" q4 Y23
" ~& O3 O. Y6 U# aParams to learn:
) k. y4 M" M1 P$ x# h" F' K         fc.0.weight% |# A! S7 w& c; B) R0 T
         fc.0.bias, W7 _6 I) R. L- J' {/ C
1' J, Q# v) ~1 @6 |* @
27 E5 a+ u# Y! r3 R+ N8 S. X
3( R" t% H" F& H& _
7. 训练与预测& }/ B. K3 G3 a' O6 d* O5 F7 O& }
7.1 优化器设置
( P* z8 }' t. M( u# 优化器设置
6 j$ q/ h0 q' T* p- eoptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
" n. u: @  `8 H9 ~) ~8 [8 K# 学习率衰减策略
- N' i# e: l* Q% s" b* _- N6 n% Hscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
8 j1 \. ~  e0 l' B4 t# 学习率每7个epoch衰减为原来的1/102 ?* H) l/ p. R) f+ t, o
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
6 i9 R) L& a* W% e: U" `) _  q3 j% [1 X' K2 C
criterion = nn.NLLLoss(), j: f  c  ?' E1 V
1
. R- `; d- e( ~* `21 Y  O2 j2 ~5 D( ^
3% e/ ~- L' @6 b1 b0 P; G2 }  ?4 T
4
9 O; ~$ I5 e! I% j1 R5 ]' C) Y" n51 n, z/ L9 {% Q7 D/ e6 ^) J
6
; z/ s, ?  W9 ]5 j, V2 f* a7
& y3 H$ w$ s) N: K; p- Y8
& E1 {% E3 u2 n$ k, S7 B# 定义训练函数  {2 K+ L+ V- u% `/ C* \9 Z: z
#is_inception:要不要用其他的网络# m( U+ V% G+ E1 T$ |
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
- P( u+ l; s# _) {& q' w7 N& j    since = time.time()
" a" v/ W! S& q- H: |* N8 L) O    #保存最好的准确率0 S+ i- y$ t' \" e4 ~# [. B
    best_acc = 0
. K) k8 M* D; v- F7 E# }    """6 a6 @! o2 i4 I4 w0 E$ I
    checkpoint = torch.load(filename)6 `3 z. N0 w8 t. ^& p
    best_acc = checkpoint['best_acc']
4 ~* d5 ^  h  P3 Y    model.load_state_dict(checkpoint['state_dict'])) [. \# Q6 A# [6 k
    optimizer.load_state_dict(checkpoint['optimizer'])$ t( ?4 n. R6 O% W$ V; F8 U" o: ^5 z  G
    model.class_to_idx = checkpoint['mapping']' ]6 J8 K$ R% B- K
    """
( k+ q2 d  P: `( v  U: f1 \    #指定用GPU还是CPU2 c1 `  Z7 _9 i' d( p# \
    model.to(device)
0 W+ h6 W  x- w* y+ @8 M; H3 K9 h    #下面是为展示做的
$ q5 X3 X- U  s" M- S1 W    val_acc_history = []
" J8 i: K- q1 ?9 x7 O    train_acc_history = []
( P7 J0 g1 N# I% R8 l8 l: Q. z    train_losses = []
2 h# r% V( w9 O2 n    valid_losses = []
& A* D, S. _, @' l( v9 C    LRs = [optimizer.param_groups[0]['lr']]
: d" V( w/ d8 `3 u( B    #最好的一次存下来
1 Z* d9 Q$ f7 {; X: H& u6 k  {2 R) D    best_model_wts = copy.deepcopy(model.state_dict())
# _; Q+ ^1 }3 Q5 w1 ~" d. j0 {$ S7 m7 b7 ~
    for epoch in range(num_epochs):
$ P% o' x$ S: O        print('Epoch {}/{}'.format(epoch, num_epochs - 1))1 \  `" F- e7 z! A8 |" L# E
        print('-' * 10)
6 |, E0 I% A4 q2 H$ o
/ v- J) d4 J: H. a! q        # 训练和验证
( ]& S. ~6 f1 G- s& \        for phase in ['train', 'valid']:
" _& E5 a( r% d  s" |: L  W  X3 y            if phase == 'train':& Q8 O( ^$ a) V0 C# }( f
                model.train()  # 训练' f: e; @8 f$ l$ g
            else:
9 k6 B6 u% i6 u; I7 z' Q                model.eval()   # 验证; b( y; _2 M+ t& K7 d& b% \
8 |* g. k6 t5 N- X# [. ]
            running_loss = 0.0
- p7 i' \# H- E" E            running_corrects = 0" x5 I, Y$ t9 n: k$ F- ~

# J" X3 |6 t* q- E            # 把数据都取个遍  c( Y" g: K) S* r4 t" u: ]/ {
            for inputs, labels in dataloaders[phase]:% {. V( x# U6 z8 x* C8 ?8 \1 n
                #下面是将inputs,labels传到GPU  V+ @! x4 p  L) v! p9 c3 w; k, ^
                inputs = inputs.to(device)+ d% f2 `4 `; S  H( t) k8 d
                labels = labels.to(device)5 P! g* {0 N3 @  R

, P7 V; ~2 u9 r/ J/ I                # 清零
. `8 Z  }8 U( D! y+ I! C                optimizer.zero_grad()+ b4 A; B* y) a+ d
                # 只有训练的时候计算和更新梯度+ X# o$ G  ?1 L( Q
                with torch.set_grad_enabled(phase == 'train'):2 ~7 E8 v# y5 J2 ^3 ?, `7 X
                    #if这面不需要计算,可忽略
- w) x/ S* H9 C7 d6 b; Z3 W                    if is_inception and phase == 'train':; h: g! v* @6 _
                        outputs, aux_outputs = model(inputs)
5 o3 ?: L9 d, K4 ^3 O0 p                        loss1 = criterion(outputs, labels)
7 b4 u$ ~6 z8 u4 q* f4 J( Z                        loss2 = criterion(aux_outputs, labels)
3 y) V' p+ d2 {7 s* |& b. g; W                        loss = loss1 + 0.4*loss2! p2 ^: H' F6 |/ I+ ^" ]: D1 A
                    else:#resnet执行的是这里* r1 |. a- S" n4 E2 Q
                        outputs = model(inputs)$ s0 f/ t, i; x* `5 A
                        loss = criterion(outputs, labels)* M% m( K& e+ |; Q7 C; \

' x4 t9 O6 W$ q. h* q4 k( h# m7 G                        #概率最大的返回preds
: }2 ]7 w8 ~6 T6 Z3 Y% ?                    _, preds = torch.max(outputs, 1)
  r0 B. y( F7 H* i: s* p8 l; P- H0 i" k/ Z
                    # 训练阶段更新权重
  I# E4 O% E6 r7 w                    if phase == 'train':8 H% P! W$ p8 Q; D4 \
                        loss.backward()
- s& `  v7 u2 {& i                        optimizer.step()4 C9 T$ c9 O, K% P  q9 c% ~% n. `
5 d  w% h+ d  C) R) v0 W/ u- j
                # 计算损失
4 X/ e( P* A& n0 k: x4 k2 r                running_loss += loss.item() * inputs.size(0)4 l) e7 c( s9 @+ x
                running_corrects += torch.sum(preds == labels.data)8 o5 R' z* K4 r. z+ G" I

' a. k3 ], F1 H) e& g' i8 ^            #打印操作
/ ?# B% Z1 m/ m, L6 L( R$ _9 |            epoch_loss = running_loss / len(dataloaders[phase].dataset); Y' E& k, ]2 f0 Z6 Y( i5 q( I
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)( c8 c' [8 L1 x5 [3 A
+ _7 Q% q  ]1 o- J
, v! a  R9 R: U
            time_elapsed = time.time() - since
4 |  A" Q5 V1 C5 _! m! T- m  N            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))- m# n2 ^& a9 p, _4 u/ J. _
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))  b6 M# I# }. V3 }, K

2 F) Z6 i/ G7 B- H% }8 H6 t  c, t2 ]$ y6 F  y
            # 得到最好那次的模型
# V1 @0 u: I# C/ k            if phase == 'valid' and epoch_acc > best_acc:
# e& }9 D) j) K                best_acc = epoch_acc
1 h( V2 H- s  j1 c, q                #模型保存# a- G* w$ J1 K* k7 ?9 J# W
                best_model_wts = copy.deepcopy(model.state_dict())
! \% W& u& `) Z4 K& P                state = {
9 B( p: c. L1 j4 M$ D                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数
0 f) g3 q, G& X2 f9 A                  'state_dict': model.state_dict(),
6 U& J. J) N9 ]6 G, o5 B                  'best_acc': best_acc,6 T+ O9 s/ i/ O" e! s& m
                  'optimizer' : optimizer.state_dict(),
8 @0 W. i5 [1 A" |$ @6 j                }( f* X- L# Q9 v, T
                torch.save(state, filename)
  Z0 i' n2 l7 b- m- o8 @8 E            if phase == 'valid':0 z8 z. G' e. a5 P# W/ _1 @
                val_acc_history.append(epoch_acc)
  `0 e; ]0 X" d                valid_losses.append(epoch_loss)
5 s3 z3 @( y/ Q% Y' d$ v4 q                scheduler.step(epoch_loss)
, u% |' w* s; }- D% f            if phase == 'train':' y! K& Q/ [1 G' T
                train_acc_history.append(epoch_acc)& ~# M! V0 ~! ~# ~9 A! _
                train_losses.append(epoch_loss)5 l1 G' o7 A$ w& ?

4 L7 b, ]9 W% g. ?  Q        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))! ]9 Z( B8 V. U
        LRs.append(optimizer.param_groups[0]['lr'])
3 _5 I8 G: e9 D- q) J& F        print(): m( R  e" h4 }0 i  n  e  [

/ p: s( m7 J; K" g0 M    time_elapsed = time.time() - since: U2 K/ h/ u) F' p4 z, _
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
2 Q9 E7 M" T) F) p% Q5 _7 f8 c) S    print('Best val Acc: {:4f}'.format(best_acc))
/ b7 u4 x" \$ i% n9 m  Y
2 `6 |3 T4 \& e4 G2 w8 }    # 保存训练完后用最好的一次当做模型最终的结果2 k( _# c. d; Z
    model.load_state_dict(best_model_wts)2 H- ~# ]& S1 f
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
5 }& g! M& T, k. `  C+ b# Y6 T* v- o  p8 G3 f# E, v& b

6 m4 H( G( `! ?8 G1% }' ~! W9 v% f# ]- I# _. E' r
2
: ?3 b; u& {8 X* r% b3 X3! T+ U' k: K! j, v7 K5 s
4
+ o: a9 `! }1 _% g: e5
! {, ?( G! A2 t  e! G6
5 U' g" @( ?1 _0 V* H7 g. p/ {1 p7& S0 M/ \5 o7 y, p5 ~
8
* _- u' K% G9 Z) ^* u2 p, o& r9
8 H( g/ t6 q; @- K0 z- V10* w2 L. n' j, x8 Q' L7 Z" S$ e
11
8 g5 ~7 ?9 V! u* L7 W$ z7 [12
# q5 u& K' {( D  }* r13
$ _7 ?2 g# w1 U/ L14
. d! n2 y8 _0 v& A  }7 m) O, D9 B7 M! m15/ j* f& V5 h1 n3 @/ g3 m
16
3 ^9 h5 x! f2 U% y/ D7 {  r1 }5 V17; A) w' {% H$ p7 D1 u* Q% x8 n
18
% n" b1 e+ d5 v1 e/ u19: b* m% [/ X$ B1 E8 h
20
4 X& e1 D- P9 k# k216 D/ q7 e5 N$ |. z& Q. N* g3 h" f7 W
221 s2 ]. }# V! ]: l; @  F
23% R- x1 K5 @4 _: b8 Z
244 M& K* g- R! x2 g! a7 @) ?
252 t* k4 y/ }2 G& ~
26
, P. C+ V, ?. A0 U5 n4 P6 T27
  p8 A9 s$ C" M28
) L! Z3 [* W# X; [$ P292 l* `5 c  z3 j
30
6 @+ Z& p5 E- f0 z& \/ y5 @1 L$ B$ R31, s1 p2 M- n) I3 @$ ?
32
6 p' A, X1 B0 p0 Y9 A! q33$ r+ C8 N* P$ m4 F
34: |7 E( H9 `" e: |1 S
35
& Z/ C+ L& s6 {5 X3 e+ R3 L! {4 q36
6 R% \+ J! J/ O+ ^3 y( \4 l37- A& f4 _' |: N
38
& C$ C$ T) J* c- |396 i  Z( e5 v) V; g) r1 p  S! \
40
- s' u3 o  F& K: y: k! @. |9 J( D9 r41  ]7 X2 y& B( c7 o: B( A# f
42, R! t. C4 {; _$ E8 i) Z& E
43' u" y4 N5 j( c& W9 f
44% |1 }# w8 f! D6 u$ d5 l8 S) y) h& C
45
1 ]/ s; y" i, X. u5 |46
# r6 l% d6 g: ^& ?# r477 `: ~3 V# G5 w, T. r
48
+ x8 r6 I4 t4 Z2 a+ D( O: k49
9 q6 b# a  }) w+ I& Y; M. k50
6 j" @9 ^- O! x# N$ R2 D! F# ~# X51, e, u7 U1 D4 H- H
52
$ P2 s# `/ L& Q+ A8 M; ^/ |4 d53
: @' ^. n& S& U3 C& [541 [1 O: q& ~  T+ u9 r7 S
55# W% w+ O1 o6 a# `2 Y0 V, K. s
56
( X; B! ^2 u" N/ X+ X' N- G579 K. D, y! z% J, L/ _
58
, a! b0 `4 s7 n4 M( [593 r; F+ y3 ~# }( ?$ p
60: z1 k% ~2 ?; k/ G6 b
61
3 c8 c" g8 [( g  k& f' u1 F62
$ |/ F; }  _, o7 }( j) h# B63. z( [: C  u  P" x
644 B4 c' }" I- k5 Z+ G4 J
65
6 X( C( u; P/ z66' T6 {& L+ ~# Z' d$ D( [
67$ J: e; M1 Z; b' R9 z/ S
68
& ]) g. \7 X5 Q9 v2 L69
$ Q$ @* w  b+ S% z' |- A701 h0 l, N: s7 a2 x' |3 z8 B
716 Y  j7 K  o/ l( y
72
( h& J2 h) y$ q( U; _73
, ?' l# D' O' @2 q5 s! o1 _9 m74  c( L- C3 o; W
75) \* H2 @# n' f3 o/ }/ w
767 Z4 d5 x) k2 L. Y" L3 R0 F2 ^
77
- G7 T, T# m- e/ R5 ~  F78
% q) ~0 T3 G8 Z2 W79
6 A& |) ~' Y, z1 m/ }2 L- u* c0 r) U80
+ O4 W+ S. }! n8 ?$ W, [3 L. J+ L81/ [- E2 G5 V( m6 D  d* v8 l
823 t, ^! E; H/ Z
83" _, f6 h  C9 o
84
  _0 |3 D- G+ G0 r% R6 F6 N5 q1 M$ {85
) b! _, U/ E' e860 ^) U& C, P, K
87
+ x& c$ X$ p6 q- J' S# a; Q2 V2 I88/ C6 g' H3 J' y- b( p, W
892 R- h' D" i8 b+ H7 }& x
90
. [# ]* Q' |, S$ }. G8 ?91
6 b$ x  ^5 @) r. l92, r7 d$ Y( t2 P6 o
93# V5 f( t0 b3 [' m( B5 K& s
941 F# B% i' m  C1 n' Y. n
95
7 o6 u$ @  j! {9 O/ m960 @# w+ C6 ~9 i" o$ @* u- s. M
97
( f4 o/ I  n; f: J/ U98
7 W( K& F/ d  |  r4 D99
; J/ H; P2 R1 s& e2 q100: a& h$ z( x! y- a0 D) Y
1010 s/ W( J0 H* D/ {
102/ U5 f& W3 v% |! J0 O3 m+ N! Z' W
1034 x8 i: x1 R% _6 g- {0 N  T
104
2 r/ L+ j& L# S/ C2 U105
: j2 r0 G/ b$ K" ^0 U1064 W, B" z5 S# L0 P* K: q
107
$ k/ [- T# u% k4 p: \! H2 j1 u1 P8 k108
: r: ~+ e  E% S- F$ ~# s3 O109
1 U0 m6 E& V) ~9 t" N  G* V- J3 v5 [1105 [0 W# }6 o3 o) E0 V
111
/ d$ C+ ]! F, m1 y8 a1 R, S* a112$ i8 i% D6 W* r! J! S2 E
7.2 开始训练模型
( r) Z8 q# N# v, y: c我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次! y( Z) I! ]2 ^: r' k

( S( c! U8 N4 i. A#若太慢,把epoch调低,迭代50次可能好些: H& |3 k& C) Q/ w
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合* b' t  F% B, ?) o, C/ z
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))" k% w  x; K8 n; M, T! K$ R
* z( ?4 v+ [. B/ J) {( e& N% C2 k
1
" A1 L. t' P: i" E- L6 R2
' X/ m; z, P$ |4 e) G/ v) Q; n7 E3
2 c# J9 T6 g# a5 k4' `. J6 m& d5 k3 A' S3 d
Epoch 0/4) l# I2 Q9 C% t/ C, x; K/ o
----------
9 D+ ^# P% q; ?2 ?, |. C3 o1 V9 [Time elapsed 29m 41s
0 Z. ^2 }( N- B7 e3 ]# ftrain Loss: 10.4774 Acc: 0.3147, V: n5 Z1 o2 i: r
Time elapsed 32m 54s
; h) E4 l8 R/ l& t4 \5 U: ^valid Loss: 8.2902 Acc: 0.4719* C! D" J( l: M) k  h% U+ q& z
Optimizer learning rate : 0.00100006 @- m' A0 r( c, ~5 C, z* E; c, z
8 ~0 l9 X/ i/ A- c& T
Epoch 1/4  p5 b. e. X& c5 W+ v$ x- D
----------+ ]: [, F6 P' w+ ]# P
Time elapsed 60m 11s
; V- i  c9 W. d4 T1 ftrain Loss: 2.3126 Acc: 0.70535 G4 {$ h" s" `/ j- _2 E' ]( b
Time elapsed 63m 16s+ d; ~8 D: R2 Q" v1 S; b) G7 N
valid Loss: 3.2325 Acc: 0.6626
/ Q, Q7 V2 n7 h; r* `, U& L3 HOptimizer learning rate : 0.01000002 Z$ z* \/ N- s. R
( o; R5 J% `5 A4 p) f: z
Epoch 2/4- r+ z( s  G; B! N+ z
----------
( b4 Y8 X3 u+ M5 v; yTime elapsed 90m 58s
: P; v; n# L! l; E* ytrain Loss: 9.9720 Acc: 0.4734
; |3 r% g  G& n% J1 aTime elapsed 94m 4s
7 U+ Q! D. g. U- A, U& r' gvalid Loss: 14.0426 Acc: 0.4413
+ ^/ c$ U* Q2 t" c; P2 O" f/ HOptimizer learning rate : 0.0001000
- L1 G1 P8 Z. ]; j
/ C$ K( Z5 R" BEpoch 3/4) f2 @/ d- n' q( N+ R2 N
----------
$ R" [) N6 e# N( R% g, LTime elapsed 132m 49s
1 m& a8 O! N% D6 v( y  }- ztrain Loss: 5.4290 Acc: 0.65481 Q0 G1 `- }* D4 W+ R
Time elapsed 138m 49s
, N4 J: O1 {# W6 G& nvalid Loss: 6.4208 Acc: 0.6027+ u- ]! g: y: e& r- |+ E
Optimizer learning rate : 0.0100000$ f: w" t) u, Z3 m0 o, T% Y
# U, `) L" j9 b
Epoch 4/4
+ x$ A6 S6 ?* N, ]( w/ h/ x' i----------& |4 I6 X; Q3 a
Time elapsed 195m 56s
: }: t9 n$ n4 T0 Z/ R8 htrain Loss: 8.8911 Acc: 0.5519; P0 T9 z4 c  Q) F/ K
Time elapsed 199m 16s/ z' [$ \+ i& u
valid Loss: 13.2221 Acc: 0.4914
9 b, T" ]- U2 V( P0 E# B5 s5 GOptimizer learning rate : 0.00100002 u2 J0 _/ k8 _3 q. q

! K- Z3 B5 Q0 k5 F6 M- V' JTraining complete in 199m 16s( ^9 H; r9 Z; m: |" a
Best val Acc: 0.662592$ E+ i  w9 o7 Y
" W' H1 Y" J  @8 N9 z0 |
1
4 X  v" ]2 t) f  K2& g* U- l: `4 D; F! J! p2 ~, x
3( Q! u3 ^1 i- S/ ^
4
6 q9 E% u' c( _1 q, r5
1 @) x% i9 a% N$ \6
) N0 K$ V1 m( ^1 ]$ @; W1 F7
/ _- r% X( p: g; K' K8
$ a; G) e, P% j' O- l2 F8 P% Q, h1 V9
( ^) t  v, O. t+ b/ X' v; E6 T10( O" q: S2 c* a* d) {5 ]# y
115 G8 q: z! {% `* Q
12
1 J# }6 Y3 q( B( f& i; I$ I1 ^136 V9 {" g" X2 {7 {: N
14
, q; O, K4 u$ i* U% B+ `4 E150 z% s$ m" M* P& f, V2 N
16
/ j( h) Y+ V6 S! f- j# q. ]17  }6 v2 d  m% a/ w
18# n$ B8 M. p) Z! X4 B" A1 I
196 J) s: z( `! Q2 d
202 z9 }* [% i- y% i( a$ r
219 l3 g- C3 L& r
22
" V/ J6 z+ B( I% _* C23
2 g7 I7 b! C$ A- K6 k3 \" f- T24/ c9 H/ K$ S) q/ m1 R; Y  Y
25' D/ y* D" {7 v9 m
26: C3 N8 C& K2 S) T+ _. ~1 J9 q5 h2 U
27
" B6 C8 r/ s4 Z: n6 N% k6 X28
' A+ F& T% h1 I, m0 c29
+ [5 U- S! e  y& @30( G6 M: z# `- m6 r/ T$ C
31; X# C0 {# F- M, L, [) j
32
( O0 n1 _9 ~  T4 J331 z6 S# _! P" U
34
6 W% }+ G8 e: f/ z355 A; D8 i6 {7 i% A. o
36
. Q5 x; M8 W3 O* _% N! N37
# `* j6 m, e. U' B3 p- T38
. w4 r9 V3 J8 g39
( T5 a1 w6 ^/ O7 X* u: h40
( \- j* i! A, n# W0 M. W41
2 Z! j2 a) s3 K" D  m1 u; |) e' I42# F4 D& ?- i: W# v$ e2 L
7.3 训练所有层& {, }/ }$ }% D: e
# 将全部网络解锁进行训练
6 z" U. S& @* R6 U9 U4 Zfor param in model_ft.parameters():
% [8 R+ Q, [, P* X8 w7 \    param.requires_grad = True: s# \$ O. J2 V- V
  G" H% M8 U$ _$ M! w" m* d
# 再继续训练所有的参数,学习率调小一点\
4 y! {5 T* b5 U4 h9 C  z+ a' t" M5 D6 Ooptimizer = optim.Adam(params_to_update, lr = 1e-4)* j2 Q% p0 D! T
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)2 G) r. W3 K" l: m8 y$ ?

% ]2 Y* k' Q6 b) i2 X# 损失函数4 O1 l, r( d; K3 K; y& h
criterion = nn.NLLLoss()
5 d8 V0 [: F, z1  u5 Q: `, E5 w$ {1 Y( Q
2
9 n* F( M8 f. X4 c; P" S( `3- Z8 D' {: Z/ t
4( O9 _4 _, D. v; E) X0 k* c5 q
59 C+ ~4 J! h$ K; E2 N: ~; {
6
6 E6 M- \7 ^; t; v  H6 k+ x% n7
8 C% S; G. A8 _8
% c# F; g. ?4 A; D  w/ ?: E9) Y7 a! V7 h/ O1 z. [
10+ W- F2 A# E8 H' T- `/ k  M+ F
# 加载保存的参数( k5 }1 @, }# w- s$ [
# 并在原有的模型基础上继续训练
* |5 Z1 i7 ?4 _3 c$ @5 K$ O# 下面保存的是刚刚训练效果较好的路径' W* J5 e1 V$ D+ q
checkpoint = torch.load(filename)$ O' U" V; w8 c$ ~- W# j
best_acc = checkpoint['best_acc']
' ~) {% {6 r, ?7 C. z9 Nmodel_ft.load_state_dict(checkpoint['state_dict'])$ e6 ^5 e9 q8 |& U* M- e
optimizer.load_state_dict(checkpoint['optimizer'])
; ^+ B* q  K3 J1 L/ d/ @) q9 {# W1
7 m! B9 C, |+ r7 t8 t$ s' }+ ^23 ^& H) t5 Z$ l5 M0 j
3
6 J/ A3 L0 f: y; D. G/ G41 ~( F1 p' g, O
5
0 ]3 u6 i7 p0 R68 z- I9 x$ G) D$ S% N
7# b& [5 O5 O7 D. S+ {  Z
开始训练& }; f- h, c- G) C1 ~. Z
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考! q4 n; y: N2 v; ?' t) _
  d9 l" W  @1 y: y# A
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
6 M3 z5 o4 H  }* H  J( I/ N# L1
* a5 {: [3 e2 `# ?( u0 GEpoch 0/19 i+ Z6 W" \$ A' e& T+ Q! {
----------  l! q  i1 j: T0 W3 T% W
Time elapsed 35m 22s
; f( |5 `+ v# a( ~( Vtrain Loss: 1.7636 Acc: 0.7346
! X+ W: @- a/ v: w% s8 mTime elapsed 38m 42s
9 A! X/ Z3 J7 i2 D9 avalid Loss: 3.6377 Acc: 0.6455
" a( X. n& ]' T' s9 [Optimizer learning rate : 0.0010000
- r! l: S" s, j) K* T" C0 c) U8 p. u: M' O+ }! W  N$ C
Epoch 1/1
4 p: y2 w) K/ y6 Z6 c+ Q2 r  U, O----------
# q, }1 I. }  A6 ^& J* k  pTime elapsed 82m 59s9 }: s1 u! a9 I
train Loss: 1.7543 Acc: 0.7340$ w0 ?$ m. c* e% X8 x, F0 S
Time elapsed 86m 11s) O# i. ?' |6 t( e! c0 L9 P
valid Loss: 3.8275 Acc: 0.6137
  x7 w( _1 m& F) F1 `Optimizer learning rate : 0.0010000
* T! r# [( b. {" _* d
  w# }5 ~2 e/ A9 e8 A% v& n" ]Training complete in 86m 11s
+ \4 c9 }5 P" t  PBest val Acc: 0.645477' d  @4 p7 ~$ z, `2 i& Z! ^( Q' A
8 L' l( C, J8 I/ @2 r0 S
11 p* V, k1 J* |% R/ P
2
9 v6 ?5 A- z: w! K3+ P+ m2 J& g, q# \) U& {; f2 T  Z
4
! D6 J0 t$ k' j4 s* A5' o/ Y, ]8 C8 P7 {
6* }3 ~/ S9 Q0 n+ o
7
2 K( X  B; Q& B4 c5 U- E8
% O% O7 N  O5 D+ o94 W9 m' k+ d7 i, w
10/ `& c; N8 i4 N" z+ P* ~$ n3 d
11- F6 G5 E! o) v2 J5 ]
12
" J% W& C' X( L# A7 h9 G13
) i6 t( q, O& P) L5 {  |4 E+ m6 [14
3 Z: S) v6 n/ h, C, m& s15
" m  {# u7 L* o) U16
6 t7 h; e: I6 p* \8 M; K( c6 @2 S! G2 J17  I9 D. y( |) K5 O: C: E
18% C, y- V& A  z" d4 z
8. 加载已经训练的模型
3 z! m: a6 M3 `% Z相当于做一次简单的前向传播(逻辑推理),不用更新参数
3 W/ t+ }3 Z1 x( I  o# K( c2 N* c  R/ G+ s5 h
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
" H6 }; L, _: j6 n+ i  W% l" {; L
0 c% c. ^( y& `9 [+ g; |* `- Q# GPU 模式
7 ^' P- G$ p" R1 `- V0 imodel_ft = model_ft.to(device) # 扔到GPU中2 N" V. f* S- [  H+ ]

2 n1 I7 J' {$ g1 T' Q9 Z( C# 保存文件的名字3 R/ o% o3 ?  {/ W" \
filename='checkpoint.pth'
9 i  P: \# _! ?! k
# x2 w: v* |* l" x3 |' n, t* V# 加载模型  v7 ~" v" M  v( ?, l/ t9 Q
checkpoint = torch.load(filename)
' |+ @7 `6 H* B' ~- ?  |: ebest_acc = checkpoint['best_acc']" Y) ]2 P! F( W+ b6 D  w
model_ft.load_state_dict(checkpoint['state_dict'])
. z% w8 |& b0 G7 K# T0 s1
9 e2 @" T6 s# g8 a5 {" Q3 N* A- [27 P8 y) n9 g% W* J+ U3 J( M# y% q
3
$ G4 Y7 C% Y2 E/ w5 L5 S4 Q) _, N8 O44 V. V2 U3 H2 e" S+ d7 X' d' W5 r
5' w2 f, V* |7 X& S, v
6
7 C$ ~- `' y0 D7: k9 Y+ a8 K0 ~- f  ?) t
8
1 w( t( p5 b5 @9
9 C6 m9 ~5 O% O10
) s* p+ G; Q' n0 \+ ^( d1 _11
7 W9 o) p/ w  ]7 N. r125 I4 G0 r% I: ?. ]; H+ d% Q
<All keys matched successfully>
" G1 N% [3 ]" w6 h1
* Z  Q+ W, |7 H/ Tdef process_image(image_path):7 _+ O1 F" H1 T# e
    # 读取测试集数据+ k: s$ y. K0 B7 M+ f" J7 ]
    img = Image.open(image_path)
6 R, L; |  ^0 M+ J% k5 E" K    # Resize, thumbnail方法只能进行比例缩小,所以进行判断
- v0 b! y0 |: L! ]: d! O! y    # 与Resize不同  b0 ?, _0 |4 Y1 U" A. }
    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
+ z. z  J2 a0 v5 p- u    # 而且对象调用方法会直接改变其大小,返回None
* _1 M) `1 a; C+ G' V    if img.size[0] > img.size[1]:
4 @( q) ]& ^+ X( S4 l        img.thumbnail((10000, 256))
8 Y* x& k# w$ k2 w4 e- R# m    else:5 s* z* [) \4 q' G# O5 e
        img.thumbnail((256, 10000))- x/ O- ^! @; N& j% m
2 W# M" j' R" I# r
    # crop操作, 将图像再次裁剪为 224 * 224* o% q' x! M) d" J! C, Y2 l
    left_margin = (img.width - 224) / 2 # 取中间的部分+ i, s/ x# h1 V  ~, ?6 N+ [
    bottom_margin = (img.height - 224) / 2 8 I9 i/ ~' Y* \, c. R
    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
& }$ f, b- S7 Q: d. f    top_margin = bottom_margin + 224
& W9 g/ C5 R3 A4 b0 z: d4 D7 g+ z' H: M% l4 n, N
    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))8 i2 e) C( Z6 T# t/ E1 t, O

! }. S2 _" o* }5 k2 _0 r, j' u    # 相同预处理的方法
9 ~% y2 S' ]+ f& P$ R# G; y( a2 P    # 归一化3 M" k* F. y$ T$ \
    img = np.array(img) / 255. U5 s  T5 E& D+ ?/ k8 y
    mean = np.array([0.485, 0.456, 0.406])
0 y6 T% t4 f* r/ \    std = np.array([0.229, 0.224, 0.225])% o% Y/ q# X2 [5 S* v/ q
    img = (img - mean) / std: Y6 D9 j( |+ W# x

* F; ~( I" }! u5 a( }! }    # 注意颜色通道和位置
& }6 B3 `* [1 h6 {+ P/ |    img = img.transpose((2, 0, 1))8 b; U8 u4 G3 c4 ~" S  a! G6 S  c2 ^( i

, Y; u+ ^. f# C. f    return img
% z: a; h/ t) S3 l3 ]! O( W# R3 F, q& E3 n7 z4 w* l
def imshow(image, ax = None, title = None):
+ |3 e1 r4 P6 Q# k; z    """展示数据"""' r. C9 d* B% b1 Z; z  t& \
    if ax is None:3 m' s2 G/ c4 T. u
        fig, ax = plt.subplots()
& X& K( B9 {) c) a- k6 W3 o! T% T  w5 R& p$ G6 r$ f1 `9 a& l' V7 L4 ?
    # 颜色通道进行还原  Q  h& v5 E8 S# t! n
    image = np.array(image).transpose((1, 2, 0))' c1 I9 o, Z4 h' }- p* J
+ f) x& t- ?% ~% J$ Y
    # 预处理还原
- w0 {6 P. A$ C9 E) ~1 h+ W% {    mean = np.array([0.485, 0.456, 0.406])
; w& Q* W( {4 a    std = np.array([0.229, 0.224, 0.225])" Q" q6 Z+ N1 E1 ~
    image = std * image + mean; C* A" w! s. x. C9 r
    image = np.clip(image, 0, 1): j  t: t; N5 K! o7 P

9 S( v9 o/ r6 b, s    ax.imshow(image)
+ _% o; U0 y, [" ^4 ~    ax.set_title(title)$ p7 h/ r, H2 r* k( H1 r3 c- J

  K4 j6 l, U% ?; T# e( ]# A    return ax9 y/ k% q7 K& o, {" Q
5 L: Z: W. |0 G- v0 i1 h
image_path = r'./flower_data/valid/3/image_06621.jpg'
& F/ U  V7 {! Z1 h* y$ `" ximg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理5 c. ^# Z* B4 Y. @2 \0 J
imshow(img)7 G! W: L. Q, K4 N
0 x$ Y- r, Z# I7 N" S
1
9 n4 y1 l# E, n) L2
! B( p6 \- o! }! S6 P/ d# i9 g34 L3 q" s& d, H( ?! ~
4
7 {% V: N7 f. D+ \4 o( j6 r5
! j5 O6 D8 v/ C% [6
7 `( `/ _; Z0 Y) S# V/ ]7
- q* Y0 B1 p4 \" N6 F8  X( M' m4 R; N0 x) |) x. ]
9
: Y4 e. Q( ~* {; ?10
; e  }: P8 s8 q+ B. k11
5 i. v: G  g$ ?" l: E12
: a% |7 K" Q, k" \, y! x2 p13' J5 C; _) a1 w) [) W; q9 Y; `
148 [* p- `% x( ?. S
152 I) U0 {/ k7 X4 b; e
16
# C: L4 j2 i3 |* D9 w& }17
6 ?2 I8 j* M/ L! X18. J# v( P/ W8 S' Z% f
19! \6 S) H4 R' ^; P7 [. K& z( P1 ?
20/ V( k4 ^3 o2 W( n
21
! p7 t" }7 r6 ?6 B22
- Z; Q- D/ N8 a1 ~' q! C23
* P5 l7 w" @( F# r; U( b2 |245 y+ U# O, {/ i3 `" q4 w. d; M
25* {/ Q+ W8 a9 [0 R6 Z# J* Y; f  Y
26
  T) }' R- |; U, c+ W1 u27; a5 J  z1 L3 j2 d
28
' F! o6 R% f- K& U! f29
' h! Y9 p; e7 a# D2 w0 c; u30
& ?3 t! x* e4 \# g- k% Z" p31, m2 t+ A9 m) {# T! b  V
32
3 D* `7 I8 y; ?33
- \) h- A# o* N" d$ p  x1 B34
' F3 S9 @( p! X& n$ r7 `+ f" V35
5 X! e" ^- _8 s0 F, H1 ]363 ]- \% ^' G- x
37
, v" j$ e8 ^1 K. Y38' @$ }: o1 f4 u! c+ C
39+ I% v$ {8 ~  O5 X, [0 ^. i  h
40
4 n4 ]1 Y- b  x! P6 }6 \. Q41
* n8 T+ a' ^- }. H4 @* X2 i42" ]+ o4 X7 i3 K8 x0 o& c" t4 |
43
' l0 m/ E6 [; m3 ~44
; [) w0 l: r) F! H45: v( f: K2 P+ I4 m5 o; i* x
46
- Z# m* U( O; B4 l4 r- {47
! [8 ], b1 G4 K2 M; t484 \: _6 P; C4 M
49
, g0 A# {7 O( c+ ]( A3 x50
* I: N+ e/ ~3 d0 Z51& [' S5 M0 q' O- o
52! p, k1 r- X$ ^' W) _6 E/ e
53
; t5 x) K! x  [54
' D  M$ h4 U1 t7 T; ^: n<AxesSubplot:># C( v* o; a0 y; ^/ d) ^( r; J
1: @+ F" F! p! i1 Q" B

9 e8 ~7 u! K% q& u上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确% y" o1 J$ B% K% ]8 ]

- t1 `" L5 ~$ c$ l& Gimg.shape
0 y+ U# i" M: f5 g* U# t' U, P12 C) \+ _3 Q) n6 W( a9 w7 A7 A
(3, 224, 224)" _) {% Y, ]! T8 @3 w
1+ p/ T5 [( p' h- M- f, M
证明了通道提前了,而且大小没改变
4 B; }* u  M) p- @7 Z, a
* o/ b$ E: e; z* T, V& g7 y9. 推理6 ?* j/ g7 m* l* A+ J& q
img.shape% K# w# b+ A$ f( p5 V/ R

6 E: D; Y" P/ |; y0 e# 得到一个batch的测试数据. w. a2 {) p/ }. Z
dataiter = iter(dataloaders['valid'])
0 ~1 o/ l! x0 Vimages, labels = dataiter.next()6 s: u/ }  l# m& A

9 F  w; }* _( ~7 Mmodel_ft.eval()9 O5 }5 S3 E; p" x) V
" ?9 @- r" z: B& N! [
if train_on_gpu:
3 I) Q: y9 [4 D% ~" w+ z$ M    # 前向传播跑一次会得到output( d# m9 Z) `/ F9 `6 @$ s
    output = model_ft(images.cuda()), I6 R  F5 d7 g+ Q2 d
else:, ~% c$ c0 X, s! L& Y
    output = model_ft(images)# E5 [  Z7 y8 |: p

' @" x& f% B; g* p) N- J# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
+ h0 `2 D" {$ E" C4 m$ Coutput.shape* U+ a; u( T+ `& E' ?- }" s
4 M, r- B! I! M4 }$ d2 T, x+ T$ e6 M" m
1; }0 ^! b1 ]% Q$ t9 ^+ ~. V4 m
2
5 G/ Z# o  f& H1 W' {0 |. L) `9 u37 m( O+ F  Q' |* o1 v! Z
4
7 T! a" i5 \# h" ^' i& z4 T5
5 E- K) h4 D) [, R6/ p5 s- Z. N9 F9 [2 J
7
3 A! A, B" c& Q" `* j8
3 t$ S# E- K5 T9 h1 ~3 J" ~9
/ Y- L; n" i, M& U8 p: I10
; D5 ~  t9 v, O% ~- S11$ u+ O, ^, r. D
128 x+ s2 x! P' t2 m- p
13
, m) S" K2 x+ N14, Z( Q8 t: }( Y" g
15
- P4 ]% J* H2 ^& J, @1 z3 U16
" y- v8 e' `& Ytorch.Size([8, 102])
. V& J+ Q' I( H. _  _1' ~: T% P' \4 ^* o7 S
9.1 计算得到最大概率  S* E* S5 g2 w+ t3 M8 X; L
_, preds_tensor = torch.max(output, 1)# n+ H4 `6 O; c  d
8 D: v& a' y% h- M5 x
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
/ Q! g! N* }/ T9 F# G( a7 r! E; z3 s1
  f+ D/ J2 U( l2 o0 ?; u2
$ j4 q: U: Q; D; a; J2 n9 S4 v3
, J$ v2 g, H9 D& _5 r9.2 展示预测结果$ j# K. G3 ]7 g8 r- G$ o
fig = plt.figure(figsize = (20, 20))
# a" D, p/ y$ F5 H( g7 z$ mcolumns = 4
% \- Q. v' q3 u7 p; R* \/ s8 F. Q% xrows = 2
3 J1 @- a8 r& J' j( n% ~
( \  e1 D9 M: Zfor idx in range(columns * rows):8 g1 y6 l+ N3 D( z2 J1 h
    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])- i/ u& B7 F7 ]) {! i6 s
    plt.imshow(im_convert(images[idx]))
( P, Y; D+ w5 h1 m    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
. u% d9 {$ A! W/ K& K                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
2 S4 |5 ~" X6 w: gplt.show()+ H; W( w' ~' \6 Z% C% J$ x
# 绿色的表示预测是对的,红色表示预测错了- Z; F& Q  g+ h' o& {0 }
1# ]1 j+ n( o! R  E0 Y9 k- Y' |
2( e3 Q0 L" n; H7 D/ \% l7 C
3, r6 h8 g# _9 b( c( _5 U
4
9 K: q4 V  Q8 [5
7 U: ?: N! C3 `& f66 `8 V$ l, n4 G0 Q- U8 r, U( I
7+ `1 z, P: K! _. y1 [$ \# y8 S
8# S( ~8 s9 {, A  I% b$ Y  g
9- r5 g( W& z" p+ P  P
10
  r' W% J& i' I& N4 q! ^- E11
5 b0 f2 V9 G- x% s9 \/ U6 K( E2 R: W1 _8 s3 N: X3 g; o
, ?$ Q5 t2 {1 e& L

$ G" V  X2 B' J9 J& c7 }7 @( v+ T* g————————————————
/ _/ o" a' q6 q8 O' U. S0 G: p版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。- H: z* t8 E8 ], b# j  n* X
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
9 X/ z) ^1 i9 ?% Y) v' R+ a3 g
7 T7 X+ r( p7 u, l8 L- r9 q8 l) f) v8 a& ?1 @6 o% K6 s





欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5