- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563328 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174221
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例' S4 S0 ?2 x h) M8 {
/ a: T6 y+ F' l) z! T1 l& E' T( |( L' b
文章目录
: ]3 W* J# P0 y卷积网络实战 对花进行分类
, X/ ~- a, O- W% D数据预处理部分
% d7 r% E" h& T: H* c- |网络模块设置# i9 a( U+ o9 J/ F6 c
网络模型的保存与测试
# X$ c; a3 S% Q |% T9 D/ U4 K数据下载:, I5 v0 F( ?% a$ M
1. 导入工具包, E- C; k- ~2 l* Z' G( I, Q1 J
2. 数据预处理与操作
4 B# {; A+ E( E! Z0 Z3. 制作好数据源( @/ f! C: P: R: \8 v& N2 C1 @7 B
读取标签对应的实际名字
. l( N% B7 e" v2 N! P4.展示一下数据
( `$ ^% u$ M# k6 d5. 加载models提供的模型,并直接用训练好的权重做初始化参数5 {3 ~9 l, R: _0 g% B; D
6.初始化模型架构
' Z8 V. h9 J" j& s* S, ^7. 设置需要训练的参数% C. [+ Y9 }7 R( ^/ ~1 w4 K q
7. 训练与预测
# C+ P% a! ?: M) B8 n" \7.1 优化器设置
+ U+ Z9 z0 T: H* G7 p7.2 开始训练模型( p# m; N* @: s
7.3 训练所有层
! i& g5 l6 ]/ ]( k5 |) b8 A" z开始训练
# B9 g& a5 _0 h8. 加载已经训练的模型
: t- c% o9 `7 C* [! w7 G+ g0 q6 x1 V9. 推理
3 q) c9 |+ ?, E: N9 p, x9.1 计算得到最大概率
. D9 ]: u: T1 a' R% H. ]9.2 展示预测结果0 B) F+ C6 C- P0 A6 a
写在最后0 `! D" R6 h( U& U% B$ z
卷积网络实战 对花进行分类
! R5 s# [! q: g" D' d本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
* K t/ b* W. |
9 z6 ]9 T$ G7 `; c& X% ]# I在文件夹中有102种花,我们主要要对这些花进行分类任务
5 r# _/ ~ |9 D$ c/ r文件夹结构: L6 f$ E. q& B# }8 Y
) N5 ~5 K, v4 V! M/ q6 f
flower_data
: h5 x. O4 Q8 d* e8 [
" {" K) [6 \% j/ k9 o' b7 I, Gtrain' M1 Z8 R7 n3 J# A) C
/ H% Y# o- d" ~4 c! C1(类别)8 O, i ]: \4 |) S# t+ Y
2/ A: x% c/ g) G2 [
xxx.png / xxx.jpg
( S3 _* \; X2 q. i) |valid9 m3 m+ e. u- `" y
! a; p k9 b5 P( p主要分为以下几个大模块
% m! {% P3 R! X! o& j
3 t: j2 y/ m7 S数据预处理部分
' c! F$ |2 s) P( n0 }数据增强
" _; y, [- K- t4 U- \) j数据预处理
( ] }$ O# S6 [4 u0 r. e G( ^网络模块设置( P9 b4 R4 f! D- Z. a( U
加载预训练模型,直接调用torchVision的经典网络架构* B1 v6 k0 D( U6 \
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
+ X4 s: B; _+ k; ^, d9 M网络模型的保存与测试+ j/ Y; r+ P P" e0 N! r
模型保存可以带有选择性
$ l! G7 H0 H, E `数据下载:
$ l! ?& T) H. rhttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
6 M8 ~. r+ W5 N. @9 v8 [8 J3 V& g3 `, n* |7 d8 V' \
改一下文件名,然后将它放到同一根目录就可以了
& }' S$ A( _ ]5 `/ B+ ^) y4 |+ G( w4 H! ~% M& J( ` H
下面是我的数据根目录1 v1 V6 K# `. Z6 L: `4 \
( u; B! {/ D, E8 K+ S
6 ~4 M* @+ d% y2 s: l% h1. 导入工具包
! O2 \" P5 c! Z$ A7 U7 \import os ^* q# {" v i) g3 Q) L
import matplotlib.pyplot as plt, b3 J9 @) V$ O
# 内嵌入绘图简去show的句柄& J/ J0 `2 r) D- p8 U) }1 ?% c7 I
%matplotlib inline
g) v! L1 a* W- j. O1 Qimport numpy as np
- N8 Y9 S1 r& R+ n' X; `8 g6 s Q2 nimport torch0 T, Q& A) n5 B* m/ L: W4 P* R. n9 j& R5 @
from torch import nn8 E1 ~0 J: g: R b) ]9 W
5 y" m) W) P# L( O4 Vimport torch.optim as optim) A% B" E: l$ _' D
import torchvision! x( F' l& I4 s/ E2 A2 V
from torchvision import transforms, models, datasets9 D2 ^+ K3 U5 s! q
/ A1 q6 X& d' |0 M. R; C$ h$ S( A) G
import imageio
5 Z2 \. a! `" w- l' ?3 l# H% d5 y8 cimport time
( }. I1 \' t% O, {* }import warnings
5 r$ D0 K( B. `3 oimport random
! V# O9 E) v2 z. cimport sys: L" Z$ g/ s9 E2 b9 z' h6 c
import copy
) O' \ Z+ F9 A+ Jimport json( ^! v3 B- h1 E3 r9 a
from PIL import Image
5 [ g' o, f4 _# _1 n. v9 z3 r# c Q: X
3 X3 V% B M9 A/ W% D8 E6 [1, ^. c" K6 n% a% ~0 [
25 r+ M7 l' x, X* C% L9 d- p
36 a" N' h1 e$ K1 Y. J4 X
4' K9 a! t* S4 k: h5 U* a4 [1 X) A
5( }) v$ @/ E5 H/ u+ I% b+ s) o* u
67 [! ?, h) m$ z& A% f" K+ i& c
7% f' U5 b! p3 d9 c. h* X: U
8
# v! H1 S4 z! Z2 J9& h; ^7 p- H& P- ^' j0 g
10
0 a* _. i0 W) j- f11/ X! _# K1 u7 I7 y& n
12
5 |# Z0 @5 }7 ] t13
, }4 [& J! Y5 ]. P4 G14
/ W2 g, c* X) U. e9 e2 q$ N150 K/ j! a. Y/ \1 E" C* k' r" [
16' H8 L" l; c4 t2 t
17
# V+ W! a2 W6 Y; Z8 Y0 p8 e18
8 ]4 m; y% {4 D6 C8 a& G19. v) ?& J7 _6 J
20 V- V5 a/ E* e9 N6 w0 J
21
9 ^- C9 U8 r7 J* y4 j7 N b2. 数据预处理与操作
/ H* Y! u& y' }. W+ r#路径设置4 I9 I R# v! j! Q
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录* @ P* h/ e8 M9 z# ^
train_dir = data_dir + '/train'
1 N. u9 N+ a- E1 x- Pvalid_dir = data_dir + '/valid'
) `. B' U6 Y% {. U' Y1
) J2 q" a. y2 g2
: j r- Q1 x: l' l: _0 Z3
, g+ P+ Q& t) P4& n) _/ [7 I+ u) a# A
python目录点杠的组合与区别0 S& ^1 `6 K. |* C" b" g
注: 里面注明了点杠和斜杠的操作5 ^- C' Q* N. @4 Q
! B0 r1 o9 `# O! Z5 ~
3. 制作好数据源
& f/ \1 X# L, t3 `! [/ h$ adata_transforms中制定了所有图像预处理的操作- b7 ~5 c5 I, E1 X0 q! ?
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
O3 y" n0 y6 P: U6 z0 cdata_transforms = {
4 T D9 k) p* z; {( Q' c # 分成两部分,一部分是训练
9 ]: l* I3 w, B* _: A( z 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间3 D* v- z! j6 h5 r2 r
transforms.CenterCrop(224), # 从中心处开始裁剪
4 X3 C& h4 s' E- f; f6 ~* w. j # 以某个随机的概率决定是否翻转 55开1 v! y- J' Y" Y W% y
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转+ B* c3 n9 Q+ u8 }9 G! {- l
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
7 w+ P# p/ n" H0 L0 A8 B; g # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相/ L( w9 p: {4 V- m$ z
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
6 X$ T' m# t, e8 ? transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
1 F1 a$ |* d _! h; a # 灰度图转换以后也是三个通道,但是只是RGB是一样的
7 m; j, V( g; E" D transforms.ToTensor(),% Y$ n0 E- p6 f3 J& n
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差; g1 C9 G( O# Z [ K
]),
: f& A6 @' o+ `/ V8 o# S9 A! U& e, | # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化3 |5 x' c( `9 D2 z" d0 C: {
'valid': transforms.Compose([transforms.Resize(256),
) _. S) d8 T8 G! m( ?6 A transforms.CenterCrop(224),2 p: V# Q( B# G, G* a0 Z
transforms.ToTensor(),1 y# Y0 m, n1 }
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
8 U7 k) q8 c _" ^0 ? ]),
3 G% d3 [5 Z2 D# g}
# {7 r) w9 c7 X* L3 e
t4 V' I. q2 S1% q7 C% k5 p4 x% N
2
( W1 [" r6 K) ~1 r3
9 A& z' r( r. f- t4
( ]; b9 T# q, N5 F% k7 Q$ ^4 Z' ]5
' Q, j* t4 i' Z, c/ V6% X% A& U, R! t; e
7
) K) f, s2 d0 y" N. K8
, f. d8 f! r: q8 Z90 H1 h# u3 \/ J- u/ \9 A; O# e
10% q9 W( K& K) A8 T4 Y$ ?
113 y* r; }8 N. n' y0 u5 g) I7 O
129 @0 c2 U3 n+ Z V2 X& c4 l4 t1 z
130 }$ y4 }3 I! j0 Y' z8 C# A$ P
14# E; s" v+ L, Y" k# N+ e
15' c1 K0 }+ I" T4 N# R% k
16# o0 k4 [5 F* A" l: j6 b
175 R* w( x# h# ]4 ~/ L9 Z
18
( Y! a& ]* j4 b. w, G19
0 {) k# [; e+ h6 u" |+ ?) P3 `( D: d20( g ~* q# r( T7 r* Z) ^) ]
21
7 I" d. k6 Y: O& b0 g! Dbatch_size = 86 d3 _3 s9 S: {, V' w+ `+ [
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}3 i. q. P. ^/ e1 `3 Z! D
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
5 S7 s& g* ^1 P5 m0 ?) [. Ydataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
+ T( y. `/ P6 {7 F! j# Nclass_names = image_datasets['train'].classes) V0 ?# _% h! P
( _, e1 f% z3 `( w0 z3 F/ c$ b
#查看数据集合- O2 R7 o1 s. i( o; v& M9 j
image_datasets
# c8 l' v* s# ?
& L( P0 o3 ]- z7 s6 @" J1( W, J' Q9 {" w0 u& N/ z8 o9 H g
2
5 c# b, j/ K" Q* k+ h3
8 H& K& P; Q' ^1 T1 [( {% o4 U, j4
7 y3 `- F* N6 B2 S, g, m# U53 w4 q ~' w; z3 }8 A( {8 a0 j
6
4 k6 f/ n" t, A$ f! u9 O78 I2 M# n( e/ n5 J+ A& w2 }, _! v
8
+ ]& D+ Z a1 w! f9
: z" d# }/ ^. k) T) c{'train': Dataset ImageFolder: K) f4 J$ A1 z
Number of datapoints: 65526 v# A/ e- E1 q5 g
Root location: ./flower_data/train
; w, e2 `+ ^; v, i9 F StandardTransform
! d! W ?( ^8 s; U+ m: i Transform: Compose(! I7 F2 n+ \' x
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)4 w6 Q6 c4 y' x3 m7 c$ K% T' Z
CenterCrop(size=(224, 224))
8 h" D% c( V! ~* E( Y2 B RandomHorizontalFlip(p=0.5)
Q6 A; O: ]" V" m+ p6 R RandomVerticalFlip(p=0.5) ?+ s7 ?+ a; \" S+ U" }
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])# d* ]% H7 H/ k) Z" ^$ d! _. n
RandomGrayscale(p=0.025)
2 k# m4 f# o* i7 [ ToTensor()
3 h0 v9 ?% B! w L Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])/ b- d' V$ S* D3 Q) ?" K+ j* d2 g
),& {, R( ?" ^5 ?$ P- u$ {* b
'valid': Dataset ImageFolder% s5 L& G: r3 T* P: r4 r* Q
Number of datapoints: 818
/ }' w. M6 A* i- n3 h/ o Root location: ./flower_data/valid
3 \2 ?6 o5 w5 K2 S" {4 h4 _3 v3 U StandardTransform
& n" s, Q2 [4 H* b& P- d1 x5 D Transform: Compose(
' K% ~7 L9 F) i, Y) | W) C Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
" [2 T% n, K: U2 v# c5 D% J$ C+ ] CenterCrop(size=(224, 224)) V- l5 F" u$ ?! y7 u: K0 [! D; e5 ?- H
ToTensor()' d9 J. y& O7 w# k9 }. \
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])8 p8 M& y, ^/ o- F
)}
% @- c7 x I" V5 U) M* q# E+ |) t& e0 e
1
" l# ^4 p( t( L7 P" K# f( G9 T8 q, m0 u2
' e+ O2 C$ t D3 A. r, V36 W1 B" i' B" r- z5 p' K$ Z% ^
47 N9 Q- x) S: [% K; ]
5: @: _3 `0 K' n) a" L7 a. W
6* |# ^6 J- a3 p& s' l! C1 j6 K7 {
73 }- V6 [' g3 T5 X( I
8
+ G' \: s' F f6 `9
/ x7 Z% G6 O8 u2 M- m9 |& T1 F/ D10
0 A4 x3 Y, U/ r7 }" }% V11
5 z9 Q7 \9 y M$ _122 s& S# O5 m3 L5 C6 g4 `, y
13
9 R- C: y+ X# R14
7 I8 o: G# w: F; N6 V r& Q. ?15/ D6 v/ x7 t: h4 i1 o# r( S4 k
16
/ j8 W) j1 z4 a$ o# y$ r1 F17
/ ]/ I1 l" w& ^. ]+ l8 [18$ p3 B7 ?# p4 g
197 P5 w1 y% g+ j7 z$ Q) X! c8 Z
20- m7 P6 F, p. B) |% V7 a# Q7 Y
21
. u7 e7 O1 j' G' c( X* o22. s4 e# O- ?1 E# F7 T
23- h D& u) |1 H. t
24
& v8 W+ q V4 p# 验证一下数据是否已经被处理完毕
5 a1 _5 k% }' [8 J% ]dataloaders3 P P1 t, u+ _, |1 d
1
" |5 ~9 n% D( i. H, e5 j2
, M/ K8 |; Y' e; y: N. z6 }& C( t{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,& D. i5 a+ b* B1 L
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
- B/ x B* S) A- ]# u; `/ p( X2 k1; J/ V5 ^5 E/ N* v3 J
2
; g) d& J4 _# p Bdataset_sizes
- A6 ~" ]7 b& v' }6 `1
; ^) \- n8 Q5 a{'train': 6552, 'valid': 818}3 J: c! r1 \* ^& P
1
/ A! z' x L% K读取标签对应的实际名字
: X, P2 t S ?. K使用同一目录下的json文件,反向映射出花对应的名字1 a6 j2 ]. r; U% ~; F2 h
( A4 t& B8 N5 {/ q: a
with open('./flower_data/cat_to_name.json', 'r') as f:" j6 Y5 i6 k# @: L( g5 O' m
cat_to_name = json.load(f)
9 _9 _3 v- J- s3 ?: `" f, V1
4 |* u1 E/ l7 ]/ h- f2
# ~4 z. G2 }: w _* Scat_to_name
) J2 T5 ~/ D( {1& Q& w4 @2 o, P( r
{'21': 'fire lily',
. Z, O8 L0 H- ^# S7 ?4 r" S9 e6 f '3': 'canterbury bells',
$ i2 s& e H) D" {' w8 k8 |+ { '45': 'bolero deep blue',
( C4 E( C, x& E- E9 ]# E y '1': 'pink primrose',. z* I% @9 |. C2 F& R" O
'34': 'mexican aster',' B7 p# V' [7 i0 m* k- d/ D
'27': 'prince of wales feathers',
. Q% c* D) r# {$ T '7': 'moon orchid',8 S9 v4 [+ t) R" l/ A
'16': 'globe-flower',
2 V, E" P* Q) t7 I '25': 'grape hyacinth',
% v+ _' B7 I8 A7 \% N4 M: O '26': 'corn poppy',
8 Z5 E2 Z+ Y9 L$ Y0 K( ^( v '79': 'toad lily',4 E* `2 G T. O1 U% _
'39': 'siam tulip',5 U# K+ I& N6 n4 I1 U% U8 o
'24': 'red ginger',
3 d& _4 A4 I9 N '67': 'spring crocus',. [2 O4 S9 \# S3 s; Q0 S9 M( m
'35': 'alpine sea holly',1 K# Y% z2 _0 h) x2 s& T$ ^
'32': 'garden phlox',
# a r. d& N; v- K" C2 l '10': 'globe thistle',
( v u( s$ J2 e( n '6': 'tiger lily',1 ^3 U6 f' A' F4 @
'93': 'ball moss',9 i9 p8 X9 s. s1 o4 C4 x3 s, {
'33': 'love in the mist',
! l9 r9 `# ]0 j6 s& s" W' b '9': 'monkshood',1 u; c0 V' C# n) g" N8 ]- v
'102': 'blackberry lily',4 H( F& k2 E% G& ^( X+ p
'14': 'spear thistle',! B+ b& O7 P) Z$ B- E/ \7 H
'19': 'balloon flower',
1 b3 U+ Z' Y; J7 R" b4 Q '100': 'blanket flower',
- f9 A7 \: C( Y+ I6 r+ r S1 S! R '13': 'king protea',6 S0 ?# w, N7 I' R( P5 `
'49': 'oxeye daisy',0 T* r% ~5 K! w; h
'15': 'yellow iris',
0 T4 B! M0 ?$ Z) K1 m: e, a '61': 'cautleya spicata',$ D* x7 s/ z! U- F. s
'31': 'carnation',4 f. @( \. J8 {6 [' @, a* c
'64': 'silverbush',
: W2 l6 d# a( i+ x8 S '68': 'bearded iris',7 G! y& n! k5 Y6 l: Q/ y, _
'63': 'black-eyed susan',
7 Q" U3 Q+ u. |- G8 K. F '69': 'windflower',
7 Q- i) R! [- {+ ^" U0 @ '62': 'japanese anemone',& G0 z3 w8 m* u" A4 u6 k
'20': 'giant white arum lily', \/ A6 h, R8 b/ l% f& v* w
'38': 'great masterwort',1 @& d) D: R" [4 O; h- ]9 ?
'4': 'sweet pea',
; k; x Y" M( N2 a- l# Z '86': 'tree mallow',% e$ u* [( {9 [0 Q7 _9 b, B
'101': 'trumpet creeper',
/ u) n0 z. B) s: k '42': 'daffodil',9 A9 [' a. f# }' x/ L$ W0 T( O
'22': 'pincushion flower',: o/ X; D* W7 z- g4 o
'2': 'hard-leaved pocket orchid',, G* l6 \: l( h; x! L
'54': 'sunflower',7 o7 N) j: s" Y7 i8 C
'66': 'osteospermum',6 K H1 n, ?& T8 D: z1 v; { y
'70': 'tree poppy',( y& c2 E4 h% w# @- ]! u5 @' h
'85': 'desert-rose',
/ e2 A3 N3 a& j5 [ '99': 'bromelia',
- ]5 _: r9 b5 e# U- X2 V! r '87': 'magnolia',: |- D! B6 O( A" w' z. Z4 H. Z* t
'5': 'english marigold',
/ S* O% u' I% O0 F2 }& G '92': 'bee balm',
( y3 O& g! B. S6 s/ f '28': 'stemless gentian',8 e7 x. C \% w! j& O' y* y7 i
'97': 'mallow',, g7 O! o* n/ W s8 X2 m
'57': 'gaura',
( y) M& k d5 o, N& x# `+ P- D" h '40': 'lenten rose',% H" F% U) {+ e* M
'47': 'marigold',; Z# j3 m1 A: z( W& _3 H( |
'59': 'orange dahlia',9 n6 X. z, m( | e' |1 h
'48': 'buttercup',4 O* v U4 D' ?" @+ P4 P) N3 L
'55': 'pelargonium',
7 q3 m6 ?1 q6 g. { k1 b! n% w '36': 'ruby-lipped cattleya',) y+ ~5 Z) }- n, a# @
'91': 'hippeastrum',; e) P, D' Q; A$ Z* k) N8 t: E
'29': 'artichoke',6 s8 g0 a7 y/ A1 i y5 s% t' ?2 ?
'71': 'gazania',
' _. v) j) W; e4 L7 S& L" l4 l '90': 'canna lily',
* o9 G" G1 N3 f z% F) u- v8 | '18': 'peruvian lily',$ n, Y: g m* j3 k9 [6 I7 J
'98': 'mexican petunia',
6 V {7 O( ^; _" A& @ '8': 'bird of paradise',
$ K- @' |2 j# k% B* D1 c% U( q3 m '30': 'sweet william',3 p$ w2 B% n0 x9 I* z5 y
'17': 'purple coneflower',
0 u5 c# `1 b6 i '52': 'wild pansy',
8 i7 j. t3 o5 @1 n$ y3 \ '84': 'columbine',
# x, B0 [) g' J '12': "colt's foot",
- D" p' E) n5 @4 Q '11': 'snapdragon',
/ ~3 p2 x! r6 J7 k. M2 p '96': 'camellia',
, U# m/ J" H3 E7 h2 L: j) g '23': 'fritillary',% |; h; Q( _0 n- N+ q1 f0 z) |
'50': 'common dandelion',
( P/ k7 k& S7 I [ '44': 'poinsettia',7 x- |9 K' q2 z0 u3 c
'53': 'primula',- U3 \( K% q0 t( F" g, [
'72': 'azalea',
7 T1 y& s* F; s8 ~1 o, r '65': 'californian poppy',
/ r" b z) o8 U' J9 C$ F '80': 'anthurium',$ [, T/ I9 S3 {6 e- p, y+ B, N+ c2 t
'76': 'morning glory',
8 I& [9 I" H' ]% p2 S- q9 C' ~ '37': 'cape flower',
q5 t( ]# N1 M, @' k; z; {3 U '56': 'bishop of llandaff',
" X) u- S. @9 X5 S6 c' s9 \+ t7 y" b '60': 'pink-yellow dahlia',
. q+ g4 i0 _7 K' y' ]- ] '82': 'clematis',& P4 P, c @3 p& N/ x: K
'58': 'geranium',+ h% K1 P5 S& u5 n% C* P" B- K
'75': 'thorn apple',
0 H; U+ E3 H- w6 W7 y5 y) ]& d '41': 'barbeton daisy',
4 s* I$ g |! [6 z+ f& Y$ ^/ @ '95': 'bougainvillea',
' A1 W5 b' X- {4 Q* ~ '43': 'sword lily',( F& w* A7 v/ j# S6 T
'83': 'hibiscus',
; @& ~) D+ l# s/ `5 T7 v; I4 B '78': 'lotus lotus',$ i* I" E% V6 a x, Y4 F
'88': 'cyclamen',
) R0 i6 k) O7 k7 A+ \- N '94': 'foxglove',
+ v' y1 T% L7 j3 S8 ~8 p3 }, j '81': 'frangipani',2 x+ i" C- q, N( y3 H# G
'74': 'rose',4 L0 b+ }3 c+ J! @
'89': 'watercress',. T7 R- v; M- |0 [0 G# d/ @
'73': 'water lily',
' Y: C5 a2 r0 @, m9 c; }, o '46': 'wallflower',4 Q$ J% h+ N0 @( P# U. Z4 Q
'77': 'passion flower'," C8 [! k% P* p$ |0 S8 V2 y" B3 |
'51': 'petunia'}% _: G: f8 M0 ~6 l- M
! @1 r; }" s7 S) n" q7 [1$ O r1 V7 @: n
2
$ @% H9 r$ p/ D3 p) m3
& u) ?3 e! Z2 R' |4
9 ?3 ?/ a2 R; S1 s! ]* j% ?7 c5: O1 `5 V5 Q# h. F. q
6
7 c- O( u- }' ?/ l1 o1 i5 t7
, J6 `1 u0 N/ O8 M9 t8
# Y& a9 |8 Y( Y$ `: V i8 H9
5 [( B6 I% G: T10
4 O7 E* V6 B& Y8 U4 t110 e2 V: \6 p4 t1 K8 Z
12
% U* C2 E0 V2 k3 ~ Y5 m+ w13
4 l( \2 s5 P3 i( q% P14
?1 N) V8 C: c, w15
" U! _0 b) k8 z* q16; A& z& C1 u: Z
17+ [5 |8 T$ Z2 v0 [0 g, x+ G
18; B3 O/ n6 _4 q, q4 a/ y3 D
19
# g: G6 h* M* F/ r7 ~0 e20
+ G: r! O6 o) v3 B* k21
; U7 m4 E- X( ?22% m( t: d* s* |9 `# T, X
23
0 W) i M/ h/ Y24
" X8 |' G9 M9 c) |- Y4 }1 D25* Z5 P0 l! u, z# x# {! e4 v2 q: O
262 l" Q ~ @0 ~
27' H& L6 b& o. z
28% w: E* n/ C. _8 T
29& i/ {8 T* g& g+ G [
30. N) W6 |- O9 S, v, R, n
31
4 G _5 Q" M- w$ r: s7 W( S32, Z6 q- T: T. B; l N& B" }, }* I6 \
33. k( K0 v- U8 r9 v$ Y
34
2 B, W8 W( e4 v. p6 ^35
: q& B: M! ?8 n) y8 [+ c& u36
: {* B) \3 M0 C/ U37
0 k3 D3 Z2 |5 i2 P382 i/ y5 `- j3 R( S
39
6 w3 t, }% |& h( H40( n! D/ j2 o1 i; @* B+ e( t* W
41, B) ]- L' ~* y
42
3 r5 \! Q; D7 |) ]+ E, L430 c" p( T/ f6 t4 X! x7 Y% h- Z
44
: l8 h/ W* F- x: B7 f0 }45' q2 v' p7 ]4 U1 {
46/ w! t8 l6 G) C
47 J# ]9 |0 D, T* N/ K
48* k% X& d. J' h
49& J* t( k! L- T+ }! y0 ?
50$ b4 |6 l: Z- w. @5 l, O! O, W
517 L4 Q) G8 i; d9 o s
52* C! b% h) N; n: K& {( Q; R
53
2 E' O N" W& Z: l* T54
( }8 J. v' {/ r55
V9 |" m/ E$ y- Q! f" `56
2 p* u3 ^0 B& V- i573 X- t* i0 b( [' r! j
58
1 w! S7 I2 {8 i% O2 h59
: `4 P2 f! b! E x* l& m9 F- o60
: |( {: K7 F3 I! {+ z61' e8 b+ a, X0 a) s6 w
62
. Q. r$ k1 l3 R L635 k" G9 R4 D w- i0 h# Q$ C) e
64# m2 Y0 ~. u7 i' p! H/ p
65
0 [' A$ ]& M4 x66
" D3 Q; v& l" f; U$ {& Z67
7 u# T4 W6 B x1 a5 y1 L68
t j% b, `% I( o+ g# n# {1 s69
! A: D# P: T# J4 y( X70
5 m, X$ X C2 X( ^* [0 L9 W* E71
; {1 F" P) A4 H2 a72" i& j- F7 O1 h. C! e. O
736 u# I; v$ T0 q. [
74/ i& M9 \9 }: {, d- p
75
6 Z- k7 j! g6 `76
?3 X7 G5 @0 K% k" T; M77( O6 K9 ?" A- m: e6 ~
78
, @# |$ V. u2 K& V( N j79& s& ^3 ]) A7 q, \. P( j
80
5 ~9 O8 M' M3 Y8 B) B E81% {" a7 w& ?3 |/ Y3 }, \6 @
824 Z' T) }+ b$ b8 S0 p) f
83
. U* A/ ]4 R4 W2 _2 X+ y84' _9 k6 \% E7 _. [% z
85
) U" p+ F4 d- R6 G* ]0 e3 v# C! `86 z! M, K7 W5 s9 V5 O( a
87. t; L9 ]7 G( W2 m! s' Y) W. e6 M6 C
88
; r6 @; T, t: s. m" D* k89
, M: Y3 {3 t& ?4 {- @) H900 E9 |" o! u0 J! g. W
916 R5 x) c2 A$ I1 d! y4 C
924 h# F8 M/ S7 W+ S( g
93. t9 }" {1 _1 H9 U" T2 N, f
940 t* |3 x5 R% Q/ l
95
: C+ V3 o* W6 |- |% E3 h. K0 C964 a: i& [! j t9 M3 C
97
, }3 B1 h) n' Y# U987 d6 ?3 q5 d, B: O* W. b
99
5 U- o1 {# X# k0 s k. V100
3 ?" s1 h, q, W) [2 M k X" R) ?101
0 [, s/ n% H7 o# W* ^& j102
# w9 W1 N# N& H; |9 R' p4.展示一下数据4 ~9 N8 G) t z3 u R
def im_convert(tensor):; Z; t/ t6 m! ` G+ j! Q) R! N/ `7 S
"""数据展示"""
0 Z, A3 \1 l! Q- X' z image = tensor.to("cpu").clone().detach()
; \+ N+ P5 R8 J1 I image = image.numpy().squeeze()/ r, ?: d) l3 ? y
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
- y' w3 U7 ~7 X # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
* [; [( V V+ ]# u' C image = image.transpose(1, 2, 0); I" K: @0 _+ g# s
# 反正则化(反标准化)
& ?" ]$ v5 u( ]: `+ @' Z image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))& D' y: f, U; \/ I& `" w1 I
5 \* l3 K7 l G/ g2 n' G; n
# 将图像中小于0 的都换成0,大于的都变成1
/ D8 ]. n5 h4 u% W. @ image = image.clip(0, 1)/ ~8 p3 L: A3 v/ O' ]6 O4 w' _% F
4 C6 Y( F* m8 O" G4 _6 ^! g
return image
' Q! h& L+ f" x7 Q: a5 f) h1
' d0 n& W# q' Q! T2 x+ d* m4 b8 k- I8 V2
- _9 X! ? e7 l) h) v3
H4 r. j% S+ }& \8 y$ c7 Q40 ^; ^! O4 q, G6 i* J" r& [
5
7 p1 a- J. w; W1 U& T* m6
% T; S4 t) Z3 B% c3 Z, i7' R- @- V2 T @; c& z g! T# Z! S
8/ x4 E! P: F% _
9$ O* V/ ] \" y, P* Y e
104 u1 k9 f9 b& `
11" ` _2 p. m) L# w8 T& d C
12
* P1 I) d" a* k1 G13
- _( W' w% @$ P7 O8 ?$ s14
" Q0 `: [+ P$ m# 使用上面定义好的类进行画图
; Y1 U$ s& D* p! }) Efig = plt.figure(figsize = (20, 12))# }0 q+ K9 s! N5 r
columns = 49 n$ C9 o* W! l, e% W4 g9 b
rows = 2
6 u' k/ `4 ?* W/ ~; F, ^6 R, p! d0 @& T4 b+ {# l! @' Q/ b% X
# iter迭代器5 C. `' e; M3 {: E; R
# 随便找一个Batch数据进行展示: Q% o6 u# e$ L7 r
dataiter = iter(dataloaders['valid'])% P* b0 W/ j6 u8 n( L9 ~, w
inputs, classes = dataiter.next()
# h9 ?/ X+ T: H1 w! H0 K* c8 z U" f2 _7 g7 B$ e9 x; H9 I4 ~
for idx in range(columns * rows):
5 U# y _, p. w) v% K7 j ]" H ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
5 l: d$ _3 u/ y4 {8 L: k # 利用json文件将其对应花的类型打印在图片中
8 k% f, A6 ?" t. m ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
6 Y# e- o$ O' g; h0 M plt.imshow(im_convert(inputs[idx]))
* ^' B- C2 ?: M$ r8 Dplt.show()
& p$ C5 C" K. J& o0 A- V0 c) j. j5 @
% V* A1 c3 C4 c* e4 I8 c* B; S# D/ @' Q10 }1 r K3 h$ j* e( O
2
7 ~5 ?+ D: o0 J# u+ E) l7 I) V3 s3. J5 a1 |* L0 L" o& _
4
" t- k2 S8 ~7 z" x52 [6 H4 c- T v* i& b
6
. L/ d/ G0 U1 j3 @" y! Q6 Z7% v! _* W# e6 b: N* `5 ?; P
8+ W. V! I" ~% {; x) i' h- ` o
9
( B% c" U" b& C( J* K9 p( d109 {! w2 a5 N% ^8 ^, x% K& ~1 Y
110 r' C7 G) @$ C6 A: t* r6 j( M/ N% U
12
+ j- i2 Y, z! \: V- ~13# |) {' i' {5 w* Z
14
: H: i k- \) B [) b" D15
4 o5 O4 _# J" f/ E3 k4 x16
' P! k7 H4 o' N3 o ~) a7 g% d% f' f2 P/ B
* L" l h' {% R5. 加载models提供的模型,并直接用训练好的权重做初始化参数, B6 g0 e% G) q/ q7 j7 Z( |
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
% m" p/ R# v5 A# n7 u# E* b# 主要的图像识别用resnet来做
! c4 d6 `$ ?/ j2 `# 是否用人家训练好的特征1 f6 N7 P, ~$ I, V. U5 _
feature_extract = True& A* \) T e l& y
1
- e4 r) q( @# Y6 G2. H2 Y) h2 _7 A9 a
3
$ p! K V% h2 Q& c6 }0 R9 \4
3 c* q! ^ Q5 h- I: A3 r# 是否用GPU进行训练
& I- m! w+ B+ L; T8 Ftrain_on_gpu = torch.cuda.is_available()# |: K! G3 z: O8 n
( T" L# L. H9 _6 rif not train_on_gpu:* q. H, w# ?% v* r4 H
print('CUDA is not available. Training on CPU ...')) `% {5 V$ E1 f0 U, x; e
else:
$ h: i$ @! b5 W+ F V- t/ R/ D) w2 } print('CUDA is available! Training on GPU ...')0 ?) |6 A' F! Z- N/ O
m. H; u3 `- h: M( @ r
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')8 r' P* w3 [1 @
1" C/ V" c& q" P% D
21 |, M$ _( I+ I0 _: d( X
3" d: J7 C: u2 W" \& u7 W- _$ i
4& N2 X/ \. W3 b9 {
5
( J' H% E7 O5 B' z; ]) q67 n ^" f: z" g' T9 e: C- f
75 W$ V: ^& E/ c9 {+ H
84 f3 O/ ]; p& o
9
" n: {1 x8 x6 i8 |& }# R+ e2 ?CUDA is not available. Training on CPU ...% g" \6 V/ Y8 P" R, X
1
) j+ s; @9 Y2 B* X# 将一些层定义为false,使其不自动更新7 ~5 |' \ T! }1 l8 g9 u4 @1 I
def set_parameter_requires_grad(model, feature_extracting):" f$ t2 b m# o U0 g( D
if feature_extracting:" m" F% j- _, l0 T
for param in model.parameters():# b/ M. C" h- Q: _
param.requires_grad = False0 B* g9 [" o% f$ N. y
1; L3 r/ q+ M" w2 k9 t
2
, l. |3 u; G0 T6 b$ V3- }. b, @* Y- M
4" K5 P' f8 S, u9 z" L9 g5 J+ O
5" F0 `, d3 q% x& T+ {) S) d$ C }
# 打印模型架构告知是怎么一步一步去完成的
' E, F! V5 \! B( V6 w: f( }+ I' }# 主要是为我们提取特征的
+ S8 E* F7 c( F" t7 F, c% U3 i; i+ ?0 G) S9 f) u8 n; E% G
model_ft = models.resnet152()
, E1 c7 D+ Z$ j- C. z, u" Q- lmodel_ft
* b0 f" A+ K' e& {1' `$ W( X- e7 I' P
2
9 F% V! W' X9 p2 m/ F3
' O1 N8 `- H7 v% ]; M4
, G3 g4 [; c3 K) X7 {" b; E# X$ w% q5
- w* `4 a, _# w; B3 p( M2 |ResNet(7 W9 O7 I" q& D, x0 V: s$ L
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
) t& `! ~4 A+ @ (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ }: ^; y% G* F7 I, Z+ b& V9 f
(relu): ReLU(inplace=True)
" O4 k' U$ z5 d" U (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)" d5 l- j3 K; q
(layer1): Sequential(, @# U* F3 m, h4 T5 I
(0): Bottleneck(' {* B; I2 x; U6 w
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)" u- R# T; ~8 E4 e. H0 j
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 ^9 w. D3 S4 S& K. F7 E (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)1 c# q% l8 O; H# Q
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
2 V) @9 q$ S( M- W; p( n6 @5 d0 o0 J8 x (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)5 k, ?7 ]+ `+ x, p5 v. {/ C, x- |* a
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)& u T N# d& d+ F
(relu): ReLU(inplace=True)$ q# ]1 j: N# n
(downsample): Sequential(
8 U! ?2 w9 f4 L# r# t (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)1 h6 \4 ?1 i2 U9 x1 a/ c
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
5 y" F" G1 V( ^$ s7 N )! d0 e4 b7 x; q% B( e
). ^/ w* F3 w# ^/ m5 Q B
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。% ~- c) [. |0 m5 X' s) v6 J5 M2 A
(2): Bottleneck(
+ u- w" e% K( d7 i( j6 J1 i1 m! Y1 K (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False): o/ Q4 z/ \/ }# ]: h# M/ u F
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 s& U9 [* N* O' q) s+ ~1 m% N" G
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
- Q) n" f/ U9 C- Z' P1 D8 R8 }8 e (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
4 Z1 |* T5 f; z% E% d+ i (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)1 Q) b9 @: ]2 v. Z! B: [
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- _' y) f3 M& ^$ r% N6 | (relu): ReLU(inplace=True)4 D% ]# g m! e+ x/ F' @
)9 `3 _' ^# F5 z' B, e/ b+ F
)% C: {4 e5 O. C3 {6 m& e
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
( G6 w9 L1 S9 g X9 P, I! i (fc): Linear(in_features=2048, out_features=1000, bias=True)
1 @/ a+ ~% x: q' Y4 v+ l4 I6 x)& Q+ K3 _: m9 H2 Z
9 V, ]# j4 T6 ]$ _
1
) Z' ?, F- X1 P, ^* {/ @2
0 B; R" K+ J7 N" j% S3/ q6 N2 c# `8 Q5 ^) t# B1 A3 @) k
4
' v; B5 a3 \: E' S7 N58 _# S/ Z) {7 `: C3 |! D- M
6* [# x; j5 i i( x/ V0 r
7
, d1 c1 k9 G- w7 |0 O8
- ~( j% A3 ~3 |1 Y94 B9 X" H4 T8 N7 O$ |" Y2 G7 S$ P# X
10& c; C/ U; q# ^( X( e) z5 m- \
116 [& e! d7 h% f: x# T! g( ^
129 |; ~8 E! V$ a7 D A3 M( ~
13
$ U' J/ q9 w/ X$ B14$ K& I1 @/ a: `! \" Y1 c1 w& C
150 U3 n1 T; U; }. N2 k; n
16% u1 s$ e. I% ]8 q2 z5 `- ^+ \; j
17
/ b q0 Z8 s! E; n( m4 l18
. G1 g; Q! h! b( e U. `19$ S W3 @& u+ ]7 y
208 v" L0 `7 ^4 _" I J/ e
21
9 V! ]) ?2 i2 y7 z& {22
' _* j# K4 s6 P, D. J0 F23
! h9 F$ J, v! Q& J24
& Y6 T: k. {5 f; M1 ]) K) B25& \1 W9 a. M3 n- y" e
26
: [( T1 `* P% O' I3 W# ^- i8 J27
& G1 {5 Y& k+ b; _0 N28, Z! d* D" m" d5 J8 t6 K o7 N$ O
29
7 ^ y# W6 Q T, R; g$ Q30( p. l! W6 V3 b4 b
315 t5 {/ [! {- C, B- ~
32
P9 M: {: m0 y" E! R33
6 i3 s# H7 p9 T9 \6 z最后是1000分类,2048输入,分为1000个分类2 }4 j7 }" I E
而我们需要将我们的任务进行调整,将1000分类改为102输出
- }+ E) n$ S; r- }. [
; ]* Q$ E i$ i6 g0 B% \ m6.初始化模型架构
3 U- w- N" R, m) O) S7 k: ~6 c步骤如下:8 D3 ~) B5 m& j& \6 V" i
" ^3 _* Y' p9 l0 H
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数" s% m; ~( Z# q7 T l, u; y& _. o
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
: y& S1 V! u( V' H无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数4 ?! K3 ~6 R( u# }2 X
官方文档链接
. Z: b- F' k+ U- {" thttps://pytorch.org/vision/stable/models.html
* L- _. {" N7 S6 ?; Z2 j0 G: L# P; C0 g5 O
# 将他人的模型加载进来7 G J3 k; w. i! _8 ^
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
# u8 Z- i1 Q3 {9 q# X2 ]8 E8 h/ } # 选择适合的模型,不同的模型初始化参数不同7 i7 Y! Z1 ~' ~6 ^$ A g2 ]
model_ft = None7 L: T. s. B6 F7 Z( G: W1 [! B
input_size = 06 K- k5 g) `% m0 h; @5 Y* r
7 J9 K& s$ V6 }4 l, o' z if model_name == "resnet":9 H; u* l+ A1 r+ l- F
"""
% [, r4 z) v( m6 I$ N5 X Resnet152
$ u7 S0 A1 u" m2 b5 a """
, i+ }& d8 b' X I3 g( a
% h% p$ ^9 `* t0 P0 G" a # 1. 加载与训练网络3 G1 P/ b1 H4 I3 [6 s
model_ft = models.resnet152(pretrained = use_pretrained)) ?3 ~/ |7 }) x8 L* p
# 2. 是否将提取特征的模块冻住,只训练FC层
% _: h' w. U) c" I set_parameter_requires_grad(model_ft, feature_extract)& Q1 U9 L$ ]: U% r
# 3. 获得全连接层输入特征
$ l* z/ a, w$ J5 ? A$ x num_frts = model_ft.fc.in_features
( o4 k- M) C1 H0 x! _ # 4. 重新加载全连接层,设置输出102
- A& q2 l, G7 b6 Q& |, A model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),- R/ Z, L7 W: d6 M
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
9 a5 }4 B$ T# {& m input_size = 2245 E; Q6 T, @2 N X( i
9 K, P0 o; e7 H/ [2 v0 ^' L* f& C+ F elif model_name == "alexnet":
3 S- F' L6 S( O$ [ H7 v2 ~ """
7 m# S9 d2 _" l# D Alexnet* u- r% M' V2 f4 c+ e4 |4 {
"""
5 w X& X# r8 F7 H: v; \% Z" f model_ft = models.alexnet(pretrained = use_pretrained)6 f% k: E5 M+ H+ g4 ~4 K
set_parameter_requires_grad(model_ft, feature_extract)
4 X* b6 P& ^$ e2 E
1 z7 E& W! O6 z+ u # 将最后一个特征输出替换 序号为【6】的分类器. N. I5 H0 Y! k
num_frts = model_ft.classifier[6].in_features # 获得FC层输入0 D) @+ G: z' I1 O- D5 l
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
! h/ x4 K' @8 L [4 w7 D input_size = 224
! @4 @" u4 |" ?1 d( J: x; a5 D) C2 ?
elif model_name == "vgg":" T. Q' g2 Q' U
"""$ q5 g( `, p2 h+ Q2 U3 J
VGG11_bn% \6 x _7 X6 q( k7 M: R
"""8 f$ ?) ^3 q- m6 ~2 `/ X
model_ft = models.vgg16(pretrained = use_pretrained): A# V' e9 g9 A ^3 O
set_parameter_requires_grad(model_ft, feature_extract)
7 n0 T6 J, |* S7 |! y# y$ p$ T: q num_frts = model_ft.classifier[6].in_features
5 C3 w& l R9 g# j# e: ^; K model_ft.classifier[6] = nn.Linear(num_frts, num_classes)& x5 `9 k1 Q, N' h" f
input_size = 224- G. w$ I0 J0 R
0 P4 q! `. ~- v* |4 Q elif model_name == "squeezenet":8 F) D: H7 h& T$ W
"""
5 g; H8 X& C4 z) z Squeezenet. m L0 w1 `# V$ w
"""
, k/ w2 v: Q$ C" i/ S1 m model_ft = models.squeezenet1_0(pretrained = use_pretrained)
# w; b, j, k" e8 v- W9 C g; X set_parameter_requires_grad(model_ft, feature_extract)( Q4 S& A4 y" i$ D
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))* z) d" E3 K) S% u/ u0 _2 Z
model_ft.num_classes = num_classes3 }+ G+ x% l8 ~8 C! d, t
input_size = 2244 h* B% S; T* q, W( F- t" O# z
n0 }7 M p2 u0 X5 {* P
elif model_name == "densenet":+ `7 `# p6 a$ E a2 Y7 j: N. C
"""
9 R" Q% K& B" a8 q6 \ Densenet
0 X5 W) P9 Z# R' Z' E9 d$ ` """
- o: q& x7 c6 q( n% l/ n model_ft = models.desenet121(pretrained = use_pretrained)+ Y9 \* _5 f1 G3 n3 f U" U% n
set_parameter_requires_grad(model_ft, feature_extract): _+ d) H8 B2 Q7 i; }0 Y+ k7 p
num_frts = model_ft.classifier.in_features6 W& P N$ S+ o1 Z! n' v% h
model_ft.classifier = nn.Linear(num_frts, num_classes)7 O# C6 X) U# k! w. J+ M
input_size = 224
3 X, O i$ t0 n5 G. I
+ R( `7 E- ?, n3 G elif model_name == "inception":
0 L' C" `; K! x4 B5 ?$ g """
a# \: P; [0 d, h% y Inception V3, R/ O' v8 @2 q, V+ O
"""
- D) Y) z2 J" b2 e+ I' ^5 J7 e& I model_ft = models.inception_V(pretrained = use_pretrained)
2 l g! u8 x7 F4 r# K set_parameter_requires_grad(model_ft, feature_extract)
1 e# j/ a! [, c) `+ R: u& j) }. U+ w. _2 E" n7 C
num_frts = model_ft.AuxLogits.fc.in_features% ?: v' F! Q4 u7 y( N- N
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)2 s7 D3 k& U: R) U' S5 q# q1 l/ `3 C
. M- n9 _2 M. M, y
num_frts = model_ft.fc.in_features: @ v, f6 z: Z8 J. G
model_ft.fc = nn.Linear(num_frts, num_classes)
8 d" X# E2 q8 m% D; a input_size = 299
0 O" `0 v1 Z5 t' ^4 o" L. ~( x
7 N. e" j q" V0 \ else:! ]' b' @9 L7 i
print("Invalid model name, exiting...")& Z, P" i/ y* E# F y- l6 C
exit()- o5 A0 c+ V) F" T$ R
+ s) Z& I; I0 |. w) L return model_ft, input_size2 m4 k8 b4 P% k% G% u+ R
4 l/ A3 S; [9 w& ~4 s# Q) `1
! D. s; s' ?! ]/ {$ P21 M7 N# F) i. b5 y" `$ s& C
34 J- d7 u( [: L
4
0 Y; r, B c2 e6 x6 h; G5
7 [, X! n! i8 y7 j- a% I) `7 ^6- N+ W+ I( u. p. |+ V3 N& M; V! {
7
, \) ?" |( [8 E! K8- U+ H7 B, G+ U
94 ^! e# o5 @+ }' ]3 Q
10: t# ?& _& U4 _5 y4 i% v8 x
11
* n1 t4 L: q: G/ h; c12
& L7 Z" c1 @* \: \ V. p/ o13* e' j: y. q# j0 w% N" Q$ J
14( @; f# [% }4 a9 r3 u
15
6 a" N4 I$ w, y* ^1 F4 n16* f0 Q* j! v: l% h& L; g
17
+ x' v7 k: _( h; L; J5 F18; I/ t U. d0 H2 m; c- b7 S8 _
19; d2 J, _5 {* P( b. |" G$ e
20, ]( a' |& e7 ]
21
- n, d# U. a% v22, M& v& Y! B% x# T) p) M3 Q
23
( y1 B4 `! Z. A, w3 L246 H" n9 v* t. O- W4 G) w5 k B& M
25. U( Z! j1 _% i5 @* i
26
. f" k7 S/ ]: n N277 L8 A5 t% y+ X& }: T
28
$ p- j5 P, H1 W9 I* N29
$ d. w- {. b K4 |4 E/ A2 Z30
+ C. s1 `4 @$ D3 \- t, D31
+ g6 z) x) ]- d1 m& s+ z' V; \. z. _) O& ?32
' l7 m( |, ?* v2 E33% V5 |/ l" A6 ^) f- Z! @
34" t9 l: j5 [& A1 P$ D
35
6 E1 o% _; r1 i364 E! F# E' f3 u' J. z0 W" \7 R* ~
379 v! ^7 w" t6 b1 D7 K8 X9 \
388 ^" g3 b J. L8 j8 ^
39
; S: m0 H% g4 t, N40
# [5 E' M( R) t9 h2 }* @) f: m6 m41- s; w0 c5 n& {
42
- b6 r" K, ]9 V5 s4 _$ ~8 R43
- |7 K* y/ t9 L- e: d8 O! W# R44
' y" c7 ?* C* O. s/ W45
& z4 i# D+ r4 Q9 Y. [3 j46
; I3 p8 x- h/ U* D0 k( D47
, U! j9 Q& I. m7 }) u48
' B, B/ P0 B% R- v) c, {: P( O7 r% z1 U49; k* N$ ^; j6 O: ^( v: V' b( _
504 R9 I! E+ m# H: X, @8 T6 V' |) G
51
! u. j7 z, } c. v( t52 a* G9 u$ M1 H' T* { @- u5 L
536 E' ^$ c2 ]5 z$ g; D: n
547 [ I6 Z& [* R/ [
55
* r* w+ k3 o! `0 _/ f ~- o4 C1 g56
2 D" {' L& J9 k" J/ N2 w57
7 N7 \3 O$ Z: [ R) e1 P& C0 H( A588 W* p3 E5 ?+ {# t4 A9 s$ \4 [! a- H
59
# `- K1 A; h: l! I3 [, h, m! {60
, V' m" [: s" q' {: r8 |9 V* j61
' y5 t# C2 f2 t& x- I: P7 }62
6 g% a$ }, @* \, W# J3 u# o& M7 P63
H# P& y8 Z. I* s1 H; B g64
7 C2 c( T8 }6 {( ^/ N3 ^4 K65
C% m# n/ ~- N665 V( \8 R' Y& T9 I# r6 F9 I% e
67
! }" s6 ~' Y5 P$ x68
% E- v% T" \8 x+ F4 ?5 ?/ F6 {3 ?1 G69
3 ~7 }" M6 X( j6 r( z4 p- k! |5 l70
5 b5 H2 t* s" O$ O" \5 L. D71
7 _# g0 t8 `3 l# E* m, a( T72
5 S+ G4 R5 U0 h/ v& I+ J" e9 _' X1 Y* p73
' G# k( U& J" k) g( ?) G74
0 Q) @. m1 _" i( h& D3 }6 B" [75$ f5 N6 J/ W* _( D! U
76
9 s+ w) s% M( y: O5 A# j2 o77
. H9 v. g ]$ k" b781 D- Q& S! A0 d" t# g& i
79
/ a! V' {: X" X8 S& \/ o7 R* f802 T8 F& U- C& c/ T1 d
81
5 P' l6 n6 g' Q, @5 z82
0 M* g4 Q U9 f; q7 C; E83
( [* U6 o ~* s, X" f( @" \/ C7. 设置需要训练的参数( i/ f6 e" x6 {6 ~; e: u% A
# 设置模型名字、输出分类数/ d. v! x$ i$ O+ G
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
+ }* R+ E3 {. J/ X1 E. Y2 i$ ^, b1 p9 m
# GPU 计算
. E B6 s9 u0 ]( d% m6 Kmodel_ft = model_ft.to(device); z7 C8 z5 y1 x0 n' s
" ]/ ]: ^0 k1 |5 q: d# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
" j" m% |$ v! m. R% J' Wfilename = 'checkpoint.pth'
9 ?' z& T/ t z! S, p! s
- L7 L+ ^) V) B+ E# D% X# 是否训练所有层
- U" K* D% U0 R9 v0 |- P, v' ^- qparams_to_update = model_ft.parameters()
& y! C+ e- D5 V# 打印出需要训练的层
1 M& d9 |$ }1 m: r) Iprint("Params to learn:")
* g. d! \& a }' C: b4 E: eif feature_extract:6 J p7 G! k; e4 C) a" G: z
params_to_update = []( |; {0 D' G% F7 h# {
for name, param in model_ft.named_parameters():
" i1 v7 z K! _& C; m if param.requires_grad == True:1 D0 L3 C0 D2 e# P6 T6 ?4 G
params_to_update.append(param)
3 M' u/ ]; y ~ print("\t", name)
2 w9 j" D9 p1 jelse:
! \# N9 e4 G* M' b* q8 O4 |! D for name, param in model_ft.named_parameters():6 I/ r h; I1 o6 I
if param.requires_grad ==True:9 I, D( _+ N% H. E$ e5 m
print("\t", name) ^6 K; X( Y6 s+ l& U
% E/ S$ q: B; o6 x$ K17 o$ v1 N7 `2 e
20 n7 E. [/ D5 i
3
2 p, q; A7 x c3 n0 U7 n0 b3 J9 n48 }0 W# R1 q, E/ `( B. v) a. V! X
53 J/ k. B; ~' R/ [9 u
6
) `7 i) j+ z- d: h9 U3 F6 `- o3 Z7
2 c4 c2 ^) r4 x& j7 U' k p, {8. v) e- w- l4 u+ B g
9
5 a+ f6 S0 c2 H10
2 M, \; i. X2 D110 o( u0 c; a- e
12
9 V3 j+ O9 q3 Q( E; E13' i6 d, \/ X3 B% Z3 C3 P
14" x" d) A, b5 ]8 g7 t
15
/ g- f6 X* l# l1 U8 L16
9 L( E) n# A1 q7 {6 B2 p% I17+ r o6 ?* Z4 b, I
18
8 M% m2 Y t) s" @! q! G! A" z, L8 d19* y8 s8 k7 C6 J7 P: X0 w4 D# q& i
20$ ]9 N0 i3 w4 f& y e; f2 R
21% [( v8 C4 g1 \! g
22
* c' |/ |) g2 @- l! T' x& R9 Y23
8 G$ e2 n w# N, k! T6 n, yParams to learn:
+ w( m3 C- b v! d/ R% F- v fc.0.weight
* [; o/ h5 H; C! ?0 K fc.0.bias
& {' j7 \! X- M; J1% @. V4 `6 Q# X a' [- _6 w
2
; d& C# v& f) D/ J$ j" Y# K3
" H# U) u! s, Q7 s" ~$ c7. 训练与预测) i8 _: I4 w- o$ n! i+ A. ~; f
7.1 优化器设置
3 }5 E; D O# e: {( \# 优化器设置9 ?5 g! I$ ^: l# r$ N4 W* M. ]& Z0 O
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
& {9 N0 Y5 ^+ ^9 y% x& E" v# 学习率衰减策略/ k8 u6 b# G* m' J/ r" M
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
% g% R; N1 _- P9 f6 d( ]# 学习率每7个epoch衰减为原来的1/10+ w( ?/ `6 A: C' D
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
! J0 |. D9 B3 u2 W! y6 C; D4 t0 d% h+ I- i' J! M0 ?1 p- J
criterion = nn.NLLLoss()( z: f0 o5 L: A, [: M& O
1% j8 |1 r+ S' B, e, Z: g) C
2% h$ S; E9 t0 ^
35 j# E' D# X) K6 u. ^0 s
4
6 Y8 j+ k( O* j) b5
8 Y* D/ ~: R( C; u4 q' y6
: t1 O( b! C) r! `+ Q3 h* l1 n1 o72 G9 J+ l& C) e. U- c
8
6 ^ }: ^3 P( v3 F# 定义训练函数4 A7 ^) Y, d6 l$ u
#is_inception:要不要用其他的网络* G* k$ u' o4 X0 N4 x- t
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
1 A& \1 v8 A k. m4 A" U! }; e since = time.time()) N! ^& M+ K$ Y# I1 Y* |3 t
#保存最好的准确率! I' B9 p* s! S$ _4 j
best_acc = 04 ` l% G0 a/ p- W+ ]; o
"""
0 F) P+ T/ [8 K i6 b- F0 ^7 I checkpoint = torch.load(filename)
0 B+ z( w" \# K: N6 u: S `5 w best_acc = checkpoint['best_acc']: s5 L y6 G2 m0 z+ `; }' d
model.load_state_dict(checkpoint['state_dict'])
. m* ~: t2 @$ D& S8 U2 P. q8 Z# k3 S optimizer.load_state_dict(checkpoint['optimizer'])
* i. |( g6 ^+ t# Z3 I model.class_to_idx = checkpoint['mapping']
* h2 B! M! _5 W& e """
2 G u) f* r5 g #指定用GPU还是CPU4 L* T! o- j' g1 d$ v* G
model.to(device)
* C& L- l- R* K1 X #下面是为展示做的
; Z5 B v; M+ \# X. ]0 ` val_acc_history = []
& t; o) U- d: P4 ]' [ train_acc_history = []9 R7 D& J7 V6 |6 c3 b% X. n2 z+ V
train_losses = []
0 i5 K s4 ^- i2 n }9 s& S valid_losses = []
* m3 W; z9 c; b! k8 s# f1 X9 W LRs = [optimizer.param_groups[0]['lr']]- K4 {5 \5 p; ]# R
#最好的一次存下来
* y" z5 C8 R7 Z+ P, V best_model_wts = copy.deepcopy(model.state_dict())" \( {/ ^! R/ ?' N9 s- g$ j
. ]9 P; Z7 S+ I/ R: A6 t5 g* @ for epoch in range(num_epochs):: x0 T m' ^; u1 |# L
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
0 S; T5 a/ {; v print('-' * 10)1 y+ S. u2 R: l- J1 Y: K+ s; P' O6 ]
4 g: ~1 P* t8 m # 训练和验证4 y+ {1 ?& ]7 o0 i" ]$ g
for phase in ['train', 'valid']:" {; V1 Z, s# |0 ^% V, J$ B
if phase == 'train':
4 @! c* \/ I3 g) W( K" j* R# ] model.train() # 训练7 }, {# x5 Y+ Y; i3 b0 r4 J: B
else:
- `2 O0 S$ a& B model.eval() # 验证( g/ C* P! i- e
% Y ~( i1 v2 y) y/ f* w5 s running_loss = 0.0: _, K+ R9 O2 w9 q* f, ?
running_corrects = 00 e0 ^9 t: T& s' _. y7 t4 Z. r' O
6 m2 K/ h+ @9 A3 a* G' Y
# 把数据都取个遍' G: }* l0 a3 E9 g4 Y V2 G
for inputs, labels in dataloaders[phase]:
( ?; e0 o$ O+ ` #下面是将inputs,labels传到GPU
9 I# h# E8 ~! E inputs = inputs.to(device)
+ R( O2 Q3 m$ O, F- h labels = labels.to(device)
?2 i4 Q' ^5 Y0 S
, Y! |6 m4 Y, |: i& {1 w # 清零+ d$ `+ d* h: E0 X2 ]
optimizer.zero_grad(): r# L- r0 Q$ r! _, t* I
# 只有训练的时候计算和更新梯度
9 D* W/ ^; o5 X D3 ?$ E' B with torch.set_grad_enabled(phase == 'train'):
) Z& }$ n& h. p" H4 w, _ #if这面不需要计算,可忽略5 [$ q* U% v/ B7 g$ B4 k8 b
if is_inception and phase == 'train':
$ H( z/ [/ P7 ^- u. l outputs, aux_outputs = model(inputs)# |7 i) B* x3 _" L" ~, t8 N r
loss1 = criterion(outputs, labels)
, K& P' i% j; K. S loss2 = criterion(aux_outputs, labels)
: p8 S2 p# J6 o loss = loss1 + 0.4*loss2% w! z2 E/ {4 v% g9 N
else:#resnet执行的是这里
+ q. c% g5 P; f$ A$ z' ^- H/ k- t# r outputs = model(inputs)
& O1 I" w `! M& k2 v$ m1 d loss = criterion(outputs, labels)+ r, w& G# c" B/ b8 U& N- h( }: z
: s/ s$ b6 F7 x% P, q- o7 P& i9 j- i) H
#概率最大的返回preds
* M3 u$ |( Y" A* [$ r _, preds = torch.max(outputs, 1)) F- u! g( M" L: e! P3 L0 q
" v1 ]4 A! X) T/ I # 训练阶段更新权重! o& G( O; m' O7 ^) @* r6 c
if phase == 'train':
1 Y ]( V* A9 c1 v0 X loss.backward()
( y3 z" v, T6 T$ M8 c S9 H optimizer.step()
2 B2 C7 U7 V! r9 i1 z4 n+ c1 q, c# q2 t8 I
# 计算损失- J0 v3 o8 a/ Q
running_loss += loss.item() * inputs.size(0)
) o6 T/ Y# ~& k- d. X6 X running_corrects += torch.sum(preds == labels.data), E V+ ^' n5 c; Z5 ^5 J( p
. p! g" D; o# {# _" b% @ #打印操作
. R8 O b+ C3 g% L# Y: V epoch_loss = running_loss / len(dataloaders[phase].dataset)
; Z8 m$ e. y) @1 X2 W' B: b2 w/ ^( s epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
$ X; b, ]; c/ {& c# ?% b6 o1 z# z5 F/ `0 ]4 K1 D$ l: \
( }4 C7 c: v H8 m; D' s time_elapsed = time.time() - since
9 r3 E) p$ V8 w print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))$ X/ m3 x7 H0 ]$ d( b7 |& `5 ]9 j& H T
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
6 y4 F8 o! ?! r) [: }/ B: ] y/ V4 u8 S& U3 k6 o
/ d# F c' m) m4 _ # 得到最好那次的模型% W! ?/ Z3 u8 ]9 c# F
if phase == 'valid' and epoch_acc > best_acc:
% x- {+ p6 H" P) M) H9 d& R best_acc = epoch_acc
' Q4 ?( D" s U5 P: e' \' m' o #模型保存
6 \/ A1 N+ h* T1 Y5 t7 H ~ best_model_wts = copy.deepcopy(model.state_dict())
" f# M& a, {- u" X: t+ N" t5 p# z state = {
4 Z$ X. o0 l* j1 m( V4 @+ M8 e #tate_dict变量存放训练过程中需要学习的权重和偏执系数; @7 C1 [3 R7 [& e0 _" N
'state_dict': model.state_dict(),
$ x4 Z; _ v- I/ ~. ?# _; u 'best_acc': best_acc,
. U+ R. N% E" Q7 i. _ 'optimizer' : optimizer.state_dict(),1 e- i" l7 t. e5 A
}
: H; l8 _. p" c3 [ torch.save(state, filename), ?% g8 \+ G B; C
if phase == 'valid':
% K* K' ]! E, ?* e, X1 ]+ e val_acc_history.append(epoch_acc)
! l1 A* x7 Y7 d, r# c valid_losses.append(epoch_loss)
- p$ V1 [2 j6 } scheduler.step(epoch_loss)
7 w, Q' ]: S; L# V8 S if phase == 'train':
5 Y* v1 |- F, X! ? train_acc_history.append(epoch_acc)9 L( r; C% O% g- }
train_losses.append(epoch_loss)$ z; z9 N3 c/ z: C; G* l
* |4 a4 E6 o i- w& S
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
, x8 _6 j. w. c6 z$ ? LRs.append(optimizer.param_groups[0]['lr'])
7 H0 v6 W' j* r' Z print()2 i6 _8 h& c: W q4 M
* r+ `* h4 X9 m, o4 N
time_elapsed = time.time() - since$ S% w* a$ B( W+ Q
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
- Z" ]9 e) Q1 M/ |' t' y1 A print('Best val Acc: {:4f}'.format(best_acc))
1 W0 |) n# \* X* ?/ ?
6 t2 m" [ Z! X; y( }: v # 保存训练完后用最好的一次当做模型最终的结果" m4 ?/ |, |& a! X& D+ T
model.load_state_dict(best_model_wts)
9 T% x: m7 ?7 e; ^ return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 4 n$ c- f5 S" A/ b
0 y2 T- O. M8 O, }' O
; \ r* x9 y: T8 x7 |) [& p' [
1, \% {/ `+ k1 v" J3 C/ _( e* y
22 M. N/ d# w; F8 h
3( Z/ G% f' k# k2 {
4" v; X9 ^. M+ D/ ~8 C. ~
5
2 W% Q! N$ k H8 |* W6) u: _) i7 ?% `- P& R
73 h# ~* ?" |% ^7 b! w E4 A- M. W
8
1 S0 s/ N6 V4 x6 X" F9
( C/ U& x& e3 ], r, S0 I# b, s10. N8 a4 E- q" D& R+ h3 Q4 X# z" l
117 ^* H' X) t. F
12
- ]5 r) {; @9 D+ f0 k13. X( O3 l- m2 k; }, ~* N6 c
14( u) f2 m9 M" Q: N
15% D% a5 ^. @: H, T1 Y2 |% G
16$ ^: q/ R, m4 b
17
0 {% O6 ?9 a/ U2 G18
+ W3 p6 Q) @6 c3 ]6 Q19
* I4 N7 i7 k8 c$ o+ I# T& c20% R5 `. c5 o& ]1 r8 o) K( X3 E G
216 ]; d9 ^- R, p5 o, N
22
* ` Q {3 g5 G3 _/ M5 ?234 O3 W% y w, f4 p G
24
' L7 c3 G1 ]$ B1 g& c25
& d6 \! O! @ G+ V( } [26
9 |2 t2 @+ B! u/ m+ R2 b/ C27
7 k3 M* R, n9 w3 \281 O; f) I# n6 f5 O8 \* k
29
, P3 x& | j7 J/ e6 S$ ]30" Z% q+ e5 i2 ~
31' ^8 I+ m/ o' h1 {4 Y9 c0 [5 O5 V" ]
32" A- X0 Y( s4 e: {! A0 z; S
33
b8 |4 U7 X7 h4 z34. D) f% n- Q6 j! e
35: {- H$ h7 ]' M6 M G( A( a" |, u0 y
36
( ^9 C+ t# L) n0 t37
# B1 n. |" _& ^3 t' O38
, `2 v0 Q3 P5 S0 e5 d- P5 ^. r39
5 f9 X h/ w0 D8 V1 I5 e* v z$ m1 u40
& c5 H3 _4 x5 w41
+ }8 {7 }: Z c0 z42 a+ z; |8 _! S+ t% p" E/ P
43+ s- s! Z8 K" R7 k( M3 i# h" y( N# \
44. ~& j9 g" U5 _% U
45
+ I0 @0 @8 ` H3 l' g46
* n* _/ |$ K; }" }1 v. x4 l* I47
" n) N& M$ G1 P F+ k482 K! g8 y$ B+ Y( r
495 [8 M4 z. p( L/ o Z! m1 u
500 |6 J0 P/ m5 M& [3 P" R% O6 S
51
[3 X! A @6 D w- ^$ d, ~52
8 P( {/ Z( T2 p* V53
6 y! j+ Q4 |* \. y8 ^54' `9 H1 p7 u- F R1 T
555 }2 H# B) e& Y' N/ ~
56; t( y$ q/ h. K. U1 B8 J
574 K- {' Q( s, f3 l
58& o4 ^! T, j' n2 q' T
59
! Y9 v& O% y) c4 I- B60
$ X! B+ J' p7 s5 Q Q. U610 g& }1 q( ^/ n
624 O) ^0 }! W7 c2 |3 N3 ?
63
6 l4 B1 ~; x& m# T64: z3 Z( }4 D: w; b1 j
654 w- J; j2 u6 ~* j4 g* X
66
" P2 x# Q& v: E* I _/ J E678 q: c, O7 ?- g
682 u% G; w% z- U5 j
694 @+ q$ Q) `0 G6 r- H6 t7 R s
70: T- ~- x* ^" [3 M$ ^
71
2 m4 v* Y. f }/ G3 R0 n72
# ~. R2 N4 S$ }4 c# n# Y% \730 X% D9 U* J2 H
74: U9 h* m7 B0 D) Z% H1 E: ?- w/ }
75& w7 Z% S i' [; g) I" W- H
76) a9 z+ I1 _! ?: y
77
7 M$ `- X* a- V) {4 @78
3 a0 f7 T0 ?/ m& V+ |# q9 |79
; {5 P7 J; j" [80* l2 N; A! h# R+ v. K
81' x& k. u7 Y; z. V3 v* Z' o1 s
82 E# s7 B1 X. Q) m$ y" q8 ]
83
+ }) t/ z) C9 C5 T+ [84. v5 R, {' h6 w# G5 e
85
* {+ O* ?9 h# N' a6 |1 p R86$ a' s) w2 R, D, g9 c
87* w2 m' d* `/ D' V2 [3 w
883 I/ F! f! Y% I5 C B
89! r: X y) x; H7 `7 L
90
* u P5 T8 T% M: p$ Q91
# u3 \ C+ ?" l6 W6 Q1 N92
. F: D9 m+ L {93
% k+ v; @. ^2 }: P0 Y; `94% ^0 i9 t* W2 X, |- q# G
95- P' P3 W: c* s3 X) ?# Z. [! n& T5 H4 y
96
. e$ M6 k6 s/ W) L8 h/ Y5 M97# o: F! D; ^; ]& q/ r& b
98
8 I1 u0 w2 x3 G/ |$ T! s- |0 E99
: ?2 A( ]0 e4 W9 j. ?# ]5 }; s" V1009 k; C( t& z% F+ R
1010 M# Y7 i+ a/ Q7 Q
102
# ^* J* G* S- R; w% m103$ N4 z* g/ P. P$ c7 u! z
104
4 Q- H% W7 k* Q! A105
8 |7 T4 x- } l) {106- _$ U# B# u: j8 s
1075 b& Z; Z# z' V- o& I3 x( L
108) F, A- C& s' F( u5 S$ O+ {* Y# d
109
. g! R& A/ _' j2 i1 h* c. l% m6 v1103 c J! V: N% {% I, ]
1118 G" }2 p r0 E, O
112
+ n3 `0 ?) K( d4 V9 K0 M; [7.2 开始训练模型
7 C- \8 E' G2 Y) @7 I+ _( X' [" p我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
' a, m. d6 f4 o
7 H- s0 O" v( D" E6 y; k#若太慢,把epoch调低,迭代50次可能好些: q0 C: m/ r3 K2 d7 V0 [; h2 G
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合0 _! p2 T% C/ d9 q
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception")); w' x* M1 p" Z1 N
! P& O( w5 E; O# B: V1
) \" g! {) Y9 C0 q4 j24 D5 q( {/ `1 m* z6 m
3
' C1 e6 Z! l, s4
+ Y- U! N$ M! OEpoch 0/4
$ c! E' B8 B4 L----------
' ]( J/ e: x0 G8 O$ U5 iTime elapsed 29m 41s" t% h& L/ ^, P3 j
train Loss: 10.4774 Acc: 0.31473 l1 Q( Y4 O9 c( f
Time elapsed 32m 54s8 @, S: N% E. D$ [* h
valid Loss: 8.2902 Acc: 0.4719' h" o8 k0 A" U
Optimizer learning rate : 0.0010000
7 o6 @4 [( f" [8 p4 D
1 b8 B, Z" s' fEpoch 1/4$ N! I, m& @9 t4 F( Q) T7 ~) O4 b
----------& b% i' l: T: Y
Time elapsed 60m 11s; H4 ^, U0 i* \$ P3 m
train Loss: 2.3126 Acc: 0.7053. Q7 U' H8 y! j; G, \
Time elapsed 63m 16s
" x; p& O6 O3 c2 \valid Loss: 3.2325 Acc: 0.66262 i# u8 s I) X9 y6 t- G- t) s
Optimizer learning rate : 0.01000007 s: g+ \( T2 Q( {7 Q0 w9 Y
+ B- D( M0 j7 g x* Q8 _Epoch 2/4* m- _5 {# X$ d2 z8 S1 J
----------
) X" Q2 D6 Z- S* X L: bTime elapsed 90m 58s
: L8 m* \( W; n5 J r( dtrain Loss: 9.9720 Acc: 0.4734; @: m0 F3 C5 J3 v9 T( G( `$ K7 t8 N
Time elapsed 94m 4s+ f3 I( H3 F7 ^2 @1 ]
valid Loss: 14.0426 Acc: 0.4413
, G) o( e% X: ?5 v( q+ h4 H* XOptimizer learning rate : 0.00010009 J2 Z ?- v8 K$ T" h' ^
% x" N4 z( ]$ r- n) oEpoch 3/4
, _6 c% z9 E1 Y4 [----------% t: t7 N' S! Y! o. ^
Time elapsed 132m 49s8 O) A+ a( X; d9 w: ~9 @
train Loss: 5.4290 Acc: 0.6548
3 O& C* H% I; F# ~Time elapsed 138m 49s- [5 p7 }( O1 H/ w
valid Loss: 6.4208 Acc: 0.6027, f. }# w' l. L9 `
Optimizer learning rate : 0.0100000
9 U% _* z: @7 J( W; O, \* ~* E+ U
6 }2 T1 x( F+ P9 d, U& c8 fEpoch 4/4
1 `! u( t! H3 b: X) ]* b----------% z. c6 e3 ~" V9 V" c" U* u6 t; V
Time elapsed 195m 56s l7 ?- O5 T9 f" l$ `: x
train Loss: 8.8911 Acc: 0.5519
w" i, i& y; gTime elapsed 199m 16s# Y0 w4 s7 j4 ~- g% `
valid Loss: 13.2221 Acc: 0.4914
: p0 f8 n6 H6 S4 [Optimizer learning rate : 0.0010000
5 q `' v5 O& U& w6 A7 `$ g% g3 s0 {0 L( |# M0 ~
Training complete in 199m 16s
8 q: p+ h* V9 a5 P; ^/ sBest val Acc: 0.662592
( l" y$ c, N. O% e1 x4 W
* K( M; R$ Z: L* z( Q' O1
4 v7 K' v, ?! r, r! K0 T; P2
; H" V; j+ m5 [ Z- M, ]: b) O) m3! ^# T7 c" @' k4 i
4* l/ h$ k( l: ^/ l. N
5
" \& v# \% F, e1 E3 _6! T& {% `3 [2 j3 a2 k8 V
74 W9 P% Q- h8 |' @% }' ~
8( X( @/ x9 ~3 q3 R$ ^
9 _, y- Q! P; G
10) p3 K3 a- F7 l! L/ f" S- p
11
- w3 Q! e; y- A12
# l$ ^* \" h x! p) H, E. e) P, j131 }1 u) u% o( U& H$ I5 {; u# E
14* x+ |+ N5 v- y+ j* X3 f" c6 B& J9 A9 X
15( [- U! Z5 n! |0 {" x- W C
16 g* f5 `3 i5 N" E/ y
17
) E! z# S$ n+ \4 g7 j. j' o: [$ c7 ~18' ]- [2 h" O" I! X8 w6 @8 C
19
- f9 h& [' h$ B5 h5 k( f* H20
, h* y0 n- s. r& y215 b1 q! V) z; O( h% S
22& z/ b3 T- I% b$ w/ N/ G% K: s2 T
236 a( N3 y: U7 P
24
# n) ]0 n2 N- Z: H25
& B6 v3 O+ p. L9 x! z/ r! i26
0 b! o1 u. V4 S. a; o4 J27
; o5 s) i+ R F2 v) a6 r7 a28
4 q4 d3 R5 M2 P4 f8 v7 k" a, [3 [29
1 N0 [* x" d# C; I( o% c( F30
! r; U* ?; D' k K4 O8 S& ?( U31
% E4 D/ V1 P H2 ^5 e; D: y32, u+ N: f( z7 A; @ H) A
33
3 b$ J" \& q1 ]' F' V+ a3 `34$ }! y7 Y5 S& w' R+ w" g. i6 Q
35% @) l* U2 g8 R, N( t
365 b' S8 X$ I' H( T" g
376 V5 j% Z' e" w
38- W- o1 @, |; U" l8 f2 ]3 l
39
8 A* a) v; D: @, t" r' _40/ u% B* C3 Z/ {3 c, @6 g/ g
41
& G3 K6 z: D5 Y+ I42" T1 [0 y) t! i2 e% T: \4 v
7.3 训练所有层
+ ]/ v6 a9 W- q8 r, j! N# 将全部网络解锁进行训练
: V- `% y, F+ X# e$ i6 v- ?+ w5 U7 V/ ofor param in model_ft.parameters():
3 A( K* m; p: m param.requires_grad = True
9 Z; m# N( A# P% I, w; z* p" h; w: G; w7 ]9 x! j _
# 再继续训练所有的参数,学习率调小一点\" H" a$ d/ O) I% ]# K# S- c# w
optimizer = optim.Adam(params_to_update, lr = 1e-4)0 e# b9 _7 m- @6 r: E7 x
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1) {4 l5 ?1 K$ I. N* w' ?
$ _- Y0 D3 m8 }" D% {
# 损失函数
: a7 k- d( V1 Dcriterion = nn.NLLLoss()' e. P: Z7 G7 S5 C% w) ^
1
2 Y6 m9 K) I8 b21 k# r, I5 X2 r1 O p
3
% r4 m3 p1 |+ v z) Y4$ n# A% W) k0 J2 @
54 S- }# [ r/ }# `& T+ {0 u3 O
6
2 e) e# ^9 C# ]3 }) ~/ Y" b! R72 ?9 H% W/ a$ G% i. L4 ?# v
8 R2 O" K5 d2 S% g! j* U7 t: V( y9 s
9
6 ~+ A6 U! {; D! G; \10
, A! D# J, T+ z- k2 I4 w& }/ H# 加载保存的参数4 j+ s0 i# t1 r: h0 g& b
# 并在原有的模型基础上继续训练4 q- Y N! P+ e" _% f$ ]
# 下面保存的是刚刚训练效果较好的路径; h, y: T! a: h% t% p; M
checkpoint = torch.load(filename)+ V& W1 M9 V: f
best_acc = checkpoint['best_acc'] F. |( N) {& l5 X8 s; g
model_ft.load_state_dict(checkpoint['state_dict'])
6 f: G u2 l) I& q! w' ^optimizer.load_state_dict(checkpoint['optimizer'])+ e& W* `8 L6 q9 \7 V
1
! q9 j' F9 R Q& a& ~2 {% O+ ]# Z) a1 z- V v# }8 _
31 v/ I% c" z4 t/ A4 L/ M
45 h& t- s6 c* m7 }
55 V7 {5 J1 {1 D4 p( F
6$ a1 ]7 o) y6 L3 Q' p; ^2 T' d
7
7 c, V5 n8 w. F1 v& g1 s5 w2 j开始训练
6 A1 @0 n. x. K5 q% L# h注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考9 {' h: E$ [# n2 h' R
2 s- f. ?4 }' v) V+ l5 k/ v/ Xmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
7 n1 G% `4 J3 ?; ?" g1; G& Q' b7 I0 N9 @
Epoch 0/1
9 W I) r# E! I! Q/ L& u----------
+ K1 C5 E5 x2 Q0 D* g. TTime elapsed 35m 22s0 N+ S4 q( [- I ^# V2 ]5 L
train Loss: 1.7636 Acc: 0.7346+ ~9 M& Y, L' ?4 u
Time elapsed 38m 42s
( E0 \1 f6 C( K; c! k* _! Vvalid Loss: 3.6377 Acc: 0.6455) ^" ~- W9 W) c/ R* O/ W0 c' j2 ~
Optimizer learning rate : 0.0010000
: W$ E5 o. v4 T9 F' V2 T
$ @4 C9 ?8 S: U u& i; gEpoch 1/1
4 u4 C) k# l/ }1 p+ P) w----------
4 B. U6 h9 @9 T X0 N+ NTime elapsed 82m 59s
" Z$ C* ?3 d+ \6 y" strain Loss: 1.7543 Acc: 0.7340) M0 a/ d5 Q3 O% F
Time elapsed 86m 11s
$ A1 f0 k- v2 a6 a Z0 Kvalid Loss: 3.8275 Acc: 0.61377 h$ W! s. E# r, k: T' o
Optimizer learning rate : 0.0010000
b. G. C( h6 V3 q7 j& C2 e. {( q7 Q0 }3 w6 t
Training complete in 86m 11s7 F; s& N. {4 \+ k
Best val Acc: 0.645477
& v4 E; `% `; @' Y$ p8 q
, f# R( ?, ~ }% P3 d1* p7 }8 g0 O& z
2
5 R, z7 d# t( h3
9 z; P" t. O y w2 s' m4
8 ~ i* J) K- J+ |: _, L+ Q5 W' q4 R1 {- n) r2 c+ J6 F
6( e9 {+ a! U1 k0 C& |; |* W
7
9 ]8 ~9 r9 \& m5 n8
" J% K) E% O% C9
6 Y, v {7 R2 @2 l9 l# R, c10
, C& H. t$ e/ h: b& w11$ I# B6 T# n& v! G1 \( k2 ]
12
$ @% F. b- f5 [8 U3 ^/ K4 Y. }: p13% Z, R' A# ~) B! _# l/ a
14
# N1 h% J. o t- w0 E" {15
}& @* ^# `% _165 p9 I( J! W9 q6 F7 j$ k9 W& w: ~
175 x" @3 i- o/ P* z. x2 L
185 y1 m9 g3 ]6 Q% t* \5 m
8. 加载已经训练的模型
+ [( k3 C, w- @6 _相当于做一次简单的前向传播(逻辑推理),不用更新参数' _3 u% H |" s |
7 {, |0 n6 ?' K2 v2 h8 o7 [! ymodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
, S! N* ~, l1 i I
; F# U4 Z7 e; _8 ]0 O# GPU 模式3 V; E$ Q: A \
model_ft = model_ft.to(device) # 扔到GPU中
7 c2 ]* R7 h( r% `' y! T5 H9 l; U) ]. l0 j% P% t1 s# `; n' P
# 保存文件的名字( }3 B1 i9 X6 T* b- _
filename='checkpoint.pth'
, @: ?% H& S3 V
- z; D- U+ w }8 ~9 a# 加载模型
. f8 v2 d0 u0 S! A4 n2 {checkpoint = torch.load(filename)
# h M. H5 p9 x8 @& k1 b) ybest_acc = checkpoint['best_acc']
5 M8 s _9 J; B* e2 {9 n' {# Dmodel_ft.load_state_dict(checkpoint['state_dict'])
( t# u# c9 \% e! @1
. z8 B0 X* K! a) p' p |; f7 h29 Z- P8 h# u" @% C1 Z3 s
3! a% }1 U/ L/ @9 Z% }' `9 T
4
3 Z' U) e( O/ D3 S2 T5
A+ E8 U6 T/ s8 n2 F69 h6 o1 F( H% L6 g i2 _" z
7" h" z$ Y J' T' n. u+ w
8/ P8 e R! x; n/ b3 G- y$ m
9# J; w$ S+ F, q2 e I4 h5 b
10, h' Q3 }" @- g9 M
11# k- x+ H$ [: K! k' T
12
4 l6 y, Z8 _; ^5 F4 d% f1 |5 Z<All keys matched successfully>- R5 r' p- q9 H7 X0 m
1. Y; j5 u: f' y2 I; Y
def process_image(image_path):
- h# L" ~ _, ~7 B. F; | # 读取测试集数据" T( q+ S6 m; f b: @3 p: \/ _2 a( g
img = Image.open(image_path)2 `& T7 u4 _# C. j9 c ]& S
# Resize, thumbnail方法只能进行比例缩小,所以进行判断5 d+ S; r2 o3 |1 @1 I
# 与Resize不同# ]+ I6 z2 B5 D
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
. h& \, W. |. w/ o$ @ # 而且对象调用方法会直接改变其大小,返回None
; k+ i8 {5 O; l L. Z if img.size[0] > img.size[1]:
U5 g A7 W/ X( G, c img.thumbnail((10000, 256))+ v5 [& z$ F7 a) `* V: j! b! N
else:
# ~0 @3 N% g5 B# o+ g img.thumbnail((256, 10000))/ |. c* Z5 Z: O0 z
. d) d% y E2 X- Y1 i
# crop操作, 将图像再次裁剪为 224 * 224$ a0 L4 ^8 b; |8 S9 s3 m/ l, h4 `
left_margin = (img.width - 224) / 2 # 取中间的部分( ~( s: r9 j- E5 O0 v
bottom_margin = (img.height - 224) / 2
+ m6 U: C1 T0 X+ {' {( J$ O& ~ right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
* r: n' i; U' j8 K" { top_margin = bottom_margin + 224& v$ H* D3 X. T( q0 n
! V6 @0 g# R; I- S' |4 m Y
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
; ?; G% ^! F. R \9 X2 g7 _- ?8 o5 w! ~+ x5 Q
# 相同预处理的方法6 v7 o* ?+ y& \5 ]
# 归一化
6 p, |$ s# k3 J: M img = np.array(img) / 255
9 ^( l$ H4 v5 u# Z# s6 E; \+ @ mean = np.array([0.485, 0.456, 0.406])3 a( Y7 X, j3 G
std = np.array([0.229, 0.224, 0.225])
$ F$ `$ \0 F0 C* z1 y1 l8 e, j img = (img - mean) / std1 y, N1 r$ p' }+ I. `
T( M a2 |0 ~
# 注意颜色通道和位置/ A% f$ Y* S' Z& d) m6 H
img = img.transpose((2, 0, 1))
7 O+ n6 s5 D/ V* @6 K1 j/ c& {3 ^% W7 p% r$ h/ B) l6 `
return img
/ K1 Z. n. w9 l, s; E" q: n5 @* }, p/ S4 f0 R9 W
def imshow(image, ax = None, title = None):# L+ Z) T4 J1 e; J' G2 v0 a1 g
"""展示数据"""6 e8 ~9 M: k: r* J2 J1 B) |
if ax is None:
8 J5 @$ l; v& K$ t# f fig, ax = plt.subplots() O1 G. z3 s% r
% K+ x2 l! K5 o% e O/ ?) @1 z
# 颜色通道进行还原
. P) L# e* O- p: ` image = np.array(image).transpose((1, 2, 0))
8 L6 }4 _- |3 a6 N+ g, E' o& F) b2 p9 V* p l3 r/ m
# 预处理还原
8 S7 a3 W/ J2 W0 n# n& B! ] mean = np.array([0.485, 0.456, 0.406])- B) {# |8 U; P. k
std = np.array([0.229, 0.224, 0.225])% ]/ ~& P* r) i
image = std * image + mean
) S: a) F, D% ` image = np.clip(image, 0, 1)
& F) ^) h2 i# `: K7 M
0 E% u1 I( |2 f C, h' i4 M# _ ax.imshow(image)
' `6 i# @1 j; t1 X ] ax.set_title(title)& e( `8 _+ R4 G/ P/ P& ]% ~
6 c* `* [7 e) c1 H5 d/ v: B
return ax# m; |- n% f! c+ g$ P) s4 S
- X& L F* t- r* y& F
image_path = r'./flower_data/valid/3/image_06621.jpg'5 w1 ]9 d2 p- g. P, D3 k& |
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理, [) l+ P4 X+ A; E; |$ V
imshow(img)
4 O- v1 ]- H% i- ]$ G5 ~9 X1 ?- {8 B; N, D* V/ x+ f' {# B4 Z
1
& U6 k( W `9 a) b# P2 N% B3 [: }2
9 D# E/ \# ]9 F ?1 d8 \( @3$ B N: E. Y) e: q7 E
4
9 S0 v% z( Y& E* E+ D51 e' _; R+ c1 P" y8 ~! ]8 `/ G
6; B M; k: Z0 l4 M8 \1 t4 E! X4 P
7, [2 C7 m" g. M
86 v, ]& V0 b# D( x
9
9 \) R* Y4 r6 Y- M$ V x" [' y10
9 f# N) H6 }0 v) [" ^" @4 v+ L11
r9 Y) E# L7 o0 f, |1 E8 j12% l. b3 ^& O4 t' x; r9 J, U+ Y
134 F/ D& J3 R# E. a5 O
14- M3 O0 N; d( `3 D+ l
15
- b% L8 l: `9 [16. Y) N4 b7 D5 T
179 J ~0 `9 Q/ [) e. ~7 [% N+ H( P
18
& f+ i/ g" E$ A) z0 C' m19
) ~7 ^; f. T& i% V1 d+ B20
% b( f# `9 y, M Y# w W, |6 z21
& B/ H/ K0 R1 _0 l5 f/ Z: Q$ d222 X9 K& ~4 H8 d7 a$ q
23
; n4 U: D2 k) }* q7 ?7 P+ W3 e24! y! E6 a4 H) J- j5 I; E* P
25/ d) r/ _1 O% u! c1 n% ?3 `& O6 j
264 W! V& ^" W5 a' L6 w8 |+ c" u
27
8 ]' r5 a5 I% I$ J" j" a% [28
+ u' _: X" F/ h, T j1 B3 Q: S0 V& {29
8 q! h" J7 N0 A30
, P! s. x4 A& N5 d+ n: j31
. f9 W9 u, n6 h: \ i32
. W3 @% ^% e7 g3 o1 b) O- b33+ f+ J; x1 z" L( G. [. m1 S6 b
345 c. W8 p. Z5 H% ^+ F& B
35
# }* _1 c: Q, L' D1 H u7 W# s36
+ E- y$ X5 f* d6 D# I& v+ F t37' ?) d: p. G9 a: ]1 Z
38
0 _. W7 e1 Y$ K! a' A w39
( G. _* ~2 M9 d" y/ C40
' n, P I/ D$ b/ ]. x4 ~, [0 A) j411 N4 P$ o0 h7 W6 T# i
42
# U9 z: `! {& T43, Y' q9 N t1 l* ]/ |
44
) b5 f8 I2 ]% B2 T4 r45
8 e8 U& t! P& b) p1 C* H46
; r# ?9 g* y0 t! f6 S/ ]3 G! j. c/ [47
7 j/ y) C x, b7 W0 V' i) e48
$ W/ r* r# ^& }9 j% q1 z) Z, R49
* S% I1 ^5 e# [; a50( K7 n/ G- {1 I7 F; R- o* T0 f" u
51: r& R3 M$ @( `& G( a3 g; E
52
* v8 M: D( U3 @4 e& D5 f- C. K7 y9 E53
' X1 `$ C4 ?$ [- i+ d54
; Q2 X6 @& W: _6 t9 Y n7 I<AxesSubplot:>
- i4 O2 t' S ^) H5 u# a$ q; {1. Y; a1 a) L4 Z7 X% M( ]. Q
5 D M) {8 u/ {
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确1 g" G! U7 s: f2 C% t( q C7 { {
1 O5 H8 A' d' L! z( u; limg.shape' n5 M+ ?3 r$ }6 F. ~* S; H
12 t9 W; L2 `+ y/ ~+ F5 ]
(3, 224, 224)
" T9 n( Y" X0 L3 _1
$ ~+ k# r2 x; ~! b/ N证明了通道提前了,而且大小没改变
. r) B8 _* c& b7 d Q2 [ D8 g: s5 M3 R3 ?+ \# O! _
9. 推理
9 {: S* o; p, S8 `img.shape
) m9 I3 Z0 o1 q5 g( M8 e+ {6 A* f T1 V
# 得到一个batch的测试数据
. l/ h, d3 B; ldataiter = iter(dataloaders['valid']); T' O: D5 r, v, V* Q
images, labels = dataiter.next(): c6 G6 I0 r$ u; d
& o' q! Z; M$ o O# X1 l
model_ft.eval()) f& h' }: n4 f; N9 W! A1 C
7 A! N6 }- U6 [% T2 eif train_on_gpu:* N, g0 p% P! O
# 前向传播跑一次会得到output
# J7 ], f3 K9 v output = model_ft(images.cuda())5 D7 N0 `7 A# ?1 l5 C8 Y
else:
1 h* O# k; T( F5 \: Z% k7 z output = model_ft(images)
6 L4 C/ Q5 O' o; k! c
6 F ], |& W/ v; @* r; G8 }# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值2 M* m) i* a- y: \5 w( j& H
output.shape
: ^* V: D2 a( I* `; G: k! Z+ F
" n k, ]5 k* F11 w) F0 A& n- ]0 a
20 J1 b! ?1 Z0 u" M9 ^
3% W/ p6 E" r, @
4$ D9 ?* n1 H. v
5
( U" y: w% Z& q6
% L9 \7 x9 n) @+ i1 J S! B0 k72 E( W, T1 v3 z" B; P3 f$ p. E
88 F# |3 Z0 y4 L D
9" K$ x3 ~8 F& C. K1 n- ]8 G
10; _* D; p* d/ u
11
# d6 c! W( F- P# o) M0 K/ y12
; x5 b2 l0 \- \2 R/ Z8 Q* L& g137 X. E$ O+ I/ `" R* R) A
14( d. K1 m& l' `' Z
15
( J) N; H- k* i; y, g% z# \16
$ h7 q4 b9 X6 V9 d# Xtorch.Size([8, 102])
5 ~- t& p2 V' s0 b1
0 r& R- G( U" a+ }( ~9.1 计算得到最大概率
! a @0 T: D! E% o/ S. F4 S_, preds_tensor = torch.max(output, 1)" I+ q7 I6 @" c
) F/ G. i0 F# C0 W
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量& I/ \8 l, h; r2 g6 V
1
& h3 P$ _* d/ ?0 g X: i0 C+ l: Y& U2
5 x! J- v- a8 o' l$ H m& }* P) x3
+ r" A% B0 V: S* x% G! L9.2 展示预测结果
, V3 ]/ ]0 p( [% r% b4 g# ?fig = plt.figure(figsize = (20, 20))
- o* x- F0 _( u+ P$ C) ?- ]columns = 4. _! ~+ g4 u: i1 |
rows = 2
9 p3 Q; \& X. n9 G/ H: b! A: z# d+ b/ d# n; {- G# v: w) {% D
for idx in range(columns * rows):/ Z& m* B% O; p5 P- p! z. y0 g% f
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])9 {2 ` _3 S* d* [+ Z3 s7 }4 c
plt.imshow(im_convert(images[idx]))" O1 i0 ]4 L$ ~: @. B( B+ p, O
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), & v# c# C0 |" y
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
$ ~7 g6 ]. t5 `* lplt.show()
n/ J1 t+ H1 n3 G W9 M# 绿色的表示预测是对的,红色表示预测错了
a' q6 W& O- o" W1
9 f1 r& C ?" \: {& _' v. `2
0 I% N4 ?9 U0 H+ Z3$ @. A! ?+ Y. L) ]/ m$ ^# x5 D! i( o6 x
48 I, }+ L: I1 k; i4 F
55 w6 J. m8 f/ p- ?* G( g
6+ J6 k/ @; Q& [: y9 `% n
79 J5 m6 C! Z/ \
80 o0 r* P% ~- ^ w$ ~
9
7 ^- M! d, Y& a, U$ d# K7 S9 q10
1 I8 Y6 s# m: S112 q& d } K; A% n T
1 A7 s; H K6 E1 D9 ~
4 X6 h0 w* _2 G; Z: T e6 g m) l: u, F0 c; \2 |/ g* m ?5 }5 Z; F
————————————————2 D z$ y' h7 a0 B
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。2 m5 W: [; v' w" o) n' V+ q) T
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
) k% ~- b- ?+ u1 h, H$ x* H# w; Q9 f6 f, @( T+ W9 b- {, _
7 q4 W# L; ?6 y& x7 |
|
zan
|