- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564691 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174630
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例- a) o, ^, h2 c& i2 U3 {4 c. e
) z+ Y' B4 {/ Q# j1 q: u! \
文章目录, s( u0 D8 c! A, ^2 I0 [
卷积网络实战 对花进行分类
) P' E, ~8 b/ [# ]# l d数据预处理部分
5 `9 H# S0 ?! \2 E8 |网络模块设置
; |( U, q1 `: e& X* o# Z, o* e6 U网络模型的保存与测试
7 ] k! D& D9 P: V% }+ U3 n; `% x数据下载:" h5 _1 t" D% j; s( ?' H& G* C
1. 导入工具包
7 U( @7 u. R7 s. W2. 数据预处理与操作5 E& z2 K' z; W% Q
3. 制作好数据源 v9 p2 U& v/ h" |1 I
读取标签对应的实际名字: c3 ~7 f5 \* W9 m
4.展示一下数据
+ f3 s) g% t- T6 r$ J8 J+ u1 T+ |5. 加载models提供的模型,并直接用训练好的权重做初始化参数
5 ?$ A, g/ k( ^6.初始化模型架构3 c! B7 p9 c8 E0 ?. B
7. 设置需要训练的参数
1 H* y% h' P+ Z# N" ]) K: s7. 训练与预测& ^8 r! @$ B# v. [" y; R6 y
7.1 优化器设置1 V9 l, K/ ^. w1 C0 f/ `
7.2 开始训练模型5 x+ B0 k0 o4 L
7.3 训练所有层
) |5 `( ~' ?6 A8 T- M9 Y开始训练3 N1 t) ~5 `" Y8 e J* ~! h3 K+ d
8. 加载已经训练的模型
3 S- z x3 |7 O; P9. 推理
2 d" V& \' D8 r7 N5 C2 {+ H6 ^9.1 计算得到最大概率
$ R4 ^8 p8 a& v$ b9.2 展示预测结果' T5 z2 a/ G: _9 N
写在最后% V) P' a/ Q$ j
卷积网络实战 对花进行分类
% j2 Q, Y1 D: z/ Z; c ~( E本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能+ Z3 q u: O3 P. O) d
. g1 L: l* i1 ?: p$ P# I, E
在文件夹中有102种花,我们主要要对这些花进行分类任务
; {) ^/ G A& u; f$ ^文件夹结构
6 G+ P/ e5 F0 g2 Y6 _: @/ v. G
2 S! B" U5 h C y7 Oflower_data
+ N7 Q5 ~8 z2 X% |7 p
. h1 {& ~0 @& N; ?train% ~& L2 I0 o9 w. k1 W; M+ X
. i& i) u- X" ?/ ~$ M$ w
1(类别)
2 x2 ^ U" _: c2
5 {( {+ k) a& P( P/ mxxx.png / xxx.jpg
( O& G# t; L0 w. p2 ?4 ]. Wvalid0 \6 G1 e6 o: U w* N
7 x1 P" K2 t/ c; v. ^
主要分为以下几个大模块
' h% Z3 t; ?! ~2 r0 C
|8 ?& b6 }0 G9 x数据预处理部分6 R2 w+ N% E: r. c& x1 m
数据增强4 `2 A; u& V3 Z I+ D& q
数据预处理
- m) L: a g' }+ k3 I) V网络模块设置: W' D- L' X4 b+ p8 I
加载预训练模型,直接调用torchVision的经典网络架构
1 i4 n9 i4 O- A' \( u% A因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务: P( a7 G" C- E5 Y+ S: S! }8 y' h( u
网络模型的保存与测试
@$ b# u: I! K' z! S3 f模型保存可以带有选择性" i! r' k/ T" _% d% r
数据下载:& G+ q/ u- z% |1 |$ y3 {
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset9 K1 w( K; k/ V4 }
' Z: T5 V( p% I" a; N; o
改一下文件名,然后将它放到同一根目录就可以了* p7 s Y v# m3 P
) ?+ E, n# t! x# ]; k, }下面是我的数据根目录
: D9 P; ~2 l2 ]7 q0 ~" U3 @+ M7 U& y1 O/ h5 Z& O7 u
$ x9 w; z# o. {7 L
1. 导入工具包- P- d8 {; E% D& R/ ^/ e
import os# ]) u: \8 N: p A4 v0 l
import matplotlib.pyplot as plt0 m, r6 T' Q) f' Z: V; B4 U
# 内嵌入绘图简去show的句柄
5 @' |6 H, T* ]' x S4 \ Y8 [" K%matplotlib inline
4 u& c0 h; H5 M4 F# X/ limport numpy as np
S% T5 M% o* [4 D1 `* zimport torch
1 i; n% M0 Q( n2 a% m/ y! `from torch import nn
1 w4 i6 {$ s/ q& u; E W9 C) t
) I, y ~4 A% Yimport torch.optim as optim+ F) x) V F, f' y
import torchvision+ c# S0 `& Q' n9 W' @ x0 o
from torchvision import transforms, models, datasets
$ v! r9 n+ K P1 W
2 O" q. k9 w `/ F1 m) ]import imageio( O% G" |$ ^; ?" S% T& `
import time
! a, h/ y. J" \" @& M4 z$ Vimport warnings+ k4 f" T4 X# P% G& ~0 z
import random
, T3 c. \6 @+ Z: v+ h9 {import sys
. l' K8 o& s! Z$ g# t+ h9 cimport copy: b8 Z7 @7 O3 n3 T3 y; I/ p: J
import json6 _2 K4 K# W" V7 \ k0 E$ y
from PIL import Image! h( S' K: H9 E+ n+ f0 x A& Z
! T# [! I; y" F+ E/ j
; @, P2 E1 a6 H$ h1 b2 o0 J( }
1+ j% h- b8 r/ Q) w1 y
2
: M' d! s: j& x1 m/ ^3
2 e* T) Q; Y3 q* x& C. M4) F; h+ }' \$ `$ y
5) g0 n- P0 v& s. c
6) }( ` h1 s# a0 h9 u! I d% R
72 E& t% U( b5 i# p2 m, n
8
8 X7 c+ _! j `$ ~; Z9 O9* D$ a3 ^# l5 C3 O
106 A" _) W. X0 \; S n
11" s/ }/ e7 k( R7 l" G
122 U( g* O6 a% I1 ]! q; X7 r' i
133 r4 t- Y0 ~/ n; h3 k
14& i1 _) a; X0 h+ A& S$ H( }# N6 y; u
15) h' a' w# w: s9 b
16$ O1 z9 B- t! X! J. X( i6 j5 F
178 g O% u/ U0 e+ t/ v: y- l
181 @9 R; C2 F1 V i$ I) z$ v! V
19. E$ P8 j, }, v2 G" ? Q
20+ ~8 E( Z- `$ p6 B0 b, C
21& I/ f. N: Q" D- k6 Q- C# A
2. 数据预处理与操作
4 D3 U+ J5 g a#路径设置. i" W$ ~# @1 o8 G3 e: i1 p
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
/ [8 |6 |4 m1 C# K3 wtrain_dir = data_dir + '/train'* M: Q6 |$ u8 s$ G3 e4 S7 O) o
valid_dir = data_dir + '/valid', i9 D; H/ Q) _
1$ x, l) ?8 `' X9 z
2
( Q S4 u+ L, H: {( B5 n3
; J3 Q. L% R0 T: Z% [' M49 g6 R5 K7 c; w; q5 E* j
python目录点杠的组合与区别
7 v+ b# ?; r. s' D3 C注: 里面注明了点杠和斜杠的操作
% z- [- Y, T7 i: o* B8 V
% K5 b) P2 C/ l8 a: C3. 制作好数据源) {3 ~0 u, v2 v: d2 H" U4 l3 C
data_transforms中制定了所有图像预处理的操作
3 ?, }% l# v. XImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
. e1 B- V5 K8 J" Edata_transforms = {: k+ t! z6 n( W( [0 I' q8 X
# 分成两部分,一部分是训练$ I1 F5 k! ~, R- x1 u3 y
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
$ S; t0 N# _$ x5 |9 P6 T$ v2 j transforms.CenterCrop(224), # 从中心处开始裁剪
+ H0 U. [. V+ X2 e # 以某个随机的概率决定是否翻转 55开9 K" B/ l" w" {: |
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转; }- Y. F1 ~ X1 A7 f$ D, h
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
1 M+ Y6 X7 [/ d! `3 B! G # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
/ W( ?/ p' m# \3 W4 d ~4 P: i# G transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),9 V/ y4 E/ H$ j& k
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB) P9 k! }" I2 r
# 灰度图转换以后也是三个通道,但是只是RGB是一样的* D6 i; ^( E; }/ I: `6 Q
transforms.ToTensor(),
1 k. e8 Q, ~$ |9 y+ ~! R3 d transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差+ t% N. f+ k$ R+ t* E7 t
]),
2 @) S- g- n1 H% |7 C # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
! @. d5 g3 Y5 m$ k2 ]5 u 'valid': transforms.Compose([transforms.Resize(256),
) K3 L. G+ D6 ^; l9 x/ S0 Q transforms.CenterCrop(224),/ U/ z4 C+ o2 a) |. c) u" e
transforms.ToTensor(),
3 y. B F& i/ S; Z0 h/ h6 o transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
# g# v: v& B. T p ]),
5 e2 u- w2 H+ Q}; O X% P! j& j/ M( H( s
' X9 X d2 s* r' a1& n6 l7 W( O( {! d
2* N+ W7 t) z3 Z2 c: v0 }# h8 W: u
33 y( w% A6 }' Y, l0 }
45 _$ N; a3 S, K! K/ {* I( z* B6 `. C
50 M& \0 ^+ G3 M
65 {& q: q' |1 S) E' O9 l/ l, x
7
9 q9 y% w' ?0 S! ?/ x8- k# T+ P4 d' I, ~2 Y) K! V" J& B
9. {3 V% W( N! W- s8 |; `8 V2 C' u
100 T! f1 F& R2 Z) S+ C
11
4 E- [* f/ i; H: d7 G; N& ~+ E) A K12. Y+ a- {- ?. ~- |, }
13
, a5 N+ y5 A" N7 V. h7 a+ _" E5 A14# k$ m- y# a) W; |* ~: Z `
15
3 f" U s+ F( ]2 i, X0 r4 F2 H16# J; B8 |$ M% s% F' ?
17
q; v" C8 E8 ~; F+ \18
5 Z- Z8 V" x' c9 Y2 D19& u" t7 ~. H3 n) G: c7 J) @+ V
20) e! _3 Z# N0 `1 d; N
21" T. R1 m# ]% F" [1 h& D
batch_size = 8
Z' N4 e& O* S% S6 E: f% [image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
7 e/ l& ?# i* c% b% m5 V+ @& wdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
3 q( c- S9 g5 N. W+ C$ k2 f! i6 X0 idataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} ' p" `; r8 l# i& H5 B; a/ c
class_names = image_datasets['train'].classes3 ?( b# }! a4 |, k, C3 \2 q, ?
& ^ p- W. d3 x4 G/ R4 d
#查看数据集合
5 p( N( h; M+ i3 o+ Q* j: yimage_datasets
# u' S9 ^* V2 M8 Y% T+ Z* r
: J9 {! d4 L- h* b4 ^4 t15 K' x" v! G) m! D; N. g1 u# l
2, g* Y, }) S2 J
3
" T0 M. @- `. i$ N: U A5 Y, o4
6 ^2 h4 i& K- A7 \& e. m0 u54 R9 J: Z7 I- B3 E
6
# s: k' a- }8 }4 T4 y% ]9 S7& v6 s* r& ^ ]: s: S% x+ h
8$ N- g+ m, A! Q& @3 ?- a3 ~/ A! q( }0 [
90 U0 |. o- W/ J$ E
{'train': Dataset ImageFolder) s+ j9 R# e# i& Q6 S
Number of datapoints: 6552+ I) x, C7 J2 x4 h9 l% G" \
Root location: ./flower_data/train) y4 Q) y |- u3 N( l
StandardTransform
0 T1 D$ ^. a a4 k" O+ O0 d Transform: Compose(
# y% f, i* p: z+ t6 m, B RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
5 J" s7 [; s, R( T# P0 |! X; [: g CenterCrop(size=(224, 224))- Z1 X! Y5 Q" ~% r3 }
RandomHorizontalFlip(p=0.5)) h* l1 S- z! l; ~# Q' V
RandomVerticalFlip(p=0.5)/ d0 j/ ^5 c3 ?
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])$ y2 V, q5 [- ]3 C7 p3 T/ a5 G- L( {
RandomGrayscale(p=0.025)
! X6 w% B$ v" e2 _. P' x+ e ToTensor()$ |/ i% _' h( [- J3 _$ {8 P
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
9 E: R) a& e" K6 a% ]2 [ w# s/ i ),
& a) B$ f2 h* g( T! w. Z- o! ] 'valid': Dataset ImageFolder
) x0 m. H* R: P! A& l Number of datapoints: 8182 ~, x/ K: | H
Root location: ./flower_data/valid
; S5 E& J/ D& |$ h9 U* Z StandardTransform
! @: t+ R' E& Y' [ Transform: Compose(
' i: T: N* [8 v( I: G Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
" A, v7 R, r, a- z5 x CenterCrop(size=(224, 224))
$ v+ U7 V( p. ~3 C0 L, D- I) \, U* U3 E ToTensor()* ~% \' g! I( x
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])0 T4 j, F' R* S0 u
)}
7 \% m. [1 i% j! x! |
: E, e' u U' t* h3 E" z1" q0 v& N' S1 w$ `3 Q9 U& n
29 d4 Z, x3 [" F: H# ~* e5 t
3+ |# |2 k; W" Z! g9 [2 w
4$ v0 A; _/ c+ T2 M3 }: O7 V
5/ G1 ]+ H0 k- V4 p: y8 T. R
60 c: m: U- M2 G! C8 J5 ~( Q
7' k3 k6 m5 g) B7 Q# K9 G
81 G, B6 ]- G0 u
9
; E% T5 k( W' @. Z Y10
4 R' M+ U; l+ h6 Q9 p/ x0 @: D11
- Q: f& B7 C; s' L3 T8 B4 [12; i( a" N) | Z! E$ t
13/ x4 x M* v4 l
147 Z6 q9 K4 ]& F2 i# \ H
15* [2 G* ~3 h; g! f& e* ^! F
16
0 d$ r! S# `9 x7 Q8 T4 q171 a8 H% x3 f5 B3 x1 R
183 x- E/ g$ ?- {0 O& P. I
199 \+ G& i1 }: u& i
20
0 u* G8 H1 Y# p21: H9 U* T8 g" {6 E, K6 @
22& q5 i: P) _/ s# B/ I
23: e4 d& R6 {5 q' v8 t' r
24+ l8 }2 ]8 l" K- O) ? M
# 验证一下数据是否已经被处理完毕
5 E: F U" B7 j; Z# A6 d. Q; d+ Ldataloaders% |3 V* y! r% J* r
1
# P% K$ f* i" W: |9 F5 z$ T2
) `2 L \7 }. I2 N( ?{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,% Y) F6 S7 N3 N5 i8 ]; I0 I
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}& W5 k5 ^6 g# M
1
# c& p/ m% M# p5 W1 I9 ]2' B4 J: d8 F& r- b
dataset_sizes
# G, [. X5 Q; ^/ \ v* X1$ U/ \4 R" K6 s1 u% g A
{'train': 6552, 'valid': 818}
/ y! m3 J) i6 U18 f9 N- v7 a3 p3 K, z
读取标签对应的实际名字# E6 ?/ a0 U7 ]2 W3 Q6 e& ~
使用同一目录下的json文件,反向映射出花对应的名字
8 k+ j0 R6 ~+ a- f5 `$ s9 e8 d. M5 U5 m$ q6 O
with open('./flower_data/cat_to_name.json', 'r') as f:4 y+ t2 d% {1 b. H( T$ \
cat_to_name = json.load(f)4 Q- [& y1 o( C6 b
1/ n7 w% e9 S( H1 Z( T4 M( ^' W9 ^
2
( i) L. J, q& Fcat_to_name- x( d: F: u0 K$ {0 B v+ w+ ~
1, E4 | }' p* r+ B) y) d( J
{'21': 'fire lily',/ h" t! a7 t9 D5 m
'3': 'canterbury bells',
E" t8 E2 O# B4 z) O. H' z' E '45': 'bolero deep blue',
3 M2 b9 r5 J2 Q' |2 h7 d '1': 'pink primrose',1 h: G! } Z( T
'34': 'mexican aster', W# J" k4 I5 n; F2 g9 e
'27': 'prince of wales feathers',
) l, U/ ^* `8 S2 L '7': 'moon orchid',
2 r& s6 c% H- y9 G '16': 'globe-flower',
' r& v R& G, \9 S: u1 @ '25': 'grape hyacinth',6 f. a' i+ @# q# a: l- h) m
'26': 'corn poppy',
* z) K" n1 G+ ^5 F8 i '79': 'toad lily',( ]1 }3 T; W7 j8 i
'39': 'siam tulip',5 K( N P+ G, M( @+ v1 f) |3 l6 P
'24': 'red ginger',
( Q Y5 K) \# l$ o4 c/ X) ]# m '67': 'spring crocus',
% c- v7 J" H5 F: f/ h. K '35': 'alpine sea holly',
- @( l0 U" r: T8 [( q '32': 'garden phlox',' e& @/ w2 ]* P% }5 Y& h
'10': 'globe thistle',6 Q# J: D: y6 e( z
'6': 'tiger lily',+ c9 I% |" N% i6 M7 p# x: [8 s
'93': 'ball moss',
! C6 H O' K: W: w '33': 'love in the mist',
P- l# y6 p0 q6 P- \/ a$ O '9': 'monkshood',
7 |2 l* E4 n2 A0 s+ x7 C '102': 'blackberry lily',
7 h* m: R- N/ A. X '14': 'spear thistle',
- z% S% N6 R$ q* d- \; y% e '19': 'balloon flower',. I% \4 U( @$ Q' W1 e
'100': 'blanket flower',: g9 r) M9 F) p% e8 ]
'13': 'king protea',' T A, V! B; k+ G6 {
'49': 'oxeye daisy',
9 q' M: g5 w6 [- C '15': 'yellow iris'," x9 l( H! t& A! G1 I; r
'61': 'cautleya spicata',. I4 v+ y: q: Y
'31': 'carnation',! Q4 [: ?; e2 N- D6 t! \: H
'64': 'silverbush',6 \- A4 I6 ?, V) l1 d; L
'68': 'bearded iris',
6 i2 @: V! h6 N# {4 T '63': 'black-eyed susan',9 @" y! @' L8 I$ m
'69': 'windflower',- Y0 d* C9 N: i( G* g* ~: Y
'62': 'japanese anemone',* J3 o6 x1 W+ K
'20': 'giant white arum lily',
) K0 d/ }$ _& V4 u ? '38': 'great masterwort',
, P6 }6 c# J$ j4 l! s '4': 'sweet pea',
& O% `; b( F* V8 ~ '86': 'tree mallow',
; ?1 e7 l! C& m( M0 g" M '101': 'trumpet creeper',
# r% B8 v! m: Y; V$ `8 r! J '42': 'daffodil',/ D; D7 a+ k. b' F' g$ L
'22': 'pincushion flower',
: f; H" e& J! y9 ?$ G '2': 'hard-leaved pocket orchid',7 x0 {" ^5 Y5 K
'54': 'sunflower',
0 ]6 B8 j- R7 ^- e; C '66': 'osteospermum',0 d2 c7 `' O8 U' n: c
'70': 'tree poppy',4 E0 [: \) Y/ L5 E6 Y
'85': 'desert-rose',
/ b+ l. [ T" V2 c2 m( @ '99': 'bromelia',; {$ {) J5 ^& [* ], y6 O
'87': 'magnolia',
2 |+ r0 R; q# p0 _5 f) h4 s '5': 'english marigold',1 U- k9 S$ w3 c8 x/ l( y0 X9 i# Y
'92': 'bee balm',
: R; s0 W Z0 t3 k+ A '28': 'stemless gentian',
( f0 y Y0 l: t+ q( I2 v '97': 'mallow',
* g- g9 z$ [9 r" N( f '57': 'gaura',2 e) u) f" L7 t2 w& q1 H" Y( ]
'40': 'lenten rose',
% s" E4 e' B+ l '47': 'marigold',
' L9 f/ p9 I5 Y3 s9 L '59': 'orange dahlia',
' `1 J6 F4 P8 |8 S7 @ l7 k2 j x) W '48': 'buttercup',% w9 J$ a: P+ U: l9 O' @# P
'55': 'pelargonium',
& J+ F6 W; z/ J0 E7 }* ]8 a7 q. y" M" n '36': 'ruby-lipped cattleya',
4 V- [ {! C& B3 L, ^ '91': 'hippeastrum',
2 H7 p; R' O+ {, w+ G '29': 'artichoke',
* B. y0 I% p7 @4 z( I( C Y3 j '71': 'gazania',, Q. G' v; a) O4 a
'90': 'canna lily',
9 n0 ~% c/ O3 d* S; l; { '18': 'peruvian lily',
! ~( m. {4 {& S1 [ '98': 'mexican petunia',1 y, w0 O {" k1 n$ O& Y
'8': 'bird of paradise',( `: m% t- {! m: n S
'30': 'sweet william',* `, t5 U0 ^0 ]. j
'17': 'purple coneflower',4 F7 y) o, J! q% n( y4 F" Z. B
'52': 'wild pansy'," T1 r/ Z+ R H
'84': 'columbine',
7 E% Z; s: J4 ~ '12': "colt's foot",
* \- \% J" s5 Z5 J/ T '11': 'snapdragon',
2 h9 U5 k6 \+ s '96': 'camellia',
! U$ z/ V( l# j4 X5 E '23': 'fritillary',7 I1 Y5 U1 A1 q0 y0 t
'50': 'common dandelion',
, \- D5 S: t7 l$ y '44': 'poinsettia',
$ i% e' J# D$ a. v! b '53': 'primula',; b5 h+ c7 C" _1 ?+ p' C
'72': 'azalea',
+ y$ Q0 Y* a* x3 f5 ^ v/ U '65': 'californian poppy',
& L9 T1 v0 z- Q, o4 h% H, l1 { '80': 'anthurium',
. z/ m9 |6 ~' n# T- P '76': 'morning glory',/ T7 P; G% C- U: @( B: _
'37': 'cape flower',
7 c; y N) Z. v. o. X( M '56': 'bishop of llandaff',
+ ?7 j& K: m) t% l '60': 'pink-yellow dahlia',
% X$ { u$ Q& ^$ a6 }! c0 w '82': 'clematis',6 @* o5 C1 b# U8 a Q, {
'58': 'geranium',
- I% N8 i8 Y* S! c: K '75': 'thorn apple',
& w) i' x" P, p0 @ '41': 'barbeton daisy',3 o% K. _ ^5 W: A+ f
'95': 'bougainvillea',+ h- l1 Q5 Y( o" g# x$ B7 m% f+ E
'43': 'sword lily',
, Q( W: w/ E d' b2 o( H0 U+ g '83': 'hibiscus',4 R/ _. E [1 g2 ?5 w2 ~
'78': 'lotus lotus',; e3 C8 s2 l4 s1 [
'88': 'cyclamen',
" z- @$ y- H" K1 D3 Q v9 g '94': 'foxglove',
2 T' u* G& g) G& [9 ~3 e7 P '81': 'frangipani',5 P# z3 c8 _ D: q) }2 z1 K
'74': 'rose',
2 ~9 z! @1 _. L( j: q '89': 'watercress',$ @- t5 i3 G# r5 z9 j n1 f0 V
'73': 'water lily',
8 m+ [* q/ U6 d- | '46': 'wallflower',/ Y3 p+ w# c4 p
'77': 'passion flower',
1 A7 t. T+ i) B5 u '51': 'petunia'}
. m2 \: v2 `# {4 x! h$ W( v! X; a- X! Y0 O8 }) B& @0 ], b0 h* j
1
! ^( i& M' M& f- }! c2
' O9 q8 C! t3 U, _7 |0 B3
+ p1 O. h! V& r7 Z7 `4
3 n9 Z0 _3 s4 _$ F, d6 h3 g5
: ~- D) C3 r+ O( Y65 L* a9 F1 W, h0 u) e( {) T$ K0 {
7
5 A9 B; m! H3 }$ |7 [4 C8, V: j2 o ] V2 ~
9
! n n9 P" p5 ~! b! i10
8 o) I# F. J! ]5 @# s0 e' q11
& y+ o! L2 c: S. W' U12
4 `9 w! m6 |1 H+ p+ o5 v3 O13, b& f4 r, u+ O, t- N5 w
14
) O7 h: u/ n/ X1 g+ B15
8 e9 a, s+ x0 q$ w6 O4 K16
k5 x5 J0 R) X* X, {17' c0 t% l9 ^" [" t
18" x8 e) _! a5 n3 ]& S: E5 o7 v
19; e3 R. N/ ~+ E
20
' a# N# e$ ]* D) s1 ?217 _; U& z. A5 Q: L; _
22& T/ E4 U, Y7 |* Y
23
$ A1 }$ c! N" y$ r24
5 h' ~$ H2 e& I% b4 _$ V/ X253 V; L" J0 L1 i- E! X3 I
26) e" w0 c7 P' e* S' A
27
- P% V; D6 s. @8 f( }285 ]6 [7 w- Q; T M4 i' A4 R
29
, ^- i [8 t' q) d& ^30
/ r( K6 ]# f3 d# R# n" R7 u5 X: m P31
1 ?" s: [0 J5 o% [* E32: ?! U" Q/ h2 q! @: h% j
33
2 S) x1 ] j# s. x34
. Y+ C& A/ E$ ]4 M# d35" G+ ]- H* Z5 t: \- Z
369 @1 S! Z: I; R: L
37
" s! H2 k! ~) w5 Z% }384 M! d! e) J% ?7 d. N8 y
39: o/ Q* W9 R0 \8 L0 c) u# z" A
40 J3 w/ U9 q# P) Q5 m
41* ]0 Z2 E+ u7 {" I
42& X q0 v: a. L3 L1 H8 T
437 k; Y9 D4 B' l/ V
44
- _" \& C' t- D5 A) `2 j% P45
' S; |7 ]! D: `6 e0 I46
" L* F8 I! h. [: I* H) {47
; ~6 k& @7 R3 ^1 E* N; J9 [8 v+ S2 f486 k+ f9 h5 e% p" n( E0 n; n
49: j* n+ o& h! t( Q3 _" n4 Q
50& ~% q+ D. [& p7 Q+ x) t8 }* o
51% d( e) G) n% L- C- z5 K
52; e+ f9 o2 U. }
53* h' v% u) q% ^0 H% L
54
6 E; ]" y, Q1 r2 E, f55
2 a/ p: Y2 W6 [* U566 f6 H6 y) D1 e0 b7 e! I
577 J j% N/ Q4 M& C+ h
58
$ H) s! }1 y. l59
0 N, m0 |( l0 M9 ^! H60& o& O$ i2 g( ~
61; ~" m2 b. Y& [/ M. w( z
62. `) H$ [$ l8 |8 ~- J. z1 w
63; L, G3 W! ?. M
64
/ r7 c# G( t6 c" X4 o- d* K65
$ o6 t6 g, y" n; }4 N3 m/ |" i66
# B3 Y& l) i" \" T/ b* @* n! [; i67
2 q0 B- _3 {: T( x% i68
) }% n; r0 g1 Z5 K' W$ e* T7 P69
0 }) y' G+ x" v* W' R70
: k) M' R }; V, u71; O: ?/ Y- m$ X0 {
72
$ Q8 g% b5 G! N5 f: X! t( Z9 j73
1 E" n- A$ v; s+ h- ? X6 |" F/ r74
2 {7 R! u$ ^) S- |75
0 _2 f, p# N" I" g" s6 g76
8 g. o" [. n1 G' ]+ p77
. V4 n, E( s* R& g r& A) T78
/ H4 S8 u% r3 S# B" x: X79
* c$ z/ H u; O4 Z( @' s806 h) M' ]- e: }3 l
813 _' I' {& \# W$ J7 ^1 R
82
% `8 n# H) q$ [5 ]* m9 |83 ^% F5 z5 W S& q: [
84% B. l2 B# i5 w+ B u9 K
855 Z& l* s0 R l7 d
86
% F% v9 E Y2 U* j. v3 N4 `; L870 x% |5 W+ T* |3 A! S+ y( Y) D6 R3 V
881 X5 k# { i% B+ a! ^! [
89
) @7 }& y' o) X! f9 A5 X90
3 P) P. |' t7 C; E8 N9 z0 `91
: W& a. U. l* ~' }8 i( E1 `92 V$ t( s( x; f! Q) }/ d3 L9 @
93
" \/ t |" [+ k- K0 h" \; T94) P# j- o8 A& H6 k5 M) Y9 A% ?' o
950 m. }1 O9 x! `2 ?
96
3 H2 S' s* W8 B3 ?97
8 p' z' N; P( n% o1 a98
3 ]" J: y5 _7 ]+ l99
' I* `; }0 @/ Q, N100
3 f* B1 b* l5 H; Q6 ^ b: j101
: H. s9 e3 o8 V* e1 f1027 _6 `. M4 [1 t G
4.展示一下数据
. K; P: J k$ n1 h5 c$ o" j% W4 Sdef im_convert(tensor):) v5 g( q( V8 i9 W) C
"""数据展示"""
' D! h5 v6 n4 o1 t) E2 o6 w image = tensor.to("cpu").clone().detach()
& o# l7 y0 W: Q7 a7 q& W0 @ image = image.numpy().squeeze()# g, q4 t' u A" A' h7 A9 P
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
( n0 f9 ~ {. W. H # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)9 Z6 B- Y. T! M: f
image = image.transpose(1, 2, 0)
& ~: k& N6 D6 g1 v # 反正则化(反标准化)
4 b6 i3 B, q3 f7 |) ^4 C5 B, o' `; W. R/ q+ } image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406)); J, [0 x9 W6 [$ O! w* U
' L3 {4 V# U$ [* I # 将图像中小于0 的都换成0,大于的都变成12 {# O7 `7 O3 k& E1 h- j8 U
image = image.clip(0, 1)
; y3 {) @2 j6 H8 t+ q( K
$ s& ~/ ~7 l0 ?& y; q return image; V0 l% C: H0 i/ v: Y1 y
1- J- E/ N3 F6 m
2* U8 C {4 S- W6 ?, J' j
3. A4 T( E$ ~( m, e c
4
$ b' J* P: e4 G) P) I0 D8 L5- b! `6 e( z7 W8 \6 i' T, Q
6; e3 c8 ]& x# M- i
7
4 \! O$ P. ]) y* V0 z8
8 z) r \& r9 a- S2 R8 n9 f+ W96 X H f$ \9 b1 U1 E
10
" i5 Q- W0 K) O110 X1 v! e6 n! t5 ~; ]
125 `+ }, v( ]2 i8 s+ B) B
130 d+ D& g0 w/ b5 l4 C, q8 O
14. }4 {5 v8 G, c' ~3 D. M
# 使用上面定义好的类进行画图, | A8 M. {9 J6 ?, C( _9 `
fig = plt.figure(figsize = (20, 12))
7 N' x! J6 z/ n" G' Zcolumns = 4- |' G! [4 X3 ^/ o/ r V1 d
rows = 27 V( C6 _+ h2 D
! U/ O+ n3 y, v9 c6 s2 D+ y' ^" Q+ k
# iter迭代器3 h( A6 `6 R7 s7 C( k
# 随便找一个Batch数据进行展示
( b! U2 `5 e* o3 |3 Hdataiter = iter(dataloaders['valid'])
- W; [; c z) ^. minputs, classes = dataiter.next()
& e7 D! R* \( d( _4 @- g; [( i* |1 l
for idx in range(columns * rows):+ \7 p* z/ Q9 }7 R7 S
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = []) r' D1 E! D$ i @2 o1 v) ^2 e/ y
# 利用json文件将其对应花的类型打印在图片中& h) P3 H# t1 E: a9 d! u
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
+ A/ l# B0 }# @( K( |* k6 J plt.imshow(im_convert(inputs[idx]))
- I$ l% X. A8 Tplt.show()+ [2 J7 I+ i) @6 n
3 L* x. e1 _/ }1 B9 S- ?
1! I9 B, j5 k4 Q l) T) C6 l
2, n$ _$ `8 n. ~. P- o9 z3 j
3( g' S9 I; s9 y4 M4 k) {
4- z7 T# s! g" f, H$ x5 E
5
; e, u r: f) R2 ]. J* ~3 J K6" m1 N4 c# F% c# D& Q: P
7
4 L! L N7 ?! X3 Z+ B1 ^# |8
& w1 |* Z9 V9 }' k, p# ?0 d2 }9$ A4 v. d8 t- v: I, M" G* g, t% [
10
# u. `7 P1 ?+ x0 t, o7 T112 J/ x! v% I9 Z: @1 |1 `
12$ F$ ?7 o! z- G3 S
13
4 B5 F8 e, z+ O% c) N14
! d, @; f* J; X. r9 l7 ~8 h, B15 N) M1 E$ Y1 H4 O. s
16
4 u- j9 e8 F& m; e$ T& I
( ]( W% x# Z- h. T9 o8 g) J( j$ M+ t0 O
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
; g" j# Z/ H! j" q% g9 K; [model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
6 I% E8 r' c. D. z& Z( Y# 主要的图像识别用resnet来做
- N5 n" s$ |: N4 s# 是否用人家训练好的特征 e" |7 \" V3 c- }* m/ d
feature_extract = True+ o. T/ A6 b# i7 s1 M8 E" i, V
1
# |0 e1 U6 y+ j4 `( b2
! b* c5 @7 M0 M. z30 _: a8 [. e0 Z+ v# r/ E; E- r! ~
42 A. V% `: m- V- K& Q2 c
# 是否用GPU进行训练- K+ @: M9 V: F0 F4 [
train_on_gpu = torch.cuda.is_available()4 @3 t- g5 C# G8 D: y: n2 B
5 x Y9 x6 ?' n$ l. H: [/ kif not train_on_gpu:% U1 l0 J+ `/ O/ D# K- H
print('CUDA is not available. Training on CPU ...')9 u$ @& s* z7 r0 p* Y
else:; i5 [$ M% r( u5 M, O; ~, M
print('CUDA is available! Training on GPU ...')
0 z3 Z8 {/ a1 C/ k5 T7 F/ U
4 g( V/ Q z7 n. Q7 ]' I; ~device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
/ t: J; z3 S+ a! }( w) A. x8 G1
1 x7 f& ]! w( F8 f% }: y# X$ M# A2& u/ h! p# r: h! |+ X
3( p6 b, _& g8 z0 v: q) Z1 V
4
4 Q$ S8 t% o7 `* k- i9 Q- q5' X2 i; v G2 \% z' p/ G3 @* F4 F
6
/ }4 z) h, v C1 u6 F- e9 Z7 r72 g7 P; Z4 Z/ B. c' _
8; ^+ o' ]( z% Q) e
9 c- `6 i! w) R1 B- Y% s2 i
CUDA is not available. Training on CPU ... a6 i! Z# y7 E0 z) ?: R2 E
1$ l5 V I9 R1 t( U7 A7 K
# 将一些层定义为false,使其不自动更新5 K% V- b: M, C, |% D+ f& E# J
def set_parameter_requires_grad(model, feature_extracting):
) ~# O, Y L8 A. I. r' d if feature_extracting:' Y k0 F' o3 e1 D4 v7 m9 v5 ^
for param in model.parameters():0 C- h3 I6 U1 [+ ?0 r8 k' v
param.requires_grad = False! T* i+ i# x* s, s; y: a3 w3 K
1
& I+ Q6 T% }" T, J6 \8 x2
3 c$ |* a' v5 h0 }3
) W3 Q8 I2 p6 i* N- ?4
: V( g _9 b( E, v5& H6 Z6 o6 ~8 f+ T
# 打印模型架构告知是怎么一步一步去完成的
4 E; M/ o; Q0 D6 z7 \' t" ^3 T/ j# 主要是为我们提取特征的
+ m& S& g2 H# S& o/ _2 m k/ z0 Q6 ]# w
model_ft = models.resnet152()
$ I" [; e' }" X" U1 B4 x* zmodel_ft2 F* z( ~& L( J1 T5 i
1
: B p% W, x) E( h" ]2
! k% e3 p1 L. Z ^3 |& H$ R3
; l* L1 S* D, ?1 g" i8 J, n# a4; @ q# d9 `5 s( P1 U" ] I
5
( C! S6 i! ^+ \5 vResNet(
0 A0 n- @) V0 t1 A (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
6 f9 F/ N7 \) [5 k a5 F( w (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ o$ `( Q& }& X. l$ [) Z
(relu): ReLU(inplace=True)
( S, Z# u$ [ R% ~8 M (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)6 x4 f4 S C4 G
(layer1): Sequential(' f/ t2 O5 ? L- K+ h/ Q
(0): Bottleneck(
/ d' l. K) D' b0 q2 c& H' ]% K (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
, C" }/ V. S, C% } (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 L3 p# x0 R. i' L8 Z. R; t4 k9 U (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
2 S6 j4 `% Y0 Q" j; ~$ t (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
B6 @& E: G8 J, A4 R6 i, Y" c (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)/ i, P4 S d5 W" R+ i
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
$ Q' _$ D! [- J6 ?7 F1 o (relu): ReLU(inplace=True)2 R8 ?$ m5 l/ o4 v& d+ {
(downsample): Sequential(
9 Z5 Q% Q6 e0 M( j5 p3 _5 s (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
9 c4 A9 r8 J" m2 H; ^ (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 u* A( J2 H- R2 i5 \ )
$ ^" H8 O/ e7 f1 x )$ N& ]2 P9 c/ v. P
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
8 O% ~ F0 ] t# l; b (2): Bottleneck(
8 W/ M( N2 S/ Y: m" Y7 l (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
3 X7 A6 f- T8 m" O (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
( @4 g- e7 k" p: u+ b (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)7 r. q+ g- O) _6 e
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)& ]' Q/ d( t% Z8 t5 [; Y' o3 W
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
7 O5 L7 `" D" z. N& e, Y$ M7 G/ u (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ p6 w( A* f3 |/ | (relu): ReLU(inplace=True)
3 X8 ]/ j$ r' Y \; ]6 Y3 k% W )/ @% @/ g: z/ f& m: [" ?% G# n
), N. V. }: R! c# ]* y
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
$ y+ f* @2 t" O" q, d& b (fc): Linear(in_features=2048, out_features=1000, bias=True)
o8 y x7 O8 D, S" j" m O2 I)
; y+ s7 D, [7 K4 }! C
6 O6 c# c& G/ t( L) _( o15 g% T+ b$ L( Z( z8 f; G
2# e7 {; l) O. \3 J* s4 z9 l) c
38 v$ m4 M c; n% N
4
+ M7 q+ o2 |; m3 K- u2 [& @0 n5 [% h; @9 b$ p
6% n% R. u; h0 J
7/ j0 h8 u- ?9 Y5 p/ A
8
! L% _/ l x. n/ ]6 T, |9
/ D5 j. R [9 }, @/ Q" i5 @6 ?4 X6 T& t10
0 [! e; B& Q! n3 q11
) l( L: }, q" R2 H% I12
" \+ c+ a/ r; G" B( U( t) c132 s& L. A6 z) R( V
145 e* | @/ |& { i4 ^9 A' Y8 a
15) u W. e3 m9 Z! }% Y
16
: J2 [; x- r# S$ F$ [17
3 m+ f5 r/ K2 S, {" Q18! D! R! E7 Z3 A$ ]* G# _
190 l; P' R" K9 o: q9 I7 z
20
5 s' m8 C) Y O4 t21# I8 o t# W" z
22
$ ?9 [# O2 G, D8 k% C5 `23
+ G6 n# c; G2 i24
$ ~7 |1 r' s# O. ?0 s25
. F# P/ f. ?+ S260 s* X8 d# a8 A: F% E
27
4 ?. _ y) ]$ ?/ H* x* @' [8 e) b28
! v5 J% }- q7 x( `3 X4 y+ v2 {298 f' r- ~" I4 u$ s# p& J: o
30
% }% d5 f' B4 i* d- `: F O31+ Q/ P. w: k# f) C2 H: I- U
32; A6 |5 u5 ]7 ^
33 B g5 e/ I7 @2 _
最后是1000分类,2048输入,分为1000个分类
' j8 |8 n; F! u0 G+ E而我们需要将我们的任务进行调整,将1000分类改为102输出
! B: W2 M: c' S( ?8 M& c, H @( D
6.初始化模型架构
3 }- B [7 C1 Y& c8 ^1 k步骤如下:
$ q! T& v5 j+ X: J+ i3 m0 s5 g+ e8 j; b) B) q
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
4 e8 J! B3 S2 T0 [+ A6 [可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)# J; R+ Z+ K. x: T2 z
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数- Z5 C- e7 [$ u/ Z+ c
官方文档链接; v- I. F9 y' m' z
https://pytorch.org/vision/stable/models.html
7 e/ r* o- I2 Z
4 w* n! u; I, U0 U( L# 将他人的模型加载进来
9 y8 ^ q) V, l& I/ e$ ^def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
# _0 U, E V) c, A0 g1 ? # 选择适合的模型,不同的模型初始化参数不同8 w+ y/ p5 ^ h7 s4 ]
model_ft = None k2 q9 c( ~! m" y8 K3 M
input_size = 0- ]1 x( x! g0 F
: S% S1 m! D$ G- \0 P
if model_name == "resnet":; m; B& q/ G0 C* X' q9 B- ]
"""
9 R+ B2 i- Y& w. p% N; d0 P; ` Resnet152) u9 K- ]/ R* @6 X" P3 c
"""" A1 f* A9 _. `
( X- T7 ?% C$ c' X2 c1 X
# 1. 加载与训练网络8 }, z K# `3 Z* X
model_ft = models.resnet152(pretrained = use_pretrained)0 J9 C E3 X6 Z% g% T) U* z
# 2. 是否将提取特征的模块冻住,只训练FC层8 c5 K; U# X! n: j9 b
set_parameter_requires_grad(model_ft, feature_extract)
: B0 g0 V* w! e, i1 h+ g # 3. 获得全连接层输入特征
6 r C9 \8 Y/ `+ Z4 | num_frts = model_ft.fc.in_features$ _; O+ Q- x5 \; S. ~7 `9 O9 n
# 4. 重新加载全连接层,设置输出102* L& q6 K0 h8 f( c4 X- {
model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),4 t/ K0 t1 W* V8 P3 c6 i- ?9 Y
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1 R$ C5 }! X, |
input_size = 224
9 H9 p/ s" Z( n( d+ M" W& ]$ D2 C) `; M5 w
elif model_name == "alexnet":
6 \/ i W2 Z( {0 G2 v """: b. x n& {1 H1 Q# Z) R5 k
Alexnet2 Y& _2 S5 v8 Q
"""
* ~# i( q' A* r5 m' i" S9 W model_ft = models.alexnet(pretrained = use_pretrained)! r/ r/ n& b- g ^
set_parameter_requires_grad(model_ft, feature_extract)( Z; _& _! d( @1 n
; ~" T" d0 f7 N
# 将最后一个特征输出替换 序号为【6】的分类器8 C: B8 V U" B3 |
num_frts = model_ft.classifier[6].in_features # 获得FC层输入
4 u# F, ~! H) y: T model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
+ T' r' q" @1 \1 C7 H+ q input_size = 224. F; L8 D% r2 {' A6 x
& j, J7 F6 g% c" L1 I4 Z elif model_name == "vgg":
0 H/ L! J' W9 c" T """) w; _: B2 y6 h+ ^
VGG11_bn
' J- I7 o1 ? [9 F' I' M1 V; V; }" m """# B$ F& X+ D# ?( q
model_ft = models.vgg16(pretrained = use_pretrained)
9 h* L! W- l* N& V( W5 \ set_parameter_requires_grad(model_ft, feature_extract)9 l/ x# W$ c5 {
num_frts = model_ft.classifier[6].in_features+ e- k. |+ { `& z/ X( ~" N5 X
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
5 L; v! R3 c8 J- K# W: c' A2 A input_size = 224
+ U' G( G( h: E" ]/ H2 [8 o" S# |: G" E9 o; \& z L0 M1 P
elif model_name == "squeezenet":- X7 f5 y9 C1 Y3 ?2 d' ]3 v& y5 S2 L
""") ?, U/ S/ E& B# a7 f% l6 Q
Squeezenet% u5 u$ G& H1 C' X, h' A6 c
"""
! V* L/ G9 O9 i& ? model_ft = models.squeezenet1_0(pretrained = use_pretrained)3 U- H+ G$ a3 }/ c& w: H+ a# i) A5 W
set_parameter_requires_grad(model_ft, feature_extract); t' z: }0 I) S
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1)): U0 v" t( U# E4 b/ S/ z, W
model_ft.num_classes = num_classes* ?8 r$ C2 E0 h( n b7 n. S
input_size = 224
- x) b& C) X$ |3 T& |1 n0 L7 n- J+ l7 ]* m( F6 I. F
elif model_name == "densenet": A0 ~6 W! \% k4 u! |
"""
8 C: e, j0 _# @. i. p Densenet) h( J/ y6 X% n4 u- [
"""0 r. N$ ~# o& I# e6 p" z9 Z9 f
model_ft = models.desenet121(pretrained = use_pretrained)
2 }) @& k; M* T. G |. V4 d set_parameter_requires_grad(model_ft, feature_extract)
% g/ x" _$ p* m- R9 g num_frts = model_ft.classifier.in_features
' `& r* f) g) V% R' m: |$ L' L model_ft.classifier = nn.Linear(num_frts, num_classes)
4 x6 R k) {& h0 P2 ^) v input_size = 224) J- |4 p0 r1 ?- O8 K" @" j3 [
( ~6 n4 P/ T4 Y5 x6 o! ]# j( U# j* d
elif model_name == "inception":
" P, K9 C( O; M$ U# T: t8 b% q """
3 H, B) E0 L+ `& y Inception V31 i$ m# r1 ~6 x) V3 n" @8 S* Z& f
"""8 X- u9 U( U, K
model_ft = models.inception_V(pretrained = use_pretrained)
$ v- ^) W$ m8 M0 f! \ set_parameter_requires_grad(model_ft, feature_extract)
: c9 u; P% H {& k/ U U& }% e0 h
num_frts = model_ft.AuxLogits.fc.in_features
1 R1 p7 y# q3 I( Q i4 b3 M model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)( n" ?4 K1 U6 O4 C
* O* ]" z& f, o
num_frts = model_ft.fc.in_features1 `9 o9 |. ^: T1 ?' V) H
model_ft.fc = nn.Linear(num_frts, num_classes)' [" |2 r* L5 {6 e( j
input_size = 299
2 U! r3 v* E, j. U, A1 A: V* ^1 A% J! \5 y
else:
# ?2 {- d# l8 r* O print("Invalid model name, exiting...")" | d4 G- v" d4 C# B+ x
exit()
8 L0 |0 K) d# n4 E1 ]' T" j( D
return model_ft, input_size6 O T% J7 x$ |) ]/ g7 U: N
5 D2 Z1 I* f9 w" U' Z/ j( J
1
* U/ R/ Y& B7 X+ }; w* W2
6 ]& h( p1 i! V8 i3
& b/ W# |- y; g/ m8 [% H# d2 w4 v. |; r1 U& u' s$ R
5
4 P; Z- y# }$ u" L! p6
% X; b' m' r, n) |2 e7# `. ` u8 a' s" _7 k$ A/ S1 s, i9 t
8
; R& r: K5 x7 c d6 d$ T" Z9
. m3 e" t( {6 s, k F8 n# h10
: ^4 y W1 @7 {113 W" u( ?2 j# ]0 @3 c% X
12
0 C# n. u" L7 m) }: v6 U8 `7 x3 d' O" ~134 B# h6 R- _# n* D' X v
14
5 W8 p, V; l$ a1 Q2 B& _4 ]1 K% }15! n! R. _# t0 N4 y( p- D; z+ E" E# K
16
0 L4 Y+ d9 J* r# G/ |17
, g R& W& t% s( P18
$ k. t! {) p% X" x! U6 z19
" q; W( \1 a1 S9 _/ Q4 H20
. P0 G( d+ t0 l" u3 u1 J21: R4 \! r Y/ J
22- L3 X" s. I. y I! }( E
23
( y; t! G8 ^" ]/ L9 k& K. A24
; l" e+ }9 _5 c25
8 p7 t& S6 w6 w3 |26 g4 h5 w/ t' v* s& z
27; T; _) ]6 ~ E, m' \6 M' y$ ?
28
& [7 h( J3 x$ d2 U/ f" ^29
w$ e0 q% Z O308 w/ _. w1 L# L$ Z. U
31
' @' T* N& x: u32- M# w) w! |7 C! h2 L
33% w# K: Z0 `, x' K& B( E2 i* J
34" G4 M3 R( K# }/ v
35
- [8 e& m9 g, V: f7 p4 Y3 {, m! v' C36% R8 r' t; f) v X6 o2 T! f
37
6 N5 L% q% `$ p7 `. t! b38
" d8 n/ y( i5 i% Q- J$ C* j0 C39: D' N/ d* }- G$ g0 ]+ ^$ h% `3 o5 g
406 p. s( s n$ _- A9 B
41: d9 c" v6 G5 M) E4 Y# K
42
+ \/ i/ w; U& }43' a1 }" A# A4 {7 ~4 g0 {2 t! _6 Y
44' G. f! |: m5 U9 R) f4 G4 _
45
- `( _) W, l9 W8 T46
# K; J2 d5 v# d: Y2 k47
- T( X- s) p3 P* M, M0 N48
& [6 Z0 X0 g0 u9 g; y0 y6 j49; M9 P4 A7 N# X9 w, E J" F" h. C$ z
50
4 U& G2 k3 \1 p" ^3 t511 _2 y7 H/ E9 e g
525 {% R; f" F* _% D {
53* S4 z# i2 P& s" D1 _2 r( K
54; Y+ \" t, E$ r& N) m' H
55
& t0 m) ?9 ]( U2 h! d. `& t56
% y1 @/ q! t1 y0 I$ D57) ~" W& ?; C: S$ z7 G) J/ F$ D
58
5 [. g. q9 \4 X( F3 i59
* J) V; X: X4 ^) \60
/ C: Q+ C5 f, @) N' G6 n* _61
2 C; e6 J1 T" Z; l! T62( K! F% [8 g4 C! Z) e, q
63
4 n6 ~/ e2 p4 J64+ E8 T6 i7 ]$ l$ z
65
6 W) P- c3 {" e1 \/ Q# D3 X" u66& u6 b1 D; |$ s1 W7 @0 p
67% J5 W# h6 i4 J9 V
684 @: T4 ^3 C: x E. H
69
P a( R8 z6 A5 }( u70' l6 a" ^* V; r: f6 Y
71; l' a6 H. v) O; X0 \) L1 e) h8 W
72
! k' q/ d9 w- ~3 L- l73: p$ a' {' |' m- H/ n: @
746 J: X/ z% t& H( ^1 z
75
- Z" h3 r" f5 |6 _1 H76! i4 ~1 s9 A( ]) \# }1 U3 k; C
778 {3 r0 Y7 s( F* E+ m
78+ K2 P# I0 F( h
790 m9 F0 W# {! b4 z n( h' ]" P0 A: a
80
% H- c. m# G- x6 b0 [81# u' F9 x8 A; p+ o" K i
82
; g4 P* S4 v7 u6 ~83; g4 [7 ` F: t" q3 [2 a$ e
7. 设置需要训练的参数
, k" c. o8 P& [! c( ^# 设置模型名字、输出分类数" E. X' w9 A, f
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)) G8 ^9 j2 U u6 ]3 w* A! w: Y
2 ], F7 b1 t& e2 G6 y7 w& `7 J# GPU 计算
' H6 W+ T; @4 W; w+ C' T) Amodel_ft = model_ft.to(device)+ M+ G( j, B! p4 w3 V) @0 m
& [0 k/ P9 W# [5 j
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
: y, ] D2 _: {9 W# W3 k1 V% a/ Dfilename = 'checkpoint.pth'& O- J! y8 \8 [2 m, V
; u6 A0 b$ d/ x
# 是否训练所有层. {" l: e2 A, w
params_to_update = model_ft.parameters()
/ e& h s% d% d# 打印出需要训练的层$ d5 B0 F) c9 A+ f0 ^- _4 u
print("Params to learn:")
) \; H5 q: u6 Y$ b/ Q5 Vif feature_extract:) ]" _; k+ @/ e k% i+ h
params_to_update = []( q4 T" S1 ?3 m2 R' I1 m7 q: p- }
for name, param in model_ft.named_parameters():4 n, g v+ r! @! F
if param.requires_grad == True:3 s, C1 E$ m" P+ C
params_to_update.append(param)
% Z6 n U6 z8 J print("\t", name)
) `, i8 x: {/ f; A0 }* ?else:/ d+ o; F5 ?% h& _
for name, param in model_ft.named_parameters():
/ n, k4 l$ L# q# U4 d& Q- w* O if param.requires_grad ==True:
. V3 U5 j. F G* w. p7 `# e print("\t", name)9 r% O& t, X( T& s+ p9 _* i M
/ M% c( x$ ~6 t% b) Z
1
# I+ z7 |) c8 f9 |2
/ \: x; t2 O1 d$ G# g3
7 w. M7 I$ P, ^5 M/ u0 o43 R+ a4 c" W2 {
59 G: J$ I/ W! N% g8 j
6
& {5 H- e3 D" E+ [0 n7
, ?. a6 I U8 U9 w# R' B0 s; m. `" y86 L- @$ t9 N, N6 E! f) S
9* S" e9 O& N4 c, i4 Z, ~
10
, `' Q, m4 W1 z; D( ^11- m. l+ |5 V2 \3 A* S6 ?! B# G
12
, \9 F9 ^, F5 S8 F: u13) }$ q: f9 l& z! K( m' \
14
: I( T7 R2 z5 L: a) `6 a156 a# G e1 m0 C& L, Q0 [4 ~" O e
16
& M9 ^* m( r5 p5 v17
4 e; x; c1 E$ a4 ^* o& }18
' O5 x$ _$ |$ M! ?; r6 {- W19
7 h0 \& {: c4 C: r- \" E( @ ]" U20* a6 W( D# i2 ?
21
7 l% ^5 [7 V7 y% X5 ^22
7 [% l# a1 C+ @: L23- ?# h8 m! o- o) o9 x) Q
Params to learn:
3 W# ~' p" g' y+ S5 z% ~; D# {* A fc.0.weight
. K" E- `+ i+ F4 n$ j fc.0.bias2 r: B: w: H- r1 {* d
1
3 c0 m: V9 P$ _4 H/ c: O2
2 l" }; H, j5 J9 c% w3
3 Y+ D, A: G* \1 m: i7. 训练与预测
% b+ @& v' z) p8 C' b$ i7.1 优化器设置
. i, q- u: I0 N B- G7 S# 优化器设置
! X4 {0 o. H8 @6 N7 a, N) u. Voptimizer_ft = optim.Adam(params_to_update, lr = 1e-2); F9 A' [5 U/ v2 p1 d
# 学习率衰减策略
- f. e2 x3 S' v# n7 Y3 k: g' X0 N8 X# [scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)% u2 c" Q9 f, s3 E1 ^0 ^, H
# 学习率每7个epoch衰减为原来的1/10' c0 G- r' F: {; c
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算3 [' I: I+ x/ C: P
: n( R4 W9 h( E2 H& Tcriterion = nn.NLLLoss()
& i6 W* b# w8 b! ?) v1 U1
( ]" F- L- Q6 L( g2 l/ ]! P) Q2( D5 {: M3 ]/ F) q
3
8 t1 i" L% x* j4- x/ r" Q5 y1 Z# }+ }
5
0 a4 i: ^ X) m$ F: Q2 r6
# O8 D0 }# ^, O6 p* G/ A; T; h- ~7- F; {: g3 @9 w7 y8 R
8
# J' p) I6 w* ?2 @# 定义训练函数3 \! _9 w$ o. s; x, [/ j% r
#is_inception:要不要用其他的网络: v/ c; M7 p Z ?* J5 P
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
0 ]4 C* U. c' o: k$ m since = time.time()
& x8 t& R5 v* I( g #保存最好的准确率3 F$ b! ^% \6 F' z; m
best_acc = 05 n/ y* D; h2 p- w
"""4 w% S6 ^1 ]3 T8 t3 Y- U" ?
checkpoint = torch.load(filename)
, Q5 @, m' [) E- p best_acc = checkpoint['best_acc']
6 y+ [" B5 |5 Q7 v/ D P. X model.load_state_dict(checkpoint['state_dict'])) W2 U# q" ~6 {$ a6 x8 g* \3 E# n# U
optimizer.load_state_dict(checkpoint['optimizer'])
" n& R& B% D& j* J! x9 N& h, x model.class_to_idx = checkpoint['mapping']
9 o- R4 S6 O8 p """+ }9 t+ \9 F$ Q3 Y4 Y8 j6 K& B
#指定用GPU还是CPU/ A* _/ n1 m* C! q; G5 t; z
model.to(device)1 X# x! t, U5 ~' p Y1 J
#下面是为展示做的
- D' C- C9 \/ L0 x# ~3 q val_acc_history = []
. ]/ g- f- Q3 O* h3 T train_acc_history = []: ]' z4 a. x$ A
train_losses = []* V' u( ]' Q, c8 r% E3 l I, j
valid_losses = []
0 N6 d6 O. A* Z/ s6 M LRs = [optimizer.param_groups[0]['lr']]
- h$ r; g4 N0 G6 p* B7 \ #最好的一次存下来
3 H9 _/ U% b! L' B3 B% k3 Y, X% }- m1 R best_model_wts = copy.deepcopy(model.state_dict()): F2 r! G5 g& m8 j5 s& u! j
6 c" P/ s1 t1 a+ |; @ for epoch in range(num_epochs):
2 H! u4 r9 P3 p! a print('Epoch {}/{}'.format(epoch, num_epochs - 1))# \9 ~! W0 G' r. [9 B) p8 A
print('-' * 10)4 U2 F9 h2 X8 @$ H" S
7 O/ i5 h. a) b6 u9 @2 i # 训练和验证
/ Y8 G5 n5 N" K ~" X for phase in ['train', 'valid']:
+ _1 F& j O( z& F' J' D# y1 z8 g if phase == 'train':
1 Y- p4 u* h5 ^7 a model.train() # 训练
, h' W n( E$ E1 Q) M; z; k else:
/ f7 v1 z* f' b model.eval() # 验证
" U L* z5 X- ~) |4 W8 I5 Y: E' a: |3 X( |/ }, y
running_loss = 0.0
l; m6 F) ]4 l1 q: { running_corrects = 0& a$ e+ x# |' {. n% J
: V6 h' e" c& M+ t8 N% o2 R # 把数据都取个遍/ ?& P. E9 i0 `4 w6 ]5 f$ T
for inputs, labels in dataloaders[phase]:
- |$ x a' R$ i0 g, U8 S #下面是将inputs,labels传到GPU
. ~8 @% e# R8 z* b7 H/ V0 F inputs = inputs.to(device)6 S3 Y# O/ k4 M2 b6 S8 t- I% I
labels = labels.to(device)
8 t: y* p9 @$ z
9 h8 n) y. h- c5 E: T% I # 清零
+ |) p5 E+ E, o optimizer.zero_grad(); X/ c0 w5 }- Y
# 只有训练的时候计算和更新梯度
( h: P$ q+ u- z0 s% y# m with torch.set_grad_enabled(phase == 'train'):
" ?& e p! P% [/ \ #if这面不需要计算,可忽略
1 J; F. t$ f1 [+ b1 [# m7 c if is_inception and phase == 'train':
$ G; d! L8 E7 v `& @7 z/ d0 ~1 l# F outputs, aux_outputs = model(inputs)
8 e# A( j' g! C( e loss1 = criterion(outputs, labels)6 _) Y2 l+ W; X
loss2 = criterion(aux_outputs, labels); r" s# m" M: }9 ?5 }4 F- S! P, t
loss = loss1 + 0.4*loss2
, o5 s3 z+ B) @' c4 d else:#resnet执行的是这里
5 K3 p9 @. ?% e outputs = model(inputs)6 i& T+ r7 U2 Z7 \( z* t, \
loss = criterion(outputs, labels)' X+ s9 ]' J1 e3 j1 w
6 I. `" A4 N$ u2 {7 f7 g
#概率最大的返回preds
: H/ `% v k; K) o) H! k _, preds = torch.max(outputs, 1)1 U0 r5 }3 P$ Q- m4 n5 c" f
# _7 _0 U$ _* `/ Y$ @. W5 `' g* c f
# 训练阶段更新权重! D+ O, Q& k8 R: I- u
if phase == 'train':
" m3 O/ M& A Z" l z3 O8 y! r% D* u loss.backward()
/ g3 z6 c/ X( r' S6 K4 M$ H optimizer.step()
' a) W) k/ Y8 g# a
2 o) D+ C* B, ]' Q" @0 \ # 计算损失
3 P5 z2 a0 w' l- m running_loss += loss.item() * inputs.size(0)& v2 s0 A: b- h8 i- H+ n6 e
running_corrects += torch.sum(preds == labels.data)
7 Y6 o8 r' O+ x) ]$ r" U; B$ O4 l5 N- P9 x. z5 T7 b! g
#打印操作" o' v8 z/ W7 }0 b: _
epoch_loss = running_loss / len(dataloaders[phase].dataset), Y- s+ U7 t& s! \* o* i
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)3 F% ~& F$ D8 l( X, c& V# S! u
' C% Z* t2 @3 u9 }. |
0 T' K1 W4 K! Z8 M% R
time_elapsed = time.time() - since
) n) U( w% i* z( C print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))2 Z: t' S7 y( v$ l
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# {) }- J* q. k
; b9 Y( K. d* \% q( H: w8 z
5 N1 m4 `8 e! j
# 得到最好那次的模型
' @5 U9 U6 z% k- q: A5 ~4 f- f) ^% o if phase == 'valid' and epoch_acc > best_acc:
: S. @. U( d2 a4 w3 W- _- h! Q( V best_acc = epoch_acc
& t6 G3 Y3 X& a* } #模型保存
5 h5 P( }% ]1 O; E' f/ I2 d2 U+ a best_model_wts = copy.deepcopy(model.state_dict()), q4 E% r* P6 Y! r& Y4 G+ |
state = {
3 e, [3 F3 l% x# d3 y4 }3 L #tate_dict变量存放训练过程中需要学习的权重和偏执系数2 F; y9 q" I/ a6 r. H3 R
'state_dict': model.state_dict(),4 z9 L2 t. u$ Y/ e
'best_acc': best_acc,
5 H7 R+ X0 A- V2 }& h5 X% @8 ] 'optimizer' : optimizer.state_dict(),
/ S, x+ x- ?0 H1 `/ G" u. _ }7 G( o0 y' W. h! x; N
torch.save(state, filename)
# p+ Z% r% T, k( } if phase == 'valid':9 Q3 v9 K! j/ s8 x' e
val_acc_history.append(epoch_acc) R9 S8 ^: ?- v- I# E
valid_losses.append(epoch_loss)
! c; J/ S' L; A: |' P scheduler.step(epoch_loss)( L* Q- U; z* l" |9 }
if phase == 'train':2 s4 U' R8 x8 f6 C
train_acc_history.append(epoch_acc)( i! ~! V2 i2 [2 d& {0 I. d ^, E
train_losses.append(epoch_loss)# H w) t2 p [
0 W; B) F% A; X print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))4 U5 {+ `4 z# G' y
LRs.append(optimizer.param_groups[0]['lr'])# @) ]- j' v7 E3 l/ D5 k
print()( W* }0 C/ C3 ]
8 w G' v$ D/ N* v6 F time_elapsed = time.time() - since
9 e7 A: O' @8 B q ]5 r: c print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
' n# L: g! h' b) N& T' B print('Best val Acc: {:4f}'.format(best_acc))4 |# T: h& } g! ~6 R
8 i# ^. M9 p3 Z& d
# 保存训练完后用最好的一次当做模型最终的结果6 f% ]" a0 T3 L3 b- R J( r; M9 L
model.load_state_dict(best_model_wts)) H8 p h5 ?2 ]/ k
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ; i$ x2 `. w2 k/ s% l+ M
1 o5 c8 A) E0 r" \3 b4 h
0 T- l( t1 Q4 P5 Q* M1( T* k& _$ V f4 u0 a9 P5 c
2' G! C2 B/ b5 i3 t$ L; j* m
32 Z3 x( W" \' ^1 S
4
+ g n+ |* G" H8 [* R0 N+ K5
2 y; j- T* t1 u, |( y4 S: p( I6
9 Q- \$ J) _' n8 o2 X& g* ~! \9 r7) j$ \* R( F% Y' i2 q! T/ l& ?2 u
8
! Y* e3 l! @ J* R* P g/ K/ m9
" p4 k% P( M$ q0 R" b# n* T8 p10
, s" P4 V( w* R7 f9 U11
" q6 b% I! ` L3 G12
8 g c* T6 r, X8 f, e3 Q6 G13
/ v* Y. a" \0 G5 T5 U149 F/ F/ {! r5 i4 M, S6 @4 l. M
15
8 A7 {9 l, Q4 l" t" H16
/ n: g" Y: M0 D5 q+ k& G170 |( s) ^9 f+ `/ o7 A
181 {- c! o8 [- |$ \0 p
19
3 r. J, k/ D$ l/ v, d20
# [& F/ I, Y4 M( E0 d+ I21
) w! d( h- F* L: b% ?22
6 c- Y/ s8 c$ V5 b5 o4 b23
/ ~# R5 m F& _; q" g245 o6 L+ f' h4 g7 }
25
4 ~/ }: n$ s# y! C' t; _8 r26+ t7 J, N8 u6 h0 ~
275 o+ q3 I: ^' x& ?) D5 T) x
28
. }( u, j& \6 D: E6 M _5 c29! \; J4 {+ D( I: p" k
30
* R+ S4 n6 _, b8 j7 ]' z31
" n4 {) L2 e0 I) u32
2 Q8 E' m* }: k. ]: U33
$ z/ C% i3 E( @& k0 L4 ^. t344 J9 \% s/ _% [7 H7 k3 }
355 K2 u' g1 q; n5 s5 e
36
/ A6 H$ q# l: K1 x) X7 x375 T: t& b! z8 f
38
) ^4 }3 V- N2 P$ O/ w( F" ~3 L39; z; R" z( J# W/ |3 [" `/ q
402 v% g( [/ [( C. g: u
418 h9 ~/ U3 |3 {8 L8 M1 x
42
5 e$ \& Z. G9 J: ?43
3 s+ e( I* b6 C8 V. d5 Q! n. {* [44: @& K7 ^- m+ [! M2 x
45& {; j2 y% q' D/ g
462 k- c" ^' B: E
47
1 l: _, o$ v1 M& V3 l483 T" ~4 e: h! E5 Q" ^! P) B0 Y
499 [' T7 }/ j0 i! j# O$ f& ]9 E
50
2 i3 j# Q3 W. \, t51
1 Q& L! m$ A" A( r Y52
: q* i/ }- ~% c- \$ ?+ z' x53; c* g1 R6 [8 A) f! G C% y3 c1 C
54
& j& i8 {0 v7 f- T9 U55) p- j' h/ b" d# T6 e
56
; q" Q/ ~+ t/ [- o- J0 ~0 G/ k57
6 I$ Z/ X' B2 z! I; ]58
) T' t9 @% ~. E- J( \& j59" K! R! X$ n, q5 |, }0 ]: \5 n. L
602 D! O% {2 C# S7 y
61( p$ z! K! [ Z9 i! n2 S7 G
624 o$ G0 Q) M$ {. m& e
63* z" I0 ?/ N# v: K
64' w! U( k( \/ l( [$ U
65: K3 C, x& w: ~) l/ E3 X2 c, S
66
~3 P" V8 b& ^/ x( h/ ?67
, ?1 M1 J3 E0 R" c9 q686 P' |5 D; X+ Q: s9 B# y3 @ o* }1 a2 g
69
' a: a' m/ x/ S/ |. r: f) h70
& `- L$ i' L. |# I& r71
2 \/ Z$ h3 @6 C1 A& i" S72
0 o' S$ n3 d$ m73
3 ]8 f( S W) D3 ? x74* U# P6 u8 ?; C; k
759 r' J4 ^1 M8 a. `
76
% b2 m: x6 |3 ~5 b4 D2 H- J77
) e4 e |' y8 ~- c( p783 }6 `: D5 b0 f8 q5 ~ y" ]) I
79
' {, T$ h, o" @8 u! {80. f( e% z' z7 y/ H- W( V" [' l! ^
81
3 Q( a+ J' I# S" H2 q; q/ U, |82
4 S: s2 |/ \" U/ _& \83& I0 B* M2 w/ r6 G: ?9 T) C3 t( z- m
847 w+ a: [5 N' Q' Y: F6 w! w
85
) h+ p, R' d) G' v4 l8 @86: G7 j* d0 s3 v: ]7 [
875 ^6 E" `8 c P9 Y1 d' a+ X: Z
88" m: w& l0 @. g- {
89
* n( P4 Y+ I2 E$ Y9 `$ e90
: Z' o. }4 E, T# ?) \5 p. I91* b/ S& Q3 L6 q/ t5 l J+ p( l
92
# h! j+ W( P8 F1 Q0 l7 N93
( o( B0 y5 D$ V! m3 N94
1 Q. D, n0 Y# n: L+ N7 [; Z95
) z3 q d& E6 O) T4 [" r96
& Q& n# Q+ H# l2 u n& J97+ D7 o3 J) v- C- r! Q
986 b- o8 Z6 n5 p$ Z: B
99$ c9 t4 X5 v8 R
100" z- p. A" B7 v8 x+ @
101
d4 Z8 }# M: D1 J1021 G5 e X6 {6 X( i
103: N$ Y: e7 S( T% } x, r
104
, Z& v; s5 k& n1059 I1 _( O- U9 `' J
106
' @4 @" ^$ h- L: p: \107
2 [5 I5 g% o' M* B3 Y108
2 a/ L) b$ ^6 W109
7 _7 A$ C# H" G- ~110
' ]) M, I5 V, |9 Z111# c# H2 e& m E" a' x+ \8 r
112# b9 h- [& y8 _6 _3 _
7.2 开始训练模型
; b# h: P/ Y/ J$ O1 z我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
/ I- v2 H# Y: a3 O
" r7 M3 v/ B( H/ d0 E+ Q5 o; b#若太慢,把epoch调低,迭代50次可能好些
8 {; K$ n) m \, g# _! |#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合. h) r9 C0 ^+ N' G1 N
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))9 X5 c, e- q( E H
# J2 t$ q5 R' j: [, ?2 c, I& ^3 H
19 ~6 i; X/ \! B0 L6 H
2( ]3 A8 F4 K6 a
3. F0 k1 s7 R3 Q9 t# @3 R
4& o- _; R/ F. B) a7 M. m
Epoch 0/4/ _! H4 p3 k7 b2 [3 ]* J
----------
! F& L3 D; {: Z" j7 hTime elapsed 29m 41s
4 C% e- S. l. N. X& e# A' X5 O4 [train Loss: 10.4774 Acc: 0.3147; G5 I& a! A. ^7 z7 G* J" E
Time elapsed 32m 54s' }" ~; n" h% e, K
valid Loss: 8.2902 Acc: 0.4719& W# w3 Y. L" h7 H$ t/ r5 Q" o/ h
Optimizer learning rate : 0.0010000
$ _' g$ ^4 M8 ?
x% V0 \" U2 C# T. r+ \' iEpoch 1/4
6 P" G. D$ c0 N& ?( x3 T----------" ]/ y! X* {; u
Time elapsed 60m 11s
! N# F9 p/ L2 _8 o d+ `% h3 v# ]train Loss: 2.3126 Acc: 0.7053
' i! i1 d+ e5 T0 MTime elapsed 63m 16s D- [0 {+ ?0 `2 B M _/ M
valid Loss: 3.2325 Acc: 0.66266 C% P6 m3 t3 u
Optimizer learning rate : 0.01000001 |5 x, [% v8 }$ M% ?! V4 o# ]
% s( [% l) P+ ^" d5 k' q$ o
Epoch 2/4
! Q; \: Z3 o# X. R9 N----------
2 S# B% G: T8 ?9 c2 R; jTime elapsed 90m 58s
u7 l( }: t+ E! u. Strain Loss: 9.9720 Acc: 0.4734
2 h, K$ j8 ]8 ^$ G% e* i `% GTime elapsed 94m 4s
7 F$ A# F. u! O! W- @valid Loss: 14.0426 Acc: 0.4413
: | D9 p+ r' h+ B0 a& ROptimizer learning rate : 0.00010006 W. z' h+ G5 l" m' Q& d0 ]
$ N8 F7 b/ s% _+ ]9 B! r
Epoch 3/4
6 k$ q+ B. g6 g1 w) e$ F0 D; U----------
8 R2 p1 D3 }5 R4 L. b8 jTime elapsed 132m 49s% q9 W; }4 ?+ M( I$ s7 Z/ }
train Loss: 5.4290 Acc: 0.6548
; x% ~6 H: y% K: `Time elapsed 138m 49s
% x# w6 ?/ d1 b" avalid Loss: 6.4208 Acc: 0.6027
4 j4 e2 T( A1 A& M E* iOptimizer learning rate : 0.0100000' b# z5 s: z' y) T: q; h
9 {6 Q: f6 p# q! F' K
Epoch 4/4
+ s H/ q$ ^9 f5 q# O5 }3 r, j----------* }: G( m- m5 m
Time elapsed 195m 56s
# d- M0 m: m9 b. n7 htrain Loss: 8.8911 Acc: 0.5519
# E, @: V7 g- @* W5 qTime elapsed 199m 16s
: }* H% u1 q1 I. o) Y2 h" S, `valid Loss: 13.2221 Acc: 0.4914( K( M* j) R4 j& m0 N1 `
Optimizer learning rate : 0.0010000/ q' } K9 w9 j: Z! N
$ [ ^5 z2 F# x& }) R3 N8 w, OTraining complete in 199m 16s
5 |9 R( [9 p1 S$ Z1 j, YBest val Acc: 0.662592
; D! ?$ d& S5 N. J( p: L" T [$ Y3 A/ j3 d6 H; M
1
( i9 K$ k3 O+ K" ^3 e; n; Q2' s' {3 F- {2 r2 |9 S
31 w/ p" y) n; U
4: F3 c- |8 L. i* C3 J% ^
5
) s1 K' N; M, O6 i4 S* C6) P3 M( e9 O/ g6 T
7
& V% B: n- o- Y i& V$ ]82 u0 O9 |% O! X' f0 a
9
0 Q+ I) Z o( b& T5 w! Q8 ~# m10" X$ I: H2 w+ Z. a; s: G/ Y
11
( b3 {( b/ F, H) G4 w12
5 x& B0 d6 x6 D# P9 N# e13
, [) V, y5 e3 R14& F5 c) S0 ` y( |% A
15( ^2 p- ?. g$ u2 R
16
/ N. @+ z+ \+ W. O8 O4 P17
2 e: a+ M; U; L8 r181 S; L7 B; N8 z9 J
19
0 c+ D/ L4 t5 i/ h+ L, v20
: h S# n# n) j/ b21" [; _6 b. X) N8 b3 n/ G' z! _( j
22$ W+ S2 R7 L% ~) R3 J
23
2 ~7 J7 r0 C1 c5 v24
- l0 D( o2 p; j7 {9 x7 n4 `7 s1 J& p25
4 q3 |( M- ^( x: \& h2 m& c3 u9 i26
4 n, C4 z1 j9 w- V27
" y3 t" C% G4 Y28
6 |- _( w2 h/ n, y2 I5 z4 j29% r7 d; C9 A0 U! M- m
304 m) A" d% E( a; \2 G& V, l
310 i* [: L( w4 s
32
! a- _. ~" a' M! w4 w a33
1 {1 \9 L; B7 W; H# J% {7 e34# h0 ?$ p# H1 w3 ]* @' F6 d
35
! _# m6 p& }- Q366 Z, D2 w4 I" m3 p+ @' I/ D
37
* Z& q g$ I/ b0 |3 {8 z389 r: l' k: g' x0 F% ~, @
39+ I0 ?# `4 `1 p# g5 e* _) n
40
+ U9 o7 X6 A# y5 c! I T410 \: ?7 h8 Q0 G* O
42
2 ?5 e: J. e; N! S: @1 a* W8 S7.3 训练所有层7 \! S; F2 y# I3 ` q/ ?% |
# 将全部网络解锁进行训练
& j3 K; \* l$ m4 }; M1 i; Wfor param in model_ft.parameters():' ~# J3 E1 l/ T9 l" h8 H
param.requires_grad = True- Q6 B7 o# d6 a* | H
5 r2 m: S. [) Z+ C$ H6 Z& ^4 Q. n
# 再继续训练所有的参数,学习率调小一点\
; E2 E: z7 m8 o: c2 K2 G& eoptimizer = optim.Adam(params_to_update, lr = 1e-4)) h4 k( D, D9 J6 J
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
& m' V( M( C% e1 Z/ V+ X
7 ] D' [' o8 x$ b0 s# 损失函数
( F5 p D+ H! wcriterion = nn.NLLLoss()
2 V+ X' u- i* i/ S1
: Y0 ] g0 }6 x8 T ?% t, k6 N2
9 L. O$ C/ o) F- m# K& r2 z2 j3
& O/ a! y5 k- F2 I2 ^. v% G4( [0 X6 r6 |+ \
5
' ~+ ^/ ?# M1 }/ W$ i% y6; E9 E1 I- J, {: u3 J7 f
7, m* k& X7 Z" |7 L5 `+ r
8
! ~9 f2 E( `5 r3 i# q8 ^9
6 Z0 ~2 z* o* p107 ?8 I9 q( v$ h9 p
# 加载保存的参数9 g$ ? l: z# o3 ?3 m7 }3 K
# 并在原有的模型基础上继续训练
4 f2 j2 W$ `% w* l# 下面保存的是刚刚训练效果较好的路径
% t, b& d+ I. B$ W" ?, gcheckpoint = torch.load(filename)) ~3 X. A; f1 Q4 O: |
best_acc = checkpoint['best_acc']
, {6 a1 v9 T; ]" q O/ [4 @5 D2 emodel_ft.load_state_dict(checkpoint['state_dict'])' {0 Y6 ~) Y4 ^
optimizer.load_state_dict(checkpoint['optimizer'])& n* ]. d8 {; z2 t- v
1
: R0 s& }) ]) {+ S. i' S+ A# M2, x* i! H( L" O4 W
3
0 h( g% ~6 L% z' G. y9 W4
1 u- c! ?% B7 q5 `! {5 S' R8 h6 b5
; h4 m$ p& G u h t- Y0 N6
/ D/ f& x& b* D& q! a, ]72 o9 [$ D: B' ]0 U& e3 f
开始训练
2 s, D3 |( T- l注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
- u0 K& X8 V- t9 Z+ \2 ?9 ]' K) q, T; @: Y
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
0 ` C8 Z3 Y; v) s17 G& F/ m: e; j
Epoch 0/1
2 ]7 U" r2 o& o9 d4 m0 D+ j----------+ t3 r8 ~4 K, w1 `+ P( J: b# r
Time elapsed 35m 22s
. F0 f; c* q. K+ K- Vtrain Loss: 1.7636 Acc: 0.7346
1 z1 I% |6 M; r0 ~Time elapsed 38m 42s
# p- E, C9 h. ^" Vvalid Loss: 3.6377 Acc: 0.6455
2 ~2 L# J4 @3 Q! J4 U c7 eOptimizer learning rate : 0.0010000) K2 {/ s" m& M) ?% Y
; N! U V6 S2 C
Epoch 1/12 @% B7 [6 `$ f: F
----------
# g3 P) `& Q" X1 F, k! FTime elapsed 82m 59s& o" n! j5 G" e8 _) o% h' @" x
train Loss: 1.7543 Acc: 0.73407 m3 l1 f2 C$ R. l, }( [: \
Time elapsed 86m 11s
5 N% q9 |) W3 a: `) Avalid Loss: 3.8275 Acc: 0.61370 F$ N, j) V; v1 i1 s
Optimizer learning rate : 0.0010000
* J+ k' m! e& T' G+ F
5 r0 t# t9 h8 G8 v8 o( kTraining complete in 86m 11s) G. k4 k3 m! V9 |5 Q
Best val Acc: 0.6454779 @7 `# H. [. ^% I* f1 h$ F- |* R
{9 Z8 v4 k# ?( l
1
' ?* \+ ^" F3 g) j2# h: {8 t( _7 T: E
3- s/ C5 B9 [+ l: A; b' B
4
1 b- L4 Z1 |8 g; T52 t' S; C8 @4 r, K" t* A# r8 Y
6
! g+ E( v# G7 W' @7
0 Q( r8 {9 d2 {9 N* y3 [) b0 A( G8
4 b0 T$ B$ i' F6 L0 ^% R) {( j5 d3 {9
: e% W! Q8 m( ]6 {7 D$ ?1 g! o10
" V$ l- c7 E1 F; l1 ]2 p' p11( }- s9 h( x8 t
12+ N/ i$ Y3 `& V4 y! G( O+ t
13
# u; l+ z k$ |14
, j; e& `+ b% `9 ^3 U8 l0 R15
0 F) e$ I7 W4 R3 n16! {3 v9 M& g+ M B
17
, n0 z: \6 a( O/ a, r18
- }+ B( }% Z. o4 T d) y8. 加载已经训练的模型
5 D( J$ Y' R2 d相当于做一次简单的前向传播(逻辑推理),不用更新参数
: v0 Q+ d2 v# Q
6 o; v `1 N$ A- L" Y/ b: Pmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
; o7 i# N& ?0 ^8 w# ]3 J+ X# o" T4 d8 {% M8 z
# GPU 模式
y) y0 ]7 t8 x9 m" b, Smodel_ft = model_ft.to(device) # 扔到GPU中
4 j8 b( ~( q, A6 _. q0 d# E
* H8 M R0 z. W5 `# 保存文件的名字
3 z2 E2 X+ H$ b A: ^* jfilename='checkpoint.pth') ~0 e1 T; l# z+ _$ x- I9 c' |4 ~; {
4 @" D0 F( P/ Q% }; b% {
# 加载模型( [; z1 ]! l1 P- S
checkpoint = torch.load(filename)
3 g+ U) {! x7 a; O. B0 ]. E+ [best_acc = checkpoint['best_acc']
8 k' Z4 |; E5 o) p! u2 Gmodel_ft.load_state_dict(checkpoint['state_dict'])8 c# s0 u" |" K
1
) Q) K$ Z X% o% o3 p" ]( d: W2
1 c- {8 @( A" B5 I8 n* G3
7 O- \) g9 }* j0 H& a% @4, A6 M. l' ?& P" x5 i) b$ L
5$ n& l$ u6 k2 C4 r
6
0 V. O' g( Z1 B5 q7
4 U* M4 Y0 ^ B9 {% W# {0 W' P8
+ u% e9 l! ~! |( i9
@ q/ \# b2 h* H, S/ D10
+ e4 B3 }& [6 k* s l11
* g* V3 ~- G5 T12
8 M- t8 b' u$ @; S& D: Q+ F8 J3 j<All keys matched successfully>" }( o ^7 f# D) y
1# T. }* [( z. y/ B8 v, h" Z& g
def process_image(image_path):. d4 C5 x" @8 e5 a
# 读取测试集数据# O( N B8 k! k1 c O! `
img = Image.open(image_path)- f0 }- }) _* e: V! Z7 y$ J/ `
# Resize, thumbnail方法只能进行比例缩小,所以进行判断2 G7 x1 G; I" _ E0 _* v
# 与Resize不同. M: F% K. D b8 l7 e
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
! s9 {$ x% I; J3 g0 y9 j # 而且对象调用方法会直接改变其大小,返回None
/ `( c1 g( v6 R2 I9 X3 h4 Z if img.size[0] > img.size[1]:' V0 p$ d( t+ u# R* Q( r
img.thumbnail((10000, 256))1 b5 q$ {+ S, w
else:
1 u* a9 Q3 Z( W img.thumbnail((256, 10000))
9 H1 g G% l" F: ]! e6 y- i5 {1 s" g7 P
# crop操作, 将图像再次裁剪为 224 * 224
& c" J* e$ x* \ left_margin = (img.width - 224) / 2 # 取中间的部分# Y" i7 L8 I) z5 Q2 T: V- j, ]
bottom_margin = (img.height - 224) / 2 9 D+ J, F" w4 P: \9 L
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度 F9 @7 |) e, L- r2 p- B( r* @0 D
top_margin = bottom_margin + 224+ C6 D: R, o/ t- ~5 g( J
7 Z, P: J( `8 C4 h% d7 Z4 S$ |* u& k) H
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
2 X+ M q0 U9 ?
2 U% w3 g; F! S4 R6 O/ [2 ?5 U # 相同预处理的方法
+ C8 I; L4 G1 K2 y # 归一化# t' p" r+ L3 b
img = np.array(img) / 255
# g; r# Y; X" G; _ mean = np.array([0.485, 0.456, 0.406])
9 b$ X; L4 k9 @ std = np.array([0.229, 0.224, 0.225])0 ^2 f, f1 L' Y, c; t# |- Z; h
img = (img - mean) / std
9 v2 p$ n4 C- i0 @5 F( s" e% B# B) @/ x9 M# v, p3 `1 G4 |. F
# 注意颜色通道和位置6 ?5 d8 _1 l% q% c# u/ _3 K
img = img.transpose((2, 0, 1))1 r3 e! O) O6 \% H
, R3 H5 \8 M, N. f return img
+ q1 y$ m$ R# q4 ^2 e% L( T+ g/ Y+ y0 x/ A
def imshow(image, ax = None, title = None):6 m* m4 s& {5 G: p. u1 D8 G! H
"""展示数据"""
/ d0 C3 Q0 R6 m) l% A. v# s3 u! W if ax is None:
0 x6 ^/ b' @. I: F fig, ax = plt.subplots(), W! k9 L& E; B6 h1 I4 G
$ N/ t9 U$ u. k3 B # 颜色通道进行还原
0 ~2 Q) P6 _0 Y3 |; e' g image = np.array(image).transpose((1, 2, 0))
, D: j7 }2 ]# O2 ?1 b; N$ {" Z$ P: \" V
# 预处理还原
" M8 n4 M* l9 ^" M mean = np.array([0.485, 0.456, 0.406]) n. B& J; j. g6 o0 i( O# U
std = np.array([0.229, 0.224, 0.225])( b9 D& B" f! o9 T
image = std * image + mean( v0 v' G2 ~& w& G2 T
image = np.clip(image, 0, 1)
) X X* ~8 e, ?% U' |. ?% E$ q$ P V* z# t2 L' {6 A1 y* ]) H
ax.imshow(image)
8 {2 [ g0 ^$ O+ Q ax.set_title(title)
2 b% Q5 {, s4 D }$ r) r5 u
) l) x' h0 ~6 ?$ U4 }. {) I return ax6 f( L! Z; Y4 E) e5 |8 t2 r& r
9 f5 c* t7 O7 v$ e
image_path = r'./flower_data/valid/3/image_06621.jpg'
# d7 b8 T. J7 |- [img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. `" h& Y* `" x9 o) c
imshow(img)4 z' }# D' `4 B% X7 T3 ^4 x
! A" ?+ z1 h8 ` I) F1
0 y/ W" u9 c; m& E8 H, @ e5 V- x W28 a8 `% o3 o; H
3
0 B- @8 y2 c1 e. O6 a" N) g46 z0 O& D4 P5 E X+ K
54 n( d) b& d3 P8 _) F
6
8 _( ~2 @- X8 z8 Z7
4 M: g R3 K$ [* W85 k' S T/ x4 T5 c
9
9 u3 @* ^2 \# l% Q1 s10: J* `6 h8 c1 O7 P
110 e4 w6 O2 k! a8 f- @% }
120 e9 J' i# G6 z, r5 e' {
136 R _. ^% s' g5 B9 I L3 i
14! Y) y+ m2 G3 f3 O! P0 I. ?5 {
15
( K r( h7 e6 D: E: l, z8 j/ L165 e' ~8 @, B# v7 { X# N B* r
17& m: P4 J) s& f+ X9 m% R3 Z
185 b3 E1 `! }' k
19
% L5 d5 u1 i! \; x: [+ ?: n6 n20
2 o2 }$ M0 Z7 j7 X& L21
! H0 ~! d5 \- P, Q+ Y+ A; F22
+ G+ D3 N/ D. X, T/ R& }23
; p% K0 W* E# _8 ?# k$ y/ u" w24 Y( J5 x9 o# X* U8 G; Y
25
/ }) b; g0 \( Q5 I% k5 I26
- }0 _+ d6 G7 M( ?2 ~27
' m3 k1 k2 a% [; R2 ]( e28- @: ~5 C' s. ?% K2 x% D9 U I P
29
7 I* R1 H/ w$ Y! ~# U' `# s [30
6 N* D9 a1 s" J# }% P31$ C- r/ W, j0 |1 z. w0 G
32
0 K: S9 D. w1 C1 F. a338 W$ H& U+ W& L% n" a- o
341 m2 d7 ?3 e- z2 C2 O
35/ t' ?- C9 q6 F
366 u$ Q6 M5 t$ ~# K" A! j) W
370 H& \/ ?+ i# Y8 @% [
38# H y2 E- w4 }9 U2 B$ d' W
39' |9 H% P! N: c& [7 ^5 P S
40
8 C4 B* Q0 w) T, x i' Q41
( X0 m& B6 e- F0 O428 S. L- l0 G' w8 E+ X; g1 N& i9 D+ V
43
% A& X, A6 m% I% e44
5 \. `( K+ F2 s" s8 T8 H6 {45" A' k# v. d. U
464 j L' [2 @+ b" y+ C
47: l/ ~; p( L/ B/ I, s7 Y4 }" ?
480 d2 U/ ?% }3 {, h7 ~. P* V
49
9 f- N1 ~ A5 p7 I5 n9 ]0 p50
8 M, r& [) Y- B" k# h ^51
9 i9 |* x! M) w" F$ W: k4 f52
+ a5 u- B% c+ Q3 e8 P53
* ~' L0 i. U' {- o" }6 ^3 |/ J3 Y54
' j9 C) Y2 v6 v* I$ \<AxesSubplot:>: ^7 }5 ?: q4 E2 D- a
1
9 q2 q. z; {# w% _
) A' M: i& X7 _/ {% G上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确( {! j, z. G7 o' e* m/ k) U
. a; C. b) ]4 ^$ X9 C: |5 u3 s
img.shape
- A$ _$ e) Y+ ~' o! v6 O2 e1
- @4 ?6 [4 g: W* x% n- ](3, 224, 224)
% n0 S* |' l, o( r2 d6 H8 R1
A; E3 E7 \! b, E+ e# k2 v& H证明了通道提前了,而且大小没改变4 s- Z1 W2 |) K; d3 q* E( P, q0 d* B
& | y% x. p$ q
9. 推理
. ~" U1 q# ^. }; himg.shape
$ Q0 o, t; j( p3 b) ]
" b# |" c! Q3 E' p) i, G3 M# 得到一个batch的测试数据# z9 _2 {8 Z5 B7 \6 ^4 n' l
dataiter = iter(dataloaders['valid'])7 C; ~7 T2 V* a$ K
images, labels = dataiter.next()
1 c% Y6 O+ D/ f# ]9 g6 B/ Y; `4 g8 @
model_ft.eval()0 |6 c+ M7 j' V, @! Y
. d$ F3 i: ~9 u9 Jif train_on_gpu:
$ J' x& y# F+ E/ l% a3 K # 前向传播跑一次会得到output
. W2 F$ U- e {* v output = model_ft(images.cuda())
# X" E6 S+ S) A8 helse:
* v, U) b2 T/ P1 ^% d$ F' p output = model_ft(images)
6 f( n3 k2 r. R5 N6 `/ @" i) Y1 l7 j- V
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值5 T# u3 O5 F0 L! ~! k+ j$ `2 j
output.shape
# j) O i0 N# A, N5 o+ n
8 S# p" {6 C9 D0 b# |/ K1
7 B; b4 W& O3 x& T2
+ F2 i( R$ A( A3
2 i& q9 Z) W( o! f1 S/ L: Q( ?4" k3 \" V+ s8 S& M
5
' f: @! B: _ S9 D3 N" I! d: K6. G) ]3 ~; {" W! @2 J0 n% R
7: ^: {. j6 V k
8
4 r; ~+ f9 Y; q6 N7 O99 x7 P4 S/ A, a3 \4 `
105 z8 t- q, q2 `. T& | I+ D( Z
11; E: t% o8 O3 w8 ]% K; ~
12# L( q% d& l9 e( C1 i- r) z f5 M
133 I% y# X, _( \. l
14$ |+ r; X8 K0 e0 u/ t' p
15
& p3 A; ^. A" D" b( e* F5 K, |16
2 W% r* P) X4 dtorch.Size([8, 102])
7 f* T0 w' b1 q$ m6 z/ h1
- [9 o4 u7 D) q$ g1 U* P; C, J9.1 计算得到最大概率
* e; S7 w! {2 c1 __, preds_tensor = torch.max(output, 1)4 n0 q. M! x8 M. {8 k( A
- k& e3 D7 l" K. U
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量3 V) }( i- B+ S8 ^" c& C
1
: ~$ ^! Q7 s% ^9 C) _! t( W4 Y2# a% D+ Q5 i! d
3
3 b1 r: a+ T: E, \2 @# L) V9.2 展示预测结果0 G9 ?2 @5 E x2 D
fig = plt.figure(figsize = (20, 20))2 G0 ~& v/ }1 L% Z, _/ p1 J
columns = 4$ V4 W; `/ L2 l2 i- ]
rows = 2
# g( J( M; s8 O
) d$ ~) U8 p6 j9 _7 C) u# k- a! lfor idx in range(columns * rows):" Q5 ]6 D/ C& F' Q# O" `
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])4 e: }- E+ K' S) I, I6 |5 z
plt.imshow(im_convert(images[idx]))
! G& [4 g6 X# C ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), ) v% H0 P2 l2 q8 a# F
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))% v7 B; W, b4 X: k' S
plt.show(); c9 h( k9 Z7 F3 [
# 绿色的表示预测是对的,红色表示预测错了' E* x- D+ j5 V2 ^, J, o R* v
1
4 a, m: j& Y9 g3 j2
2 i5 z, W! ]5 U& {# t! h/ I3
/ f% g1 V+ m$ x+ R' l0 \# }4
* o5 B, a4 {# m9 y52 @4 K) h, Y* [" H
64 o3 L& K" U: G/ ?& [) Q
7% T% A5 ?9 x3 d" }
8/ T0 \' m5 x1 O1 o7 ?2 }/ j
9
# F% q" h& u$ X4 @! Y$ Z5 A& e10
/ Z+ J% a' }% }0 G6 \8 j119 S" ~1 ]' F4 _, D! _: |' Y
$ R: e9 h! T4 K
4 J* j6 m; i' ~7 }; [) U* }/ W, X b/ ^9 { x& ]- s
————————————————+ g7 W- K1 s0 j! S4 x& P3 O# Y
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
, Q# u% v% |, t原文链接:https://blog.csdn.net/LeungSr/article/details/126747940$ y ]- z6 D7 h
+ w' ^: O$ S% _, _0 d6 Y w L. @ Y7 W. m% o5 E. u* ]
|
zan
|