- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 557708 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 172685
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 18
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例9 @7 s8 u3 T" n0 l1 M
, o* N7 C( z' @- I- F文章目录
. v `- [8 d/ ~' x9 n( p卷积网络实战 对花进行分类0 k2 z1 _9 W: b4 L1 z2 u
数据预处理部分! s. a8 K9 D V) O# n! u! Y9 {
网络模块设置' t6 K; U5 I6 K) P
网络模型的保存与测试
" A& ~4 c- j, `0 h4 |数据下载:' J, ?/ [: b! H8 X
1. 导入工具包
V4 N5 u4 t0 Q4 p1 c: o/ q- r2. 数据预处理与操作
4 V9 ~9 h: b; s3. 制作好数据源& A7 o2 f9 U" H
读取标签对应的实际名字6 Y6 Y6 f/ @, U% T' G' N
4.展示一下数据( e( K% ~0 s4 E: |
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
. R. l, X& m& o y. V4 V6.初始化模型架构
! z5 ^* j; v9 Z/ Z/ [% {7. 设置需要训练的参数
0 b2 H& n6 j1 l% a7. 训练与预测. @8 z+ Q) ]) p R @ A+ d
7.1 优化器设置1 y, o: w. d Y o: c
7.2 开始训练模型+ y b/ ]1 d1 P7 f, z
7.3 训练所有层
' F7 A+ d+ M/ C6 l& W开始训练
/ I# p8 a* G# M% C% W/ T8. 加载已经训练的模型
; E" u5 l1 o0 j9. 推理
) v. Q0 c. ^1 r# q9.1 计算得到最大概率% y- }$ M7 q/ c* k* e( z8 t* J1 F B
9.2 展示预测结果0 Y0 k9 {7 n9 K2 x; x5 [
写在最后
! E, e$ H0 m2 M+ ~( e: \* x卷积网络实战 对花进行分类
% ]; S! Z: g$ _; f; q本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
0 V# ^& a5 g; P/ h9 i3 o! u0 n5 V/ i9 `/ p, v* b' D
在文件夹中有102种花,我们主要要对这些花进行分类任务
, V$ w7 q# L3 _) d& z文件夹结构! C, t. a# j4 p |' X
9 }$ A! ^. {' M1 {9 X
flower_data# X# A3 a B. g
3 [0 u' L- B1 Y9 p. H# i& h" e9 m8 m
train
' f$ Y* F# W% J! P" }/ q; t( `; J: n6 s
- f* |# U* s6 H. K1(类别)
1 U1 B" z/ M" D+ f# T% [2
; o. c9 [. t6 s/ P6 _! vxxx.png / xxx.jpg
1 H C' f2 F( R# m0 q( U+ A3 evalid& g1 j* \9 E/ D* T
! h y' e$ v/ J2 c4 V9 \: L |( L s主要分为以下几个大模块+ @. ?1 c1 w' K/ F0 y9 q: K
* [5 D. R3 |0 E8 W数据预处理部分
$ g/ Y9 N+ U$ c+ F0 x, J数据增强! }9 f; l+ ]3 q3 O% G( R; @
数据预处理& O6 o: x. T8 O9 |" h5 p
网络模块设置
" x- g) b( @4 B" A: u, t0 B加载预训练模型,直接调用torchVision的经典网络架构
9 @) W' [7 k! A' N! i% j% i因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务& S8 w4 }) ?& T( m. N9 t K1 m
网络模型的保存与测试
/ M/ R3 I( h' K' ?4 q, @模型保存可以带有选择性
, y2 ]9 @- z# s数据下载:5 M6 P* q) ^$ X9 B- y$ Y) a
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset7 r6 h! d% _( X, x6 h
U( J! z: O7 i3 S1 n
改一下文件名,然后将它放到同一根目录就可以了
$ [* s$ l, \8 d' D' E
* e4 R7 n/ ]: \+ N$ J下面是我的数据根目录
6 Y9 |( ^5 ]; z/ i1 x6 ], I" _! L1 H6 t& s0 b6 i+ t8 x/ H9 x$ ^% q# Y* R
: ?+ A) t/ x2 z6 D- e' a
1. 导入工具包% q+ Z l3 v2 e C. q+ U, ~' w+ r
import os
' O! v; }# R' Ximport matplotlib.pyplot as plt
' E2 ?3 @$ k# G% s# 内嵌入绘图简去show的句柄' ?' G, x; q5 k! J9 a" o3 ]$ Q
%matplotlib inline . [# b/ Q2 u. m2 g' }6 J& g
import numpy as np/ [( k$ ^: N( _& ?
import torch' a1 Y( }+ E0 l: R7 P# k
from torch import nn2 q. z* e+ n9 @, `
; u+ s& t9 {- h" Z, D( r5 U$ U z3 Oimport torch.optim as optim- N! K9 \6 q3 J% O: a
import torchvision# c" n7 V: G9 A3 h1 I" o6 A* |" H9 w- f
from torchvision import transforms, models, datasets
5 m" Y4 F, t" k5 l; `6 H( ^# z- T1 `! I! l& o3 ~
import imageio) O- }% g. ~4 p# o1 W; ?
import time
# ^/ |, p, f Dimport warnings+ Q8 t, q! Q0 z; X, }3 N
import random
5 ~! H' x6 S1 K4 J' C! j* ximport sys
z8 X! ~% k1 H6 ^! S" y1 Nimport copy0 X- o! ~& ~: U, Y @
import json
! [& x- l3 k- Cfrom PIL import Image) P! h: k" f' u o0 h% M
5 H; [" Z* k6 l7 l, G) p$ x7 H1 T# @2 V1 F$ ]; q1 [" r: a; u
1
3 S; X5 W* ?) R4 R, X- m2
' O! {0 G+ W8 {3* E; s: f6 h, X6 [0 l4 l
4
6 J1 D: @% @/ J- Z: h9 f! ?2 B5: e4 J* b8 F" P/ Y/ y4 G/ O
6
Q: o h7 ~3 N# H7
' |+ l9 l$ o7 ~4 y) p" A87 X$ x( T0 o& ?( t& G5 [
9
; I$ |9 s& o' H# [10
3 e* A1 E. a5 O! T$ ~( d' k6 c, i F11, f6 ]. F5 \' |! C
121 n+ Q. Z! }. U7 e; y
13- @/ h$ S$ x2 q# F5 @2 v" C+ H
14
5 b3 W( d' J( y: J! Q& U15
. _) S e- w6 Y- g- y B16
- @. h# l+ i) D2 R$ @+ D5 N17( \# [3 x+ H4 \; p% |$ s
18
4 f) q# _2 l- C3 R; b! a+ u( |; Y19
% G/ H1 \* z& ~* q3 i20
/ Z2 T: Q' S) \8 }( Q+ A* z. o+ T21; s, S: A+ p! A# |" ]( {
2. 数据预处理与操作
/ W* |6 X. D; ~2 V9 r#路径设置
, a% ~$ N, B ^& A% ~7 mdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
3 k t' h d' C% T2 Z8 Ctrain_dir = data_dir + '/train'% o2 ~- f+ s( }* C S
valid_dir = data_dir + '/valid'/ R/ G4 E8 P; b y2 M1 f9 H/ c1 X
1
- I6 i. W e) i, Y2
, s; T% D/ ^/ S9 J; \$ ^8 E* P34 K0 s' |& Q* g8 z
4! M9 y: t# K9 P2 n4 S4 E0 Q
python目录点杠的组合与区别0 J, d$ L+ c# d" ~+ X. g0 r& Y
注: 里面注明了点杠和斜杠的操作8 y" m- L* W& i) y( |
% a3 C W- V8 q- j
3. 制作好数据源
% c' M* b9 M, E! |, U: |( @) Zdata_transforms中制定了所有图像预处理的操作+ G! Z( G- e3 Z) [# x" P2 C
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片0 t5 A# D2 u: O, c) d+ G
data_transforms = {
* [' [; y0 c B8 N7 a # 分成两部分,一部分是训练
3 y) v p0 Z7 @! g" o) s [: i+ y9 d 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
. | q% C, M5 c5 M4 u transforms.CenterCrop(224), # 从中心处开始裁剪
% N4 l6 j- ]3 L- j- w # 以某个随机的概率决定是否翻转 55开
8 |* H# V) W2 ^- I' N transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
7 ~) n B. t; D$ ]) U( M" ]( K transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转! \. B) H! d3 d2 Q1 ]/ @
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
* P# R, V' n& d0 m7 i4 a* D transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
* l1 B v( p, n1 \7 A& H5 I4 k transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB3 O2 b8 E7 R- @: [( u& k y
# 灰度图转换以后也是三个通道,但是只是RGB是一样的& C' @: D( @4 K' i
transforms.ToTensor(),
+ q8 W! y& H$ l8 }% O: a transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差' I2 u6 u: F/ N6 x# Q* y- |" Z
]),1 A) R& p5 h4 W: Y) @7 I; A
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
2 B1 Q" q$ X) u2 y; r6 x 'valid': transforms.Compose([transforms.Resize(256),
; [: H# z+ m* \$ V' K) ~6 w transforms.CenterCrop(224),
! v: \% m5 i- P2 }; \% ?6 j; b8 U transforms.ToTensor(),
! X9 J) L# t6 u" _ S* K transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
# n% G( q6 _! v5 A: }3 b ]),
) T! {) |4 K* S; N}
1 x/ q+ Z1 G3 p* Y* A5 V' U& q$ E" }6 x+ ?) x+ Q
11 p! k' ~3 A, h. \! z
2
& a# B/ W2 B5 ^, _6 i+ m3* T7 { u4 s4 R' D% T% v
45 C' L1 } T% c% @
5% I# s. c$ j, ]% f: U5 J
6
; q7 I9 g: G% o- n, f2 I) h) l7+ @/ ?8 P0 L2 I* H% v
8, ~7 k- S+ o. Z
94 E0 D2 G8 x5 F/ z( Z: v2 s
10
- Y+ y' _! t+ a; f+ i. z. ? @11! J, }* W: D4 k6 M& X$ T5 f
12
6 M& c# d9 _ W+ `13
7 X+ }5 ~* h+ w2 [( N14
: J. V- @5 N ~- {5 c2 {, C) q* ~& O15
& }/ @% m; h' b- Z5 P16
+ x. V1 }& W) K* \$ ~$ g* n' y17" y) Y' E. N2 L/ G4 F
18
0 P# m: {9 N8 p% e( ]; a& Y9 F19
( V* \: z. y2 D! W- Z$ O$ p20
( q1 A+ @& b& s) ~- P5 B21( G: W# J4 |" ?2 U
batch_size = 80 b* N w* {0 m. E" Y1 T7 ]
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
$ f4 d& z( ]- S4 Odataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
" T' w# S K8 B; ?6 v3 o2 J6 Ldataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
$ I* Q/ Q! |! `& j2 Uclass_names = image_datasets['train'].classes! M2 J% Q/ ?% j
* d+ {8 \& [: W) A6 ^, o
#查看数据集合
- F$ c, x% L3 \% k1 Timage_datasets
0 D' F/ o1 V3 A( X2 z) |, x# z1 F5 o- B% m4 k" X) G
18 U) c. o( Z# r! e) Y
2! h3 L7 a" R: }3 ]# L F4 x
3
4 M: b5 ]5 X. n- j) ^7 \4: A% p; q# Q5 G
5! y; w6 `: f+ C* c
6
/ }) B4 Q* E) d; B1 _7
0 j( a4 r" {) O' g% ^8 f `5 M) E82 T7 y2 d6 l5 M- }
9
, g4 |, U8 x( {. o- @( l: j0 A7 N{'train': Dataset ImageFolder
9 ^/ d; {* ?$ D% G Number of datapoints: 65529 ^9 b+ c! c6 b
Root location: ./flower_data/train1 }2 U4 b; \* E$ t) [0 V R$ q
StandardTransform
6 |; C7 N' J- ?. A8 M' b6 ^ Transform: Compose() U! s; q9 G% T4 k" S( p4 k0 @% e
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)) R5 [6 y; {* }) e1 ?
CenterCrop(size=(224, 224))
# [+ e0 p& r! G RandomHorizontalFlip(p=0.5)- k3 f. v& ?- u6 y( y$ ^3 i* U' M
RandomVerticalFlip(p=0.5)
" Y' ]% e4 v0 l3 E% ? ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
5 ^( D! a) S2 W+ K. b RandomGrayscale(p=0.025)# O: y2 I1 x% r1 D. f& s$ b
ToTensor()
& G9 R% ~9 x. l, ?3 g$ { Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
' t) o+ f9 \1 F ),
" ?# K; x/ a7 @7 D/ O 'valid': Dataset ImageFolder+ j3 j6 z% M0 c3 Y% W
Number of datapoints: 818
* [/ i2 H# T1 y$ j Root location: ./flower_data/valid
7 H% t# c) |* O, K/ x; q StandardTransform
; ^8 z8 R4 H2 e" j C Transform: Compose(
$ r, C( W3 i/ {. d3 X z/ G Resize(size=256, interpolation=bilinear, max_size=None, antialias=None), v6 S! e- {" D% ]3 U* i
CenterCrop(size=(224, 224))6 e; j* ]3 h: L# x/ j" Y Z0 o* m" H
ToTensor()' c$ m/ H, q8 O7 u
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
& j5 I3 o e( T& X- } )}+ v* V0 h4 K: O+ B* L6 n, ]
# d2 q4 Y' w( Y. K. ^( P1
) r# ]) n, V! L) I+ v2
2 R, I7 K! F$ G( j8 G3$ R( h2 T3 h/ z' Z& Z% A
4
5 t) f# U* q6 ~; I: M/ d$ _' U57 o ^ s5 j% Z
6
: I6 Y. D# M# @7
5 l* a- p, v; e3 H7 {% ]! } \8& L9 _- l/ L$ U+ a! v6 { E
91 _2 I3 I4 n" m# W0 C
10
7 @/ S" Z9 I3 W" B11
$ Z( m, _2 Z0 t. C12
( H4 N- U$ P) a. }* @13
. ~) W/ s. h' X L. \14
+ ]' ?# b% N3 Z6 P& `15$ h* s! B1 `7 {3 |
167 s- l! S2 O2 ?
17
8 H! H6 L7 S) Q$ u182 u3 R( |9 ~$ ~2 k, s+ k; H+ u' d. L( z
19
5 b# W# Y2 M# \9 d20- G5 [3 s. t/ P6 z, E; j
21
$ g- G U1 M+ Q: g$ J; X22# }& y$ {; d; T: a; P
23
! ?% |. g1 f( R! }. U) w24
f# k/ Z! l d% G# S* u; w1 u$ ]# 验证一下数据是否已经被处理完毕, o* Q5 q6 ]9 J& o$ b
dataloaders! m6 P) N% j3 K& M1 R' @5 f
1/ j* w: {4 f; u+ W$ L
2/ E* c5 ?0 \" w+ g
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,' ]% Q4 w- N3 A( P; z
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}5 h9 ]# u* g1 n2 ]! R5 S. M
1
7 p5 e, H! f2 `3 q9 T- ^2* t" D# i; {! Z. Z, d
dataset_sizes
4 Y8 i) T8 D! z$ V) ]17 m, Y2 X' I2 V
{'train': 6552, 'valid': 818}, v8 e# P9 B* \; t8 o% O
13 i7 J/ H# s! V$ G3 R
读取标签对应的实际名字
" T8 B5 B5 e1 n. c* n% ~5 K) N使用同一目录下的json文件,反向映射出花对应的名字/ J1 H9 t$ q5 _5 p4 z$ E' v
+ X* A" A5 \/ O+ F- Swith open('./flower_data/cat_to_name.json', 'r') as f:
- P" q! x/ P5 v( O; m* G8 ^/ a cat_to_name = json.load(f)7 ~* `+ g# E& G4 b
1: P- O/ Y9 ^9 B4 \- m; {
26 s( I5 u( Y0 v/ |
cat_to_name9 y- y% ?% q% U
11 m+ _& {# R1 d8 `
{'21': 'fire lily',
/ J! G% x! ^0 f, G$ } '3': 'canterbury bells',
) i' a0 X' c) e" S '45': 'bolero deep blue',
1 A- M. j$ \+ |, V '1': 'pink primrose',
p) E! T, f& [7 [ '34': 'mexican aster',1 r8 d" M. e/ n$ i8 W; I$ ~
'27': 'prince of wales feathers',
' m! u( n7 j" @$ M '7': 'moon orchid',
% ~( r* J* o, h& D2 G6 N '16': 'globe-flower'," k) P# K- R' i
'25': 'grape hyacinth',
0 S& v/ K. ~- U% a+ }6 N; N '26': 'corn poppy',5 v% N' e: }- d, V0 |. J
'79': 'toad lily',, |8 y4 K0 \) r6 ?9 `- @
'39': 'siam tulip',+ D5 B. k' h1 u, o, B4 J. x# ~
'24': 'red ginger',
; E) Q+ A* _! X6 a- Z '67': 'spring crocus',, t6 c3 P$ c# ~8 J( r
'35': 'alpine sea holly',
- D) F' h5 A8 ] '32': 'garden phlox',3 k6 P/ c1 q" L
'10': 'globe thistle',
9 a5 n, Y8 h# |+ ~1 t9 T6 B" M8 N+ ? '6': 'tiger lily',! E% Q( P9 S1 a0 [
'93': 'ball moss',, U3 j* G/ p4 A0 z1 }
'33': 'love in the mist',2 o6 m! L4 w9 g
'9': 'monkshood',
1 @: ?9 f( C0 {) s/ a' k9 W '102': 'blackberry lily',
9 c4 F7 C+ v( L7 P$ r$ v '14': 'spear thistle',/ G1 d- J& @% A
'19': 'balloon flower',
0 q( z5 f v; g1 ^5 D- b '100': 'blanket flower',5 A3 Y& g p! w) B* o( q8 l! q
'13': 'king protea',
) ~& a/ \( J, M" }3 r$ P2 a '49': 'oxeye daisy',; ~( s; C- }! i3 f8 B
'15': 'yellow iris',8 b- z% \# {5 z' M g
'61': 'cautleya spicata',
7 o: V8 d/ {# P R3 j* T6 j8 Y4 d2 T' b '31': 'carnation',
, s2 w; [, V7 H0 p5 T- e '64': 'silverbush',
# Q/ L' _( M$ ~ } '68': 'bearded iris',
# m' I5 Y' g) h$ |2 @ '63': 'black-eyed susan',, l( X$ c" V! w* c$ K* I. x1 ^
'69': 'windflower',% i' m# q; C& A3 b7 p4 I) ^* z2 A A
'62': 'japanese anemone',$ O& Z( P) J9 H$ w; G
'20': 'giant white arum lily',
% Z2 P" R/ l \6 [6 ]0 Y' ` '38': 'great masterwort',
3 _6 j3 }9 ?4 e6 A/ j7 h '4': 'sweet pea',5 ^5 E2 h; l+ F9 a
'86': 'tree mallow',
+ O$ x' q0 Q u2 D '101': 'trumpet creeper',( Y4 R. }( C5 _, _
'42': 'daffodil',
2 W6 v) C0 v7 U6 Y '22': 'pincushion flower',0 p( J4 B3 R7 M6 \9 M
'2': 'hard-leaved pocket orchid',: ?7 D' J$ o" ^: Q" T; S( b% y
'54': 'sunflower',# C& {1 r7 Z5 V# N n- K+ q
'66': 'osteospermum',- V5 Q2 q( z/ W) Z5 A
'70': 'tree poppy',' X& I' {. V6 M4 }3 Z' I
'85': 'desert-rose',' b& o( d& T% B9 a3 W
'99': 'bromelia',# v4 x: L1 U( \+ ~4 h
'87': 'magnolia',
0 L9 V" l! c/ s9 r" p '5': 'english marigold',
. v2 r; H ~ B( p7 M# N5 o- E '92': 'bee balm',7 L) i/ U2 V+ S$ S/ f
'28': 'stemless gentian',: h/ h) W( D7 m
'97': 'mallow',
" O9 B/ a* Y. b, o, c6 n2 [ '57': 'gaura',5 P$ }' j6 K9 N! L
'40': 'lenten rose',: G) n: t& w0 q, [2 W9 ^
'47': 'marigold',
& ?0 V- ?0 Z5 \: Z4 y: g2 g '59': 'orange dahlia',
+ I0 i$ K' u- @+ U '48': 'buttercup',
* ]. h8 x' P8 L3 ? '55': 'pelargonium',# C' ]; O# e g/ V0 T
'36': 'ruby-lipped cattleya',3 }8 Z+ {# M& l/ S
'91': 'hippeastrum',0 G. P( r7 t7 t& g, K ~
'29': 'artichoke',* `- b! K0 L5 x& R
'71': 'gazania',6 c" Z. `: F8 w5 z
'90': 'canna lily',2 e+ @0 u# v* T, r
'18': 'peruvian lily',
# K5 ]9 p; q$ Z: M# i '98': 'mexican petunia',
+ N3 B2 p, F; |2 r& R '8': 'bird of paradise',: Q# Z: d6 }& |0 B( F2 \
'30': 'sweet william',
S3 X J. Z9 ?* A '17': 'purple coneflower',5 k8 V3 I! t l
'52': 'wild pansy',
$ \9 i9 L5 [1 D# U& M5 m '84': 'columbine',5 Y- |8 B& f5 ^8 M: K4 B7 H
'12': "colt's foot",3 _" U: S# S* k
'11': 'snapdragon',6 f! R* r2 z: D4 ?9 H6 Q
'96': 'camellia',
5 Q5 Z3 E6 n6 @' U7 d '23': 'fritillary',2 O" T- C1 {8 H Y" t: J$ f0 X
'50': 'common dandelion',) z& o: E, F7 Y2 i$ ~# L6 c( ]
'44': 'poinsettia',+ n& L+ P; E( b* s
'53': 'primula',
7 g$ r& O# K1 B '72': 'azalea',
: f% Y+ v$ i1 C '65': 'californian poppy',4 ?1 L7 c. x$ t; o
'80': 'anthurium', ~; Y$ E0 K7 d( A p
'76': 'morning glory',
* n/ s$ o2 G; e" N '37': 'cape flower',8 u# q; \; H) ~: h4 N
'56': 'bishop of llandaff',
8 n% L& |0 R8 C6 L '60': 'pink-yellow dahlia',- x/ ?1 H1 r6 _
'82': 'clematis',
# K( e$ l4 S- S. [4 ?. t '58': 'geranium',
/ @5 I& _& n( K( R6 w '75': 'thorn apple',9 h) c/ n& ?* q7 V$ e1 e
'41': 'barbeton daisy',
1 W4 O1 P7 h# W# D6 r '95': 'bougainvillea',
9 r! ]/ |9 t6 z) @" ~ '43': 'sword lily',
8 d4 r0 n4 v) x8 Q4 h3 W3 D, R '83': 'hibiscus',
C! J, z5 k. ?% m7 T( g '78': 'lotus lotus'," Z# z5 ~5 a! J* w/ j# e: }5 Z% F5 h+ E
'88': 'cyclamen',
+ J; C) f- C. t' w: n: l '94': 'foxglove',
) b* ~% [- s" S8 C6 V6 T* o- i6 ? '81': 'frangipani',% C |5 G* C# @' n" Y) l3 v9 j- u& B& _
'74': 'rose',
; |2 S* o8 T# } o3 F; _1 Z3 c '89': 'watercress',
. \" _8 {; d; B3 j, C '73': 'water lily',
( e( C" C5 g3 G3 d1 o: \ '46': 'wallflower',6 H5 f Z% i) V
'77': 'passion flower',
w- I4 m" N+ B1 L) z& \ '51': 'petunia'}9 Q, r, s! V! M0 G$ ^
% E7 N3 J+ A* u6 ~+ x) ^; B
12 e+ Z8 j. B0 F4 e/ U9 _# F/ ]
2! e# Y! E6 V* f& d& ^
3& l* A2 T4 c: ^6 Y u8 Z% r
4
8 ^2 i: l( J" f. z4 \5+ O1 { a, U% @0 Y
63 D+ d% N, V# N& \- L! C$ N% {
74 k# @0 T1 Z; X) e8 u: E5 R
8
4 ` ~9 [2 h! N4 [2 Q, e9& z5 F2 V( v% S( Q4 w
10, Z, s" p) X) q b
11
2 ^5 Q: O: `4 B12
; J% R& p+ b- v/ B7 G" ^' X" n13
9 Q7 H! z- v2 L14; v# D; ~4 ]9 f/ m
15& h7 q" H& \% c9 |
16# `. A" r0 k1 k6 X
17
( X9 x2 V5 Q* o; T+ f, a% W! {18, f0 R3 s) y9 [! N6 @
19
- Q1 @. E% ^# Q) L8 ?" ^20
5 Z7 g2 O/ }8 \( w/ j/ {212 v4 C4 H& l4 e4 o" F% [, b# F
22* W. y) A8 T {( F2 H& t
23
& K' |# j# d2 l7 w6 n, ]/ X! Y24" ^; [$ u; X0 B: n% D7 c6 I# C
25( ^( s- F9 y8 e$ W; S) n' o
26
) B* @2 M- H4 U D! L( q% V27
% }& x' ]% D, W/ l- X+ l/ c" R28
; M7 G8 s7 [ p5 j9 Y' T294 F# o* ~8 |' k. Q; l* l
30) T$ h" M6 j; I3 S: n: j
31. h% k* O: ?( ^
32+ J0 c. Q6 {+ i. X2 p! h2 v
33. T1 P# x# u+ w+ Z
34, m4 g- B. t1 e- Z4 e9 g# y2 ]
35
# b; {; O" w- s( f36
+ S" K0 s$ a9 r; t4 c37# ]( e2 q0 O+ ?8 D! f, z+ m; E8 G
38- _. o5 V& \9 D! V) f
39
Z& |% x7 P5 Q40
& g' o3 W- H/ U5 r$ j0 E41
: p: ^4 \' R% f! R" l42$ j( `8 z- w, K u- x
43
/ ?( w5 f, {& T, Z" s1 v44/ T* l3 B+ N+ A" c
45
' i8 L' z6 D! q5 v$ e46
5 E4 Q* i1 V, E; b S47
, H+ W! {. W/ a3 W; T" w/ B48
+ X) J- [0 S8 E' o B# `/ h49
' P6 {, ~; H3 {4 [3 ^50' t1 f5 g' S- y4 C; B
51 e7 W7 B3 g2 \2 Y. s4 _; z8 N
527 }/ B$ F( `( @8 U! R5 N
53
/ K* |5 }; G* q8 U54% y5 V) R0 c* w/ U. ]# O3 J
557 @1 h# ?. B0 r {8 Q `
56( Q. U2 l& v# i. j6 Q
57
- x/ f+ J4 G. a! o58* s+ O. h/ i% F. H1 |2 m. }% G3 t
59
/ P! k2 b/ U) i: ~60
9 K6 }) o2 p8 i' r" I" u f61+ p( u$ ^! c" [ i3 ^1 j: L9 d1 W
62& b9 W4 X4 ~4 S: x
63! z, _( r9 X+ P4 e
64 {+ Z; i, |. E6 c- j
65! m8 u% V% w$ C( c' O2 p: a
661 \: {+ i( A% j! T; P2 Q1 k
67
s4 a# v6 c, _4 z6 h( |4 Y3 }# ^68+ Q. S% K& q& k. K7 l+ a1 z, z
69
7 R0 }" l: { {707 u$ [* Z& q5 i& \4 h/ i
71
: ~* A) p9 [2 H. j7 U: s9 _72) m9 ^2 s) i: w f0 _9 f; D+ Y
73+ x9 v( j# ^; b4 e! S0 j6 ~- ]
74
# t* s- K) A; P" K1 x) G' f0 t75
; N; n7 r* q$ P# U2 ?& h P76
# u5 u8 Y" C' N# h$ [77+ z/ Z0 t! ?& N# ]1 [" y
781 \! O3 Q5 t' Y# o7 H5 p
79
; k+ {& l" b4 E1 A& W' f80
' S- b( P; X& u2 l815 f6 [5 @9 v) h+ ?3 @) u
82/ ~8 s+ v% h4 {4 |9 O1 z
839 `) k+ }8 W3 S
84 o |( w$ H% D3 q8 l, b
85
9 _2 s* L6 O/ y86+ o3 e* {7 l# W1 s: y7 A! k7 f; Z/ G
87" J- V! X; A) O/ w4 D* i" q7 f6 F3 Q
88. r) [: L) @9 e
891 U2 t. r; [+ V+ \# p2 m8 v
90* }& L2 A8 [, M; @1 R+ I1 g# Q: A
91
2 s. J+ H1 j& z6 H+ n- k92
/ Y4 E- |" Z; G1 a93
6 G; @) r& v! v: H94. f$ \( ~# Q+ c" A* I d, [
95
; _5 p: M* v7 @8 ~% V2 d/ }96
9 `3 R5 ~; n4 }5 V97
& ?+ Q: c* K& p98
# D+ g% m( Y# d. s4 `, [99
6 {7 }/ @% T2 C+ A) Y2 [1 ~% y) w100
! f6 f0 J y* C: h$ t101
`. o/ Y/ r1 K3 I102
# ?3 D2 \1 Y* `( v/ S" g4.展示一下数据
# w0 }" o7 q. Z# _def im_convert(tensor):* ^: p+ q6 t3 e) m! O
"""数据展示"""; a5 H K- F, ^* M3 Y& n2 C3 B
image = tensor.to("cpu").clone().detach(), r/ i0 `0 ~2 U- y
image = image.numpy().squeeze()
6 E$ q/ n4 I( L* x- [ # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图9 c5 K7 h/ M" x9 n& ^
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
; u% w; f0 n$ v) s# r image = image.transpose(1, 2, 0)
. f7 ~) S$ ~4 C5 s3 h$ W' W # 反正则化(反标准化)5 B2 c& d( Z- p1 b
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))- ], v4 g, X0 G/ S; }% ~( M
+ S% P; Q6 U) y8 [; d
# 将图像中小于0 的都换成0,大于的都变成1
: T) T( W7 [5 S/ O image = image.clip(0, 1)4 b [: @! }4 S% h/ L
& w& l' u. w- }% c return image
0 P2 o: |$ y4 |4 v1
3 p/ A9 k4 t# y0 P0 |# R2+ b: ]+ ^* B. Z1 W) q
3, F" J- F+ P* j2 C5 r( b5 e
4& J2 j9 ?9 t+ B; [& h
55 r y9 L1 g8 z0 B v: Q
60 U7 ~, Z1 D% L9 V
7
/ ]' z8 K: x1 U: a. @+ ?8
. V) i/ \- \: P+ O) I9
; E/ ]( K9 n/ v7 j* {+ S0 ^10* d$ f: B3 w* P N
110 g' V' I9 H# r* ~/ ^
12
1 }5 L; e# d2 X139 O0 L5 i; _+ ]( ~- x. @! u) E
14
$ _' C6 ]" b/ O! }; `# 使用上面定义好的类进行画图
7 t; T& W* I$ A# efig = plt.figure(figsize = (20, 12))
" K4 `- G( f0 j- t% m; I5 ecolumns = 4
% Q% D7 w! O9 K E: t% _rows = 2
4 ^0 P/ |0 N9 D$ V# M" y4 R' l% v/ s+ ?& f
# iter迭代器
' e( e/ s" ]! h8 K* [; L# 随便找一个Batch数据进行展示2 U- H. K- z$ X' D0 J5 g; _- Q
dataiter = iter(dataloaders['valid']), I: o5 R6 A( c. T3 A
inputs, classes = dataiter.next()' C: k+ R7 ~2 f- K& o6 ^! y
: c# }- M2 Z2 a4 |2 s* [
for idx in range(columns * rows):
, [8 K! k0 i( X+ h" | ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
% W7 f+ C' v" s) Z; [0 ?( P+ R # 利用json文件将其对应花的类型打印在图片中* A2 d( B. ]; ~$ v5 c% u$ B
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
! d) G6 I! a/ S/ V plt.imshow(im_convert(inputs[idx])): z9 Q3 m7 c0 K
plt.show()1 k8 D& i P9 Z: q! i" t2 v
! d) u3 T$ X: K# K: c
18 K) G% c6 g+ A0 ~1 L }6 h; ~( j
2
8 T" q- A% ^) b( i+ k) [% d3% X# p4 M6 X; B; b" e" h/ t
4
1 j! A" p1 M) o+ T9 b2 d1 e5' ]+ @) g3 x( b' P/ e
6$ P5 M2 d0 i& M4 h1 W$ b
7
- B$ p+ [8 w8 ?) M8: H# v B. f0 l* U8 u
9: C* T! J5 E i, i
10
9 G, ^8 s. j' c- ^" c9 M11$ L* `5 O& I. F5 U* i& J
12
* o w* p4 y( u2 Q- `7 L6 M13
$ x+ q: H$ W, P0 A, O" K141 k; g7 l" ?3 r
151 k7 z0 P2 f8 O& k: ` X- Z- H' ^0 b
16' p4 b# [. ~& t& D0 u4 H
4 ~" S" j6 h* ~; x$ g' V# \- q3 w
2 H) @4 n, r: @! v5 V; d, E5. 加载models提供的模型,并直接用训练好的权重做初始化参数
R3 f) e6 L- a3 c1 ]! [( xmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']% }+ U F0 n9 E( S; V. W
# 主要的图像识别用resnet来做
+ T6 d+ u Z; r, Y7 ^# 是否用人家训练好的特征
+ {+ h# @- D( w8 r0 kfeature_extract = True# p, |( U) i ~. M
1
1 w" j9 _( Z4 O {% f21 k5 z0 G% Y7 j+ Z4 [% U* z1 A( T
3! \' G# K( s; N% X
4( Q+ j8 Q% o: k
# 是否用GPU进行训练
; e- W$ T; n/ T; ^% q5 Xtrain_on_gpu = torch.cuda.is_available(); ~2 U9 a7 }% B+ Y! |0 i7 R' B- p
9 l& E+ ?7 x" I _" \/ g
if not train_on_gpu:
# F( `( ^- a4 O$ j& \$ o y% ? print('CUDA is not available. Training on CPU ...')
3 _2 K. \9 L0 }else:
1 I+ c1 V/ i, O9 U( s1 G3 C print('CUDA is available! Training on GPU ...')
6 p/ ` ?6 ?# L0 P4 p# m
5 n0 u4 S }' u* E/ |. Kdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')+ i( l" D( `2 ]* E3 Q! |4 _
1
) A3 ^% [4 g4 [/ ]& {1 z. V2
0 D" G7 t d- v! F9 W$ p6 v3
' S+ K' R+ a) z7 K" a4+ V5 a& y4 V- m
5/ l' e8 J0 o1 Z# k3 r/ x9 o2 \% E
6) N) A9 v/ s+ r' X, O: ?" z9 a# T8 w
71 d1 J( m1 f' M+ f8 V( A: D
8% e" \9 V/ B# l2 C& G9 k/ y
9+ \# |& w0 h8 B5 p7 G$ m! y) x2 ^4 r
CUDA is not available. Training on CPU ...2 W0 g+ w6 p( U6 d; w3 b% ?
1+ A2 b. X: e. b. k- ]
# 将一些层定义为false,使其不自动更新7 w7 |+ i" h& U. V: U4 H/ I2 h
def set_parameter_requires_grad(model, feature_extracting):& k, O6 V1 T6 }$ _
if feature_extracting:
! y! c3 d+ e# Z0 Z# t ~/ J5 T& U for param in model.parameters():
8 v3 A0 N6 e$ I! d) c param.requires_grad = False+ Q2 i% z, k2 d5 D4 K
15 J5 o z8 d% F
2% r4 Q! h( K% o) u9 f' X' c
3
2 N7 w" H, g5 Z. |1 _2 z' K4
$ a' D+ ]# V" o' W% F" B$ I3 U5* R e2 J6 ]! A/ I
# 打印模型架构告知是怎么一步一步去完成的& G( n+ s) R2 ~0 N
# 主要是为我们提取特征的
/ R2 x( Z( C, S5 C7 q+ _, P7 t3 w7 w/ u7 Z- `
model_ft = models.resnet152()
5 P! Q( l& o; e! X3 dmodel_ft0 c# t, E4 P4 q9 r; |) @* N" e4 l% p
1 b: T* X+ U) a) W6 \
2/ T( z; d5 F) ^
31 k, j7 n9 H' [. S' I3 C
4% N# Y' K) r* u. F& w
52 H. Z( S4 J% J% m+ U n
ResNet(7 V' N' K9 \! V& A4 _6 F
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# U$ {7 P" N7 O: T" O% F) ` (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)8 z3 R/ s* B1 ~* K4 Q% h$ d
(relu): ReLU(inplace=True)' V9 b* A/ K; y" J) v
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
2 n% o' B4 U- c7 S4 V- D6 y. M1 I' P (layer1): Sequential(, ~5 G* X) y# E* w1 t
(0): Bottleneck(
4 n$ D# ]- u3 u1 Z# \/ ~, y9 v (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)0 X! V9 z+ x% H5 J( `/ h
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): @" C, e+ T) k' w3 T" t( n
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)# U( J( V4 ^6 Q! f) F
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
6 n/ F0 x6 P- x- W( h (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
# S: v v; q5 Z( e9 v0 m/ z (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)' j+ ?6 _2 K6 A, H0 V! B; S9 ~
(relu): ReLU(inplace=True)) Y. R3 e2 }1 ~8 D$ h
(downsample): Sequential(
# E( I; J5 i# \( C/ {# q# U (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
1 j) d9 D' z' G9 q) L+ V (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
4 P# e" e% K1 G6 B! [ )+ k" i: h5 @# d# ]' h2 p
)/ }4 ~, O1 b- m b, [
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
2 h/ M: q5 k+ t1 a3 Y* z (2): Bottleneck(7 ?' S7 D4 F) d) C
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)7 `) P9 T8 I) s1 O* t
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
5 ?) t4 m# Y9 ]( d2 c+ } (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
0 U* F7 n" S X% B' r- Z (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)8 Y5 f& D' f% _2 S) J# A9 n& B- g
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) H5 m, V; c% \: y* O4 e! M
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
$ e! W' [4 x: ? (relu): ReLU(inplace=True)& F' g9 Q! M5 D+ z! V: U
)
6 j9 X3 d, _$ Q/ u; C/ y, w* B )
# _: B3 S( b* U( b) z, [3 j (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)). l. u' Q( _6 b
(fc): Linear(in_features=2048, out_features=1000, bias=True)" ?2 T& ?0 p; z! u
)# U* G3 d. a& d2 _, n6 U" ?
; }% p8 t8 [: a$ j" b4 g" S/ i13 h& V* K0 N$ \$ t5 b0 M3 u3 l
26 ^' o6 z u6 F: i
3 H Z X( C. M8 t4 y0 z c5 Z
4
! P- [9 } ?. d& m5
9 S$ j6 X8 H/ g& v+ S* N- p6! h( X8 V8 K: `6 ?+ F, J
7
3 H! }4 s; ]- K a3 Z/ H8
* _" j' Y( c6 M$ N' d9
1 S: a5 K+ z" ^9 i10
/ ^0 `/ w5 w$ H, ~, ~) A8 Z11
1 [. l1 q6 G. c8 O, ?2 M3 U12. f5 b; B1 z5 R' b0 _) C
13. M4 ]$ R# E. v7 ^ P/ h# b* R
14
% H/ b& t: H! f4 {1 ~# L5 d8 j+ F) w15% _$ b# E6 o; p( R2 z2 ]
16
6 P7 h/ A$ d# ^17
$ l) |: |8 S& G& ], X18' S* J0 x4 X1 E; _
19, c, B) Y$ I. z4 G7 ]: S# j
20: R2 a5 ^" G9 M
21" B5 O/ m& T% y2 {& J" R# [
22
3 o8 ?' q( Z4 l: q- W( u# X23
7 Y. d7 g n# n! J244 u q E% L0 [9 M- ]/ q
25. l) i5 J0 S* p# q* k. S* B# P
26
- ?! ~8 c( P" M1 B; s27/ ~0 I6 a3 G7 k# X# Q( \9 b
28: l* a% F, o/ S0 B# E: B
29: e0 }4 B2 O6 ~( k# h
30/ x9 Y G; W( x8 N# d6 A
31
8 I2 Z# T6 I: g! B32
) H7 @5 r; A! |2 i; M: H0 O33
; @* ^9 |8 x ]+ Q0 O最后是1000分类,2048输入,分为1000个分类
; N3 R7 T/ i0 }4 q7 }( i! x1 n而我们需要将我们的任务进行调整,将1000分类改为102输出" J3 E- L6 J' X) V9 a
b$ @) k/ S2 x4 c- y3 w5 \, t6.初始化模型架构5 G* w$ H8 y' V+ \4 H. o7 Q$ a9 V
步骤如下:
& s8 [1 \" [0 L1 }- ^
- X8 a2 z6 g4 Y i9 U$ t& F+ m将训练好的模型拿过来,并pre_train = True 得到他人的权重参数8 |& }- Q3 }6 ?. w
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)9 _! G) y* r% Z2 v0 n' S
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数) Q3 e" c% X/ \/ b3 X# P3 Q ~
官方文档链接8 C( a4 S# F) n# p1 ?5 A: q7 T
https://pytorch.org/vision/stable/models.html" e o1 n g0 v) v
" J# t* H1 w1 y& H# 将他人的模型加载进来# B% R3 `- Q8 K! P$ S
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):- y9 p: p7 J2 B/ c
# 选择适合的模型,不同的模型初始化参数不同0 |7 ^/ s( C+ t2 w$ |* N! f
model_ft = None
' @& e1 g& D" n! V input_size = 0% _3 M; V/ g& X0 t( |' e
& T. i" b" j+ j if model_name == "resnet":& K5 j; U( C8 z0 F9 M W( I" Q
"""
$ K1 t6 Q7 g5 o$ w$ M Resnet152
; v4 C" ^7 T& G* U+ R8 ` """7 U J! f$ U% w7 z# w' P
8 I' `* w. W% W$ J/ W. d
# 1. 加载与训练网络
, N& u1 {7 W* s model_ft = models.resnet152(pretrained = use_pretrained)
" A$ m0 g# r$ D E2 F # 2. 是否将提取特征的模块冻住,只训练FC层& e$ L% L: T' h* E
set_parameter_requires_grad(model_ft, feature_extract)
( B2 z% j- C& e6 V # 3. 获得全连接层输入特征' {$ S( F8 y, l- @ ]
num_frts = model_ft.fc.in_features
+ n3 O% j9 R: h6 \ # 4. 重新加载全连接层,设置输出102
/ k0 W2 s! U; ^! w; W model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),0 r- g3 J0 W+ s1 g3 a, K
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为18 s1 U! P# J9 e* G* E) B
input_size = 224
. l8 i5 [# J) S% M, n7 m
$ _5 G8 G4 Y4 v+ X' R7 L elif model_name == "alexnet":
1 Y! j6 l, j8 W4 X0 [, O& W3 _ """- D" M% U6 M" ~
Alexnet
# h1 b2 I! f6 l# `0 W' g """! r" z& U$ d! |
model_ft = models.alexnet(pretrained = use_pretrained)
/ ?: A3 u; `. c9 s P9 S" [% o4 ~ set_parameter_requires_grad(model_ft, feature_extract)! x2 p1 M+ u' T+ v8 P! R
' K; F& w- ?7 `" s& _8 Z # 将最后一个特征输出替换 序号为【6】的分类器' s, J" a7 ?0 ` I; K |
num_frts = model_ft.classifier[6].in_features # 获得FC层输入: ^ |- C5 |5 ^+ j5 s* b
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
3 ~3 e4 ?, E% ^5 M$ N input_size = 224/ Z0 e( k! U0 V9 [$ ^
- @* m5 Z& \5 @/ M elif model_name == "vgg":% k" c2 F8 Y. _# \# ~- c
"""
+ F9 M3 T$ @6 ?2 e( a! \$ f VGG11_bn
2 G2 l$ i# F, m1 E1 [, w """
, K- t8 U# M/ Z. ^6 x" h1 z model_ft = models.vgg16(pretrained = use_pretrained)
4 _2 {( ^0 t: Q. B8 g, t( G set_parameter_requires_grad(model_ft, feature_extract)
' W. j$ b- D7 G% \ num_frts = model_ft.classifier[6].in_features: e8 q0 ?4 l& r/ V" @
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)" m5 e) m) A* M _& Z( l+ H
input_size = 224+ u1 b* _4 j2 T- Y, i9 y" ?
4 {( y+ y* ~: E+ {" E: L/ n
elif model_name == "squeezenet":- }+ h7 C7 N& R$ R- N, @
"""6 k' A: h5 s& Z" o* C* M$ u* X. @: w
Squeezenet
9 S4 O0 B( G( P """
, M( C& q: A( D4 r1 a" } model_ft = models.squeezenet1_0(pretrained = use_pretrained)( m8 Q( S; n- }: v) O8 s* @
set_parameter_requires_grad(model_ft, feature_extract)
/ F: s) l y, u) B% q model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))7 z5 ~, k1 q2 B5 K
model_ft.num_classes = num_classes4 N6 C! k( `/ x, E7 _, O( K% U3 b
input_size = 2242 @" c& S$ ?5 p, T/ t! [
% y) r" C* m6 g8 a U
elif model_name == "densenet":" d) g; x0 ]: l- C5 r
"""
1 A; [- g' Q: V( m. A. ]+ S Densenet) }3 m0 V6 ^2 S, M; q T
"""
: Z( B' L* ?2 O7 ^8 ~ model_ft = models.desenet121(pretrained = use_pretrained)
5 N9 n6 O( b v; a set_parameter_requires_grad(model_ft, feature_extract)$ m- |$ Z$ [7 h6 C0 ~: m1 o0 C% }
num_frts = model_ft.classifier.in_features- p, g, E4 M. Z+ a' g5 [9 |% V
model_ft.classifier = nn.Linear(num_frts, num_classes)
: W% a! U, d% }3 l7 e$ V input_size = 2243 A; n% i: t0 a5 j' s
N) S8 r% R" b& X& W/ A7 T" Z5 R
elif model_name == "inception":# l' j& U: O8 W* C R8 w3 E# Q5 O
""" j& o; l! f5 o3 m5 f" E5 B
Inception V3
/ j2 V2 A* }3 X8 Z( {& {4 c """
5 k; _, s& K- y& S model_ft = models.inception_V(pretrained = use_pretrained); U; E: p% j* y4 F
set_parameter_requires_grad(model_ft, feature_extract)2 l! J$ v: R# D. d. p" K
( J7 L( ~1 X C7 |& D9 v num_frts = model_ft.AuxLogits.fc.in_features! }# s8 N2 |9 v1 Q3 f% x4 k
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes). d0 y, |% ^% q' v7 [. m. F
- X' s' z" s' b/ l4 [" U
num_frts = model_ft.fc.in_features
$ Y/ V; Y a3 s4 z# U: @ model_ft.fc = nn.Linear(num_frts, num_classes)
: E$ M7 ?" q0 ^+ K input_size = 299' `! k9 P# A) d* U l+ q
6 J5 \+ p6 _ V. J: L% m
else:
# X' x( x- G- b+ \6 J4 ?7 l! O print("Invalid model name, exiting...")
% E4 Y4 l5 V! @ F exit()
1 E9 m/ [- c' y0 g
" v& a0 s5 H. {, U1 y3 ^, g return model_ft, input_size; s6 g# C; z8 T8 Q. i# [
9 h+ z6 Y# T9 ~+ g# |4 ~
1
, `& z: i7 I; v/ E1 f2, D3 y1 b( T9 m$ L. a3 b/ m# @
3% w) Y. n' t U/ v' e
4
. I' ?: s: h5 P! Z( v% U: `5
/ o1 z$ S9 n7 Q6 q6* R$ C3 Q5 s- n$ H) G0 y$ C* i
7
7 ^$ k/ \2 {1 O, A) N' h! i" d8
! F, U: M3 |2 }, v9- g, l3 W/ A; x/ h$ |# v
10
( ]# Z3 e( \7 E- M0 W2 D* r11
, B4 c+ D* E Z4 Z8 y+ _( s120 ~) H/ r2 t& }5 t! T" L& H
13" s4 t4 ]4 V% e7 ^% C% L- S' B0 R
14
) r0 g. [' x9 S: \- F; h6 C5 q15
9 I5 S! v! R4 V$ _- I16" F4 _* a" x) W+ G
17' z( w+ S3 X) J% l1 b- u9 |6 L
18: L" R4 B0 C# _3 b" A" r# `3 W
197 \5 O+ F# c( z2 N5 Y# k$ o
20
& e6 [4 E: h4 h7 n) G1 N' p21
5 u! I, B& c2 M5 h0 E22 m: L: a5 k Q% L5 H1 @
23( ]/ B3 B1 @9 z) f9 a
24
5 T5 b$ M: Q B252 A- E+ m9 ?+ t ~* r# s
26
2 A! K, Z" t1 b% P! o# |" y4 E+ W271 j) O6 F4 q5 _9 h6 E! L5 J
283 ]% }: D8 h; ~- ~0 d s4 q; B
29
7 X* `) S/ \, k& L: ^' _& \30
4 g9 E0 i* r. f6 n31
, ^$ P" o5 P6 m4 C9 a( ~, M- i- c) a329 ?! C! u( D2 _
333 I0 w1 {$ n0 r+ `: ?9 a( t
344 H6 @$ i8 J+ m+ N0 i+ G5 a
35' ^2 n- F6 L: Y& x3 W
36
6 {; q2 W$ T: u0 n: F& @37
7 K* F' o! h& I1 c38
0 D5 I( X; j' y1 K# }8 L39
- d2 C, o! y4 h40
6 x$ E" _7 _2 G9 P% s6 G( h41: L% B; a6 o8 S* P: C. |, t% U
42
' }/ h1 G$ d. l* e43
) ^; h$ I' v0 {44. p S2 s; C# Q7 R0 Z' D h
45. r& A8 K9 ^/ f3 u# o6 s+ C7 Z
46
) h' a- B/ t/ F47$ S/ O+ b7 s! D% Q" k; b7 c
48
: F) z, T9 H2 S0 l6 ?& k49
3 Y, S" i6 O8 b" ]2 V$ r50
' h: ?! t" H7 R) K% d' z% }518 r W, P- I" D- t% J7 s. k
52
t `" x1 @" n; h$ e4 |7 p53
8 ]2 |7 `1 W: Y54
/ ]' v, O( Q- ^55
4 v- \5 _' B5 r6 |% _$ D/ `* Q56+ N( M7 x7 K: W4 m: h- y# A
57! S' H5 u, P+ ?* r5 ^4 c* O
58
+ v2 G$ d7 X# p% D6 q59
9 r! Y" S, N F7 O/ z60
; h1 W& m7 t0 i/ u" I* a( u: i% J61
% g2 b7 `7 M- [2 @62
+ P0 W! S6 y. x9 b% H; W5 \636 W- x- ^. B; r. I
64
( ?& h8 L8 Z* }( o+ I) s0 @3 j/ [0 `65
' s" y) O3 I* U3 W2 k, S" {66
; ^- {+ }; j; U) B2 G678 Y- o0 X% Y0 ~/ ^, ~$ k% x
684 N; L: T4 i% I
69! Q/ M* y( q8 A0 D8 M
703 G0 m. i7 q2 _6 p# v3 U
713 A6 |3 `- v: S
72) p( N7 Z! x+ ]1 A+ n, }
73
9 g" v+ x) N/ l8 | O$ n" K2 `74
) C' ]" [/ |5 \/ c3 h5 Q i75) u- \4 _& R1 j8 t% r
76
E! E) f4 M4 \% w4 y% F9 ^$ o77$ L2 z( T+ Y: H
78! e3 `- T6 ?" h( P# t& M" U
799 V# C# o, I2 `! V# U8 e, y( l2 K
809 v" s) H& L/ U. z7 h
81
, f' N& I) Q, {5 C: ]: B) E82
/ h7 B& u, @4 g6 s83
. C* K3 r! i) Y0 V. V* h7. 设置需要训练的参数* [2 y3 j( C( ~% c; d r h: A$ H
# 设置模型名字、输出分类数
8 j i3 m0 f2 ~: Y3 [model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)/ W5 }. a* C/ P: R9 I1 T% I7 r& h
& n, m2 z) f) D; j t5 q' v
# GPU 计算
" s1 G" W/ x @' ~model_ft = model_ft.to(device)" i. h4 |$ X( k" T, \* q% M
) G; |9 W3 Y1 j3 w" j4 Z. Z# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取) E! N% W- L& E8 t. ^8 l: G
filename = 'checkpoint.pth'
/ i: \# M7 C+ s* V1 e5 q% |/ ?+ g
/ t) r0 a0 h. @$ y# 是否训练所有层0 M n2 c3 i: Q2 s" [7 E
params_to_update = model_ft.parameters()1 j% Z& E8 v# v0 m6 h
# 打印出需要训练的层
# e) l) i# D3 P' ]- _& f% bprint("Params to learn:")5 _ V( I+ j$ p+ r5 V
if feature_extract:
/ n+ w) m8 e- u. J. i- U% q2 K params_to_update = []8 q% i U X' v
for name, param in model_ft.named_parameters():
' M5 U' s# C5 D& S if param.requires_grad == True: m2 l# ]. t+ R4 Y$ [, ]$ q4 w
params_to_update.append(param)
8 O' ~$ d" i$ ~4 w6 f b" Q" E) d print("\t", name)' R) }1 g+ n9 L9 e) c
else:# `3 E9 f( ?6 Y, x5 J# a2 T' M+ v) r- r
for name, param in model_ft.named_parameters():
6 r8 u! f) m2 ^) Z0 E$ ^ if param.requires_grad ==True:. `8 n8 Z: y, C. l- [" F( j
print("\t", name)+ E6 }0 O' _6 ^: x5 w1 M# r' y
+ Q7 S" e4 G1 s. f$ R7 p( J4 W1
0 t$ l/ L* d# X C, r i2
2 d+ ^& e% I& t5 V7 \' W3
% @) y+ L1 _+ s( T$ l49 u: a" F" T2 q: r8 V& R
5
) Q; K) `% \( ^! s. X7 l" \- Z60 D" A- J# U3 l" \* S' `5 u
7
- A9 E$ Q: A2 x# ^% z8$ ^$ b9 H6 _# t8 v( D3 r
9% W3 B9 _. B, R) |0 Q1 I0 k6 y
10
# @( d! |* b+ H: d# t11
: w: o) q$ Y H12
- `2 {+ F4 q9 x& ^13
, l" ^( i2 P. F% K* F& i* t$ u4 P Y148 A' [3 O7 A- p. l* y! h
15
* w! e* u. i" u! S165 d5 ^4 s: b3 b, W1 T
17
( Y0 H# |( R/ j- y5 L6 L+ B18& m) K# j% d' r7 h8 d3 f. f+ h
19
0 o: N/ g5 f% G20
) O9 P& j% d4 U3 h5 t9 s, `1 j3 K21
: b: Q. x4 U! e0 u. K( p5 p22
" [) F% A* T$ z1 j$ c# {" i" Z23
6 m# f2 S9 X+ p U8 f+ mParams to learn:
. ?( i; k. [. L" Z' A& Q+ g6 b fc.0.weight
2 w! N# q1 v; n$ ~0 m K) X" N fc.0.bias
7 k% b. H; A! Y0 r7 k1( m/ y; Q4 a4 h7 Z7 i( m( K! ~
29 s- X' M( _! g% P; a
3. H& V' c6 g t" n( b7 x, e& o1 }# y$ j4 }
7. 训练与预测
; v d, p1 z, }7.1 优化器设置
; `0 r, \. s3 [2 Y9 p- }3 R9 Y# 优化器设置% _: ]4 ^6 C0 U. A9 ^! ]3 @
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
) A7 e+ O1 | @# S9 [% b) `) y# 学习率衰减策略# y+ `5 }" C" m: l- z$ z$ b' U
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1). e" H# K% X9 o9 o+ l; y* k
# 学习率每7个epoch衰减为原来的1/10, `- j# S7 A2 C3 k
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算% y& v4 m4 O9 c @5 T8 {
) `6 a1 ^- i' J( I) G3 f$ l
criterion = nn.NLLLoss()
}1 R" x! K) c3 k9 T k6 t1/ B! B7 b! ~8 Z& y2 g! z
2; W' v% n9 c& ~- ~5 ?
3
6 q# s' Q) _- Q; B& ^: x% r0 x4
+ B' G4 X5 W" q8 E3 {' n" q5
9 E7 R7 m8 Q; k- b9 G2 q4 v6, [* q3 i, d: }( U# u7 e6 b
72 A7 H% E6 }: ^1 z# v' ~ `/ c- |! h b
8' Y3 J# ^& j! }& X% e
# 定义训练函数
- ]- U, x# y! @#is_inception:要不要用其他的网络
; ^! O0 g1 A1 {: g( ~* x4 idef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):" j# k0 H. f. H' C8 Z: _
since = time.time()
6 ~4 \' y2 r5 A. c$ {' R" U #保存最好的准确率/ V7 Q1 v$ |( S" w- o
best_acc = 0
& V* N9 u8 B U9 `: }8 d """2 F2 ^7 D/ }( B) ^$ U6 D8 R
checkpoint = torch.load(filename)$ c' t! c0 d, Q$ }6 E# R% z! o
best_acc = checkpoint['best_acc']- T( t9 T7 S1 N4 u, h( [
model.load_state_dict(checkpoint['state_dict'])8 V" C) U$ M# g, s) t
optimizer.load_state_dict(checkpoint['optimizer'])8 `. o: q7 R: ^: C4 Q9 S
model.class_to_idx = checkpoint['mapping']5 ~( Z Y/ D3 n8 ~
"""! a* a q3 A) ^. p: h- b' i
#指定用GPU还是CPU- k j( ~4 V/ l8 `) G M+ C( O
model.to(device)- l1 j( L' n7 J. v4 C3 L2 q
#下面是为展示做的% V& k% \1 D$ w0 t( M
val_acc_history = []6 V8 {2 H: q4 N
train_acc_history = []9 B5 O- J M9 p6 Y( ]
train_losses = []5 c# V8 r0 E8 A9 f! Y% A
valid_losses = []
! [. \2 m/ ^+ Y! h LRs = [optimizer.param_groups[0]['lr']]
. ]. z6 b6 l, N* R; n0 ^ #最好的一次存下来
! z6 M, ] D0 P- j# U7 ? best_model_wts = copy.deepcopy(model.state_dict())
9 k0 \( c) u; N# K
8 {6 P3 ?) M7 P1 _% X& r ` for epoch in range(num_epochs):
A4 `: O* X" S3 i# J print('Epoch {}/{}'.format(epoch, num_epochs - 1))
; i0 a4 n0 C! {0 o8 \/ S8 _ T print('-' * 10)1 u) Q9 [7 A8 u& d. e& O" u0 D
/ r/ i/ H9 }: I4 j. L* O # 训练和验证$ Z4 ~, h d9 M% K
for phase in ['train', 'valid']:$ C* U, y) ]% d. g! w
if phase == 'train':9 d \ b4 n8 A) y# p9 A7 X f
model.train() # 训练
5 o- D. }3 ?% g2 N8 A8 S else:
; t- e% O, w( ^: [# C9 T' ? model.eval() # 验证
8 `+ v) ~$ O% l+ J t1 N7 k% s* e8 N
running_loss = 0.0
4 C- `+ `, P0 s! C" Q7 u) k* y2 G running_corrects = 0
& ^# W1 b7 Q) G0 e. o7 Y$ _
( n; }# t, M- w3 N0 d # 把数据都取个遍
$ S0 H# y3 _. z for inputs, labels in dataloaders[phase]:# K x( x F$ g, U7 k4 ^
#下面是将inputs,labels传到GPU! u. M9 R! a9 i
inputs = inputs.to(device)8 p7 i3 B0 u* [$ }
labels = labels.to(device)
% ]4 @ q3 m. x, |
8 w3 c ^5 `0 v6 B+ b # 清零
0 i: Y: L% C; x& @ optimizer.zero_grad()
. g/ Y& n7 S. y2 ` # 只有训练的时候计算和更新梯度( a, M4 y# b8 }. J
with torch.set_grad_enabled(phase == 'train'):
' M6 e! K: j A4 X #if这面不需要计算,可忽略
4 P5 y9 ?& B& a" i* R: ~ if is_inception and phase == 'train':
[; H U* U& e; `# p8 ^ outputs, aux_outputs = model(inputs)
8 L5 j$ c3 E+ x8 i loss1 = criterion(outputs, labels)
0 c- L2 K; H; L loss2 = criterion(aux_outputs, labels)7 e$ a; o3 f0 y; D
loss = loss1 + 0.4*loss2
* e% [/ U: _0 K( h: [ else:#resnet执行的是这里) F7 v' X$ G q; P
outputs = model(inputs)
! ]) E& Y0 I! p* d- O loss = criterion(outputs, labels)7 N5 z8 G T2 K
8 E3 Y5 h" }" w% ?% Q
#概率最大的返回preds( v ]/ c( C! h( S# @3 p5 j
_, preds = torch.max(outputs, 1)
9 w! f I- l7 n- M8 y% j' }
3 ~$ i: n1 m9 W2 t- V( C7 Y6 | # 训练阶段更新权重8 d8 ], B) C7 p( I
if phase == 'train':
; @% w( x, {! l$ k2 j2 d loss.backward()
0 w5 S2 }! S( D! V) e optimizer.step()
$ z7 Q; Z' l+ k& ?. i, b
/ K0 s" ]* f) D* D+ C$ k # 计算损失- b& z4 e1 G2 j' A7 I) O* L. s1 ]/ ?
running_loss += loss.item() * inputs.size(0)3 o" C) h- @9 N* j9 L
running_corrects += torch.sum(preds == labels.data)
9 V3 b2 K! x- c i% z1 s+ C* Q3 i
0 W( f9 [, e) h+ O2 L! V% T5 f #打印操作
( ~* L, w+ `2 k7 H2 n8 s' H w epoch_loss = running_loss / len(dataloaders[phase].dataset)
- M! N0 e; o8 m) X2 v epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
9 F( F$ W0 r9 T4 Z8 e: ?8 f
4 C+ [7 Y3 i$ l' } t
/ g) n/ f3 u, e. n time_elapsed = time.time() - since
8 m( K0 s2 G; q: d/ Z3 H% u$ V print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
/ z$ l i6 q$ I( B \$ ?3 A) D print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))) b3 J9 H" U' M! d4 K
, V* n7 S- g. C# `0 [' [# B0 p* q
G$ o6 h9 ]- N3 i; y: C+ ]6 P$ t
# 得到最好那次的模型
" p8 I1 n v6 d1 O" T if phase == 'valid' and epoch_acc > best_acc:: l [ j& ~/ t4 ]
best_acc = epoch_acc
8 j7 C( T- J; P$ l4 z5 o #模型保存
4 r% Z1 d. E6 M best_model_wts = copy.deepcopy(model.state_dict())
/ n& ?8 |& R: b& M state = {
# a; c& G& d2 Y, [0 s3 u #tate_dict变量存放训练过程中需要学习的权重和偏执系数2 M5 R* l1 n5 v& o; b) s
'state_dict': model.state_dict(),
0 p" a+ Z5 V3 q2 ` 'best_acc': best_acc,
- k, n. ~9 ?' k1 b 'optimizer' : optimizer.state_dict(),1 ^' x; B$ |- u
}
. I- o& M6 B: c# u% H5 l torch.save(state, filename)
8 v5 `( b6 S# }! G$ Z* z if phase == 'valid':" `7 f2 r" L* n* [
val_acc_history.append(epoch_acc)
, x, Z: S2 h2 `+ s1 M1 o$ m valid_losses.append(epoch_loss)
% y R9 z7 |* O8 f/ F scheduler.step(epoch_loss)( X) c4 s. Y: }- D( V8 m# M
if phase == 'train':
6 g3 i; ~, E' Z. y6 X train_acc_history.append(epoch_acc)- p/ U* x2 D6 U0 c
train_losses.append(epoch_loss)' D3 Y, S, l* n" W' i
* G& d" [( s) h" `
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
$ p% x% ~$ M' h" ?) e9 s9 M. y LRs.append(optimizer.param_groups[0]['lr']) x- l$ L# |5 E( z, ?' k+ r: b
print()
5 B# ~7 g2 C$ C
! Z6 s$ j% N% d: o/ l1 b$ t time_elapsed = time.time() - since: j5 Z# ?+ L+ a+ q6 K& R
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
$ s- K7 S' W* u! Y5 g print('Best val Acc: {:4f}'.format(best_acc)); }4 G$ y7 }, \) x3 b) T4 U
* n& M9 z. V' X/ m1 F* W # 保存训练完后用最好的一次当做模型最终的结果
; l, W9 N9 j7 x. C5 H model.load_state_dict(best_model_wts)- q9 Z5 S) L# G+ A5 m' p3 i
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
6 F+ b2 |" b- y5 ?8 g+ m" g' a1 b1 g2 n! N; x0 l7 J" H' f y
! X6 l1 u4 F( H9 V7 `
1$ s4 b0 x- w! U3 \% b1 B
2 I! w8 Y' i% f
3
6 }- o4 `3 b0 X' W4) t, F5 k3 W- Z _8 X- O' q
5 i& J3 k0 ]( F2 y
64 k1 M1 m' i" h( W$ M* k1 D, T: C
7) Q* Z( j' U, s8 x9 M1 I n
8
& d( T! ~, F, c0 B9& R, ]' a# z- q* N* G
10( w$ R5 ~" ?9 \6 a* V
11
: x; j4 A* Q6 ?9 M8 e12' C) Y* q# y& E0 G# ^. z' \ [
13' f7 c" h/ p$ ]3 u. ]6 I
14
* g% W0 @1 Q: p9 c15& e- O1 _( D4 d6 D7 m7 C; h/ `, R
16: V& b; [* ?: y! [; ?1 \
179 {0 v9 v% I* E; [9 g. ^
18
: a/ y; h0 G7 V/ j2 }; n1 h( O19
$ n4 o- w6 ^8 J' Q2 u8 p20
$ L4 K% A* K% b& v' w: X21: i6 R- k& R) A
224 n" C; t" j" y/ v# e9 N
23
$ y2 M' M2 E6 }+ A/ r; v' p24: l/ g; ~: ?! Y( e: Y
25 y8 t# R/ q4 R1 Y+ n( @4 S: w9 I: v$ z
267 p1 y, v1 ?7 K' t' e" I8 N
274 P! ]" J# L1 i3 f% \1 S
282 y. Q) B* b' D, m; u) }& l
29
( Y7 y! X. t2 }& D: e. ]30& q* G* a {6 {
319 r9 r- n! N0 c, v( u" ~8 h$ m$ Q
32" h+ T+ k) W0 y9 N
33
- I- S) ?/ S. @! A34
1 ~( X, E# w& l5 O) o$ G35
$ J3 c k. O1 e368 }/ X, c2 l- x* u% ?
37
1 U5 _/ B" t1 _5 w. }38! n) u1 Y( c1 M9 |) |% R
39
( ~. z, Q2 v+ U- }# Y0 i406 Y. `( q( M4 K5 M, `% Z$ U
41
5 ?& [' g* z' M* t3 [422 @% T0 z5 j2 l) u: z
43! J5 l v4 J/ f
44
- {; d) J4 J5 p1 s0 r45
2 l% v' Z$ {2 r* K46+ {; |3 Q6 L0 H% x1 X; d
47
6 u0 d2 x8 `& F5 s481 l0 d( y4 |. V% D
49
" w8 c& W- K- }50% ^* B4 M9 }% k* Z6 O* n1 ^4 p
51
0 N4 u; _' W* ]- T7 t521 ?0 R, H# Y8 p; {: Q- k0 V9 X
53
: A- I4 ?8 Y+ W9 V+ f! q$ y54
' n: ^' z7 R8 s8 G& S" x558 E4 J5 i$ b* L5 w( I' G
56
7 @1 C7 E {3 }* i, [573 R5 @' [9 T2 _( O2 N$ v' U& T/ D
58
/ ]% @- C# u& O; i59
( t6 g( T0 n: C4 _+ f6 `+ o60
' l% A; O6 V6 U. ]61
0 P( u- K+ @2 B+ K4 V3 o6 p62
) i. r; c; I2 y- \6 q% |" @63) j) d7 Z) t3 y: o9 P
64
/ F/ B2 ]5 v+ f% {; x2 B65
3 D6 I8 m4 |8 ^3 g( u66# K# l/ q# s) C, k; E
67
1 u j7 O" Z' s. J% c5 w68# X2 R! o7 o% h+ W8 n
69( K" \+ }7 s$ |3 @7 m0 j f/ n; C
70% a. ?$ b+ x: |' k
71& i1 p( F; [9 _8 i# x
72
! v$ i1 ]) \" g [7 P6 R7 y/ c73
/ ]$ A( P8 g1 {2 w3 V A74( W9 f# }- Y2 q0 E2 a
75
- p( ], Y5 U- k0 y+ ]* S76
) @* t2 t# i# v3 u3 O% o U( m77
2 N0 [+ y2 m4 k6 w* x: R6 @' w ~# f785 r: g( ~+ g8 ~% D
79: Q4 @ y. [/ l
80
/ ]3 U8 e/ B' L7 r3 ]/ ?81
3 P0 d8 r' _+ ?' j5 {# Y82" b/ ?, p% B, H7 W4 X D
83. }, Y- Q5 Y C5 W, x. i; }2 {
84
/ I5 z! H( ]1 o! H z; T" L85
8 t8 F( p9 f$ _6 w& f- o% Q _86
5 S: I& P* G. i87" F; @8 k! X4 w: R) H7 l) j( S5 h
88
3 x4 d: @: d# {# k1 T; \89
/ n% k( G; O1 @90
: U6 g: |# [, L* l& n0 S& T915 l6 P3 r2 `, c' a
922 y* E) x9 o! g/ ]! v: v" g \0 y& ?. g, N5 S
935 w' \+ X1 H8 z0 k1 r5 P8 v1 C
940 B$ \: l( m8 k% K
95
0 _" p4 U* L9 d% p96$ |/ h* e2 M; G7 c- _
97; k4 K) e3 y1 F9 j: A
98- a2 V( G, g* ?( c
99
/ [/ }9 @# Z; ^# s; F5 Y- r100
9 Y# `! d# I, B) Q101
0 }3 s% I b8 m" O3 k102
; K( m3 {# _8 c( q5 u7 x103; a) G+ v, e O, d; t* n( z
104
, Q2 m% z6 M& O$ F1055 _4 N& z. z5 v2 M
106
+ s3 g. U1 Q6 J- G1 n5 l2 y107
5 p( L Q$ K' d- e8 I1 A2 \108' t; {0 j$ ^7 Z. \9 K, X& W
109
* T( A' x0 L+ J& N& k7 [110# R. P2 O" F, J9 e- E; c: ]
111
& a" y! ?0 I6 \% W: V% I* u2 X& ~: r1126 M' g/ ~9 S# O9 F
7.2 开始训练模型) e% Z$ r) P4 l. a' m/ E) s- U
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次8 h) J' Q" \7 n( U9 t( a; A- @8 ^
`4 A+ d1 A& c1 W1 ?
#若太慢,把epoch调低,迭代50次可能好些& Q B, z' j) Q, X+ Q
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合. {4 d8 t9 P" k7 T; ?" \
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
: {2 ^ G, p+ e) C; I% B5 H+ W& Z a# e2 ?5 R3 R4 C0 y
1
* F: N9 x* W" P1 t2 R C# e1 j& f( o9 b N% j
3. z$ F# _* V: Y# [" C/ w1 D6 ]
4' K/ L" ?1 Z- H* k$ Q% d
Epoch 0/41 u D* [, q- T) z6 t8 n' w+ |6 N
----------
6 c- X% ^# q9 TTime elapsed 29m 41s
! x4 d7 A, A% [train Loss: 10.4774 Acc: 0.31475 t5 @! J5 N' z% O; i+ n* H# S9 h
Time elapsed 32m 54s
/ _# m8 v7 l0 |" u- kvalid Loss: 8.2902 Acc: 0.4719
! S" O* W* G/ N% ~Optimizer learning rate : 0.0010000
/ ?- T! d! P% ?$ N/ X. T
3 _; l2 y7 J! a: n4 N( SEpoch 1/4
5 E+ W/ P; g+ u. I8 n% }----------
4 Y5 B& O7 P' h; d0 [0 d6 r( xTime elapsed 60m 11s1 U2 [7 {9 T4 a# u( z4 R0 q
train Loss: 2.3126 Acc: 0.70532 ~2 L: Z4 H" W
Time elapsed 63m 16s. j0 b+ E7 V6 M
valid Loss: 3.2325 Acc: 0.6626
5 o/ p: Z4 k' M! d- FOptimizer learning rate : 0.0100000# w; a" W: X# \* U6 i u5 w C
5 [" ~% l; X3 T* t, r
Epoch 2/4
6 c: H( _/ Z) U: J----------
" c( F1 K0 a, O: e9 W% `3 bTime elapsed 90m 58s
) N( N2 g- l# e( \& r- v6 `train Loss: 9.9720 Acc: 0.47346 @: L+ t6 F. m$ g' B! X. |2 Q
Time elapsed 94m 4s; B5 G( ]& w" d
valid Loss: 14.0426 Acc: 0.4413, G! q1 o7 b, L4 E5 X! k1 U; U
Optimizer learning rate : 0.0001000
5 V3 ~# @" R# L* {3 \: h+ W& W5 j0 j' K' B
Epoch 3/4
, N% C+ {, b8 c4 g----------
i$ T1 Q0 R. M$ G* k6 iTime elapsed 132m 49s" o1 H- x& t8 @7 O6 R0 q! C
train Loss: 5.4290 Acc: 0.65489 h! o* x# f% C5 O [
Time elapsed 138m 49s! {& W+ M3 \0 v; S X7 I
valid Loss: 6.4208 Acc: 0.6027
0 }. r# s) E$ o) ], FOptimizer learning rate : 0.01000009 o! }$ {8 s3 {3 g- l9 E- t
1 g- f1 p5 M, Y9 I7 |. O" QEpoch 4/4
. ~' E. t2 _5 M! p: y' c8 p, x& K" f----------4 I; E0 e8 w) W0 Q2 {0 S7 j
Time elapsed 195m 56s* Y- ~) ?0 `* ~; u
train Loss: 8.8911 Acc: 0.55190 q! N) G) }2 u; X
Time elapsed 199m 16s. }! L% |4 s0 m0 z( y' g
valid Loss: 13.2221 Acc: 0.4914
" v' \" K4 O1 IOptimizer learning rate : 0.0010000, a; ]; m8 v. S
, Q- Q L! l' ?& m& s/ C" T3 cTraining complete in 199m 16s/ \6 m) }+ d" B) ^
Best val Acc: 0.662592
: B; u' z% l4 C: {( ]! P; e" K' e0 [
1
! m, ^* ]: z; u1 E! d& X2
) B+ m6 M' F* w3
: d! j3 T. t: t/ q8 O/ i! Z B4% [7 d" P$ b1 N; x" J$ h" N, Y
5
6 \. w0 G) ]; }2 D) C' i2 Y6: m7 b1 Z0 U# P& x
72 }! ~( _* g1 _/ k+ Y& a
8
" F# L, L4 k% r: [$ W+ A% K9
3 w* W# n% e4 X7 p10! G7 t5 u) g6 d
11
3 A, r$ ]) o) Z123 T3 H, D- X! ]4 Q
130 D7 `* c% D8 N& C8 f1 ]% @
14
: |' [/ {$ _3 c/ Z; f* B15# l/ X3 r7 M) S. Y2 Y
16
3 P: }' _- k; s/ Q# N7 \: T" [17
/ {8 T9 b; `1 g4 G2 \18
+ Q8 }3 ~& {# {6 W( S9 Z19
0 N1 p8 Q4 J8 @2 m8 j20
- L, |' k" ]5 v% }21
( w. ] p5 ]( K G! Y: o222 J( U2 d9 V( X& O8 ?: V" [% G
23, K' C' d/ H6 \
24
( s% Y% i. p- K7 r3 B7 P v: b+ u25+ c) T' R4 {( s' b0 N0 s6 F9 w0 \
26
- L; P# ~! n! y6 W n# v27
, Y/ F" ]( i8 d* \. a28+ [4 a0 ?2 n& m, U* ~
29
& n1 k6 p1 g2 u30# u" _: k) H) L8 {4 O O: M7 D1 z
31: F3 U* i& d0 y. u
32! x1 D- y- g$ b. e& i2 u, @
33
( `( J: |( R( F+ w+ i! a347 l9 k1 u) F# k+ P; Z
35
W) X+ F% w M$ g36
% k; J b! A+ `9 D+ L U4 `5 ]$ t37
7 d0 O8 _/ U7 s9 M x( d38& \9 M. l0 r: J1 L& I, P/ K
39
: C! k/ |6 d! U% S; u4 h+ r. \0 d40: u8 b1 z. y# B8 ^1 k6 Z& k: b
41+ ^) K; ~# e4 B; D4 t& x: L' E
42
7 H7 H( M T0 H* c, V5 h: t Q7.3 训练所有层
3 o( h7 A/ h! _$ T# 将全部网络解锁进行训练
$ y& F ^2 t+ x3 k3 [for param in model_ft.parameters():
# l) ]- G5 j/ u+ q5 c% v param.requires_grad = True
% C+ O* o% m: {6 ]: m: R5 A. a+ Y1 a3 D( ~; }: y. P: Q
# 再继续训练所有的参数,学习率调小一点\: V' {% r# ?1 q7 p, c
optimizer = optim.Adam(params_to_update, lr = 1e-4)
/ d# _4 n& j! z2 z# B# N' k) ~* q3 [( i- Zscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
# I: [" m' l9 T8 ?. e9 L2 Q& T/ _/ T: w1 F4 g7 i6 h8 k
# 损失函数% ~5 y/ a7 c' `5 D- u2 @
criterion = nn.NLLLoss()" Y6 r! O5 L6 D6 ?3 B' V
1) }2 n5 `$ U6 X
2( ^6 V _ _+ j1 ~! O7 C
3; [% l1 @5 R4 h3 }% ~6 r* U3 g- B
4
V* m# _1 h- T- E5 a6 `# X9 ~5; I5 l) t6 O) K/ `
66 R2 S3 i# Q: @7 h- Z2 h s
7, w3 @( g9 \: w1 r
8* |, t+ j' j- J
9
/ O& F' f" m- O' Y" D: _, |10
8 C# Z: t. @0 ^8 D3 C- ~) r3 Z# 加载保存的参数5 j3 B2 B" t8 H- b
# 并在原有的模型基础上继续训练% f9 s* }- z F6 g
# 下面保存的是刚刚训练效果较好的路径
9 {& l/ S5 D' w3 m% C3 n' lcheckpoint = torch.load(filename)
0 M' G! u" i3 {0 Y2 ibest_acc = checkpoint['best_acc']
( ~# |7 {2 S& r4 V/ [; h4 Jmodel_ft.load_state_dict(checkpoint['state_dict'])
7 d$ R1 M ~' ?* {6 Xoptimizer.load_state_dict(checkpoint['optimizer'])) L; ]) H2 j5 } R3 R! ^: D
17 a3 M5 r1 l* x
27 g U+ w$ X" T
3$ t: J- C+ H4 o9 Y/ p# G
4! J( C4 ~ ^, c
5
& z; K& O8 }" E/ a60 Y: S' X8 R& w
7 V: P, C v% Q4 P' e, ~: x. {
开始训练
$ t, x0 U% c* H* D* p2 c8 ^+ B注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
' \2 w& m# ~7 O7 ~
( n$ i; |* [+ o# a; Q& p/ D) ~, Lmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))6 k. C9 \/ m5 v
16 }1 r& C$ e6 \- h* L9 `) u
Epoch 0/10 M8 T5 N# b: |1 [8 l
----------! H6 w# ?: q' }6 \
Time elapsed 35m 22s
9 p% z; G" F- V4 Jtrain Loss: 1.7636 Acc: 0.7346* U: s! Q4 S+ h
Time elapsed 38m 42s
1 v' z7 A. M7 ]/ U7 ~1 o( e' J9 y- ivalid Loss: 3.6377 Acc: 0.64555 ]8 Y( T) h/ `( i1 ]
Optimizer learning rate : 0.0010000
: y6 v; O/ J% B! T ?& |4 Z' V1 H( e$ L8 o( I$ A9 T
Epoch 1/10 W* N+ g5 Y& y1 Z5 S( z
----------
" ~/ \! v3 |, Y* ^' m" BTime elapsed 82m 59s8 N6 |7 ? z/ m! o% ?5 O
train Loss: 1.7543 Acc: 0.73405 c- ]7 z: T" u/ m! K! k, T& H
Time elapsed 86m 11s
# N" d% I# X; C4 N; }; Nvalid Loss: 3.8275 Acc: 0.6137
/ M1 E5 \6 a8 H6 i# K% NOptimizer learning rate : 0.0010000& s' f" X: H0 Y' m
% j9 J0 D" H8 N; |7 J2 k8 WTraining complete in 86m 11s% M! @( v8 [$ l. S- \' r
Best val Acc: 0.645477
; J' R8 v$ H, K: U* a. v7 O; r l* s$ c1 h2 u9 b0 l
1) z' S! P3 H. i$ O5 Q; R5 i
2
! L7 D* D4 b; w% ~3* a& J8 }7 M2 \+ C5 r3 n" |
4/ s8 Z8 t9 H1 A% g* ^
50 |: g, }: r) L$ x! U- V4 f
6, q0 T! [$ F' ~) e; a" l; ~
7
4 ]; n4 ~6 X9 S$ O. v' M: [8
, j1 [/ D4 p F- V1 E9
M i1 t( w& |* g10# \- V* [) d0 c8 a$ n0 x+ ]+ ?
11$ O6 x/ R1 h- J) U1 O1 D
12
% Y! F2 c% M; ?; A$ p139 r# M9 r$ i! q: `1 u5 {
141 @3 I: |+ R& L9 d
15' @7 K; e0 j# [4 F$ ]; I% P$ j# d
167 x z/ m. d4 c |/ r( V
17
% j. ]- [- d) e; N3 e; U& v8 A4 I* g18
7 Q8 n2 O4 _8 v K( ]8. 加载已经训练的模型5 ~# i# F# I& v- _6 q7 j F
相当于做一次简单的前向传播(逻辑推理),不用更新参数
3 h9 C9 r$ l9 ?6 \6 p' i6 I7 A [
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True), `, x/ M' ~8 y0 @: {9 C6 D
! Y5 ^) g- b* v1 @6 l6 T. V `# GPU 模式
6 S& ]/ g" D" e4 ^- E* cmodel_ft = model_ft.to(device) # 扔到GPU中
" Y2 i8 G8 `# X
& \% C8 C* D* L! B5 ]# h+ X" _# 保存文件的名字4 ], T) t4 T Z8 f) ?
filename='checkpoint.pth'( r8 r7 k( M4 {9 G6 E
3 n3 Z* g2 N& D7 L' o0 \: A# 加载模型0 [3 I5 ~& _. I% ^! b( j# {& D1 ]
checkpoint = torch.load(filename)
; y+ x. V) M0 J$ ~# x5 ]best_acc = checkpoint['best_acc']8 ~) M9 E7 D% Z6 T. J* F$ [# m
model_ft.load_state_dict(checkpoint['state_dict'])
/ d: ]7 t+ D- C( v1
. m& j, p9 E& m' Z0 x- ], F2
% a0 b. p7 _4 `! N- Q- B3
+ ?; v" S9 {% B: a- m7 Z* K4
5 |, s# P1 x; U3 a( I5 X O8 F# e7 o. } D" l! F
6" l9 P+ G% c+ V
7
( W i7 E* n3 q/ f* U, I88 C9 {9 [) S0 v, S: G; D
9& t6 c% q0 ^ V8 W n7 c: X! C
10
. t. C) b- N9 W119 [7 ~6 m# X1 \7 s! Q- T( T
12
4 x3 s9 ?- Z6 h0 T [% n* B<All keys matched successfully>- Y6 u7 Z$ D# p* X3 l/ W3 Y
1' i- L, W# ^0 U5 |
def process_image(image_path):9 f& d5 C. Y2 i# B" V
# 读取测试集数据
2 N3 v' A/ M' f# ` img = Image.open(image_path)! U7 U) p* U# i
# Resize, thumbnail方法只能进行比例缩小,所以进行判断
8 {! M* V, `3 o) S2 E+ B # 与Resize不同' d# D/ v6 Z& Z1 m4 Q
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
- V( C( r: g0 v z # 而且对象调用方法会直接改变其大小,返回None! x) a+ P9 T$ W& Y6 M
if img.size[0] > img.size[1]:0 z+ b7 {! I: P v3 T
img.thumbnail((10000, 256))
9 ]! I4 w4 M, u( F* F, [ else:
. b5 a/ Q$ Y' S img.thumbnail((256, 10000))) B* k: k1 t( `( h3 E V2 [5 d `% m
* j( v& k- d, x; u5 M1 |
# crop操作, 将图像再次裁剪为 224 * 224
. y3 q4 Y5 q$ W& e Z: j4 w left_margin = (img.width - 224) / 2 # 取中间的部分# v+ @' j% j8 }) G$ R0 O
bottom_margin = (img.height - 224) / 2 # X4 |' N% V) s+ m$ q) {
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度! I2 N6 m E2 ?5 r
top_margin = bottom_margin + 224
/ L- M9 M+ n& ^5 U f9 ?+ ^0 ] P% [& [. j6 t8 a9 N
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
+ Y. b4 j9 g# Q
6 a) w1 T5 M# m( V, q # 相同预处理的方法( X, Q! h- _2 y3 y* b( N
# 归一化
: v; r. v! ~+ A! h2 ^' Y- @ img = np.array(img) / 2552 Z; j1 M. o/ J4 h+ B+ o) Q" C' |
mean = np.array([0.485, 0.456, 0.406])
$ _7 T: M4 g# d( q0 f std = np.array([0.229, 0.224, 0.225])
6 r; B' H+ O, k3 ^, V img = (img - mean) / std1 h( G- c4 I% H7 P0 p/ @ f
8 h( a- M `# |( T # 注意颜色通道和位置3 o. E, y/ u* u- ^" p+ U/ t' L
img = img.transpose((2, 0, 1))
9 J7 F9 Y7 X' J, m) B8 |$ q( F* C% K4 L7 A- f6 @
return img
3 D# i/ n$ M0 m; ~, U7 v# {5 k% q3 U+ A9 W) r% k& \6 A4 g
def imshow(image, ax = None, title = None):+ y+ f2 }5 c% D
"""展示数据"""
6 _) j. w# s$ s* J: _7 ]# F if ax is None:
* o7 y; h) ~! [ fig, ax = plt.subplots()" F% w" w7 E' h5 C/ W3 d+ e) d
7 _6 T! s3 y- `. g1 Y8 S
# 颜色通道进行还原
( w$ P- m, }2 n0 s' Z/ G8 [4 } image = np.array(image).transpose((1, 2, 0))- J: ]/ l; G3 e$ C7 _0 G3 O
, t9 p5 N8 n" P6 T2 P # 预处理还原/ _3 Y; L8 D! K$ @
mean = np.array([0.485, 0.456, 0.406])
2 B) j' |; ~9 X( Z/ S1 L7 u std = np.array([0.229, 0.224, 0.225])2 I' s, H! Z( K& Y8 e
image = std * image + mean3 d; g" R, s( X/ _
image = np.clip(image, 0, 1)
: u4 G/ V3 N2 w) t K ]2 q; f& X" {6 L7 F1 ^) D4 R3 d: O7 b3 U
ax.imshow(image)( w& M5 a0 u) ~ ^; a1 u1 H
ax.set_title(title)
7 U7 q: o; k4 J' E- i+ \2 ]3 J! Y% q( {) n0 V' e9 V2 k
return ax! k3 x; d( B, V8 l" e
, X0 ~4 O! B1 P( @image_path = r'./flower_data/valid/3/image_06621.jpg'
" r! @: q5 X/ G- |. Simg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
+ v2 B" ]4 |; q6 Bimshow(img)4 w& I# ?' f; O: [& q: e! }
8 f: }- t5 y$ F' ~8 z
1
7 o! l4 T6 h \21 Z% v0 R9 m9 R
36 Q# |1 D: ?; q {+ x7 E/ J& S5 ~/ {
4
, A* t6 O3 h0 U! f( Y$ ^5
) `/ u' e2 R$ R6 n/ g9 G6
, ^7 l ]7 K- F5 m6 B8 f* |! R( c- ^7- W( o3 w% F5 W* S5 F
89 f- Z- U9 H' R( t
9
" j6 [* i( E% ~7 |& l. B8 w0 C10$ o$ J# v% T4 j/ D& w
11
6 u/ y8 z- R+ z12) b& q7 ?2 V) w; @
13
2 d% v2 g, U* A% V( K148 Q+ }+ K6 e" q8 N& J+ i/ W
155 u3 m& v: ]4 m* ^7 ]" z. V
16
! ? _/ h( M2 e* L! v17$ w% M% l) I. q3 u+ y
18
- v- N9 T, m; d2 M `+ d( j19& k+ n. M) X- K0 U2 @, ~
20
/ X& c! w& L' B2 q/ m& b1 m21! U/ D# t {1 I
22! o+ R* \3 ]. [8 x5 ~7 p+ v4 q
233 N' M2 p% i* |8 s) c+ Y
24
" ^4 `1 h; e) q+ @; X: i258 ]' s: ~. A' V) L2 R5 `2 }/ i
26' x/ N; v4 _2 R0 l3 {8 t: h
27) G, z2 {5 n/ a+ B/ G, e
280 a2 I- C, Z. t( ?1 @$ X! P
29; S: H" C/ L/ ]% v
30
2 H/ ~: M( |, B# }! x5 w316 U' z. E& T9 P }9 K
32
( x: o( _& u0 V" u. I335 [! F- F: n7 I9 ` |9 M4 T3 Z
34: o' C0 i8 Z& a. j
35
8 a/ M& K1 D) ~. C4 h2 P I3 r' J36
- j/ e0 x m8 d2 \& x; y$ _37- z% |9 _7 r! X4 w* W# a- j
38; Y' `, }) \8 D y2 T1 z
392 m1 r6 I g$ ]# T2 ^8 O- ^/ e
409 @$ P8 w8 E( l) y5 [4 \
41* R& p8 \% v' \
42! r$ ~% Y; ~- f0 @8 F
43
5 v8 C% G. D& N+ u; c44
1 B1 _/ b, |/ w2 @6 {. \45
$ w6 \/ K& ^2 m4 ^, i; a46- ?- {: _, ]* N1 w
47
' `9 {8 o2 @8 p. v) H487 M N+ O; ~$ I! q0 z4 v2 d$ B
49
2 Y% e- E8 V4 m( i50
1 o" g3 |8 z% ]# g51/ }( j- u) |& k8 H
52! p) }+ c }2 W/ t, {
53! h* ^3 Z4 F0 _( ~) g
54+ ^1 k6 t3 k2 i8 A1 H- i
<AxesSubplot:>
4 o1 `3 O: Y, j1 G' Y' Q4 N2 b1
1 H5 P3 N1 o& @9 _1 k/ m8 V# L7 m
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
" u0 k7 `6 X5 C' ?4 {
0 X% ]2 `, ]$ ^; K& x0 dimg.shape
! m5 m- f5 H' r/ x5 z& v1$ K3 `3 Q, B2 k# J- a
(3, 224, 224). c$ \: O+ Z$ i: n( _1 `- n9 r
1" C \; g" X/ s8 M, A* L! s0 V
证明了通道提前了,而且大小没改变
5 N9 w8 S( D8 {( Y3 [, \( |' d( @- ]/ b7 v/ j: P
9. 推理
, G. x* w( p9 J+ h$ l2 n$ \7 ]img.shape
9 b2 B7 t; p- t* [8 J2 X3 ` {1 R' J# O
# 得到一个batch的测试数据
$ o/ F+ ]5 Y* g( _$ A8 H4 `dataiter = iter(dataloaders['valid'])
1 }% u' v. P I: n1 p/ d* Rimages, labels = dataiter.next()# K' ?) L c6 t5 N
4 y* a$ ~0 w( I: B
model_ft.eval()' F( I. M+ f M) K7 p- n
0 |9 o* _, x; Z4 V/ }- A% r0 F0 Cif train_on_gpu:' H3 q5 f# \& h+ D; {: @
# 前向传播跑一次会得到output
# @5 _, h, G( c! j& h( x. b ^' h" d0 t output = model_ft(images.cuda())
/ E+ b V% w4 l* X3 ?else:( n" N1 X F* r) ^; N6 M
output = model_ft(images); H6 e! |8 [) g5 X! O( Q* s! e: L
O/ e1 T' }+ `+ X# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值' s+ u Z# D& D1 K. d9 K, `- Y9 U
output.shape
2 P' f, l/ i9 V9 ]# D
9 s$ l2 n4 W# G! U9 q1. h+ L' P! ~, d1 o5 a
2* ]: U: ]/ n- I
3
5 }: d* }" U# q' M: ^4
! R, g. a7 P" w) m+ e" @5
3 D- v6 v+ X$ |' Y/ h, N: @6" T4 n5 H1 l" Z$ {
7
# J# `3 X5 A2 b7 R% V! [7 r5 D81 a, q& @% P* N) i. S
9' p+ l' L+ p: L) G( t
10
6 L. l' @7 U! ?- y1 P5 i% ~: Z11
/ D1 n9 h: u# Q0 m- v" V# a: S12
, P3 E6 {/ ]+ a135 }" w" W* H9 u/ r- q+ f+ E
14
; Y: k( [9 S7 `# x% X2 l% Z8 \ |9 b15" @# k* Y/ \' {% w& {7 E
16% c( ~& Y! K4 H2 n# g( v/ U
torch.Size([8, 102])
( D5 f1 W" K$ O: e1 i( d* m1
1 W# D: l; A! M' U( |1 `9.1 计算得到最大概率
- }) a w* _* E1 X6 I1 M" u2 R3 B! N L_, preds_tensor = torch.max(output, 1)
2 I. M; S1 s5 C- v+ k. k+ N2 E
5 |& j& B" u* i! ]; Q: Apreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
, B2 g5 s, J8 q) N, G1+ e4 }$ N& n$ s4 U
2
) v+ w9 _2 _7 i( i/ U35 a1 J2 g# T$ Q2 y
9.2 展示预测结果
& s# S: k/ w( Q2 A& afig = plt.figure(figsize = (20, 20))
; r5 _- C3 m0 K# @/ F" |3 lcolumns = 4 t$ ^4 [/ W2 w, [ {8 A+ l) h& D; y) L
rows = 2- M% t' W% y) [( H" @5 |
; [5 W! E2 ~) \for idx in range(columns * rows):
' |3 `$ a( T4 x ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
6 n6 i% m) N6 F( B. Z plt.imshow(im_convert(images[idx]))
( i. Q) [" q s0 {4 G U( A ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
: [/ I/ n0 d' |/ \1 |6 ]# P; {( T! K color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))5 a4 B+ l) v7 j
plt.show()# G7 ~7 z# k1 P t7 n
# 绿色的表示预测是对的,红色表示预测错了* d- v/ Q/ }- Q# ^
1( V- H7 }% l0 M7 M( i
2
+ ~( X! R; |; x+ C31 W4 A$ V% Q' }' P! x/ h7 u8 l
4
$ ]4 s6 C* _- J% `: V; o# ~6 Z5) d& K9 A6 {+ d; e- I7 P7 q. f& u- {
6" u& x6 R+ L. Z! a* L
7( J6 x* M; n1 o
8
" R d6 e- ?; N$ q8 g. X99 L( f7 z# n6 X! A9 M d- C' k
109 m" j# F# q: P7 S, I6 z( o) k
116 H. P( t+ @ Y, {* x* q' D- K
' ]8 k* W- B( M0 E; X
* o: }+ L1 l9 w, |4 ^3 o0 }7 x" ? l4 Q% z2 P- b$ W. ?; k
————————————————" O) l- @! V* T* {: W
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。2 E4 `/ m0 L6 t/ ~
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
b" R8 q- S* j& O/ S2 y( a) o! E9 D( r$ Y
3 K* M; v( r! f, ^5 Q |
zan
|