- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564448 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174557
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
! y/ A0 C) g4 p. C1 y* {) N0 @. R6 U5 s6 W
文章目录
/ V- H2 r, ~: Q E, [+ E; c+ r卷积网络实战 对花进行分类
( O) \7 \4 j7 }* a+ U! \数据预处理部分
' N0 d0 l4 i) u5 G网络模块设置
6 n! @3 B6 Q) R4 _# s/ ~' k: L网络模型的保存与测试
3 G/ v, @' v% Q/ U$ o数据下载:
; u1 k: J9 r# K0 e1. 导入工具包4 \) G6 k) X" x, |: q
2. 数据预处理与操作
8 Q( h6 L3 h9 l: O5 e3. 制作好数据源
, ?6 P* h7 W, ?读取标签对应的实际名字. S+ T" {! C+ W( n
4.展示一下数据
5 g" W3 q1 L) y) c5. 加载models提供的模型,并直接用训练好的权重做初始化参数& v5 \6 y4 J; r" y# v3 T* |
6.初始化模型架构- E, M- R# J: w( V+ |# {( D
7. 设置需要训练的参数7 Z8 a) A* B0 O" U" {
7. 训练与预测; y2 ~4 E( _( K! [, h
7.1 优化器设置
. |" x# o! q' s7.2 开始训练模型2 K* o; {" X+ R' L3 B
7.3 训练所有层
$ x' D- \2 D1 I8 U2 n开始训练/ F/ S0 J1 ~% y }9 k8 t
8. 加载已经训练的模型& X: y# @6 B: W# j
9. 推理
5 k3 B2 o* b9 K1 ^. |9.1 计算得到最大概率
3 D5 C# F" r' K0 u$ n9.2 展示预测结果' h7 _9 N) \3 F9 P3 [2 K. a
写在最后% c, ~5 h9 n( n" f! n! b
卷积网络实战 对花进行分类* L$ w8 s2 p) d& E% n# [# P+ b' ?
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
9 G" T/ z6 Y: h, O) l& \" u9 V, b& w
: f! l; u2 q0 G- c- V在文件夹中有102种花,我们主要要对这些花进行分类任务. E: |/ n+ h# `
文件夹结构
0 U5 u: p& \& @" f. F, {$ T2 m+ H3 m9 C+ F* d# ]8 u
flower_data
" l1 n' |6 e; n" r/ K7 f3 S7 R
' P" f/ u1 }! {train- }$ ?9 r. y0 {- c; K
1 s. E3 c* w2 ?1 d( M1(类别)
; P* u' g2 ^, n; c0 {0 A2& C2 A5 d8 Y- Y) D
xxx.png / xxx.jpg
5 \6 d$ [+ h' o0 o& ?- dvalid) t; {9 V( k) x0 r* M6 g+ \
0 I) x5 d2 x* e- a" f( c) ~5 [6 W主要分为以下几个大模块1 \1 W1 S/ [+ C, r6 A
+ ~0 d+ T# \: { R数据预处理部分
- ^$ V% k8 r# P数据增强; m0 @" ^$ y& h$ m ^7 k7 j
数据预处理
. X8 |8 x5 k) a4 T$ B' _: A/ j网络模块设置
% y, w5 S9 n0 s8 x( A, Y! F( W加载预训练模型,直接调用torchVision的经典网络架构
G, [2 E" c6 A2 l, Q因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务) m+ K5 H% n9 j- _6 O8 p4 a% E# E
网络模型的保存与测试
9 u! c! E! T/ K( F; A/ K9 k模型保存可以带有选择性, x" O# Y% ^) }( Y
数据下载:
; R/ E! T+ ], K( p" G4 }9 zhttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
) W3 V: R J) s& G% Z- z
& B# V* O0 U% ~+ Y% `2 w改一下文件名,然后将它放到同一根目录就可以了
/ |+ u0 j. A4 g( I
$ ]" W8 B8 r W5 B( o$ q, @下面是我的数据根目录0 v* c6 |) S: x$ s i7 `! H
# `* y; R+ F( v2 M9 @2 z3 s; q( @" s8 ~
1. 导入工具包7 S* U; ]) p$ z* o
import os! a7 M$ p! k! ?2 a
import matplotlib.pyplot as plt
7 D7 e# n7 x+ D% w+ X2 u3 _# 内嵌入绘图简去show的句柄& k4 B9 `* d; C
%matplotlib inline
- E/ Y# S0 r9 J# w2 L! Dimport numpy as np
/ P2 j) P/ b' L7 m! A+ Wimport torch
! M: S8 W8 I$ R1 y0 {from torch import nn
0 O% Z1 {# s- [. R5 s) O. D( |& \5 u4 `. ~0 [! K
import torch.optim as optim, q1 r0 R0 B' c B" z
import torchvision
' f2 D g; d- A# t" s0 ?, ufrom torchvision import transforms, models, datasets
; P$ a! R4 E+ u* F( _( {; d4 z) f- S2 S2 O
import imageio% k6 T0 _; U+ q; c j3 G6 F
import time i! }3 S1 S7 c* L1 z. S
import warnings- z9 w# o& x( Q. u
import random
6 _( I# U1 T4 w1 U" w' r* L8 eimport sys
" b$ S. s7 ~" f- aimport copy: F1 j2 h) m6 g* G: J2 f+ h
import json
' C0 y6 u1 k! I% m+ kfrom PIL import Image
) Q0 P4 i8 s3 O J
" |/ ^( g2 d* x+ x9 n
3 f1 I4 A% }; B$ p* B1 a15 R. j- n( {: G+ x+ C( x- t: W
2( B! w) l+ i+ Z3 n) R
3$ k' {2 E N' _) R4 X3 `
4
9 w: N% _" n& s) }) `% U4 I: b5
8 G8 j: M. N" M/ _61 h/ _! q% H5 a* W; b* _: a" C
7, V: f; C) l* b# p
8
& g. m, {( l8 h. K! h( G6 E95 x Z' I- `4 S/ h" L" d$ Z
10: U0 b6 a2 U/ ]7 C) f
11
+ R7 `# z- b7 k, I6 ]+ j" S# v12
- }% C6 ^3 _) {1 }9 e13
# k/ n$ E2 q% n, ?; a5 W1 g145 ?8 `7 T$ r1 W" ~
156 [& i8 M7 ?8 \/ h( ~- K
16& j4 Y. _2 Y, x! r
17: q0 _8 T% F: G& ^$ t# y
18
J: D, E2 S; t8 `; }, c! B19
4 Q: s9 r$ N1 s3 U4 a20& v. |6 C' R: M4 T. o/ T
21+ ~5 s, y8 z5 k6 c. [
2. 数据预处理与操作1 [$ B c4 K: o( _# d9 _. A' l7 r
#路径设置
, L7 O8 O! |$ O4 i4 n$ }8 E" |data_dir = './flower_data/' # 当前文件夹下的flowerdata目录5 E U* n4 c& _2 L; U# X
train_dir = data_dir + '/train'
5 h7 C3 \6 X. j# uvalid_dir = data_dir + '/valid'
7 P' f t7 Q! Q3 e, q1 X; Y+ v1
6 L3 h3 G# O* U5 U7 G2' B6 g5 K' a/ s) A# X
3& A, ]' W7 h' k4 U6 `; d
4
' X$ L9 R: L7 e3 h; w! Npython目录点杠的组合与区别5 A* a5 j% g; J1 b5 a8 u& |
注: 里面注明了点杠和斜杠的操作
* t+ Y: q# E( x | C' p# s9 B6 ]2 D! i! m A* Z- _, Z
3. 制作好数据源
) ^8 U2 D3 p0 Kdata_transforms中制定了所有图像预处理的操作
5 m0 R2 v3 q+ a- h7 K) ZImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片0 d; J1 d6 S8 G% ^1 S
data_transforms = {! x5 i3 f7 `9 }* \# h' a: U
# 分成两部分,一部分是训练
- R0 R2 @+ v" g& s 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间' J! R0 V: ?8 I
transforms.CenterCrop(224), # 从中心处开始裁剪; L0 z7 Q) t R! v, Z; j
# 以某个随机的概率决定是否翻转 55开( q0 E" @% u6 j6 ], K. {! ]
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
, L& C+ l0 x6 f0 r transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转* J6 w# v- Y1 X6 @! d( l+ V, d
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相9 R d( y& E% r
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),* K2 E2 Q! r5 Q' G6 u
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
& J$ Y: x/ A9 w% P% A # 灰度图转换以后也是三个通道,但是只是RGB是一样的
5 r7 r" a: r3 O5 e t$ x1 X2 A transforms.ToTensor(),5 M/ [$ R# z& c T* _
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
4 x) w9 M8 X9 ~2 } ]),
$ ~: P/ ^* \( r # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化6 K" w2 L. f. W0 J" w
'valid': transforms.Compose([transforms.Resize(256),
) n" q. `. w+ f! C transforms.CenterCrop(224),) \( n; z! ~( s8 w$ p2 a. S2 X! ^
transforms.ToTensor(),
' f6 N9 w! E4 x" x1 H+ @8 H' W transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同& C4 V C1 h7 N3 ~
]),
1 T+ ]1 b& p6 ?8 i- n% |}
R# Q# W' G" I4 l; h, c1 P! W, d# v0 d
19 {+ \4 H5 o X% y
2+ Z* m. ~* X5 h+ q$ a1 j
37 Y7 v2 O+ p' A& d' N0 P# g
4
/ p( U' S! C1 K+ |5
% c! h- T5 C8 S5 o6 i; P6' O) H# G" o. z3 x2 ^
7
/ y3 ^. C5 H+ F5 Q" M3 i8
2 C+ t- g0 K* j2 M9
! }& S) p1 C% ^/ O8 r6 }10
4 q9 X! |8 T7 x* h11
8 S: m+ ]$ I# I2 M8 L, O3 Z12" w# k/ T* r! ]% g: d- x3 P C d
13
3 J. V& `3 n# t. H" b3 E7 k3 t2 P% I5 E14. D8 h/ b/ |# s2 z( f
15& g3 |7 b1 g! t5 I' h9 L
16
. P1 J- H; y0 I k' d7 S0 H, V; c173 L' N0 o4 v4 c# Q' T- v
18
1 j9 T v8 G' f. ^2 G4 v, K19* x$ ?+ S9 V5 Y: S
20
: O! n C) b' G2 U21/ W5 y( \& U* T7 }; L7 N, \+ P
batch_size = 8
9 j. Q- g7 ]8 W8 F7 U- T# aimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
8 `9 i4 P7 e2 q) `3 I1 e/ T8 b5 sdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
+ c2 o! J$ Y, `0 S2 j) i. `5 ?2 gdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
6 t) k! N" j" i, M) t, D6 \* vclass_names = image_datasets['train'].classes, f& O3 E; `( ]! \4 }1 O
# p2 b/ K, g; \) t/ L#查看数据集合4 w; v& p; j1 ? `( T7 i) n& }
image_datasets
# l* Y% \7 Y2 T" C& g J% Q* V
1
3 n. ?% ~) v8 y' ]2
% ?+ i( Z5 G9 J3
" F$ k& J% i6 [6 N0 ]4
2 u# R0 c2 Z! V$ j. d. ]56 K% \' o+ M. f) U( ?
6
5 G$ w0 b2 _) {7 }9 h$ t; T7
f7 b' p! |3 _8 R4 L+ f- S& j8+ F0 q$ n" W' p/ v% X; e
9
~8 z" f- n1 M5 Y1 d* ]{'train': Dataset ImageFolder
/ `4 ]3 O5 j( [' j# l; J Number of datapoints: 6552# N! Q8 b2 u: S; C) t0 u! |& [* Z
Root location: ./flower_data/train& L& y) O- h, I: q- u
StandardTransform
4 G, z* Z8 A& O$ A Transform: Compose(" \" ?( e) g; \+ E8 d
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
8 W6 b; G$ k5 s1 v CenterCrop(size=(224, 224))* W( r$ G; y/ m7 ?6 m
RandomHorizontalFlip(p=0.5)
* M$ a3 l# z k# o( V( ]: J RandomVerticalFlip(p=0.5)
1 I% b2 G8 g$ x, _& f; N6 y ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1]) I1 l2 i! V! U, f
RandomGrayscale(p=0.025)( |. d% ^0 c \3 W
ToTensor()8 b6 ?& q; A7 X& t2 r1 p1 D2 M, c
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
) q& X8 H6 R$ H ),
' n( |& I0 ^0 J' ^# q 'valid': Dataset ImageFolder1 {& N; a( P$ X0 ^2 T% C
Number of datapoints: 818
9 y, o& P$ y/ w' S& e# j; W" n Root location: ./flower_data/valid
' W. B1 K2 ?, z, j& g7 o* \/ @/ z StandardTransform5 l% y# x4 E0 J Q6 \2 h
Transform: Compose(
' U& _/ E: S9 L$ s1 z Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)' o R+ C* m+ E: p. U& u
CenterCrop(size=(224, 224))
9 i2 K( T7 q2 F# B4 g( ]: m7 j ToTensor()
' |# r, e) f$ Z Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])8 V; t% O# ]% b2 ?
)}
7 L- a2 K) l N# M7 O9 c9 M
5 }. p+ A' o' y8 j& k" S7 K10 W3 j" f3 L4 |9 P5 {$ P& b8 Z& ^ B# K
2 l2 g. Z5 Y% ]
3" e7 i4 i2 @' S1 M9 V1 X( R$ @
41 P1 r S7 B6 M+ b& q3 U
5
7 [/ t! ?1 Q* B; [6
C% c' a1 a& }2 ?8 ^8 t0 U7 {- F7
+ ~ i& n* O8 W# \; G80 P% s6 J6 k- b7 n" t& b, O
9
( W+ |% v- _+ S, T10
$ [4 I5 N* u p) |11
V% Z! r' v# J W+ Y' x2 B) J3 n12
) m/ q, h& D6 M. C; X13
, ?2 ~: W) U- |6 r7 x1 T14: K P2 L7 X, B5 w8 e0 [
15
: ]* r7 \1 m* ~, K% F16
r% O4 C2 y4 X& y) e172 q3 e% t: F2 O! E' B& D5 z
18# j8 P/ k+ Y8 Q1 K) H. v
19
/ M/ Z+ g. {( e0 a20
5 Y9 Z% X3 l' {0 W- I21
( S8 k1 ?- Z9 G: \+ v+ \/ N7 H22
4 q+ e: R( c3 x" Y235 }% h3 D3 }, V. C
24
6 I' ~1 N# y+ Q2 r7 R: f# 验证一下数据是否已经被处理完毕6 c( M" [3 H/ j
dataloaders
; `- H+ k! h) `$ p( U1+ R, E' S) z: m
2& u; T- V$ X& ]% Q) w" ` p3 O* F
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,7 S$ v# _0 R6 u" V' x3 |
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}9 {- K; o8 G3 U( o2 Z! g) g
13 S) Z; X# e: Y# p
2
7 e. ~: F* {6 V- |* \6 |dataset_sizes
. k6 K" o2 o4 {9 z% @5 x1
1 o+ J9 q) c1 l5 j) D+ _; g{'train': 6552, 'valid': 818}
; e& H% {) j/ A% ]; E6 v9 I1
$ S8 X$ G) c* i$ K读取标签对应的实际名字4 z' N- O2 [( x! }1 B
使用同一目录下的json文件,反向映射出花对应的名字8 o: v. P& ?, {
; `- V+ `. ?4 P/ Y0 T9 }5 W
with open('./flower_data/cat_to_name.json', 'r') as f:
( {! @4 M: p/ {& y cat_to_name = json.load(f)
9 b; S1 [/ {- ~, ^8 P1/ x! \3 p7 G3 O% Y4 f
2: \' `: m' z1 W% N% y
cat_to_name
+ S1 p/ \( u7 a1
: g, o8 a: \$ K! h, J( N/ n{'21': 'fire lily',% b5 {3 T* G/ s
'3': 'canterbury bells',
q' f1 p& i% |$ x j" ^ '45': 'bolero deep blue',1 G. T" \) A) E% \0 v
'1': 'pink primrose',
& N) ^2 n% E' V! L/ N. Z. h '34': 'mexican aster',
, g' Q# a3 h# Q0 B q/ S. p '27': 'prince of wales feathers',& n$ \( {; C4 }% P& H o$ ?
'7': 'moon orchid',/ p" x) u7 S3 a8 k! m* f" k
'16': 'globe-flower',
$ }3 c9 G y6 h/ ?. Q7 f5 T '25': 'grape hyacinth',& @3 J7 W! Q8 F2 r- P4 y& d
'26': 'corn poppy',
! J: q# ~5 y6 \$ B6 _* j) K( ^" O8 \3 Y '79': 'toad lily',$ l5 N. Q- R. D" _; g9 I, l% }
'39': 'siam tulip',% M) d$ E4 f2 a( c9 P3 S- p; [
'24': 'red ginger',! w6 `. ?+ N. X! k) e
'67': 'spring crocus',
" ?: Q7 i& J. D+ }1 h '35': 'alpine sea holly',
8 |% D# Q' |' \5 _+ X7 l' j '32': 'garden phlox',
6 P- k: A) T- h4 f- G) e" I4 n '10': 'globe thistle',
0 h$ L. w3 s$ ~" f '6': 'tiger lily',5 I+ H* |4 C/ N; o
'93': 'ball moss',# V: g# q0 o1 C
'33': 'love in the mist',4 t9 P: m1 _" F- U) }9 ?
'9': 'monkshood',0 r# }2 m6 K) X/ O
'102': 'blackberry lily',' S; j! K; b5 I+ \+ `9 s8 q
'14': 'spear thistle',5 v, }: q# {1 @# B
'19': 'balloon flower',
8 K. _$ h8 g f0 i, M! M: B* P& f/ M '100': 'blanket flower',$ F- { z$ d7 n
'13': 'king protea',
) J& D0 ]8 U: q6 y% O, b! V '49': 'oxeye daisy',
" n* U6 o9 T( d '15': 'yellow iris',5 y& S0 ?: V @: ]0 D1 }& [
'61': 'cautleya spicata',
5 h0 E0 c2 V* P8 @$ Q8 S '31': 'carnation',
/ {+ U+ R$ I( q '64': 'silverbush',
5 j* `- @) t& w d1 ^. Y '68': 'bearded iris',9 w- N" ~4 z( p( w; R6 L' \
'63': 'black-eyed susan',' Y2 s9 y6 S, N# T7 Z; z2 A
'69': 'windflower',5 f9 |' @- q& M7 p9 k/ f
'62': 'japanese anemone',
( J& g) ~( }+ M '20': 'giant white arum lily',
! C& \0 \+ G: V4 N '38': 'great masterwort',
) I( t- @% H5 Z" ?+ m. K4 Z '4': 'sweet pea',
# D5 [& L o9 D; X, F '86': 'tree mallow',6 z: y" p, Z X, F: l! R
'101': 'trumpet creeper',
! Y. |6 k% k' k '42': 'daffodil',
. n' T: M, x# l1 c$ c, H8 z6 } '22': 'pincushion flower',6 P' t% [1 G* c
'2': 'hard-leaved pocket orchid',
1 }6 _+ {8 ~# y: d '54': 'sunflower', W3 |3 l- s' u: D: r& S: a
'66': 'osteospermum',5 l4 G2 A' N- j% f/ @
'70': 'tree poppy',
% @8 W0 o6 t& M* u' N6 ` '85': 'desert-rose',
% c. |- U: v3 K& e: A" q. e1 _ '99': 'bromelia',
- Y2 Y' K* a/ t7 H '87': 'magnolia',
: G3 ]6 b9 ]- I" S) R '5': 'english marigold',
+ C+ L& Z8 G2 D3 u '92': 'bee balm',
( G% J" j& q& w9 g '28': 'stemless gentian',
H! y" A5 F1 B8 Z7 U( j '97': 'mallow',
$ ~" k5 [1 f9 z: T. I0 b H9 d1 `/ v3 k '57': 'gaura',2 d) H) p3 x0 Y4 K- X) \
'40': 'lenten rose'," L1 \' _0 W( j) S* Q" g
'47': 'marigold',8 i& _& q5 m! |
'59': 'orange dahlia',
/ Q% P: e+ w$ I, {' `6 w* C '48': 'buttercup',
. g+ I' e3 i6 J7 F: [6 D '55': 'pelargonium',7 h6 h, z+ D# A4 @+ ?9 p
'36': 'ruby-lipped cattleya',0 W/ c4 r0 |0 A! a/ g3 P
'91': 'hippeastrum',
0 I' l# v. y) h6 e; [ '29': 'artichoke',& |5 B0 r( U+ O* w" r- T0 U! |
'71': 'gazania',& |+ ?+ s" s% V& l- R& O: k8 a. L
'90': 'canna lily',
( T |( M7 D+ t* X3 ~: C '18': 'peruvian lily',
( V. k1 m* b5 u '98': 'mexican petunia',
: P2 {4 O- a' p! M, \1 v2 o; \ '8': 'bird of paradise',8 g' V. Z; N' c- v( A3 y8 K- Z% l
'30': 'sweet william',
9 t/ e. U0 n& O '17': 'purple coneflower'," I% d- w( z. H" p
'52': 'wild pansy',
/ W* o- M9 h. i+ B5 Y '84': 'columbine',
( {% g" _: B, B1 A. b '12': "colt's foot",/ [ P# l4 h; S" x4 S
'11': 'snapdragon',
1 v) o [# }% v% ?$ ^- a '96': 'camellia',2 w) ]& p! p ]$ q" s" R, w
'23': 'fritillary',
9 D S% S5 f; A; s o* z '50': 'common dandelion',7 L. y$ @- I. }
'44': 'poinsettia',3 _# _3 P0 y1 T7 D& w
'53': 'primula',$ v4 B+ l0 u! j6 m
'72': 'azalea',# T! [& f, ?; v1 d6 D- K
'65': 'californian poppy',
7 [* c5 z8 {6 O* D '80': 'anthurium',
5 Z$ i9 U9 m/ J+ Z' v '76': 'morning glory',4 a1 r9 T/ [: e( l' r
'37': 'cape flower',8 M1 U8 s1 V9 z6 A/ O7 u* x& v' r
'56': 'bishop of llandaff', u$ \, C3 E8 X I# @
'60': 'pink-yellow dahlia',! S# U: M' Z, x: I
'82': 'clematis',
# \1 m8 \! d5 w '58': 'geranium',4 S1 }+ [; g0 T: l7 h3 w/ `
'75': 'thorn apple',5 o! }( g D( E. b* C' [" g
'41': 'barbeton daisy',
c" D" } Q& x& t9 Q' E '95': 'bougainvillea',
8 P8 L Q; ^2 T4 c4 q! X* L '43': 'sword lily', o/ @* o& B4 O' }9 a8 u
'83': 'hibiscus',% t) S. d9 E x) P+ G& }
'78': 'lotus lotus',
5 m. y3 s# D# h' q$ z9 ?- f '88': 'cyclamen',
% g" v# F5 }# b '94': 'foxglove',; j" A7 P2 Q& c- v2 l8 }
'81': 'frangipani',
: ^: M( S& v6 x* G+ s '74': 'rose',
3 t0 d% x; K3 C1 R6 ^ '89': 'watercress',
. { e6 N1 E/ ?- k: p3 a '73': 'water lily',% E2 K' c" v5 ?3 I+ j, O h
'46': 'wallflower',
' p, V+ |1 A; U2 d/ b+ B '77': 'passion flower',& F3 ?/ r7 ^, E' r
'51': 'petunia'}: I% _. R P, l6 M, V
; O2 k9 a0 J; G3 W' N( b+ m1
0 P1 R7 u5 k6 n" L& r2
& p. y" V% P& [) ]* E3
0 s! _3 [7 Y! V- v$ @4: ?3 U& K n! W; t
54 V' _4 ]( l# U, [* o4 i
6" U4 u, r% E0 x8 a# T
7* G4 @) @1 @1 K9 O# E7 @
8
) S1 i5 B* F$ P D9
; ` k+ ~) M3 E) l' \0 ^ x$ x2 o10
1 q" [) @7 M8 H+ x5 d( {11% I4 S* b' u `) A
12) H' H5 S" P3 T K6 m' Y: J# s
13: K$ Q: h# }2 U) q( d0 O7 Q
14
8 `" j- K, t7 M- h7 O# P% ~$ o. q153 `, w" m3 N4 ^$ E9 g; @
165 V6 M7 Q+ M/ s+ _; p7 s
17# M9 E' |- |4 [$ i# m2 U
18/ W. p0 P2 g; G
19
l$ H b o6 x+ X200 r" I9 J" r, Q( E1 R* s2 e
21
6 e( k* i" X v P' q22# h2 i* E: j% c
237 S' A# t6 A& A) i# T# a1 l# c
24
% R0 M2 X; J2 C25
" g: @- m4 d/ t2 S1 W9 y5 q+ ^26) B; |' i+ s- T) w: b e2 _/ I' q" o! |
27; b( H9 s6 y6 c( r$ K! r d
28
/ g! h* F+ {/ d3 Q" _1 V29$ a5 t- l; \& u8 c" T3 _
307 q3 ?+ \# ?8 B9 u$ [' M* x* ]
31
" z8 G, R! ~, J1 O0 P0 t32) x# l5 C+ H1 g* n2 g- a
33
5 E2 y1 o+ r: _4 x34
* r/ l, d! G. j% t353 I2 |* |: h& Z4 x$ c- @+ b
36) D& _3 |8 u2 U5 C( P/ c
37
! U( i% T5 Y( R38
i, t: s4 b0 }6 n391 b7 E; N: C" I9 B
40
% i. N9 j( B: g" u3 Q. ~8 D8 c4 |419 P. P; E8 ^# P% K8 u- y
42% |; M0 m( N9 F3 m3 r1 z( t
43
7 `% ~: Q' V* }# Z441 @9 s( `& _$ y2 E3 Q! e
45
, ^& @1 E, I0 K) }7 z46
* q5 y9 e4 Y% b$ }477 ]: U3 Z- C; b5 {# D( t
48+ m/ v% S4 R8 U3 M( O
49- D: T- S: n1 W' i
504 f# X, X B3 e9 E6 g% T* v
51
2 k9 }/ _- E- C52
- M1 y; s m) a9 i {) z8 @- L53. p/ h' Z1 ?/ ?) E1 A
54
% g$ P) j( c. ?55
; ]0 ~0 \6 L+ E" Z' y) f6 n56
( k' l- M! p2 O+ x5 V& j578 X. ]& s6 T3 u0 b) e r2 ~
58
; r% o1 F$ ~3 H N59
* ?. M2 z5 {2 Q q: N6 k60
0 v7 y2 q7 g" [( Q& {1 w61+ J7 p3 u& A! B9 ?- M3 V# z
628 U: h( p6 f5 x' s: ~; d5 n
631 f& ~5 p. i: D3 N, `
64" ]' M( o" ]5 p
65
& Z* ^7 w; n$ v7 p) T- e+ u66. a' S9 C6 @ O) c6 v
67
; h) n+ q0 ?% p4 i68
- F7 L! `6 c, N! b0 D69, r- ^6 `) U7 _6 _* ]5 D
70
/ E5 L* z* n( ]; _' G& R71. q& Y( A7 T% F; E1 X5 r
72
+ z$ d0 z9 ^ ?! w735 I$ w2 l: A- c" f! B
74
! C6 i) a& o9 R/ }' s75
$ I: R* V) [$ U. m; t- E76
- i: D0 I( a1 Z77
! w7 W W) j- r9 H# P! O780 R4 i# H/ F2 A
79$ m/ _! m& p( ~; r/ J( y5 _7 c
80
9 l: g! L) ]/ V) u& D# m819 r+ e: E9 S2 S5 P* F
82
" V; g# ^) C. o83/ C* Z! J+ x. o; h% {) V
84
9 Z! D9 z2 ^5 H3 X# u) S85+ a+ b; @1 P0 v$ y( W: I
86+ q& [7 i2 v6 D: n; |
877 `- E$ m" h9 p8 S, T+ v8 X
88
D5 p5 }# }2 Y/ c' O+ X) o89
/ t0 `& Q+ P; ]90
+ ^; |4 i* t+ A( G91 Y- u6 ~6 `4 F. G4 {2 @3 b; T
92: Z# E' E% c5 N U* X6 R# I
933 f) y3 N4 H6 ?8 O
94: U5 r, \: V7 v
95
4 o; ^% u0 j- a+ b5 U' S$ ]4 q7 ]965 U: ?4 u. n: d" ~# e
97
6 S& f( i7 V$ [! [+ J" N2 `98
! L' u3 W, c$ }" ^; x8 q9 y99; S2 R; \$ i6 F2 P3 [ q" ~
100
5 ^+ T: B. c* W. l$ n1 E101& S: }# f( a& W0 G; g
102* {% ]" L% ^2 M; f# j& H
4.展示一下数据( t; m/ Z9 q9 [
def im_convert(tensor):
' I7 v6 j7 g+ Q """数据展示"""
. Y+ [+ Z- V2 D5 f image = tensor.to("cpu").clone().detach()
) E: H# ?% K3 X2 m( W image = image.numpy().squeeze(), U; g; z% _/ z/ Q {; \: t
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图! M' M, M. i' w5 a- S
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)7 j2 ^+ {! P1 ^ E s7 m1 C% T5 f
image = image.transpose(1, 2, 0)0 n$ m( J& c$ R4 h; {$ x, `9 u
# 反正则化(反标准化)
+ G5 q1 `1 a; A0 m1 ~ | image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))9 X/ ~+ H' a; ]4 `
8 Q1 E+ j: m, V2 Y3 T) s* g- P7 d5 c # 将图像中小于0 的都换成0,大于的都变成1
2 `2 s/ C: {& p$ A image = image.clip(0, 1)
: k1 e! j, f+ G1 [$ p1 G5 ]1 C K9 R I; K
return image
$ `% i( l7 L) S1' c, G4 J0 V W' @' O: f5 F( \
2
2 m; Z0 P, g( x* D0 }% u3 u: k- R& d. T4 N
4
3 ^0 M# L& f: d4 h5* I J' t* j1 j9 ~. D! k3 F+ t
6
1 |, w B& F# G. m" R4 G5 s! K+ H! O( V7
# R H2 m" A4 b) j1 Q6 i8- d9 \% W3 W$ R* e# d
93 d7 r/ l# T! N# f7 u
10- X/ h' I3 @* T- i% T5 {
11- | T% c9 ?6 G' a7 F4 A
12, P ~. _1 F7 c+ m
13
% r5 t6 H( f |14
3 _. d( ^" P% x. P: @( L# 使用上面定义好的类进行画图2 p2 U3 L1 H" M
fig = plt.figure(figsize = (20, 12))7 C2 a8 Z5 z% q; W
columns = 4
1 J* G/ _( e$ x' D! K+ T* Mrows = 21 }# W- D E t' w" @3 t; l1 \! @
% ~4 {! j: h" @( O# iter迭代器
# z; o; c$ d- a1 ~7 D* T, S# 随便找一个Batch数据进行展示
j( J2 f& X# cdataiter = iter(dataloaders['valid'])
1 F" i8 G* B o4 dinputs, classes = dataiter.next()* E+ U0 L6 K2 ?
% T7 t; o! I6 N; m9 A4 n2 yfor idx in range(columns * rows):4 U! ]8 Y6 }4 e+ J0 e% q+ s+ M9 s) K
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
0 O% A& W6 H( z # 利用json文件将其对应花的类型打印在图片中0 q9 J4 X* R2 D8 r
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])) J% b/ `0 u) Z8 a7 K7 D3 a0 ?6 [8 K
plt.imshow(im_convert(inputs[idx]))
- {+ V0 y' F0 o7 T- O) U9 Aplt.show()
9 ~2 n+ l# B4 {
1 h) G, r; Y' P+ Z1
7 ~6 T+ J$ S) Y4 m" b2* u& O* I; s% T# Q$ n
3
& G% u, T0 X4 c' K( I4
" |# ~/ P7 O# d7 b$ B3 v9 b: D$ z& X5
6 S( p; X: D$ i- u# I: ~/ H6. V( p! r Q+ C! W6 D3 P6 U! ~
7! e$ I/ w4 n7 i# b- `" X
8
! o& d4 o6 |1 h3 @9
7 Q) d. g5 d4 X+ d) v* E10
3 Q2 G, a) N; E" I11
% t$ P( K# @! C1 T1 ~# M- G# t12# t4 o& [% k5 } [ ~' [+ y
13$ X" d. B0 J/ n9 h
14. K5 Y) g9 k6 A% q! p3 j7 |( c) e& T
158 A* u; z" ^. Z1 o
16
" W9 h- U1 A) j) _: J
; ^/ J9 I- n' r2 L0 @3 Y; a
0 g! D- Q4 }; ^. S" h% W, B: r* u5. 加载models提供的模型,并直接用训练好的权重做初始化参数
- z ]) h- d% P- @/ C: `model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
7 c# E! D( D% C) W2 s# 主要的图像识别用resnet来做# I' i0 p+ j( r0 H7 N
# 是否用人家训练好的特征
+ {( ~0 { A/ j" _feature_extract = True' M+ R2 b& d# n3 R6 f- P
1
5 _" j8 N' F5 Z! K1 W2
: l" V; @; o8 w8 ]6 V+ L32 _# J+ o) ~- @: r
4' N" w. f1 B% Y) U3 t
# 是否用GPU进行训练
) F* S( ?8 t3 L, v7 i; r8 I5 h: S) gtrain_on_gpu = torch.cuda.is_available()6 c& y; c' s4 w) P& Y) L
4 ^1 c6 |0 k; ~+ g0 o. j
if not train_on_gpu:8 t: H& @/ f/ s# `- E6 n( M, {
print('CUDA is not available. Training on CPU ...')
3 h6 d$ h, \6 c# t- Z- A( Belse:
0 w( w- e4 q f% T: {, I } | print('CUDA is available! Training on GPU ...'). [9 k U7 _$ |
6 F" `! I& t1 i" _device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
; T7 I8 B/ F( ?+ J9 _( n1 d: R1/ ]( S( T$ ^3 J& P3 Q# N: O
26 `/ H i. o6 e$ U% ]% U
3
3 X# ~5 b- x, Q* ]& f. \, e4
+ w" P1 F; M5 K( C. q N5
4 D6 U5 e2 [1 F! B @, _5 m6: }8 h `8 u$ M, r
7
0 M/ }3 l1 _* P1 a: z4 z81 S4 U7 W/ n% a2 c) {9 e
9# Y8 ~. l! D! Z/ b+ j1 ^8 D
CUDA is not available. Training on CPU ...
; k [, {; s# s) I9 c2 U8 y1. }: Y0 K# m/ x$ J$ C, L" w
# 将一些层定义为false,使其不自动更新5 V% W( j. M* N: j) R) o
def set_parameter_requires_grad(model, feature_extracting):/ G2 r- s5 \0 E7 Y4 C- F
if feature_extracting: |! ?$ A6 r/ ^7 y3 f
for param in model.parameters():% r% a0 v! B6 y3 o: M
param.requires_grad = False- G: x9 e& e* r
1) N0 W& G/ f' h9 @
2* V/ n; j4 V, @$ C
3) _9 d, K% w0 Y4 S* O8 r3 e' D1 x: k
44 G/ C) L- N s9 i8 P1 z# {, d- j' Z: }
5
& O9 e( a3 [/ O! F# 打印模型架构告知是怎么一步一步去完成的
" L @5 N U. o% i1 e# 主要是为我们提取特征的- G% f, {2 G @0 m$ U
6 x- O# [4 N& L& D/ H8 @
model_ft = models.resnet152()
; \9 P5 V1 A$ o& ~1 k% H$ Hmodel_ft
9 I# Z( R/ L/ X1
- ^ o7 [, r$ d3 V4 n2) _# K) ]) E9 n; F. E: h# Y0 ?
3
; \0 f1 B- J& O+ [4
. K7 X- `) o$ |, q3 u$ A5
- [% Q4 Z$ d# f6 _4 L! z+ cResNet(
1 ` F4 d! G( i( u% s5 e0 {; X5 z2 z (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)7 ?+ @9 l- f3 `& R
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# R; I8 b8 a' r7 N (relu): ReLU(inplace=True)1 x2 H- w) L. l# [' J
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
- `( s5 H" o: }- S( I+ p4 q (layer1): Sequential(
* z1 k' {: u; D9 f8 P% R9 k5 n" ?3 g (0): Bottleneck(
# S1 \2 j" e1 u' a: h: T8 m (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
6 B5 {( s7 J( t' A: I0 X/ | (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)2 G8 A9 z6 A7 ?8 f3 ^2 \
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)* o, J5 B" q. i7 K8 l: Z$ f
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
% ?' n& k) V: ?) H* p E5 i5 h (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)9 F( q- j) R; J4 }1 }1 v7 M
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ w Q$ ~1 I4 v M9 T0 i
(relu): ReLU(inplace=True)( p7 W7 k, W i, I2 f( V4 X
(downsample): Sequential(9 ^6 b1 O4 C; `1 I' ]$ ?1 ]$ k
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)- c/ Z5 b2 H- l( ?9 N$ m
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 u$ U4 n% I1 _( J ); I0 s9 @5 N" N+ r7 e
)
- i* r3 y M0 a+ h: q2 I中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
# W# R) Y" u% z) p( z- B (2): Bottleneck(
$ P# k+ d2 V! K9 V! L3 x( m (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
2 P% i' g, Z6 P7 _6 {$ E9 y" |2 S3 Y (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True). _2 e5 Y4 d- @0 g+ T- C
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
" r+ C' k5 ^1 n# W% W7 V (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)0 E$ `! ? K- V$ h
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)) J) f( m' |' u, h7 E6 `
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
6 Y+ Y, b2 W, O/ V4 ?$ r (relu): ReLU(inplace=True)- ?- p0 G) Q# @, {" N
)
' z+ j0 K1 E2 L8 o )
" y3 h& |1 T7 O# | (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
. Z0 c- P8 R$ @5 {/ P (fc): Linear(in_features=2048, out_features=1000, bias=True)
6 i1 E6 e) I& `4 G)4 h# y9 \3 d# X4 t5 w- }* y
5 P8 z3 \9 E0 z9 x6 A
14 g/ I& c3 V* d- A) Z( F- t0 a1 C, W
2
/ z- M- S1 p3 s% N' }3
3 [7 ~5 q# \; Z; l44 S: D. q, V+ y& b; }7 l) ^( }
5
' w/ f7 J. v% b& K+ x1 T6 _4 a) x; p' J
7$ ]/ t/ l9 b/ N% g+ r" v$ `: `
8
- H& P, f% }# a/ @0 s7 U* m1 w9
. T4 G+ S5 ], C6 X+ E2 L9 h% _10, c4 N; W7 M" a! v
11
1 P$ B9 H& @- ^12: U. e4 D9 |/ r. I) ~/ b6 e3 q8 O
13
; g9 x) ~; v/ i# l5 l9 G. h& P: t14
5 `( B) M7 y$ U6 a S) h0 ^" g15
0 n+ x2 g" |, O7 |- ]7 R: `9 B16
1 g Y& l7 Q# w3 Y% b# }5 f' b171 p* K8 I( ]8 O+ p
189 D# d; A- v0 T& O* i" x
19
6 E \& l. ~7 h# r9 s4 e/ F8 Y20
" ] U1 N: f0 i5 h0 J3 S213 q; X0 p9 o/ ~. v5 e+ O1 t3 [- d
225 [5 g3 j/ m/ e
23
; o& |" X& J7 [- B24
6 Z: j' X, Y8 Y$ u2 I2 y25
/ t/ {: {7 G% s f x4 i. i* T26* v4 l$ E* }- \1 s$ l8 F0 N
27
- e' F/ _# U/ h28
. q! w6 e& T5 t, O) E1 G( x29
9 A7 V( X2 A% }" A* b30
! ?% f- q* k6 {; j31
7 T* i- {( @+ l32
& g0 z4 Z/ b3 @* d33
; p- g4 O8 V, K* w( y+ O最后是1000分类,2048输入,分为1000个分类
- D% k4 Q/ x( v9 @& s而我们需要将我们的任务进行调整,将1000分类改为102输出/ c; M' g1 {% V; O" ?& O3 e* c1 w
5 a+ I7 W6 h/ N5 m @2 O
6.初始化模型架构+ n) y6 u8 C7 l/ \) J3 t$ r
步骤如下:
, T2 D6 \! U+ C1 k
; A# E# g6 `# d+ ?* m# {) _将训练好的模型拿过来,并pre_train = True 得到他人的权重参数( u2 F2 o8 m8 @: F' j9 u6 {
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
& J/ ^! S6 I Q+ U无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数" f+ p% i! p4 i# N3 J/ m( Y- O# X
官方文档链接( l' I, B$ e/ }
https://pytorch.org/vision/stable/models.html
' A/ V0 E% T7 |" N( a" s- i3 F, F& B$ _/ X& s s) L
# 将他人的模型加载进来
$ x/ W! b' C4 G; j7 [7 T8 Y3 Vdef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
- z& e2 }4 y# Q # 选择适合的模型,不同的模型初始化参数不同: l1 m: y9 y' G. T6 Y- m6 s
model_ft = None
; N0 a. _* T5 c* T input_size = 0
6 w6 d/ u- Y$ o4 |9 u% b6 @
; {2 `( S ], H+ L' G; p if model_name == "resnet":
0 H- T' Y0 j- _ """6 W' T6 l/ H3 w% Z1 V
Resnet152
# _- q* I1 N) F+ j# c- W """$ T$ F1 U8 m4 v5 b1 E6 `) i
9 M1 k X. H! y. B
# 1. 加载与训练网络! e5 u$ s2 v$ Q+ u! Z3 M5 \) [
model_ft = models.resnet152(pretrained = use_pretrained)
# ?* V0 w% P: `$ m; e # 2. 是否将提取特征的模块冻住,只训练FC层/ w# q9 {2 A& @0 f' _4 `
set_parameter_requires_grad(model_ft, feature_extract)
[. x+ Z) D X Q% v% S # 3. 获得全连接层输入特征% H2 l9 c7 Q+ l" O# S, [+ M( j
num_frts = model_ft.fc.in_features1 T" l8 Y) ^3 O. ~; R9 e" s
# 4. 重新加载全连接层,设置输出102
8 g! A. ]- U6 |( x/ ?; @ model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),8 N' f& ?; B+ o [0 Q" P, O0 s" m
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
0 l! v" Q( w/ j4 J1 H, q& f" Y/ o9 ~" G input_size = 224
% x/ s6 e" f) I2 G
, U" B7 h1 Z$ c" |6 ~; D- c elif model_name == "alexnet":7 S5 \) q( o+ C; X
"""
' g! P3 }% R9 a; t9 [7 d Alexnet
+ g. [$ X. V+ O+ G5 E """# [) v4 [1 U, R' S6 ]
model_ft = models.alexnet(pretrained = use_pretrained)9 z2 L7 v% q5 P$ ~3 |% Q% ~
set_parameter_requires_grad(model_ft, feature_extract)
, R) N, L0 ?. J1 i+ s9 m$ h) {5 C6 o2 y" b
# 将最后一个特征输出替换 序号为【6】的分类器" A8 `) \7 `/ Y
num_frts = model_ft.classifier[6].in_features # 获得FC层输入+ W. y! O3 E) S+ \ L/ \- R
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)& \; f4 u3 y( {
input_size = 2247 _+ }- s$ D4 h; x6 K; q, \
( a. c7 c4 f6 k l elif model_name == "vgg":5 `, U/ D& P! i+ d! ?' o% W
""". q2 H% d: O! H8 m3 W5 p4 g' z+ \
VGG11_bn
. ~+ Q' F1 y @( u """9 W; X+ C( v: Z+ ^: x
model_ft = models.vgg16(pretrained = use_pretrained)% F# k/ L5 ^* C: ? g1 A4 u8 f: h
set_parameter_requires_grad(model_ft, feature_extract)
( I% V6 i, d& g0 a2 ]+ o num_frts = model_ft.classifier[6].in_features
) v' O6 e( `" f: K5 E- ~+ ?& U* l' A model_ft.classifier[6] = nn.Linear(num_frts, num_classes)% ^0 B0 L% C w" q5 `
input_size = 224
k4 X1 _$ P8 P- x5 ?" X9 Z( n7 M% T8 a7 `6 T: p
elif model_name == "squeezenet":
X% i t! b7 T """+ w* E# r1 H/ [: s. z1 }/ D7 v6 u
Squeezenet
0 P; V. U2 c. p A """8 p# Z0 U# O% m4 Q
model_ft = models.squeezenet1_0(pretrained = use_pretrained)
' y( l' d) l( h5 E4 |4 v set_parameter_requires_grad(model_ft, feature_extract)* l: A1 k9 _5 f8 Q* W, K
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))$ O* p% W0 [2 l8 m& d
model_ft.num_classes = num_classes
9 B) r' g" u- a# ]# D input_size = 224# o" e5 U4 x. d( O K7 H9 E
# [! u) o' m& {+ z8 Q/ B
elif model_name == "densenet":2 G, z) x. [* Y9 b/ T0 o( V1 x( {
"""
# r3 T5 ?5 @3 H" P Densenet
* w5 M7 p5 ]0 G """
0 g5 v. v1 h( {2 Y model_ft = models.desenet121(pretrained = use_pretrained)% Z( Q& y2 i* a; }. D; I8 J+ R
set_parameter_requires_grad(model_ft, feature_extract)
: c w$ p! v7 k8 u9 w num_frts = model_ft.classifier.in_features5 U/ T- K$ A/ W- `$ I! d4 j: _* `
model_ft.classifier = nn.Linear(num_frts, num_classes)+ v- @2 A" {7 f$ k9 G
input_size = 2249 m% s* n7 t8 r0 z8 {0 C8 M
- m6 p, I+ _0 V4 A
elif model_name == "inception":
; U( i! q1 S7 Z1 ^* _8 ? """
( ]1 n. H% t) Q9 l' R: c+ P9 t( W Inception V3
r F- n: x: k- U3 f2 i """- ^1 [! N) U6 ~ n. u
model_ft = models.inception_V(pretrained = use_pretrained)' q& m1 W2 s' u( c. |5 B( ?4 x
set_parameter_requires_grad(model_ft, feature_extract)
- u) y- F4 h( [1 Y4 L0 T0 U& q6 w
3 G; c4 B' A# l6 }2 v4 e6 l num_frts = model_ft.AuxLogits.fc.in_features: ]. W* S/ W: d! T
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)" z) ?* L5 L C! m7 R! |
6 E3 {! D; R; E! p
num_frts = model_ft.fc.in_features
3 v( h8 I+ H/ J$ y; U model_ft.fc = nn.Linear(num_frts, num_classes)" ?% R- t* Q& p( y$ e
input_size = 299
1 L. _2 g y7 N* k) s( Z: T3 V% e j/ q6 z3 D
else:
~2 x4 _; e; g3 N9 ?9 f print("Invalid model name, exiting...")
8 @7 e9 ]2 _; `) s exit()) O1 |8 [8 E/ ?& s
: r/ I9 u/ m( i return model_ft, input_size9 C9 k3 k6 t: f. H" N
6 p, U. W$ u5 {. f% j' B! h1
. Y1 @/ @8 C, n- h1 O, X$ x29 q* H& @6 |3 i2 _8 n
3. Z- U7 ?+ `5 l" ~) F0 Y2 \$ o
4' u w# D2 C2 y* `: M' B$ z
56 W8 k8 o- q: a) |
6+ t; R" f; y0 Y" s0 }8 i
79 J7 M( J9 v1 \' {( c
8$ g4 f2 I# `, H2 d4 Z" I* A( P7 N1 X7 A
9
+ g! @5 x5 v9 i/ F3 i/ ]10/ ?* W; o# m% q. Y
11
: h: m5 j2 b: e129 Q \ `! _4 C6 A# t9 G
133 ~' z3 X- o; _$ G K$ M# [; x
14
0 c0 x; D% [1 J a2 h' c15
# b* \+ i& I1 I1 d161 ^/ M- ]+ h8 d
173 a- `6 w0 w6 k8 i
18
- k6 F8 q2 W8 u6 Z8 e) I; B19/ R+ m Q3 S: l6 g7 `0 W
204 F K3 h1 l4 r5 v
21% Q, {1 q! I7 `
22
" c( n- V+ w$ l1 m1 M23; B0 N4 A& j3 q6 ?8 U7 _8 I
248 Y# n2 C9 F4 r2 a
25
+ _, _' Q0 J7 c1 N. U$ r: h26
/ B: g& q0 B+ N2 U- p0 H5 S2 F27
) m: E9 i% P& r" K28
, K) T- y/ M& Z+ p {0 e! Y29
6 D7 c% _. b" H. [, l0 |- D" p307 d: v. b% R+ i8 W: A; Q+ h+ g
31
. t z; P& A" x" D* w32
) U/ h/ P0 Y+ G. o339 \; T- M" E$ B6 o) I
34
( R* U/ h8 v' @5 a9 @35
7 Z/ m @$ _( x; Q36
; }- Y0 I3 P( U7 c' n0 S37
* x6 _+ I& M" X: |4 i3 [; `38; @) H0 Y. w7 N+ K; d6 s5 w
39' H3 m: a4 m) Z6 Y+ N7 j0 Y
40
L3 G7 P/ ?; k* l+ r41; u3 d! ~2 D; t# P; D3 B
424 q' Y6 V9 v7 @1 Y& N3 z
438 E6 Z" ^5 v. z8 W( j
44/ K' |. J$ A5 w" ~+ A+ B( `2 ~
45
- L K( \; g8 v6 Y4 ?7 M9 e46% H! e \ h; e
47
- a* d/ _* f0 _) f+ `! Z48
* V% O. ~; _" M: t" Z* N1 y2 l49
" T- d0 v6 m: i! l8 g( Y0 s* B* S/ O500 @# ^9 Q. s; n3 r
510 `+ w4 r5 ]4 A9 R* B
52$ ^' v9 t& o' y: P* G# S
53
& Z7 u7 S# E8 u2 `7 R- w5 w54" ~7 }. L8 n- y' ^" o) G
55
) [& n ^9 u5 X, @) `% C56
& i D3 S0 O/ j( N2 S0 W57* r8 F9 x. P. s1 y5 L
58
) y4 ?( u! M5 j$ s1 W6 V59
& \, K6 \- w. A: f" u60
, h7 m- `* s: _$ \2 `61' h# o( o G9 i% T, I7 j3 C
625 P% B0 o" z3 p1 Z# N
63
) Q! a% L7 x N m' Z+ {& I7 t$ U64
' v% K, S5 R/ X$ ^# [7 B6 r: ]65: o# \4 p0 T$ b
66 [3 r2 D$ a+ I+ b* l9 y' `
67$ h5 D# E$ t' V
68
3 ?& r2 v8 e* w# ~. d' e& W9 m69# ~2 V# u! m f, X
70) I2 z, D# R* W$ B9 R7 y3 W
71+ j, F6 T9 E1 t! X9 d: y
72
4 i0 O$ x; j( @0 A% {7 i( ^73
4 `/ G) U% _; C0 C* H7 R74
2 m$ W8 E! Q# ?7 |7 @5 G7 M75. u/ q# A7 D1 m- f
76) Z4 f+ o0 s* {' a4 W) q- p3 N8 D
77* a9 J, V% s) L5 l
78! @2 I9 k S4 [' k
79
x9 {8 R; S& V80- W* W/ _9 S* B: K/ J4 S
81
- G* l. r; n) d9 i7 n1 o* p82! ?! A9 I5 W a- ^
83, |+ F9 `5 f; O
7. 设置需要训练的参数/ @; G+ N9 i. p( s. ~
# 设置模型名字、输出分类数, O8 h1 ]& B) I0 ~7 \2 K- p
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)& M0 b4 V- \* _# t- N
. m. P# o8 P8 G, p9 g8 E8 d# GPU 计算- K5 R/ `6 `9 ?* w
model_ft = model_ft.to(device)
3 i/ W: d% E+ H
$ O E, n# U3 P* k+ X& k! [0 k# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
3 M( z; v# G2 D$ K' {8 xfilename = 'checkpoint.pth'( u7 @- }& {2 K# F9 ], F
$ l7 U4 {/ f5 r5 L6 v5 d
# 是否训练所有层
! n$ W6 a% @* {( P; `0 Yparams_to_update = model_ft.parameters()
' n5 q+ t9 ]7 T- |# 打印出需要训练的层% ^* ?: j6 Y0 ]6 g6 e/ s/ U3 S& [
print("Params to learn:")) r" M$ g j2 w y5 T: H* I: O5 O
if feature_extract:) d, J1 y; ~1 Z H( L" v; m
params_to_update = [] m( d! ^+ z+ B+ ~: e, s1 \
for name, param in model_ft.named_parameters():! `. T0 C! `# L! z! T, l
if param.requires_grad == True:
' B9 ~' z- |: Q# d9 V params_to_update.append(param), e: {. ^( |, {
print("\t", name)
3 b/ D: [/ K! X9 H# R; Z% _3 Oelse:
7 Y4 g4 w+ q! K" p0 \/ ~2 [) U6 f' c for name, param in model_ft.named_parameters():# f2 \( c$ a/ \. A b6 V, H
if param.requires_grad ==True:3 U' v. F& l" S M6 W! F3 e
print("\t", name)' C+ y' T4 }+ F( W$ |+ S0 \! y/ r
+ n/ N3 `2 z/ q! N$ F1 u- K- c1# m# M7 `7 z9 ~4 C
2% T/ R7 `! ?9 h ?7 E: }9 }
3! T, e) M. j, g5 p, s* I
4$ j9 D- O* V( R) a. o* E2 Q
5
) M' P- b* b z [' H6 g1 N' Z1 a, t& z( J
7
& N% |2 ] |3 S# x: L) A" V8
; ]. Z; U8 I7 H! P5 {3 w9
/ b8 u/ c* H+ K& n& J) B10
7 t" o6 Z1 a8 O4 F) H3 M' n* m11
( i$ K9 ^' u) I; g1 ?12
3 q1 l$ O" B( M1 o) T% u# t# k13
% o6 Y0 J; H# y8 F14' M4 w3 }% i; o
15: t# _8 R- q1 ]% p$ T
16: y& B' Z6 ?- d# |) T% [. t$ D
172 T& l9 O5 Y: x( q4 t& E8 S4 d0 y/ v
183 W$ W/ @# f( E* d) U( F
190 [$ y3 y/ K( B* C- q3 v: w8 Z
20& u+ L L& @, L, @. D8 w- I
21
7 j6 f4 o# z1 }2 _( b22! {5 o6 x* l+ \" V' g# P
23# q# z4 t/ K. X0 Y
Params to learn:
4 @" S& F5 o8 ]7 ~- A' o1 K fc.0.weight
: p) X. T; A2 d4 E6 M7 g fc.0.bias
; l: f' O4 I( }15 R5 B: D1 @6 {, `; v
2
( o3 H8 y0 `% f* z7 N0 J3
$ r0 f' l/ P: T, D7. 训练与预测: V. r6 Z1 {6 z6 H" B
7.1 优化器设置
- q4 H" R' M% A$ H- p' S% C' n# 优化器设置
5 i0 d4 G# F* v; M7 Q9 |optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)- j6 N& h; n: M$ \5 d, C+ _+ S" M8 P
# 学习率衰减策略
" k5 |9 y# ?9 H: g" a+ A/ d1 Xscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)' i9 ]6 d% z( }" t c( l
# 学习率每7个epoch衰减为原来的1/10
! x# v5 Y8 e# u5 D, I) |8 L8 z7 ^# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
8 ^( D0 [/ A) `9 ]
* E1 L1 C" H* R4 v% ]criterion = nn.NLLLoss()
4 _. n8 T3 z5 j$ p2 k& N1) d! \/ d. U; i7 B7 Q6 { X9 Y
2
. Y4 q z& @0 F7 r: {- \" }) i3' Z( ^! e' |. X
4' h7 B4 o$ m# V- a( K1 O
5: l9 N' U. \/ ?* v) K @+ p
6
6 k m8 F X6 Q2 q9 k7
; n8 J( x r) J' a: n8
" x3 x" j: @5 k `7 P7 C# 定义训练函数4 v6 B3 m6 r9 j
#is_inception:要不要用其他的网络
8 x# V0 |4 V! @- U, I# ?: [0 Ddef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
( w8 g9 Y/ Y$ ?' ? since = time.time()
6 \5 I! r% t8 F( D* s #保存最好的准确率" _- y% |! T( k- V8 u3 U
best_acc = 06 p9 A# b# S+ F' }, j
"""
4 V! X* a; z8 X v: b- w; Q; C+ ] checkpoint = torch.load(filename)
3 j. ^! c( r5 ~, o2 j1 q best_acc = checkpoint['best_acc']' {" m$ x1 Z2 h! W, S' D- u
model.load_state_dict(checkpoint['state_dict'])
, c3 O% M5 V. X9 E optimizer.load_state_dict(checkpoint['optimizer'])
% D& Z8 ^7 o! R% ]! @ model.class_to_idx = checkpoint['mapping']
1 R l T! H' V/ | """" h5 k, f! o3 v5 v' s
#指定用GPU还是CPU% j* w+ e0 ^1 b) i/ e3 }# r$ w" e
model.to(device); B, G5 Z' L' x5 e
#下面是为展示做的
9 }3 W7 T* A' H3 I val_acc_history = []
5 J) L; ^/ b1 d train_acc_history = []
; ~" X' K) }1 p train_losses = []6 u$ V7 v& |$ b% t$ f
valid_losses = []9 X [4 c5 V+ U) l( e# Z! U+ a
LRs = [optimizer.param_groups[0]['lr']]
# H$ T- o8 A3 |; t* F. b #最好的一次存下来
* K9 ?7 S# C8 q7 W: _ best_model_wts = copy.deepcopy(model.state_dict())0 l9 p4 G, d" E8 f
) {5 f% G$ q: \# ]. b2 x for epoch in range(num_epochs):4 t' G( Y. \+ s1 q' g3 q
print('Epoch {}/{}'.format(epoch, num_epochs - 1))8 f8 O V. ?0 k2 d% r9 G& ~) U
print('-' * 10), K) c: q! W3 P( W
/ z7 n4 I8 w9 g: \+ D% h) g1 F
# 训练和验证- |! ]4 M# S c g9 H
for phase in ['train', 'valid']:
. k( a' @" G1 S if phase == 'train':. |9 k' j' y( ~8 Q$ K) \
model.train() # 训练+ Q' s D' Z, s4 k- Q) A) B8 ?9 g
else:
O! i3 k; d* H! g, F4 q ^ model.eval() # 验证
z& ]! s3 C# }( R- f/ |# G8 A+ |8 p9 G7 K
running_loss = 0.0, V9 T2 U0 c x8 P8 P& F% |
running_corrects = 0
1 X/ U. S) y. q) O6 X8 n5 ^
3 q& T7 u7 `& y2 y( o # 把数据都取个遍
! {( |; S: E. x' O0 A0 j t8 x for inputs, labels in dataloaders[phase]:8 K- ?0 p. u7 e. C2 ~! B
#下面是将inputs,labels传到GPU
0 [& f X8 Y/ z3 L4 m inputs = inputs.to(device)
# u: O& L) ^& s' C labels = labels.to(device)& `$ I* ]6 T; P" D
: O2 D0 q, W8 p # 清零' [0 i# e' ?0 o. v
optimizer.zero_grad()6 k8 U; }6 ]- ]& ^+ r' c( e6 V) k
# 只有训练的时候计算和更新梯度
) n7 p( V' ?) } with torch.set_grad_enabled(phase == 'train'):
6 O+ a( D5 h3 h( {5 m$ M5 v+ a #if这面不需要计算,可忽略
1 _; P/ u f" {; H' l* N1 s5 ~) i if is_inception and phase == 'train':7 `& [! [) v* k& c
outputs, aux_outputs = model(inputs)4 y0 r7 z* P0 Z; `7 b
loss1 = criterion(outputs, labels)
6 | F+ I5 @. a( h7 ~! e) U. P loss2 = criterion(aux_outputs, labels)+ `/ a5 M \! }) G/ x* S: ? |6 G
loss = loss1 + 0.4*loss25 M: ~6 H. O s- A5 I
else:#resnet执行的是这里
8 u& ~1 s3 a, `% w outputs = model(inputs)
* W6 F- Q7 ~4 P6 c& P& C# f loss = criterion(outputs, labels)
O S' B5 r$ S6 K& J9 X% u: }2 l1 w' M' V+ }! E
#概率最大的返回preds
3 Y' Y4 R7 ~) A }" {+ J _, preds = torch.max(outputs, 1)
9 y9 E v' L3 [! n& A. W+ P* R4 [, i8 ^* q
# 训练阶段更新权重. Q% z0 E z+ b/ {8 |" y3 _
if phase == 'train':
# D/ z7 G7 v( p/ D* X7 f& y2 r loss.backward()% t8 \8 n# \" ^7 D* F
optimizer.step()
: C! r1 G' g% S. P+ s8 v* ^# s3 ~+ f+ E7 Q( X
# 计算损失
1 X! P& P( X- F1 H) N running_loss += loss.item() * inputs.size(0)6 P" c1 _6 W0 b. `2 U1 R7 c
running_corrects += torch.sum(preds == labels.data)% q; `' S2 R. H9 D7 |
# [9 Q8 _* {5 ^) \ #打印操作9 Z( k* O3 I, W' u' h+ G% j
epoch_loss = running_loss / len(dataloaders[phase].dataset)
8 e, W- t6 y" s% q: C epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
# T8 w6 W( o1 |5 f1 o5 G
5 v$ }, r+ Z' \5 g; M! t! z: E; h* {2 F3 G+ }( P, p7 U
time_elapsed = time.time() - since2 D' R9 x% {' Y$ |" o
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))" h& \ r" d$ x R/ D$ I* T
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
P/ B! U# z3 s3 C! z& g" D, f1 n" }. ]
3 x& @7 V6 d+ m: {/ |) n: U- r' R# r r2 g/ m" Z. I
# 得到最好那次的模型) O ~) _, l: M4 m* Y) Z+ L
if phase == 'valid' and epoch_acc > best_acc:
7 v) E& p3 O7 ]( U2 @, _ best_acc = epoch_acc2 H% u- ]8 _2 @ j4 w0 s6 D
#模型保存! ]$ j3 r: |0 r% \
best_model_wts = copy.deepcopy(model.state_dict())# h1 K" a! U6 Z
state = {
$ a1 D' ~% M( u, b5 l4 t4 P #tate_dict变量存放训练过程中需要学习的权重和偏执系数+ V+ n$ K( \- M8 M
'state_dict': model.state_dict(),
9 N# D, g5 [$ f" i( D 'best_acc': best_acc,
% O8 y$ W0 P; q7 c 'optimizer' : optimizer.state_dict(),
8 a8 ~) O7 ]( s. s, y& { }
' m, b; c! h& F torch.save(state, filename)8 {! B+ ]% \+ N9 I
if phase == 'valid':
' W. `9 C! R7 E4 U, |) e val_acc_history.append(epoch_acc)$ _% e$ j- i t1 L" S
valid_losses.append(epoch_loss) F/ l: g. O( U: ~" r6 h) N
scheduler.step(epoch_loss)
1 p6 x! R% X* h( }! G3 d X, ^ y if phase == 'train':8 \0 H$ |3 m8 K, ]$ i5 s& |
train_acc_history.append(epoch_acc), y1 m7 V; V" N% t; X; }
train_losses.append(epoch_loss)5 ]4 i# J9 H' e0 T9 j4 m
! ~" k- S0 O" k8 D" M3 B3 R* Z
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))8 L: K/ W Z" ^
LRs.append(optimizer.param_groups[0]['lr'])
% k- x) g6 s" e+ l/ y* L0 ` print()0 G) o8 h: k3 l- ? z) t
6 _# l/ B W; p9 s/ m
time_elapsed = time.time() - since
0 P, G1 t$ }! H: C! p- ^, q4 s print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))& S4 Y' m6 z- L/ j6 g
print('Best val Acc: {:4f}'.format(best_acc))& _. _. y$ W1 r$ i
$ x* ?0 u+ a- |1 W) O# S# _% U
# 保存训练完后用最好的一次当做模型最终的结果' o/ l3 A' f$ G5 `0 d1 C) A
model.load_state_dict(best_model_wts)
0 J; d% c; }6 Q! x return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
% `4 |$ _3 d1 z- i3 H# x
' `! ~* {3 T, q7 a/ h# k c7 k7 w O$ \
1
. J& Z ]3 s+ F1 d3 R2" T; y- Q8 O( C9 t- T ~
3" B' f. G4 }% ?; l( B0 @7 ?
4
; H+ W4 U- a) {+ M& O: [! b5. y. i$ p3 h* {# T; b/ {8 t% y9 u
60 j W- Q! z; E+ X: W0 [. W
7
: I- r3 e7 H2 e5 y4 ]: s( K' Y6 S# Q8
$ \# c4 p5 J/ `! F9 w90 T- G, U( v. I, } C
10
6 W6 U* g/ H! E11
B X/ D2 T' a; L( g, ]0 k) A12
( h4 q- [ ]7 P1 B5 b0 ~% Q5 {2 q13- O X' o3 e6 p6 u+ D0 f
141 H( j0 Z3 o/ r7 a& x& T6 H- e
15
: P) p' H7 }. f- [$ z16
! |4 m* ]* \+ E& }; T$ G4 T( D; x17
% z$ c3 c3 ^8 n# ]18
- v. l( z. |( [19
1 U" w- W6 E' N202 c0 D- a; [& r+ Q/ E. J R
21
# e; ^7 n. ^. U; U$ N1 _* _# t22
" {% t4 V/ S/ W+ ~1 _6 ?- p, p' b2 ~23
* e( ~0 n- C; e; w6 b24
( R9 L5 Z2 ]' V$ _2 p25
9 V/ Y! b# K5 Q7 g! [260 ^" Y) X5 @* o8 L# ^
27
& S |" p- b+ c7 j& o9 w! ~" C: e5 O288 C4 [# n! ]! s7 n
290 P7 l/ B2 U1 k1 V+ F0 k: f
30
2 o' Q6 D& G: V5 H( \: V31% m+ v& i2 `1 Z8 d6 l6 ~
32
( A- Z, p9 J4 h* Q0 f3 w. j1 H' k33) }: H" A c3 y( F( i2 {
34
1 l+ d% \9 n8 s4 d8 N# w35
4 ]8 m/ k6 F' X7 V. H36
5 h7 F3 m* o1 V% i/ H. @$ F' P37
: B, h! B8 L* ]. S% I$ y! G! [38
/ n6 \/ [) S+ d39
6 {' @. l" X* U/ r: i6 Y; D. I1 @. A402 S6 ?& W4 u4 ^
415 V& d, t7 a& x; G9 k; e
42/ o; Y) c) [ V. t5 @+ w
430 m7 h/ X4 r O# ?: p( l
44) r5 j& t, O4 A: C3 G5 v6 @
45
. {; ]1 M8 _- s& j: S# x46. K* V7 {% Z( f) h) g, D* i4 |
47
T+ [3 w) ?! T- W48$ ?6 n3 J \! E0 Z2 K* B9 B
493 ` t3 C- \2 f- w
50
1 w& q o3 V% X) I: z1 X$ S515 L2 f* S8 l& a$ }' D2 q. D
52
3 S9 \# }2 ]! S, F% W4 S8 o53
7 v# v& M- q" N* v; ?# k) g8 r/ c54- i& J% y* h( c5 p8 R) z
55
) ~* Z3 L, v( T% l% b* ^56
+ i+ U _5 A: \9 d9 N: n- W57
) P4 X; A5 F2 X+ |7 ?& \6 }' P58
. L" J2 V" k9 ?59
. L- x) x& }- u; S# M* T0 A) b. u z60- z9 p5 p- d6 A2 {0 U2 E( T" i
61
' c1 l: R/ {( O9 X# K+ T625 ^: U3 w( S+ N7 j
63+ z+ ?+ \+ I: x$ [8 R! ?3 `
64
0 U; `6 |! T" {$ C7 B$ D65
- y8 ?6 V4 S" @/ |. H668 M% q( U$ y/ o+ ]4 C4 z: f
67
1 s. N+ ?/ @. k. P/ R2 H68
4 A2 P5 @- O c4 j69
& r, x/ l- O% `$ O" [* ]70
@1 y5 a4 h3 M& p, z1 z71
& s% z& U! Z: S7 N5 P72
y7 ~, N6 h, N/ Q73: P7 t! |" D) a+ q
743 {- d7 Q! j$ C7 r' d
75& y' c. J, z1 L2 O
76% J x+ q+ ]6 L3 W' S' [! t( \
77/ \4 y( q& W9 \( V; j4 q4 Q$ v
78
; C0 K+ O3 l* K1 z# `79
. R, u& y5 {* S; k! z3 K1 y$ Q2 {80
* E( t9 G8 D& {1 R* o( s81" A# v; _7 T+ ]. c" c/ ?$ ^; r
82
# `( h/ t5 C5 S* ?2 L8 @83
( E u A( ~# m; u p% h84
9 _6 d; [; Q# f9 s85# M1 D8 u) b, Y7 ~
869 \5 _. P) ]+ U' _
87
8 L6 m: ]) G5 b- j4 ^0 o6 q88# J+ \0 F. z" x) g, q
89
6 N2 G/ }4 Q5 `90; J8 {) g# Y$ ?, X: k3 `4 X, R
91
0 B" D/ P$ b5 S( F# g92, e [) L: L8 [) w
93
4 H& C5 \! v$ e/ e94
; F4 K& h. i5 \: _2 x( g4 P95
0 _3 X( N/ s7 O+ t2 B: Y9 y96
2 i4 B7 {0 ^: Z& N3 ]0 A1 L97
8 [$ Z# n4 S/ B5 \8 g7 T980 X6 }: x. Y: n) i: i* {
99, _* G7 O5 n0 S- Z+ y* L2 h
100
: {( S5 j. p# p7 t101
0 y8 A1 r2 R. C" L/ c; l102
5 S' u. @) Y0 ? n: q- V) z1031 @4 B: q9 t' Y1 j9 d8 G9 ~6 p& B6 l. Y
104" F, C* p; ~& X) D+ [
105! D+ c5 O2 q* K a8 D$ o F" [6 N
106 k& w7 ^" c8 X' S+ h$ K
107
* H; r+ k" B" e M108
( _( I+ z3 l1 \0 R109' N4 I4 Q) d, a2 v: z
110" D- U K3 d3 I w- @9 q% F! E
111) L: d' ~3 b9 w( m. K6 {
1121 g% J) h8 q4 m/ |5 K
7.2 开始训练模型
& f4 d" Y1 V/ B' \我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次' l) M4 A& j, F, J0 z6 D4 N Y8 z' I
9 F3 }& x- F. R6 z$ Y. X3 c
#若太慢,把epoch调低,迭代50次可能好些
. `8 O) D+ _2 x# N n& I9 y0 y( ~5 O# |5 J#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
q! e& i! v, K4 n" jmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
' I# `3 @6 p* W! t3 ]- y4 J& [0 C5 b6 u" b9 K) m
1
" Q& V, S# E& |, p8 s2
- W/ U! a C8 q, J3* A5 a# W/ B- f7 C0 Q
4
& _! y9 ?$ G+ uEpoch 0/46 i1 q0 w+ S) N2 G# v
----------3 C( t0 [: h5 r u+ L
Time elapsed 29m 41s
1 x) H% Q% x- C6 d( o8 X( ^( @train Loss: 10.4774 Acc: 0.3147
$ r, Z3 j" S( }Time elapsed 32m 54s
; e6 F" K8 `+ p) I; dvalid Loss: 8.2902 Acc: 0.47198 q; _1 x2 n7 d2 F7 ?6 q8 Y# c
Optimizer learning rate : 0.0010000
! E: |4 s. O. [/ h7 D& E, U9 ^, a& {1 L& R
Epoch 1/4
2 v W& i. C3 j+ }; @----------
4 o/ e, @4 G# x' Q# ]" F9 J- `Time elapsed 60m 11s* ^4 p3 k! h: w& l4 ]- h [
train Loss: 2.3126 Acc: 0.7053
* s6 L; x3 ~ ]0 m f3 }' y) [ oTime elapsed 63m 16s
# b& c7 {# N2 o5 Wvalid Loss: 3.2325 Acc: 0.6626
2 N1 O; r5 Y* d5 vOptimizer learning rate : 0.0100000+ y% b$ p6 T0 x6 r2 b6 F @% s
$ Y9 o; ]( ?7 X- G( q' l
Epoch 2/4
/ o. {. a* N" l- _5 H3 E----------! ~% Y' C8 [( c& A4 M
Time elapsed 90m 58s
' J/ A) o* w+ g/ T; k4 Ktrain Loss: 9.9720 Acc: 0.4734
' f8 l$ p$ a: I" u3 u6 O) xTime elapsed 94m 4s3 ~/ `! ~4 ] A) W( x
valid Loss: 14.0426 Acc: 0.4413
) @, g! X8 u1 i6 {- m7 q6 QOptimizer learning rate : 0.0001000
y0 A |( P9 h9 q, \3 t {2 x* p% Y) [: C
Epoch 3/4
! N& |; G0 o% @2 B1 J----------
/ h! j. K& r$ I+ [* d) bTime elapsed 132m 49s
; J+ G. B, q! s0 `, c9 A1 h0 {train Loss: 5.4290 Acc: 0.6548
. ^2 X! G5 L/ N" QTime elapsed 138m 49s
# s* f9 t+ Q) ~. ^: O; Mvalid Loss: 6.4208 Acc: 0.6027
, S ?5 Y1 r* j: gOptimizer learning rate : 0.0100000: G" g0 ~$ N/ G
0 g' R+ Y; `0 G, o9 V8 |
Epoch 4/4
0 K0 ^: b7 L9 K6 F2 E+ n! ]----------9 u. p& S1 u7 }$ I3 D8 o: D
Time elapsed 195m 56s
* f$ i- r y' b& [$ g: htrain Loss: 8.8911 Acc: 0.5519( Q' D! s# Z V7 M: C U4 C
Time elapsed 199m 16s% H* b" e- g) o" K& B& H
valid Loss: 13.2221 Acc: 0.4914- G4 c- U7 T$ b m1 p( A
Optimizer learning rate : 0.0010000, D G! V. Y0 O# D' E/ l
" l+ {" I+ ~3 V1 t8 d; U
Training complete in 199m 16s
0 `: |: m7 b4 a" ?3 EBest val Acc: 0.662592
0 s8 S" X7 J, R) P" `1 \ p. {* v2 W' b- ~& f1 l
1
" c1 a0 T7 b6 c$ w) q* F: g2( M( A# K2 W' u$ U/ {
3; w9 _+ D$ Q# _. O& M
4
$ p2 y/ X/ A; l9 ~5
* w2 J# ^7 L" Q! e6
9 O+ I; Y3 l1 w6 S* Z. p7
2 x2 q6 K8 ]) z# t0 M8 q* B$ ^8- E7 @" e r' F% j
9
4 g8 d* f8 u8 \104 d8 v' m' e4 L
114 D1 |* ^. @/ t% v
12
4 b2 H5 n" w) P# s! J9 T5 l13
- S( H0 Q1 o" e- V143 |; P. d0 `4 a; r! X& _8 |
15
5 j3 R, P3 u N% H5 n5 I2 c* H U7 A; d16; F2 Z0 f: i& d' y9 \+ X3 M6 \
174 a4 R7 O+ }# Y! ?& P \! Z% [
18
! ^ X0 \# r& V1 f19
6 E7 V& p# ]2 I( `( d& q0 h20
: X( O3 K% p) Z3 s2 l216 ]$ g! ` U$ t8 B8 ~" e
22; U: e3 O/ J( x' W g5 E$ N$ A/ v
233 m I6 ^- M( t* ~2 W
24; R/ g; `8 [) t+ X# r
25
2 [2 z2 n1 L1 {9 g( s" f9 ], M9 n( k26
l: v2 ?9 t% r" o1 D9 S27
+ x6 f8 L4 v9 Q+ P. z) {0 ^283 G' O+ P+ S' n+ f. X7 }* N
29
- J1 V/ k1 [/ G, Q0 ? l( d30
% i8 v5 S5 M* @31, y* [- b9 ]9 K7 m: Y
32
/ F) H# m! o3 `, j33
0 Q, T3 W3 Z/ R: L; G4 c! Z34
9 Y3 o' c4 x' g$ R35- f; F/ `0 d6 Y: e6 u9 J/ I k
36
( B( x$ ^2 W; x9 h376 y2 E/ p( m: L! Q) E- U0 b
38
( u5 G* o' C$ C$ ~# l0 w W2 c39+ ^# m- u* A$ d8 I2 e' b4 @; v" X
40
2 J. B7 e9 U8 Z* N0 b41: x: N# u7 q; d
42
6 J& y0 G i$ @; x/ p. H/ Q% f7.3 训练所有层
( _& Q1 G* v7 ]* P# 将全部网络解锁进行训练& g. S( g) f+ \
for param in model_ft.parameters():
8 u O9 n* g, z, g7 _* o param.requires_grad = True
8 Y1 q4 Y8 q( X9 Q f6 u: k O
# 再继续训练所有的参数,学习率调小一点\' l5 H' J" S6 T. y, V8 g/ e. P& ?
optimizer = optim.Adam(params_to_update, lr = 1e-4). p( p# [* M7 R/ @% y
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)3 q# X. x x/ Q; E; Y
+ L; P2 y2 y) P$ ^! F- w# 损失函数
& n7 m+ f9 R. `8 Ncriterion = nn.NLLLoss()
. f4 R1 }/ N9 o5 u( t- ^1
# G2 ?; m* M% c$ J. }2% T8 w1 [! |; x5 ?- O
30 H3 A! V# r% K# X
4/ q! _# z: ?7 h; |( Y9 ~) z
5
& \0 m/ ^* }1 W# N- w5 O1 a; h6/ v$ Z, I; U- O: }
74 R5 T& @& T1 i) d$ }
8
$ t. r; G- A5 J" ^, D9
4 V) z* P+ v7 }+ @# m3 g' Q10
' n1 t) M; s2 U! W6 h# 加载保存的参数" _1 h0 V2 |, v# g, k2 V# Z
# 并在原有的模型基础上继续训练
8 w# q& [, C& _9 |' @# 下面保存的是刚刚训练效果较好的路径
3 _" x% ~2 ~$ z8 A& N0 m8 D! echeckpoint = torch.load(filename)+ [' C' @* c3 Y4 g6 K
best_acc = checkpoint['best_acc']
; c5 e# x( l" W! d4 X8 u! q2 Hmodel_ft.load_state_dict(checkpoint['state_dict'])
% {- y7 M" E/ U) Xoptimizer.load_state_dict(checkpoint['optimizer'])
) k; }) z* {& k, S) x ]7 m17 _9 U/ Q$ G5 z4 o+ ?% n0 n0 \
2
( e" c4 M9 G8 X$ c1 r' W3
. v# l k; v6 O* W46 l4 J3 S- o* P: F8 ?; `( d+ o
5; ^' a7 l9 p' i% A
6* G7 ], Q) k2 T
7
+ J! l7 A/ v* V3 M3 E$ G3 L. m开始训练
: m- w* O: r6 t. N# r4 y注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
5 i7 {7 v5 K u. h* o$ H* q$ c+ P% r
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
9 D/ J- m: t$ P' s! C1
5 C- E; v0 ?: n8 U qEpoch 0/1
9 x1 y3 S3 h1 x% _3 l" {----------+ X: U- R3 F: g8 c+ Y, c# c
Time elapsed 35m 22s
5 J: }) z" W, Z. x/ Qtrain Loss: 1.7636 Acc: 0.7346; _# t! g% y$ Z" i% t" J, S$ \
Time elapsed 38m 42s: f' u* O" M: a4 ^/ I% w
valid Loss: 3.6377 Acc: 0.6455# |; y% Z0 d, m0 J& z/ p% g, }
Optimizer learning rate : 0.0010000
0 Y. e3 o2 ]. \5 R- D. n( u7 ?. O! l8 ^# V) F" K" P* }
Epoch 1/1$ E$ b- S0 [7 o
----------7 }8 X- X$ m: N
Time elapsed 82m 59s
' p- k0 y y3 ~) u8 {) l, d3 U) L4 Jtrain Loss: 1.7543 Acc: 0.7340% c- f# ]: C6 T( c1 K
Time elapsed 86m 11s
: {: q) ~* a. u N5 U, z$ ]valid Loss: 3.8275 Acc: 0.6137
, E& A+ E9 U! c, kOptimizer learning rate : 0.0010000
B& k& P$ m ?$ t, C" I% f4 H& U' L4 P: w! y" X- _2 L7 r, f5 J6 x
Training complete in 86m 11s
' I8 ?8 w1 J7 c* f HBest val Acc: 0.645477
& x. x6 h9 d1 d& b/ q" G7 B! A: ]& n F3 j: n p1 T) O
1& l+ ~# m8 G* e) t3 K7 W
23 ^% ]% b5 @9 ^$ Q- _! k0 N
3
! C/ w( b* E6 p$ j4
7 x U/ t% r7 j( Q7 }7 N6 ~. N# p54 n: ], M( W5 O
6
3 y4 N, k6 C# [, q$ V" [+ L, i, A7$ _1 S9 U' p4 h
8
" y: E L/ d+ u% m/ ~5 S [% \9 o' m% W; B$ g1 ?0 t3 V* Y& c
10
/ D6 }& a" _: s8 E11
% w0 w% _" j( l) }' a) K w12
5 @/ J. @9 a& X$ r! z13
$ ?# G0 X2 o; t. f) ]( ]: N2 W, ]14
- O" i3 H% A; p& m( Q15
9 f+ [ I: r) [0 M* y9 V% x* E168 Z- Q/ M4 s9 w: E( s2 {* J
174 b# h. O n& k s5 b
18# n5 |. z$ s) |8 O- W( k, X
8. 加载已经训练的模型
. B2 O5 m& z$ D, ~' w相当于做一次简单的前向传播(逻辑推理),不用更新参数, D7 a; C' Q& k! E& Z3 d" }
! p- @, i$ f0 emodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)- W7 _0 K- w" E4 `; Q! f
3 ?3 i, N4 Q# i/ W/ Z, R7 g) F" I6 `
# GPU 模式' F# A+ X, Q4 g
model_ft = model_ft.to(device) # 扔到GPU中9 _+ ~& E0 M+ n, h
* V5 X, u# R( I: J4 r# 保存文件的名字
! T! F1 Z/ |( Cfilename='checkpoint.pth'+ i' b, t) Q5 X% N: i8 V
. k; t% T6 U: a$ p; }3 z# 加载模型0 l( B) W$ @/ ]! p# h) v1 G
checkpoint = torch.load(filename)% F& A6 D7 ^. A, a; ?: n" Q
best_acc = checkpoint['best_acc']
2 `5 a- _# w" M$ f! ymodel_ft.load_state_dict(checkpoint['state_dict'])5 {- A( G" s( [9 @1 T
1& F S; c5 r U! ?. J& s
2
7 q! m& ^& t+ q& y3
# w# d4 `- e' o! Y4
" f2 N h& m7 W; F5, r, z( L1 V8 g9 x6 |* }
6
/ V6 U) N% g6 S- K7( f2 g( G9 G* t$ [# L
8
0 h$ ~ k2 _0 v: ?9 C9, X# }3 ?& b0 B( K, ] }
109 t7 ]$ }8 b. b; [
114 I3 B. M' h. Q' b
12
/ q+ `+ c3 J, Q1 d, e% b6 `! _* z<All keys matched successfully>
$ F. s6 Z, T: ]/ J1
' ^: o1 h0 F i7 ^3 z' S3 }def process_image(image_path):
+ S1 y9 A( d' @; x4 Q # 读取测试集数据
$ m# Y7 J7 F+ x4 ^: L+ E+ ~ img = Image.open(image_path)
+ r! @/ v; V% P! S # Resize, thumbnail方法只能进行比例缩小,所以进行判断4 c t& o7 p4 s% D7 Q) R
# 与Resize不同# j. v/ x+ k ?! V8 ~# G
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
& m1 }5 U+ i9 G$ `( H # 而且对象调用方法会直接改变其大小,返回None) s0 H# E# G4 ~! @
if img.size[0] > img.size[1]:
/ {( j) R; J. x1 J6 u1 B img.thumbnail((10000, 256))5 d2 h; o* z" T3 v
else:
5 I+ y# R o6 ?7 p! O, B; L" R/ c img.thumbnail((256, 10000))$ W1 z3 F; q1 z6 ~" z3 a9 K, e* N6 I9 a
! L0 j8 Y7 j) g, s
# crop操作, 将图像再次裁剪为 224 * 224
5 a$ H/ \' N; `4 q2 R left_margin = (img.width - 224) / 2 # 取中间的部分
4 h8 N1 @. h: Z bottom_margin = (img.height - 224) / 2
& K+ ^! S6 U( o6 S' @+ {5 d right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
! x+ i1 I$ N# {' ]. K0 e top_margin = bottom_margin + 224: t9 [9 K L- P0 e. m
# A, @( z; Q* G
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))+ z6 Y$ V) k, D' L- R% ^
* D& j+ u# Q0 t2 I # 相同预处理的方法
' z( x4 g6 n* c* L* T! T # 归一化
* g2 k) X3 M& w- q, n. i: E H0 G img = np.array(img) / 255: t9 P3 a) v3 Y' `
mean = np.array([0.485, 0.456, 0.406])0 l8 n8 q! l- n, \: `
std = np.array([0.229, 0.224, 0.225])
1 Z! v# y; q9 `) G% U* T G8 Q img = (img - mean) / std0 o( t& r1 p# L6 f. H5 S v U
, \ T( G' M* Q" m) k # 注意颜色通道和位置$ \% e* F6 c. _; b. Q( ^: e" n
img = img.transpose((2, 0, 1))
2 l3 Y' C+ x Z n5 L. J' V' O. A8 k3 M, a5 h7 j
return img
0 ?0 k* t! \$ q0 m6 y
* G3 @. v2 x0 A- Xdef imshow(image, ax = None, title = None):1 `+ J) s$ Q* f B# `1 v& @; M
"""展示数据"""
% l5 N% _2 ~4 h. a1 P2 A3 r if ax is None:
A5 K2 h2 y8 i) E, L2 W fig, ax = plt.subplots()5 z2 R, Q! I2 w2 e; N+ [* @
+ ^7 U3 u6 M E+ x# \9 k( y
# 颜色通道进行还原
8 B5 r5 I" u" t" Q: s2 I image = np.array(image).transpose((1, 2, 0))2 [7 K; W( S# ^
; ~4 L* ?7 e; N
# 预处理还原
. @* X3 t* z7 d: ?. P mean = np.array([0.485, 0.456, 0.406]), c h W4 X* ` X
std = np.array([0.229, 0.224, 0.225])
1 `- {3 L; g" }4 g image = std * image + mean( W& z- i7 P' F; y% A* ~/ A
image = np.clip(image, 0, 1)- V' o6 A/ M8 x3 X3 C3 A
# Q ~+ e* I7 {. y ax.imshow(image)6 P, r( a6 ~ B q, K) ^& A
ax.set_title(title)# f9 K3 U: X1 O6 h
! `# _6 W8 ?9 R: Y9 t# b
return ax
* V' f6 Q+ S/ e* v
' b: `9 X ~( S9 N/ `9 B1 T3 Vimage_path = r'./flower_data/valid/3/image_06621.jpg'
( R# a% @. s" y# ximg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理 m5 P" j$ R$ a, L0 S2 k1 c8 o
imshow(img)3 X$ E/ o, d" b/ q R
) I$ }! @% K9 C7 A, t7 x
1( C: ?+ P& H Y5 u& _5 {6 O' M
2, W- x8 F: _ o( ~( _5 m1 `1 q
33 H+ a0 o! z# w) b+ i0 R% t0 T
4% p$ ]- @# @4 H2 S
5
( M: h. e4 r+ C [1 k6/ e9 k/ a/ h: L* ^7 R O
7; N8 m6 C0 u% d9 N# [8 _) P
8
4 ? l" K: g. d1 G9
1 r' T# a% u5 c2 _# Y) R7 D4 \10
$ a" k/ I/ n/ V8 B( H11) ~" }- X3 m! A& ^+ D: u. [
12: Z. C. j7 B( H
13
i) l( K7 {3 E14
7 W i! B9 t' U/ k8 K1 h: ~15
1 _, \0 D# s" ?, Q6 @* O' B16
' O& ]7 [- x" w' Q, n7 \& l+ b+ x17
& U4 I4 U0 y( |- A' d18
7 |+ R" R3 `: N* v- u# F6 Z* N8 J19
! \ t& i6 _: P/ W20
0 \* _. Y: c6 Y* ]217 l/ E2 }, b# H$ x$ T
22
) T }, N+ C0 S B' E; C# {' \23: J2 ]5 L$ s, ?8 w$ ]/ C
24
% O/ c/ q S D8 P" u+ U$ h25
t8 B |# {5 \' k A261 h! K; I7 I$ G
27
- o1 b3 t8 o* J28% K+ P$ \& O+ I
29
& P1 [. u8 [6 P! i; A30
R7 H: d/ B' S# U31
2 `1 U$ z3 k. z9 T9 C32
+ u; P" B. R! c' }$ X4 s338 T' [! i- H( Z- Z/ _$ T9 B
34) d3 {/ q/ j, i0 C8 g( r+ p
35) b8 j6 p i8 t( _, L" r
36
- e8 [+ w2 `& i. J- z! U37
9 J" E5 \$ W$ `! k0 h, @% b' E38. V2 x x+ n0 a& z f* e: h
39
; e1 k3 b, y6 U# d40
. @$ H. x9 s2 O2 i41
6 p2 k- R6 t* m7 b2 A42
! J. ~( R9 X4 `2 K! l43
% B/ d9 W* m+ O' l44
3 Q, H. x2 P) ^( _. i1 I" {45
9 i- ?$ t2 M6 P8 n6 Z7 _2 i46
2 s/ l6 ]: A) `% C, w3 k( |47& d" k) a& k, e, d9 p
487 T' j" E- _3 V: D* |2 Z
49' \$ b) e; D; V7 i Q ~& b. W/ y
50
; u0 g+ m: s1 ^9 F F, j9 C) }51
; B0 p7 h3 ~# i/ M) u* C522 C0 A2 q, y: m/ l
534 i: m+ Z( W) j9 E1 d U8 N
541 M9 ~: y- ~5 K7 q+ y- C) i
<AxesSubplot:>
9 q$ t) K- v: u) @7 G1
/ j( V4 h1 G2 N# k7 i
% G) t- w+ ]; i% K上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确. @* R U. ^5 t
8 C9 n4 W' ^8 Z. e6 n: Q F
img.shape
0 X5 `7 i9 M* z r; [0 }8 v1
" \% K; d: _& t5 G(3, 224, 224)3 U5 S2 J- [+ P7 C" t: I
1/ Q/ D7 U4 [* ?
证明了通道提前了,而且大小没改变
4 |& m" x$ u- w( z* D( \2 e0 I r; v& U0 F
9. 推理
9 P1 h. c+ | W9 Y/ Nimg.shape/ {/ m9 d- u) b* M1 y3 f9 l4 c( C# u
4 s: ?* a# z4 f5 b
# 得到一个batch的测试数据/ }4 J1 ?! y# x) e6 [1 s1 k
dataiter = iter(dataloaders['valid'])
5 r2 h5 m! H+ W0 v. X, ^. E) Vimages, labels = dataiter.next()
+ Z& w, @1 N, @5 Y3 G* d. P
1 w) ^. n0 X" |) D1 Gmodel_ft.eval()9 s" J8 r' ~+ X, b( g$ j
7 t! f5 e& u# o& ~/ Z
if train_on_gpu:' @6 D; F( f0 J6 I2 }# u' ~
# 前向传播跑一次会得到output
' _) K6 L6 J$ E* v) k- g7 r/ c6 w output = model_ft(images.cuda())
$ L# O+ K0 C! belse:
1 [4 K) H, S3 j4 D* Z8 B output = model_ft(images)
% ~3 L% ]/ {* W8 {' e: B) E! J7 I5 i' G) V; _
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
. m6 S H1 Z) @5 Q4 F( {8 Aoutput.shape( g( ^ Q9 L' o8 V
0 M4 \/ h5 e+ T6 S. d4 H
1: v/ M( P; C3 ~& @& v. e% d! A- n
2
V' W3 c& ]) l, o+ }8 x3% M, e& |- k% \) h5 G6 t
4
2 m' }& M3 z: S, w t+ Q: K2 W5- k- w7 ]; p, A e" _
6
$ ~: ^! _4 U: m1 L5 w: t7
) @& I) ?0 T7 E8 ?8 P8* ~" I; k' a4 F5 M7 B. T7 o' g7 A
98 F$ y- k' h' Y* V/ |8 E
10+ Z9 l* }# u0 F" g
116 B0 a9 w/ f! C6 [, A% a& g
12
! K" w; W. I/ Y9 W, P6 z- N13
G9 V, F( l) j9 C) r14- I" a9 g$ ~5 ?4 e% N6 p' q
15
" b; N; D6 J# n0 g6 i16
7 u m$ f/ V1 r7 y; ]# }7 Qtorch.Size([8, 102])3 v5 j; X) p* _4 Q2 d
11 Q8 r0 l. ]& ?+ W; k! m1 s$ K1 w
9.1 计算得到最大概率1 Q% M# G0 ]. V( N1 r* t
_, preds_tensor = torch.max(output, 1)- \7 Y* i' a1 W) i2 u0 M" Q1 h3 k: S
% u% q4 m; K) `$ c0 N9 |' {! Cpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
- L+ n$ B7 H+ e9 x8 l/ Q& W1 Q' E( m4 Q. h( K0 p
25 S" S' b( t) h6 X. h' Y
3. I' d3 H7 o' p% g
9.2 展示预测结果
5 N! F8 c c; hfig = plt.figure(figsize = (20, 20))% I7 n" ^" k0 Q. a: C, s2 g1 \
columns = 4
! T: I, v7 k5 R1 Z- }/ @" m# xrows = 2# _& R% y# B* w" u( Q* N! i
T0 m$ p6 u; o$ J
for idx in range(columns * rows):
. k! V, ~; V! c* c/ _ ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
- K7 F! V; ^1 W2 M+ e plt.imshow(im_convert(images[idx]))
) W3 U0 C* y% m; n z9 E ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 3 H7 o. n0 Y9 L
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
1 i7 a9 z* ?; P5 j) jplt.show()! M# b! j3 ^ F9 ~/ e
# 绿色的表示预测是对的,红色表示预测错了2 r, s5 @$ E. w4 B; V) I
11 y8 b( n9 m* \3 [) X- G9 V) }
2* r1 b: q9 B6 v. t+ |. `; U
3
4 ^" z8 S5 T9 w9 l46 M2 y: @% o* s
5
" q2 z4 r7 z+ D- g1 C% ^6
! @# V! Y- E3 \, Q# `7" z% n* ~1 S1 ], z
8
; M5 Z9 O7 Z0 S# v! f9. A4 X. f% I, i! ~
104 s: V) q! |( F
11
2 m' f2 f0 M' O4 s4 p5 C* g `6 ], d4 C6 \4 G
! b7 g1 z" @) V. C9 F5 B1 p- _4 ]- b& R1 l/ |! k) o" l
————————————————
& p, D* F$ F) x4 W: P& b- @" d2 ]版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。8 m1 S4 j. P9 D3 C: m6 a! p2 C9 U4 a
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940& [+ t: {& ~. _2 k6 K0 S
' _( \. B8 {2 `/ T3 G: j; M
2 u% o6 Q9 M& d7 }8 V5 Y% I
|
zan
|