- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563274 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174205
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
* J: l! S! {- t* u- m. Q! b- x- p& r K& b6 ^' h1 J8 J* _' f
文章目录+ }# ^( D, o% g+ W
卷积网络实战 对花进行分类
`9 d; }2 K. p8 g& S2 d2 h3 J数据预处理部分% F. j! v, S) T+ B2 B" Q( o) B1 M
网络模块设置" o& d. v0 O% p- _7 r& C) \. i
网络模型的保存与测试
' }' }+ d3 k( _& m& Y0 g数据下载:* e* M, I" Z: y, T% G0 w n; c
1. 导入工具包) {6 ~" n$ r7 j8 C" ^
2. 数据预处理与操作
4 ?8 M- k' b: m5 c% [& y& K( ]3. 制作好数据源
6 p# Q8 p+ s2 `& w. e5 \: k读取标签对应的实际名字4 D# i1 z4 p% A$ w9 v" H4 D7 ~
4.展示一下数据. w2 h& w8 f3 P
5. 加载models提供的模型,并直接用训练好的权重做初始化参数& t# ?# O/ `7 g& @, Y! @ E& c8 N
6.初始化模型架构: ]6 e' A7 Z4 @4 A& z1 _
7. 设置需要训练的参数/ t) _: O8 m) y6 o! i& s
7. 训练与预测
" f3 }4 L0 p* t- s7.1 优化器设置
* V9 e/ A3 J2 r8 P3 X* X7.2 开始训练模型2 |% g5 c" o* X
7.3 训练所有层7 w) U' C# r( D9 }
开始训练0 P, {4 `# G0 p$ @1 k
8. 加载已经训练的模型
0 C9 \" U* @7 ^$ l$ u+ Y9. 推理
% m+ n: E% q: L0 E9.1 计算得到最大概率
6 e" r: X9 t w5 _; C9.2 展示预测结果- h$ C7 V4 f H4 g0 A
写在最后6 a2 G6 V$ ]! r6 }4 r
卷积网络实战 对花进行分类; q% Q& V+ {6 r* x
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能4 q; |3 n: q, r7 e
) ~2 D w( R% g7 W, w# c, h" H
在文件夹中有102种花,我们主要要对这些花进行分类任务) I: i' H# f, D q2 |
文件夹结构* i2 I3 t! x! ^6 z
& j/ x$ W8 S/ }0 G. |
flower_data
5 _# i Q2 H5 t7 K5 W* l* C5 v$ c: \
train
5 M! P0 @4 g! R0 x* x0 D
8 _+ ~$ w5 w# T, i" T: m1(类别) a4 B! C* s- D2 T8 Q- |4 Q
2
. z! _5 K( Q7 _* \. n+ [8 J, `- G xxxx.png / xxx.jpg" \8 I5 L* F: \# P
valid: m* T8 L) a+ P- D
/ Y! S: Y% z" N3 C+ l' X
主要分为以下几个大模块
( {' W( g1 Y0 Q5 G8 Y! U6 T* a# y, V: L: h& _" F# R
数据预处理部分
# d. m" h: R" E$ `, c' L* X数据增强
+ o( P3 ]5 [7 n N& @数据预处理7 O: |" |8 U9 v, U' {
网络模块设置
5 W* ?; H% L* S4 b加载预训练模型,直接调用torchVision的经典网络架构
9 k! r5 K1 t% W5 R- Q# U- {! w, M# R因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务; }6 U2 n. r! S. L) B: J) B
网络模型的保存与测试
- A7 K3 U1 c. w+ t* |+ f9 u模型保存可以带有选择性
8 g6 N6 g1 N& C/ D6 Y数据下载:; j7 f% c! |1 i0 w2 U$ u/ W. f5 L
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset" ]% L# S! p9 R- P4 G* A
8 y' D$ |, z V6 o
改一下文件名,然后将它放到同一根目录就可以了
/ y. l0 r/ I% {9 s$ i% S. ^5 w" Z' K) J& u$ T" Y0 y1 m! N9 d. s2 J
下面是我的数据根目录3 ?* Z* ]" b4 |
: O; c8 w D. r1 T
4 q6 A2 B N& V( P1. 导入工具包
# E7 u4 r- U6 |( Fimport os' C. E a/ v1 m+ i
import matplotlib.pyplot as plt
' y/ _' m- D! {2 L# 内嵌入绘图简去show的句柄
& q2 M; z+ \& M d4 }& t%matplotlib inline ; _/ P/ U( k; R. j0 f" F
import numpy as np% o8 ?( t* k1 u: q; i
import torch/ l9 ^9 c8 K4 N2 z8 U2 g' m
from torch import nn
. r6 S& H* H* h
s7 y) X& y: [9 x) n8 qimport torch.optim as optim
l4 T- `) X' L: y6 z, t! K4 timport torchvision" _ H7 T7 E1 W' L7 a" c
from torchvision import transforms, models, datasets, x) @ @1 Q+ J: `
0 n; P- E% O: W5 f. m! D
import imageio9 Q; W V, s1 K2 x
import time
, y0 v6 ?6 r' S4 c9 Oimport warnings
) n& k* I0 D/ Nimport random
$ t: c3 e* \" d4 Iimport sys
6 T! e2 t$ d* r* wimport copy9 \- `5 @, `, L( W y
import json
6 K# E8 X3 n5 G% V) }from PIL import Image
$ k& ]0 b$ K3 n" B: J' r1 T8 M. E" I: }5 t$ E( @
/ s7 D* | V C8 V/ o5 Q) _- m
16 |% W: n1 n7 p- y2 a# S; o! s! K
2! p- c- v n! ~" o2 O. T
3/ T7 p# E0 C) l) `: ]: W
4
% B* _/ o5 k, @# e6 }5
, X+ S: O3 d4 p. i6
7 g3 e' ^+ `% ~7
7 Y2 h! j% C3 `$ y) {2 O+ y8
# ~1 ~8 h( E2 {. i9
" _% D' `2 F4 P$ N( ~1 Z, E% J10! E5 d4 n) I/ ^$ W4 U& h. `; H/ L% T
11. t2 @0 v% X2 }5 P1 A2 F0 Q
12) k2 z, }8 ?1 i- |0 e* I7 m$ d' f
13
7 @ \' N# v6 a14: E* y! A0 S: p4 `' m
15 m6 X8 b2 }! G& ~1 F4 D
16
: h7 \9 c6 N+ y17" m! G# B" d9 A$ G I, _+ H4 R
18
, f" n3 m9 _" b19) K" F6 c, B" M. ` L0 m! o6 x8 y
20( W$ D: |5 z! [5 ^6 D% g+ [
21
8 M3 m" X/ ]- c+ F' _( a2. 数据预处理与操作+ m! h; p3 \ ]) W0 R( g4 v8 C
#路径设置
. X4 e& ]" h7 \% Ldata_dir = './flower_data/' # 当前文件夹下的flowerdata目录" M% w4 I7 p) d6 \- g
train_dir = data_dir + '/train'
9 \% t6 {1 W8 a, {valid_dir = data_dir + '/valid'
C. X8 z/ q8 v4 n# l5 S% D1
3 l0 [( w# |4 n1 O2. \' F: K% {% k+ n
3
! l6 W! ]& N- v$ f) d4
4 m& s2 N2 G6 [9 t$ npython目录点杠的组合与区别
. L( a5 t! r2 Q注: 里面注明了点杠和斜杠的操作
, a# @, H! m" m
; B$ T& F- l1 K/ |3. 制作好数据源
: F* ]# o& c# p0 ]& ~) l" odata_transforms中制定了所有图像预处理的操作7 q& ] D3 [* V7 Z$ o3 }
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片" N; N; y b X1 [1 B
data_transforms = {
|- Q" R+ B7 Q% |9 h # 分成两部分,一部分是训练; y2 X: e; o( u7 X2 D. d
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间6 k8 r4 Y* i: O3 i% I
transforms.CenterCrop(224), # 从中心处开始裁剪, {& n! T. H0 D' ?4 A
# 以某个随机的概率决定是否翻转 55开2 G' U: F+ U1 w N6 p u
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
! e6 H3 W9 F1 w! I8 Z+ U transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
1 ^) E8 |$ |! j6 c+ z6 F # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
5 E. i4 Y& U, K" L/ k! x$ G transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),9 J q: z! H" B! M2 i' J+ q
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB9 g, ?5 y: ~. m j8 N4 F6 ~: e
# 灰度图转换以后也是三个通道,但是只是RGB是一样的
: l3 v: p- Y; Z3 t, [- t transforms.ToTensor(),
- r- C- h7 d( K; \% y transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差3 b& j- ^( h, z6 v) l5 @! a+ u$ B
]),
, I) y! g* g& M( l # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化6 n$ @* b9 v! P$ X; b
'valid': transforms.Compose([transforms.Resize(256),$ p/ Y2 k, i& U+ M
transforms.CenterCrop(224),
* k4 t" S, E( ] transforms.ToTensor(),
: L2 a& a% x. z0 ?0 Q. [& ? transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同8 `4 [* n7 |# h7 f
]),
! @$ L0 l! J1 Q" c. ^}
4 p0 x' _1 s7 R6 v
% O& T: b7 ]8 i1
& R. r7 { n: _1 `/ S' j# a1 o2
$ M: D$ D4 n4 K) W6 o5 Z3
* v# V/ `% S, j' p4' U% W) M) O+ N$ C/ n# ?* K5 {
5
7 _& _* j0 P7 | \! Q6
/ Q/ `) B+ ?/ P/ W( \7 m* Z7
e' U/ r; z8 J! k' c& `2 q' S8, A6 K4 U. z1 A
9& A( y* L7 n1 D0 D# q/ g6 ?; d* U
10
, J0 s* y) j) X11
( Y5 H0 a! R/ t u' S12
; ~( m0 j" }" C13- F2 |& d. u# a" m' ?: z
14
& K; c3 w2 w2 J+ }+ S15
. F8 ]$ ~2 [: `; q" \16& u+ |. w$ S: N j( E+ |6 b( e i" b
17$ i( l* N5 U6 A( U- B
188 E- H5 E- h. g0 y* E0 N2 n
19" i6 h8 z" ^& ` o$ K) l. W
200 @- K+ g: J) C
214 [. R. A6 d. t5 v- M4 ~
batch_size = 8
; r3 b4 p7 Q1 _$ l2 T* |image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}7 [; ?' l& M( _' h) [! {
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
2 H3 V6 |. S$ j3 E5 ~dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} 4 j* D4 m5 Y/ ]" l. E, e
class_names = image_datasets['train'].classes6 R* Z7 F z, Q5 K- ?5 m8 }/ a4 U5 h
* |# q% {9 c; q, b1 d$ @; R#查看数据集合
1 i+ S& s! \6 {# `/ S7 Bimage_datasets4 d. N) H& ^7 e9 c: N
# f. G1 g K, p' I
16 b u) E# S$ e' ]4 D
2/ t' J8 |3 l( X, |) H& e, o
3
/ Q) K" D( S7 E% j& q46 z8 B# \# B7 G7 h
50 ?; `7 _; m$ U% t; h
6
- [- E9 x- r8 y4 V# V$ ^7& H$ z& B/ n& G) R7 l2 v
8
9 F) }) V8 G0 y9" G. |' H# F2 p3 Q/ {- {
{'train': Dataset ImageFolder7 ~+ _! r* P9 m% `! z
Number of datapoints: 6552
* U* M/ B8 ]/ ~3 o: f* e Root location: ./flower_data/train) p% C# j: m+ L% t/ L* W8 N1 r
StandardTransform+ o$ l2 Y4 f1 \7 Q& ]' F8 N* d. j5 y0 k
Transform: Compose(4 L: X/ x; W4 C; j+ l1 N: e+ u% {+ v
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
6 h R4 L# m1 K6 E% b, [ CenterCrop(size=(224, 224))
z0 k* j1 H; @. N1 A( I. ^" I RandomHorizontalFlip(p=0.5)
2 ~ S# [* P! t RandomVerticalFlip(p=0.5)9 p+ L9 M2 i* x& A
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
, g+ Z! G. u% u+ s. l+ ? RandomGrayscale(p=0.025)
! u: b, m3 v" R% c, a ToTensor()
% {2 s7 {6 ?/ u* x Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
4 g, c4 u! j; r; c& q3 V% K' w ),; |5 v8 {0 m B4 Y" m
'valid': Dataset ImageFolder- z: X; w) S& W# o. S1 n
Number of datapoints: 818
$ q- `& K' c6 g6 J Root location: ./flower_data/valid
u4 X2 F6 G6 ~* k% B/ J+ K StandardTransform& E6 I/ I4 z" g8 ]
Transform: Compose(
) E" t$ M. e5 R" M) a8 z" D ~8 [ Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
9 L. f: D% u! d CenterCrop(size=(224, 224))
3 W/ m: Q0 ?5 Y$ {1 c7 x) \$ T ToTensor(), [5 `) E+ ?" M' ?. s5 X
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
& m! O. f% K' b0 g; `8 o% y )}% P1 d$ W. c9 t, o! p5 W
+ A- z5 R( a% O1
: V' E5 H; c8 O! e8 ^, e8 d' X2! ~1 T& m0 u/ m
3
7 m9 _& l# V8 q9 `( T/ C4
9 Y6 ~+ a) S# v5
) E: a6 J& H2 H8 n+ \6: q% D; B. D X7 O
7
* A7 m$ M, P2 ~: E) X8
4 E. H2 Q( q& H y+ }; O9, P% B, O) O) e! ]' Q2 o6 ] |
101 T% }% \- X$ ^& \3 a" h
11: E! P J5 X, O
12
5 S& M6 P4 n* a4 w# d13
. H( e2 H; X$ @ W; o2 [$ k- L D141 S, G; K3 T! F* `# _/ r( x3 S* D" x
158 Z3 l5 `; ]/ d) T3 L2 Q
168 Q5 |' t. F9 J6 Q0 S4 [
17
% n$ u5 }- F6 `4 y1 i5 P183 I+ _- z/ h& {( Q2 u- D2 q
190 K* k# u( {( Y1 P
20' f. i6 g# Q4 t' W
21$ j8 h& I& j2 k. g; G) D
225 v. G6 w9 Z0 ]# Q; e5 ^- s6 {
23# D+ N2 M1 w% f4 S
24' j( U" }# c/ H+ J, v; s# S7 t8 n( \6 }
# 验证一下数据是否已经被处理完毕
Y% W; |7 m; l. M& P2 ]dataloaders
' y& X6 L) Z3 K# J l1
; |9 L; p& Y. z0 k' \# B2/ w+ {+ s" k. }! l# {4 R( j A+ C
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,, x- C! a) Q9 x l8 d, y# _' @$ P6 u
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
2 A; J* C; A) U1! M: Z- k* _& N" \2 p: j
2
! \ P% R, y. mdataset_sizes
# B8 I' O" t4 }4 ?, Z2 b( j9 E1
* P( b) p7 X u$ V/ |' J{'train': 6552, 'valid': 818}: P. v4 }1 B- D! [1 ^) L8 _
1( s2 q7 V: Y: O- O! x& o
读取标签对应的实际名字9 @% s5 I, y( c% @" [
使用同一目录下的json文件,反向映射出花对应的名字% z% P# Q' r5 h# J- j2 {
7 w% P8 u( o% L8 f, s7 O( O( awith open('./flower_data/cat_to_name.json', 'r') as f:
8 q* Y, ]5 J! {0 V. N- [. L3 T7 \0 V cat_to_name = json.load(f)
! f7 {% N* S( j: U l1
$ P" s* ?/ B* m' s2
9 c$ z P5 c, C' g, R/ Zcat_to_name
9 X. {9 y: G' G8 m. u1
& T1 E2 r5 D7 x{'21': 'fire lily',' T! R0 B C: h6 l; |8 e
'3': 'canterbury bells',
& k% q" w1 w ] '45': 'bolero deep blue',
# M4 w$ w/ r1 U* I5 n# D '1': 'pink primrose',) n: A- ^5 J1 R5 \3 |
'34': 'mexican aster',
* E% N. y9 y9 X# R '27': 'prince of wales feathers',# P6 r- x: |& [5 v) q: p
'7': 'moon orchid',8 W+ O1 k, p# m& g$ O
'16': 'globe-flower',2 e7 x( n5 {/ S* y' k _8 b
'25': 'grape hyacinth',7 f2 t6 {5 L2 Z$ }5 v$ j5 y
'26': 'corn poppy',
& f% }8 j0 n- j4 L '79': 'toad lily',
/ `' Q r/ E* _ '39': 'siam tulip',
9 k M1 Y9 [/ c '24': 'red ginger',
: h1 h" V! [: p: y '67': 'spring crocus',) G F* L- f1 x& X2 Q# i
'35': 'alpine sea holly',
& c. H) K, Z) ]% v; {+ h! ]/ u% F '32': 'garden phlox',
- F0 Y8 I2 n7 F* m '10': 'globe thistle',
/ }, x1 m# M9 M' @- A6 j7 Q8 a '6': 'tiger lily',
! V% r. E0 {, E2 Z9 ] '93': 'ball moss',3 C/ k7 f3 i) l& x- G8 v* A
'33': 'love in the mist',9 z) P; d: f! g+ E/ a3 o) p/ U
'9': 'monkshood',
3 B7 U$ D/ z- ~6 X1 \6 v '102': 'blackberry lily',. |! T7 A- I0 ^. e1 d! o1 |# u
'14': 'spear thistle',# ^% b) C+ A; c9 f$ Q
'19': 'balloon flower',, i" ]$ Y: n/ O0 }, }# u* _) j/ Q
'100': 'blanket flower',
# ?4 O- B9 I- f" e- ? '13': 'king protea',; i4 C/ i$ [0 l6 Y3 T, B/ ]8 v \
'49': 'oxeye daisy',; L7 T; _8 g# `% H1 Q( \
'15': 'yellow iris',
' ~; e9 {) Z) v `" P/ t. S; J '61': 'cautleya spicata',
+ f8 b% ~1 H0 U( H8 I '31': 'carnation',0 c; M, P: {" d0 g0 A
'64': 'silverbush',+ b* |& s/ t9 d9 }. M
'68': 'bearded iris',
1 T' _" Z. J! M8 o4 b9 F8 T7 m '63': 'black-eyed susan',8 h. W, X1 F3 ~( E) ~( N$ r% o6 j
'69': 'windflower',
- y3 h& ~7 ?7 i '62': 'japanese anemone',# ?& g- H0 L2 z/ N/ Q
'20': 'giant white arum lily',
2 Z g( v% z9 ~* K" p+ U% E '38': 'great masterwort',
9 i( n% T+ D8 S+ k& o '4': 'sweet pea',9 P3 ^4 c, w( }' o, _9 a) E
'86': 'tree mallow',
( T3 |" P! T! q" ` M0 Z '101': 'trumpet creeper',/ s8 ? W8 S, k$ U, w l- \. V
'42': 'daffodil',# z4 g' W& t" N
'22': 'pincushion flower',
' U1 s- `$ [* D B% V: L7 _2 M '2': 'hard-leaved pocket orchid',
* O) ^, D$ ?2 v/ o" B1 ] '54': 'sunflower',
$ b: _7 Y. G$ E0 L) x '66': 'osteospermum',
7 p* r0 X) k0 w8 b* b% R; | '70': 'tree poppy',+ {: ^4 k6 L$ n+ p9 w
'85': 'desert-rose',8 a9 l( i5 ]. s+ u6 k7 Q* g! ^
'99': 'bromelia',
7 P( ?" e- V5 x0 X3 ^& {7 p '87': 'magnolia',
0 b/ @5 J5 t0 S. M4 n5 f+ X '5': 'english marigold',+ ]8 y. g6 B0 ^+ B7 p O6 i2 C( Y
'92': 'bee balm',
: u/ m- \% z9 k9 n$ ]% g '28': 'stemless gentian',0 b' L% _, q4 k9 V; q
'97': 'mallow',
# W& d% N3 h$ h4 s, w+ f '57': 'gaura',. D4 G: _3 v4 h
'40': 'lenten rose',; @ ]7 B/ s/ B, L
'47': 'marigold',
% G1 t* F3 H) z4 o '59': 'orange dahlia',( h! {" K3 a l: f0 B2 s. i
'48': 'buttercup',
/ _- P* \9 S" |# G: a '55': 'pelargonium',3 Q) o6 k/ n% z/ t8 n
'36': 'ruby-lipped cattleya',* r8 Z- u, h/ W- v; d6 m2 l; X
'91': 'hippeastrum',
0 h7 I6 R- a4 G! Q+ j8 O6 w! L8 I '29': 'artichoke',7 ^, V/ Y7 b( [) L1 p V6 P, R ?4 [
'71': 'gazania',0 Z' f# Y+ c% V6 r
'90': 'canna lily',& d0 e: |7 w: i+ H! ~
'18': 'peruvian lily',
* k( H8 F1 a7 i5 b8 ]7 g '98': 'mexican petunia',
0 R! P8 r- j6 ?" f$ T' a: w& C '8': 'bird of paradise',) `" t* I+ a5 m5 Z
'30': 'sweet william',1 v2 l! |2 o2 B' |1 r
'17': 'purple coneflower',
( F& `' Z3 D7 r2 u W! d! B. E '52': 'wild pansy',1 [% I/ k/ F6 I* B+ d- ~
'84': 'columbine',6 o. G9 c3 t& o! Y% G
'12': "colt's foot",* Z" m# O6 u# I, u/ l0 L7 d
'11': 'snapdragon',, g4 Q8 p2 q' ]2 A) c
'96': 'camellia',' J! E* Z% h x4 l
'23': 'fritillary',
5 o9 U G* p) i- V; T, C '50': 'common dandelion',0 v, l3 B" _9 J
'44': 'poinsettia',: h( m' \1 `7 `+ s `: ^2 r
'53': 'primula',+ u: d# X, h- N# s9 ~
'72': 'azalea',
2 K, B' k7 y. o- h3 r '65': 'californian poppy',% H: n7 p. ?. ~, B% b
'80': 'anthurium'," E2 V1 N$ \. D1 j7 n
'76': 'morning glory',3 S6 y5 m4 G. a( s7 Z. ~ _$ \
'37': 'cape flower',
( D' n4 Q8 v8 u* z- H- q '56': 'bishop of llandaff',( ]/ X" v9 P; O {/ W; g; J' ]3 P8 u
'60': 'pink-yellow dahlia',
0 O3 W3 }; Z3 f3 T2 @% }: q$ z '82': 'clematis',- L2 g! J! W; {# d4 }: ]7 h
'58': 'geranium',/ e, N$ W$ B# n: q- J% r" @: g) A
'75': 'thorn apple',1 C* A( f5 ~- R& \4 K3 b$ [4 q
'41': 'barbeton daisy',# x0 j5 A) e( L1 t0 J1 t
'95': 'bougainvillea',5 w I- S. p4 |$ C
'43': 'sword lily',
1 F3 F: k* {) s# ]3 p9 ] '83': 'hibiscus',7 T& k* `" p4 D( W
'78': 'lotus lotus',; A# R& @3 L* k1 T! X- ~
'88': 'cyclamen', _- _9 v" L/ [: I3 N
'94': 'foxglove', g* e2 m% g8 h4 j
'81': 'frangipani',+ ]' f0 k5 r+ [/ z
'74': 'rose',
) X0 ?# D$ O, `, q6 D. O2 X '89': 'watercress',4 P/ M4 h) y3 Q
'73': 'water lily',
, ], o; S' |" c8 g2 V8 p '46': 'wallflower',
$ `+ o( i. U3 g& s" ` '77': 'passion flower',
. F0 u5 ^7 P0 h2 q '51': 'petunia'}& s* t9 C5 D! P: Y9 t& s* ]
0 X/ H% s0 A H3 K
1- k$ l5 ^! G$ B8 S8 @/ j
2
6 |1 d1 ^7 V* N) a- c3
& T0 a+ r; d9 G4
8 P" ?# k; W$ E5+ Z+ {% F$ S7 n# k$ Z* G- J& @
6
8 p' _0 K: o* }" H3 V8 ]7. j4 ^% P4 ~3 [; ~* a
8) F' v1 r* j3 @' \) R$ U7 c
9 c+ q" u8 m1 q! t! V2 Q0 d! h9 u
10
+ d! s) e X* a7 U1 Z: T j. G111 A3 E$ Q$ f6 w' T3 q) Y9 K
12) S$ n- `5 d B' c5 I9 C6 ^
13. ~; c) U; t( O7 t+ |# f2 E- ^# }3 p
14
3 S" D7 W i G1 B$ |* r. z j15( |% ]" l! [7 j* C% q
16
+ C- K& W& ?6 p' C( O( D17# z2 V7 o8 C4 q' {# W1 N
18
2 w8 C5 [$ L! z+ ~+ p19
* R1 m; v% M" `, J6 Z20) O3 l) e) w' B: z9 B
210 ?3 Q& Q$ O0 U F7 F' s2 i
22' \3 N/ g4 v: S- Y: q
23, o* z0 p8 h3 C5 t! n
24
4 \3 @3 l5 b9 K }$ S% q25
9 G5 [2 Y' H! t" }0 X* p7 \8 U; I( A26
2 ?0 C; o( E4 y( B# Q1 _279 D. V0 k+ g: J3 v. c5 h
28$ P+ u' `. L! @
29
: r* N( J/ r. b% Z( w30 m. I4 m) d7 o4 W) i1 g, \
31# `: l8 |) D8 f/ i
326 i: K9 T9 o6 r+ Q& j8 ]' P* H
33
2 l* p5 ?( @2 h! _1 F4 d% K34) f9 R* Y$ n9 v, r4 a7 A" ]
35
. M1 v; A% F4 s+ A7 s5 @( x1 z/ q36" a# O- A2 W. x* _6 n
37
. H8 `+ \0 i) R6 A6 H7 p0 |: V38 r' ~, [! e6 I; h
392 i1 I# {; `5 D% f
406 y# B7 X4 x5 ^5 y! M1 C
41+ d2 {8 [5 j/ v) v
42
6 p; e3 B- E4 v+ |2 g434 C# y8 f5 {5 Y, m" u
440 e: R. ?2 q9 M( H
454 z; l# }( K! `' r
46
* s, W3 X0 b. m" e6 `471 o( u; u& @7 _5 G
48" B! @& Q/ z" p
49
: W6 ]- n1 d& M2 p% j: ~50
9 y& x0 N" f/ ]! S51
$ ?( ~" H; ?% g5 m# i! Q52- E) A4 g; Q3 w7 _
53$ y9 x) h! u1 _
54
& w9 y( a4 ^5 D55
0 u# l- U; U5 K$ y56
( H' ~1 V7 v0 g: o1 Z6 J571 ~7 E6 W" f/ E0 _! v5 f3 W- `. [
58" A/ [5 p5 f D6 T& I
59
' x' I" I5 f* ~( G& k602 v1 J. c' N" z
61: B( R# ^- w* z7 T' T* j
629 @: e3 Z) I0 y; F. p4 D# r/ I
63% u1 f- M0 T: }; J0 }" U
64/ E6 l- o! c: _4 z: a! |
65! r. Y( D; n5 |
66
& ~( r- [ m" V9 y+ h67
% h( x6 [" }" Z" E1 y7 i0 h68. X* L% g$ S$ n
69# _5 ?+ e) J! c, {
703 c4 U% G- v3 q5 v% ^( y( n
71
8 c6 o' V7 X; F* O722 l9 h% V2 Z7 e* U
73/ m3 f0 G* A" r" l+ N6 b4 C
74
/ {, s$ V, n }+ \ k754 w& P i3 n: r
76
0 l5 Y4 }( g" X5 a9 {7 H8 p, U77
# Y, j! j5 X2 p$ q78
1 w5 m6 K$ c$ B7 s79) S' ]$ I' X8 C* d
80' a6 Y ^! K/ ]# R* I; C
81
, ^: i1 k6 R. {3 R2 {82) Z) |# x$ e0 ^3 x7 F \3 @4 \
83 s' ~6 [! M, ~! l1 n, B) G8 V
84
; U5 L$ X7 G& X9 R$ L! B: L. c85
: }* P( k- t4 G9 C. h3 W86+ |( }- e: l$ C" f) i! c; q
87
/ K) r" y( ]: ^9 w) z1 Z6 k7 s885 ^& t5 o& @4 @- F5 g3 q
89
0 T. _2 Y. y+ j' |) c3 h& B: Y90
& c" V4 J. X# d: O9 e2 e- h k2 A2 l91# t% m- u! d+ l8 g) e2 @
927 V7 _* @2 J6 g
93! E2 v L2 {9 T& {$ ?9 G' A
947 _% q. I) ~5 ?' D
95& ]2 V2 {2 r, v2 X/ ~" ^
96) T. w$ v' ]7 Q3 z1 j2 I
97
2 \+ N& t2 F/ f" s. i/ a5 Q98
) ^ R# \1 H$ h* b+ h999 t" i# P( q' s7 m
100
6 h. T0 M% ^# \101
) P2 m5 X# t. n" ]$ w102
9 e$ w; _9 n8 y) I4 D0 L4.展示一下数据
! k- T1 E6 L& c" wdef im_convert(tensor):! |1 T; e: a( x" {% I( U
"""数据展示"""
5 S" z% V2 h- k5 d r# k7 G6 K image = tensor.to("cpu").clone().detach()
* l; e5 L8 g3 I5 L image = image.numpy().squeeze()
5 _4 Q. \: S/ i # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图, H/ {, X* U s) m4 F( }7 I$ Z
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
) B! v' b$ A+ Z. B# J* I3 f image = image.transpose(1, 2, 0)3 o2 M+ T% P! w1 a! I
# 反正则化(反标准化)6 o; w7 w* x0 T3 R5 B6 o
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
" I6 ~5 x% h7 l* o( P. H' X7 I% V
) E5 Q* Q" q6 B4 E7 h # 将图像中小于0 的都换成0,大于的都变成14 P8 f) k) k$ C5 z9 \* M
image = image.clip(0, 1)
* i0 r3 v9 y% J& i: G1 r
8 q: O! @! | q( {8 p! U return image
4 E0 H) q9 I) ^5 J- h1# U# A p+ S6 x1 c* I9 D* N) U
2) I: Q. o+ u2 C7 G% B2 P- i0 W
3
: l: u5 I% l$ L4 g& t+ d4
' n4 M! f) h3 \$ [, X7 U k/ e5
* G; T% t8 v) K9 W6 y66 x4 f/ s$ Z+ @: H- U( F( ]
7
1 D1 N6 B, q7 s3 ?8 G4 P) \8: `) c$ d6 n; T: [" e
90 z+ F8 D2 Y( v1 w- \( `
10
1 E# j7 N, \3 o" Q. ?11( O7 H) D% E0 r7 q
12! R) n# g0 o) i) ^
13 {) g2 }8 T& w) m: n
14
9 h3 H9 p3 f0 g5 L4 D5 y4 g% |# 使用上面定义好的类进行画图
- I) }' |8 s9 U/ c4 l% a& C2 mfig = plt.figure(figsize = (20, 12))4 C9 r7 O$ g9 j3 |& _& u$ x3 H7 j; S- Z
columns = 4* c; S# x# M R7 {# Y5 a" T8 O: f
rows = 2
7 ~7 T z& ^; F1 x
0 t6 l, H2 U: |& R# iter迭代器) {' w& }6 G6 C: [
# 随便找一个Batch数据进行展示& B! q$ p: c: W
dataiter = iter(dataloaders['valid'])* q2 p/ X2 O, F: `4 o( Q- g
inputs, classes = dataiter.next()
! p* H- K [, v, ^7 V8 C. }3 J% \7 t, v, c$ {. O% _" Z3 r& l$ r
for idx in range(columns * rows):
/ ^1 v7 T2 Y, v) o4 T ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
$ p' k6 R! A$ q/ V # 利用json文件将其对应花的类型打印在图片中
; M+ j# P2 j( k0 _0 E& L ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
2 } _+ X4 U+ j4 E plt.imshow(im_convert(inputs[idx]))
$ ]6 P6 M% I' w/ y3 Gplt.show()
% k$ l S% v: r4 s' R+ G# x1 J |# @' B1 @
1
5 B1 N w n8 q" |# n2
; t& F7 f7 p( |3
) j2 {5 Z9 h8 U4 @4
8 ]( A3 f5 Z$ n! f/ J: a5 p5
?2 C+ u! K' k* Z8 c6- T% d1 V, y3 V! x: [. R, m% h
7: P7 m s" Z: g8 h1 G' U
8
6 s o }5 a9 a! U& v; a9
3 }/ \. t3 t7 P. d, Q( z10) N, ~5 _; T, v; J& B( n* |
11+ f) A+ u6 m/ A# X- b
122 q4 g2 h* Z# I" U' W% }
13
* O; i9 x. K5 k3 a143 P2 M1 z4 K3 x) d
154 m, u2 C Z+ n* Z- `, o
166 W5 g/ u# U' h( @' n7 S. i9 R
4 }0 f5 G& y; C& _9 o7 z
4 U4 d; V5 {2 v2 ^) V
5. 加载models提供的模型,并直接用训练好的权重做初始化参数( @) d8 k% c4 R9 E
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']8 z, q( g; L- Y' H7 V8 L2 d
# 主要的图像识别用resnet来做
9 ~6 w: [. \/ \" W& U7 c# 是否用人家训练好的特征0 e6 O( b W# J) g. U. r
feature_extract = True
1 W) g4 c6 }8 r0 P7 }1( J7 B% A/ }8 h' d) K
2
F+ J" }+ b" V. `) M3
" D: |% \4 n: {9 D' s. G4
: D: K& y) D- X8 w* J+ B$ h5 f# 是否用GPU进行训练
' x7 [5 ?9 a& G Rtrain_on_gpu = torch.cuda.is_available()
+ y# O4 u- C- I3 u! o4 V$ X4 F' R7 {3 \, X! j
if not train_on_gpu:& Q) c) z9 k8 b
print('CUDA is not available. Training on CPU ...')
6 z; g5 z9 ~' s7 E4 [else:
; E6 J2 n, _$ }% a" m print('CUDA is available! Training on GPU ...')
7 g5 B9 ~1 [) I4 s, G" \
& \ P% e6 x- _/ V' ^, t. Ndevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
0 B' P' i7 A) Y: x1& {) z- k i6 ^8 l0 m4 i5 m- u, l! E/ |
2
: W/ o# Q$ O; j1 D6 h' g6 p35 P# H, Z" [- G2 f9 @) H/ o
4* F/ g( q4 ?. e- Y
5" g( m% \/ P, R4 F* L# ?
6
& T ~" S' [: V7
5 H9 n( g0 h3 Y+ {" X85 Y" Z8 S* P/ W1 _' M
9
$ }$ ?7 Y9 a `9 mCUDA is not available. Training on CPU ...! b$ Y" R% M' Z3 k( P
1; [. ~( o4 e$ r7 U9 ^* x
# 将一些层定义为false,使其不自动更新8 k: g: D, t( c# x* C
def set_parameter_requires_grad(model, feature_extracting):4 g$ y! g @, g3 u& j5 ]/ w. n
if feature_extracting:. @: Z! `* L0 J- e, O! |
for param in model.parameters():# T1 G* |; I+ o4 |" b
param.requires_grad = False2 R' }: K a7 L
1' n4 ?0 z, K0 v
2" ?' j1 Y g# W" K
3
& S) T6 Q. b! D5 {' M* q4
- p* M5 n- y1 ? b5
! |' ^3 O, X3 h/ I z- U4 y# g# 打印模型架构告知是怎么一步一步去完成的* H) J% h- f R6 @. V9 n
# 主要是为我们提取特征的
/ I( @( d$ y6 M+ A0 ~4 m- ~
: z; I, W" m; dmodel_ft = models.resnet152()
5 \' q" J; E' Z' hmodel_ft' n% }$ k& G) M7 H5 g
1* Y! @' f2 S3 S' k" w& h& g
27 \7 ~1 q3 c# [6 I# \9 J! f
3( |4 C( F6 }9 J' \
4. H; ]1 ]4 H* O
5* O# h7 p9 h8 |6 D" A
ResNet(( x/ T( m, L c t- }
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
6 j& R& k0 s& B/ z (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)$ C9 K+ k+ {$ g
(relu): ReLU(inplace=True)
) y: N' H- l8 p2 A7 L8 ] U# T2 l; h (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
' l) f$ H( k! e4 _ (layer1): Sequential(
3 a- k+ W9 a. K0 A (0): Bottleneck(8 W! y) k2 f5 d
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)1 E5 C* W/ W1 f d
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 e* M3 W0 ~/ K; ~3 g (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)( I6 |) W4 B; G* d" ^, ]7 f: J5 {
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 m1 N3 c3 @. j2 i" p2 }; ]1 A (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
0 I- Y! I, ]9 h, f (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)1 ~: R: h0 Z% C) }( O8 {6 i/ ?' D
(relu): ReLU(inplace=True)
7 ?9 _' h: X( g (downsample): Sequential(
" V, T/ Q: s" i4 l# J+ V6 n$ a (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
& w9 F% D1 V* {* g (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)! R% w4 e) N8 [0 {9 x# E
); p% e0 z; l2 g; R2 A, \- Z
)+ f- [9 M% s7 A# c
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
+ T% W F( z4 @4 ]* E, e3 y (2): Bottleneck(& @ @3 n- G$ J1 X0 i7 t
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
$ l4 a5 W! l; D" W4 Q. o& j% }' ] (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
4 T ]: a F6 L6 ]& | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
. y! w9 Y/ \* C" s2 s- A (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 m, B: F. ~4 }- {5 P; m; e (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
& o7 A# P) i3 T' ?) E8 a (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
. U N. ~# V' |! V2 [ (relu): ReLU(inplace=True)% M7 J( y+ F( p# c; Q& u3 M
)* I h( a2 }! n6 w9 q. k
); p& c- ?' J( H/ }1 M4 ` L
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))% V; V* X K# N; e, m, ~4 ]6 s
(fc): Linear(in_features=2048, out_features=1000, bias=True)
# Z N3 k) j* X/ b& L! z9 })* O+ e- ?0 G' O8 ^- v0 C
$ G1 h% ~9 A# H
1
3 J' z4 Q N* {. F4 O) H2
9 |% `+ a5 | N34 Q, R. [( U3 X/ W4 S6 s# X+ x
4
; k2 G4 j y' ?8 U9 Y3 Z' N56 N! b1 [& E |6 [+ f, Y
6. N. w5 K) |$ M# I8 p" z; {" g- W
7
" G3 K2 i9 g0 g0 }8
& z1 r. g1 m5 L" G3 o90 Y& _) o$ K9 M- ~/ a5 i" v
10; v. n/ f$ Y: x! P9 z6 M# w
115 z+ p: }( M. u. |+ u1 |. j
12+ X3 A5 x5 x9 @
13
3 Q. w4 q b' X; ]5 `14
+ T- {3 I. I2 B& b15
, m: T' o" P+ a* l% ]0 A168 K' }, Y( |7 f8 ^
17
6 \7 [/ e1 E3 `; W18( v0 w0 b# L0 `* g: |' Y2 A
19
, K4 t5 t! F1 }; F# t3 D. p* `4 f20' [4 X5 C2 e' j
21
. J) q1 I2 }& t% T# U/ J5 K22, I6 P9 I3 S8 z+ |8 O' I
23/ ?# A$ s7 v5 f& Z+ ^, U
242 o+ b! |& ?4 j2 F6 e1 i. j/ ]
25
3 a" m2 X7 f% ^261 K+ ]% ~( `6 C3 `1 W
27
, W2 z9 Z# C: M( m0 `( C284 c4 W! F& x9 _ }' x8 J5 _
29
/ R: Y& S7 r9 c& X7 Y7 g# o306 v4 ^# P- K; i
31
1 u9 P$ X5 F; m- w5 T; b% X! s6 E* [% T32* ^& T$ x! ~! `, o! i2 `
330 N0 |! V! e, ]
最后是1000分类,2048输入,分为1000个分类
, V& A; _+ A4 n4 [6 Q! U而我们需要将我们的任务进行调整,将1000分类改为102输出
" ]+ K5 }/ k! b! ^6 Q+ L2 `. w9 w
" S9 T6 d: ^: ?( J' E4 o6 R: c, r6.初始化模型架构
0 l2 N% M% X G5 c \步骤如下:! o# }0 n. ]2 U* g, H
! j. u/ l4 y! Y7 u4 d) @
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数; O1 t$ ~( c8 [1 n, P
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
( e: _/ W& S" d; z+ x无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
) t7 p; R3 Z( `5 ]! w3 Z! E- k官方文档链接& e$ M# F, e- S& n- q8 o" N7 X; p7 q
https://pytorch.org/vision/stable/models.html
. q3 ~; g$ t+ v% `1 b- K: q$ J9 S9 j# D+ W+ w. Y
# 将他人的模型加载进来
% }+ \% t) K6 odef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
. C% q. M. ~1 A3 | # 选择适合的模型,不同的模型初始化参数不同3 J; J9 `9 O2 d% n# R
model_ft = None
; x7 {% `' v) j input_size = 0
: R) z; q0 a. U7 O8 W4 W, |6 R* g8 E* c2 J; j
if model_name == "resnet":
: M: f, \- K& A9 J/ x) q """
- g/ i" [. q" _4 x" A- O Resnet152( z- A# c" g% D$ P* R. s
"""2 E0 p E5 e/ ~* [
& M8 [2 ~, R9 T9 E) q, Y3 X# @9 l
# 1. 加载与训练网络( w u. ]) R4 E1 ~: z# ?5 W$ w
model_ft = models.resnet152(pretrained = use_pretrained)
! I' K6 _: T7 x) t$ Q: V # 2. 是否将提取特征的模块冻住,只训练FC层
/ W% y, G+ \3 T7 S0 w7 I set_parameter_requires_grad(model_ft, feature_extract) t- ^- z" l1 ], W+ m
# 3. 获得全连接层输入特征/ u/ N2 a) ?1 F: B- v
num_frts = model_ft.fc.in_features
* n% H1 l( O8 }2 i% Y! }) q # 4. 重新加载全连接层,设置输出102
9 Y! E3 ]" d0 y8 @ model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
. J v% z1 P; l d" U. w nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
# D" s J7 ~# D$ x input_size = 224
6 E7 v( U( k0 ~" |: I9 g3 n9 ^" n; e
3 j( H5 f @$ W) |/ l- D6 c( L elif model_name == "alexnet":
0 ^, }4 a, @- Q2 v6 O. T* t """+ }9 n8 C5 p _, b6 o
Alexnet
6 e* Y/ F, H7 M! A! B4 A$ P/ J """
9 M. x, ^# {8 }. G4 F. D6 ^ model_ft = models.alexnet(pretrained = use_pretrained)) F; R f6 h( ~7 l" H" W' A! F
set_parameter_requires_grad(model_ft, feature_extract)
! v0 N' U( ~% X
6 T: O0 P" {& Q( A! S& ^ # 将最后一个特征输出替换 序号为【6】的分类器
/ m0 x2 b5 Y# A5 `& ~ num_frts = model_ft.classifier[6].in_features # 获得FC层输入, \ j8 I" j0 a2 j3 o( j* x Y
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
0 T, Q. @$ d3 [ input_size = 224 v( \0 G3 l. y. m1 x
# _: {% A( g. |; s) G5 H elif model_name == "vgg":
9 ]% G3 b C% u+ t8 d8 F """
4 _! J9 \# }% v VGG11_bn+ b+ V5 j, `8 g( a n0 J
"""
& v3 |$ C9 a3 ~ model_ft = models.vgg16(pretrained = use_pretrained): s1 A9 F! R" V6 x, \
set_parameter_requires_grad(model_ft, feature_extract)$ G* w, I9 t- [
num_frts = model_ft.classifier[6].in_features
( b; e2 K ^( l3 Z8 f. y( g2 v model_ft.classifier[6] = nn.Linear(num_frts, num_classes)* y, a3 K S# e6 }
input_size = 224
0 m; L4 E7 G" n$ p5 a! Y& d7 u$ G4 Y* @" U8 L/ [3 i1 r7 ~& H
elif model_name == "squeezenet":
# B7 \+ n4 z6 n ?+ O """) I. y# |' H+ L8 D
Squeezenet
1 L. @$ c+ k# q5 s3 B """
+ L/ s3 M h( H; W9 C model_ft = models.squeezenet1_0(pretrained = use_pretrained); y9 S& R2 |; c
set_parameter_requires_grad(model_ft, feature_extract)( D7 W# f* Y+ O
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))' K3 M+ z) c- T/ m* T$ j
model_ft.num_classes = num_classes# D, C8 j6 J9 c2 K
input_size = 224
* J) o! a2 \0 e
6 x( z; {1 ]: g elif model_name == "densenet":
, ]" A- T5 z" k """" @* A3 N) }5 a+ s V) o, p
Densenet6 ]) V4 L I: H* G4 a7 }- j" h
"""+ E8 E7 f) G1 e; A
model_ft = models.desenet121(pretrained = use_pretrained)
5 l9 r7 z. o+ z- E" X3 _ set_parameter_requires_grad(model_ft, feature_extract)' g. B# o+ ~0 H
num_frts = model_ft.classifier.in_features
. L2 C! r7 m5 G8 e. i/ F model_ft.classifier = nn.Linear(num_frts, num_classes)7 i: D( _! F' G4 `, O" n1 R# P
input_size = 224: J6 j" N0 q) z" V# }# i# L
0 Y1 E+ z: Y6 P elif model_name == "inception":* T' B0 U+ h x2 W% [$ k: L3 P
"""- F X0 z& d) S3 o4 C
Inception V3
# X/ m8 `& f$ H """
$ ^ i" _& u f0 ~ U. k9 p! Y model_ft = models.inception_V(pretrained = use_pretrained)) r$ C0 W# x8 Q4 c: j" X
set_parameter_requires_grad(model_ft, feature_extract)- h6 x/ x* l8 _; {/ q, x* [+ k
+ t" W% c" G* b num_frts = model_ft.AuxLogits.fc.in_features, X- d3 R$ R, ?- D
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
' B m( d; j* o8 W, w7 |) M" i& D' ^: L4 Q
num_frts = model_ft.fc.in_features
3 }- x. m) V% [# t. z6 g6 U model_ft.fc = nn.Linear(num_frts, num_classes)$ l5 c3 q; i8 s* ~" c7 c
input_size = 299
( `5 n4 E, B& a! o4 x" [4 ?& f/ B5 u; W* n
else:
1 C* S' h2 L. F. I, K6 m4 H print("Invalid model name, exiting...")1 S. U0 q, h' s2 `" L: ~
exit()
: j% n5 N4 E7 Y C5 x+ z1 g5 }0 S0 K2 ^- ], S/ _! l$ f7 A+ i
return model_ft, input_size7 u6 X* X1 C; \4 A
, Q5 u$ H$ f) C. P3 z2 e+ Q1
" T1 ^9 r) d* m3 ]; a# d- i2
" @% u9 {2 T$ P9 `- ]& Z/ m; @/ M" _3
. o" n) I* g) U* _+ O4
5 N; X6 F8 @+ n- u/ V( }) P5' r/ _. b; @6 y
6
' q" z* S. y3 P0 ^7. q* b' ^# L% k
8) _: m& z" F* x c- W
92 \$ h' A: a6 {! Z1 q. P
100 G# |( i0 c/ X$ X
11" ` |9 l( O! G& ~ m! L
12
' A. R$ }2 y! m4 a$ R/ L7 r$ q13
! M4 y2 M; {4 K( C14
0 F( U+ {7 E) U3 Q7 {15' e3 f/ c5 _7 I! \- r2 L# ~
16; u( P3 d* {2 N
17$ i) R, x9 W" l# M( I W
18
f& W, H( V% ], y: R1 c* c5 y19; v! p6 e" g$ z+ H
20
( B/ h, J) b C& s217 K. c: L2 [3 v" y* d& K5 u
22, f$ z! a0 O3 y% P: l5 m
23" d4 Y) o4 q" V" G! J
24$ {; u$ g+ @: ]
25
- @9 V! `/ b" U26
0 ?# h; n3 }" O0 X/ I8 T% [27& i6 d7 a( E! I7 [9 l
28
6 G- [, u1 q1 A U. H* a# H; U( N29
; H/ u9 H) E2 H( r% n& v30
5 T6 n) h. ^. O& q U31& f i" B6 `: O a% h5 q
32% z9 i0 J5 m" ]& D6 W
33
" T- Z t, @+ y34
- j+ a0 n( X% R6 G8 n) b353 S- L: _$ O! `; ^4 b8 ~' P
366 ^6 {9 z7 d1 J1 o3 _
37
" Y2 L i) I! ?' D1 j38! G6 D; I9 l; |" o9 r2 ?4 p8 O
39/ }( P V6 _' m$ w" d7 T/ A
40
8 @' T; C N( P7 @6 L& E41
) G' _6 R, }" I7 h42
, ?% E* `. D* z3 o434 y; _5 ]0 @! z& v; `
44
3 ^* L9 F" a- i" B8 ]45
3 f, B6 e1 `( o1 [46
$ t& P/ m( ^8 n47- K9 ?6 @; [* B. i; {6 L, B
48
3 ?: W# n: c6 [; N) ^49
$ Z0 _2 v: O, E, d* {, y50) r" k% s/ y& { W- D( _
51
% ?7 ]& [3 M5 H( j% a9 p' s; \2 H/ t52* n' r0 ~; E7 ?# ]" q
53# X- g2 `9 x' R. j) u% E) f2 y
54
# h6 y) N0 b E: o* G" M551 b5 v" @3 W9 s- [* W3 U8 K3 v
565 e" s9 V1 x2 f1 W+ G8 w, z
57
9 {0 k+ h7 k3 B: S* i: _587 R. y$ k( F; T* B/ l/ b
59
1 g) ?" F+ F) ^+ k5 v- L6 b; U, O60
! b: |. p; Q7 r h61
0 y; t8 w6 h$ ~7 g1 {/ e6 E' M62
9 k& |( [9 Y& U* n63
4 r! `7 C1 m0 S6 x8 n/ |9 ]64/ T. k7 B* j6 W0 j" o8 U; F
65" ?- w% X7 [' N7 R
66' m) Z8 m5 t, \+ _
678 t, j9 A+ v7 A, u: s
68
) T0 l! l, b0 B. n/ K6 ?$ O+ q69
# T' S) M( E( l" P& l4 i+ T70& X2 G$ }: ]. ^# ?; _9 ]% p
71
, f+ T3 A8 D3 P9 ~1 G724 B* h, B% `9 E
73
5 b& M/ D5 j6 [, v74
9 q6 L5 M* {3 o: G+ I G7 s75
/ h' Q7 e0 @) H: a76, b* a' _% ?: `2 a' ?/ i( S
77
" t* b2 f e8 Q { y+ A1 L4 \78
' }3 a! [4 i1 T; e797 u+ N9 W8 ^( x8 J9 d
80
+ H2 \% d. C! i5 x3 l4 L5 T81
! `' o5 `+ N' T82
7 h4 y( I' X6 S& V83! G: X6 K7 R- N8 ^* V' K
7. 设置需要训练的参数+ t# u/ |0 e6 D
# 设置模型名字、输出分类数1 e* b' Y- q7 I
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
# v7 t2 \( ], k( ?# {2 T/ I& T0 c% N+ |, }
# GPU 计算
, S) f4 w. m7 }) L/ N6 k# c: gmodel_ft = model_ft.to(device)% I+ o: W+ d, C
4 ?! p8 K7 Q7 U% k+ t
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取" ~) E P" e2 I6 _0 Y; [' b
filename = 'checkpoint.pth'
( o, q0 K" O5 @; I8 E$ T* r1 }
$ o7 |* [' z' s: k' [% j# 是否训练所有层) [4 T& g: b6 [# l. e# l
params_to_update = model_ft.parameters()
% {- _) u# O% d4 S K# 打印出需要训练的层( W3 [! h! S9 k: V( m) ]- n$ J
print("Params to learn:")
) p- }0 v% O9 O7 z. r0 Y, Y) tif feature_extract:2 Z" r; @' d- O; I1 O
params_to_update = []9 E- g& d$ |. G" c/ i# O; i9 G
for name, param in model_ft.named_parameters():
' H% N# c* h5 c, F V1 g if param.requires_grad == True:( x$ q3 s3 j8 w1 T ~
params_to_update.append(param)
+ ]6 D' d: J! F! u8 H1 t" s print("\t", name)
8 M6 v3 \# k. e8 belse:9 j1 Z7 }2 @5 |# ^: k' d
for name, param in model_ft.named_parameters():" t k, {3 n( x3 U* a4 _% }+ X
if param.requires_grad ==True:6 C1 n; B3 t- j% r6 @2 I
print("\t", name)- f. Z/ j2 z5 |5 L
( R/ @! s% W1 \% \
1
& y D ]! V8 Z! V$ u: C' n( r8 h2 C8 m; v4 _* p+ q8 v
34 o' o. p3 \2 n: ?& h$ w9 ^
4
3 V3 y2 A- e4 i- l5- \; {0 h b' H P5 w' O1 ?
6
5 D1 u) Z' n" E% f5 r70 B+ x' S; d2 {$ P% r
8
" y/ B& D2 R3 N' i9$ L w" d( f5 K& [, m" H# W0 Y
10
, q H8 `1 Q8 D: p8 B! ]11% \! u5 W% M$ g$ Y! O
126 c3 p. g1 u. O) `
13
3 P v2 g! L: G* c14
0 E1 R6 @6 x' I- q* q15$ i1 X3 W& L! C3 a
16
2 G9 h" m2 \" M% i174 C! g9 ?: q+ l# a; T9 E9 U
18! L7 z( V/ w. g* Y4 F* Y- C8 G3 E
19
0 ^3 N( U# J1 v' b0 |$ _20
- s' G& ]5 K, `/ b7 _; E21/ n2 c, v7 t5 z' X6 C8 s& t
22, q" o/ \/ K% I$ g: R3 P
23
+ k( V5 @; d2 F, R" G H! [Params to learn:
9 C7 N L7 e6 `! N% }. B% T2 E d fc.0.weight
; s; C0 t: C- Z- T/ U' z fc.0.bias
! V9 i1 S( v9 r/ U' P) H. L1; u! R# @+ U- N
2
. g, x* N/ e( ?+ F32 w" X7 F9 t# Q8 P g: G# c
7. 训练与预测3 Y- Z6 E- {) H) b. ]* ?0 I
7.1 优化器设置9 @8 h2 k! W# r0 \( F% f0 i
# 优化器设置
9 t4 J& ^1 M8 b+ {optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
( x. c, Z$ }! _6 j. f4 z& z% w5 V8 V# 学习率衰减策略. a4 ?# P6 y& }2 V; D; H
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)/ o7 w. G. z6 ?* l: Z- P8 M
# 学习率每7个epoch衰减为原来的1/105 j) j( C1 H/ K3 Q
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算! \- N2 w- ?: N2 P; G( I
4 _) @1 T n" m1 W) }5 u1 }7 p/ ?. _criterion = nn.NLLLoss()! ?4 C1 V: D8 Q s/ m T8 W
11 g+ f+ ?( d# i! b
2
! `: |! b+ e9 T4 I% |" Q3
3 U6 M- |$ N1 R& d' ?7 ]4) @6 m7 E$ T. `- i A5 @5 r! F
5
2 W" n. O$ t' k9 i* `/ s6
! ^1 u5 U2 M, ]3 H2 L" r7
! O M- L2 x, W& f% x8
: K# Y/ F3 W% Z4 G, M# 定义训练函数
0 D ^. d2 G( D& H/ m/ d: o& ^8 \% V#is_inception:要不要用其他的网络2 E4 j+ q3 \8 ^ f f/ j% j
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
9 @, j5 L8 t! s6 a5 w; b9 t% Y+ y. T3 J since = time.time()
' W# \3 Z. M8 A, `/ n #保存最好的准确率
7 o" W2 Z# _5 _6 ? best_acc = 0/ x. a! }9 B4 |- F! b, G L; l1 ~0 c
"""- r2 |0 C& R( v) ^ n4 H
checkpoint = torch.load(filename)
$ M5 C* V' R* d/ a( p best_acc = checkpoint['best_acc']
" J) l$ P6 ?' Z2 v( e3 [ model.load_state_dict(checkpoint['state_dict'])# s7 y/ i/ O$ O) D8 S/ }- e" g1 a
optimizer.load_state_dict(checkpoint['optimizer'])
2 v$ ]7 K% }% l4 U" T# U0 I% |; K model.class_to_idx = checkpoint['mapping'] \ a5 c3 b7 _9 k$ b; m8 v
"""
% A( E R- J& t5 j" L8 B) u #指定用GPU还是CPU
% K! }& k4 M1 C5 {4 u% a model.to(device)) e/ ?3 @3 f, y: R Q) l7 j
#下面是为展示做的
$ X9 `) e X! F& m" v5 k% u9 q val_acc_history = []
0 s4 N& D d5 o9 ~! K& n: X) @* F; h train_acc_history = []' Y m: V; d! f+ ?+ V7 I! Z3 P& p
train_losses = []9 @1 M4 x; j8 Q- y/ a
valid_losses = []
$ ~. r) W: G8 I$ I, {' X& o LRs = [optimizer.param_groups[0]['lr']]( \; A8 M: s9 V* |8 _& q# N' w8 ^
#最好的一次存下来9 h0 |; w! U" _7 d0 m+ j
best_model_wts = copy.deepcopy(model.state_dict())
4 X- x' {! A1 X3 D4 b
: f# D" x( C5 Y1 n3 ^7 y for epoch in range(num_epochs):
% t# m/ d. ^1 a6 E3 } x. a print('Epoch {}/{}'.format(epoch, num_epochs - 1))
6 U& Y8 L2 b. u' h# D: ` print('-' * 10)- A9 c4 E" Z$ f c. ~2 p* Z
) y9 \6 e( Y' G" |1 H0 k # 训练和验证& {: l2 _3 y$ g" N3 U4 u
for phase in ['train', 'valid']:
y% m- |9 v! q$ P P if phase == 'train':
: X' @$ e; d1 x model.train() # 训练$ `% {7 k4 N7 o
else:
, J" y% R8 U( p2 a$ I0 N9 H model.eval() # 验证# i$ m; {' s& [
& ^+ h, S6 i1 ?- G0 @# N6 r
running_loss = 0.0
2 q( S- k8 A' G$ h running_corrects = 0! A8 P! l& M9 {' T( ~ p
" V% w% h, _; o# f* _2 [8 l # 把数据都取个遍
$ b7 @' T( T0 @2 P: f for inputs, labels in dataloaders[phase]:) U' o5 B% ~! t: D1 k# b
#下面是将inputs,labels传到GPU
8 i7 ~( p" P( @0 k4 l! Y inputs = inputs.to(device)
X1 b5 W# |% ] labels = labels.to(device). Y; P* ] O8 N; g6 U2 [8 h
9 X+ f/ [9 r8 {3 t # 清零
1 C' z# f8 {& j1 C optimizer.zero_grad()0 {+ p! t8 |" r( W
# 只有训练的时候计算和更新梯度
7 z- l4 ]+ h$ `: O6 M: K with torch.set_grad_enabled(phase == 'train'):/ m4 J7 d) v6 e
#if这面不需要计算,可忽略, X$ T1 i+ E. W: H; l& t6 I4 q
if is_inception and phase == 'train':+ @; q8 v- M) M
outputs, aux_outputs = model(inputs)
8 h% J/ h! r. D5 `) o) M loss1 = criterion(outputs, labels). I& X% B5 n$ }* N: o+ r
loss2 = criterion(aux_outputs, labels)% g5 Y3 ?# n, N% L! z, q+ x! b/ ]
loss = loss1 + 0.4*loss2
% c! x4 ^2 H* Y6 g) H t4 h. x else:#resnet执行的是这里( @+ c- @" r& ]0 M6 G% Z
outputs = model(inputs)
# \( L6 H# K1 r loss = criterion(outputs, labels)! H# k; y/ }- L9 M8 D
$ q: `) U* l7 K+ [2 v; O
#概率最大的返回preds, l, \2 m1 s( E) F
_, preds = torch.max(outputs, 1)2 X* ]. {, `7 W7 O; O
# [/ {+ R6 k$ W3 ]# y' W( [$ h8 f
# 训练阶段更新权重
$ o+ S' q; Q$ M# w, Q. n4 H if phase == 'train':
; A1 O# A( ~# T+ k. s4 ]: K loss.backward()& p0 h0 h3 g5 T! P* A
optimizer.step()2 T9 k( @5 }8 L+ b0 X0 B" v
6 a3 c& Q% Z& t # 计算损失" n( l! {$ b' |5 U7 G, U
running_loss += loss.item() * inputs.size(0)
& W3 K8 Y$ a$ `' d$ z9 J0 E running_corrects += torch.sum(preds == labels.data)
. w. K; R9 K8 H1 ~! @
' x6 R. G' q/ P4 C) n" Y' Y4 j, z #打印操作
8 @8 `3 ^% I! f7 \! y# j epoch_loss = running_loss / len(dataloaders[phase].dataset)
( O* o! g6 c0 I$ \1 r epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset); k! q, V" _ b- g! F: t
6 M* H3 u1 h- x: i c# f7 n( z4 K
/ m/ Y1 x! f2 k/ i/ I5 _
time_elapsed = time.time() - since0 @- U9 s! G: d- e) R3 z6 l3 t
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
# _; p- e/ }, p. ] E, X print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))9 x- N L( r) y6 e' m& I
3 }) V |5 t8 L$ U3 E' o$ \
4 a3 J1 L% g: z; o/ O
# 得到最好那次的模型: ^, H6 {1 X, c: t k
if phase == 'valid' and epoch_acc > best_acc:% D4 u( B+ y2 X& V
best_acc = epoch_acc
& `2 O/ n. K* Z7 I F1 f #模型保存& ?3 F! p8 O6 x* ]" V% ]4 \5 R
best_model_wts = copy.deepcopy(model.state_dict())
! o2 k H9 X$ h9 r state = {/ {) i+ l- }! N- b
#tate_dict变量存放训练过程中需要学习的权重和偏执系数$ t. O) |- }6 P
'state_dict': model.state_dict(),9 @( F3 m. [6 Z- g3 f q
'best_acc': best_acc,1 R* V; e( J, C% P2 D9 ?, ^5 Q" G
'optimizer' : optimizer.state_dict(),
6 I6 F7 @) \. Z% b+ A4 x }
# K5 z: |4 A' U& x x) M0 m torch.save(state, filename); g$ _. v, U( R
if phase == 'valid':: k! K Z) D4 u/ } A# O2 S
val_acc_history.append(epoch_acc)' I9 @5 m& e/ r' i2 p4 A7 z
valid_losses.append(epoch_loss)
{3 _3 u; n" A; I; R scheduler.step(epoch_loss)
q9 r9 h o* t2 L% ~4 v* ?: S5 Z# ^ if phase == 'train':- a7 A5 c+ Q: a0 B' s
train_acc_history.append(epoch_acc)
) ^# e6 K( Y) Z6 }3 V train_losses.append(epoch_loss)
/ U' J% b7 U5 D
$ f9 y. a Q0 U Q( y print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))! T1 y# ?9 S$ K* c) D6 V6 A
LRs.append(optimizer.param_groups[0]['lr'])8 Z/ n: b V- n# ~
print()9 b0 f+ p" L9 w) o
# U3 x) J0 [/ {' q time_elapsed = time.time() - since
( @: h; W6 V H8 q R/ j6 V3 ?/ B/ f print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
* L+ N$ n3 p7 O9 E2 Y print('Best val Acc: {:4f}'.format(best_acc))
$ I) m6 [( _5 C* `. n$ g& X. w8 S5 Q2 n* `
# 保存训练完后用最好的一次当做模型最终的结果
$ R% t6 E5 w4 E2 \ model.load_state_dict(best_model_wts)6 y$ o! a: H5 i. z2 j( ?( j8 E1 Q
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs , N7 X- h" t7 W) R8 L. p; ~$ y
; R- A/ r* l" e9 v+ K1 s1 ^, b7 Q4 o, {+ H* k/ r
1
$ t' q/ h- x' S# Y8 N* u4 X2 V26 ` K0 U# R$ y$ Q! Y
3! N& q: M$ D6 y n- A/ e
4$ z% \8 S% D2 v& V/ c) L
5
9 S0 u: |& ~) ?9 n6
1 K" V& k4 Y: H0 n3 U7
% e' o! B* L' m84 ?% ~0 C) T2 {2 P# g
9' W- O/ w" ?3 ?+ l2 V' Y
10/ B4 z+ K3 @( h0 p0 V# d
11
+ ]( ^3 `( e: s. j12
7 q# v+ Z! F3 Q5 o3 i; N13- D! M( Y* @5 S4 Y$ K$ F
14
/ M( F5 m+ e+ _1 q, `4 _15
& s1 O& f3 i' n r: J5 Z) F: b16
* @! p' D! i" X5 b- M( |; L17+ \$ E+ _5 D7 n
18& F A( T& w$ B6 \2 g: H
19! C, {& j; Y- n) B: T0 Q$ Y# Q
20
9 t! v t! T, k# q7 G- n7 l) R213 h$ [, `) L6 d& U* X
228 t! K/ _ ^8 ~- S9 w
23
1 G$ z1 n* a8 X$ |+ C+ _245 J! o4 t5 b, @7 K+ b
254 Q( R# ^+ L& U0 ^- P1 G/ E1 Q0 g
26
" q' f+ q1 ]( |# _278 P8 S- @/ o9 Q6 p$ ^' R& h
28, b# i8 Z8 k( L# w) |% R
29
\: ~9 m! z" Y/ o: |302 B( s* u- |, W7 D4 h6 w0 ]. G
318 o5 B' F' Z/ H; U: a
320 O1 a" ^) X' H( @
33
! H( W5 a* B2 j5 I0 |7 A) ]. R34
% x# r! g E) I6 x35/ b( D2 _$ k3 ?; b$ k% n
36
3 J9 q0 ]. n/ L/ V37
+ c; [0 V5 H, O6 c38& l' e% s$ X" F6 F
39 a2 z' u& P0 |1 m! k. n8 `
40& [1 [0 j) z# P5 E9 g- j
41
/ q& L9 p8 w X3 W2 E* f' E2 q9 n42+ r1 C5 y; {8 n0 w
43
( V2 C2 E$ @2 T, U o( Z4 C4 Z, M; q44/ b7 n i7 @; S* Y$ m
453 {/ m$ S4 `7 j
46! u6 [' n. X9 j$ v) @4 W$ ?; C& P
47; Y$ V) }( ^5 c1 v
48: ~/ P5 Z2 S9 `5 l+ d
49+ r6 `: C5 o8 a. `* J3 a
50. R% |- }: ?+ x5 S- i, @$ f- s/ x
51
: O# r3 E( W# E9 s3 a- R3 D; z1 c; I$ r52* H' `8 J! k3 _ Z2 G8 B
532 |* C& n1 J: Z* x, r" Y
54
& e* _* ~! }3 F2 _0 `2 J* E55" \* ^: V' F& M
568 c8 K2 U, p9 @4 C1 C, R x
57. j2 N& x+ \$ {8 W% T& n7 U5 E
58
6 l/ |3 e( M' m3 T9 C) ?9 X59
5 U! ]$ K/ a; y. M7 i* e60' Y: C" G: i' V( k
61
/ \/ U) H* T# |620 d8 }( W/ [, `! ?: b
63
7 \& e- U) ^, ~7 I3 I% {64& Z# z. J& W: _5 ]" H* G
65
' m, l) u- A0 b Y+ _7 W* _$ C66
2 ]1 h3 u# p' ?- o0 x. _! d7 [7 [67
- ?, n" b3 g! g/ J' Q68
) h2 z. }* G! T% X697 P5 E/ h( _; n; N) W c, r
70
! r# S. C$ \, e; P713 O! V8 }' w9 `" g* }7 w& t5 R
72* ] M3 [/ e/ \& t6 r7 q
73
" f# J4 E9 P; E* O- P7 H74: \" _' X B0 b
75
0 ~. E5 q. g$ Q5 m5 z2 |( V- I76) L$ P# w* @0 X3 \
77+ T" `0 g9 D) d1 C, G
78
2 J9 s4 [' Z; I9 n( e79
- L* e+ @ l$ W8 Q3 {. k9 z$ l* u80. |. n7 N2 f0 N0 F2 i( |
810 x* L! n! x( _* O( j9 b
82
: O9 z |$ ]4 q2 ]2 T83* e, @# u/ M5 c' v
84% c: P" N) |0 @+ O8 X
85
# a; p; X5 a& s7 ^7 {* ~, t8 f1 q `86" i7 D/ _4 S9 x n
872 L2 p; H- Y9 N6 L! @* i6 J% v, I
88- B2 d$ a$ |4 T7 ]* z
89: ^: A1 m+ T @1 n0 \ ^0 N
90/ K* g1 F M! p$ i6 P; ]* L
91
5 ]7 J2 m3 N# l1 Y3 {0 _3 @- R92' D9 S) N' r7 y( `; C7 P' [; f6 J
93
! G# ~3 Y) s& m/ t: E94
& ~: x# h& L7 o5 u. E. n; ?/ k7 R95( M; ]1 U+ d. a; j- {1 E' r
96
5 Z6 [" C+ c& ^97
6 T4 @9 K, ?9 D" Y1 o% I1 D98* e3 b- n0 V" d+ G" Q
99( _; {; ]" R {5 X1 ?8 x
1008 p$ x3 y: k; h2 m4 l8 z) A! s
1016 k* V' `- [" `
102 h. l/ [* {+ L. M% j& R
1039 {2 ~* `9 o- _1 w2 ~ i
104
, m; f' N% E1 u$ x. g% |# i; B3 @105
) [) e/ b2 }+ B$ Y% p2 J, G106' k# A+ H9 D# u8 e
107
% S3 W/ f. P2 ~108
# L# t* ^$ S U$ p: G# q* z+ n) ]109
, z3 ]+ |& \, H: `$ p8 X1105 `+ l: h% t+ l3 G" Z9 U
111
8 M" L h4 X' p4 z$ s2 J# |4 }8 V8 H112 _8 ]# r1 O! ^8 |& I
7.2 开始训练模型 c7 V" L- W# @5 R
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次5 j; t ?: K7 L4 G+ W- n9 H
8 `; l, ?9 [9 K5 t, Z9 J/ Z' J
#若太慢,把epoch调低,迭代50次可能好些
. z H+ U+ n8 y9 X1 \& I#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合& R. c% G: e+ L6 O- H% A# Y# \
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
6 H& ]' J+ e$ ]9 O. s% e( I! v5 ~% e9 [ s$ x7 B1 ?8 Z
1
9 W3 d# C) D+ P3 i. H( i/ R2
s( W- b4 B% ~ U5 w3
% w( r# ~; C% S1 e6 z; s4, w5 w Z: e7 b/ k5 x8 P6 s0 I
Epoch 0/4
# F0 m& G8 a2 k5 z: g/ W3 C----------
1 {' M5 P0 h8 L& B( u; @" MTime elapsed 29m 41s8 `# D0 t( \' J9 U% {
train Loss: 10.4774 Acc: 0.3147
6 ]# o/ C. q( T- W4 f/ F4 c. X5 STime elapsed 32m 54s
; }9 \9 K. O) E# D, x. B5 O/ Bvalid Loss: 8.2902 Acc: 0.4719
) W; J3 ]7 |6 \3 C' z0 BOptimizer learning rate : 0.0010000
' B# Q. {/ I+ h3 K; M1 a# ~) `5 W( A8 ?4 `2 @0 x& z: n; o
Epoch 1/4
2 d( F5 X7 K' Y: v' C. W----------
7 b4 E7 R1 q8 ~2 M. @5 {1 r* G5 @Time elapsed 60m 11s
2 r) X7 [, J/ D' h# t" strain Loss: 2.3126 Acc: 0.7053
0 @+ [8 ^ [: P9 V6 R0 PTime elapsed 63m 16s- S; v/ J) M. B" s& T+ y$ B+ w% g7 e9 L
valid Loss: 3.2325 Acc: 0.6626
1 e9 x! @3 G, vOptimizer learning rate : 0.0100000
3 c5 `" `% o: m9 B: p
: @) x" Z! _6 bEpoch 2/4
* s- d* ]/ ?$ |) |----------
`. ~1 T- Q. n) h( Q6 t* e1 STime elapsed 90m 58s
" i$ {( q ~4 v; Z. ~train Loss: 9.9720 Acc: 0.4734$ e: t. [& r( u0 h0 m, p
Time elapsed 94m 4s
8 K% G/ l1 Z" svalid Loss: 14.0426 Acc: 0.4413
* O) s. r, L% z; M$ BOptimizer learning rate : 0.0001000
2 n M7 Q' V2 U( S* q/ v3 C g* A8 |! K8 a% Z8 w) D
Epoch 3/4
5 s6 G# ~0 k z----------
! s+ k3 d) j4 y$ LTime elapsed 132m 49s
. G! S M5 L3 P, a2 U" ~train Loss: 5.4290 Acc: 0.6548
& X; R5 M0 w q s; }6 R6 n. JTime elapsed 138m 49s
0 c4 _4 M' w! K2 q) I, p# s& ]valid Loss: 6.4208 Acc: 0.6027
7 D3 @. ?5 J1 q4 wOptimizer learning rate : 0.0100000% x) O" f5 G9 I: M' Q; q
6 ^, R$ {- S( l4 `: dEpoch 4/4
: ]9 V1 u2 A8 \) ^----------
% e0 Q$ D' A6 n A1 |5 p6 uTime elapsed 195m 56s
0 K7 P d8 A- V; ttrain Loss: 8.8911 Acc: 0.5519
! }5 \, ]6 ^$ ?2 l0 z7 [Time elapsed 199m 16s% b# J! k' z7 M0 B4 f
valid Loss: 13.2221 Acc: 0.4914 b2 }( r" B0 S' M1 t6 N4 N
Optimizer learning rate : 0.0010000 z, z- s, ?: S p3 {# x3 v
0 {' d8 _9 Z/ ^; m0 wTraining complete in 199m 16s1 c( j6 k0 a p1 F
Best val Acc: 0.662592
) r2 L! Q. y5 h. q# B- ?/ j
8 x- N `5 X2 P( p) D1
% V" M, f3 {" @9 _. f) {3 m2 R20 {6 U8 K7 M3 o
3
6 A* D e: _+ O4
# |* l4 H* r. C% _ |- ~5! O) K9 W. F2 i2 M( _
6
: g; H1 I0 Q1 S78 n; ^. A) i# x
82 I* W5 w$ b8 O ~
9
' L+ o! i* @3 L5 e5 h4 i10
$ } _# e- m# x( J+ Q11
$ m. f$ w% l- S. R# f3 p; f, k( [( c124 x; S! _, @! ~) q$ Z: q
13
/ X. U8 K5 k. z+ I# E( _14
9 w3 {* R! m: Q* I0 e3 Q D% r( _* k15 ~ k. b: w- ]" P9 y2 ?; C* `
16
$ R, |( r- I$ l8 s7 O17% Z& \5 G8 Y$ J4 B; n/ s! m2 {
18
/ Y0 k; h7 y, v# j2 N" L( I/ U193 q0 K# `$ T9 c% _8 M0 q0 K: ?
20
7 @3 ]7 P6 z4 k/ K9 T" y$ P21
' J% l# i/ o' v; c22+ L; ^1 I. p! Q8 O* q
237 O1 Y0 V! P8 w ~8 V2 ?4 I% H# U
24
$ z P) ?. {6 Y25+ H3 `' B' s* Z! z8 t) b% G. g; }
26
& d1 O) y! x1 g& }( |27; p( d$ O+ p2 b- i* d/ m% E' p
28
- O4 h7 o5 x- T! T29
, I8 G; U* R' H30- g' Z* f E# r. T7 z4 v- |) T
31! x1 v0 o/ H C" @6 z* E
324 |' G) b+ f* r3 @* r7 R9 `
33) }6 d; n) f- \1 y4 V L
34, k, G+ x9 P x2 t
35
( _, S$ S! [; P+ H36
# O- g) Z% a0 g; A; j9 I379 \6 f" `1 J5 ?
388 r' L/ l% w0 J
39
" p/ a2 ^6 N* x( q40
: d, \. D7 h# B3 L% J41- S0 e' b; h1 H/ l+ b6 u/ k$ b. m
42
' G6 }/ C! M1 h. B9 O5 {7.3 训练所有层
" C Y$ U g4 ]+ b# 将全部网络解锁进行训练: A4 z' J! ?+ C$ V5 T
for param in model_ft.parameters():- \! |5 i( d# }7 J
param.requires_grad = True' Q7 w6 h: b* s
8 u! W' t9 l2 V& b/ o2 C+ Q$ R0 `: s# 再继续训练所有的参数,学习率调小一点\; S+ A. J( L2 k2 K
optimizer = optim.Adam(params_to_update, lr = 1e-4)) T# Y9 [7 Z4 Y" d$ E- t4 c; v
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)' q3 ]) T1 f t0 Y, {/ R# p
1 K4 B6 [7 r4 y# 损失函数
) |0 `) B/ U0 z0 k; Y: Ocriterion = nn.NLLLoss()
& x9 y0 R/ n* v) E) G W7 K3 B$ c1 E/ L1
; j3 r" J! p# T2 A4 R" V% b" n2
# U5 |; d& ~" w) w. w" Q30 J, Z& a: \ b1 Z, l
4
2 ?( Y7 @8 i9 _9 E* t |5
. k% D7 s6 p \; t" Q6# u2 }% w0 K# [' c
71 m) p2 n, `& @7 ]
8
+ B3 P6 I& v* q2 T) B1 |* x9! c7 V: G+ N6 A0 A) G
10' S% a: |/ M9 Z3 P: K0 }" Y
# 加载保存的参数
) t, H2 T4 U% X9 v8 X# 并在原有的模型基础上继续训练
" r$ e; q% X0 r2 q3 }; C" L* C# 下面保存的是刚刚训练效果较好的路径' K3 o' }* e4 f. Z2 f
checkpoint = torch.load(filename)
" {; o7 E% d+ r$ X }) `2 Ibest_acc = checkpoint['best_acc']+ g+ M2 [. S% T1 k5 m
model_ft.load_state_dict(checkpoint['state_dict'])
5 V, L, F" W+ boptimizer.load_state_dict(checkpoint['optimizer'])' f/ S1 A& A' A& E1 ]' E# @
1( d4 V' y( \3 h- J% M
2& m% G5 A1 M2 X+ O9 U1 u
30 s* O- D/ X- R$ {( a3 f
4. L7 B% |' s4 w+ F& B) w
5
1 ?& Q& o" x* [# o; G6 ?1 k9 s$ `6
3 v. ^3 ]: _, `/ m$ }% ^73 {, o9 H6 c' y f' V1 Z0 Q+ q
开始训练
$ Z2 V) l" P( K1 G注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考! K# s: ?( S/ |, C
; q$ p r- T I4 }0 h) @( A
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
! G* P7 s n8 w1
: j* ~9 o$ v; FEpoch 0/1/ h! F- t/ U3 B0 h. w
---------- c% F- r/ W% g! N
Time elapsed 35m 22s
; J% {$ ^: c9 |/ p/ z: ptrain Loss: 1.7636 Acc: 0.7346
$ O7 b* K% K# g. Y5 |5 q/ r3 |Time elapsed 38m 42s2 G: u9 b J) g
valid Loss: 3.6377 Acc: 0.6455
" @) G% s9 p l( n. vOptimizer learning rate : 0.00100006 Q) _) d! F* r3 \6 f
x7 u( y2 H' q/ K' N9 p
Epoch 1/1
5 T" A3 B1 s/ B' E( j2 m+ y/ s----------
) m' C7 `$ u4 q LTime elapsed 82m 59s
) h5 v% ^- ?' o( z/ F& N; Ltrain Loss: 1.7543 Acc: 0.7340
' O4 s+ T* s* _! \Time elapsed 86m 11s
3 `8 }3 E+ ~4 h8 i! nvalid Loss: 3.8275 Acc: 0.6137
. R" c& }9 | q/ ` H/ dOptimizer learning rate : 0.0010000
# V7 ?% T& E: I0 L) X& @# s7 y n2 D1 L( j a" _- r; L ^; {
Training complete in 86m 11s% c. _: R' ?! B8 m- ]) O
Best val Acc: 0.645477) g% G/ g% F. h; ]
/ {% N7 y- {: V. J! _1 v9 f4 T1
# d5 N7 M t6 R1 ^) ?2% G6 ]& _+ E+ e
3: K+ h, F, N3 r9 k
4
, y/ V! D9 n1 h: R55 I" n9 A3 e! K ]5 F' O$ @+ }
6$ C2 q* u* `) F. D8 O
7
# D6 K' E5 ~# b3 R7 T# `8$ G% X" j* g G; N
9
" |# S- K# T9 K4 s5 j& l$ R10- c9 x$ s* X4 q x( G
11
- d, x2 x9 V2 L4 c2 H12. R6 N& T! R7 H( \
13" B; [+ Q( s+ U" P2 _6 y
14
0 R( C5 N9 X' D' Q: J15
7 E7 x2 |) |* q8 z1 i3 J4 m16) o# l5 E" B0 ^
17
4 b' g; x! }% M" @18/ G1 o: S7 k. y& ?& |; }" J3 Z
8. 加载已经训练的模型$ [( _" M4 R6 t6 w. d
相当于做一次简单的前向传播(逻辑推理),不用更新参数
. h& U0 H/ f. ?! t! P
: G' v! j' N- J4 o6 {* s# fmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True); A9 [& H* `# w5 P, A
* L6 X1 Y, U# O" f; f# GPU 模式7 f7 T2 ^$ [% O$ o
model_ft = model_ft.to(device) # 扔到GPU中
6 |) t& a! @( f Y. R U* h: I+ M2 z7 d6 x5 s9 U
# 保存文件的名字
7 ?, m/ q8 N; x% r8 yfilename='checkpoint.pth'
- W3 I7 k: a5 }, S5 ]! O
2 H; @0 C0 @' \- F" L' ~# 加载模型# J2 ]9 a( `/ w
checkpoint = torch.load(filename)
4 S4 y% g- @, T$ j5 pbest_acc = checkpoint['best_acc']; K6 n! |/ a7 Z# x( _+ a" H
model_ft.load_state_dict(checkpoint['state_dict'])/ X# f2 g$ c4 t& g/ ^; q. y
1
0 v; J# F5 S( g8 d S2
% h% H5 Q, U" M2 p# P8 ?, ]36 e# b& g$ B4 ?" a
4
1 K: h# j! h1 {' G, X% k5/ H( n2 J5 W# x6 e4 R
6; u: a7 F* r' j9 k+ W0 B
7% F1 Y7 T. l* M
8 d$ ?+ s1 k5 G
9
2 B* V# y( f2 @1 `5 l9 [& V) w/ A10
9 x7 f, D% Q8 H+ _: \. |2 {; S4 s11
' V6 W$ B$ |& G123 |) h! ~) m& q0 Z- ?* O
<All keys matched successfully>$ Q8 |( m7 W! j/ `' r
1; Q, M8 \; \2 }& E! r9 k
def process_image(image_path):( ?& S7 h" M7 e% j6 T w& X3 V: M9 x
# 读取测试集数据
8 p e" a" s p img = Image.open(image_path); U1 ^; f4 [, o' b# F& M! O6 F/ p
# Resize, thumbnail方法只能进行比例缩小,所以进行判断
' d l1 I- ^2 a' g6 U8 T# i # 与Resize不同
5 L0 O0 @* V q$ x # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小; Z" K: ^ {9 }0 r; J) c
# 而且对象调用方法会直接改变其大小,返回None
; _, C# v* c8 ^% ~$ q if img.size[0] > img.size[1]:: B* m' _0 W3 Y: R& E9 u$ ?
img.thumbnail((10000, 256))
* k K6 g4 p7 d+ Y1 W else:
6 v; P" X' ?# p# r img.thumbnail((256, 10000))$ A' g* o1 ?9 J! y2 s: b2 `2 x7 ^& B
8 j) B9 K0 H) J W* J4 j# P% Q. R # crop操作, 将图像再次裁剪为 224 * 2244 a, p3 O3 z: [; m: g
left_margin = (img.width - 224) / 2 # 取中间的部分
' J6 n! C: E7 i0 y' H; R bottom_margin = (img.height - 224) / 2 ( w4 d5 N! b3 x0 p N; Y
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度$ Q. X, j- _$ W
top_margin = bottom_margin + 2240 H" b5 P2 |, g
, B4 J: |' v1 f2 K img = img.crop((left_margin, bottom_margin, right_margin, top_margin))% S4 z( }* S& L, g3 V" q1 P E! y
1 W7 S1 Z2 H" c # 相同预处理的方法, {! B- Y& f- n/ t
# 归一化
0 V0 _$ D/ C2 E0 @7 i img = np.array(img) / 255) @% R' ~3 Q6 y5 j: w
mean = np.array([0.485, 0.456, 0.406])3 B8 Q* _0 b! D) s
std = np.array([0.229, 0.224, 0.225])
9 t2 z3 O% d2 x4 \ img = (img - mean) / std
1 P' Y, W* f$ N# U/ X: p
) C1 T1 ]5 y1 I! @. u # 注意颜色通道和位置, ~: }) H, q R
img = img.transpose((2, 0, 1))" o Z& ~ j. [8 b
% f* f: ~* N" e0 \+ K" }! X2 P, C( V
return img, Z0 h( s R$ U. J
( V i7 W) P( o4 adef imshow(image, ax = None, title = None):
" }: R1 D2 |5 q h# v4 c7 H """展示数据"""1 e# z: [" u z' n M
if ax is None:9 Z7 s1 I2 P. U2 r1 S
fig, ax = plt.subplots()
6 U) i. P" P3 c# _. B+ W% s9 A i9 I8 Y( d3 S U
# 颜色通道进行还原; x: k2 L0 X. l# s/ p% J) Q- b* q6 Q% M0 {
image = np.array(image).transpose((1, 2, 0))4 @8 t9 o; g% @( ]+ K0 G w/ a
0 f/ R2 ^1 H) ?" Y% r. e$ v
# 预处理还原3 f5 f4 t( f: G, a
mean = np.array([0.485, 0.456, 0.406])
6 z: ^4 l5 m5 ?( Y& w3 \ std = np.array([0.229, 0.224, 0.225])) F4 P3 R# A; Y0 s: B2 v
image = std * image + mean
' l8 O9 L% ^! r7 k2 Z' d image = np.clip(image, 0, 1)8 J/ a9 w" N# }8 b
4 }9 a$ Z' n0 c o$ F2 c
ax.imshow(image): d* ]. Y5 X1 D, E* q% |6 |8 a
ax.set_title(title)
! E( |! R) n; q6 X4 D
z, }1 s p G3 f return ax
& Y/ H) R7 d5 L5 A8 g0 c
+ s+ Q; h N3 C; |9 R! Mimage_path = r'./flower_data/valid/3/image_06621.jpg'6 J& X' j9 B# N; a z# |5 f# g$ t) V" S
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理; A1 Q* |( E2 z1 w5 s. ~. r
imshow(img)
3 u( [/ Y% D" ? U2 }1 `7 d8 i, t! N, y9 c
1$ ?6 [' t% m6 f, j0 }3 Y1 R& _9 l
2
6 I2 n/ X- r% }3
7 Y* K+ t9 k! a4+ M% {& m! R. D
5
* @8 K6 p+ `& L5 O+ n& A6
4 {, F( X0 O2 f& b7 t# R6 P7 F7
/ |5 o; l: B* k: y8 G, T! J6 ?8; Z. b) y7 N7 _; t7 p
9' T, d: Y+ T+ W' {# Z2 G
10' r; o; i: X0 p. _ b. P( o
115 y; v" a/ b# v$ b L$ W: z
12
9 l W2 S! c; w7 `6 L; I+ k6 ^13 j, s7 Z( h9 U) ?0 @2 M
14
9 Q" e4 \7 B( r7 F) s) f& R159 }7 A" m/ j& Z G+ e4 M( g- m7 G% w
16" E& a) |8 U( ?$ k
17! g+ r! c& N( [7 `, L1 G
18) ^4 q3 P0 S5 r" ^$ N9 g* M, b
19
2 r: P% a# ]1 Z3 S20
, i4 x; `; Q7 s' _21
0 _6 p( A/ M# u# M( z9 S+ _22
" o) o+ Z3 O1 d+ j23
6 B7 b6 F. h( ]$ H- o1 S; B0 a24
3 N. l: _4 Q n) y25
- M/ z Z- V' c" ~: ^3 B' J1 U) t26
( Q" @. j: o k& Q. T. b" U27
! N) X9 C( ~) k; b/ }+ ^- s28
8 N, Z& |( j9 d7 W4 ?% D1 {4 `29
9 [8 R {: a7 y/ E% C30/ A2 ~# e9 j( ?! j
31& B' q1 ?- z" T. T
320 F6 [3 A4 v, q( J3 u. f, h7 e
33* t" O. @) k7 i6 E9 P: E
34
6 i3 V. c" E5 s$ K( r35
9 a1 F- V2 x9 B; [$ h3 Y6 O36. V# B6 e( a( T5 F" U% l8 p4 o
37
o2 T: Z9 S+ q7 l38/ j1 K; S! n4 M/ {) x8 m
396 u; d) z" b. D) f
40
: ~' S! Q' x( t1 X/ D41
1 } r/ F2 R+ L' h" F42( m3 ?- c3 h8 ]
436 R. l' H2 R3 E% H2 `0 i* ~0 o: O
44 r( s% Q( A; H4 o1 R3 L
459 X# H. O" e6 {$ P
46
, i' } R# z( S. F6 B47. j. w' [5 K- c5 D$ B3 G1 n
483 U5 }' J+ w, u' O/ L7 |
49: o3 ]6 |! ?# \" {( o; G( [. D
50' t; Q& e' ^, \$ x6 p3 j2 Y+ R9 h
51
4 b$ x* q& A7 A2 J& L8 E528 @) ] o; C$ l& @% ^" _
53
, r% N# F" [5 g5 ~0 q* C542 B/ F$ t" Z, B. |/ M
<AxesSubplot:>0 p1 k5 N1 e# u) O
1
4 i( e3 B* g! u" w# @+ c
) q0 \% P! }% H1 l上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
6 w; k7 O! l7 |, ]. R' r8 b2 D* b* R" R; ?6 f
img.shape/ m0 a: a6 T4 N
16 t1 Z: j( k7 Y1 t- u1 D
(3, 224, 224). C2 a: I( U. o5 c$ C8 l+ @$ z
1* c7 z. u H9 r5 k2 b- W8 G
证明了通道提前了,而且大小没改变
" Z6 L& P2 b2 Y3 A0 G! q
; ?9 B) U& p+ N9. 推理
7 L- T* ?' {+ }& ?' S" l7 f% Oimg.shape
1 m7 s( R* O4 o! j% _( A% i/ U6 K+ M% {0 B' z3 ^
# 得到一个batch的测试数据
2 N6 x- p( v+ S) {1 X8 pdataiter = iter(dataloaders['valid'])' f) X& _$ Z4 o$ \3 V+ x$ i
images, labels = dataiter.next()
1 `9 n# x+ n: j; a2 l/ y! a/ F$ [" \' ~ E
model_ft.eval()" i5 B9 A8 T" m L6 [* h
( @9 x; D5 h: W7 H4 u7 V
if train_on_gpu:
9 i7 r) r3 i9 C7 d* } # 前向传播跑一次会得到output5 d2 {- v* Z7 _7 X% c, A. Q X
output = model_ft(images.cuda())+ V, A( _1 x: g* e* W- a
else:
1 J4 S( p0 q# P M9 K output = model_ft(images)9 x5 `3 m& i8 p* l
5 J; r& q0 o. }9 u& P: @# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值, s- M! o; c) I" w6 l. G
output.shape
( B& l9 J; x# j0 O2 X+ b8 B; y9 r6 l3 m$ R% Z4 z) [, w' x4 a
1
" b. z+ O% Z. q6 C& r- S# M8 @2+ D: K' C) g+ _; \+ C, G$ u
3
4 Y+ K3 p" x0 U9 ]49 X. R/ K9 \8 U" a# {* e: {
57 s& w0 t5 N. I% C
6+ V8 S8 f0 s0 c9 ]) Y! `
7
5 ~1 O5 r2 q8 `% u" z$ ], @8
9 m6 j. ]$ U- y$ M9/ [- r+ g. M' x, j* t& ]
10* ~: E8 k; o; d' B
11
% J+ a5 S) s' @8 J0 Z12
. {4 R( ~8 l4 i% m136 k& y1 q* [2 i5 U+ _
14
9 _( {, E6 X8 o f150 i( }1 d# s' A# A; Y' P* x D% |$ y
16
N' Q% N* w4 s& v5 ^torch.Size([8, 102])
# T- h2 ]3 n) F7 o7 P6 m+ ~2 v1/ l6 ]* ?4 W" h; b
9.1 计算得到最大概率3 b4 f' e7 r1 h" c$ [
_, preds_tensor = torch.max(output, 1). N# m) o+ N, _/ J/ z- \, \
1 P" a6 {! ~1 v2 e* ~5 `
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量, U& f, L: S. t Q! B W& p L: a
1
3 c9 _ q; l+ o- I7 C2" d( q) Q4 J" U2 h4 F9 I% y: q
3& L6 R- P `% u7 n
9.2 展示预测结果7 Q) u! f9 o/ n" t8 I4 {6 E
fig = plt.figure(figsize = (20, 20))
) A. } M1 R* Ccolumns = 4- k Q) a* `7 L+ _( o$ N
rows = 2
3 N) `: a1 V, }# e: `% J2 ^# O! z7 |
- S/ j0 C& z+ rfor idx in range(columns * rows):3 k$ n( l8 g3 A4 e/ y; b7 ]
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])2 w4 h0 r. K) n$ K1 r
plt.imshow(im_convert(images[idx]))3 d L z) }3 I0 M0 j1 U9 Q
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 5 o7 h# e" ^/ Z
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
& |8 G% a" ]- a) W, E' Pplt.show()
5 X7 R* v# |5 Y+ e# 绿色的表示预测是对的,红色表示预测错了
5 \! ~( W; y3 u/ y5 G7 ?1
+ C7 ]2 F: Q0 ]* J2
$ `- G" G1 z+ F, x2 h3$ A" i' r8 s+ y- m
4; e! m! q% s# m4 f; ]' O# S/ ]
5
% W" y6 ^9 s9 {! I/ T& U62 D, | A( ~- o' L3 {: a$ {
78 G+ A& }, S& s9 l( t/ \
87 h# q7 z5 b& s {7 D
9
, _6 Q# ~4 [4 S9 h; Q; z101 \) c w m" W* x. h
11
) k# R1 u4 l) k8 o' ]) V- N( {: ]' n, m y/ U& ^) H( L: N; |( E M$ v
0 Q% ?' }7 N4 G$ x* X+ P
. M* O9 R/ {+ W$ i
————————————————
" }9 V8 M3 s. n) t版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
2 x) e. {+ F3 s7 s; _) }, f原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
: n/ u( E2 ]% h, x6 P* q$ q4 [* Q. c4 u8 M- k
) s3 L* y( @/ X! ] |
zan
|