- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563353 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174229
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
% I9 P) A6 i# L7 e0 w
7 g: ]: z c8 p: I" E文章目录
) N) N8 ?5 P, |; X& r卷积网络实战 对花进行分类
* l, i1 C/ h- V- H数据预处理部分$ o/ O1 N, S$ h4 h$ }; R
网络模块设置- r/ H/ z( ?, k
网络模型的保存与测试
0 M. I; U; q9 l, z7 c# W! d" H数据下载:0 r' q6 U+ ^7 n
1. 导入工具包
; e" r% \! n) @& J* R' f3 f2. 数据预处理与操作
' Y& H7 B% t: D( Q4 w+ f- _3. 制作好数据源1 q1 R3 F9 V+ |. |1 }( r
读取标签对应的实际名字
( E! `0 E" Z# m4.展示一下数据
4 u( Z! `" |9 f8 H/ \, z3 g5 b1 C5. 加载models提供的模型,并直接用训练好的权重做初始化参数 }/ g* m- [& b; @, ^
6.初始化模型架构
# p! t8 ? T& E+ ^* L% y, ] K o3 D9 |) `7. 设置需要训练的参数0 e. M3 u4 S4 Y5 R, P
7. 训练与预测 r# _2 g( w: d- d* S) o4 R3 }9 s
7.1 优化器设置4 d$ t7 O& g1 O8 l1 g+ T
7.2 开始训练模型: q: \& O( o: f
7.3 训练所有层
7 T/ k, ~/ s. s1 O开始训练: b3 w+ F. e3 H0 f( \# z7 {+ {
8. 加载已经训练的模型* h0 P0 ^* Q9 \
9. 推理
$ [9 }. ]$ Z T3 S1 n+ a9 G9.1 计算得到最大概率- T6 ~; x" H1 m# F' e' [2 e
9.2 展示预测结果
2 Q" H4 ^4 O/ k1 P写在最后
3 y" h& Y& q ` b4 W卷积网络实战 对花进行分类
2 L$ j. g; Q0 C7 o) R本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
" P" e2 p1 e0 p# C* j7 [
$ m! q+ M% j* C+ B在文件夹中有102种花,我们主要要对这些花进行分类任务
/ A5 _4 x1 Z; S0 u6 e4 N文件夹结构
u( C, N" c- `
5 M2 `' d z' `; ~/ nflower_data `) F3 w( |3 i" v
5 Y8 D2 a$ Z% ^8 itrain) D, h7 a1 W2 a) x
/ c) F/ K/ s4 ^/ ^) d! v5 K2 z1(类别)6 s; Q$ A9 c( H$ ^' O% \
2* t. r3 B" m+ p9 [' b
xxx.png / xxx.jpg2 i% [7 t1 V6 ?
valid
) g/ w/ r) }* g. [" P5 m! j% D7 J2 m' ~4 _7 Y4 l$ Y! z# c1 d
主要分为以下几个大模块7 {" s% M0 J5 e" p: V; ], |& q
- j: y7 W8 h p' A* x
数据预处理部分
! a# i, h2 O* v$ K) J( s数据增强1 d5 u/ {% z0 v( r
数据预处理
6 m" Z# n: @! `, T# w5 b网络模块设置
5 P0 D5 r1 ]2 j5 k) k: \" h7 h0 g加载预训练模型,直接调用torchVision的经典网络架构( s% P9 z+ O2 K' A1 p1 h$ q7 m
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务( v3 f% _" C1 f- [9 u( N1 m
网络模型的保存与测试1 O$ F; [4 X" h0 @
模型保存可以带有选择性
% F; @4 }+ i ?5 m" W: t数据下载:' O1 Q, X# H# @: X1 e! M
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset) u, H8 ]: q2 Y3 z5 e- U6 {
- y" |! p8 A# x, J; D改一下文件名,然后将它放到同一根目录就可以了" ], o' ~) E- e" R$ L J
2 a Y$ i2 q/ E' ]: O
下面是我的数据根目录' x0 b3 `: i+ D
- {- y- Z1 k7 f/ F8 q7 E4 S5 m9 I0 |4 e! d
1. 导入工具包
" i& K* K5 L' k$ S5 ]8 ^ Simport os
; B0 g/ x' {& b; B* h4 nimport matplotlib.pyplot as plt/ s# ~9 S! H- t: R- k1 P+ }: R" y
# 内嵌入绘图简去show的句柄
: c$ j4 ]9 C0 x2 K%matplotlib inline
8 L: W r' j- d7 ?( W* ~import numpy as np
6 g; y z# R5 I+ }import torch8 t- @. E; ~* E. g1 n' H8 q
from torch import nn
, Q: p" {9 }; C# y0 X+ I3 L. o, @+ d+ D; q
import torch.optim as optim6 B- O8 {5 i% W: \" n9 m
import torchvision
) e4 j% U: L1 W* ~$ X* |from torchvision import transforms, models, datasets
% }4 S9 R! L8 l/ t: {* Q4 o- I3 b' P" r
import imageio/ {+ D9 v) m: ^6 j
import time, ]' v; e+ a4 o9 d6 N; @
import warnings
6 h, A v4 ^1 ^ M& z0 ]import random# a" b( @' |# J' ~9 Y& v; Q
import sys$ f* N; I6 ?; [& Z3 N
import copy4 j5 Y3 h3 g+ h* j$ i
import json6 F% r. r, n y; O
from PIL import Image& s7 {. h! k* X3 N+ }% I& S- l
: U0 J' _# B9 W! e/ _
( P b* Y- g l$ M$ Q$ H1$ V! ?* m0 I* b: G7 t: J5 R
2
# K2 K. W9 c; l0 Z36 o3 q T% l- w! u/ D: y' C
4
/ m- A% x$ U* q! [. y B5
! |0 w7 `/ x$ v& U1 s7 [7 V R$ s" o" ?6' t7 ?9 v" } w6 f* A. x+ A
7
[7 b$ i6 c# D0 f8
. @" m q# z1 j: L+ A, r$ E# j' _9- c1 m; N: v. v' S" M
10
2 |- e8 z: k* b- P \2 e; u; O11: M& Z! T; e# h' {
12
, h) r' z" x& m8 e3 Y13
+ X O6 h, M; \142 l% m3 s5 S6 Y& C8 f, q9 l% j
15
' Z* m1 u& D4 I162 ?& Z; h/ f8 I" b
175 o8 E& f) G$ i w* q7 |4 @
18
7 w0 j6 D8 G! g2 u+ `4 g/ A# h$ Y0 H19
. W5 m3 j2 J# V) s20; H, }" @+ L& A5 s/ o
21
V5 C$ G! E' R& R9 |3 c2. 数据预处理与操作3 D/ F1 O+ h9 @( x: u- W" _2 g9 K
#路径设置
\2 C: b6 q8 z: bdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
9 m& r; |$ A# |3 }train_dir = data_dir + '/train'6 C2 K7 t9 q% {4 ]8 W1 B
valid_dir = data_dir + '/valid'
! m4 W& n$ z8 N( K! O1
! ^- |: J) f( Y) X. B( S2
% V$ g0 j. }" l9 \2 E3( h- x8 a! X f$ L
4
: Q' m) r1 ?8 A+ m/ a; Spython目录点杠的组合与区别
' z& f/ Z5 Y1 z* @1 D注: 里面注明了点杠和斜杠的操作( E$ ?( M: J6 f) m! Q0 L
# } [; r) x1 w% @3 N1 A( {' l4 Y' o3. 制作好数据源
8 ~$ }; E" _2 D' Sdata_transforms中制定了所有图像预处理的操作3 V. v) f* o; n/ g' c
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片- @+ O: D( V" \1 f
data_transforms = {
2 c( Y } j+ T: y! n # 分成两部分,一部分是训练, \. N+ c: |, V8 D0 I: l+ u
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
7 I# A2 c2 C- W5 f* g transforms.CenterCrop(224), # 从中心处开始裁剪, ?& j5 V v; B' [
# 以某个随机的概率决定是否翻转 55开
- i8 U- C4 ~, e, W( H- R transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转" c. a% d4 N6 p1 e1 X7 o' M1 }& R
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
+ p/ I2 M( k" m T, }2 j # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相( Z- U( h6 h1 ?
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
! E" U8 M c' _/ \8 U transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
1 _) W: z2 |* Q # 灰度图转换以后也是三个通道,但是只是RGB是一样的
3 r' h# N; T% t4 [$ L* V ` transforms.ToTensor(),
7 p* S* _4 ~& b0 x transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
# f# C+ D, `0 P! q; Q% h ]),
( n* Z) i/ l) C # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
: v4 E' S" E, }7 H 'valid': transforms.Compose([transforms.Resize(256),2 u: g) I9 U2 Y) V
transforms.CenterCrop(224),+ ?; @) i# c' W
transforms.ToTensor(),
3 q# U4 y; U) c7 c6 T: |+ \ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
5 B" q" T0 Z6 x' E$ {7 e ]),: k7 F; w6 G/ z: b+ i; E
}, B2 t0 [) v7 Q: v2 I, K4 m/ o
' @- E- B5 M7 d2 k* O- i# e( @
1
) s' N0 I8 j' I; } _8 [! Q2" u+ h* f/ L1 v0 O& _1 U1 ], \
3
# M( w9 O* o2 x9 c4
% N X3 @+ Z6 n% E5
; q' k5 J7 o& b" J c' v6/ x) J/ b+ M0 t2 b7 r* v
7
7 z( z" B# [+ u8 A: x8
( O r; s% ^- n" f9( B" [* ~0 W \! w
10
+ X ], l3 w0 c11+ P7 _; ^. n- _* q) Y
12
6 s% R5 x$ i$ n% ^- J13- D- d9 G4 Z6 V3 R
146 ~+ b% w! u, h7 L/ R+ x
15
5 H' X, i2 F: ]1 _163 l# S8 \' `; a/ n7 r
17
O+ o4 g' ^% j18
1 l2 W) R, Z U1 n0 d% N/ v; J19! A# ~, S2 D7 h! i7 e0 l6 `3 V
20# z' c+ B5 I9 F; G/ ~, W0 W& G b
21( n+ @( w! a8 s D: V' g6 E
batch_size = 8
) B% |' L5 c6 M' k7 e2 O/ ^9 dimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
; s3 \( N- \2 k8 [ adataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
4 N" p& |$ ]3 W0 ?8 t# ^7 \, Xdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} 5 U: B e$ ]& H7 V
class_names = image_datasets['train'].classes
1 X6 S2 z: ?2 Y( o7 \0 h
! n9 z4 q6 Y G" y#查看数据集合% W3 w. h- J7 F' H) p- J1 [: L% U
image_datasets
, G8 n1 Q. ^# u5 `$ o: s
# d8 `8 ?" x0 k9 Z0 h# Q+ O' r+ W16 q* v4 h. s, K* b5 g0 N: y
2
3 N+ L: e9 P$ @% p; ?3
' Y* c4 R$ x! k- t4. i! l9 Y0 i; p- \( B2 F9 h$ M$ Y
5
$ ?& F, z7 C# |: D6
3 T; x) Q1 g* A7 S: \$ t: X, R8 L7
8 x. g% W/ K4 B" ^( m8
2 V# E# T6 i3 x4 l, }9
! X/ k5 U4 k% G( X( A" I& ~{'train': Dataset ImageFolder
5 @- ?$ M! [9 e3 z) K) }: [% f Number of datapoints: 6552$ o$ j; n' ]" v. N$ ]/ `
Root location: ./flower_data/train1 |! ?) F( [6 L4 f7 |
StandardTransform$ c( X- }, C+ {1 Z: K/ U
Transform: Compose(1 F5 z' n; N% `$ x
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)( G0 e# [8 }- |# S8 Z+ B* ~
CenterCrop(size=(224, 224))
5 W/ {- L: u4 S" V3 b! Q2 R/ t RandomHorizontalFlip(p=0.5)
. ?9 b5 s% a/ `5 [ RandomVerticalFlip(p=0.5)5 u- E% T& E6 W' P
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])4 e' H0 l, E* o+ _# I% d
RandomGrayscale(p=0.025)
& ]6 u( e( \4 e6 R ToTensor()
, x, i* c- L* }) w/ U Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
2 e9 t/ c( O5 Q3 L3 o# ^& X ),
8 f# r4 Z3 |7 S 'valid': Dataset ImageFolder0 V7 C& s+ g6 D( f
Number of datapoints: 818
+ D& i4 q5 x: j% c1 \ Root location: ./flower_data/valid
% | ?- \$ m' q- b6 O StandardTransform
5 w- @5 C1 X6 k5 \0 o9 [; i# ?4 D Transform: Compose(
$ D8 e0 D" t* ~6 z( k. ? Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)9 ? ~% U$ T3 t. Q; b& S' I
CenterCrop(size=(224, 224))
& {7 c6 u7 b0 K& y2 G6 H$ } ToTensor(). H/ H; m7 ~6 ]8 }9 u& [" \0 v
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]); C3 s3 e! s7 n8 y" V
)}
8 o _. j, h! \5 Q9 h/ u: \$ E" [% p8 o% q6 r, G
1
2 r" v# y# j, t4 F( Y2
, l8 r. N$ r2 T1 _3
& y( ]. Q1 p" S+ C5 y4
' t- E/ l" `8 U. _# r5 O; _5! Y b6 T3 Y; u+ t% K1 n
6
5 ^( F. r n6 `5 s4 i6 k' P7* f' k1 W- P. n& |9 T
8
1 ]8 ?5 t: \! e7 A, e9 T9
/ I K% S" g; N w& v# h10
& _, I+ S& I! b111 F& y# g6 h r+ _
12
, V3 @; x# n K" E3 s13$ [9 J( c2 g( e2 v0 p' S8 I/ e
14
. z9 P8 C2 i5 E% D( _/ @15/ B# t7 L. ], M5 o( j- `. J4 m" N
16* u! B5 h4 q% u+ u3 P9 ?2 j
17
# g: U* S5 \4 P& g& | Q4 s! Y189 D4 z. q0 R* [' \# D$ G- [- Q
19 \+ r* r: @5 g0 w3 B6 A5 m9 i
20, w' B) A, N) L8 L( M/ l
21! |! s& G2 E J) x2 O5 M
22
' w- i4 u4 J2 }: l6 T23% w! N1 }! B1 w9 u5 Q8 d @/ ^
24
$ S# O Q) H" V# g; j) c& V# 验证一下数据是否已经被处理完毕
( v9 k" l" b' \' Y3 s6 J% [" b7 sdataloaders7 P4 P3 t! }: {8 [- E) ?
1$ I, J% N" N) @7 c
2! k! g" k6 ?/ h( F7 o
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,) R9 m# |, X4 x* J
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
$ o! F1 n( w9 a1
% e8 B7 k& K5 u) {, {( A$ p2 S2
& v C, [9 i% a6 c& ^dataset_sizes- C( P4 i/ W6 ~0 O
1
% f9 V9 w' G3 K* K! U{'train': 6552, 'valid': 818}* t4 `/ ]" F2 O5 h
1
k# v, m3 @9 _$ Y读取标签对应的实际名字
h8 @6 ?$ p+ \. ]! x使用同一目录下的json文件,反向映射出花对应的名字
; J- Q6 t$ C2 _7 l/ \. F; [1 `; i6 C! `' z3 T
with open('./flower_data/cat_to_name.json', 'r') as f:
- [) w% `2 p7 d0 [7 _ cat_to_name = json.load(f)7 H1 B; s9 A* ^- Z) b7 _6 c
1
# A' {' w2 t" H: [3 E1 m/ ~2
( k3 y6 _3 A& e4 }6 Q3 Y, z4 l. icat_to_name% b& e: p" f5 M1 M I, C# @
1
Y; A9 @. @# K{'21': 'fire lily',% f v0 l! p9 ]' Q# c2 H
'3': 'canterbury bells',
6 I# C) X. d5 z- F. A( o '45': 'bolero deep blue',
1 U) C) n3 `; @: C6 C8 ` '1': 'pink primrose',
. b5 V+ z! K3 u, X9 F. M '34': 'mexican aster',' Z$ |1 o1 ?+ R# E6 [' ^6 j" q! Z1 H
'27': 'prince of wales feathers',2 w- ~, s: r6 l7 u/ ], y
'7': 'moon orchid',
) x, N+ Z, B& z4 T* G3 J! z5 N '16': 'globe-flower',4 V6 |8 |4 ~: v
'25': 'grape hyacinth',
" _/ Q3 e( _0 o3 F( t '26': 'corn poppy',( F& d4 D B" Q+ [3 D" [3 P0 l
'79': 'toad lily',6 e# g" \4 S k% y+ s
'39': 'siam tulip',
% J3 v, C9 p* W1 e0 H '24': 'red ginger',
! T- L' Y0 R+ d+ D- G2 z '67': 'spring crocus',( u3 t2 J6 S4 c! F
'35': 'alpine sea holly',
7 P a* x: p( G1 H; D '32': 'garden phlox',
: x, G8 R) C" i& t8 t. \ '10': 'globe thistle',
8 R' y3 Z0 b3 A. @" p7 }) ] E% \ '6': 'tiger lily',
6 {/ A+ _) C) s/ |2 q5 U '93': 'ball moss',* e# k" F% h) |. r7 K# ?
'33': 'love in the mist',
; L" j) ]2 A" ]1 K4 B+ \ '9': 'monkshood',
. Z5 Q) e# w4 R6 W '102': 'blackberry lily',0 n7 F( p, u3 ~( v
'14': 'spear thistle',
6 ^4 G2 L# M# T! M+ D '19': 'balloon flower',
8 d: u+ k3 n' h '100': 'blanket flower',
v' v2 u" a" v) p5 o, s+ g* X- J '13': 'king protea',& S& I, d9 o9 N% u3 o9 ~* L( `9 m
'49': 'oxeye daisy',
5 \! I c9 ~, m% ~% R" Q9 s4 [ '15': 'yellow iris',
" ?. f N3 T+ G8 |) D' Y, z '61': 'cautleya spicata'," A9 q. x, O7 \6 a( e
'31': 'carnation',
i" j7 x3 O. |. C$ b9 l+ R1 d '64': 'silverbush',/ Z- U2 v( @" x5 H( m$ m* j2 M
'68': 'bearded iris',
7 C# E' { V0 H) B% h '63': 'black-eyed susan',
' Y* U. |2 B# e9 e '69': 'windflower',* q" C1 ?, M9 S' A" v
'62': 'japanese anemone',
2 D9 h8 j2 Z" e; M1 E '20': 'giant white arum lily',; `( _, Y, R5 j6 R" }8 J ~
'38': 'great masterwort',
# f3 y" ]: f4 ] '4': 'sweet pea',# k4 d) o2 P5 k* O6 H2 u
'86': 'tree mallow',- n/ X! T) i& Y1 X) p1 S0 W
'101': 'trumpet creeper',/ S* I; `( D; J3 C& }
'42': 'daffodil',
4 |- p w, w) h. C2 s '22': 'pincushion flower',
. ?8 b9 F8 I9 d @1 p '2': 'hard-leaved pocket orchid',
; I! l' r8 n- i/ G '54': 'sunflower',
$ Y7 C+ `1 t. {1 r- X9 r0 w' D) e '66': 'osteospermum',0 Q1 [6 L3 N" m2 m& j9 [- ^
'70': 'tree poppy',
- Z. N, r/ v$ S4 R+ U# I '85': 'desert-rose',1 p7 i; x4 w( |& {+ Z
'99': 'bromelia',. g8 y0 }" ^# |8 l, w, A/ }: x
'87': 'magnolia',) W- P+ S& }) Y& m/ v0 p- x
'5': 'english marigold',
" ?" O/ h6 z( j, V1 \ '92': 'bee balm',
5 c7 e5 M+ j8 d$ ~* u '28': 'stemless gentian',4 g; ~( m3 P$ A# r* a& n: r$ H
'97': 'mallow',% V- q+ a" m- {/ ?- A3 L6 ^1 L
'57': 'gaura',
( Y' B( `5 B# d; L '40': 'lenten rose',2 C2 ~( d- M! P
'47': 'marigold',
+ X$ b5 M, ?1 M% i, Z. r) s '59': 'orange dahlia',
1 \/ q' l* t1 o v '48': 'buttercup',
. u$ c6 I" D( u+ Q3 L '55': 'pelargonium',, f: u- z4 P- ]
'36': 'ruby-lipped cattleya',
7 y; P0 J7 \, C# w- y" b l* K '91': 'hippeastrum',
/ i h+ W' C9 Z' I '29': 'artichoke',
; Z0 I: B0 J, y" M* W '71': 'gazania',+ b: T- Z; x6 C6 j( S2 k
'90': 'canna lily',
( p8 S% G2 |& G, J '18': 'peruvian lily',4 I# K( a1 L% z0 n3 S! b' q
'98': 'mexican petunia',
Q# m' B H* w! V) `8 ?5 m '8': 'bird of paradise',
1 p" D$ T" {' D! W3 W '30': 'sweet william',
( N( \3 [$ |( N" Y+ [ '17': 'purple coneflower',: G# z1 i& [7 s' V0 D' l q
'52': 'wild pansy',
4 B; Q+ l6 R6 L# |$ V '84': 'columbine',* u7 v4 N+ z+ S$ E/ g" o
'12': "colt's foot",( O0 L0 z, l+ a X$ D- Y* |
'11': 'snapdragon',: Z9 N4 P. t8 N
'96': 'camellia',
* C3 i3 k* Q* b/ }; V '23': 'fritillary',
! y% i3 P% d% { '50': 'common dandelion',9 m6 U4 l: c( A$ V) {- Z
'44': 'poinsettia',
0 p5 a" d3 f5 d( T- c$ ]/ e9 C '53': 'primula',; V/ M* w+ f& ~% H. P3 r9 R
'72': 'azalea',3 e2 a7 G7 Q. U6 f* B! x1 ]
'65': 'californian poppy',& ~) B, p; i( |
'80': 'anthurium',3 l( h' W [5 Z3 i* p
'76': 'morning glory',3 ~ _: z2 P3 l* ^% g+ x
'37': 'cape flower',
( D4 U+ M7 a* [+ u& v7 K '56': 'bishop of llandaff',/ V5 h- ^: O+ @4 l
'60': 'pink-yellow dahlia',& L; m. V' |2 x( _& j& d
'82': 'clematis',
5 f' a K# D, Z0 w9 g# ~ '58': 'geranium',+ ~; A L: o+ f0 B
'75': 'thorn apple',
( u# S" ] S$ p. t, G '41': 'barbeton daisy',
; Y- X% G7 p0 M '95': 'bougainvillea',- w6 i( G' i( f: N$ O, w
'43': 'sword lily',
& B0 P: B# p0 s3 I, \+ q# R! [ '83': 'hibiscus',
5 b+ y! U* F2 g# q! {' @( C8 Q '78': 'lotus lotus'," E. N- T G4 `1 D3 D
'88': 'cyclamen',9 U$ T+ _6 W+ }! T
'94': 'foxglove',
4 ?( f5 ? B7 r' S '81': 'frangipani',
3 t% |4 G0 _# V# r3 F9 a/ b& a '74': 'rose',% H, P1 l9 M7 G' z6 M* A+ h
'89': 'watercress',* i2 S; R1 Y7 z( o
'73': 'water lily',
4 @+ s2 f) j; e" G '46': 'wallflower',, m& b3 p4 g: W( M3 e) n
'77': 'passion flower',
5 p9 s: B R: H E3 i5 ^ '51': 'petunia'}4 \5 P/ ]" X+ y b6 e' M
! e. v: |5 _2 b" N, ]# c: J1 h. @1 k
1$ C e8 j8 |/ ~2 K
24 h/ H; O! X& q" f I
3. L) I, @% {% V. v
4' o% [# Y8 C, R0 A0 {6 W+ ^( X
5* s0 }* A' o$ y% Q0 H
68 z* k6 R0 ~ Y) N
71 l" _9 Q: T' m& S0 G8 ^
82 h( I" w. c! q: b
90 l. j* M1 m4 o; h
10
: |7 o. w/ D: i11
1 b# M/ I6 X3 J! j9 `12
' S( Y+ ?7 `9 H8 i: R1 \" u13
7 z: E7 h$ R( q7 t8 U# R2 `14
; }: A1 A9 \+ U9 A15
* p3 i3 M$ W$ j# J3 X16
0 H3 o3 G+ P; L% d. [17
( h: S9 A8 `4 J# R) C$ g V18
, g; A0 s9 J( B6 s19
5 m8 ]# A: T% N4 ?, d- B& e20) Z$ v/ R$ z1 L; Q4 U
21" D( l; F- p. Z; _, i! a. N1 i2 i
22* ~% L! w" V) q, [
233 y' t g+ {8 j9 y* ?
247 B+ V: Z J- r/ g2 H
25- q, f& N: D- x. X
26
4 m& V3 L @# w4 }# S' f0 o- K27) t) z( f, J# y1 S& {! [# o
28: n. j7 n7 C7 x6 e
29
G3 ^2 p& N/ g6 C30
8 @7 }& d/ _6 B; f31' W0 f6 b0 F+ h9 _
32
- j% Z( P- ` }2 d% i* V' T33- v! e8 R% ^ o8 O. P
34
( z4 Y5 f; K5 Q- h2 M358 G+ x" _ z3 z6 s; f- y, E- X
361 Z7 O: k! r# c
37" c+ r' M: Y2 n4 t3 f/ g8 ^
386 W5 [. L' E$ {9 @, b% [, `# ?
39
/ p# e2 d# D8 F" I40
! b# K7 f" g E41
" [% a4 E+ S8 C2 i* j425 g' P& g% o4 X3 w4 J/ U
43! {9 a( C/ z v3 l4 A' A
449 K' A) j4 q& Z( k% Y
451 h" l8 y1 A) z
46
. v ?( U, J) i47
, r" D" r- U( N( I' b" z" @488 p7 M" K- M8 h) `; k5 s) Q# ^7 _4 D
49* B4 n, j3 Z2 d6 h) @
50% p" ?8 {0 S# i/ l+ v: Q
51; _1 J a. L3 P, T
52: o# c; K) u2 U* j% a7 {& K
53. S' g2 E9 G# r0 t
54. A% {; ~- q( f, B; W
55
' u. @' G$ p; \. @# |$ ^# |561 B6 f0 }$ ], U2 `
579 p% [ j- C; J" G- }0 t
58 `# Q4 q# V* E! @: ]8 V
59
# q( v2 |0 x, U2 l/ J* d60
0 b p- n3 u' D2 M3 k3 B61
, n7 B7 A' @6 ? S ]62, O5 h5 m% k8 b. [
63
3 m' [& r2 S3 r645 o0 S3 t- p& b1 s1 d a
65
% ?5 R; }& G3 @6 p% Y664 U* B- Z3 |/ `
67/ n1 P4 q4 s0 }+ |
68' D4 J5 ^# ]/ d. ]2 Q- a. v7 V3 y+ X
69
, C# [* j$ q/ ~2 W0 y) j70& B# b, |$ B; m3 u3 h3 _
71
/ E) \, t& @' I2 K72
1 |: M3 K2 G0 k4 L0 h G3 Y C73
+ F2 C: b/ l% h74! M* J. z- F% y
75! ?- u" P6 u) \, ^+ [/ l
76
! {* M! U7 a" A+ t- Z2 c, C" K77( L9 ?: u- _9 E( I
78
# h0 X: O1 _% p L/ a79
7 o/ d( _% G: y1 T, x/ A- I# z7 V80* l! P! T4 D6 N0 |: @7 e8 d! A q
819 _, d- e* r3 X2 u; p
82" n% v/ P! W3 p% c# u5 P \/ z" f- ~
83
+ N6 t2 B; f- S84
' Q8 p( y- P% u& M& L5 g85
; ^4 @+ v9 l5 g* M0 w86
: |' w* N" U( z% P. v7 w1 \5 {87
1 R4 b& q2 e7 U+ Y88
4 L" ^1 k7 P& v0 L- `7 ]89; s. T: ]6 {0 b2 u( r0 V
90! D. h0 l+ j: r6 O7 _, Y1 p2 G
914 b6 ~% D8 x, r+ w; u! i: f0 f; j
92
7 p# w1 }3 V( ~% e5 i7 v& t: O- o93; f8 I& k/ `( y: D5 U [3 I
942 y9 z9 x* k" K, ]
95
' {1 r2 F r @0 ~2 p y( Y9 A96
' T" e6 ~0 o5 ]) o97
; ]9 k) l+ p Z98. o# o( R% q3 ]
99
7 V' t% C! {! I- P: f m" f100$ N- t" F* T" f
101+ N X! S6 N' r8 D/ c# ~
102! }$ r% \: x: N# g
4.展示一下数据
9 S- G4 K5 w z" Z' ?6 b0 q6 Xdef im_convert(tensor):( W( C4 c5 L( @" c: s+ E3 Q- w
"""数据展示"""
* m; |8 ]3 S Z% | image = tensor.to("cpu").clone().detach()
8 Q+ j0 Z4 u& P1 C5 _) f image = image.numpy().squeeze()
# L2 [3 X+ h" o/ ~/ M( m # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图4 G& r+ v0 w5 A5 a, x' e
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
- c V4 T3 X8 G& i! ] image = image.transpose(1, 2, 0)
$ E- W1 m4 s: S8 S. G2 o # 反正则化(反标准化)
' a. T- E3 B' A image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
) O: V2 C, _" V6 U& D! }$ _
6 M/ @+ K# k1 z+ | # 将图像中小于0 的都换成0,大于的都变成1( j$ ]9 `6 `; T9 I# E5 P
image = image.clip(0, 1): ]& d# q+ n8 [4 F+ O& |
R% Q: Z% D. e0 {- g" H return image
% q5 m0 i; C5 L5 f5 w1
1 K* {1 n; C3 p. H2. B# y3 r$ U9 H" C: [
3
- I$ ]& E; P1 Z8 u4
6 L6 y$ u6 w7 V: ]. l; Q5) ]9 d; |" `/ S4 i
68 O3 `. Q, k0 W# h$ y4 ^9 S
7
- Q9 i. B( H3 w3 m+ e8) D/ o+ T) Z$ @% B
9. @( @6 Z$ N- P/ i& x
105 D; x1 \9 g0 a1 n* P4 O, c
115 [/ Q/ _: K( ^: x! |5 r O1 Y
12
/ ?8 I: ]" t" K- f6 l13
! h* h% m# B( u& r14+ v" I h: d0 T; Q7 I( J ?1 ]
# 使用上面定义好的类进行画图6 B W0 Q) G) J4 @: J* Z+ A. O
fig = plt.figure(figsize = (20, 12))
* g8 K0 ]" n7 }( E: f8 }& \+ scolumns = 4, t/ K4 [+ h2 z; p0 k
rows = 2
* k2 W* f0 |; t, X$ c6 G8 n
! g) I4 g1 _$ a: o% \# iter迭代器4 a8 }! b+ h" I; A$ H
# 随便找一个Batch数据进行展示" k% i4 L8 s9 K& }
dataiter = iter(dataloaders['valid'])
9 }% U1 c7 S3 H. N _inputs, classes = dataiter.next()- u' i4 q$ P8 k5 Y: E
3 M8 m+ k" O9 N5 ]
for idx in range(columns * rows):' Y+ X0 j0 A& S0 ^* l& d
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
2 q! R; f% F2 }; G # 利用json文件将其对应花的类型打印在图片中
+ \ |, ?# C5 @9 R+ d/ s+ ] ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
( o: c/ `9 T+ t R, k/ F plt.imshow(im_convert(inputs[idx]))3 `. |; ]5 i3 n+ v8 G
plt.show()8 s, f P2 {, _- k
: g# |; T3 @- }* C
1
7 ~& h6 h" ^7 S" g% i- {2
1 \0 w* ~& ` L; p! f. t3
+ u' x4 e6 G; y% c4
! @; x" S: S1 q& a) E9 L5/ W1 i0 b6 N0 i$ C# o
6
( B v& J9 s; b$ W# E; H, b7
' {) g% s- `% ?. n: l8
4 B6 U3 p) U5 y* j9
! I" x# M) v+ z: u7 Q @* J10, N/ e8 \" e) a& Q1 D W+ e) j- v
112 T' @% I# M$ {/ g6 y
12- q' a4 H. o" U8 o
130 c$ r! C% q- s2 v+ Y9 m$ _
149 S: P4 p% E) u8 {: ~# U
15/ u5 J8 s, t4 ^1 x
16, M' I$ y0 ?0 t" [
7 m7 A4 u0 w) D8 ^* l: L
) {6 h5 Y/ G7 i- [1 m
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
p d2 L6 W# C# Tmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
9 P. T7 ]0 o9 H# 主要的图像识别用resnet来做
0 x* l% ~% P- I1 `. [5 G6 Y* s# 是否用人家训练好的特征
2 |0 m e5 K) J5 ]9 F9 k' @feature_extract = True( N" A% v- o; f" }$ C6 V
1
8 j" J1 `; H/ s2 c21 B8 J/ D7 c% b, V5 ]& q8 V3 E
36 S( y+ b6 b6 J q8 ^ r( \4 }
40 b5 b+ d( [0 H/ d6 v5 ]
# 是否用GPU进行训练4 e W/ N2 p) X1 g
train_on_gpu = torch.cuda.is_available()
+ @& C' f7 K2 P% U2 S6 }+ u6 N/ n1 S6 C/ O! R( O7 j, ` j
if not train_on_gpu:
! B# i& Z/ N [8 c print('CUDA is not available. Training on CPU ...')0 o1 `1 ^- Y2 }% { T+ j. G- {* q
else: ~0 c. v+ `2 o' ]" g( t E# s
print('CUDA is available! Training on GPU ...')
" Z' _8 G% a/ R' x% X) R8 @8 F8 i/ {9 i, O. F/ m/ U" |
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu'). v- w6 |" ?& ?3 b7 f" d$ Z
18 ^. H1 q; l: L" Q& z
2
0 U9 \* w0 e: n: q3: @! X+ j% [) {5 Z% I" \
4
8 [; o! d: c! K! F" `5
: K o, m# k$ H, g6! P- v0 F- j% V- H- [& z1 J
7
2 _* L, g. t p: Y8
( i9 g4 Q( j+ c/ S2 j9
; ]6 |4 n7 {4 J8 M4 zCUDA is not available. Training on CPU ...7 v& g! H2 l: W* c J) {2 |- g9 j
1$ c" y% k6 S, `1 K; @2 j
# 将一些层定义为false,使其不自动更新3 i$ m4 I2 ]+ \- Q! T! s
def set_parameter_requires_grad(model, feature_extracting):$ Y P; N$ E" {. L
if feature_extracting:
$ c; ~; P0 x" n* o for param in model.parameters():% d; H6 Y* `/ m% z
param.requires_grad = False4 d+ d) g9 M; H8 d
1, {9 Q, K* r0 \& K2 \. |) _# N k' I
24 V6 c& Z+ P; J; _8 g9 C
37 ]; @+ s# P# h W& x% P) t1 k* V
4" \) Y1 a( r& Z& X. b! M
5
) b) B7 F' C& n# Z# e' P# 打印模型架构告知是怎么一步一步去完成的. ^ }: k; x# c8 A. [3 r5 ]
# 主要是为我们提取特征的
* I' y+ [* ^9 H) f: }0 K3 v; m
0 n# j' u2 Y1 D$ B, vmodel_ft = models.resnet152()
0 V: M5 C$ J2 W3 \model_ft
; Y/ s$ ^% F, x" _1
7 i& u. i9 J ?: b* W' r, i+ v% U2
9 K) T0 q2 d& M6 f3
5 w6 q# t3 R) W3 U/ ^4
/ e9 M, L t; R5. K5 D* D/ k! g% I0 O
ResNet(
: v6 R: m ]. A (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)' C% ]: w0 U6 W* v _$ r
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
( |& i0 p* h2 U: }; E (relu): ReLU(inplace=True)) H+ H1 L, V& a P5 P4 j
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)8 W" E" q0 F7 \8 Y7 U+ u) @" R6 i- `' k
(layer1): Sequential(: X1 [/ J! O& L# d, u2 a0 _" Z
(0): Bottleneck(! _) v$ B* I# f5 O
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
1 x# M4 A' K. `- D; X* l (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# n8 w( E0 p' N3 N% H (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
; s( y4 ]4 i0 h& i# ]1 h5 ~ (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
; k! W& X" \% X2 x$ T" F (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
2 Z* f6 s8 o: T/ R (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ U4 X, F6 @4 Z1 y( \: r# k; e( u
(relu): ReLU(inplace=True), H; i a& L9 l
(downsample): Sequential(
6 L* o V- C- | p (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)9 t0 F3 }# P3 l) s7 X6 _4 C9 b5 T
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True). h; o7 W1 k0 I2 A# y
)
+ B8 o4 t5 \! G8 D x8 N& ~! K )
" ?' A: f$ ~8 z' Q0 N中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
/ G5 P" z# C. W5 ?5 }- c# N% | (2): Bottleneck(# q8 k8 n0 [5 d9 a* E4 x
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
! V8 N4 S9 U! i* T) Y3 r- z4 a (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)( Q. ?9 ?# ~$ f2 Q- F, \ h$ m, M
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)$ ^: g) J4 S1 Z& n8 Q
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
6 I K% s: v0 G (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
& W! z2 u0 q5 v4 o s$ d (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
8 D1 V5 v" T. o5 ~! F (relu): ReLU(inplace=True)1 E6 Z9 O' ?8 c( W
)* {% o, V( z' c+ x
)
' C; i" ]. A( T% H8 j (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))0 p+ D7 i3 j1 l3 e! r, X
(fc): Linear(in_features=2048, out_features=1000, bias=True)
% n6 n% ?* l9 G9 N/ k)8 X8 A! d9 \) H I, K2 z
! X) T; E) a. M1* U$ J4 ]1 V k E
2
8 I! `- ^2 ]$ ^0 Y, d) f: T1 }" y, N39 Q& M. G& Q! X" R0 `3 b" J5 I9 M
4
( b' s @6 V H! F6 J- U( P& F5
/ a: L0 K4 q7 o5 d: }) ^1 `+ j2 x6
d+ U6 ?# o0 b8 x) x9 H0 e* H7
! ]) s. Y0 `1 C- |7 W8
! a2 b* g3 N! ?2 F' T; K$ r9
- W: [8 b( v& u2 N% N8 m10
8 R$ F f, E, P5 _ @/ O11/ j t2 o) f5 K' u1 _
12 J: `# o# X# c* r1 R0 _% ]% K
13
+ c' R; s" }- p% X$ h5 H9 S14
3 p/ g4 ~ M9 R* ?$ r4 R/ ]15
/ y# O0 G* r. B- Y) y16
7 O, w3 G2 h" [7 L7 R4 W! S$ {17
% X; I. e2 D m; R6 k- |, K8 r! v18
3 X, I2 v. [& Z; x19* W1 M' v4 g3 _6 z9 V" A% M
20
9 h O$ |7 R/ J21- e2 g8 ~( Q$ _" o
22* j# h- q, q) E/ i
23
! ?! ~, e& i/ P: M24/ j: W7 F p7 e/ h" O! \
25
9 u# a. @: X7 E- j26
3 o6 X' x8 N( Y, d2 Q: k27
; D4 x* W, |1 D7 V6 Y: N/ ~8 N28
% s$ [$ C* U: c29
( H# e/ C3 e+ L! Y6 q' V2 G30
! [8 {6 n) c% }& r31
* K; k; {+ |0 b& `- `) E" S32. r1 ?# J& T; M1 t: j: `
33
/ A4 Y) O7 u! v( A最后是1000分类,2048输入,分为1000个分类; ^4 p6 k5 n8 n6 U/ M3 U( N O
而我们需要将我们的任务进行调整,将1000分类改为102输出
& N4 e/ h5 G' q) l' n1 C
1 k( p& f8 ` h6.初始化模型架构! a7 D5 D2 m" f- j0 l3 J7 i. q& X7 n
步骤如下:; @/ F) }$ U+ Y( P
4 {; Q* Q* i" O. d8 Y- V5 w
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
5 Y# R/ B c+ X `! O可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
3 H% C0 f3 D! m" S* R' W: |. h# t无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数- V' X. v- F }
官方文档链接
+ f" ^) f' c6 Ihttps://pytorch.org/vision/stable/models.html
" f+ _) I+ O' B
; @( A. c0 E0 z! U# 将他人的模型加载进来
/ K0 {& O& n; o' G% hdef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):1 F H2 u7 a/ N" I% `5 [+ _% @ ]
# 选择适合的模型,不同的模型初始化参数不同
3 |+ S5 `: O' x) ?: r n& \' R model_ft = None' P' k3 I) \ O( i7 y
input_size = 0
( B+ j$ T0 X3 h2 g* x2 i1 b
6 r* q& t; p; c% B if model_name == "resnet": P; Z" A! |* l, r+ a
"""5 d4 R t& w' W8 U$ D; z5 F
Resnet1521 {" D6 g d' E1 V# W$ {, ]
"""
; w* F. |0 Y9 I( K% @( H3 A' o! K0 n3 f. ]
# 1. 加载与训练网络
U* H. d `7 r2 ~# P( m model_ft = models.resnet152(pretrained = use_pretrained)
- X( e' @7 z; D$ t+ _ # 2. 是否将提取特征的模块冻住,只训练FC层
9 k# z& S; B9 |$ y: V set_parameter_requires_grad(model_ft, feature_extract)" E6 u" p9 T! ?+ C8 }4 Q' }
# 3. 获得全连接层输入特征1 I, Z# O' ^5 y* j5 s
num_frts = model_ft.fc.in_features9 x K4 K2 k; T1 y4 i% G
# 4. 重新加载全连接层,设置输出102
/ y3 q. k: }( Y" z( x. X7 f model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),7 g& U/ q% g# h0 E3 ^ f; L
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1$ E" c9 T* l: J% X, g) I
input_size = 224
+ q/ c* V8 f; H5 ~: T
! D( |0 R. y% m1 X& v$ l7 R/ b elif model_name == "alexnet":: `( G/ @, w* N, Q
"""
N5 S$ H0 Q$ T; h5 n4 H, i Alexnet6 b+ a7 a& J. x( \3 h/ S2 O
"""
$ h. f3 d0 Y0 s4 ~" G0 G1 B M model_ft = models.alexnet(pretrained = use_pretrained)/ w& p) E2 W8 ]" p* l, H
set_parameter_requires_grad(model_ft, feature_extract)0 }, f* B2 w! }1 Q6 L" O
0 I4 P3 |9 I' W1 X& d0 E& Y. s # 将最后一个特征输出替换 序号为【6】的分类器" O( p) l- f: A# t
num_frts = model_ft.classifier[6].in_features # 获得FC层输入3 D M) v7 V/ X/ o% Q
model_ft.classifier[6] = nn.Linear(num_frts, num_classes); Z& C5 \2 }. s/ |
input_size = 224
3 T. s% P- D1 q* n6 }- ]0 |& G! x. n9 o3 U Q3 S, Y, _
elif model_name == "vgg":" B# E! H; I& V5 e3 g- I
"""2 r1 ~: s, w" U5 A; m/ |
VGG11_bn
3 u9 @+ Q8 r O """9 V7 q Q. {2 J& Z$ ?6 w: _& ^9 t
model_ft = models.vgg16(pretrained = use_pretrained)0 t! v, d& w a. k; q
set_parameter_requires_grad(model_ft, feature_extract)
% E; |, f! H1 _/ {; U: i2 e" B num_frts = model_ft.classifier[6].in_features
4 _& e# D9 V4 _' s, E/ B# b model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
8 Z* t( _3 ~/ v; V, z5 F' d input_size = 224
]& Y4 l B1 z5 C) q) W6 M- u5 ~9 J u4 X# P9 I4 A8 W
elif model_name == "squeezenet":% ~- |0 l( C7 E
"""! S1 J; h9 |% {4 n/ W: O3 Q0 Y! ^
Squeezenet5 A1 r) q) F7 \* m' J
"""
5 `' u* J8 I8 S9 N+ y6 H model_ft = models.squeezenet1_0(pretrained = use_pretrained)5 \( H- c4 x4 r' S: U1 P2 f4 j
set_parameter_requires_grad(model_ft, feature_extract)
% x8 V8 x/ f) \# W5 ?. g* l model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))2 f) ^1 }2 P: e0 J
model_ft.num_classes = num_classes: L. t! {* h9 _4 T- x3 q1 P
input_size = 224
$ E! z' L, v8 z7 j" ]2 p0 j: I/ P
/ d6 }' v' {( ~ elif model_name == "densenet":
% H) l9 J: G8 P """$ h9 Y( ~7 `* f% M/ D- s
Densenet
% a* e. s* c& X4 N4 C """
4 @% ^& B' {3 P& K+ W model_ft = models.desenet121(pretrained = use_pretrained)# [) E0 D9 I! _: p, M# \& ~
set_parameter_requires_grad(model_ft, feature_extract)
! u5 t0 @. f5 l7 r+ k( P num_frts = model_ft.classifier.in_features
# s8 Y5 n5 a0 t% M# n$ ^7 o, W model_ft.classifier = nn.Linear(num_frts, num_classes)
9 ^6 }8 f7 `$ P& U input_size = 224- _) P; a+ c- D3 _/ N1 V' U- J: E' V+ q
/ h8 V! v, `, c' W4 d2 r4 v elif model_name == "inception":
b1 F# x1 Q& j """# r% _" r9 U1 R; O3 D
Inception V3
/ v5 L$ T: ?( i9 L& h* j) y9 q """( Y1 B$ ~3 X; ], i4 m5 k3 b
model_ft = models.inception_V(pretrained = use_pretrained)
- W% t5 K+ _ z0 F U set_parameter_requires_grad(model_ft, feature_extract). i6 R4 o5 _9 Z. c! b
. G2 L! V6 M+ }7 Q
num_frts = model_ft.AuxLogits.fc.in_features
$ C) t, J! g# s% M& p- `4 Z# \$ e model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
) @8 E4 C1 D3 s% A( N. e, A! T7 Z6 M+ ~" Y: J" @4 T% c
num_frts = model_ft.fc.in_features$ o' j+ q; I" ^ H- a+ }% }8 a ]
model_ft.fc = nn.Linear(num_frts, num_classes)% k' w+ M" J( p1 B
input_size = 299
7 s- K$ V# _& N8 I* m& [4 ]' G
" V$ |! Z. U/ F7 Q. x else:
. a/ G$ x9 y! l print("Invalid model name, exiting...")
/ X9 } m0 z9 B$ T: L/ g6 n7 i exit()
, P7 H ~1 B! o8 o, v' b8 o
3 G4 N/ t2 s! Z" j; Q/ @8 v return model_ft, input_size
7 W. \; w0 j+ Z Y; k
& G8 b. Z6 d: {$ K1
+ X2 i/ q4 J- \6 q2% {4 ^& B: q% I* r. Q, u3 F8 ?9 p) J
3) k2 H2 e1 j8 |. Q5 k' o3 |
4! h, C6 r2 R9 n2 m
59 W0 t; d: b4 A4 @7 \
6: j3 p1 p# W- @8 C
7
& s' j r' W% R7 s {& o88 d h$ X7 n3 T! \$ V& r
9; u4 e* ^5 t/ K3 l6 n- t
10
; C9 X7 X7 o6 D4 p, x115 u" |3 j9 i. H- o. j) y/ @ \, U
12& V9 E6 }3 K5 w+ I6 A! K
13
# `( D) c2 \5 {0 e14% C# P# [$ g+ X' ^8 X0 n
15
1 `- w$ M# f2 C h, W16
* E5 f9 d* P6 P! I. D2 M: h17
! U, E. o ~1 E187 R5 S8 v' ~0 R' T- m4 T; h
19
) [* T; e& T) F# t( J* t' [20
% O- |: A3 ]$ A, D21
5 ]$ Z! h! |. l7 ~+ L( J8 J( E0 z# L22
; J8 M5 o- S1 d4 K$ d23
- ^0 `$ k& }1 _& T0 I9 p1 z" J24
" R& D$ j- j. N; p* ?' D25
3 T9 c7 ~: U w; B; z& O' L26
; z" | g/ q g0 X$ Z27
# f. B1 a! R( x' Q! b% R28 h6 V6 {$ ]! I; }2 H6 `" G
29
, ] c/ j- C; E; D2 z$ W& g302 `* N6 a5 P: @1 C5 e
31
* B& D) d! S: |# v% [$ N32+ p$ g4 S2 }" `/ L: A Y/ |- X7 P+ j
33" l9 o3 [$ k) n8 w5 m
34# k# L8 }' f6 d/ \9 q$ \# c" P
35
$ ?1 j2 I% C9 D+ `36
1 `& Y1 V, c2 a37! V7 }2 ~( s+ V3 B+ x9 Z) H
38 o* X9 V2 t' P9 S: E+ V o
39
3 C) T, J/ P1 C( i$ w2 | |401 N. a, }6 v! a& V/ B
41+ p' ?, P; J/ q) j' }! {8 K1 `2 C/ Q
42
. Q/ F ], I" @" _; l! v0 a43: A" g' `* K- ~; J1 T9 |
44# S2 O) v" U' j, }6 A8 P/ {4 i# c
45
$ m# p' c, b# ?" F* ?+ X: ~- |46
4 ]1 [4 @. ?) P. Z9 [470 |' h4 X4 i! V9 D6 x
486 k, Q- S) B c9 A+ T, t! `4 k- S
49% y4 D1 ]8 m" ~& j
50" _5 O# v4 w8 a4 ]
51, H: O M# |; f E
52
, t7 @% @6 n# g! B, A0 C53' y$ C. h; Y% l" a4 F" m' b
54
) p1 _" Y) ?! U) T55
- ?2 A& E: c: |: f" a( [) c56; d: T. x: B8 T# ]( C
578 ?# _+ I; P2 Y
58
4 B' y9 w' A+ G( P/ Z- D- q( p- x/ A59
, w: A% w* I* J- a8 @60
4 b0 G1 C* U3 Z0 j' }; `- C! t% |61* o! n; I0 s3 I, J+ y. ]% J! x. ]
62
; U: e# M1 P e$ \/ m63
5 U* Z! n# {2 |64
4 r [/ L/ k6 k9 V) ]5 H) a! D65
! @9 C, h+ \$ V& t3 ^5 R1 O66
& R/ B. X. `( M/ T678 n4 }5 {' S: ~7 @& D
68
, p/ t8 n7 _! `) t. _7 F# o69
% m& ]# h# C# s70
; s" S5 m% X0 C6 l6 T. X718 Y+ |3 }7 r8 e( k# _
72, \' ]7 [" F' w4 w! Z
73; V+ e2 ]3 `* h- D/ B4 }7 x
74
" o0 k; L% c" {75
# c5 \7 x4 l/ m- R# X+ `76# b5 ^' V. l! J4 f) t' D
77
2 j8 j; u+ W* y- v! J. I78
! r2 ?) v$ P9 |% x z: [3 a79
6 M) U: M. P* t# e# C80
9 ~# v/ }3 D* [3 G81
1 w! c! b/ Z: k" ^( r6 |% S82+ \% p- D! \; p
833 ~1 P! R/ V2 C* B) a! E, V
7. 设置需要训练的参数
! k9 u; C* l8 v$ e5 B0 j2 Z# 设置模型名字、输出分类数1 Z( Z, P" i0 z
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
, |. ?% y9 _" v% G6 |1 J d* f! Y2 u$ T% c5 n: X
# GPU 计算
# ?; z; @' a/ H6 w0 s$ nmodel_ft = model_ft.to(device)8 e& P8 ], N. T8 P" e; |6 T+ R% F
% H, D( x# b# }: K5 m5 b4 c9 h
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
1 b& f2 N. ?' I. ]$ hfilename = 'checkpoint.pth'5 A& Z1 y8 D& Y# w, Z! D! V* d1 o
6 M- }* P* h# ?* C7 C4 [) Y# y6 r
# 是否训练所有层% O. U; F% b' F
params_to_update = model_ft.parameters()% V6 Q% q% N; V6 b! o2 v6 L2 ~2 q7 o
# 打印出需要训练的层
" D# v0 o5 l5 {6 a4 G( E1 ]print("Params to learn:")* \9 U! \; {! @8 c
if feature_extract:) j, l) R- u9 {; `( u
params_to_update = []0 h% x+ D4 N$ u
for name, param in model_ft.named_parameters():& v9 q+ ^( t& r1 V+ T
if param.requires_grad == True:7 V; d' |( D; T: ?$ }
params_to_update.append(param)
: U4 y0 t% \- G0 I4 G' U7 J) @; j print("\t", name)% v: g$ ~2 t. m0 V& _ l$ }; X/ L
else:
7 ?& K1 Z8 _" ~7 D4 F for name, param in model_ft.named_parameters():
6 _1 _9 Z) x1 M9 f+ f2 B( U2 x if param.requires_grad ==True:
& u3 G2 G J, e$ } print("\t", name)+ G+ {+ P0 f9 f& b
6 V- Q2 B) O) T
1, B) T6 W# N+ z+ b( T, H8 Q
2
9 a I1 h( w! H4 c: t* s3
" Z$ y+ U7 N. H5 E5 U# |4, y0 o) |" n K+ L4 L
50 ?' ^' l V% j8 Z% V1 ]
6/ v* Y3 w9 d% k# X! |8 O
7' u/ e3 e4 M' o. v' E6 E2 c6 |
8
0 c; X0 u. d( t7 \2 ]9
; _! P4 _& h4 t7 ~* o& @* k. S10
) K" {4 l; P* \0 q3 S9 s9 y# [3 \11- X" G9 v8 B, ~: |: ~ m K
12
& I$ {1 F. O3 D( z7 h1 j: L131 {2 X7 I- S) N8 M! [& m
146 _6 V4 R0 n X/ m8 R0 m
15. K& _3 b$ b1 v% i- i; w, K
165 T; S) F) w6 J5 m9 i" Z
17( M- [) ]4 M7 F5 s, Y9 {( T ]
18( u* t' ?8 L' g4 n! ?: N
19
( U' R$ c3 Y" O" b20* S2 j8 {: p# F9 |
217 c( |2 O/ V2 P3 z
22
' t4 N8 W9 ?7 _& ?# `2 S" c' j23
/ p" R& o) {( Q. C0 JParams to learn:& M' {% d3 W! i% M2 s- w- V# ~
fc.0.weight# m! G7 y$ p& y, {- n- \' N
fc.0.bias4 f1 v% K9 Q0 H( S' i# B! v
1
7 B, f! B2 q# v- D& S8 H2
- b/ _9 N# ^, h7 H/ @3
4 ~ Y& P; F( z- g: k6 p7. 训练与预测9 o% P* t1 v* C1 n
7.1 优化器设置
* i7 I5 R' a3 l5 k# 优化器设置 {+ D5 u- R% X1 x( n
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
! X3 U! g! S1 G# H5 o0 g5 A# 学习率衰减策略
2 D3 u* w) ~* j) A; e: X' ascheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)+ n8 j$ @4 |6 f; Y( H7 E! M
# 学习率每7个epoch衰减为原来的1/10
Z/ p% B/ I H+ e# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算& @6 Z2 I# c$ w$ b5 `
5 ^6 b, Y8 j( qcriterion = nn.NLLLoss()/ R8 H! w5 L0 {; v0 l: s
1* ~. b2 I' {1 A d* `" Z9 h
2
8 Y7 q! T, [6 ` J6 t3
" a4 ~# b3 ~- w v- z$ {+ m4! y/ {5 X/ s3 A7 z/ N
5
! s4 \8 ^2 G6 Z3 V/ N$ k: ]7 ^6
+ s7 c4 G! W% e' Z4 B0 {75 c5 j5 p- r6 d. b b7 G
8" x, ?7 e! I# z; i
# 定义训练函数" C2 b. H0 K* Y: J) J
#is_inception:要不要用其他的网络
4 L2 V, S6 m9 g$ \0 O: Qdef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
8 L9 b) }$ l: V) p' _! S since = time.time()
/ ?( D8 R1 ]( s #保存最好的准确率
( k6 C) h! q$ G* l8 v8 E, N best_acc = 0
* N$ I' H3 V3 D/ L( z1 O8 x( b """7 F" F, r$ y8 v1 b. l/ {
checkpoint = torch.load(filename)
: w4 R9 n g4 H8 _; X best_acc = checkpoint['best_acc']/ R% x" R8 H# S; }
model.load_state_dict(checkpoint['state_dict'])
. n' g4 v3 N" N V+ q optimizer.load_state_dict(checkpoint['optimizer'])
. A/ d* E- r5 }7 r: }7 s% Z& ` model.class_to_idx = checkpoint['mapping']# K; }" r% k. P. u6 J5 b/ l
"""! n& h' `0 o" z, H9 a
#指定用GPU还是CPU0 H8 z; O2 P1 T0 u3 R2 I! f+ L
model.to(device)
y( o# I+ s" n$ L h2 x #下面是为展示做的+ L" s* ?0 E9 u, X9 o' U
val_acc_history = []- d, m; A7 r6 O
train_acc_history = []
9 N4 A% J2 ?+ q; u/ N train_losses = []$ o! y6 i+ ~ R9 f- y
valid_losses = [], f/ T" x6 {0 a' I( I. |
LRs = [optimizer.param_groups[0]['lr']]
# K4 r9 g/ S+ B #最好的一次存下来* Y! @* \9 T; L+ Z8 S6 J
best_model_wts = copy.deepcopy(model.state_dict())7 ?% E3 }. g% Z; S
[! ~% @ [5 y% P
for epoch in range(num_epochs):
9 i1 i8 S+ Z; w$ D8 i print('Epoch {}/{}'.format(epoch, num_epochs - 1))
8 l# z! d- W5 t- {- l# `7 K9 ?2 x+ d print('-' * 10)* [! h" F# H+ n$ V3 i0 K* n9 M7 f
: z) v R; B1 p* R3 Z# u! e # 训练和验证
+ _# r$ h0 t1 i |3 U* N for phase in ['train', 'valid']:
2 \3 O8 u7 ]' u! ~" y7 { if phase == 'train':4 j9 B) x, ?2 B2 q9 Y, Y& h
model.train() # 训练# ` I% ~8 @7 G
else:
9 |1 e/ B# q7 L. {3 K% ?' \" O9 t6 w model.eval() # 验证6 V, s' v' o" O8 A$ y9 g: ?' o
1 M. R# Y- K; k* B- Z5 M" k' l* E running_loss = 0.0' A: |% Y' d! ?" @- w
running_corrects = 02 M, F W/ m( W, P. |
& q; T4 D6 Z# c8 p8 M' p # 把数据都取个遍
6 \. h t/ l3 x* T for inputs, labels in dataloaders[phase]:
/ H# Y9 Q6 h4 d8 f1 T% N" q3 a #下面是将inputs,labels传到GPU4 |, c# u: [' ^2 r- V# j* w
inputs = inputs.to(device), M: i% L Y) N+ K' v% L3 w; d
labels = labels.to(device)* y& J: e; E- O7 G
# k3 @. r! m# V& u9 }8 c% y6 S( [
# 清零
( {# K7 E" c% h B- n# p optimizer.zero_grad()
& ?+ W! J7 z! ? d- R # 只有训练的时候计算和更新梯度
& c+ W/ F( T! N with torch.set_grad_enabled(phase == 'train'):
8 l. B+ ?1 }7 x# t7 o #if这面不需要计算,可忽略
7 A1 C' f0 M+ _4 ~' d; t if is_inception and phase == 'train': g. y/ ]. ^" k. t1 v% Z
outputs, aux_outputs = model(inputs)& R' b# ]/ e6 B, C" J& S
loss1 = criterion(outputs, labels)& x! n# i3 H! y0 D5 {; G* Z
loss2 = criterion(aux_outputs, labels)
H+ {6 K# I' }& [: V/ A loss = loss1 + 0.4*loss2' A9 T) q/ f! p6 j7 O
else:#resnet执行的是这里' n6 d# T9 Q0 v3 {; c4 j
outputs = model(inputs)
3 Y# D8 a0 L/ I2 J( m& q4 e( i loss = criterion(outputs, labels)
5 B1 L8 g; d! H
. A! k: c% k0 W% _ #概率最大的返回preds* j3 @, Q5 c' }6 t
_, preds = torch.max(outputs, 1)
3 k$ K7 z" R3 d1 M
7 `- \: I3 E4 l' `- _( v # 训练阶段更新权重
% k$ }* J1 u! Z T/ H1 A3 |$ d( t8 T! V \ if phase == 'train':# S* i- r T0 E
loss.backward()+ c6 N" A" c2 N6 ` ^2 x2 `
optimizer.step()
) G. Y Q) {) _* \( Q' [4 R/ X$ j8 k& Z# j2 X$ x
# 计算损失
+ Z! t. B# I2 p- V: K2 L running_loss += loss.item() * inputs.size(0)
7 B) Q, |5 n. W/ A4 m, e running_corrects += torch.sum(preds == labels.data)
) ]1 H, K0 v( B+ ]2 t- X8 M% ^7 o/ }8 h
#打印操作4 m) ~/ v8 T+ ^0 q3 `% E; P& g4 W
epoch_loss = running_loss / len(dataloaders[phase].dataset)7 [3 W# H7 l) ~7 P
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
8 I+ Q! Y' d/ |9 r1 h3 u
6 ^( S# n9 O/ y9 _) L
5 j1 W9 n1 `. x! t$ q time_elapsed = time.time() - since
4 S! T1 [0 Z" w; |. k8 O2 ~+ Z print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))) e0 C, U* d2 e6 e
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))6 l( M! d3 ^# m' M' H2 b/ D
% y4 q" }! G2 d0 S+ l
: R2 q8 h* c$ X: C # 得到最好那次的模型
/ _ X f2 ?# y8 f P$ G) u if phase == 'valid' and epoch_acc > best_acc:, ~3 i$ j# i/ M `/ Z( W1 v
best_acc = epoch_acc5 M- x: O9 I! ?: O
#模型保存
- N8 Q+ E' U9 F h best_model_wts = copy.deepcopy(model.state_dict())9 H, z2 O5 U4 r
state = {
8 F: S$ ]4 t+ d# n, t: C #tate_dict变量存放训练过程中需要学习的权重和偏执系数
2 k1 a& j& l6 P3 Q6 G" {2 i 'state_dict': model.state_dict(),
9 s: G) Y' c2 N/ u+ e9 w0 ] 'best_acc': best_acc,
6 S% c! r( F5 G- ]0 s3 O5 q# C 'optimizer' : optimizer.state_dict(),, C0 |( C: p7 V( a
}
" v* H3 y M6 K4 ~8 f* t J torch.save(state, filename)' C4 X9 B- s* g" O& {
if phase == 'valid':
9 Q c" u1 v9 f% F' ~. b+ S; i val_acc_history.append(epoch_acc)
! y0 g6 E O5 o# u valid_losses.append(epoch_loss)' B3 c! m& C) a$ V/ o8 I
scheduler.step(epoch_loss)8 j* C+ r$ \7 I4 X3 q0 U: x
if phase == 'train':
4 N& ?4 s3 [/ K1 Z" B train_acc_history.append(epoch_acc)2 f6 P l: i2 s ?* h2 d
train_losses.append(epoch_loss)6 p8 Y8 N, i& [7 X
1 n* y# I1 s7 u c3 x print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
( {6 d7 s5 W- L( I LRs.append(optimizer.param_groups[0]['lr'])% q" f5 E' b2 R( P' R/ J
print()
: O& Z; n5 A3 M2 ]. G- Q: q% w4 h$ H; d: G+ O! S6 c1 b5 |) w; i( \
time_elapsed = time.time() - since) q; c1 U. }' x2 P# ?
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))% u( K/ F2 m% W; ^% f& p: A
print('Best val Acc: {:4f}'.format(best_acc))6 ^# E, s; i2 Q9 ~' V' w
7 ]. X* m- q4 b: A( Z) o # 保存训练完后用最好的一次当做模型最终的结果
- L8 p7 ^4 ]! B- T$ ` model.load_state_dict(best_model_wts)
: L0 L- n2 G2 b) U; f9 \ return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs {5 s" a6 Z9 `' s3 f8 i* |& g4 [4 z
* H! \# I& P1 K3 B1 R5 C
- B6 i: z6 l5 ]/ ~1
9 b/ B& Q$ ?7 q. P' V3 ?2
& s: i4 }6 }" e% ^35 F8 R9 y2 ^7 |/ ]2 K7 N
48 O5 ?/ \& \0 b8 B
5( ~) }/ e& v: A1 a+ ]
6
9 S5 | s4 l8 y" f8 ^2 ?7 J7+ S5 J2 r3 T8 X D1 @8 k; E! C9 v
8
3 `* N$ @ [8 d" e A0 O! b1 _9
- J/ l6 x8 p- y& d! V& d10
1 L3 n+ c6 U2 }11
5 T. y" |" K R( j) s' a12( h$ L& ~, {3 g9 \( h
13" c- ]/ G# W/ v* d1 Z
14
) Q3 ]& C/ y4 y# b& t% R. S153 ]% Y9 a: g0 D
16
! z# ~* _& T/ @+ N' M& d- y17
- |; y* S/ C+ s, x+ m18% @% w# u$ Z+ k, D: Z: h9 j
19
' r, N7 G1 E6 j1 s8 B5 N207 F7 D! p: q. A4 g$ \
215 f( g0 k3 v, V: G9 p. }) R6 S/ ~$ L
22( Q F6 M B* F$ v4 X( m
23# ?. O' F9 M+ K' b" i3 {) e8 b; _& ?5 M
24
- T1 \1 z& g/ c3 ~9 o25- a3 c! D7 {' C6 r( ]0 b( f5 f6 G
26
9 n8 a. x2 {: K: G* e. E27
: ^6 Q( o( z! s: A" p28
- w9 k2 W' q- a. x$ b29
6 o- H# s% @; o8 d" R$ b4 O, F302 ~. }3 x% T& C3 Q1 r
31
D5 S0 s5 B8 B# ?+ q- c32
' _' o4 s$ r( {, q33
; S% P$ ]5 w( i/ W6 y34) s! `7 A1 D v0 Y
359 [8 }! d3 F. }( H; t
36: P: N- g, Z3 _; ~* |; ]0 b" W
37+ u; B6 d* }/ e4 a5 W. j
38
) r8 w# J% \: i+ a ?39/ N/ b: p( v; \
40
# i3 Y! U, x: ?$ c1 O41# X. E& e6 H$ a* j' Y& ~: G
429 G" U" |3 F7 O" ~! \6 o+ v
43
( u4 w& `2 T& [0 V" A* r: v% u3 a44
* F4 N# V2 t) u7 T8 C) u& H45
/ c1 X9 Y/ F0 A+ Z6 {460 E3 D* h7 I3 m- T( M( g
47
" n+ A' T9 K2 Z& O7 r8 j! m+ G48
0 V2 N: I% D& t7 B$ b2 e2 m498 k" Y$ p2 j$ s; F
50
# m% a7 Y! V" T& ~1 i51; `. v- `( C: Q6 l
52& X, b C% j5 F7 g g0 ^. m
53
4 I' g$ N) v6 `# _549 _9 l7 \% \1 L. [( @
55
5 \3 z: M9 f: A2 M0 P5 }56
5 B& s' G9 Q' E* [. Z2 ^/ R57 ?! M# ]3 j( @. r
58
/ i/ v$ d9 y' ~& T8 I59
. K2 E+ [# q* _+ B( T60 x+ D. v1 ]3 E+ Z/ ]2 s% C
61
5 ]! c: |* w, u) T- E620 v& j ~, _" J9 _/ ^4 Y
636 M: m, t% A$ [5 Z
64
7 `. i& S E8 ]% o658 E, A. e k4 u3 G0 M3 g
665 B4 C; ~+ ^) T9 v ?
67! n: U0 t% N# M# j
68
7 y0 D: Q1 ~$ F1 Q69$ ]/ }/ ^& Y) m7 _- P# U5 C
70" Z( c) k3 D6 F q
711 s B; G9 c! a3 q
72
* Y+ o! Y& z- B0 q) a3 m& q73
8 s6 Y' S$ T5 b5 V; b4 a" M74* F& d6 y" `0 B2 w1 [
75
" G) b- _# j2 O/ r7 R/ G76
2 R5 ] O4 {% {& P9 y# s+ v77
) n& U( ]7 e+ P6 S5 B78
4 T) z0 a) _( |1 X8 J79. h6 l& g3 J) U* b8 y2 o' {+ M
80" Q7 P! f6 Z) u3 c( [% v
81
5 v$ D! k/ c" V9 ^: J; a# c& k7 w82: B% Q& z2 V0 o; P
83, ^5 \4 V3 G: h k
84
a1 r/ E$ J4 z6 ^856 P( c" s: z. }; M; U; A0 @ M
86
5 d: f9 ]0 f6 t1 i" h& R7 [87& b8 ?- w' h$ J" z2 K6 a5 g
88
! l$ \! n3 B4 F/ ^" G89
0 B* ^! c4 F2 z$ Q6 L90* _% b2 @* B S4 w/ @
91
' }$ W9 X* F. m0 O92: ~$ N( X+ C4 G" A
93- R7 ~" S A! o- s! ^2 S
94
$ f3 o& N% f2 n2 f$ E+ e' R95% A& t/ x: z( a5 S# M6 S7 Z' }
96
8 w* L6 k) w) S% T: \" K97
& q6 E* w* D5 `- c98) ^2 x% x3 _) V1 Q$ e+ H& I
99( s, H6 M6 S" i
100& y7 D( v' b+ k5 P
101
+ M# d5 U* {% x. F, d3 [1026 c1 \1 q' \- j" i3 M2 x& j
103
- Y- ?: K2 z; u/ Z+ @104
* q) C6 h! b6 g* D3 Y105
7 a) H8 \ {& z/ B' @+ [" \1066 N; @2 }" b2 S! H
107
! `/ l7 w! H2 [6 u/ M2 j/ p3 G) j1087 E8 E5 |( W: _6 c4 ~) ~* D9 C
1097 N; W# q% a5 N' ?, D1 Y) a+ a
110
0 L$ k; e# T' B$ [111. s" w7 q' j @$ `* s
112" a# _& [+ y+ b8 X* Y0 C
7.2 开始训练模型6 M% Y2 H, ~+ d6 I8 |
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次+ T9 |9 O+ X) \" o, B2 h) R3 u
5 p: d$ _( y) q0 b4 S#若太慢,把epoch调低,迭代50次可能好些
, X+ A1 n# w" k/ a/ `" T8 S. @7 t! b6 M#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合3 e, W( s) w2 @1 \7 x& Q+ b
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))8 g ?. ^. @4 n: }2 X; a' \3 D
# `9 u4 [. x" c6 K7 m+ i; f; j
1
2 J: i3 E% A* H/ e* D" P8 ^2
" `0 U U# s8 A" t& I3# f* E' G+ u7 T+ q9 t/ V" L7 G
4' e) G* f( I# w
Epoch 0/4
6 W2 n4 K3 t w# ?* K' ~----------" V! B |0 [; V# H: ]) ^$ p( {
Time elapsed 29m 41s/ V0 r* J. v4 V, s/ U% z d
train Loss: 10.4774 Acc: 0.3147
0 Z# S! m; p6 a: G OTime elapsed 32m 54s
% b9 X4 u2 F0 dvalid Loss: 8.2902 Acc: 0.4719! [0 }# O, a/ l
Optimizer learning rate : 0.0010000- j" h( b( e' s
) K. i9 B8 U J4 o0 i1 sEpoch 1/4, M1 U- {1 k% f
----------
7 t# U) z7 z( O/ S4 BTime elapsed 60m 11s
/ N1 G' h+ f* _& Z: i" @6 Itrain Loss: 2.3126 Acc: 0.7053
8 q, {/ @$ j; y0 tTime elapsed 63m 16s' R8 e5 a5 U) [- `& P* R
valid Loss: 3.2325 Acc: 0.6626) Y7 ^% J `4 ~' x
Optimizer learning rate : 0.0100000% ]8 z' k9 Y0 p" Q# R, `
1 @, |' N3 {7 ~, b6 T$ FEpoch 2/4
( W! K8 L& v+ B3 a9 e9 L----------2 l: d1 j7 _# N3 B) x
Time elapsed 90m 58s
A, r; P# W' f" {train Loss: 9.9720 Acc: 0.4734
7 ]" [" @4 S0 j! }Time elapsed 94m 4s$ K Q+ m7 S3 e" t1 T+ i* g, ]! C
valid Loss: 14.0426 Acc: 0.4413, l& ^# s P1 w) E4 C$ o' p
Optimizer learning rate : 0.0001000
) s0 D/ \# ?4 [' [. _) P5 d$ _9 b- q1 O/ P- F* C" z, i
Epoch 3/4: h. q5 k% T8 p8 h5 j" c
----------" o6 l. ^3 e( D+ O
Time elapsed 132m 49s3 b+ p7 J- u$ K8 W" y1 P7 O- h
train Loss: 5.4290 Acc: 0.6548, P) B* p7 k# l- i' P
Time elapsed 138m 49s' w3 c0 H) [4 B, F, B( @9 M
valid Loss: 6.4208 Acc: 0.6027 ~# P: v% s4 n0 O* B3 q
Optimizer learning rate : 0.0100000
, G( p! W: u% H1 E; V+ V7 L% s' H+ K$ B
Epoch 4/4
" ~9 @2 h" F- d& W/ n) t' o----------
$ A' v4 q4 g7 m( Z7 qTime elapsed 195m 56s
9 v1 A* e/ N' J9 `" h9 a5 Atrain Loss: 8.8911 Acc: 0.5519 {, d @3 i9 V! g3 c5 r
Time elapsed 199m 16s
! d2 Y, u6 ~$ c- w Y5 J) dvalid Loss: 13.2221 Acc: 0.4914
- h0 A, Y+ A$ x+ J4 z x0 ]! Y1 ROptimizer learning rate : 0.00100009 L5 T F* A' M# L. K# ]& E
0 O5 m" |$ A7 S) p$ N
Training complete in 199m 16s
g6 ]) p3 ^% k& R5 yBest val Acc: 0.662592! W: ~2 K$ K1 A6 C
4 J- C0 l' l% I3 ~ W5 n6 s1/ B; E+ Y6 x" [3 |
29 j$ [9 q& c7 W0 \3 m S. Z$ ]
30 o: O6 m; a$ |
4
5 k; U3 x2 t' i m, P! I# _2 m) H5
' a8 P: d- |3 N. R5 v6
5 C2 K0 L1 M2 `+ w, ^7
0 k; B6 [, H& I8
( _3 @$ d: s! n, P9
+ O# \4 v/ P$ u: M8 @- A$ T10
$ s" p5 K/ N3 O1 @; F5 e11
: l& I1 t" V W& W+ q7 h& q9 y3 |) x: [12( N4 _: V6 r4 c: D9 c7 K
138 {4 f. `" `' {$ t, ?0 |4 A
14
" @. |. e% g& u& @* `$ u7 A, i15
% H1 F9 r* K! [ x6 [, E16
" G$ x4 s1 A- e5 _# N17
0 u @5 N( b/ d% ?% H5 Y7 U18
- ]& K4 x. k- z. h* m; S19
a: g# L; c4 |" z20$ N( @; J7 k4 L6 Y
21
5 g( O' |; v% I8 j1 i. C22$ f! f* s X2 `4 V( N
23- [; _: [- E+ ]0 y+ h5 E
242 ?* X" p2 v. G
25# Q) C5 k7 [3 K0 j
26
2 f+ ^# n5 |+ I# a7 g' |272 }; }5 C+ \ t* m
28
( h( O4 m2 H+ w29
* ^" p5 [. u9 g30( \/ |2 K( _ S' A# i4 v' M
31
( C ^( c: ~9 k T0 @32, A5 n; q0 t+ P9 P$ ]/ k& A9 w
33+ q0 P- h# x! Y" ?! n6 s4 @
34
3 d3 N) ?5 _, I+ D6 u5 [2 @351 d! c# O. S) U
36( v! [! r2 d7 `( j
37% T0 l+ ?0 A6 L0 h: j: [
38
; W1 w2 v' o. Z7 w" m. ^* o39
6 U8 q* q5 I. k: o6 o( A406 k( `8 _$ m2 H! k$ v9 z: G
41 X s/ Q+ P, T) G
425 U# A" H1 o3 r8 k
7.3 训练所有层
& T- [+ V( k, T1 Z: b# 将全部网络解锁进行训练
3 a9 X- R0 Q5 g B. Ifor param in model_ft.parameters():$ v4 I/ H$ i6 T5 V/ V: Q) T
param.requires_grad = True
! b3 k) t( {5 B' D! Q
M. l; C0 q' B; B1 r# 再继续训练所有的参数,学习率调小一点\, l0 m0 {, s3 ?# |" b3 v3 f
optimizer = optim.Adam(params_to_update, lr = 1e-4)
" ^7 g5 |' H1 o4 |scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)2 A* H" y* \ j% Y
8 i$ g% |# g8 ?4 ^0 X
# 损失函数; j8 m, n' `! i! i8 X) K
criterion = nn.NLLLoss()
3 N s& C1 ~, v) I1
6 W3 ~! h2 i/ f: G! I3 L! Q% m2
# u7 Q, P# v; T" N6 Q8 {34 M4 e3 V Y# |, g, q) X
4$ Z1 W9 {0 A0 A7 s. ]5 g U
5
2 A- y# o; r0 F6
2 z3 t, V9 {8 x5 `- m7* p/ J: A/ Y" P* k* Y
8) h# \0 N# X0 f8 s- r2 E5 {
98 N4 m# p6 t# _6 C; U4 D# [# K
10
/ ], ~: z) `3 Y) \0 D& V# 加载保存的参数7 b7 i- S2 |" s* S9 q: y
# 并在原有的模型基础上继续训练
9 h2 L$ N# N8 g3 G( a# L# 下面保存的是刚刚训练效果较好的路径
$ U6 l: ^# w5 n6 z) Q0 W6 ~checkpoint = torch.load(filename)6 W9 p4 x% s' Y3 t1 ]' j- U0 p* w
best_acc = checkpoint['best_acc']1 u: } l, X( v6 A6 F7 k6 h% j0 a
model_ft.load_state_dict(checkpoint['state_dict'])- F1 H1 R% b4 Z* a
optimizer.load_state_dict(checkpoint['optimizer'])6 I- G E" A9 y+ D; ]/ J6 u
10 B# Q- m6 ~" ?6 D8 S- n7 V
2! g4 o8 W6 h7 Q7 ~- T/ N% O7 Q
3) I1 k# W1 s7 M! e0 u+ S
48 Z9 Y+ }" S k1 E
5
, F" ~/ d) y7 j+ \) X64 X0 n) d2 k( q# x$ c Z
7, t( \' u7 V$ q
开始训练
$ F2 k- t p0 G9 A" V注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考' L) X) i4 x9 P) b
4 L$ ~9 d e9 Q0 e! ~2 u/ G% \
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))* Y0 u) p/ G& ?0 ~! C
1, C# i& Y. a& E* s/ U0 v
Epoch 0/1) [& v. C$ K3 A% |, Z
----------9 A4 c9 U+ i0 g
Time elapsed 35m 22s0 M- B6 m: t7 k' i/ |
train Loss: 1.7636 Acc: 0.7346' g; E$ i& v, c" |
Time elapsed 38m 42s
) ]3 x, E+ O7 e! Nvalid Loss: 3.6377 Acc: 0.64551 T/ e. ~# L& m( i" W D
Optimizer learning rate : 0.0010000
+ A! c3 B* b4 K8 T' {1 y
# j3 o9 ~, R/ B. vEpoch 1/1, [: {/ U( P% N4 E% E4 j
----------
$ J5 z3 \; ?0 e, cTime elapsed 82m 59s
1 z, A& {9 H% B ^5 V6 {2 l4 Vtrain Loss: 1.7543 Acc: 0.7340
! j4 F8 m0 }; ^# pTime elapsed 86m 11s
; J6 J% O! m0 V' Gvalid Loss: 3.8275 Acc: 0.6137
- {) r. b# @3 c, ?+ U/ iOptimizer learning rate : 0.0010000
4 ~0 G- C" t7 e: W G7 U6 }% R; n
" G6 w% g+ s w7 B D9 N, [Training complete in 86m 11s5 ~: a+ ~ |- _6 n$ J2 x! x6 l9 ^
Best val Acc: 0.645477" c! G$ I2 n0 l# e3 d! @
a2 k2 @3 p' T; R% W+ C! d |1# g4 X) V' |# f h$ ]& K2 j
2
2 T9 y5 m3 U5 }- e7 R1 `2 U+ p0 b36 x/ w% U8 D) L, d0 w
4
( |7 t0 F& J+ d1 w' w( x5
! y2 T: I$ d, N7 x) l0 T60 d* Q. E7 i# o) @* b3 Y1 b1 j
7
2 q' r, V+ E1 J* V% Q: n. {& `. q82 y: x: i- e! Y( Z8 L, `/ c6 F
9! U: r& c7 \6 a5 U( y
10
o, v7 [+ _5 {" m/ q* |* O11% M8 y u3 P- o' Q
12
- k5 L+ m$ v' N9 J3 [13, d5 N! r" n( n3 [' Y4 m: P0 M6 i
143 |) E/ W0 e" n5 n2 A
159 E0 w( n; w* c
16( N+ S# v2 h) q. ~) Z4 L" {
179 A; @: K: ? e+ w+ |* s
18. ^) Q& \3 S* b! u$ x
8. 加载已经训练的模型
+ y' ^, \: z4 [1 F2 h相当于做一次简单的前向传播(逻辑推理),不用更新参数
( i# n2 y+ n' b; Y- A: q
* M2 F" F! y6 d; Q7 A+ s1 Q9 x3 {model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)& I( E, }7 Z5 F1 k3 r) | h
5 h! e% I" w2 x3 N' W& r( q" E# GPU 模式
. @ U4 s5 f* {9 U0 {- lmodel_ft = model_ft.to(device) # 扔到GPU中/ Z' Y3 g) w+ X j
/ f" F5 c8 }/ \# S# 保存文件的名字
' ^# L3 `( Q5 g( Yfilename='checkpoint.pth'
+ E7 x* U0 k5 r- ?9 I
% A4 y- `# b: I ~* q% r9 r# 加载模型
0 ^ f9 {. h6 I" k5 dcheckpoint = torch.load(filename)
) c) v$ E8 n( ^* Obest_acc = checkpoint['best_acc'] \, ?- \5 P$ M# a# _
model_ft.load_state_dict(checkpoint['state_dict'])/ H8 m s8 R- |) e' w+ @
1
" v; g, y3 p1 ~6 u0 S2' M1 s5 O6 U/ B$ e) |4 y
3* l# D/ i0 o" b/ B5 K* Z
4! q4 U* A- i1 W$ J; F+ Q
5
2 T9 s9 \' I" O6
9 ? `. _5 ~" y7 Y7
1 `" f, u# v1 {- T1 Y- |8 K8 ?. C7 E# z* X
9
/ A( a* ?# b: L- x2 M" Y& a9 n" j10
- r9 s* W' L% O11/ X, p3 j, g4 F8 s$ X3 H
12
# V7 E3 m8 Q$ y( W1 x% y7 }: R<All keys matched successfully>
- O8 D# x- L( Q z* Q j. M2 w1) `; N2 p! N# h
def process_image(image_path):% @4 e6 r9 V* v2 R
# 读取测试集数据5 v* ]6 l. ~2 s
img = Image.open(image_path); _7 O" U/ n! F& p: s
# Resize, thumbnail方法只能进行比例缩小,所以进行判断- _. d% H! P- [! m& o. v3 J9 u
# 与Resize不同
3 _. {& `* @- }& a1 Q # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小+ z8 ~$ \/ q' S- A, K
# 而且对象调用方法会直接改变其大小,返回None/ ^- X& n0 v4 Q0 Z6 T, Z+ \
if img.size[0] > img.size[1]:! J5 n' u$ O# N) X5 Z/ T2 \$ R _
img.thumbnail((10000, 256))- V. q( t0 N" `$ d
else:! H! G5 Q# H6 L3 w0 ^7 p: }
img.thumbnail((256, 10000))9 Z* H( Y; C& S$ K% C! S2 q% G" x& h
5 j9 p, M/ [7 ~! D5 L
# crop操作, 将图像再次裁剪为 224 * 224
( _ P9 I$ Q' t% b! Z* G2 k$ q left_margin = (img.width - 224) / 2 # 取中间的部分1 {, P6 z9 y' q# f/ K' N- D! K
bottom_margin = (img.height - 224) / 2
8 ~0 Y3 I" Z$ t0 K1 G/ p1 x right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度+ U+ e- j, [3 o f
top_margin = bottom_margin + 224' f/ C1 `8 j5 a/ p
6 o7 {7 K) T4 @1 Q' l img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
5 O4 R0 n- V2 ], v
' ?: J$ n% v2 h # 相同预处理的方法
' F2 \' p. }% P2 X9 I' C; N9 R: J. e # 归一化; X# B( Q C( `% f
img = np.array(img) / 255
) q( g3 O- J4 T% \0 `/ q- g! } mean = np.array([0.485, 0.456, 0.406])
3 o f9 \& n. e std = np.array([0.229, 0.224, 0.225])
4 J: q# `+ Y4 S0 u8 u7 \ img = (img - mean) / std
" M& z. X8 M9 s
& _. E& e# x. e8 M7 N% j # 注意颜色通道和位置0 R5 A0 ^# \8 D( h4 f* ?
img = img.transpose((2, 0, 1))7 U( w3 ]4 ~9 o& w" @0 Z
' C% ^/ X5 Z A0 K6 a# l: R0 v& d
return img
5 l: {3 a3 V1 U# x4 J2 V: Q/ ]
" {( q2 p- r1 J! W! ?. M. i! J4 odef imshow(image, ax = None, title = None):/ J3 f! \, _; ^; t
"""展示数据"""6 z3 W! ~2 V7 {" ^
if ax is None:' S. G# O" K8 a/ M" }
fig, ax = plt.subplots()
# S m* l ^- v' T0 l1 }- z3 M( S. `4 C# ^0 `
# 颜色通道进行还原1 d+ c# Z4 _ f
image = np.array(image).transpose((1, 2, 0))
- ~( n5 }4 E! Q1 h' Z) b8 ^5 n+ S+ h* ~; r5 \/ a% X5 x
# 预处理还原2 x, Q7 V( I* _4 q0 H1 o/ B
mean = np.array([0.485, 0.456, 0.406])
! i4 a) |1 u! g8 S8 R9 _ std = np.array([0.229, 0.224, 0.225])
9 Y; j$ a( E9 E9 c, F image = std * image + mean- y8 @8 L+ l! r- V( s+ `
image = np.clip(image, 0, 1)
8 T' i3 }9 V" t1 S' ^ _+ e k& w: i9 F
ax.imshow(image): A* j% j7 r) n# Q8 z9 k( L
ax.set_title(title)
- j9 Y+ H" ]# ~8 ]. d8 z! D! o9 K' e, |& N. g. Q: a: C* k- q( @
return ax
5 }' O' c/ W8 `* T% t& ]/ n# |4 B6 ]1 `
4 {5 i# q( m+ [2 n& W% `# o simage_path = r'./flower_data/valid/3/image_06621.jpg'$ @9 `" u) ?9 K" c
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. j3 W1 A( z: I
imshow(img)' [) T0 T* S) Y
7 a _9 @. e/ ]* r* r. R2 G/ v
1- @& n5 ?: ]3 J
2- ^: |: a. v; H1 w; f
3# ?, p/ H1 a4 W# E
44 A/ e7 l( v% z4 v& T( p) S& l5 B
57 E5 l7 N! T1 z+ S& g4 o
6
7 f; L% o; w; p5 Y% t1 [ n0 ^7
" d* C/ A8 }# h9 \8! V5 S+ h1 s; k
96 P# @- r) u6 L
10
6 P- o1 k1 J+ t$ h! e2 T6 m11
* ^5 B+ U3 ?* V- W1 a12
6 E4 J# K. o3 v1 M; s13
! S; B; b# O2 O0 Y4 r/ a14
# |5 U8 V; J {7 _; }) a9 @' V: }15
7 s4 V$ w* ?, B3 D+ I7 R; m. d167 l; _+ ?3 P5 [8 K
179 \; X4 G6 e# P: ]4 Z9 V
18' l0 q! m. G9 s* W
19
# c* ]/ I& K! O) Y& k20
% @4 U" F& r) z. \212 x3 [, l/ b3 s' X$ B/ D9 b7 T" `
22
+ H; Z8 m% H1 ?1 d5 U+ s3 G8 f/ l23
0 N" ], s' V1 u# E8 Z24
, N4 H, V; o1 ~: c5 a0 F) [4 R25
/ M4 H9 O7 {, X( i p264 s0 u a6 M. t" e+ |5 c# I
27
' |5 d6 a% p4 ^. } i28$ C7 o5 z! s2 f
29
6 C! q' @& }# X6 l; m7 C301 ] x/ m3 t7 c& P1 P# v; X
31
: X8 T8 Y4 O% Q M$ r( I, J32
% N O7 C! V3 Q" ]( q/ S0 ^33
( K4 s. r# w# b0 H34! L# [9 }* Y0 O, h8 P
35) f# R4 k' q+ i1 d* }
36" _+ k$ _) X& o* G* N& e% F5 [3 O
37
3 Q% D0 K; ^6 f3 l! L1 H6 T+ c38( w+ V) ?5 ]- R" B- O+ P
39
" O/ I0 H* |) O0 Q40
6 Z' i7 a/ f4 R3 ? t; Q41; R2 O$ I( c' n. }7 i6 [
42
/ T& U( j& C9 Z6 f43
- H. V9 ^7 U& T0 q5 t' x445 C5 D* D9 q& }6 [
45" P/ I# i" P7 X/ [! N
460 c, b+ N# o4 v P1 ?8 [# |$ w4 L
47
5 u( B1 f) \) a1 g3 L, D48
; R2 g: T. p7 g6 t+ |49# U' y+ {8 {3 u
50
. }8 a6 I) l5 L( x0 L51
$ l Z! Z( N( K, k52
0 ?! }# X. _. d. H' L, h$ K531 _" g8 I: [4 i
54
. n: a7 B4 {" V+ X2 |3 g<AxesSubplot:>
5 k+ z3 N+ L7 b5 r* Z; ?1# z0 F4 Z9 [# h
8 j. g: w0 L T8 m6 ?. @! F8 [5 B: ?上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
( K3 O2 k* a3 Z* {5 K6 [+ ]6 k3 ? j, m) N, S
img.shape: A& h) v" M1 J: u
1& k, r( |- b% w9 T9 b$ h- l
(3, 224, 224)
9 D$ L. V* }, \( U9 `1
# u4 l2 ?0 o& Q证明了通道提前了,而且大小没改变
# P/ j9 x' ~. v& C1 e9 v+ ]2 ?/ D. Q' \- R4 R4 z7 b% [
9. 推理5 [" v0 }% g; O# [
img.shape0 {8 | o' W( B
5 p; h1 |( m0 P% [; h1 N2 ?# 得到一个batch的测试数据
& F( L# y: f% b/ odataiter = iter(dataloaders['valid']) e* t& J7 q3 c1 Z9 e2 a- W
images, labels = dataiter.next()( D: O( _/ A& A2 `' }) z1 d
* p6 {. R/ \$ v, j) y! xmodel_ft.eval()
- S) y! Q1 b- q( b( c) q1 T2 ?7 U) s* K0 z7 ~4 p
if train_on_gpu:3 K3 F* z2 n' X/ K: n
# 前向传播跑一次会得到output
n6 o% y* n( |4 |& E6 |5 W% z2 ]) H output = model_ft(images.cuda()) w5 n3 p1 h$ E% t3 v, V' u
else:
# d/ j8 \1 ?& R% q output = model_ft(images)
* S4 z) M4 y# f% w5 @! M. @% A& H1 U; z" p% e/ k' W
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值8 H! ], E3 t8 t- S h8 V& X9 Y
output.shape
( v a+ s9 T4 I9 ~: C. a0 x# S/ ?: x5 A4 ?( E
1
* Z- s# {1 Q8 Y2' a. d5 c. ^3 o# t8 b& A
3
0 @8 x! U% j: @9 V4 v) K( V: F6 \4 G; x% x
51 e% Z$ f* j) t" s7 {! [( {
6
3 ?. A! R8 \2 E O! m& x W" L7. t1 r) D' A6 A
8
3 M( B1 }5 Z7 \" W. Q+ @! u+ w9
8 [9 C0 q( `- J10# t( _3 N( w- j' I5 |0 [
11! R5 Q& o. N0 o- ?) a. [8 j0 s9 n
12
) G6 p0 p p' s% j; M/ }# k13
B6 X2 m8 L1 @14
, X9 S: u; G+ m. d( F \15
& a% o& T6 w1 t" c' C16
, m6 K- q% X2 Ltorch.Size([8, 102])
& D0 r$ t9 k4 P) c' w" }+ v16 i4 l/ m$ c; m
9.1 计算得到最大概率
1 }7 D4 T) G1 c8 r4 f_, preds_tensor = torch.max(output, 1)+ p1 j3 L# A* M& F% E3 R
$ ^8 L' S3 r+ M
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量# d g- O! W( q/ b( o
1% |. [9 L1 b) B: |
2+ }9 A& U0 K. F U( s1 X. F( q
34 e2 q7 O% [9 ~7 k# {3 Z+ x
9.2 展示预测结果
- M1 T. @# h* @2 |$ `( W. L2 s# Xfig = plt.figure(figsize = (20, 20))4 Z( x; U. P2 R% c, P4 t
columns = 4/ _; z0 U3 o0 Y) d. \0 X( {! {
rows = 2 ~8 k( R1 o$ W7 g! Q0 R) u
& \0 ]: _0 O2 C1 a6 P$ `
for idx in range(columns * rows):
1 x7 e6 x' T4 D4 { ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])8 Q j- G1 ?) B/ [
plt.imshow(im_convert(images[idx]))
, p5 J0 j% j* w! c, n ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
2 n: H. r# v4 R$ d4 O color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
0 Z& ~1 I0 k, X/ ]plt.show(), o* j0 |3 _7 _# e" r9 w4 E% F
# 绿色的表示预测是对的,红色表示预测错了
2 q1 V0 B" h# a4 k4 I1
) g/ R7 [$ m2 Y; k+ u2
* S4 P; ]3 D/ q; S; G( e4 ^37 ?6 q" |! i4 U# C d _! p
42 B& A/ f" l8 Y# q1 l! Q
5
; k" m( O6 z: z, g6
% G; g/ X# o5 e1 Q7# H4 t4 M/ t# x; e. ^5 P
8
/ E4 n7 ^8 U; B7 O- b( [. s9
, r6 F( z9 [7 D, H8 o/ R- L10: _. P n! `3 l6 g7 S1 v
11# s- f; n4 k8 F
, G5 c/ K% @$ k. h
( y: L& j* M3 Q- G m* b
6 ]$ g' G+ w; q+ T————————————————
$ M" g) a5 R2 D9 K# i版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
7 |" O' t F& k/ c8 L原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
9 Q1 D6 z# L7 `% d) Q0 [# y+ ^7 e' T; C% g$ g6 h
6 Z2 a; Q# p& J4 S; J. k. q) u, H |
zan
|