- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564697 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174632
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
. R. x: [5 t4 X0 s+ q, s* D
+ ]5 |$ s3 L, c% J' z: _4 r4 Y文章目录
B# a; P6 Q9 \4 e3 t3 |卷积网络实战 对花进行分类" x3 \/ m# b. H7 N
数据预处理部分
" B1 o1 V! H# J# \% u* A网络模块设置: h+ t W/ f- i! I# {8 P ^
网络模型的保存与测试
; ~" Y5 R1 ~" ~* i数据下载:: N0 f; ]! p0 T x, C
1. 导入工具包
7 H/ S" p' d! ~$ N! x3 t# H2. 数据预处理与操作9 Y8 G. ]' h- [$ j4 n
3. 制作好数据源1 I( V2 |- P4 A
读取标签对应的实际名字- g6 z( N' V. ]! n! \& a8 a
4.展示一下数据
( u3 t5 K& o' J( \5. 加载models提供的模型,并直接用训练好的权重做初始化参数
. [9 I. Y' [# Y" P8 ^4 q! |6.初始化模型架构
q9 Z' i4 \" K" ^- ~) S7. 设置需要训练的参数
z; R7 z4 d8 E, U" c7. 训练与预测
$ }3 Y3 j) H& X3 }# l$ ^5 [% X; ?7.1 优化器设置
4 a; ^% _+ k1 m) w. f) @1 i7.2 开始训练模型! ]1 C/ J& ^7 j7 w9 M1 I8 e
7.3 训练所有层9 W' B5 T( i+ f: B% A% u
开始训练
3 a! _: z% g& T1 v9 k8. 加载已经训练的模型
/ v( w; c' n+ q9. 推理6 ~; L% C! y6 u: Z* k4 w! R
9.1 计算得到最大概率2 t7 a* W1 r6 q) d+ z( Y4 y
9.2 展示预测结果
# r/ U9 {7 k% p1 w0 P* k" x$ Y写在最后
$ [2 y) Q0 s* e# g卷积网络实战 对花进行分类
* R/ d/ ~/ x; I) a本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
]; _' x* |% s0 h M `5 @7 u4 H9 n1 c6 s0 t
在文件夹中有102种花,我们主要要对这些花进行分类任务- P/ O; p( k8 K5 I5 a o
文件夹结构8 \# V6 z B5 s3 N, B* S2 L+ L
, I, T$ i+ _9 }) T( T# dflower_data8 {4 m$ e/ [$ P
- E0 e W& E+ @3 a
train3 L! ]. k! ?8 l" x
o3 H7 C7 P7 m5 u# s6 U7 t6 a9 k
1(类别)
& m0 `' Z, r8 \- [7 h6 k9 K7 q2. p; y' c9 G8 h! r
xxx.png / xxx.jpg
% p. a. m. ?* Q. `" ?valid
7 g. _' O8 Q1 C8 I- F/ u' `. {0 U# m7 y; R- Y' ?# o
主要分为以下几个大模块
. I- {6 v9 n/ ~ i) @) k4 x. e [
[. ~7 k, d& `" ~数据预处理部分
! U5 t; ?7 h9 D" k, ?2 W数据增强
* }( n) e4 b, o* K! t* r7 g) ^数据预处理
3 a( o# g* I5 C网络模块设置6 b5 L, z8 w9 @+ i1 U. f1 M
加载预训练模型,直接调用torchVision的经典网络架构
! F5 C- ^9 B2 Z因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务! U8 ?+ N S, z
网络模型的保存与测试
7 A" G/ _6 B0 D5 M7 q模型保存可以带有选择性6 T! b8 X/ C3 G" w- Q
数据下载:; U6 E4 I7 `1 Y* M
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
* B1 O5 @; z; \- M, J+ }) G! V9 b3 \6 o( s% Z1 ~1 d: `
改一下文件名,然后将它放到同一根目录就可以了
! N$ c- v0 V& t: ] V* k
1 F3 S. q. x3 M/ `7 U' t下面是我的数据根目录/ P& @$ F/ G; l/ C
. i$ C0 t5 Y9 [' p- u
" |/ K& M+ w3 `+ R3 G
1. 导入工具包
( d8 ]$ y8 x$ H B7 ?+ |9 Yimport os
4 f6 Y3 \# x4 x6 F5 oimport matplotlib.pyplot as plt
0 B* a+ J/ ]4 W' v* m) @# 内嵌入绘图简去show的句柄
6 ?; _) j- u! K- n- v Q%matplotlib inline
( I, u B0 U: m G( X, y7 C Oimport numpy as np
$ r9 E l* x# W7 N" ]- |import torch" W* \! w6 u9 m% {; n1 G7 R
from torch import nn) P0 i4 c( B" X) B/ D
# Y) z7 s$ }+ l. q- Nimport torch.optim as optim" ?' W: @. }9 F% B
import torchvision
6 j1 w& x) N. a$ k, C0 j5 y2 _( t9 yfrom torchvision import transforms, models, datasets8 r- x7 s; i+ W
" N9 U6 w. W) k- @) |, P8 f
import imageio H% S2 l, [0 {$ i4 k
import time. X5 z$ k1 N" B) G% ]- A
import warnings
$ b# ~' ^. v% g: N# Dimport random
1 L* P, |7 @& v9 iimport sys
- o: y* w! t/ A" r/ Bimport copy4 h: M( ]6 C) `( ?! l- Y
import json
: u" n$ f. U. O' }( R5 z$ zfrom PIL import Image
) X- T- Q5 {) Y3 I. k o& A
& L/ I- N+ e( O5 M5 K' c" L D
8 h$ x% s7 g7 l. i1
- `: r7 E$ y7 N+ o- q6 }2
6 w0 q; z& k. A% L4 P7 c+ u5 Q, D# j3
0 V6 i/ n2 X2 ^- ^' Q4$ A4 N9 K. N. o e# f% [3 y8 ]
5
2 U+ t# F! L6 F8 x- h O63 p4 b2 d- Y7 k
7
) B+ o; k- v4 r8 F8
! w9 ` }1 v- |: n- S! B# f9* x0 x' [" V) Y" [, f
10; n+ ]+ L) \7 g) ] f. L
11
3 |2 [" L: J" [* J12: S' p( @% ^& ?0 }3 K& X) p* ~
13
; F7 o; X% n+ T# r& k+ y! K145 E3 r0 w* b3 ]1 n* C( U
15
8 c1 S: Q6 z) C ?" l16. r! K+ q" Z" }
17
% S8 u$ I- H6 X* t! o6 _18# L( l6 m# w# ^3 s5 b2 l
19
2 F, J2 m. V3 [7 D20
2 l: _4 A. J$ }0 T8 X+ B21
# w Z# z/ E8 g5 s; f" I3 G$ ~3 g2. 数据预处理与操作& w5 r8 u) s( l7 `
#路径设置
3 c& |6 X& c; S5 _0 gdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
' E" v& u' J, ?1 z: ltrain_dir = data_dir + '/train': F; x8 F- d# \ q; s; Q
valid_dir = data_dir + '/valid'! o3 ^" N/ ~% b9 @6 R
1
% x- Y! s+ l2 F3 N( P$ E2
' g! T- e; P! T7 E, s' u3 M3
( h& _3 ]( E& g5 |0 C3 p$ O4) A% A1 r3 k7 ~
python目录点杠的组合与区别, \, _* {( ~- K; k
注: 里面注明了点杠和斜杠的操作* j% B+ ~( [8 @
+ N3 P# p( E6 a( F, ~% j4 _
3. 制作好数据源# S2 Z$ J T( e2 W! S
data_transforms中制定了所有图像预处理的操作
+ n. {, Y4 b; E' e# Z6 D' [ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片+ y+ C5 u6 S6 _: x* S
data_transforms = {+ B* o6 L9 D/ d2 U+ k5 M/ T& `
# 分成两部分,一部分是训练
9 ^3 J ~2 b' m3 h 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间9 ]0 }( C& W' H+ t" |* W8 X0 f
transforms.CenterCrop(224), # 从中心处开始裁剪9 U8 u& K- y8 S4 \! M2 t
# 以某个随机的概率决定是否翻转 55开
" [3 g0 p/ }7 P transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转# U4 \& M1 j9 r t7 R
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
/ N& D$ z7 U e! V # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相9 A2 \+ g1 K- b+ J: `, B& U. u
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),! L( \8 i, ~; l% N" k; M0 y
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB$ j `( C. H( G
# 灰度图转换以后也是三个通道,但是只是RGB是一样的6 s& m, c" m* Y
transforms.ToTensor(),
& f) O$ _. J2 i2 f5 H" R transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
6 j& @' j# |% \3 V/ V$ M/ b$ m0 `4 o ]),1 v* u3 o( J: z3 r" H: G) C. ]
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
5 t7 ?5 \0 @& e9 ] B! M" { 'valid': transforms.Compose([transforms.Resize(256),
" T) b" B. ~% C) \9 y: S transforms.CenterCrop(224),) e# [3 Z: o7 e) O8 s5 ^' r
transforms.ToTensor(),% U* _8 d% R; r, [. W8 @7 C
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
* K+ ?, U4 Q+ T. D% ?1 u ]),( E4 U5 h0 o: h. s; r
}
" b9 H! s+ X: G' _* j
( J' v# k$ t/ s2 f9 z& C2 Z* z& G6 h1& Y" o( M& @% Q2 T$ I
2( w: ` K9 z6 q1 H* l# K
3% k: @0 g8 A% ^' a D" r
4
- D9 p7 E* Q% Z, i, A2 x1 B0 s" s56 U, p: \$ E6 Z# j* S' @3 N, s" C
6
8 Q1 S: `% |1 C% k2 S5 \7
- U, k( N2 d: e% M! T. m8; @. P: L3 j8 T; U( z% _: Y
9
" J* ]" d: Y6 q/ k( s( n' h( ?10* M C, b( `5 j
11
9 E' N( `, Q( ~/ D1 z, J/ K9 y12
3 R( ~; L' Z+ R, ^( Q13
. m0 z1 F* {7 s. `! ~14
# i& T# t1 f. v15
( @. c* j, y/ G, {. L. C16
' v/ s1 w; A3 f175 V+ J# b/ |4 \ c2 V
18
+ `! h2 f6 K5 z8 U+ n19. |# Z' y( H/ u/ @ n9 @* t. D8 s
20( c* M' @; }6 x1 I. }/ p
21- W$ p+ E7 _( E( p, x* o
batch_size = 8
; w7 T' t) s2 [" z5 qimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}: `* s: W; z% j% u: D
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}; O1 L$ G; w2 q
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} $ n" \# U4 J+ B
class_names = image_datasets['train'].classes Y9 o4 \+ Y. W7 a: \7 T2 v7 E. n
~. J4 K: L8 Q/ u
#查看数据集合
% b0 c) m: ]; {) P5 \* Oimage_datasets
, d4 A Q0 w3 z7 L- F' C! x( R4 m7 P) m* D; C! Y/ _
1
+ C' `" Z n8 a( @- r2. [3 T; w) J( m/ p( K F" m7 T
3
" H2 d6 e- ^7 @, b4# A4 V+ } A- y# ^% k6 R# L |
5
* e; M9 k8 w( k, C0 R2 \ }6: a9 r. z! J2 e. V" A
75 E* H' g; d9 H* T" U
8
) Q$ w3 ~ j8 u% a8 f9# y3 z7 r' Q' _# p
{'train': Dataset ImageFolder
" @$ u& y4 v: l Number of datapoints: 6552# X0 t$ O, N% J/ N* L: v
Root location: ./flower_data/train
% s- m# l( |+ ~1 [/ I& V StandardTransform* l' I$ G* q* }6 r
Transform: Compose(
( t& _) N2 [! t' t1 c; D2 b RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)$ x( W0 @& q1 {) ^& Y6 S/ K- o
CenterCrop(size=(224, 224))
1 C% M4 a" i2 \/ }7 Z) R* Q) z RandomHorizontalFlip(p=0.5)
{2 s3 F: L" L RandomVerticalFlip(p=0.5)# F' x& r# y6 |3 }$ c% D! x
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])7 r% {5 g. J+ G4 c" Z4 u
RandomGrayscale(p=0.025)+ m2 f D' f% U/ q( l
ToTensor()
7 \6 |3 I+ F5 D( }8 j Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
D* V `; d! U ),
. S& H5 \% [- M7 w3 c0 H 'valid': Dataset ImageFolder: n5 k6 y' T& \2 s' ^. A
Number of datapoints: 818
8 b+ Q$ Y) S0 B6 h I+ u" s Root location: ./flower_data/valid
7 ?+ f2 @& _' P& j StandardTransform8 i1 w7 _7 j( A4 m# g; j/ |
Transform: Compose(+ C6 A' u4 R) ]$ u: C( @7 Z
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)5 x* k- v6 F0 y% v/ C3 L; f
CenterCrop(size=(224, 224))- A& u( K0 u. c/ U' M
ToTensor(); }, P$ T Y4 `$ ?+ R8 G
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]); W6 N: P5 Q5 ^0 b7 s% _' e9 X
)}
! R7 e" d2 T1 A3 _4 A- Z- S
7 b+ T+ D; g: d1& U% p) O' Z) b* H# M- O( l
2
2 w7 h" C5 M4 y3 F% u. {3/ u: P1 q8 m6 ]
4( u) |3 w8 a+ h. B% W, P! I
5
. k% p- x2 d+ D' N1 w2 r' U, F' W6
0 n3 Z( l" `# x' ?7/ m5 W5 l6 i* C; P
8! C( k( [* C0 c" S
97 W: K: v, T0 D; q, r7 ^6 F
10
' S# I% {- I& Z2 y11! ^9 K$ ^& e4 d# [
12
7 `( k$ W$ ^2 O3 I3 h2 Z13
, Q L! L- W1 _! L0 q8 ^! {* v143 A4 B9 G8 L M3 X6 F
15
$ `' X; s: @* W) n16
6 A/ n' T4 z5 G; v17" X+ Z3 G: {% c% @% f B
18; q4 h# u6 l' n) B! _, x8 i% E3 `
19
# j; R: f/ |$ E9 C% h. x3 R20
5 V* r6 Y$ e9 J0 K2 v5 a& x21 k, E. i9 m" ~3 q% t% X
224 [ w; k$ r; E) I9 x
23
0 ~4 c9 Q. ~" M. j. R: y241 d( W: l. Q+ B
# 验证一下数据是否已经被处理完毕
- A$ L+ A) _- w9 ?; ydataloaders* Q" C4 z. S! t& b" z
1
" m5 q5 b# J, k' N1 b2 x. ~1 Q2* z! x% T1 L$ I M: p0 O
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
c; A/ k* k/ @9 { 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
7 i) c# H: q* q z4 p1
# L# B n) ]7 f' H: L2
5 i8 ^ Q8 G V D4 \/ J2 p; pdataset_sizes
( l, @% J- X) O' P4 T2 O/ T1+ M. ~# V( [) I- s2 C
{'train': 6552, 'valid': 818}2 x9 c: q& J$ y# v" R1 S
1
6 L5 q L7 \' @读取标签对应的实际名字
! x1 u, c- |0 R/ L使用同一目录下的json文件,反向映射出花对应的名字
; h( `7 R6 r% D7 w5 T+ i% U- K0 X, g" `/ c- ^# D$ H8 h1 |, H
with open('./flower_data/cat_to_name.json', 'r') as f:2 X- }7 A- v, v$ ^6 Z3 Q, O, X
cat_to_name = json.load(f)) R* W1 N! x9 ~2 C
1! x$ }' M* D' [( y; l- u& ^) J/ }! B$ m
2
* i2 f. s8 [* `cat_to_name
; n- B* r% U0 d' y, R1- R' e6 Q. h8 i+ D. E$ D
{'21': 'fire lily',/ a* }+ ?+ @5 Q+ V& T F( b
'3': 'canterbury bells',7 u* V7 t' I4 A$ X, P! e$ ^
'45': 'bolero deep blue',3 E! ~: l/ _2 N. J3 Q4 q, q) z5 ~
'1': 'pink primrose',
- Y0 ?5 E) f' [4 q1 A( a+ ^: J: P '34': 'mexican aster',1 L: `5 L0 ]: O8 a* m
'27': 'prince of wales feathers',! q2 b S4 X* I9 s" e! }3 \
'7': 'moon orchid',$ Z& t# k' d8 e0 ?# v
'16': 'globe-flower',( @" Z' V7 j0 z; d o4 R# l- \
'25': 'grape hyacinth',3 X: h2 D/ p( b& K S
'26': 'corn poppy',
3 P- Z3 T' S0 v/ D1 o '79': 'toad lily',3 s/ x$ f. f" m3 Y" Y! B. L, W9 U
'39': 'siam tulip',9 }5 u2 l+ ^# u. K
'24': 'red ginger',
5 d" i5 Z- _3 @, X) Q+ M1 R$ C '67': 'spring crocus',
8 w2 P9 Y1 B' x5 h0 W2 k2 O X: N. Z '35': 'alpine sea holly',6 K" ?; a. _# N$ Z
'32': 'garden phlox',( k8 J4 Y- G, k8 P9 _
'10': 'globe thistle',
5 R' R Y* ?/ `$ o0 Q& z '6': 'tiger lily',
7 k3 v6 I+ n& M. x: e '93': 'ball moss',
4 B1 _; c4 t; N1 H6 P0 \ '33': 'love in the mist',( d2 |1 T5 M5 ^# ^, N$ T8 d; s
'9': 'monkshood',
' f; d* F9 d. n o4 X# [' m/ O: U '102': 'blackberry lily',
' D- Z5 Q. V6 f: e '14': 'spear thistle',
i( T3 D& D7 i! a$ x0 N& O6 S '19': 'balloon flower',
+ Q" q P" a( h3 K0 C '100': 'blanket flower',6 @& j8 E& s+ |/ ]# |) p
'13': 'king protea',
4 e: h) l8 l+ O '49': 'oxeye daisy',
7 g" ]% c! C$ F6 E& S( N '15': 'yellow iris',* T8 ]1 n, c: \
'61': 'cautleya spicata',
1 l: n$ s9 [2 w8 i3 H: B9 f '31': 'carnation',
( a7 G; V+ f4 R4 }& T '64': 'silverbush',
, g" t( h4 _6 }+ G '68': 'bearded iris',! y- ^- R9 R0 B& V# d# d
'63': 'black-eyed susan',! s& [5 C P% W2 S5 x/ K' {3 |: @. j
'69': 'windflower',4 [. J0 M" q$ A2 I+ T/ i
'62': 'japanese anemone', {% v1 z2 p, I
'20': 'giant white arum lily',
4 s: p, j, [9 @7 o '38': 'great masterwort',: M, _2 ~/ l p% X& B/ P
'4': 'sweet pea',6 [ f5 F3 i' s! g" m8 t- O
'86': 'tree mallow',
- ?9 ^. q4 H& b. M+ Y C '101': 'trumpet creeper',) \4 _" } W! f, |/ [% S! i
'42': 'daffodil',( W$ |* x! A6 o5 N& ~% H# Q
'22': 'pincushion flower',
$ \, ~" J9 B; h6 S) t; D, p5 \$ ^- @3 O t '2': 'hard-leaved pocket orchid'," c d, S, C" e& u/ S) y
'54': 'sunflower',/ W7 l7 j2 \/ `
'66': 'osteospermum',
0 W. x P/ I4 O; i' i '70': 'tree poppy',
! o" c3 @ f7 D; Z# U) I7 ^+ e '85': 'desert-rose',
4 z: l/ G! [8 G/ T- a5 k D$ G3 Q '99': 'bromelia',
* b% @* ^, M2 X9 b3 j '87': 'magnolia',, q& d2 B& b9 @ Q/ o4 Y$ _- a$ S
'5': 'english marigold',
) q3 K8 E1 ]2 f, P# m6 t- Z/ ?! v '92': 'bee balm',) D4 e, q5 A/ U \8 Y
'28': 'stemless gentian',# ^2 w% ~7 R6 \
'97': 'mallow',8 c" `3 f& g* G+ K2 A
'57': 'gaura',
" I& [3 v" ~( |; C2 F '40': 'lenten rose',' ], z: {2 U |* L/ F
'47': 'marigold',9 `% B$ a. Z. }) L/ q$ B
'59': 'orange dahlia',: ], v" ~3 U) V, [* I% ]6 `6 K) \
'48': 'buttercup',
# j, M0 K' E$ ~( w '55': 'pelargonium',; K8 F2 \3 y C" J& f' {9 P( U& v/ x
'36': 'ruby-lipped cattleya',
2 _5 K; x3 Y+ v6 B2 w' n# [ '91': 'hippeastrum',
; |$ C/ D" w8 w+ @1 T '29': 'artichoke',. U. l* {6 n$ W+ z5 W) d
'71': 'gazania',6 a7 F5 i. S4 [' g8 y w4 g
'90': 'canna lily',
% G; @, H/ ~: E$ R, j# ~7 s '18': 'peruvian lily',
1 ^: z k5 i* g- U% S '98': 'mexican petunia',
% h2 ]8 T5 Z( E# A '8': 'bird of paradise',
% o7 Q; Q$ D6 C, @6 W! t '30': 'sweet william',9 B1 g8 ?8 D7 y- F) I% Q% m
'17': 'purple coneflower',/ l5 \ f$ j, w4 `! V. L
'52': 'wild pansy',
) T/ D. M) X! ]; L; P0 H$ {% B! ? '84': 'columbine',9 @7 q( Q/ B: d; P" a# [
'12': "colt's foot",
' V2 ~' k2 n4 T) d' v! ] '11': 'snapdragon',9 }: i& }9 G$ O* I( l
'96': 'camellia',
1 C1 Q+ {+ |/ F9 e! I, { '23': 'fritillary',. }1 k' O6 |5 D. {! b
'50': 'common dandelion',
; _* I: i6 D5 z2 H) H' {+ ~; ] '44': 'poinsettia',
: |8 t$ C' k# K" _ '53': 'primula',; s# U% @5 J+ V) `& y
'72': 'azalea',
5 r! L" W8 o# D8 Y0 L% {1 P- G '65': 'californian poppy',. ]% q: j3 R1 O3 X( F4 w7 T9 L
'80': 'anthurium',9 l% R+ [- ~7 D/ X
'76': 'morning glory',! s! t8 ^% w8 N( W0 z
'37': 'cape flower',: {, @2 x2 t& [3 J) J
'56': 'bishop of llandaff',$ c& o: Z6 {% _4 Z) `7 i
'60': 'pink-yellow dahlia',
' L9 V6 I* s1 m7 Y5 d/ Z '82': 'clematis',
' B; ]* L: B$ g9 u& u '58': 'geranium',
" g, K0 V9 F5 T3 q& z8 o '75': 'thorn apple',
4 h" k! l. V; m9 A) @ '41': 'barbeton daisy',
3 f1 _$ L; q$ Y. t '95': 'bougainvillea',) P! j2 \: w/ ?5 g0 c
'43': 'sword lily',
4 M) Q8 ~/ T/ e( O) i3 ? '83': 'hibiscus',9 ?: @6 q5 V \+ T, C$ W
'78': 'lotus lotus',
1 G' F- P1 }6 B9 b( g '88': 'cyclamen',2 t. P3 p3 ?& u) b9 n
'94': 'foxglove',
/ D; _/ ^. z" W, D- U! A '81': 'frangipani',
& h3 y% {% w, {3 I '74': 'rose',
6 M* v0 J& A' g- y '89': 'watercress',
1 C. R2 @4 {$ t. p& c/ ~( t '73': 'water lily',) I0 t$ I+ p8 h, x. H4 y# o& Y
'46': 'wallflower',
% v7 s( o5 ]9 X9 L0 S+ ` '77': 'passion flower',4 V9 G/ \, \6 m1 N2 |5 H9 \; \
'51': 'petunia'}
) d4 [/ |2 I" B) e# E0 ~4 X! J" B) W L/ I' Z& C
1 O+ Z; S+ z; y- h, |" u: h
2
/ U) S% E$ P1 R: Q3
' |) U8 A+ _0 Y8 E44 f3 U1 ^$ P- M/ k- t
5
! W. A* ^7 Z4 j3 X; o6 k1 ]6$ N- Z" s# z0 b0 `" e# }
7
) ~% a( z( q* e0 J8
+ s( W$ m0 t' Y; g9# C) Q* w% d$ ^
10 m. x8 P* [, w1 a: Z' `
11+ d' G% m9 C2 y" T: S6 t+ d8 I8 O
12
* n3 j6 J- ?# B# D* U" `7 J, w13' C; @ e/ Q2 t. c8 P! D
14
8 Z D- i7 H" S: n0 R0 }; Z6 d15
- O3 _( ?8 k5 k4 \16
$ K/ [3 g/ F9 j7 a8 O( Y17
. a+ Z& { G. F V189 b6 ?% n3 x8 L$ k
191 Q! V0 K, D2 [5 `; n
20
( @, Z- I- u Y# q, w21
}2 v8 I y% {4 x9 M/ P$ t22
+ i: s) f8 v x- p @5 L23
+ Q; y4 c! X! L% y245 b9 o, W2 a+ O5 d
25' R' R( Y& r; G8 e$ H' ?
26! O: E& g; T: F4 n! y2 U
27
O# X, A: Z6 e& X/ \28
7 S: v6 P' G$ M: r8 C* z29
7 B: w) R4 I9 j& E( o30
4 `+ l& ^( e% F8 `) P+ f7 }31/ G; v3 ]5 _0 v; [
32
( P5 y( |1 K' w S! \8 e33: K* J4 R9 P+ M' j+ W
34- K) S7 j8 J; z& i
358 @9 F/ A& ?& t" i- s' d }3 Y
36. S7 I9 ~6 a2 c8 V7 @
37: ]- o3 Q4 ~/ T, s* {
38
( [7 B6 D: W ^+ I; I+ N, d397 [5 E% T0 E- s P" q
40
/ a1 G4 @2 W, k5 F4 ?- f I41
: Q. n" ?. }, Z, [% a42
3 h# F3 h! _0 N; p. C43
% E4 g2 o. A' t n) _44
; w. n$ Z G8 A( X. |455 Y* V2 U" j. v' ]" y1 ~, @3 r
462 _& H7 G5 [8 d" q! C$ B4 _
47
7 O( |( A! v, b/ O48
; x j8 m5 Z* P# w4 q: q491 z- j: C" N* U) Z3 @" |/ k& T; B9 B
50
0 ~* \& z, t* J n/ f( @512 A+ c) h5 y0 q& [& q- Q9 [$ x1 r
52
0 O" _; V7 I/ j" c/ Q53
( K( P$ V, b9 p' M54
5 n9 M. [8 U0 @55# \2 q5 d; m \- M& _
566 g. _* W# h3 |0 a3 ~+ @
578 j+ c+ m8 G1 _2 ]
58* a" o" T8 n9 ^' Y1 ^0 Q( v8 O
59
! K# N+ ]5 p! B9 R606 s3 D& W( R! B# p0 p0 M2 c
614 w" |9 m* ~9 |0 a3 |
62# Y- ]! O9 i# {, U
63! Y9 C( C. n8 h; Y& ]2 i0 y
64; [6 z: X$ @4 }
65" W H$ P4 n0 [' K' n: H/ w, b0 V
66) E: [. Y+ c) |! |( E8 g `6 _, P1 J) f; R
676 `- R3 j- r/ x$ f7 u; A2 n
68
4 F( w6 t# V( e$ |6 i$ S69# X$ t; A+ F) w/ M7 _
708 y8 `# D$ ]' c" f* |. N
71% K' p2 a1 W& d8 _) m* I+ l9 h8 i
72; [+ q2 F3 N! X7 e' N( d ]- p, Q
73
7 \0 `' ]* {. B% g5 r" q5 M0 i749 S* j g, x3 Y0 F, |, g; Q
75
- f5 w o8 L% G% u0 ?" ?# D* O4 h766 i- x) e W+ A7 ?3 }2 m# _
77. O6 X, w! V4 U' [
78
6 t2 T7 i- H# K3 r# F/ `5 O& [6 c79
, n3 ]6 O/ N I& M1 `80" O2 Y h" A; D; V8 P: g8 k
81" {& Y) _: X5 S: G5 Q8 Y. y& ~% n* L" {
82
$ S* K! k8 L, k4 S8 j2 R& x( H7 m D83
- `* |+ G% N2 _" V84: W% ?# z. \* p! Y+ o
854 Y' ?6 _4 f' c1 b/ r$ _0 ?2 O
86
6 u9 b# C: o3 S! s; e87
1 ^9 |3 ]' Z7 J2 W( `5 ^7 e88
* `& \5 }3 Z: Z2 }! e$ n% ?89
. U J) r% m5 T8 g+ a7 t/ l; f903 P: \7 Q1 G4 Q( m5 t* g) A
919 j+ u$ A$ @6 S2 H" B
92
. ~* x- u# `7 c6 n( _93& D G$ T5 I7 S4 A& i b; z/ a
94
v( i1 a; z2 V* {95
, c+ i2 @) K) |0 l3 i4 f/ ?96
4 L/ p9 E% u+ S, S$ p/ l( M1 v97
9 V8 q. [7 U' O3 S3 O% C8 Z98" d# X! w% O T* k! w; P) D
992 s6 E- r7 X1 i, q7 P4 [
100
3 L$ J4 t0 U ?& d# \/ E6 q+ a101
( ?1 S1 D2 X r1023 P6 `* [ a" c3 q& [1 N# M7 |* t" o
4.展示一下数据0 E p7 g( E6 L5 b/ B# h& \, v
def im_convert(tensor):; v7 \5 n W; ?
"""数据展示""". `% h, w! v( N" W; J2 n, Q; Y
image = tensor.to("cpu").clone().detach(): r) Y1 i# r7 i" Z
image = image.numpy().squeeze() S5 N! T" c. Y% L" o8 [
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
3 M5 [! a p; C9 f # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)7 | l0 i) ~, C+ t, _6 x; C
image = image.transpose(1, 2, 0)
9 I0 g2 _$ K W) b* w # 反正则化(反标准化)
& ]! U& n( G8 k# w; N) E image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
5 B4 I3 X5 [' Z& L `3 u P2 t; o8 g1 i0 M" k3 N
# 将图像中小于0 的都换成0,大于的都变成1: k& N; F6 f& _4 q4 T, X
image = image.clip(0, 1)+ h n1 g1 h. T+ K
; | x$ ~; @6 n! ^& w% t, {. j( [
return image
1 _" A" R- T7 u+ M8 M$ N1: S" Z) A% k5 k) t
2) [ L7 i+ `! S
3
% P8 d/ q% b% ]4% i/ }$ D- Y( P% K0 b
5
5 N& o" `$ c Y! J v6
% _9 z+ j- f3 {: V& U7) i' N7 _/ }5 B
8! j5 a$ i- y/ B0 H: T
9
; Y6 l. p/ \- A3 [& C10
. y p' M/ O, I- ?& C2 G118 P1 s) V, G# K
127 A$ l& [$ @0 R# v; k
13
: b. C1 e/ Z- o: G$ f# @5 \14
, ^8 I$ m5 x, F. m, Y- k* m# 使用上面定义好的类进行画图
8 |8 K# d+ e% t1 j( j. H$ B0 D: Lfig = plt.figure(figsize = (20, 12))
5 F* o" U: o4 K+ R' s: Y+ zcolumns = 4
# B. `" D" h! w5 q8 G6 L9 L5 ]7 a0 v5 Grows = 2 G# ^2 p0 y1 \$ h
0 V: y& ^1 z1 Y3 D# iter迭代器0 e1 r" ]) v6 C1 a
# 随便找一个Batch数据进行展示( V6 B, `' t' T% k/ e' M9 `0 F- L4 U
dataiter = iter(dataloaders['valid'])
+ \: \. k! Y7 l. ~2 k, Y" R% Vinputs, classes = dataiter.next()
( Z7 |& A# c" L( R3 }! q8 ?7 y6 B) v, F; _/ }$ ]
for idx in range(columns * rows):
+ R& U( Y& g( H+ F ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
. J% J% X, Z. a5 I # 利用json文件将其对应花的类型打印在图片中# Z, X+ B G, ]2 C; ]
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
8 H" _" m4 J$ f) k+ c7 F8 v plt.imshow(im_convert(inputs[idx]))) Q/ |! b, o: Z9 G/ Y# d4 q
plt.show()9 i+ s: r: M8 p$ l/ y: q: e
. T5 b1 E' ]% L; [5 a; ^
1
/ o1 ]7 [: n0 V9 q* T9 M( x2
- m# | ]( R/ k. z* B! n, d- u3
# t2 H8 H/ b) }$ s7 k1 J: _6 d4- n6 L9 m- m Z( A% q% P$ Z
54 S( e4 c3 {7 \7 q( A9 ~
6: s8 J! I' b! A6 m8 M1 g# h
7
0 S( e' ?9 Q. l- W, l2 K) o _$ N! t2 k8
' Q4 [& ]* @! S' G' u* K4 {9
; B$ S* f7 E" W10
# m \6 O* s' p, u- F# C11' q4 t0 K" r; j$ d4 V, V' Y+ [/ }7 C
12
5 b) W6 ?" ]# ?13% d$ ^1 t; n" G3 g( t
14, X) h: c- @: f# I! \2 V
15' y. u- y/ z. M
169 ]" }4 ^1 {3 t! U3 y8 v
& H; h$ i% O2 @$ u2 }; D% \
0 X" x& A3 S6 t( \/ f* B5. 加载models提供的模型,并直接用训练好的权重做初始化参数( R, j5 ?" @; f
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']! w' A: B$ t1 g6 [# v% B1 v2 o# j
# 主要的图像识别用resnet来做: |" D4 l& N x- |, r. X0 q0 s
# 是否用人家训练好的特征
7 M3 S; l3 u; b3 E* E6 q0 J2 xfeature_extract = True
; p: [7 G5 C7 A% w1
! }: @' @7 ^/ m0 X7 p2/ Q: A2 h) r0 l
3, K5 ^ f8 g; I' Q) \' z
4
7 {" _) n& W* o/ f( ?7 j# 是否用GPU进行训练& ^/ ?# O0 H0 \+ e% m7 G5 m
train_on_gpu = torch.cuda.is_available()( O2 L) N( u6 _* u @, I
* i- x# r1 z: p# @7 `if not train_on_gpu:
: ^+ J- t9 d j: h, Z- Q print('CUDA is not available. Training on CPU ...')
( y0 E. d* r* Lelse:2 o h$ p/ t( o8 t+ ^% e! t
print('CUDA is available! Training on GPU ...')
|% S1 ?" ^) g' Y+ B }
: E& h$ n/ ?6 [) z8 Ddevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
: J ]4 w9 m2 b' B+ p' o1
/ Y* u2 e3 N+ J) E2
( ~ O6 M& L/ e6 e P. G; I. `) a5 j3
- p$ p, m; f5 Q4
3 Q+ B3 _) g# Y9 [# a$ d. C4 l2 E+ Q5* h5 ~7 Z# i) f5 [
6. o" s# N" H$ ]8 J" L8 R3 q
7
+ ^+ U, ^( O8 g: l$ a1 U5 O) C8* w/ b$ @% M e9 |+ `
9% @9 g# M3 ^; m( w5 k+ \! R( F5 d4 j! u
CUDA is not available. Training on CPU ...+ L3 c5 z8 p. L2 a
1 z _, u# O3 T: n0 h
# 将一些层定义为false,使其不自动更新
: I+ i2 T* \# R+ P" K5 @def set_parameter_requires_grad(model, feature_extracting):' n7 h7 v8 x! m0 [6 k
if feature_extracting:5 g7 Y. a0 ]# P( O& r
for param in model.parameters():( ?, N- |7 a& [/ T
param.requires_grad = False& W7 t0 w& r v9 ^( o, [
1# O' J, M. f# G% m
27 e, \3 _# l9 N3 k2 Z( S
3
: L) R8 L ~; Y$ }( V4; }4 c2 ^4 I& {; B! j
5
" H) u- [; P# ?. q e* s" O# @1 R# 打印模型架构告知是怎么一步一步去完成的
3 G% J6 m( C+ Y3 b; ~# 主要是为我们提取特征的( Y+ `- L% Y$ T, N% E9 P0 ^
1 w6 J! e& i; R5 L( \model_ft = models.resnet152()
& `2 P& O8 S; Vmodel_ft2 M; y, c9 T! |5 c! y+ @
1
0 ?+ D5 ^, b( Q3 Y2 G& o2- r5 q* V" g; c7 X
3. s+ S* T6 O8 x: H
4
. W5 ?$ q" Z }; X! f5
* @8 e/ V$ ]& c$ i3 jResNet(
5 B* @( j8 `! W2 y1 J& o (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)5 G1 \. _+ H/ W. D* |# [: w
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 ?) @ r9 s- _$ P/ h (relu): ReLU(inplace=True); ?" D1 f4 L) p3 l
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)5 t8 \( ~7 u% J, _5 n9 z% X7 y+ m
(layer1): Sequential(: \/ h4 o, _: X8 R* i/ u, n. y0 z
(0): Bottleneck(4 S1 h' D! M# O5 k- [
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)- V! N" \& o, j. Z% F7 N4 u
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)8 m7 d6 T; y0 Q% A5 j
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
* i& \1 N8 [) S (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
/ p" P s. H8 f& z/ u (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
# x7 Q' J) q2 l7 { (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)* [ m$ Q! ~! ?' p$ B( I/ m; f9 h
(relu): ReLU(inplace=True)' m* p1 Y( Y7 H+ ]1 q1 {4 m; a7 W
(downsample): Sequential(% H3 F, F$ k K# c0 F$ o# I1 v8 o# X. i
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ ?" r/ A/ L; N7 X, } (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) m- D! n( x4 L# s- f0 z4 c! b, ^
)
& P0 L3 y0 K- v9 T )
2 [: v# [/ P6 R2 q* ?中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。. @9 G' M+ i& Y4 a0 p. N$ a3 |
(2): Bottleneck(
) C* K7 c7 E# S) P* B (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)3 ?' `) H3 t+ H0 E
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): e4 s5 }* @9 }4 \$ Q$ z0 j
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
, j% Z, i8 y" a& H (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
* Q' K O: ~8 J7 V3 [) O8 v' l (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)8 o5 D0 Q1 p: L& @
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 @) Z2 A8 G5 f
(relu): ReLU(inplace=True)* A" j3 y) H$ u$ P
)" G. b) }' v$ m
)! R- e4 D5 O5 s
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))* G' @: ~% R- |- P( Z7 a1 u! {$ w, L# N
(fc): Linear(in_features=2048, out_features=1000, bias=True)! |# W* g2 j- h' |
)
* {+ n1 \# x& X" A3 y; z5 [" u& O
6 C3 Z7 T/ z( u. i9 m! \4 F+ s0 L1- a% b2 |/ ~# S) L: B: b- E* b
2; A3 r! _# f) N" U- ]3 q Y
3& \2 U9 o$ k- D A# C0 Z' v& ?
48 d5 r& j" G- [4 L; j/ `8 i; g* w
5; U$ V) }' O. P% d+ a7 q
6( N6 F! K# I/ [, N' M
7: r, f( X. V! I& L4 q) N
8
7 p- q7 b$ K9 W. a0 c92 y+ V+ I6 E) U
10
5 n1 R w. D2 ]4 O11
# C0 j; H: ?: r1 p, Q/ ^12
0 `# l: N$ z! y, G5 y- Y13. n0 L# ^/ z% r- c$ w
14
! e! Z2 Q5 s+ }) Z156 m4 Q5 I* y0 i
16
4 z$ O n3 g5 Q% @) f- y8 P1 l17
+ P1 m" t6 Z: u% Y3 K( t18
/ ?9 i( P, z/ L* m; \# ^' K% D19! R* w s2 W& b! @/ \% {* t
208 d* n) e: I) K o. \# f% q' m
21* f9 u& h+ w1 |; C7 P1 a
22
1 {. N& d9 _1 U) B' }23
$ V8 g# h7 Y/ V: S7 S8 b1 q24
% m2 L0 W: L4 n25) p; o+ Y7 V1 A k* k
26) Q% N7 k$ c% j6 n
279 s% G G) T" \3 @, s
28
$ w3 e( C% J& S$ u1 V29
5 L: `6 r& [: L6 Q8 Q+ T# B30
% S2 i2 f" r- C- D31
. W3 r( r* B; L3 _! z32
8 u) S! G( L' o. [ u337 B% a" C( E' D+ k" {6 N2 p$ p% }
最后是1000分类,2048输入,分为1000个分类! O6 W3 p' T/ {) J8 N8 ?
而我们需要将我们的任务进行调整,将1000分类改为102输出1 ~3 c5 c, N4 F9 i/ f2 i) `: a
3 m" \; }. m/ e J4 G; E N6.初始化模型架构
1 w+ r$ t2 \& T* I步骤如下:& |* f/ B4 a Q+ w6 ?7 f7 |( p
- }' Q- Q( r4 l: g) b! R
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数2 \+ \, M; M6 H c0 h
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)* ^5 N, O2 |" a a! y% @
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数" i3 d0 B9 u- B
官方文档链接; ]3 b8 n1 Y$ D% _9 `
https://pytorch.org/vision/stable/models.html
% T2 j7 z6 j) i' K5 o. H
, J }5 u! j5 n! m3 i4 Z# 将他人的模型加载进来+ S. z. J0 |. S
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):" K3 \; p/ Q: r" M1 `3 J3 h4 S
# 选择适合的模型,不同的模型初始化参数不同
" i L% q4 B1 e/ P$ z model_ft = None
' Q' ^$ P' {! Y& ] input_size = 0
+ M, ?+ s' h- g/ B% f/ Q, C! a( Z: L! f9 w$ d
if model_name == "resnet":; e; [3 ~# k4 h* [
"""
2 s7 e, M) E/ g9 o4 Z% v+ x: Y Resnet152
2 \# H% r( P$ T5 m """
# X. ~; d6 n8 ^) c5 X( ]( @" e# u9 c1 W7 o5 K1 U
# 1. 加载与训练网络
5 l J M& b' F# t; t) l model_ft = models.resnet152(pretrained = use_pretrained)! v7 }4 t0 a4 S, Y; k% V6 o7 |
# 2. 是否将提取特征的模块冻住,只训练FC层
' P. J& I9 V) E+ f' k$ P5 o set_parameter_requires_grad(model_ft, feature_extract)
0 p/ n* z" @" L2 c2 ^- b7 A # 3. 获得全连接层输入特征
" g6 x6 J# a% r" }8 p! q* m" x0 K num_frts = model_ft.fc.in_features
) P1 h5 {- _0 [9 u. J- V1 L# ~ `; G # 4. 重新加载全连接层,设置输出102
1 D% Q+ O: R7 D model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),2 l8 k& s. x1 X6 Y/ F: ?; u
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1, e$ b5 X, a4 \! k; O
input_size = 224
3 M2 S% w+ a3 y- \* N$ y" ?; Q1 B
5 e. t4 K. h4 N, w/ z+ j elif model_name == "alexnet":
( d( z8 p* ]4 j. [7 J """, d7 c% s y6 [
Alexnet
, q5 u8 u7 a' R) x; V% ~" L """
+ l( X# q9 N5 C/ B& p2 u* V$ j! { model_ft = models.alexnet(pretrained = use_pretrained)
. A9 a T' \* w% e set_parameter_requires_grad(model_ft, feature_extract)
) l( C- V3 w$ | A- h1 d* }. Y. _) H" e
# 将最后一个特征输出替换 序号为【6】的分类器
8 J2 m' y7 b1 X num_frts = model_ft.classifier[6].in_features # 获得FC层输入, b& }; X: L0 S% p' y1 \
model_ft.classifier[6] = nn.Linear(num_frts, num_classes): H! S( W9 g& ~5 ]! H
input_size = 2244 x9 d" U: D g2 V7 X+ E; U% a+ x
$ ^) f( e* X" A5 M# J
elif model_name == "vgg":5 \2 Y1 o# W0 X( k; x o9 T
"""
8 s0 e2 F+ K7 w' n |+ Y VGG11_bn3 T3 W: U; z( S1 T- F1 m7 L
"""% F b5 T* z8 ]: U
model_ft = models.vgg16(pretrained = use_pretrained)1 g8 Y+ P z: q% Z, M a" j
set_parameter_requires_grad(model_ft, feature_extract)
) B( Q, Y) Z% x6 v num_frts = model_ft.classifier[6].in_features6 j; S1 v2 e" q$ a& Z
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
# G$ J( ]6 P$ z0 F) | input_size = 224# d" x/ f5 c7 S7 Z( r/ }; i. n
0 ?* g$ y" I: A$ J elif model_name == "squeezenet":$ Z" ^, H) g8 w: u
"""
2 c t, r1 E0 Z! N Squeezenet
8 I' U' b' c* _ """, U, i9 d8 O B$ J3 a
model_ft = models.squeezenet1_0(pretrained = use_pretrained)- x) p( {8 U" ?/ Q6 w- R% J& R
set_parameter_requires_grad(model_ft, feature_extract)
; H+ s, S$ u& O2 c model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))% s" N% ^- q$ o" l2 ^
model_ft.num_classes = num_classes" e$ _7 P( Y2 F% u
input_size = 224
/ U6 K3 {5 ?6 Q$ P! Y( w9 k/ I5 q, S0 o& V
elif model_name == "densenet":
. [) Z' V. G7 e# J, `9 ~ """
& e3 E: u2 o$ {0 e: Q$ _ Densenet
/ h' a! O$ g; C! p, P9 y" d/ g8 z @ """* Z3 T& ?( t# c2 [: x z j# S/ l
model_ft = models.desenet121(pretrained = use_pretrained)
. ]1 r @: [& q( z! _7 f set_parameter_requires_grad(model_ft, feature_extract)$ |* ]. H! @3 y) h6 U. g/ x8 i/ S" E
num_frts = model_ft.classifier.in_features
- l; O' m# [# Z9 o6 n3 Z1 M model_ft.classifier = nn.Linear(num_frts, num_classes)
% O# m3 S. u$ L- Y7 h" H6 u4 { input_size = 2248 g) i0 J/ y3 X
- p @: a, I8 j2 R elif model_name == "inception":
& ~2 F6 A* Y. o* h' X """) j( b: J& {. ]: Z" W, j
Inception V3& O6 }; b2 m- \' n0 h' n
"""
$ }! |! k* D! j2 D: ` model_ft = models.inception_V(pretrained = use_pretrained)6 E2 x* L2 c* B5 Q" |
set_parameter_requires_grad(model_ft, feature_extract)5 B0 l3 J" f, i- w. l5 n, \- ~
, A" b) q. x- w+ v* Q num_frts = model_ft.AuxLogits.fc.in_features
1 X4 P% b: u/ I8 F model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes). p$ U; Y2 R& J2 }5 M$ ]
2 ~# X& L, B5 _$ W; {( h6 `9 y num_frts = model_ft.fc.in_features
$ X4 ]( ?8 [0 `- z model_ft.fc = nn.Linear(num_frts, num_classes) o: J. V( Z* w3 j# t+ z7 z8 M
input_size = 299
* V; v' a4 n! n7 F( i; | W) [. M! Q
else:
8 v: A0 G, Z) v: A print("Invalid model name, exiting...")
' C: ?* u" n- Y# P2 \! ?" e3 j. f exit()
$ m. ^$ u5 C; G3 s1 D7 e; c+ f# Z% I$ O+ h6 r4 W4 @) G. a& b. |1 M+ A
return model_ft, input_size& F9 B# i. i/ g
- E+ d2 ` }5 S2 D* x3 N1 B1
) P. C- a& V/ f2 u2
9 `, i" ?8 Z/ U( K7 v3: q% L% v4 ~& c6 l: d4 v" i& U
4
' K! S. l$ E, q2 r% I5
- ]- @5 f! i6 L- L6 _+ a6
* d7 d# E6 Q+ {' Z7+ J/ `4 ^, z3 o4 F' j8 d
85 f) i2 K. s6 o) a; r7 B
9
; u6 S( x: g: v10) h" |& c4 E' ^ `& z% R H
11
' Z; o8 t/ I, `9 `) S4 r" f# O12: p0 T1 U9 c' e
13
: ~2 |* Y" i! w- y14
5 c/ z& R5 B' j: O- ^, q# Y15& ?" a- z0 z# a7 L
16
+ `8 T3 |. r/ ? A7 K# S2 d% E6 I3 E$ ~17% c, y& f: {- o5 c/ K) v2 R
18
+ p3 j: n# g" b5 S, h199 a" x* n% n! X9 N/ D( [. N* }2 ~
20
# n+ D6 i$ I+ t/ j2 L21
, _$ h1 g: m7 c22, g. ^: k% `" Y7 X! ]
23( @" Q0 e# M) R0 p! T4 k* \
24
/ i. d$ t/ w1 b7 i25
5 k/ Q7 t# \- m2 j: e262 O; V$ `5 S0 E5 [2 c" h, B: ]! M
27' c2 i, V H2 r( o% O
28. j. J/ A. a5 ?- Q9 ]8 R" l
29
- y r* g, ?) }3 c3 t" U, k30
8 J M* Y4 `7 w1 D- C31- T: S& D/ E: D8 ?. s, k( T
32
2 m" z& t/ L; [+ t4 m( z1 H2 |333 F. Z4 |/ y% W# ~# @9 e& n7 @
34
$ f3 ~9 M: m" Q& v- k' k2 ]* ~) a35
1 t& r- R6 D9 \4 S1 y9 f% W36 ~/ B, r' e3 k7 h$ B# }( Q/ L
37 L6 \& q1 T; m. Q, a! x
384 p" n3 T" }% y
39& I; E' y w7 H0 [: f4 K2 X
40
* D' L0 q9 s/ m41+ y, E1 @1 P, ]/ O9 R! [
42
/ [- U. P4 j# l; M5 b% j43( S& {' c4 o$ v/ q
440 m5 n. F1 Q; B( w# S' s' b
45. F, u( e) c& Y. |
46
8 K, d. k* y; b$ K; f+ M7 R' C5 [47, W! Y, I u( Q8 O0 R
481 P% i$ C5 J0 a9 q
49
6 C2 L7 L% n1 C7 }50/ D, B8 d$ D' |0 ?& E
51
H/ [/ z( g6 H3 z7 v8 S! Z52
: O; y2 K, R8 [4 X+ ~53, }! ~& T' W2 H3 Z
54
: |- U% }0 W& ?# h A6 W+ ~55 E4 n( K( u. P; f ?2 d' } l6 e
564 F* @4 n; g6 W$ K* V0 u: s' z+ W
57
/ B" n0 ]* `! B58: N* H6 d* Y; `& Q0 e e
59
: y/ K( {5 u* T60
6 J( k- `) b( |* |$ C4 N o61& u, \) z5 C \: ]8 N$ i
626 Z7 }6 [% |) r% v% q$ P- Y# a
63- N4 v; x( D }; Y1 w% X
64
- Z; C5 @, K$ \0 X/ W- g2 }65
/ c; |1 ~/ B' e. [5 V6 { j66
( @5 s5 T1 h6 Z/ T67
8 i4 `2 j/ x$ f& ~$ C7 O6 H68. H. B6 E* e+ s l
69- \. E! }! \2 t
70. b: W2 r% ^, a9 b
714 r! X9 A3 w5 a* {, @$ X: y
72
8 u9 K' l6 w4 k& U7 A73" h$ a; ?, l! a+ F0 M/ z( |5 y
74, ]4 y- x! d9 U( G5 ^% ^4 W. F
759 C% O6 X1 e2 z9 P9 y
764 G& U3 Q- e* R0 Q, W# j e
77
$ U8 F' U5 D5 Y% h; N78% E x& b& q' Q7 k9 F
79
Q9 m! N& s' M( W80, o5 i* `1 h4 J- B _
81
' i( p4 @6 @9 l9 d% r82" X# p( [( f& ~3 ?( i+ o( S/ p
83
2 b' Y+ ?% ]8 i: i2 l0 f% M7. 设置需要训练的参数
* U3 n/ J$ f* }+ K9 b0 h# 设置模型名字、输出分类数
' q0 E3 j" R' Q; V6 Gmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)( y- l3 `# w; t! k: Y9 O5 M
: g, J/ {& @. z) G! C$ ?6 j# GPU 计算
5 L$ d0 Z- O3 b' z) N7 Ymodel_ft = model_ft.to(device)
, K. ] {) f+ ^7 v; T" q. T
- \& t* n# s0 K8 G5 ?( j# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取8 @, k- A2 m% B- z, S: y! C, _2 }
filename = 'checkpoint.pth'
8 D- Q! b; ~3 J4 u. Y7 }$ A" _5 F7 D$ D
# 是否训练所有层5 W X2 {. d( G! Z3 G% y4 M
params_to_update = model_ft.parameters()- I/ ]/ O% r6 ]
# 打印出需要训练的层
/ P7 x" q" ~1 q" G& f6 Cprint("Params to learn:")
: G0 o4 U2 u0 dif feature_extract:
9 ^+ Z1 B0 O4 ]; U8 V params_to_update = []
" s5 t7 O) o0 [8 i1 \ for name, param in model_ft.named_parameters():0 r0 q2 X9 @; p6 b. e8 O
if param.requires_grad == True:! \1 W4 R5 | A m! g+ n6 [
params_to_update.append(param)
; H: l# V# a6 c print("\t", name)
: f% J- K' F- g- b* celse:. G, l" g, Z* i/ w) S G1 N
for name, param in model_ft.named_parameters():
! F; j$ N1 n o; T+ U. J" S if param.requires_grad ==True:
# @" w. W% k2 K3 J% e6 r print("\t", name)- T1 `1 o- W4 W$ E. V) p
5 p q0 t# O, @0 O; h- k1; t' W. Q& C: y% ?# ^6 G3 X
2
2 H% z% B3 O1 O6 G9 u+ d9 L3
9 f, R |( p; h! @# d6 v8 J40 Q- k& E: w5 e1 q
5
/ A6 p* i& I2 {: P: X6( S8 J) b/ C; V& f/ T/ e5 _" @
74 e e+ \4 z1 W5 _& u" `7 |% v
8! x/ N+ A' [) s9 B, ~
9
; e; g0 ?0 V3 u/ I10
# F! d4 q. |9 I s11, D+ t$ r5 r1 G
12
' Y2 \; q' o/ R13
- W1 ?% C/ n0 |& ?9 M+ `14: p5 p; o4 N2 k3 r. a
152 F/ {! c) Y/ n3 n b7 n9 x/ D% y
16
7 d1 ^9 y; O1 I; A178 w+ z* [( D A" {
180 Q5 X) z# Q; f
195 H7 Q i1 B/ ]' K8 p: q5 ~) m
20
# ]9 R# p& C! P2 B; q* l. W% L& c21, z. o# |; V5 g" {
225 A6 V) K, `& m) c: d
23: k4 j1 U& D! a; z- s
Params to learn:
9 N( [7 g2 o* G% l. p, z6 g fc.0.weight% I" L' Z9 p8 Z" F& s* r) c
fc.0.bias
. a( M8 V% t8 Z! _17 |; I, @4 k1 y5 @- U! ^7 L
2' r% c* z$ l" F( e7 f
33 {' B' F4 [, P( ]5 _' _/ j
7. 训练与预测9 I/ g8 K ?+ w+ N6 S
7.1 优化器设置
2 H% ~# N: N" R' n- G0 z9 c# 优化器设置5 y9 ~2 ^, u& j. l
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)/ Y1 B. a6 ]- M5 S2 _- w# G3 E
# 学习率衰减策略
4 m$ M! _" y! j2 U+ k* [3 i7 Tscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
4 E( ]" m2 k2 ^$ N3 |" c% r6 V7 R4 t# 学习率每7个epoch衰减为原来的1/10- K; o( U4 K# n! ~" h
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算+ }8 j t$ M! b0 e
8 C- }$ v# \% ?3 _8 q) M |
criterion = nn.NLLLoss()* M L9 s! N" D& @ ~
12 S+ I- e1 y2 I3 K" l- a9 n
2( |3 E5 A6 G# }% h N# M
3
H+ `: z5 a, u" o+ l# F40 L# {' V) Z5 {8 X5 E3 R* O, w
5) V( P. y- S7 j9 R! Z1 A. n
6
* W+ \& W- l7 w4 Q7 ^9 o9 ]78 P/ K. h& B* |+ N7 D: J# [: J
8$ m. [6 d+ }' `) s! j/ M U/ I
# 定义训练函数1 J, B; a3 S$ |
#is_inception:要不要用其他的网络
; i. E+ C" ^2 Hdef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
2 g* W7 i" J: Y$ {5 S/ w% I( I# O( p& }' K since = time.time()
& B% }6 u" }' L# i #保存最好的准确率
4 C1 j/ K* S Z) Q6 K( U$ B best_acc = 07 d6 l+ ]! w6 S& Z: y
"""
% g! y( @' Y3 i3 e+ l! Z! u0 c checkpoint = torch.load(filename)
: L- A; q) J4 {; s best_acc = checkpoint['best_acc']
0 a9 c$ |$ M2 `0 l; m" J h model.load_state_dict(checkpoint['state_dict'])9 t" m! y2 x7 e5 q; T1 D
optimizer.load_state_dict(checkpoint['optimizer'])
5 P( |: y/ P! S0 F! e model.class_to_idx = checkpoint['mapping']; r5 y3 q; i7 ~* T0 C" N1 R6 P8 c
"""# j4 `: n0 h/ l2 i) N
#指定用GPU还是CPU
" F4 k9 w# K; o2 M; I& | model.to(device)' r1 s& l7 h z0 g' o4 d
#下面是为展示做的
. j: P4 ]1 X) D% X val_acc_history = []( G- w% B8 U3 n
train_acc_history = []1 g- o! |5 n& Y$ O; Y4 H+ D, ^
train_losses = []
2 p: \9 y( E6 ?2 a8 V valid_losses = []* X# @4 G" t4 k ^6 ^! Z
LRs = [optimizer.param_groups[0]['lr']]5 G/ ^& z8 r2 p' n5 p
#最好的一次存下来0 c. t5 P6 t+ {+ L0 K/ ]& ?
best_model_wts = copy.deepcopy(model.state_dict())
) \4 E3 A" ^/ [5 ?7 |! k$ }5 _
- N" b1 T5 U0 B for epoch in range(num_epochs):
( X/ S f( z# J- E y0 L print('Epoch {}/{}'.format(epoch, num_epochs - 1))8 X6 D. g4 H0 V
print('-' * 10)) e- a9 F; M: {: w! I
: p0 w4 `7 H5 S6 ?% A
# 训练和验证; v+ U6 P7 S, g- W& f' k" w
for phase in ['train', 'valid']:8 B' M }2 t; E* ^
if phase == 'train': P* {0 y9 s+ h: {$ M
model.train() # 训练
4 k2 Q; D; [ P# _% k* ^- O else:. Z: _+ }. p/ |* B
model.eval() # 验证: K! {. N1 g1 f7 d ?+ b
% Z; q* @# n f+ p) i
running_loss = 0.0
/ c2 O: Q8 V# t" Y; F" I/ l+ L running_corrects = 0
; }: a9 _, g2 U9 p7 R" C/ F9 D9 P* \" g
# 把数据都取个遍
- v' u* C% [6 S for inputs, labels in dataloaders[phase]:
* ]. S$ b/ D" a' Q8 ~" @+ w #下面是将inputs,labels传到GPU8 U: A5 i: U }5 a4 y
inputs = inputs.to(device)
( R3 h& Z8 h2 l" O7 d labels = labels.to(device)* z2 T! o) x) T( ^3 P# m% [
: c3 H1 H/ I9 c- j7 s # 清零
& x# G' K9 r& l0 H& w& G* y+ W4 N/ s8 F optimizer.zero_grad()
! @9 v$ U# u4 G2 h # 只有训练的时候计算和更新梯度+ |) z* l& A8 }/ P
with torch.set_grad_enabled(phase == 'train'):( @, r8 B4 V' Y4 C K3 x5 n$ E
#if这面不需要计算,可忽略3 N* y, V0 y4 D) |. Q, b& W; |7 J
if is_inception and phase == 'train':2 `$ ]2 N- z* Y) g0 x3 q# H
outputs, aux_outputs = model(inputs)5 K; }& A$ L+ ^( u6 h+ q0 J$ E1 o
loss1 = criterion(outputs, labels)
" H' {9 n f" R' T0 O1 K loss2 = criterion(aux_outputs, labels)+ A0 n6 q; s! s* p1 q
loss = loss1 + 0.4*loss2
" d4 I& q, y E" a1 {# R1 _2 e else:#resnet执行的是这里1 k' r! I) o! ] o8 F4 S
outputs = model(inputs)
5 e; C. |' o& Z3 @ loss = criterion(outputs, labels)& z0 `/ @ d ~ V- m
8 n7 o" X* N* N: q+ ^
#概率最大的返回preds
6 E* f- W+ N) _4 ^ _, preds = torch.max(outputs, 1)
u8 C1 F: ^, [/ w* K, W7 f l, b
# 训练阶段更新权重
: o/ e v' T+ H8 f& N3 e. z: K if phase == 'train':
' w9 n6 R$ n1 S# j loss.backward()$ ]8 U7 j# @$ o3 P( F4 T0 v4 f
optimizer.step()
/ ]) ?6 \6 V7 m+ {$ i3 W8 d
: }$ r6 R. }# `" k* r4 U$ n: }% O # 计算损失
& U4 `* p' F8 o9 @ T6 c running_loss += loss.item() * inputs.size(0)" {) G+ G( v" V
running_corrects += torch.sum(preds == labels.data)7 N# B1 [% A* ~0 o( `0 A
* \9 V8 a! @+ F- R2 D" Q3 H #打印操作
: @3 E4 A. w4 Z epoch_loss = running_loss / len(dataloaders[phase].dataset)
7 t! G5 M/ A" c2 d9 Q; ^ epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
* b( m) j( [, a9 Y z- _3 l
8 B8 r! H5 j' @( @( z! X! B1 y
time_elapsed = time.time() - since# y# H$ P% ?, V1 A& D
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))! B% E* g1 @ L7 C
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))/ B5 p. u7 y3 l" @
% w0 y. g/ N: A$ W
6 c* Q- x$ U1 H5 [4 @4 p+ a # 得到最好那次的模型
( l+ u/ N8 l0 v' `: v0 A if phase == 'valid' and epoch_acc > best_acc:
; x& x- ]0 k' }1 a best_acc = epoch_acc' h* _: i! I/ n, M+ A( Q
#模型保存! f# w' d! Q ~6 Z6 m5 j0 e
best_model_wts = copy.deepcopy(model.state_dict())
( V4 [- y# }' }5 } U state = {
* A1 h- |" K2 b Y) V6 t #tate_dict变量存放训练过程中需要学习的权重和偏执系数 x0 |4 C- f( D5 q+ C" V
'state_dict': model.state_dict(),; E/ s5 s" Y0 y7 o( Y; l
'best_acc': best_acc,
1 S, B/ F+ [+ g' {0 m& C8 z3 X 'optimizer' : optimizer.state_dict(),# }3 T* J! Z3 x* d0 B5 B( Z
}
3 M0 e; E# G+ k8 V9 Q" T torch.save(state, filename), S* Q# x# O) @
if phase == 'valid':
% ^* {, c6 Z* w7 s; b) I* N. K- ^5 ]+ s val_acc_history.append(epoch_acc)
. a1 y- C w+ ?& R! G valid_losses.append(epoch_loss)
6 H9 m f$ b. G0 r. \* r3 J2 w scheduler.step(epoch_loss)
. B0 D" }4 L R3 F( n! e% _ if phase == 'train':, C1 v+ M3 [* g0 w* R9 W
train_acc_history.append(epoch_acc)+ f% u4 @! @( |- e7 J- W6 H7 ?
train_losses.append(epoch_loss)
1 t5 y' |$ @1 C0 C0 o4 |
& R- }: n/ d3 j6 n) ^ print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
# Q$ h! i" Q. n: Y' l LRs.append(optimizer.param_groups[0]['lr'])0 x, q4 d3 C$ ^& A
print()
; }! S9 W3 p/ ^- w/ I: N2 A% U
5 c) ^2 I# w4 [6 H' {4 T: Q time_elapsed = time.time() - since
. J( x/ x1 o+ m( }* z print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
( x: Y& g* Y2 j9 L& |6 p print('Best val Acc: {:4f}'.format(best_acc))6 Z) [- _8 I7 W' ]! w- Y& N! K
8 K g" u( f: _$ h0 b6 \
# 保存训练完后用最好的一次当做模型最终的结果! ]( S; ]5 `+ D! A2 \1 _
model.load_state_dict(best_model_wts)
) d! r% u2 s* {# l1 Q return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
5 c* d% R. g2 [; W- o3 @* R2 N5 Z9 i
, z* g- Q; [, [
. K9 G7 p/ X1 X: d16 S5 J. y/ {+ _, U
27 d3 ?; P) v; T7 t0 \
3 J/ f t9 _; l5 o6 P
4
5 m0 V/ K* U8 }2 J$ f1 w$ O5 u1 E- Q. I! x2 C/ h" g6 P
6
* U5 Y) S6 e- ^# f/ {! z7
/ c/ ^) Z' ]7 p! y85 y9 l) v' z1 x* ^! V
9
5 v$ V7 y; {7 b3 Q10* ?; z8 P7 O4 H
11
; j6 k8 \3 D5 T, H+ c! @128 a7 G' N# b) P9 Y6 r
137 P; j4 ]/ I7 M+ a$ ]" ~
14
R4 C1 ^7 ?4 E7 [8 b4 {: {15
/ i# X+ y$ @7 D# D16) X) A. a+ \! \+ _2 e
17
9 W; J5 I2 \' g* p7 b# {/ W/ r18
3 l# Q0 k( f _$ m. Q0 Q4 i' t19
8 k/ ^8 `3 b, z& v& o207 I- l& A8 E8 V4 }7 u" o
21" z( ]5 |' n6 p8 l
22) h% _$ C4 g( ~6 |# o6 F
23
; {$ @0 S/ R! f+ b# {24% R# L4 g% l1 a$ w. E' l
25
U: T3 ~+ I0 a- C+ I* o26
5 H/ b5 U7 w. ~3 h27
" l/ g; `; G1 D' ]# Y3 E$ ~2 i28( d$ P/ w6 M/ E$ Z; G
29* h2 ]7 y# p8 \" C# {& [" u
30
. W- b% O7 F3 F; F, S. F31
4 A8 d* D! W( d, {! J% K32 @# S; `3 k6 ~/ g1 O. s+ U2 C2 W
33
6 t4 [1 D; O* q7 X4 r/ R34/ Z; \0 @1 @6 B. \; Q% E% h. k
35# r: c! }+ _$ x6 r# B% g+ t7 V9 U
362 @2 p( n: A# |4 x9 h9 k1 h
37; v& d5 F- A/ |) t
38
# d/ D2 m7 a& X$ f39
, W# ^. g. E7 z4 s/ h0 g/ X8 u) q, x5 S40" Z, v* L, X/ Q7 b) M( r
41
4 X0 }& t* G* L2 u429 |; V! @+ W. e, l6 l; P& P) c
43
# y2 x) E4 Q Y+ r5 V44$ D* [2 w X2 G9 I/ a; m# {
45
y( V7 x% W+ }3 \4 E0 ^2 `46
% W, ~+ b1 h5 E1 S- B47- b$ z5 x$ A+ k
48
0 Y$ n2 u/ D2 G% H% q. [49 X, |( `+ }: {1 B! _. C
50
+ K: a* ` @0 \' C: `. w51
+ r- c7 a4 `2 P9 R2 k52
% ] F# }; B% o8 s53
: \; m/ V0 H0 A5 N d; F54
- q4 R" A5 n0 R& X3 {% K55
" X( D% a5 m0 S( t! B" ^56; N# O' l0 R6 x5 Y; g$ v% v4 _
570 A! n5 e2 I# ?5 C
58% L8 h% x' y, U$ D! g, {: ]
59& Q* A( D0 b( p/ e
60. n4 F" B; I1 [* y4 O7 Q
61
7 v9 B* J' D# B# O7 ~62
" @+ ?) p' w, E+ k4 }0 K9 W" ^5 q63
6 h I% k9 }# d" Y( V% T1 `/ c64
6 C$ c) i7 z2 d8 _9 Q5 \, G" e* \65- b- c2 w6 z+ v" _; N9 Q
66
$ w" @! T1 |8 H% X- ]; E67" y5 V) _7 n5 P e" ~
68
/ y- L) D) _+ _4 \69' R! ^- o! \1 }
70" S* i+ M6 I$ `' g' i9 {' W
71
3 R, z i. R6 S8 N9 @) Y; j) p72# M+ i+ T8 F3 U1 _% h3 {0 q0 \
73/ J7 s1 B$ r5 D% E
74
* k6 N9 E& c/ q75
% y3 Y8 H& ~# B" G6 d2 x763 y$ o; V/ |7 e8 `: w6 b
77
" }6 l9 e9 M+ [4 X78% ?$ B# {( ]4 T7 f8 L' |
79) W- P% O. F. P. y7 D. {0 h# j
80. ~# I- K; ]& i
81: R8 P1 r. P$ k# t7 e9 ~# m
82
1 \7 `2 h0 | F$ ?4 f; @; ?3 d- k839 K0 V% V+ C) I- O1 b, [ K& \
844 i# W P- m m8 J! T' J P( x% Q) a
85
' ^+ M: Z/ I# E. B# F86
, M4 m4 ~& P+ w9 C873 y1 k1 R# ]9 B
88
# E: K' G" L! ^3 E6 q4 c89
+ L- _: f8 V/ ~) @: N904 l8 F* c7 x+ N& B
91% S2 t& d4 h+ P' l( K c
92: r$ l$ Z- X* e% u" |4 |1 j
93
, E9 l1 U" ~3 m& {8 Y, D94
- D' q& H) x* J8 S4 Z- ~' D# ~95
( O4 S; [' `9 F7 B96
% B& M. l( `. v97
! i% k% m4 e+ J* A' f98
; U) K6 M6 J- u1 g2 E& o5 f993 U+ ?& _. C+ X# X
100
9 P5 b: \ Z6 j7 @! ^101: i! A* A' w' i7 L# ]
102
: ~5 z1 I- K. p0 W, @: [6 o2 H: a1031 k0 W8 c) q# C8 T
1045 v) X/ x) ^* c0 r4 D
105
t* ~5 c3 E& T* m. n106
4 q7 j8 d) b2 Z( V" l6 s. i9 a107" u' ~9 r8 t/ N3 U
108% a9 G( T8 q5 x! H6 U2 x
109
% i/ C* q0 ~) b/ _; y! C! S, ]7 }110
{: H; E+ H9 p5 x) O% h2 z }111
v/ v3 V' L) w* f8 j* Q112! E- S& I8 `% S( M7 F2 B% ]
7.2 开始训练模型
& P3 E- i1 F& d% @2 T我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
+ ]6 D5 W( t1 u w9 [3 @
5 b# I6 ^* r: U3 L5 K#若太慢,把epoch调低,迭代50次可能好些2 D9 X6 t- P" ]; w% V, `
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合+ v* A6 t. r" d$ T" {# f
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
' P$ G* c* o+ I7 V( I
5 L* F6 |8 G; H) t" ^7 M6 l1
7 h. w! p; I1 a, P/ y2
2 v9 j4 T/ ~/ D* `% t5 G( r3, U7 N6 n. p; V2 }) G' E% a2 c
4
2 T$ b% t1 n/ X m! @5 a7 |Epoch 0/4
; a7 Q5 Y- k J: U5 x2 X! Z" p! d+ Q0 {----------
0 U* ]0 c* S8 c, C; a. eTime elapsed 29m 41s8 X+ t! K, U; A7 \6 s
train Loss: 10.4774 Acc: 0.3147, a: ]# l9 p0 {
Time elapsed 32m 54s2 S7 o' F. `7 a5 U C$ X3 ]) \
valid Loss: 8.2902 Acc: 0.47196 N0 _, X1 i+ B
Optimizer learning rate : 0.0010000
6 @1 R6 _9 I3 o9 p+ {. |) H. ]0 ~7 h
Epoch 1/4" v! G) S7 B7 ^: `; O
----------
1 f: P* Z6 p+ E, o* C: S8 VTime elapsed 60m 11s( b$ o8 w& v& M/ z/ F9 }
train Loss: 2.3126 Acc: 0.7053" F- u4 ]: |3 C! Q! n' W* g
Time elapsed 63m 16s' `# I* M S$ R: r7 f2 J) I: u
valid Loss: 3.2325 Acc: 0.6626$ ~1 i( D q. L8 B* q
Optimizer learning rate : 0.0100000
4 A3 M. M2 r; w; v* u% o+ f% R1 a" k
Epoch 2/4
& t! z3 P3 _4 B; w& h8 }4 g6 }----------# c8 x" [, Z' S3 G
Time elapsed 90m 58s: S7 [# \" T5 {1 Y( `
train Loss: 9.9720 Acc: 0.4734
/ K. P# \3 _: O* O2 ?Time elapsed 94m 4s
) z+ ~" D. P6 P- \: u+ E5 yvalid Loss: 14.0426 Acc: 0.4413
3 c" B' L7 A) X$ X0 ]2 T8 D' sOptimizer learning rate : 0.0001000
( E1 ?2 m* d$ x. U
4 Z2 v9 w- T; E5 G& L; [Epoch 3/41 o$ k+ R* P& {, w/ [- ^7 M0 F
----------4 ^) K! b# ?! a" Q; o" I* g2 k, J$ M' f
Time elapsed 132m 49s
; ]: m$ c: I. R: M! l; Q0 O) rtrain Loss: 5.4290 Acc: 0.6548" s, T# X+ {$ b3 t2 c0 |
Time elapsed 138m 49s( C! m& t7 H1 X+ m# A9 \) b7 s
valid Loss: 6.4208 Acc: 0.6027
; S5 C/ N2 M4 i- r; XOptimizer learning rate : 0.0100000% ]! q+ }7 j$ P3 r; f
6 l0 r1 c8 x8 C% ZEpoch 4/4% k- U/ S6 S2 D% p; W
----------) n0 G% _- J3 A) X
Time elapsed 195m 56s( ]; w; B. V# A
train Loss: 8.8911 Acc: 0.5519) y0 z/ U; v# p, G
Time elapsed 199m 16s) ]3 T2 l/ i1 N: N- C6 t$ t$ G, u( Y
valid Loss: 13.2221 Acc: 0.4914
# q! \7 K+ H1 g7 f! FOptimizer learning rate : 0.0010000
, Y/ ^' A7 h" S
8 I% T9 D: _8 d- J2 I9 JTraining complete in 199m 16s/ ~1 |0 j9 [8 i( \- a/ c
Best val Acc: 0.6625920 F9 K" T! ]# @0 y
7 g9 L" ?+ @ ^9 f
1
4 T' m" j* d& y% O) G3 X7 t8 [6 H2 c2 L9 A! Z7 ?. g! r& D
3( I4 o/ m5 P0 l
4- X- V' T6 b! p. S8 s- P" O9 n
5
6 `' X. Z5 J; ~8 s67 [+ X/ D9 \: y" S# D. h
7
! \; A; u5 n1 p* l, Q: V86 E+ l' M0 Y N9 ?7 o4 i6 E- ^
9
7 e; K: ?3 @+ i10
# F9 x# y9 H/ H, q" U115 \4 B# |, L5 k& h
12- F( W+ I' ]0 P* v! f' R6 _
13$ I. S( g2 [4 s8 t
14
( ?% m- b) C( `. }+ C8 n9 ?15
8 Y$ Z |& P) n16
$ Q7 O& P% k. A& y- S* {# h17
9 d* y9 _2 v9 w4 A" ?. h. Q1 P' h; X9 q18
% u# U: g: ]5 [& P' k( g8 m3 K19
* h, r5 e9 T1 o1 q" e5 s% K2 i: [& y20
7 K+ J2 ^2 c4 D, f' a0 p21
& r- t( w& R( h. v0 u3 N22
" h* f5 K0 f5 `. s" a5 {* V23
& z# [( {1 O# M b24% Q1 N- T, A8 s
25
) i. \- R! g9 M) M2 } A, W262 c& \, j5 Z% G5 c! P
27
) T2 Q& S8 n6 p" R" e285 P! e5 w! x9 i
29) y$ R8 x3 d( o6 T8 y* k
30
7 w: p6 `8 X7 ?4 m: u5 o31
! \4 z. v: |' x5 r. e' P! c' j v$ J8 Z32
5 a1 S0 }( f5 h* J33( M' B/ F7 u) x K; i
34/ z$ K9 G. `$ p0 {! ~' }
35
& O o6 _ C8 r* e7 t/ D36
2 i$ F2 O4 ]/ x3 g' }. x, W37
{6 ^% u& u6 V& ?: @' X38
* h( }( @2 H0 L* g3 p/ k" s2 L39. g& S$ j" j" z# u' s* H9 @
405 a* G8 ]( U8 Q9 w; A. w7 V
418 }$ U6 {' m$ f$ s! t- @0 L O$ m) I
42
, \& R# y0 F6 s# Z# X* u7.3 训练所有层8 ]* r( B9 e$ f) Z
# 将全部网络解锁进行训练
: I/ y+ R1 ]7 Lfor param in model_ft.parameters():
7 f( m! K" j2 I, Y. K1 I param.requires_grad = True
) O% g7 a6 x* t2 l" o0 Q5 D, z+ e+ l7 s3 c# |
# 再继续训练所有的参数,学习率调小一点\5 u x- L/ x0 F( w/ b
optimizer = optim.Adam(params_to_update, lr = 1e-4)
3 g; Q) ], a2 f3 w7 O# tscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
% y# H! r" y \% Y4 L9 Z0 U# e
) |+ ]( }' h* S4 ~, C$ b+ y# 损失函数
+ P* X( n( F, P# v. f8 l9 b/ r, jcriterion = nn.NLLLoss()
9 _/ |' F; v% _1
% L7 D; |4 m8 _2 w4 j* U9 [9 E3 ~2
8 w" \$ P8 `! X9 B4 Q" P% ^38 C$ k& ~) Y. n& P+ v9 B# a0 e- @
4) K3 U5 x5 T3 Y+ H5 W! I. ~; C
5
5 @- h/ m8 B; f* j; U6
; D! G( c9 w6 M: J7
7 V' y$ E+ ?& E; q! q; w7 A0 ?$ |8
6 `! P& G$ H0 K1 v1 ]9
2 C B- Z$ Z# w( x* T1 e3 Y/ D10
( f. V: X7 @2 l3 u1 Q4 ^# T# 加载保存的参数
9 S. u6 I- i9 m. u3 M9 r, e# 并在原有的模型基础上继续训练' S @" k' R; l, R6 {* T$ B7 }
# 下面保存的是刚刚训练效果较好的路径
% k2 b9 C0 W# {1 Pcheckpoint = torch.load(filename)- U# h$ C9 J% r0 w' D. P
best_acc = checkpoint['best_acc']
7 P- C+ b$ [+ q" o- c4 u) Cmodel_ft.load_state_dict(checkpoint['state_dict'])
5 U# Y. V7 B/ ^/ hoptimizer.load_state_dict(checkpoint['optimizer'])
2 a( r& v* F4 M% k7 [$ X1
+ c6 c/ R- E5 v. Z/ z7 m2
1 x# R3 i/ s$ v6 G3
" Y: k8 a$ n b0 U: _/ [47 c8 x8 `/ N( _* k7 l2 g6 M( J
5
+ n( T; A6 ^. w6
3 r" F% {& I1 t% f7
2 O/ t! J' R! V开始训练. M0 P2 g8 g9 F9 D
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
. P) r2 A( t) \1 I9 r2 R4 N/ _5 i5 W3 X! ~9 I, R" t$ W
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
. W' A6 E6 H9 p( [1 A1
5 H4 l7 u6 n' Z" g: `" xEpoch 0/1
3 r7 s- {" i, U5 B- k----------8 L& v- ~7 z; ^9 Y, q, C
Time elapsed 35m 22s
1 x: l3 n2 e; p, e4 C' `0 e' B0 i8 ttrain Loss: 1.7636 Acc: 0.73460 Z! j+ h$ K( e+ i" m
Time elapsed 38m 42s& ~, k" o1 p- `) O! `
valid Loss: 3.6377 Acc: 0.6455
' M* k& N1 S2 z9 j4 w T) b* gOptimizer learning rate : 0.0010000& N4 X' g, C; z' Y
& a9 Q( H6 |+ L+ F# Y. K9 I
Epoch 1/1
& K1 \) K' e* C6 U----------2 |9 {5 J# K. b" n+ M2 j
Time elapsed 82m 59s: r9 h9 Q8 @2 o0 o
train Loss: 1.7543 Acc: 0.7340
5 \2 s3 j! n' B% K- ~Time elapsed 86m 11s
8 v# h) |! x7 Y. M2 F+ x8 Xvalid Loss: 3.8275 Acc: 0.6137% X3 k3 }& l0 I7 A) l6 ?" u8 o$ \
Optimizer learning rate : 0.00100004 l" Z9 A- ]2 I0 T4 t
$ J" v1 l0 c* B4 a P, y- lTraining complete in 86m 11s% {) X; p& W' H7 q7 G
Best val Acc: 0.645477% Z! [ V# h7 B, Z5 @
3 O( W- ~: J4 ?0 _- h1
% P' d; r; Q3 g. c; R2) i! e, m) g3 j e
3
P4 t; ^2 Y+ O8 G4
, _9 d- f8 s3 f0 l1 X55 r* G1 E6 A5 g6 f% f5 j A% {0 K) s
6
! E: ]+ U: m& r! t6 S, C. \! m7* J* x; t9 M* n+ I: F
8
3 i8 B# I' r6 [4 L; M* u# D7 f7 R9% |. I/ K' A- x$ E5 F/ t
10
c7 Q3 Q, p" T& `! X/ U- v" A11
/ Z) m$ Z! @2 C6 Y+ y12
/ E6 D5 u% f* p, p13$ Q8 D, C6 a- X) m/ O
14
7 o- O/ x( g- x15
6 U6 @! o# Q9 G$ ?' h3 G7 s2 h$ q$ k2 V16, s6 y. c: x# ?' G
175 K8 n0 m; _2 ?; }
18
% a* X% `( D2 Z! X+ E' A8. 加载已经训练的模型
( s3 J. ` P' L* g) d# V h相当于做一次简单的前向传播(逻辑推理),不用更新参数
% I0 Y% e |( m! Q; M9 Y# s' ^
. O1 `! w" |& D5 h! N! w3 r7 `model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)2 h# R( P+ Q6 ]* G- I5 P
3 S, Q1 ~2 p/ k1 m8 d0 c# GPU 模式% @. L( {# q& Y3 L
model_ft = model_ft.to(device) # 扔到GPU中
1 ^/ Y0 r& n! x5 i( M6 R6 }
7 [8 v; j# _8 v* T! j4 \' ^# R# 保存文件的名字
+ @2 q: _! a+ q l' u$ Lfilename='checkpoint.pth'# r/ \4 X6 V0 p
$ r# G5 z: H" _7 I- i8 D
# 加载模型+ y- ~3 i; n: ^1 ?1 b! g4 ^% f' c( b1 }
checkpoint = torch.load(filename)8 H7 \/ {% ^/ ~4 |5 W! o8 V
best_acc = checkpoint['best_acc']& N; T+ g. b6 k1 e) n: y( z
model_ft.load_state_dict(checkpoint['state_dict'])
" \- d+ U) S- G' [* [, M' J1( u' Y6 c8 b9 w( \4 t$ P
2% `# {8 o% e. R; M. p
3
, z( J2 ^* I9 v. Y' k9 D6 f; ], t46 t2 Z/ w0 E! e3 ^% d
5+ D: }: O. q; {* _7 C
6 L8 k7 c% @! W$ Z! S( @6 u# u
7% h5 ~6 Z ?. _
80 V' {$ H8 a) M- k# N7 {, i A
9
( |' Q3 E' D, [8 |; \7 B7 x10
: g, z+ y$ l: i9 N+ J11
1 z8 M. t% I% {4 k' L12+ ^5 K0 }( s+ ~/ \# a% q
<All keys matched successfully>
4 r6 S, {0 Z. ]" v3 I5 V+ v18 ?, O) ~+ n0 e" I6 j3 ]
def process_image(image_path):( H% [) H7 @! l6 M- ~: N) L
# 读取测试集数据
" B9 W+ v6 i7 ], V7 m img = Image.open(image_path)
% k+ y5 a& a4 y4 G4 U/ }) A u( G # Resize, thumbnail方法只能进行比例缩小,所以进行判断
- q8 Y- v. T' u* x # 与Resize不同
. n5 J, @3 D; m # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
0 V' I/ D" J: q6 \: u3 j # 而且对象调用方法会直接改变其大小,返回None. D7 N- [# i" w" u% I
if img.size[0] > img.size[1]:; ^, O! g* b3 r7 n
img.thumbnail((10000, 256))
% i7 h6 z; x* H8 y else:
& R5 O* |5 E( ^ img.thumbnail((256, 10000))
. L; e! z5 @2 l5 n3 ]- I6 i5 [7 r; [2 O
, M' A2 h4 [8 ]$ w7 M) M # crop操作, 将图像再次裁剪为 224 * 224
( g, R. m/ S/ l1 ~ left_margin = (img.width - 224) / 2 # 取中间的部分
5 K, m+ q9 k; ^- ` bottom_margin = (img.height - 224) / 2
2 ]" A6 @& s' h0 A- ] right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度8 u0 N- N8 d1 P
top_margin = bottom_margin + 224
) C9 l( Q6 Q8 S- J. z* u2 n" A8 n( ], @4 A
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
0 \! M. {# ?( i8 w: l* {! u* k" T/ b( m) u/ D1 S
# 相同预处理的方法
, @+ i% j% T' O+ k( Z8 J # 归一化
/ ]7 }5 L- o3 O$ N img = np.array(img) / 2550 s+ c( c) |$ m; V/ v
mean = np.array([0.485, 0.456, 0.406])
# q! J @, @0 r5 s, E2 e5 o9 M8 H std = np.array([0.229, 0.224, 0.225])
& S1 | y# x9 l* L! J3 w img = (img - mean) / std
& o' H& y! P4 s# W: ~2 x( i1 L
1 W9 V5 C6 S9 E p8 i9 \ # 注意颜色通道和位置
- F+ b1 d$ H6 d img = img.transpose((2, 0, 1))' u; Q5 Q: U+ B
3 _- T, w# _3 H8 W; E9 ^2 V, H* @
return img- V S1 y. d$ \& t( \) ^8 i
' k- M I* g+ T0 }# ^def imshow(image, ax = None, title = None):
! b' f9 q) e* k N- ^ [ """展示数据"""( R5 V6 Y6 ?; d* I+ I3 E, [3 r' i- k# ~
if ax is None:
1 E# Z7 h5 }8 L! {. z% g" _ fig, ax = plt.subplots()6 s! z# ~, O) D" Z
( t; x0 w$ J" B( ?* t6 m1 }# Q # 颜色通道进行还原- z( ?3 Z" Q- _0 H
image = np.array(image).transpose((1, 2, 0))7 T6 W, I( H( C8 E
8 }# g5 e! _9 e
# 预处理还原
2 M4 n" {% x/ Y& P: |% ~( G mean = np.array([0.485, 0.456, 0.406])
1 k, c1 [1 N; U' y" Q6 `" u2 g std = np.array([0.229, 0.224, 0.225])
! ~: C, ~- D9 V/ Z2 u7 y9 R- T image = std * image + mean% b7 `0 V. R0 |
image = np.clip(image, 0, 1)2 E& h+ q( o5 h4 u3 b6 H8 o( K
0 x( a \9 w' y ax.imshow(image)
+ D" [3 M' S: d4 Q. e ax.set_title(title)
) N8 W% ^4 @# D: Q0 ]" Y$ ?9 ?$ g; t; h4 M+ G
return ax
! v3 K# o$ G, F3 t
% O4 i b. a0 @/ G1 I) Simage_path = r'./flower_data/valid/3/image_06621.jpg'
" p, N, X7 {2 zimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
6 U3 r+ F8 V# A0 P! z4 wimshow(img)& J- U& Y K7 R4 \
% X& k8 s& @0 |* r& [0 s0 t1" T9 p) i0 R" a) c/ Q
2
% Y1 o0 {6 ?5 }1 P0 ^4 t* c3. ^8 G. T' Y8 |3 n0 |/ }5 A3 q
4
- u7 F8 Q$ a O+ }- D# V5
* E) {5 _- g3 D. V2 i6
8 O8 e; O, c0 g4 D9 ^70 H3 ^/ L5 H8 t/ `( y& L
89 @, e. u+ t9 B* m( l
9
# B0 @7 H6 f. \* Y, K% H* W5 D102 I9 h; P9 n5 F( p X: x
11" e' J4 W5 m* y' H* G2 T
12
& O& F5 t5 \) r& B: e13
/ G5 M {/ \9 s14; i* G' L: t8 ^% E' Y/ K X+ d
15 ]" `2 ]: j! F4 @5 h0 J
16
7 \! R) w! L9 t17
3 d& _2 |/ p# ~7 A" k; H# X9 T18
+ x* A; J; { ^2 B19, z; @* {- _' `/ w
20+ x. y6 U1 Y* O" L5 O
21( n+ r A# o: f& M3 p: S* X* R
226 s1 U2 ]# y2 u( A
23, R- }; o2 v- M' d; d$ b
24
3 e' T1 [( g2 I2 L25' ^. v: E( L3 ?0 C% a
26
7 q3 `2 k4 O" T# Q# v/ l. Q27
R2 N$ w, T/ e8 K, H5 z7 e28
" i+ u. ^( U8 R8 ^294 s1 P6 r: Y$ U1 A2 T0 r
30
* _/ S ~/ R% \# V. p7 p31* b9 `$ {4 T- W& a4 O. N3 T
32
% s2 y) T F( E6 @33: a( E# U2 j1 ?0 W
34
0 b$ p8 A/ H7 P9 h0 Q35
% W. h$ c9 g* f36* H% I( M. k& w5 |: G+ `, J
378 n3 k6 @. g5 x& O' D% c
385 p# \% I) l0 {
39) S7 M/ S. l+ K/ G9 M% c
407 ]: M" {0 P5 o5 Z/ T& y
41+ m% U5 w4 v% i8 ^% ]3 b
42
' r1 @; j- _) W" W- b- w- Q43 b2 r8 U! E3 |. t
44/ |) f$ B# q# e+ |
45 f" L% ^$ D5 `& |
469 G1 e, V, g+ `6 F, Y2 K
47
, e4 F; `# Z) w/ j7 o+ Z48& J* m* @3 v3 s& I* S
49
% C- [- W& j, B% w# a. H50 K0 e. j/ d- q. {, j
51% Q" C3 Y1 [: j9 j
52, w3 ]5 C1 S# t
535 g$ D, ~/ k. B2 D
54
: a5 ~* `$ S9 B9 d3 }<AxesSubplot:>
6 [- x1 U9 Y" u2 V5 @: R1
- r! T: ~7 x9 z. Z3 V' ~, O, e3 v" S2 E7 ~! ]
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
$ _# r) Y1 _/ y& m r
' _( w. O5 h' S, w. wimg.shape5 V6 P) V- Q" l4 v# a7 p
19 [. }8 S- L/ h9 R9 J
(3, 224, 224)
; b; q Z9 V! n; [9 w7 }' ^1( l4 k; h. N3 y8 A/ z
证明了通道提前了,而且大小没改变/ W# G, U" i2 c l6 |9 N" c
0 A+ h2 Z% k4 v. b8 A$ B! f
9. 推理
C5 w( q. I% C2 @* B0 Kimg.shape
: K9 {! x! `3 }* d7 k; V% J) a) X: r8 U7 }( ?+ e7 M
# 得到一个batch的测试数据7 J& ?9 v* U5 H0 J. T2 o. I
dataiter = iter(dataloaders['valid'])
; |: G0 I; J* j1 z2 p/ Zimages, labels = dataiter.next()
: A" {7 i0 @3 {; K% N, F9 }& p# f+ f- r
model_ft.eval()$ u9 W! u# e y$ |
- K2 p/ N) p3 l K7 t$ I4 Dif train_on_gpu:# n b, Y, Z/ k8 H6 \
# 前向传播跑一次会得到output
* A. V5 p9 D- n1 f* J t output = model_ft(images.cuda())
- ? H" J# T) q7 r6 xelse:
/ I9 }- c% @& t output = model_ft(images)1 T, Y `; d* z6 l
~* i$ e# F8 g0 r1 ~$ j7 M; X
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值' ]1 Q* }7 h9 m- Q: W; j* S1 k
output.shape
4 T/ Z7 y% F$ I5 t4 f
$ D" N8 r9 M5 j) m2 s& K1
. T Y! ^. z7 h# l) T3 b( E* I2
/ ^' w+ F( }1 ]! R$ q8 I) P' W3 J. k3
' G1 O% C; f' {4# _3 n E( x4 ~
5
V" _: y* d+ [9 w2 T! w2 S" Y6
" i& u( E: }0 q j- j9 u! o7
9 z: V! S$ s) u2 k8 K5 a8' z4 o3 c& F7 I1 n
9
# ?" H2 K. F! c8 F& V( B10+ f2 e, X" c* L3 [( @
11- @; W" R1 n* T
12
a+ E) C% Y3 M- m4 r13
6 W6 V" A$ j* M+ A14
" k' d7 p7 `5 W. W) ^+ t15
8 m# p5 L. d5 v161 s( F' P6 H* o6 _" t
torch.Size([8, 102])
5 t& l7 X% }7 B# A( _. a1, l7 q4 e$ K$ g
9.1 计算得到最大概率 z8 G+ @/ n# o
_, preds_tensor = torch.max(output, 1)' @' C1 u6 I! f* K/ H
; S. j& C" p; w& Zpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
* l$ K P7 T/ F; S* x: i1+ d! k6 L9 e1 v9 M
2
" `6 F& }4 `/ \2 W7 t3) \0 w. \7 |0 t$ a& C+ k
9.2 展示预测结果$ E, [, s" o M9 `
fig = plt.figure(figsize = (20, 20))8 B g2 H% k7 e8 T6 m0 o# V
columns = 4
3 Z8 b1 u# i9 Brows = 2
- w# c N' N/ x9 x% G( t: s
$ U; Z$ }: B$ Zfor idx in range(columns * rows):
1 D( p5 O s# y/ b$ {- H ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
& P# ?4 C3 V/ W( S3 G4 b plt.imshow(im_convert(images[idx]))/ L3 ]7 \7 ?& A
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
X" u( q, `0 e1 P$ ^9 [ color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
# ~: S0 Q- c: ?- o% Hplt.show()
+ }7 a, ?! x% _( q* w# 绿色的表示预测是对的,红色表示预测错了
1 N1 | y# i2 `; J+ \1 R1
' }4 S1 r3 A0 i# E% r" [9 l) |: t$ a3 k26 X$ R& D* q- T1 C6 U
30 U$ _' S; z6 G8 b. x% j4 \
4# K" t# u+ v+ O8 I1 g/ `& D
5
8 M1 e* G/ M1 u) Q7 A5 D, L6
0 p0 f9 T+ o% M1 z, H: q7- w0 T$ [- G% N) p j. J
8
8 c4 O1 w) q5 y$ a# o# M9
4 u/ C \4 h- }4 n; v1 M% ]. R10+ m7 m* w( r+ Z3 D3 |
11
0 [ c9 V* [* c% T: l& K& E! V+ u! N+ x9 ~, t
$ M0 J1 w* D/ o* m; N& {$ l' s q
/ b/ v; S" I# j T. _* {————————————————
* c7 a- ?: F, p版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。- T. P4 _5 K3 Q2 U7 @+ Q$ i
原文链接:https://blog.csdn.net/LeungSr/article/details/1267479400 J( V5 A5 J5 H! H* i! B+ O
& M3 m. ~; e. n$ M m! t+ q1 x4 i6 u9 T( g" K
|
zan
|