- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 562213 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174037
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 18
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
! b3 t& H+ k" t& I0 M, S( N7 r
, _" c) z; i- f: W" ~文章目录# n( D* H/ A. ^: o
卷积网络实战 对花进行分类9 s8 [; E# H9 V( W* n% c6 }
数据预处理部分
- { e3 C2 c+ }8 A% {4 w网络模块设置
2 z( ~) v: W0 [$ R- v( O网络模型的保存与测试
/ D0 A3 z+ g2 w% {& H) d0 z( z数据下载:% B0 v7 a1 F. h5 ?" H2 c. w! }
1. 导入工具包- J$ x. t& C9 ~" E; |& S
2. 数据预处理与操作
! r4 B. j- Y. Y2 |1 o! I3. 制作好数据源
/ H6 K4 O3 Y- I. Q. W% e读取标签对应的实际名字
3 Z% }4 A% @. C& [% u% t& x4.展示一下数据
; K3 m6 ^2 {% ^5. 加载models提供的模型,并直接用训练好的权重做初始化参数
w" r* [5 Q9 Q# j* D; N6.初始化模型架构
8 V7 r+ G/ H o; P9 V$ D5 F7. 设置需要训练的参数
& X6 ^4 J0 |; d3 S7. 训练与预测 ?- n& p7 y" U8 h
7.1 优化器设置8 y. L6 N7 x+ I0 T" h" G& o
7.2 开始训练模型( w( n" R3 n& s3 O$ Y. n( e
7.3 训练所有层, b& @; q1 [" C% p
开始训练
: l8 x* ~% {' O8. 加载已经训练的模型' h5 Q, _9 G2 b& ]( q
9. 推理* W# D8 r) b; B6 c! {" \5 Z$ X& _9 K
9.1 计算得到最大概率1 u; g# }/ N' W7 L' y0 E+ K) \
9.2 展示预测结果0 b( m: \2 I3 S
写在最后
% c5 Q% Z, F8 A. n9 u; b6 s卷积网络实战 对花进行分类2 j% U7 E8 x; O$ b$ A! I1 Q; z: G
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能$ G" d! x6 m( K7 |) X
5 _- X+ t1 M/ y7 }. [# J+ M在文件夹中有102种花,我们主要要对这些花进行分类任务3 n6 X4 @+ r% K& q( I* m
文件夹结构0 B2 i. l; E; @( p, P# P
6 E4 I% D4 {$ C$ b( e" @0 g
flower_data
' i( l' z7 |1 r( J8 |! Z
4 w$ c& @8 A0 h* Itrain- j1 I5 T( u0 o* O
5 D2 g5 E5 w. ^+ S5 B8 I$ V1(类别)
( L; V* Y5 z9 i; e# n: G2
4 D: Y; N( L0 p. M% E8 Bxxx.png / xxx.jpg& N8 X$ u6 _6 R5 H0 C$ h
valid- `2 j! P U$ ?6 }% _5 U# l& r
8 e$ o8 F8 z; a$ g
主要分为以下几个大模块$ `4 A: s2 q z9 s" f3 |* ~
* ?/ J. Q& L5 t$ B6 Q2 K+ z& n
数据预处理部分
7 X6 [: X$ }: [7 ]* J" z5 b数据增强
1 ]! I! P! S- v1 Z数据预处理! C( r, J$ z8 l6 y. N. z ^
网络模块设置0 I* Q1 q7 |' u, z+ |( q* ]
加载预训练模型,直接调用torchVision的经典网络架构9 n$ g6 h; B( C0 B2 e
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
# y3 [6 z" p/ N6 ]/ b; N网络模型的保存与测试/ z& W0 m' L! C* ]! e
模型保存可以带有选择性
2 e5 {+ v) k" c/ F& R* V7 U数据下载:
/ ~! e2 h. d2 D: {9 Ghttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
, L, K3 V5 b# j
9 y- s5 y/ ^8 k3 c$ o( p# }, ?5 B( z改一下文件名,然后将它放到同一根目录就可以了+ o: b3 R* ~$ x& ] B4 Y; ?5 I
4 ?7 l) Q9 N( G9 h7 d {
下面是我的数据根目录
( e- S3 H0 W! c9 T1 H2 p! B) K$ J9 R [9 F4 l/ c
v) i$ p- {! t0 f. C9 n( J4 X1. 导入工具包) B1 T' r0 Q6 @1 Z# S* C
import os3 |' {5 p" V( }/ d# i& V. k
import matplotlib.pyplot as plt. C/ l |2 L6 a1 }( }: s' ~
# 内嵌入绘图简去show的句柄4 h: t; t+ G- f" O! d5 |
%matplotlib inline + t8 n# e6 D( |) C7 [
import numpy as np7 m/ Y- u+ p) U5 j
import torch! `* _7 V* R0 v8 Q4 h
from torch import nn
0 @ G# `# `! b; D8 N6 B3 \3 g6 b1 {# y& ?9 y4 j1 x5 h% P
import torch.optim as optim7 L+ J, p( [0 M8 m; i+ m6 U
import torchvision
5 F: H( F( G- K9 u5 T3 @8 @from torchvision import transforms, models, datasets; i3 Z k& u" [ b
( ~1 o+ k+ V& D( vimport imageio3 F2 c2 \: }$ {
import time
" n+ l ~* y q# {8 Q uimport warnings
, f% n' f6 L5 C/ N/ I t" rimport random
1 _! c* k# t% W9 d# V% Kimport sys) j7 g: J# x0 j! G( n
import copy; d9 K. v0 q) o( f) {* Y. Y
import json
* _1 }: p8 a# ^, g6 O8 Dfrom PIL import Image
8 Y! D5 A7 |( z) ^& C& m
$ E M; c' J( B, U/ `* D
* b, i5 f$ b d( g; `: Z; n/ b6 U1. k, t9 K6 o7 ?/ @. C
2! W2 t% j j( F1 c. l! |$ T
3! r' e2 l# j! A
4
% [2 c4 p4 k |0 N& r9 L* k: G+ m59 n- W4 p/ [# Q5 u7 n
6
8 U+ [6 E4 b) F" N) I9 ~1 \7
4 d" k+ o# D. F F7 t% ]8
# s0 j, r3 S* i. R9 Q9
3 ?+ f6 D, ^. V% c6 k10( w& w% x# k! t$ U) {
11* B( w' P0 ]% i7 l: B( q
12
7 O4 i3 t2 E# V9 B: j) r: E! Q13$ N1 e! e2 @$ r: J
14
M% o N" I) E! N15
. E4 z2 C6 s5 J ?6 v: f- P2 o16! f7 h' i3 T' f6 w, I
17* l e5 k+ ^$ R5 O
18' `0 W+ M; P1 E" N# a
19" o! L, m8 P1 v( d% N
20
' c7 {! { N( B5 H) \; v211 n) J' ^/ V" G3 r, b E6 i
2. 数据预处理与操作
: U0 P* \/ I$ N. T/ k6 `0 l#路径设置
( H5 Q9 ~ @+ h- a d, A/ rdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录0 Y5 J! P+ W: w; s4 L
train_dir = data_dir + '/train'1 V. s) h6 g/ Y0 z
valid_dir = data_dir + '/valid'# k+ X0 i: G# \* U1 D
18 J8 |2 K. }; t% w
25 _: Q& d' }- I- o. ?
3* O& H+ s" b2 M& M) ~, C B, B
4
7 L1 G2 f2 i z% T$ Opython目录点杠的组合与区别" Y8 ?, E. M& w7 c( w4 ^5 p& G
注: 里面注明了点杠和斜杠的操作3 ?4 A* d, U2 v* U9 |3 _$ Q. H
$ w7 |0 D1 K8 j1 _
3. 制作好数据源
: n1 f7 s' D7 `4 I2 Sdata_transforms中制定了所有图像预处理的操作0 i( G$ W5 w) }, q( F6 c
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
# z# y; _5 F1 I! B& [data_transforms = {5 N9 k6 @4 p) F4 H% s
# 分成两部分,一部分是训练
: Z z- [" ~! h3 _; h 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间) e$ U! \( h; b5 i
transforms.CenterCrop(224), # 从中心处开始裁剪
0 `# k k! y; X1 X3 m0 i' o& [8 p4 x # 以某个随机的概率决定是否翻转 55开1 p: F7 _9 K$ e3 ^7 k# D4 W
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
1 n* U3 w( k$ n: n- y/ f transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转$ h9 K* }# J' V6 ]; v
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相4 l ^- c1 J# U9 X$ s4 h/ _
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),2 b" d3 M' v6 `9 T& a$ @ |
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
0 y; \- q8 B+ P4 Q* d' | # 灰度图转换以后也是三个通道,但是只是RGB是一样的" W/ p4 J# T9 K4 n9 [; _9 M
transforms.ToTensor(),+ x7 m3 E! Q5 q; g6 v( n
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差, `% g! _: e$ g) v/ Q
]),( U) L, i4 V3 f$ Q% W% @$ S
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
) g/ s& {9 R" M8 d8 X3 I 'valid': transforms.Compose([transforms.Resize(256),: |+ f8 Y5 @( F( Z* U
transforms.CenterCrop(224),, [2 Q4 {4 T; \: g- {
transforms.ToTensor(),9 G: _( l2 n8 h8 K: J
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同+ X0 y" ~, |/ l% e
]),4 ~2 e: F. [9 S
}
8 i. A1 R$ n3 [& B
& }( r- J( V' R% Y( \0 i- C. s14 F: V& N% K: V% d9 g& |
29 b8 d2 c+ y) R1 B! b) w
3; }8 Z# V9 ^6 O4 k
4
- M/ u5 o6 @7 {. L58 f: u& M1 R+ }1 D3 t% {
6
$ ?3 r y, K0 h' _, v7& L* }: R& A4 U I+ Z/ l
8
9 V. O2 M! C4 n: f99 ~& p2 f i2 `, u2 c( T9 x! I
10% J O4 A7 Z/ K1 `
11
- s! g. b S/ | X7 z* O12
' g. K2 y2 O; F* [; @; J4 m# `% Z( C13; K" r! n5 @, h* y9 e5 R o
14
7 T+ J& G `/ g. [% |, K15
- ]* W; H" n2 Z& J1 J16
( b5 c3 O3 z/ P2 A. a17
+ O4 r: L, d9 f) Q! a1 D7 ~( ~, d0 h6 A18
: O% u2 ], b7 ^& {6 n1 Y4 R2 ^& y19
9 f( T3 b; g* N# U20# V7 @' @: [' B% q9 V1 w
21+ P' {5 u" b( _& K
batch_size = 8" `( D6 v: H- }6 x- t: E1 P
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
5 J, ~$ U- A4 m: k& j8 Xdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
( k; \' u* y6 N# A* p9 e/ kdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
' J2 j$ X! v$ Q* }# R. L; s6 D& oclass_names = image_datasets['train'].classes' t c/ |7 G. p
" T& f, p: d2 @# W! v B#查看数据集合
" F: p- N* _8 X+ eimage_datasets* @3 T, v" X; q& e5 s. Q& y% C4 z2 z
4 s/ H* m5 S* J# a$ e& C9 u
1+ Z k' K. Q) A7 S5 n% \ O
2! g) g2 X+ A" F8 [
3& ^' H0 [5 I" p
4, U0 g8 J# R7 h) |1 H1 n$ q( S
5
: t% [; m. ?+ |) l' ]) `3 [6: I" ^; y0 m3 ]+ e! O
7* y! U* m$ |' M% h9 J
8
" S5 v# R0 U7 `9 {97 ?0 J2 s& Y. p& I1 B$ f3 }
{'train': Dataset ImageFolder( ^7 A. m* t5 l) C! j* M
Number of datapoints: 6552
1 N, d( t0 C7 a( { Root location: ./flower_data/train
H- R' R: K+ w StandardTransform
5 E* V' S/ w- Y Transform: Compose(
: N; g3 V" O- d$ R: m RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)' ]1 e' m4 s3 [# ~( p! F; P
CenterCrop(size=(224, 224)). t4 S X& L8 x; D
RandomHorizontalFlip(p=0.5)# V! _. a, d4 R, m& c! I
RandomVerticalFlip(p=0.5)
, u3 n9 H. Q2 s; K: }' h ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
7 N( Z; ?9 z2 ]& |( | RandomGrayscale(p=0.025)5 U3 N- ~& ^1 j4 Q. {
ToTensor()
1 F2 _- b5 _" C W" [ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])5 z4 N4 H, @ f% R7 k1 W: c, [
),
# [9 y7 D" R* |: O( E 'valid': Dataset ImageFolder$ d+ ~% E* C% P* X
Number of datapoints: 818$ s- ~2 M' I. i2 w+ q1 z, z, J
Root location: ./flower_data/valid7 q, ]2 W+ F8 \" O7 ^
StandardTransform
1 j2 C |+ W% z7 {( ` Transform: Compose(
1 p2 b% d7 { U" b. M( s/ \/ w Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
6 P2 x; m) w6 Z# z0 b) @' g CenterCrop(size=(224, 224))/ e* i! N; q4 I: W& I
ToTensor()0 `8 ^- l | }
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])% Z# J/ Q( \# x1 \ W
)}
: ~9 }) P4 G* e/ N. A7 Y* x! S' T% r" v& B5 Z9 S% B
1( s. q: R9 K. c( L) K/ I
2# c. c( k9 f4 d( {$ Q
3, M. e" T9 h2 C4 \* I; O
4
7 Z5 U; J" k2 e5
8 ~2 @2 g, c) o5 j( L* l$ M. D$ n6+ u" a1 R7 ?$ N! a& E/ Z
7
: q6 |! o* x \$ G5 ~8
) t0 u2 F$ k3 o& a& S) f9$ [, L& S* C0 t9 H! w: n+ P
10
+ K, [) K' o% |" F* V" X( O+ [5 |: Q119 X1 X0 m9 _. P
12& @+ A) c3 X7 q& `2 G! L' J
13
' q; x/ ]1 k4 D& G14
5 a5 c8 w3 ~0 s8 j, `% L# ?& J152 K! l/ T, r* L* S; }% b8 u l
16
0 ^5 _ w# c1 {17& m# D+ S. F' x
18
. g: w( t3 A, M19. R, n3 W- o7 z) P0 _
20
. |: v, o; s) R% K! T& k215 k! y8 F# [; i8 I
22
8 X- X2 ~: f) Q, z7 l% C& Y23
6 d( K/ {0 u& o/ B248 t, q, b9 U' S) R2 J* s F
# 验证一下数据是否已经被处理完毕
v6 m6 G% i. C/ O3 c/ Ndataloaders
; E0 E" v/ [/ v1) y; E) m* d9 A
28 `5 Y5 m0 o8 t/ h4 \) [- p7 j' J' ^
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,( u+ a6 B8 E) V
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
8 |# u! @1 \( O, T1 H! O1
0 G7 }/ _$ B) K0 [) O2
2 t2 B$ Q2 i$ ~( q) a: Tdataset_sizes& M& o5 \* N( E7 O/ H, Z/ ~
15 q, Z5 G1 ^, j2 y! z+ h, J! H' p4 W
{'train': 6552, 'valid': 818}* E5 i3 ?; ?; I0 k: C. y
18 \& D! V; y$ `' _1 U* d" P4 Q5 I
读取标签对应的实际名字2 Q; }0 t& [9 k- ^. N
使用同一目录下的json文件,反向映射出花对应的名字
) ~0 \* y* ^( _3 G: [8 C5 U! i3 {8 u0 O/ i+ }- U' L* ~
with open('./flower_data/cat_to_name.json', 'r') as f:" e+ v+ M g, Y/ J' \
cat_to_name = json.load(f)+ O) }8 p3 |$ G2 _7 V
1) ~! R+ n" E9 m' U8 G
27 ?$ w. h1 O& |6 M- g7 O4 i
cat_to_name& m$ W& b. ~2 `: f9 y
1# _2 e; u. G) X! q# b
{'21': 'fire lily',: c7 c- n6 K. g; O6 _
'3': 'canterbury bells',! a# q/ q1 a4 Z% j- i+ e2 |
'45': 'bolero deep blue',
- w" d) i! _' P: s# H '1': 'pink primrose',
5 U; G" i" S: K. P/ C u& e7 l '34': 'mexican aster',9 S3 L/ e; K4 c* R8 H( f3 e1 g
'27': 'prince of wales feathers',7 ~9 I/ a& Y1 o
'7': 'moon orchid',4 P0 b/ L% O: D/ x9 l! W
'16': 'globe-flower',. \& F6 f: F% ^" m& [+ }: s: w# Z
'25': 'grape hyacinth',
6 x. b5 L7 D5 N& A3 q( {- s '26': 'corn poppy',1 [$ L. U- U! x" i
'79': 'toad lily',* r) r# r/ k2 k0 Y& h3 ^
'39': 'siam tulip',
5 p0 A" _3 x- V x1 v '24': 'red ginger',+ q! j. D- S3 q% X, D
'67': 'spring crocus',
$ V/ k% p, d% g# V6 }8 _$ [% j '35': 'alpine sea holly',
& n, x" d& q1 b: @. ` '32': 'garden phlox',8 e e# p, X/ ]
'10': 'globe thistle',2 H0 ^1 B# l0 o' G
'6': 'tiger lily',
5 j4 v' c- ~7 |1 n8 E5 f '93': 'ball moss',
1 {3 _& p5 x; a7 F4 k- O9 H6 S/ V '33': 'love in the mist',
+ D) a$ R: ^8 j/ s* a '9': 'monkshood',
( F$ V- U0 R5 J# z0 w% ~ '102': 'blackberry lily',
0 b9 R& i. K! h4 M '14': 'spear thistle',
" X; N3 b6 b. O! p. Q* z '19': 'balloon flower',4 o1 M6 g* \2 Z3 {% m3 r% }' b6 D
'100': 'blanket flower',% i# @) |% O( `6 T# ^( ~
'13': 'king protea',7 ` n& `: W( i- J
'49': 'oxeye daisy',
6 p) }2 I- a r5 y8 \ '15': 'yellow iris',8 h# k% Z- k+ S% D8 N2 f; v
'61': 'cautleya spicata',
" ]$ p" w+ l! ^& u B! J- F: Y '31': 'carnation',
* F4 K- u$ ^4 z* [/ ]0 [, e" y '64': 'silverbush',
& Z, K" c/ ]% M '68': 'bearded iris',
- I, k" n2 R0 |. H" H '63': 'black-eyed susan',0 f+ N. v* S7 z5 X+ V9 U* Y
'69': 'windflower',. S, I; [/ K7 I- a: n( J) J
'62': 'japanese anemone',# b- c4 y, V7 d* C
'20': 'giant white arum lily',
, n$ J [2 j! E) _ '38': 'great masterwort',, G0 {) p# N `, A$ C
'4': 'sweet pea',
8 d' h3 t$ l0 G& | '86': 'tree mallow',9 k2 y2 b& p7 q1 ?5 ^; e
'101': 'trumpet creeper',2 N6 B$ t" \( b0 T) u$ K' ?6 q6 x
'42': 'daffodil',4 E- ~& Z$ w, N o7 L7 E
'22': 'pincushion flower',$ I2 T$ d" z* Y# [' A/ N
'2': 'hard-leaved pocket orchid',
0 B* Y1 ]' w* }# a* F/ d; k '54': 'sunflower',
" u! {- s& V$ z6 D. ^9 E '66': 'osteospermum',
+ A* k1 l! L" z4 o '70': 'tree poppy',
k q" v' A2 O- }* L2 j '85': 'desert-rose',; s2 c# n2 |1 t) x8 @6 u
'99': 'bromelia',
4 g: ?5 K& Z. L. P4 j, a3 x '87': 'magnolia',+ I2 b+ i% x+ }. f5 V+ h" ^2 e7 C5 ~
'5': 'english marigold',
; J9 O( d* a7 F5 u '92': 'bee balm',; _: e$ p& ?) z; d5 M
'28': 'stemless gentian',2 p# U4 W! ~& ^% Y% ?$ C" a1 z j/ j
'97': 'mallow',
, F6 S2 `( l w9 M7 Z, [" h7 J '57': 'gaura',4 b8 @; E( T6 h8 S Q
'40': 'lenten rose',8 _0 x/ g! U9 r1 u* p1 H
'47': 'marigold',
2 k/ @# }3 _! h8 e! o. P '59': 'orange dahlia',' N0 o+ {6 w6 X# g% e# W
'48': 'buttercup',/ W$ n. `& @+ l. ^
'55': 'pelargonium',
- _4 y" ?, x( [) `# S7 c '36': 'ruby-lipped cattleya',0 a7 s+ [8 P9 U- r4 Z+ T
'91': 'hippeastrum',( n/ O3 u f1 o2 e" T" ]. ~+ F) d& c
'29': 'artichoke',
2 g h4 T, A+ l) x- V& p5 ^ '71': 'gazania',$ L/ H- h8 ] b! [$ P) k8 g
'90': 'canna lily',8 T1 ]3 n7 T6 J$ {! U+ B
'18': 'peruvian lily',! Q2 j7 z2 E- L0 y/ t
'98': 'mexican petunia',: I. q- `% g7 w! f, `5 ^/ P
'8': 'bird of paradise',
+ B1 z2 ]! s. U3 J) _9 s6 l '30': 'sweet william',
' g; T! H0 O. {. u; D '17': 'purple coneflower',
3 p! _' s9 D1 p7 ?( I) b/ J '52': 'wild pansy'," w+ L& H2 w- E) Y$ X2 V, H0 }
'84': 'columbine',
2 m$ w. D4 D5 C. r7 r; H/ J9 r '12': "colt's foot",
0 B: ]3 d6 S5 T: v '11': 'snapdragon',
+ G' O W; k0 s2 v# |! E& R1 ` d '96': 'camellia',$ M" ]- ~1 X" E4 C& _6 h6 u' j
'23': 'fritillary',8 q( v( r4 a1 @
'50': 'common dandelion',
" p; i! {* F, K% z7 l '44': 'poinsettia',- L7 q+ \) l* m
'53': 'primula',; o# z8 @% V$ ?8 @7 O
'72': 'azalea',
* j2 F, \( h" S0 z. x" q '65': 'californian poppy',
: L9 i1 L y6 b! p% Z '80': 'anthurium',) _" @5 ^: R% x: d% ^; o) g9 L
'76': 'morning glory',7 q: O# j. `& ?8 F! W
'37': 'cape flower',6 @! X) g$ T! S. T9 Q
'56': 'bishop of llandaff',
7 y" J0 }/ N7 T- I8 n4 P" A '60': 'pink-yellow dahlia',
2 N, t6 q% @5 P6 d '82': 'clematis',
* q1 h) c; U$ J" g$ g '58': 'geranium',
/ L/ c' ]0 B! O8 e '75': 'thorn apple',
( J$ u+ G8 x* n/ H3 q) V" _- O8 _ '41': 'barbeton daisy',
) C2 ?9 `4 v) w# [# I '95': 'bougainvillea',2 C+ r/ Y- Z- G# h! c! A8 |% r7 O
'43': 'sword lily',, W( m( N P+ c, n
'83': 'hibiscus',
( A+ C" w2 j; W '78': 'lotus lotus',
( p) w" {; M/ [& z9 O '88': 'cyclamen',
1 y8 G" L/ b0 ~2 F0 K9 S% D '94': 'foxglove',; `- w1 u5 Z) Q( v' j
'81': 'frangipani',
$ Y: X A. W+ X8 V '74': 'rose',4 u$ I, {- F2 @( [" c( T
'89': 'watercress',1 o3 s3 V1 t% s/ l6 r/ Y+ {! s
'73': 'water lily',3 D, S) p2 b) _9 X1 G, r' G
'46': 'wallflower',( b2 l2 v; x: M& N
'77': 'passion flower',
% \ o9 |+ q" u* C8 @& E '51': 'petunia'}( P7 ~ R& p; S1 c3 {
R3 e9 g' }6 v& s6 L1! X: o% ?& q4 Q
2- ]) q3 u' f* b+ C8 A
3- `$ Z' |/ ^4 ]( B
4- A9 A) u r# ]8 A9 v, |8 L
5
7 P8 {2 w9 b0 K1 U5 y- r6
z( i$ h4 s3 M5 L: G& |7+ S& i) R0 ?; ?2 C( C
8
' Z |" B4 y0 u! d9
: n) d! o6 R& U( o( P! Y, ?10
# W; L/ U- h2 f; ?" M9 |11
; J" v0 S2 q) }12
( r: _, c5 [9 Y$ w4 ?8 V$ W+ h137 l" _" c, o/ U9 O, D
140 J4 G3 O ]9 d* j B, \1 L$ O0 W
15( [$ b. z7 I/ M+ h& C
16* `7 ?+ V; K5 K( I* S- O
173 M# B- E* z8 f/ j) D' @5 a
181 L1 N: w/ P' n
19* |. f) h, x: q- L L
20
* O! h7 d3 o- n" s, {21
% N& n+ u& T s u4 `1 e P221 Q' B5 j1 ^, d& |9 c6 K: G
23& A, C) T9 {1 K9 U+ Y! v" {
248 V% S9 y1 W5 [" p/ K$ w, C1 H- c
25
7 i+ k" J4 ~* V! C6 R26
. I( ]0 M- t8 x8 h275 u- v$ |" X {- p6 V
28" |0 t0 O L5 M9 Z: w
29
/ e' O8 g: z$ U! |# C4 S30
p- D. X( h: G' v. [3 J, D! J31
# r) i0 f9 a3 @2 B& N6 d( B32
; ^$ H0 d! N1 r) H5 L, C33/ |+ w* `0 v" e4 ]! V% S
34
; a/ `3 X! E* G7 J5 @" o354 \1 K2 ]9 K( J! }
36/ J* c' g$ {/ n$ Y
372 r1 h6 s# U8 J4 |. V @; V
38: z( ~' s/ M2 T( `$ R8 `
39
4 B5 S0 ^ ]/ I2 y; Q% F40
9 [0 M! M2 r0 P% T9 I! U/ R41
4 g2 ~ L* M! X3 f; U427 @7 Z2 \7 C( ~+ a+ y: ]6 S. h/ g
43
( x# m1 \- V* p: X44
8 p# p |4 S/ {6 d/ K45" E/ O3 E% |& D! t& j! `2 X
46- P$ D$ i& U+ ~; ^4 M# j
47
5 P5 @+ y7 \4 W/ R4 @48% }$ x+ ]0 \2 X" I9 @, A6 e
49
. K5 @* ]/ v( z# w50$ K; ~# u x! l
51
" r3 @ G7 B) n- m6 r0 i- {4 a3 b52
) R. B) W; e) J: [53
6 M5 r3 v6 h5 r. x. X" b- ~ f. c54
, H. s. A: _! j; e55
/ i: K. J2 \/ o56, H; n3 ~/ W$ j1 [) x
576 F+ g) i4 }+ b, m. l; E3 S7 s
58
4 }. F2 k( K4 H- @) z# {6 x* ]8 X59
( {6 b$ C/ q5 N$ p- i60, D) m$ @" N% x4 o. Z A7 q
61* u& x/ u# o `6 l `/ e
62
! r! {9 ?& U0 B9 q3 K" C9 i634 }- d& ~6 `2 O! f" H% h6 C* a& u
64
l$ X: X# ^. z65. l" ^ l. S: A% M" g7 } t
663 J6 ^$ B& {1 h9 y: x
67
( C1 {( C- k7 A: W6 q682 _/ Q+ b4 C4 N2 Y
697 I6 d2 X$ k; T$ F! O+ p0 U, U+ T6 o
70
\* ]7 {0 |1 X" U: P, b71
9 @( h8 K, Z6 W+ ?/ z$ p72! o3 y5 \# ]3 W& }+ e% Z
73
' B: s8 ?, g3 z2 ?74
, v7 u' B% i* S4 q/ G8 a75
2 [/ p% J9 {6 r8 u1 v76( }0 B" a& [0 a- J2 S" E+ O8 Y
774 a. j8 P+ M# F( G- K' Y X( U/ D4 F# ?
78
. P; W) X1 s/ ]; }79
3 P+ d# S4 j" G80
# N; z9 a$ t; M- F; ^; P$ u81
9 h8 F# V' L+ J j820 w) c# y* d* k9 j9 m. ~
83
1 z& l7 t8 q) Q1 }- D# }! I84" ?5 ]' R3 u8 w. e, D
85
- X0 X: \$ i F* M8 ~, D" l86
: f5 `( f5 ?) P87
s4 N! {# B9 U- u; w9 n7 \88
. B# H6 ^3 e8 V3 s( `5 k! R# V89/ t- M4 o- j. y
90
; O ]* N, ]; i J91
- S! c# L p: R j! v9 a2 W928 d7 N* K \( H$ n. E
93
. R) M2 P4 I' L2 ~+ X94
3 V& F. T1 k. a- F/ j$ N& z95
% d; ~# e+ x5 {, {4 r2 O, {& W96
6 k4 L8 _( G1 s% Q. |979 \, z: s8 H5 P; B
98- V' T) _4 B# b+ L3 ?
99
r4 c* L5 f: K' N; ~, e4 A100
" {5 X! m& |: [! J: N101
: u: q- r* H6 F) M& ^1028 s$ I: G* e( L4 c! p
4.展示一下数据9 x0 j5 I3 h, u8 k
def im_convert(tensor):' }0 Z# s4 Z( V }$ P* i' c* j
"""数据展示"""2 b0 r4 P9 O4 R/ U8 n" m- Q
image = tensor.to("cpu").clone().detach(): @. T6 r! Z, j% t$ M
image = image.numpy().squeeze()
; Y4 y8 a* Z: E! X, L # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图' t& G3 n9 h: Q+ U' N4 a @
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
8 b4 Y7 \0 P7 R image = image.transpose(1, 2, 0)* q) p# S" Z# M4 u6 H9 H& a. T
# 反正则化(反标准化)
# V3 n' _' @% w. c, V image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))% {" T# V4 `8 i( D' y$ i
6 Z, { U9 a6 I) g7 G u* K
# 将图像中小于0 的都换成0,大于的都变成10 A' {, C0 L5 v+ Z' l/ Q4 X |
image = image.clip(0, 1): M: Y# _6 \3 I1 C/ b
1 ]# B. n' L/ C' _6 g) U# _& b4 t: N
return image
$ D; W5 }6 l7 m1, d6 F$ `! q3 X" e
2$ _ c$ ?% ^7 H7 b
3) [6 \3 g9 ?3 C( v
48 e7 e$ ]; Z* T0 g! v/ e1 ^
5. B$ k2 A; [7 c% O4 E t: B( z. Q
6; F4 c9 e; N& ]1 m0 V
79 s* c0 R& h# B8 v; |, a: q$ E9 F" _
8! S) G6 x/ t! e. P0 z) Z; B& b, Q' [
90 s8 j P$ ^; E2 W
10
( F/ u& H# t7 d- F9 h11
" H+ }0 o# B' q" ] `3 S8 e124 I3 U# R2 R6 r0 K6 A3 s" `+ D
13; X# O" k% l i6 B0 m: f. R
14
6 s$ I) d3 ], `. v3 y1 r! l' Q* u# 使用上面定义好的类进行画图0 I% L& y! Y. L' A# n
fig = plt.figure(figsize = (20, 12))1 N+ ~6 N. o* P9 {
columns = 46 M- z" s5 ]- K' L
rows = 2
: U9 r! c& w3 l5 X/ P$ V" ?. M+ l% y, w. j* q4 e
# iter迭代器
! |5 G k, @" Y" l7 S* ~# 随便找一个Batch数据进行展示
$ s8 v2 {7 }5 [9 pdataiter = iter(dataloaders['valid'])
6 j. P6 B, R, @: K3 c6 }" F5 \% pinputs, classes = dataiter.next()
# _0 }1 o7 S4 E. i4 ~: b* k! ~: I, U t5 K1 z
for idx in range(columns * rows):
4 |7 r5 I7 [- C7 K/ c% X4 ^, ] ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
0 h- k0 F4 {, Q& I. M9 N! ^# p # 利用json文件将其对应花的类型打印在图片中! F* _$ |( l# u8 O3 @( n
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
6 M7 j# e& e7 m; J" ^5 [# w plt.imshow(im_convert(inputs[idx]))
% d7 g" N Q6 b! z$ z9 Z% I* Yplt.show()
}( j ~: v. ?5 F8 W, |0 H9 w) F$ h9 G
14 i$ Z, d3 o# k
29 j+ X5 B7 c+ f7 C) L
3
& J$ a% |( l' p" n- {4, | H6 Z7 V! y1 t8 Y" I7 U
5- F3 `/ y: _5 c
6
4 W: |* ^$ l: ]% h! Z4 q74 U% p* u+ p, Q( a
8' r0 m2 W c- S
9- A S6 u h% ?' u
10
. { u3 t# R* C' K, Z; T4 T11
: s& G. Q5 D) P! M+ K/ E12, {6 \, @7 `, m5 w! J1 w
13
6 n' ^8 L: {7 M1 t* I14
! f& o6 }9 X7 Y- E; e3 T15
! W" ~! V" q1 }9 H7 ~8 y; u162 R r% m2 G# z
( e- O9 z! q- E0 A5 L' I
; V; K1 [* H% t3 h1 D0 s5. 加载models提供的模型,并直接用训练好的权重做初始化参数
i, t$ Z$ w- p9 Fmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception'] i5 X) j! e, c# A$ |4 n
# 主要的图像识别用resnet来做
3 P% T. y. e' f! n% B" L/ R! F# 是否用人家训练好的特征/ L. @5 Q% D1 w3 d# F5 \7 A% \
feature_extract = True: C2 Q- d; \# j3 {/ G' A1 |' S
1! m9 a9 b) H3 e& E0 A4 n: E; l- c
2
" _) a0 H4 w+ @, @3 k' a; _; _5 I' v, |4 Z) B: @0 u5 @
43 w; j: d4 A- p2 A0 X
# 是否用GPU进行训练
. h1 ]+ t7 D0 i& Ntrain_on_gpu = torch.cuda.is_available()! I3 b! y& h; m
# Q3 w" h: r9 Z) sif not train_on_gpu:( Y# ]& z9 a( t) t$ x/ P
print('CUDA is not available. Training on CPU ...')
1 E" a2 @% N. A- v3 O" U- w- {% Helse:: g! ~; N2 W) j7 N/ w9 n1 [
print('CUDA is available! Training on GPU ...')
7 \# T% h" P' y
6 v+ h2 Y$ ]1 J$ }/ `device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
4 g& n9 l3 B; J2 @. v. |1" D- x5 b) T5 w+ R
2
# q( `2 R! T5 s( o" x3& N2 ^ T0 |0 {" h" i
4& ]+ |4 r1 p1 z2 C0 T
53 `0 D- |6 Q5 D
6
3 O- ?. O. S0 H# _& n74 E' m6 n6 C/ b- v: ?
85 ?1 U1 h$ `, |+ U$ d1 j
9( Q4 ?' o& K) @) S' p# Z0 a
CUDA is not available. Training on CPU ...% S# O- g% B+ ` J( D0 w' z
1
3 @4 |, H. ?1 E0 Q, X1 H$ \# 将一些层定义为false,使其不自动更新
. S$ |) Q! y. L" x, g; k, s0 E$ q, y: Xdef set_parameter_requires_grad(model, feature_extracting):1 ` h# r9 c% N" f
if feature_extracting:
! R p5 d8 w6 t& L8 g for param in model.parameters():, ], y# E7 d6 r r1 C5 W
param.requires_grad = False
4 l4 A+ d5 e& w) d- ]# ?9 w& {1
3 ?) n( F3 k% S2
# E/ B: q# j: p: k3% ^2 X B5 d3 v+ V
49 z- r# ]' c8 B7 K1 V& E" G
5
' D: u; q3 a# s" |5 a- W, N p% D# ~# 打印模型架构告知是怎么一步一步去完成的
S( }- V0 z+ B" J; F# 主要是为我们提取特征的
, _, x5 S* K" ^9 U! g
. O- q$ @0 T* f( P5 Fmodel_ft = models.resnet152()
8 ~7 `) u! z. O8 S, v, Mmodel_ft1 w- L0 W: |* Y5 R; A# u6 N* k
1
4 _( _+ n* q0 G$ G( W& b/ b2
4 N: k* x8 Q, u) O. K3
1 b# H; J# ]. y" D/ [4$ l. W( t( x0 B: B1 Z; E. V
5% j$ F2 N" L- z
ResNet(
2 X! y) ^- k4 l0 s& I (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)% P" e" q$ Z0 \# P
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)6 D* {" b) J G' D! X
(relu): ReLU(inplace=True)/ B' g- R) i u
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)5 l3 u+ g2 \+ r" V
(layer1): Sequential(
* M0 g3 b, a# f p# ` (0): Bottleneck(7 \8 @) f& Y& `1 \
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# ~- L7 H& U" L" @* L (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
$ G3 H2 E. {) H3 o0 A n (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
: T: Z4 J( a4 T/ c (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)" Q+ E" @+ u: L1 e" K _# D9 K
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
" C: r7 t( J( H D$ r (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)0 i8 J. i' ]1 x$ q( I
(relu): ReLU(inplace=True)& ?* s+ g" E3 `. W
(downsample): Sequential(
( Y! U2 g6 }, I+ i: C' G, y& H (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
A- S5 N4 b/ ~9 e% M (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ h$ u3 @6 y* I, N8 Z
)
, n2 h @7 E2 g+ I A6 m& L )- }& E* r% k% u& l) V. x5 X7 O) U
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。; t7 v+ E1 m; O& } L4 }; u- m
(2): Bottleneck(
/ N4 i+ h ]: i6 j( K1 Z) W (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)0 O5 {8 W9 M* C. x
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
" P" u2 F# P" r; ?5 M6 A (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)8 t5 i* j7 O3 a2 X" j# ]/ p
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
* n; Q% c; e( ?1 [. H (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False); F! w& e, u9 J" q
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 u, g U" f5 D! ]9 x (relu): ReLU(inplace=True)2 c% K P; H+ b& h
)
) [) g T$ u1 a, ` )
) }9 ^8 o$ G5 L" z: n0 u( G% N (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))& }$ r' l" f4 x- G! u7 b
(fc): Linear(in_features=2048, out_features=1000, bias=True)
& H O5 c4 |: M& _2 D8 R)) }: @4 h! @9 J% L. q
6 ]- T+ ^9 q# P/ ^1! e+ Y, f2 P0 z0 ~$ @1 ]5 k
2
! a* d q P. A2 A& z( W8 {& r: ~3! v- j _, @. ^% s% R+ A
4) M& n1 A; v0 S( e
5; [ @# R4 L& V# Q1 e
62 [& v( u, S" |3 A; r
7
0 ?4 P% {0 ]1 F' g- {7 ?8
" f1 y4 }- _+ Z g7 S9
8 P5 ]1 i/ c% Z' F& s; `10+ ^( d8 O4 x; u" i1 a4 c. V: o
11
. b+ E" E6 q) H" O& D* v2 N; v5 D12) n. B/ _/ k+ ^7 ^4 K$ {& @
13 Q; [+ k( o! @
14! \# L- ^& h+ I
15+ S, q) T6 R0 b2 m: Q
16, \ [$ M" v+ F6 O- ~" g1 l& f
17
$ l! G, a4 w( K9 ?2 d; _18
+ [. O Q4 B y0 t190 \" F' e3 A" Z) G3 u
20) c3 ?# U: Z! F
21
: j5 x, }% U) u( [: h; ~' M1 M# S22
0 b$ h) j+ n1 ^/ j9 |23- T* V, o! J5 t: I( ?
24
+ N) f$ t: c' [2 o/ R s, }+ I2 y1 J+ t25
& P2 H7 g6 a8 O! @% U, o- T% t/ y6 z26
; E5 _3 K0 u( }- {7 k; n2 z; z27; |! M+ {( }$ T; ?' Z( u2 Q
283 P$ t& z4 v& c( b
29
4 }( i$ G( m. g+ K30
0 V6 U9 R& z* A' s* v7 K, j( s+ ^31& E8 [/ K9 p0 [" Z
32) U3 h& c& |; M+ N
33 d3 u. b' B& S+ }* L8 U0 U+ U: n
最后是1000分类,2048输入,分为1000个分类
( X8 h; p; c) b6 E8 Y而我们需要将我们的任务进行调整,将1000分类改为102输出1 c/ E( \/ S0 I
^- f' r$ }% X7 i" M. I9 g0 W6.初始化模型架构$ s/ J, p# @( B5 V" i
步骤如下:
! W8 G- X; v0 b) V& H+ R+ T) P+ Q& Q
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
( u/ K0 N& K6 U# F" K9 ?可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False), _( n3 U0 z3 A( s" S
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数1 H. N1 P4 |9 z) @; a
官方文档链接" s2 p( Z, k* c5 t3 W5 ^9 Q5 Q
https://pytorch.org/vision/stable/models.html
) O0 a' p9 t( a! @- }6 p$ S
, G* E9 x5 d1 F4 G# 将他人的模型加载进来$ B; z E+ \1 S! i2 y
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
8 A3 d% D4 n1 }4 K9 ]$ I; ` # 选择适合的模型,不同的模型初始化参数不同) r3 W7 R& L p p
model_ft = None
' x2 m: n$ p! N input_size = 0" F" Q) ~$ k5 `- |
. W- p4 \+ s4 {" s6 t
if model_name == "resnet":
9 e0 `1 V4 |9 J/ h9 ?( L0 m """
) X; L _4 t9 Q: m Resnet1521 r. |# d' G, M1 r7 S. F
"""0 @9 B: c2 x Y* R8 R4 K
$ u8 a* Y. b+ {0 p. g& x) Y # 1. 加载与训练网络
4 {$ q! O2 D s5 l ? model_ft = models.resnet152(pretrained = use_pretrained)# c' K1 A6 g& M/ I
# 2. 是否将提取特征的模块冻住,只训练FC层
# X. D7 e( n. Y/ {9 D. d7 v set_parameter_requires_grad(model_ft, feature_extract)6 G& o* r1 c; a/ J" j3 @
# 3. 获得全连接层输入特征9 N" e4 f5 F: ?' m4 ^. o8 n+ ?
num_frts = model_ft.fc.in_features
0 d: W7 V* x2 S( D # 4. 重新加载全连接层,设置输出102
1 K, s* f. D) P1 k% Y! e model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),; J" _0 R Y8 r9 M; H- k& ~
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为18 u6 [* T! Z- I# L# B5 P% Y7 A
input_size = 224( r( b, {) @1 W
6 P0 Z, L& I, ?( H: {) r, \
elif model_name == "alexnet":& Z: L0 {# K' r8 _4 v
"""
2 `( C! |. x% w' v! Y Alexnet
& a! J* k, B; ~; f/ b3 f' \8 x """7 }& D. s2 x5 }6 W4 Y9 q& F! l
model_ft = models.alexnet(pretrained = use_pretrained)
$ w) m' V# U3 K# [+ z3 x set_parameter_requires_grad(model_ft, feature_extract)
/ |$ g& ~( G/ v* u4 G& m% X( ]: u+ e- L0 Q) o
# 将最后一个特征输出替换 序号为【6】的分类器/ F! o& L8 H! ?" M5 B! y7 L( G+ f
num_frts = model_ft.classifier[6].in_features # 获得FC层输入: w% D: T- R# R# n$ V- w8 u
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)8 p- }% n6 Z* D: v
input_size = 224$ v. \* H: q* q: ]( y. K( v4 u
5 C# |8 y/ o5 _$ B! |( M/ X
elif model_name == "vgg":$ V5 K7 o6 n, g, M6 j* j7 H: L
"""
- b2 q: M$ i* s9 {1 `* ~ VGG11_bn4 v5 [2 R# [$ w
""": L \) A. p( J1 B$ T5 z) D* d
model_ft = models.vgg16(pretrained = use_pretrained)
5 ~' o4 ]* a- x( k3 D set_parameter_requires_grad(model_ft, feature_extract)5 u; u& K6 J) J
num_frts = model_ft.classifier[6].in_features( V& Z7 z. i" r
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)6 K+ u6 i3 C; L: Z' k
input_size = 224
! _0 L" i# w" i# O- G* C& L) C4 v5 t* t6 |9 }$ I o
elif model_name == "squeezenet":
! m; q$ s0 ^2 K0 k2 M& M """& Z: D/ E1 K7 U: P1 z3 a# o
Squeezenet7 k. B D3 J/ D3 v5 |( ?% o( C
"""
/ j( H. R) j1 ~1 O model_ft = models.squeezenet1_0(pretrained = use_pretrained)( x7 E2 E4 ~) ]1 o! Q& `2 \# w
set_parameter_requires_grad(model_ft, feature_extract)0 i h' T" l* [* t% _$ y! \1 l" G
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))7 o( t3 }+ k1 q1 f! L
model_ft.num_classes = num_classes/ h* L6 i& i- u4 x+ q- i5 A; d! p* b S
input_size = 224
9 D$ Z5 ]- U9 ~0 J5 D; D* n
' J2 Y3 C% h# G7 ~! N elif model_name == "densenet":7 n/ |1 b r% }4 h# K& V
"""
! x; E$ K* G* a, P Densenet
1 X. E9 W2 M9 K+ Y, \3 Y """
' ~2 T( I; l5 ?' P% Y model_ft = models.desenet121(pretrained = use_pretrained)1 H/ m# H# G2 Z0 _; [
set_parameter_requires_grad(model_ft, feature_extract)0 |# O' y" W3 e# a& v! ~- A7 x
num_frts = model_ft.classifier.in_features% c+ z3 u, Z& W% q6 o- W) H, q
model_ft.classifier = nn.Linear(num_frts, num_classes)9 g- h9 H, g: o0 U
input_size = 224
, ~5 _# l, X) f r7 G) S5 |- i' X& x) l! S& k
elif model_name == "inception":
' o: x* ^* H+ X$ z2 p) d, z( P* U """ @4 @$ M+ J" F+ _$ L
Inception V3
* H. e5 v) p. I" n """
) X: R6 E! [2 u" k3 _; p. D model_ft = models.inception_V(pretrained = use_pretrained)
- k$ T( y6 w( K, F set_parameter_requires_grad(model_ft, feature_extract)
6 o) ]- }. V D$ q! }
* L9 ~# H9 w0 H' H/ k1 e! P2 t* K num_frts = model_ft.AuxLogits.fc.in_features
2 h/ W. ~3 } V, G, Z9 h2 A3 o model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)/ x; q5 K+ s* t k0 J+ s7 Z
/ X8 i3 B2 T/ }7 c, ^* {: Y# o
num_frts = model_ft.fc.in_features7 ^, _# x3 {1 I/ _5 P
model_ft.fc = nn.Linear(num_frts, num_classes)$ | D5 ^, }- k% n% ?
input_size = 299
# n" k! p0 J4 C" c X% o4 y# _- Q
8 T8 j7 o" l! d- J else:
& f+ _5 x8 x5 Q4 u- n print("Invalid model name, exiting...")
8 e$ n9 e, T) R# k0 @: K exit()+ G6 R- x+ T5 f% d
! u; `! X+ w) T5 c- F
return model_ft, input_size- I& W% w0 U2 x; c0 y% ]1 G' p
- V0 _$ Z* ^- f# H' K8 [1
6 l) l2 Y# s. d. ^. a3 }2 W2* P, U" z; d1 c3 y0 l ^- E7 ?0 }
3
& t$ x" ?1 X, ~ ~! o$ @4, L, q. f0 _* x7 |; k
5
( ?! E7 H% |9 [" t& ?6
; P3 E0 B: l0 ] Q% M7
' c6 `8 x( z% f0 M/ C9 B' @" o. Q8
: W4 w) E( d2 j9* S& R l* U* f% A7 ?
101 p5 Z5 ?; I; `% G s& a
11
& k- i0 Q7 [$ L2 F/ a12( z& C; k2 M3 z& ^) Y
134 s$ `8 ^) V" z3 X* i
14
1 x5 d) k8 V( e15
1 o& p6 d7 z2 Z5 \165 ?4 U7 R! u* Z& t
17, b% J) U. U/ B: J1 P" Q
18- }9 o5 S2 a4 n b9 C
19. [; R3 Q+ Y/ F( C3 |2 W3 d7 J* z
20/ f2 C8 v7 i, {, u- f
21
) U: @& R7 a: j% [ L22" ~& H: |* ~3 I: Q% ]/ }) q4 G
23 r( r# }8 T8 @, V, p
241 [& H& G* B' i
255 ]- q! w, i3 H1 O# }
26
; e; b( p; |. @$ E273 G. @0 j E8 o% \$ `6 K
28
: S) e8 W2 w0 d! m+ G+ o29
1 }1 _* q# j, P I. \$ |- S30& m/ [ Z; o9 e
31
) I2 I( e. k x: W0 a32" i% c1 e( I# p( ]4 {- }3 Z0 j& d( q
33
7 S6 ?) A* O, V9 M5 |# L346 U7 S: m3 _. t5 _& i: Q
35. z& z0 T, k# W) G$ [, C5 S3 \; ~
36' k0 y D% a9 S& @$ I
37
# C$ x+ _6 y" C! I38
/ c/ w$ c g3 r* h9 }" v( P* D; Y39
" E3 I8 \4 M+ ?' l40
: ` E6 Z' t% s: Z! f D7 \41
. u. I- [$ s' V) u* ^, X$ ?42
2 {" W/ T% Z5 i& ]. C43
& b" c( |: G5 t; I5 ?, l& J6 u$ q443 v8 Y% o! I: Z) @# V* K
45
+ h0 W& x; b5 b1 |0 {& X( H5 @) k46+ y6 j5 F4 [* H, t9 W
47
: s6 Q9 S, j& o3 ]+ O* E& A0 @5 s48# E5 r7 S u9 g2 o. @$ |% Z
49- v7 z; _4 F( l' f
50, P# M6 {" z, Q. S/ Q' |5 }
51
+ V' _' J$ L4 z* e3 T( j4 Z/ S" J52
7 a" S2 ~# V& T3 `! Q A53/ c( w! z4 N+ w- e: U8 a2 N) Q$ V
54
. K1 b" j: G% C55+ l6 H1 ^& Q+ p
56
" n+ e9 g+ O2 ~; }, J2 c$ C# ~57
# n# w) }6 a% [58% k( Z9 x9 I% m7 i1 F
59( E' D; g& l* b7 d" |
60$ s! Z1 n, ]) c) V2 Z! l! t
61; W3 `5 E7 l% t1 q' `$ Y- y) S
62
5 Y9 v* ~& I1 Z& ?* R63
5 e* A: d* }8 i* P64
& O- a) a% X5 W. Q+ h3 c x65& E( h3 y' E: r t [! ^" ^/ y; }
66 P8 ^. k3 w( N3 N
67
- x4 I$ z/ \6 H! ?' z- {7 t68
' }# }" N0 `( N4 k: z69- X& _7 i( F3 y+ Z" r6 u8 \
70. U: f$ M- v. ]/ G8 z
713 _( {- v% L/ y) J$ Q
72) {' a! X: R/ I' V
73 c( d, e2 w( k. R& l5 `# V& s
74; @! T b* F9 h! A# d% L9 c
75% @; C3 I1 T6 o
766 N9 G" u0 `! F
77: c$ b: L( G% A* d- w+ `# S
78
5 L6 v2 L$ Q0 b$ L* }# S% Y( q& V79
# e. B& P: [$ y8 N) f" _80
) D/ y& e: m! S4 x81
8 O0 z. O* R% C; C5 g9 k m82
2 h# _# ?; v2 G$ x6 O3 b83
" s& i& w. i* Y5 l7. 设置需要训练的参数
4 w; J5 }7 h; e# O% \# 设置模型名字、输出分类数* ^- P4 O" h2 r7 Y
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
! T5 d v: y3 J
4 L# m6 O7 d$ x7 ?) c# GPU 计算
4 ?6 ^3 G9 \0 w1 ~% X& P# ?* O" qmodel_ft = model_ft.to(device)
, X7 m) x( D8 f5 B; n+ R, k0 {/ h" T# ]5 Y' ~" k K
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取! V$ A- ], E3 J8 S6 V) g L* H; X: K
filename = 'checkpoint.pth'+ J g+ l* r7 m, y9 p, N
0 |9 O. ^8 N" Z. h" G+ s; n8 c
# 是否训练所有层! Z0 o, K8 \, D
params_to_update = model_ft.parameters()
3 T' X; ]" L; R5 k' M! I; D" ?1 \3 U# 打印出需要训练的层
5 Y9 x# E0 v- S8 F5 {4 Iprint("Params to learn:")1 b8 n I2 B2 w' V
if feature_extract:: [# K/ u- r, t" l2 u4 [
params_to_update = []
3 T. O6 A2 S! `1 ^! j$ T$ }- t for name, param in model_ft.named_parameters():0 U2 n8 q& y( ^7 H5 p- j h
if param.requires_grad == True: ^; N6 Q% S( c8 p W7 Y
params_to_update.append(param)0 b2 {" K B* g. v- y
print("\t", name)4 n/ A0 k+ ^* h* E6 ?; R
else:9 }' S" L5 n9 X0 C' ]8 I n; R
for name, param in model_ft.named_parameters():1 Y2 d; S9 r# @# h
if param.requires_grad ==True:
8 m2 V: `9 u7 p* L7 e2 q8 X) x print("\t", name)4 W" j, ^! H9 E. \: n& q
/ O! W% B# M8 S19 N* [3 D+ h" H, v0 q, E
2
8 \6 l) m# @! b' N/ X. l9 W32 t* R. t# `4 g ]$ P
4
' W2 A; Z+ \. S- h& U0 M9 t$ w5" \2 J5 @; C% p% Z! X, K' j
6
A8 ]1 D( `7 ^+ A9 [+ ^- P7
4 @4 R/ X8 q5 x: ?( o! ?! j8
0 W5 U: i* l2 _6 C2 }% R9
: A1 V& j4 Q% i5 C$ v0 u f10
8 T. |+ {* }* S1 q11. |, {! ^ o1 z! P
12
, t; s s0 ?, L( I" X- J/ O13" s+ {3 P7 n/ R0 ?
142 X* B/ V/ E$ y1 G
151 n! ?9 d* F" O$ `5 c' N1 n
16
8 X! m) H% O# t2 H2 E, {. q% w$ q176 g0 S" s2 `) }
18
" t) g3 L- i3 G0 f9 A3 V. ?19
, V y9 P; k1 O: X/ G20# {$ b A# _6 |6 W
21
6 m4 K$ j) Z- E8 X2 m22
; ~6 _$ F x* C+ y9 `3 d1 @" ]234 h, n; b6 s2 M2 r' {
Params to learn:
- P8 ~9 ?- O* o fc.0.weight5 v& l1 b1 K2 R. Y2 u4 J
fc.0.bias
# g) r& k5 l! @$ o: v0 u& {1: f( h% H0 l! M! e
2
5 L% ?: M# ^# A$ M: Q1 C32 C+ y* x6 ]# e8 j/ [; T
7. 训练与预测5 y6 R8 d4 Q4 R
7.1 优化器设置7 m o2 s8 i* d/ \
# 优化器设置
9 h7 S- O/ k0 E8 k# u% M7 Soptimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
7 G( H" L' x2 U0 r8 |0 M! X2 v; I# 学习率衰减策略# ^8 d& ?* b4 O; F- \% \- `) H
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
9 N. ]+ r" F, k& Z7 f# 学习率每7个epoch衰减为原来的1/106 e6 Z8 |5 x/ G, J0 d
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
# S- I7 N, T; J& p7 k1 v. a9 m- D
% j8 F2 }, [! x( z: T3 icriterion = nn.NLLLoss()8 E3 V' S$ e6 O2 H
1
) a* ~4 I7 g$ D3 _8 Q2
' E. s1 g" h9 R4 y" \3
1 e) }: E: I4 K7 b e" y4
" Y" [: l( G' O+ j* k" \8 ~5
' F$ o, _; Q& p% a' L1 S- Q+ M6
9 }# m( P. B* N" m. l3 [77 q+ ~6 K0 H- C' }
8& i, r G! P9 c0 k) Z, }
# 定义训练函数
7 U" j0 A5 B+ F7 Q! X#is_inception:要不要用其他的网络 L0 K6 W$ e! z, w2 b
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):" J, ^8 Z+ {( E# I2 M
since = time.time()
2 ]1 ^- }) K) p9 T1 p #保存最好的准确率2 c+ v3 ^+ O% j" I6 ]: \
best_acc = 0
/ F1 W# [+ V/ B% c1 J """
# {/ q7 a: K- H8 w4 W$ `. O; A! z checkpoint = torch.load(filename)- ?" s3 W k, @3 Q8 [" Z7 B& Y: _
best_acc = checkpoint['best_acc']
v! p1 I% A5 ^: c! ~' I8 K- L model.load_state_dict(checkpoint['state_dict'])% T# M3 j9 O" o! q
optimizer.load_state_dict(checkpoint['optimizer'])
8 b G% o3 d, U0 _! V0 W. m+ y6 \4 b1 e model.class_to_idx = checkpoint['mapping']$ E& l- o# t# N! z3 u" A& v
""". q: r$ S. [5 t0 P) M
#指定用GPU还是CPU
; ?- U( S$ {5 }) d8 H* x; {0 p model.to(device)
9 O2 n2 B8 L2 k9 h. B #下面是为展示做的2 B3 m2 m6 N" L
val_acc_history = []
. d8 K% s U5 d" ^% I$ A" A# a train_acc_history = []
( D2 T" O' o% |, s4 _* K0 E train_losses = []
0 q" x: \# A! p+ _- \' d valid_losses = []4 F# y4 u C9 {& b
LRs = [optimizer.param_groups[0]['lr']]
) \( P3 w) W$ o& W1 N+ Y7 C #最好的一次存下来
4 d+ Y8 \! H- g" B% r% X- G- k best_model_wts = copy.deepcopy(model.state_dict())) f$ [& x, u2 V9 [
{8 j$ k6 I9 N/ t `- d6 q for epoch in range(num_epochs):
% }& s2 G0 `3 {/ e. p& M, P print('Epoch {}/{}'.format(epoch, num_epochs - 1))/ O0 o7 o `; `& P8 u; ]$ F) o
print('-' * 10)! A4 U# M. |7 ^: n
- Z) `* ~% H! E H# ? G
# 训练和验证
2 Q! f0 L% _$ {# w; }4 S# G: f for phase in ['train', 'valid']:5 X+ }3 h$ x9 {" [/ G; ~ b
if phase == 'train':3 f+ y: \* @( I8 B" {
model.train() # 训练
+ r9 s" K3 \) O else:
0 x* r: i0 @- n+ _) H) S model.eval() # 验证
# r: |& [$ `+ A9 r, G9 e% u$ P" l+ _. G, X g/ M* R
running_loss = 0.0; j: ]; p" m9 s- k0 L* y
running_corrects = 0
8 \2 h0 x: {! E' L- {# b% p' B* P% ^6 m
# 把数据都取个遍
& N/ H7 U( E% i for inputs, labels in dataloaders[phase]:
/ B% N" q3 D2 L8 i6 N% H #下面是将inputs,labels传到GPU0 X/ W- M- B. X# [( M8 A: b( p
inputs = inputs.to(device)
" T$ p) G# q% `- ?! I7 D labels = labels.to(device)
7 s+ }# `9 E- B7 h" h) R8 n9 e. ^/ D! S& W7 t$ w, c
# 清零" ^' W2 K# C0 J* @+ g5 k
optimizer.zero_grad()
3 o' F! x J. e8 p) Y9 r+ _) b9 ~ # 只有训练的时候计算和更新梯度( E- \: B) z" n) |. `( F
with torch.set_grad_enabled(phase == 'train'):; ?5 u+ T8 {7 J3 g8 J
#if这面不需要计算,可忽略
' P1 R5 d5 C8 `7 m+ n+ S% N if is_inception and phase == 'train':3 m z4 w) h0 i4 |
outputs, aux_outputs = model(inputs)
+ ]' F! {& Q& B( @" f loss1 = criterion(outputs, labels)6 w% t, p( `7 p9 i' ?
loss2 = criterion(aux_outputs, labels)
9 {$ j& R! \) V( u0 j. V k loss = loss1 + 0.4*loss26 \* t6 m$ f9 o* |6 j7 l
else:#resnet执行的是这里( G9 G& B1 s* V
outputs = model(inputs)
" `& K, v9 V8 a$ w loss = criterion(outputs, labels)2 j; g3 E- O' |9 j
# C! l% E& E% o/ e% s9 T #概率最大的返回preds& x- }3 v5 U9 G# n
_, preds = torch.max(outputs, 1)+ Y: u6 T: _* u( X' ?: c
, p6 V0 L7 ?" E; x0 m0 y/ z
# 训练阶段更新权重1 W A* b$ p( u1 b. ]- ^* y
if phase == 'train':8 G* l& R3 J0 |6 }- V
loss.backward()7 S' l8 @, B. W) C% c
optimizer.step()! i, r6 x/ S1 R
+ }# b* B- `% H5 R) B6 h3 I; e # 计算损失
# q. Z9 Z# y, K7 ?# ^ running_loss += loss.item() * inputs.size(0)
6 d; U5 w" M7 h( O& i6 O4 g) [ running_corrects += torch.sum(preds == labels.data)) N* S. O1 o% v( Y, h
& ]% f) V v5 h$ E& a# L# r #打印操作. I7 }8 ]& B. m9 {
epoch_loss = running_loss / len(dataloaders[phase].dataset)
7 T% |7 o: z5 y8 a3 A9 w epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
; O7 h; E, Y2 t2 K
: J% a( D( ]6 s8 @; w
+ o8 ^9 T8 {" v' G \/ y: M time_elapsed = time.time() - since5 ]5 Y% S4 ~( g! G
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))# u$ l" C2 X1 L/ u: Y9 }9 U
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))( [ R) s. C: K* r8 W
N& i. `" V3 Q6 a* ^" b
6 e- r3 B. L# |6 b2 W. [" Q # 得到最好那次的模型; g% _/ [! v& L& G6 u, Q3 |4 ]
if phase == 'valid' and epoch_acc > best_acc:3 ]5 b5 D. f. O% a: c& u6 W
best_acc = epoch_acc
H6 ]0 D% w% ] #模型保存
$ v8 a- U1 J( _4 ^ best_model_wts = copy.deepcopy(model.state_dict())
$ k( W; F0 K3 P* b0 V% K" f state = {
' Q4 s* d+ t/ G: Y- \ #tate_dict变量存放训练过程中需要学习的权重和偏执系数
4 f9 b3 a5 ` B/ l" _% o9 x 'state_dict': model.state_dict(),
: K8 s& B: t+ ~5 ?+ j 'best_acc': best_acc,
, ]% X1 A- J- F- p* [; D( h& N 'optimizer' : optimizer.state_dict(),# E2 a- p4 k( d1 m0 O
}
?# W5 n! m P$ H+ c8 h. d7 g3 J torch.save(state, filename)
/ p% g N3 ~ }/ P2 v if phase == 'valid':
( J1 N, L! K. T/ ^3 s" d val_acc_history.append(epoch_acc)+ s( b# ~; y8 `% ~ @" V; ?) D1 e
valid_losses.append(epoch_loss)4 m' ^; O( u' t8 y# m
scheduler.step(epoch_loss)
6 G# i: q% p3 Y/ }) Z1 C- E if phase == 'train':
1 \0 X# y2 k/ ?8 ~3 m train_acc_history.append(epoch_acc)3 ]& A7 k8 v8 [3 I+ I5 s
train_losses.append(epoch_loss)' t* U& ], X' b" f
5 b4 a& n$ z$ P, X: q3 Q
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))9 A, d2 [3 ?! z# b0 X
LRs.append(optimizer.param_groups[0]['lr'])
4 T7 K, M$ n2 R! I( @. ~ print()% Z2 C" ^5 y1 K( ^8 x0 V1 [
3 F1 Z h( w& g* W5 c time_elapsed = time.time() - since6 `% l& W8 \1 A+ b% D) S8 U! |
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))$ @( g5 e# l8 f
print('Best val Acc: {:4f}'.format(best_acc)) w5 Z: l# d, Y; ^- ?& s' N$ _
# E1 Q, O! o% V) d7 V) o# t # 保存训练完后用最好的一次当做模型最终的结果( P% Y9 w g, ^ X t( K( t* |
model.load_state_dict(best_model_wts)5 l( f4 ^' v- ]- ]
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 3 r. e% O( Y+ f) _
2 D, M2 P% w. n! |1 D" [. ^# E6 F4 d# O0 q
1
4 I- J3 f3 G* t2* T1 E. U: T2 `6 ^' h: i6 W, }/ Q- V2 {
3
3 o3 U! N, i2 h( g) H0 W. f4
7 G2 T* g9 N' B5 o! V' w% ]# ]0 a56 N7 W* ~2 `6 |5 b2 t9 ~/ {2 k; i
69 W. f1 k7 b+ ~, |. c
7
9 ]0 c# I, b! G8 V. D! k8, j$ z+ U) A% H; q$ ~* O& [1 F8 `
9' s6 E, l( q8 a' B t+ t
10
6 x d$ B" I( s6 d) G11- P" E! z& `2 v6 U- ^, g' |) o2 _
124 d+ l8 ?' E* U2 F
13
9 ]( _: O2 y+ W+ \: J14
* o3 C! p/ g0 A* I8 @2 U9 e p15
6 ]6 t0 k( p" J1 b16
1 ^8 V% A, O. k9 \! k5 p# m' f17
/ J+ H" q; o0 j# p$ ^18
: H( f5 j! N5 _! o197 q0 o6 F6 Y0 l) N: Z# o9 F
20
! q; Y6 \* e, s215 a# D! _+ t5 b. N3 ?% G2 } c) `
22
O9 I3 a3 l$ ]1 @7 i, ^6 T233 X. i6 F; O$ ?0 O
241 I, ~: M2 w3 M' S+ q8 ^
259 k: r5 g9 q' _4 o2 ~
26
3 t1 P' @/ \# a27
& |3 g8 i9 B5 A& Z28
4 W# d6 p# s% t1 I) y5 C8 v2 r295 ~. O6 k' o0 [% [
302 C2 ^; Z) u3 f- @/ P& X* l9 }
314 K6 n f8 `& a
32
/ V' R" ~( q. z f33
7 s+ P7 s2 h& @) t) O q34
! l, s$ u/ W( i35 _0 e- e7 I5 Z) P
36' { b+ P! d, H+ C
37. e" L/ V, |8 O6 n, [
38
- c# ^9 g$ L/ U- L; H F39
( ^( b4 I/ [* ]1 H1 O' |4 t k401 R1 L% Y, m; n( ?
41* Q# ]4 J F) d! b( Q
428 L' \4 u9 O6 a. h
43
9 j9 u' ^8 x1 W2 ~44
% q; \" k% t! K45
6 Z2 c5 b6 N6 m/ L: r46
- S; V8 V1 g! o: a8 ~, w9 V7 v7 i1 h47
" r& }. @8 N+ G3 t, N48) @& A# r H4 c, M' K" A
49( x# n. e) G. {# ]
50* v9 R6 j# x7 x% J. V
51" @6 v; {5 q0 b1 v5 o
52
; z) J3 Y6 f D* k) ^, U53
9 M5 o# O o( |3 U/ c) C* `54$ A a4 ]/ _ ?- g8 s- v
55
5 f% b6 M5 ~* X \ M( {9 u, s560 ^+ \6 T; f; V1 W3 |; c
570 m: y! z: M, G7 F+ _
58
6 [7 e1 b& _# E6 w9 \ T59
! r- o* w$ ~& {% W; Y1 e, Y/ E9 H$ D603 \* m' |9 u$ c& b
61# x, V- L- ?; L0 z3 L$ d
62- Z6 @ s# _2 ~/ }1 M$ q: X1 O' @
63
* c7 u& A4 B" H6 ~% U' `( O2 a' \- s4 |64* Z9 Z. J1 D# q8 y( ~0 Y2 B, b' Q
65
J/ e7 {# S0 N* k+ e2 k! B% Y: [7 R66( a7 ~& S J, q T( u# q
67# I P5 h- O$ r
68
- z: ?- |1 Z& I5 B! D. q69
' n* Y y" Q8 ~6 P6 d70
9 z7 S; ~1 ]& R2 _711 [+ j$ ]/ A; C: y7 y
72
+ a8 [$ f# j& d; g+ O/ {9 `2 p/ E73; |$ S6 x4 X% ?" e$ o9 R ?( s
74
x; j8 ], g/ w) ]$ Z+ D" f) G75
! s0 u- E8 a) G* w( q76
, u+ ]! p% B8 C8 q+ |% b4 x5 N4 U; f77
2 P: o6 Q2 x( h: \# C2 d, P/ z78
2 g6 s3 S1 t3 M: A8 w79
1 Z' k5 q H& f5 I N" x80
( K# Q) B8 N4 U8 h81: [8 l. @2 Y6 K Z% H
82
, e r; Q3 C% e2 j& ], `83" v* _% Z8 l5 j2 M" K
84+ g- b$ J" U% u' \5 L. I
85
6 e$ z7 R: Q P2 ^5 a7 R+ n+ E86
2 r) s: j. Q& u# S87
0 g& ~" R) B0 f H4 `' }9 T0 O( @! n, G88" l) e. Y+ d+ B, H; |
89" _1 m" A* R3 ^/ s' _1 m+ C5 S; ?1 T
90
; @; K& `# L6 p916 O/ K* I) Z I2 J# g7 f( Y) g0 z
92, n+ x# @4 V3 @' v' L
93/ Q- ]7 z2 F* ?) }! v* W( f& g& a
94' V) q" u+ m* b* f }* m
95& D7 T, P* W8 j1 b( J; L B/ R
960 O( T+ a5 S( N6 k3 ?( }2 ~
97 |8 t' |, M: g: B7 \, b
98
5 \8 s t O0 S2 L- @5 _99
& ^, d8 X% k& q, R& p9 {100. t2 y" d9 R; B3 f
1012 j4 p5 q% T. }
102
2 ^) s F# t8 Y103
( z8 ]/ B8 U! K104( O/ I" ^; P% S' u
105
8 d# k' q5 u" E- D! |106
, M2 w/ h- c1 P6 {2 }1077 y# y+ x& e1 f' Y; M' Q
108, U7 {* G) f( d \) S: `3 e" X7 j
109
) {9 x0 u [' M9 n110
1 D- Q$ c7 G/ U' n4 |0 T }" d111
, j; Y# a/ @6 q. B$ P9 ?112* ^2 D3 K$ v7 M# p6 u
7.2 开始训练模型
0 N# S7 W) S# z我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
" @+ g4 p7 C% o7 N
m2 t3 M0 F- i5 ~/ |#若太慢,把epoch调低,迭代50次可能好些
8 {- o# {) K+ g, i6 J4 M# _# S- Y#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
2 m" S3 |4 b+ u! @# o: pmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
3 \+ O+ }1 }9 L& w
# p8 |7 b' s" S2 h1
z% E1 c1 r7 [7 ?2
5 }% X! R" }: _) g7 {5 Q9 S1 k3
3 ?1 V3 t- `, q r) v0 v4
6 S' v* t/ X; EEpoch 0/4
2 P; l! G3 ]) ?7 K" j$ C3 G----------7 `7 t. L& T! k5 a' |2 z
Time elapsed 29m 41s" v3 x. w; |6 S
train Loss: 10.4774 Acc: 0.3147: p0 x+ [! Y6 V2 A+ U, b# ~$ R
Time elapsed 32m 54s2 ?+ ~& `+ K. J5 q0 P3 L+ C
valid Loss: 8.2902 Acc: 0.4719; z- R7 m6 _4 @, P( B
Optimizer learning rate : 0.0010000
. [3 L, D) }. O5 @+ G$ e4 {2 X0 Z
! j7 m7 u% C6 R* ]7 i5 P1 KEpoch 1/4: K# c6 d; A8 z4 [3 _% c
----------1 L# B1 v5 R8 l0 K" I7 `
Time elapsed 60m 11s
( g5 Q, Y% L' }3 btrain Loss: 2.3126 Acc: 0.7053
$ G6 |0 ^) d- V8 n0 XTime elapsed 63m 16s
6 M6 C9 w' @+ J: E x) ~! X/ w* fvalid Loss: 3.2325 Acc: 0.6626/ k) E0 B2 F1 v% l6 q5 W7 }
Optimizer learning rate : 0.0100000
( `# T r; s4 g8 i" w6 R" O _/ y9 ^ b# E* N
Epoch 2/4
4 M4 M9 y D: d5 g v1 Q----------
3 v A4 b/ @ s# e* C; b" {Time elapsed 90m 58s
6 D; g# T0 E. `6 ?train Loss: 9.9720 Acc: 0.47347 E6 u2 l9 |0 p7 ^1 D
Time elapsed 94m 4s$ R. s8 Q" ~: A: u5 K, ^! ~
valid Loss: 14.0426 Acc: 0.4413
# M4 Y! |6 v7 [1 GOptimizer learning rate : 0.0001000
1 D6 z3 Z5 z0 W+ A
/ ]2 W4 A1 C& A% b tEpoch 3/4
/ I; `/ p/ a! @& H! _----------* H5 f! M' O) J% s
Time elapsed 132m 49s
& R/ E4 y6 b# m1 L# T' q2 htrain Loss: 5.4290 Acc: 0.6548+ t/ m, j$ r/ F+ d8 `; g
Time elapsed 138m 49s: d8 x. ^( W0 y* B
valid Loss: 6.4208 Acc: 0.6027
8 T( Y! J4 f4 sOptimizer learning rate : 0.01000009 j# O! i/ k8 K7 E" E
+ n2 ?0 F- Y ~2 X0 h
Epoch 4/43 c3 ]% x& A2 V
----------
+ f1 C: R) C* h q' I! l. eTime elapsed 195m 56s4 } i N4 S3 f% ?, M% t# q6 I
train Loss: 8.8911 Acc: 0.5519
9 W& f% s4 F* ?6 T4 L+ W, STime elapsed 199m 16s( ?# k2 q) N) ^, I( I' ?! C
valid Loss: 13.2221 Acc: 0.4914
6 {0 V: q, ]/ p. aOptimizer learning rate : 0.00100007 P1 W' A8 W3 A V% v4 h' x
; P2 I" e) f7 V; S+ z2 C2 w" ^/ ATraining complete in 199m 16s
; k3 o1 j+ \7 {; c$ v! iBest val Acc: 0.662592' s4 X& `. j3 n9 {' w) F4 e6 j2 w
/ \" W2 A, i% i# R& ?2 Q5 }' G3 f
1
0 V0 S z2 Q5 I0 ]2 h: _1 Y9 Y0 u2
, w& R* X j# L- y0 _+ \3
, I* p" [) c2 g- I6 D. u4 O4 Z5 f: q+ b5 h7 H& }% j7 X' _* g
5' j g5 W: v% A' P5 e* V/ C
6
% T/ g, H( `3 D4 u77 J( Q# K6 Q, D- p/ a0 f
8/ C" i/ l! \, y- o. F, n
9
5 o9 s% Y2 t% A4 A" S1 v3 B10
3 [+ b- r8 Q+ [$ D) U1 K11: Q% O4 z! K) U; Y8 }/ x) Y3 m
12! m# _; f* _1 n5 Z8 M
131 V/ g* e2 P) [0 V$ y; s
14
: c1 V5 h) o* h* t( H( p0 J15
% p- L8 Z5 N" M Z# c4 ^. H16/ ?1 @. R4 G: m* d
178 W! h) d3 H. e( ^! U
181 H! z4 u# G5 b# ~
19
/ O6 s: a; Y4 C; F/ F6 h$ B- t0 p20
7 r$ f" b6 _7 @/ E+ s- G21) T) ], c9 w1 c7 v4 t t, h4 b* V
22: b3 L q8 b' w, j/ h/ q0 Y
23
: A, A4 ?: f) x$ C; Y24& h) {; \* k* C
25
5 ]* b- @2 D6 e- r% U26
6 v7 w) X2 f1 B1 x" c27; y1 ?6 L) r, K5 D9 g5 b2 q2 u
28
. n8 Z V* t: [5 C. N4 q; J292 F _3 G/ \$ z
30+ r+ Y" W5 d8 H9 \! T7 u
310 t" h- ~! A2 V, `' ?' y
32/ z6 N- q7 @2 c# k3 N% `
33
1 c. a- j; X9 a" v a- j+ k/ l, k34) d! Q+ a- B X5 t! H2 @
35
, r# w l8 f% ^0 v* W36
s- H/ _2 s; ]6 v0 Y8 a1 a37. d9 p1 V: {5 D( x* @6 n# M
38% S i P$ d; k+ E" j5 D7 t
39
. V' A. B, n6 b% ~0 `; Y8 _* i40! E" i/ Z% F; D6 ~1 e* E
41
2 M. t/ x8 ]# [( f: T! p/ h42
0 r8 S& ^" Z: c" f7.3 训练所有层
! i# y# A# l9 t: i" ~# 将全部网络解锁进行训练
- ^% h9 C9 r& H! J/ V, ffor param in model_ft.parameters():
5 @0 ~$ L! N# Z R2 y3 b: m param.requires_grad = True
( J6 l" P1 v# G% W+ s$ Y
8 D/ t' Q3 O) H$ h' J3 Z; \) H# 再继续训练所有的参数,学习率调小一点\
' B( ^9 m. K: Voptimizer = optim.Adam(params_to_update, lr = 1e-4); J; Y/ E3 P# C1 t# r5 H: f: z
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
2 r. s5 d- ]) }' D3 }. @. e: W" f: h& O. S
# 损失函数
% ~4 W+ c% m F$ W, v/ ~/ D% Ccriterion = nn.NLLLoss()
6 N3 G. B6 P: u& A: C t: m1
5 h9 C5 Q, M& I h3 I2 v4 \2
- r B$ ?: ^0 y# W9 _) ~/ q3
" P2 _" f& {( Z: D$ O. d" J4% c2 q9 a- h7 w' I; y) M5 ?
5
# |( d1 k# x9 v/ D7 Y9 S6
4 v4 h6 ]! j3 w( x- O: _( `, n7
. I$ f+ r2 g+ Q3 m81 U* @# M8 `$ r. k" v
9: g5 B6 o$ `! s6 ?) W/ ~
10) I6 |# Y0 P5 j/ Y8 ^
# 加载保存的参数3 w( h; y* Q' f _3 E: t4 n
# 并在原有的模型基础上继续训练* a1 |5 _# \2 g5 m8 T+ p1 l% F
# 下面保存的是刚刚训练效果较好的路径1 c+ J) P" D# h6 X: L
checkpoint = torch.load(filename)
- }, @% P5 k5 a$ Y) F; _: Z' \. Ebest_acc = checkpoint['best_acc']
' ^& Z' ~) H9 Y, S7 W& \model_ft.load_state_dict(checkpoint['state_dict'])
& h+ p' Y1 Q3 [: T U2 [/ e" ], ?optimizer.load_state_dict(checkpoint['optimizer'])5 f; h; D# o: C5 ]' K- O4 h' X
1, H5 @7 I' W. L- U( V" `
2: F' ?3 Q) b# Y+ b2 D
3
0 |. c2 E* r( u' [' C$ m$ p4
5 I. B/ y% g, L; N, |" B1 l55 H6 L2 W8 v: {" p
67 \$ T; |* g$ \2 j! b, }5 @: g
7
& p& P3 r. ^9 o7 l开始训练
+ r H+ O, `9 T# }注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考7 O$ r' ]5 }8 g& ]. f) p
& {7 p; Z( T$ E( M/ Ymodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))7 G* ]5 u% u/ a! P9 N9 m& {* H) U
1
9 h- X: ]0 ~0 s, X. D+ |1 l: e" {Epoch 0/18 a/ _1 t9 n$ O+ d
----------' z, K% @- N0 ^" `! `
Time elapsed 35m 22s5 C9 K" G' I6 L z0 Q, v
train Loss: 1.7636 Acc: 0.7346
+ O; A5 T3 f8 h! NTime elapsed 38m 42s
) e5 W1 f+ U3 s; S4 y& d6 Nvalid Loss: 3.6377 Acc: 0.6455
" Z8 @2 L+ T+ j. Y& z* {Optimizer learning rate : 0.0010000
5 w. ^/ A P) `* |/ ?8 U3 i3 A* U1 ~9 U8 q+ n' n% `8 f
Epoch 1/11 Z8 n( ^. s3 s5 z: T
----------
- u+ i8 d, }, dTime elapsed 82m 59s; }/ H. U. b2 H1 J$ m$ B
train Loss: 1.7543 Acc: 0.7340% Q" ^3 T, q# u- e
Time elapsed 86m 11s5 M% t+ Q9 P8 P- J" K
valid Loss: 3.8275 Acc: 0.6137. |4 y9 d* r4 c! j
Optimizer learning rate : 0.0010000
2 S9 C5 S4 ~% V, \5 Q
5 m4 u v8 ]* d5 K( S3 _. r+ q$ uTraining complete in 86m 11s0 t, ?9 m$ Y( C8 j0 b# _5 ^
Best val Acc: 0.645477
+ \& U- y2 }( f5 H O. f5 b( e. W8 m- x
1
/ o$ x( C) w9 d2
4 s, l& O2 D; m* K9 j: F3
6 I/ v& A+ ~- E0 F. P4
) p" E! X, H: S: \. N5. |, D9 k4 P8 M) `$ F0 Q) G. m5 v o
6& o% M% j8 g5 T4 t- t0 v# P
7, n: Q* q/ t9 S4 V
8+ j3 f& @: J- N0 K9 n% |
9
2 y% @; q. B3 r* |/ ]10
0 }( U4 }% M L6 U' }8 v# A11
% X. p* L* t& Q! j12- s( \1 J/ L4 c% d/ `' \0 ` e9 @* T
13* C7 A6 y4 T) ?: \- s) k! H
14, f" Y$ h; b* y5 F; m. q) |- _
15
1 T/ m; Z. c' T H& z8 R# ^+ d" ]( e167 I7 V& e" U5 t% c
174 w, y+ O9 K- \* l- p/ ~6 I
18
: B. _8 r: Z* W8. 加载已经训练的模型
: T* ?, B9 U1 U1 _- p相当于做一次简单的前向传播(逻辑推理),不用更新参数
8 d: O2 P+ g: R2 t4 F
/ U+ a2 w) c- q& _ [$ Q, }model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
' \2 \* x; }9 m m V/ p0 K7 k8 |
# GPU 模式& p& U/ s8 D3 w0 x1 B
model_ft = model_ft.to(device) # 扔到GPU中
0 s3 M3 g6 H( g1 a( L: |+ w" J6 H7 B9 Y7 P2 @
# 保存文件的名字
2 N5 e( q, r! a2 W+ a: c! Ifilename='checkpoint.pth'; \$ a7 ~, J/ V/ F- Z
+ E* @- q! G; t
# 加载模型. @8 d% f+ a: J
checkpoint = torch.load(filename)0 m7 _; c p2 o+ C/ a% z! b$ `# |
best_acc = checkpoint['best_acc']6 X! h' d. l* u" ]
model_ft.load_state_dict(checkpoint['state_dict'])
0 e) ^3 J3 I; h: {1 ]1 ^7 a: |1
3 l4 J7 H5 B$ ~0 F6 h: D# L# L! W) o; L2( X8 l9 ]0 Y+ q( r, A' A( T. J- r
3
, ~3 ^. }, U1 O% n4
8 p- ^6 I2 S. Y5
# U& Y! p5 t7 p3 J/ F# u! C62 J$ v7 q, T C, H# Q1 E, T
7
4 g+ x" [7 S/ L( n4 i; z- w6 ~89 E" f! \, t9 V$ m4 h8 P
9 L3 b1 @. q$ N1 o6 m3 t. t4 G) u
10( _* M1 H+ K* V$ b* c5 L3 f$ H
11
z3 f+ U4 ~; E: w& [2 _ m( d' h12/ k5 I6 h( o6 \( d
<All keys matched successfully>
6 o$ c' {# k b1 `6 n, L1
. F$ |6 a" y. x9 hdef process_image(image_path):5 R( m! {" `$ t2 L( L
# 读取测试集数据# k6 [% Y: J6 @) o B3 k' f" b
img = Image.open(image_path)- b) V" q7 C: m7 H$ v# M' q
# Resize, thumbnail方法只能进行比例缩小,所以进行判断+ A q9 C! P/ |4 O
# 与Resize不同; s2 W$ V9 }- U9 v+ v5 W' R
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
- m" m4 @5 ~" d( } # 而且对象调用方法会直接改变其大小,返回None
- u9 h" d! k; C/ T if img.size[0] > img.size[1]:- j+ S$ N9 ~$ d: g1 t) g! [) q
img.thumbnail((10000, 256))
! L- U& [* X/ h5 O; V else:
) [: N& l0 T- C; u/ V img.thumbnail((256, 10000))2 V+ @2 O1 C) ~0 O4 t6 l+ t. _
; L5 ~0 ~5 \0 }1 z$ r- t6 N # crop操作, 将图像再次裁剪为 224 * 224
, }! F# g# \- p/ {' H' f left_margin = (img.width - 224) / 2 # 取中间的部分, X. K5 M3 i5 ?/ E: w4 l. q7 y
bottom_margin = (img.height - 224) / 2
# a2 ^! `9 _4 G- h5 i5 c right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度7 a; u: S5 A0 O5 |' [! H/ f, \
top_margin = bottom_margin + 224
) C4 W Q Y& T8 ?
* H4 {( r# y/ ?; v/ o1 P img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
m! W3 H' [! Z$ R! Y
" I$ I' b/ N, L+ Y! ~ # 相同预处理的方法8 W* p+ S) ^; i
# 归一化
0 }6 c: p+ K \9 E, T, o img = np.array(img) / 255
9 s' _- @/ A: C mean = np.array([0.485, 0.456, 0.406])0 o4 ^* R/ w* Z( s! u/ ^1 n
std = np.array([0.229, 0.224, 0.225])( v5 X9 G0 l; v* h9 j) t
img = (img - mean) / std
, b- ^* G: o7 D' e' E4 Z0 V
4 u; s/ ^1 X. ^. h8 j" ` # 注意颜色通道和位置& h3 y7 N I: @# \2 u. h' O
img = img.transpose((2, 0, 1))
! C1 V+ b, y$ y
& N c8 o7 G l2 F2 X return img
6 y6 A4 K, r+ L) X, `+ d
: p9 N* r9 b" F$ cdef imshow(image, ax = None, title = None):: _: {! H' s, F
"""展示数据"""% @+ E" x5 m: Q% O
if ax is None:) t/ E9 L+ n$ Z' a9 Z
fig, ax = plt.subplots()$ a8 H) g& ]% b
( ^6 w+ a" ^* W# ?
# 颜色通道进行还原9 B h- o( e9 _8 y1 S E5 _ j, S1 L
image = np.array(image).transpose((1, 2, 0))$ ~* G# w- e8 ~; D W
7 H7 Q+ [! j6 v! C% H+ G u; { # 预处理还原
1 c' ^6 @9 K* [% n x) O mean = np.array([0.485, 0.456, 0.406])
+ u5 J3 y) p% c" r3 P. a std = np.array([0.229, 0.224, 0.225]): |3 E" x5 j/ I3 \
image = std * image + mean
) F. s* J1 t* Q; o- ^. U: J image = np.clip(image, 0, 1)
5 J% S: |9 @' j* o! {
T5 O: K$ C8 s* q6 o* b, s& q( R4 ?9 n ax.imshow(image)
: H2 @ u+ E( _; F: s ax.set_title(title)
" l, w- e1 l, S! ^- |! F* |- \# g. O
return ax
( k5 y, c* D3 Q+ i/ N8 r# o* l( y. Y2 b6 L
image_path = r'./flower_data/valid/3/image_06621.jpg', s- i$ c' e V' D2 p/ U, T/ r, ]% V
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理7 l |: F, \/ C
imshow(img)' s7 B/ q# W2 a% s# O4 T
; U" @6 F1 D) f# X8 `# I; d
1$ h+ H0 ~, y$ ~2 H$ n: z1 k
2 {$ D" c' h7 x! Z
3
2 X, H$ P5 `& |# r' f+ d4
: N4 I% h z6 Z6 K) ?5
9 P, r& u% C) Y9 H/ v1 c6
6 D4 r, e: e1 a. z) m$ M7: z, \( C# c1 O, ~# C7 }1 N
8
i/ J d6 t2 F: }" }2 |% z( h& @" ?9
' f! d: b. b( K# H$ Q10" R0 @" G2 B" N, ~ n$ y) v
119 M+ O. e% r4 E5 R4 L4 t. \. d. e
12
3 ] `1 s; ?: y* _$ ]13
# R9 `$ |' Y" P* C4 I5 _14
( r y# r U* Q1 w157 w2 ]. M( S: v) k+ U
16
# ^8 m6 o, O2 j; X8 V8 }# `17" y9 _" n( g9 K2 q
18
- u0 [$ z, Y/ _% I. Y" Z, l19
_0 R b" _, E- W8 }3 B20
* B/ i" d ]( f! R21
5 C9 R1 K; g. V2 e3 n+ d22
- }5 f6 d: {0 ]1 h3 O1 _23
8 E0 h9 f+ }, V1 B! i! y6 m24
: \& V' I( ]5 d" X" B# z2 o25
& h Y* _! {3 w$ d A# \26
d# m' z' l+ P+ b) K" e" m8 H6 b27, w/ b) l2 l' h3 U5 I% v
28
9 K6 J* V* B: \' B% H: b( E29
: D1 O5 r4 ?! r- z: i/ p9 W30
5 K. l- i, p/ Y- s; `, U7 I31- h: M' F M6 [% l
32
: z- Q: J% r! r. b33
- t( Z& x* o/ _34( m) G6 U1 A) P% o2 o% v# p' k
356 ]' E* [7 L+ u q( D, |& f
36
0 a: U+ G: n. K4 a) ^0 }37- B! S& X- o0 m4 `' N' u6 _
38* D' u b# J& k
39. @+ o- ~; O/ Z P; W6 z. a2 @# t) F
40
5 M6 u+ P# V8 G ^5 x6 Q$ c, J41* b# l- p) V5 u0 U
42, m, F& B N0 _3 o
43! f% U! o; }# ]
446 X, H# v9 C- V5 f
45+ z6 I) y9 H7 u$ b' b
461 a; R, H% w% m+ }4 m f6 [% O
47
! u/ Y# \ I$ c! K48+ }# }( B9 }9 f3 c. B; \
49" K4 o7 E! I0 [ Z M# g
50; r t5 @) T) c
51
- Z% P' f1 A+ Z9 j$ @; V5 I- d) m52
M2 z/ U D( Q6 B6 R2 E) e9 W53
+ l% [/ B" i% {7 a$ E& i8 ~54
& g) v5 y* i) x, v7 }( x, D+ P# ^<AxesSubplot:> p( F; { X7 E6 v) O6 e d0 A
1
, g, ]1 I9 l% J' E7 B& p8 B
, Q: a7 `% h, ~ O+ A; o& |4 t$ [上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确6 J5 ~0 t/ q. a! I% b |$ u: X( k
6 [; Q7 u. Y$ Q& W
img.shape% \: F- T6 B1 q1 `# X, L6 Z
1: ~0 N1 H: e9 A3 N: m
(3, 224, 224)9 Y8 D$ A) A9 @9 S8 o. l- ]+ b
1. T7 P8 l; o* A) {! M s/ R
证明了通道提前了,而且大小没改变
$ j6 {9 ?5 Z$ O+ u5 F2 c6 X' Z. r! h3 E ^! p
9. 推理' N t+ M/ n5 F: G
img.shape
* i" Z6 A3 |" _7 t" n
0 J( F4 c5 f0 x( t& L4 c) m# 得到一个batch的测试数据
+ x; }( X7 N) w S& m& } vdataiter = iter(dataloaders['valid'])8 f- A" W8 H4 Z( C* n& D8 P6 K
images, labels = dataiter.next()" P0 M% b$ @$ |3 B0 K C* Q
& s5 \: l' C! ~" E. qmodel_ft.eval()* ^! g4 p7 d$ O9 t$ r
9 m* `: U$ |0 F) @! q; T+ rif train_on_gpu:
: T( @" n6 Z% k$ d K& _( \" e # 前向传播跑一次会得到output6 x2 p/ C, D2 Y: n
output = model_ft(images.cuda())
6 x! I1 Z/ M2 F, z8 \% Z$ C/ A4 L5 Melse:
1 G6 J5 G7 S, n* `7 l5 ]2 n output = model_ft(images)
' p6 l, G- Y: q: a) U5 k8 F0 x" w" ^" C6 U) [0 t5 d9 R
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值; ]! c: Q) ^5 I; ]8 j9 y" l
output.shape2 N* v8 _- p7 s) ?5 ?" v6 c
! D: t8 F. n% K( t6 j/ p
1# T, }1 c! }- j: a, g2 U" t+ ?
2
: V) k7 z0 o4 F2 a3
5 Y+ W! b( {7 _- D2 z6 \4
, `. r9 o" q6 z$ E3 ]2 |5: g% x# N9 V# {+ i: H% d
6
; Z1 l) ]( E( j& A* a% S7
7 J# t, {+ a- a3 c& v9 g8
& @6 }. {* m3 U: A( X9
! ]; J0 w6 Z1 n* h10% e) i. ]5 R( o6 A
11% x. f4 E. Z2 s6 Z/ g
12 S* s) n4 E1 C0 {$ ^( }
13# N; X$ a. i# ]3 k5 l& p$ }- r
14
) W; D: }3 f; Y$ J3 t6 F1 |# d15% n; D7 c% R* K. z$ q
16; J$ i& |. J& M# D' j9 j
torch.Size([8, 102])
! j: V" Z% P0 o: E( o0 E) \6 m0 M( D P1
2 L4 A8 p: i L9.1 计算得到最大概率
* d5 J9 `4 J) P+ ]- {+ c_, preds_tensor = torch.max(output, 1)
1 `! N* ?- ?6 W: H. V2 w
9 d8 ]; e( u" i3 E6 Ypreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量) x" x0 @) @/ V8 Q# m) J; v7 D
1* T5 G+ K0 Y7 m; y% P6 i
25 |- y! D) c8 a# R6 N' M: }9 x3 e
3. t) I8 F) a" b! d
9.2 展示预测结果9 ~) `. z( f- G2 m
fig = plt.figure(figsize = (20, 20))
1 X4 l- V6 @3 ~6 |) ~6 q! ~1 V7 ccolumns = 4/ ]) P3 B$ t0 y; y
rows = 28 L/ c/ q0 p2 @ i% L
3 t( X U O5 {1 A$ w
for idx in range(columns * rows):- j8 y/ B& \; ~5 V
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[]); y& S1 v# @, @& {9 W9 I4 j2 l
plt.imshow(im_convert(images[idx]))
+ }: D! A5 H! l* h! e! A! w ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
2 a7 W2 v6 d' @ color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
7 m" [7 p4 K9 K* K) @( [* Fplt.show()& w o; E4 ?$ H. H
# 绿色的表示预测是对的,红色表示预测错了, r: |8 G2 F9 ~( d9 u$ |
1
& |* d4 v2 k; u3 j8 X/ Z. U4 @' k2" ]9 {5 K: R# e2 b0 ]5 i: E
3% k, [- {" d- m) [/ B, Q3 f) } J- n
4) C2 k% J5 h! z/ R5 a
5& D# [% q4 A& n1 V* Y
6
3 f- i' a* v2 Z7 z7 O" @7
! D6 z8 i1 s- p6 Z7 J8. O5 a9 E$ Y) {& J2 B
9" d8 ?* d4 q! C# K! z( j, Z$ J
10" a# }% Y7 \* U- a" z
11) p- [8 T7 X% {; |7 B" j e
! F9 Y+ _4 V& P9 @. v3 A1 }
* h8 p2 B+ z3 L* H
+ k2 N9 ] j% N3 M
————————————————5 D. M3 o4 \2 Q' p; c
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
4 `) M( j3 `6 w, |原文链接:https://blog.csdn.net/LeungSr/article/details/1267479404 \& R2 ^) n$ T. [6 L6 B
! |0 F$ _/ W6 a1 C j. q- [$ `' k
|
zan
|