- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564697 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174632
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
! m3 _( g$ u3 |; z1 h
% i+ J9 m- t k. @' X9 p文章目录
5 Y# ~" E1 R7 ~# u2 y8 k0 b+ ]卷积网络实战 对花进行分类
0 M# G# H2 }+ B) t, A* h( p( i8 K8 x4 e数据预处理部分/ I9 r, ?) H7 n) O7 k$ q
网络模块设置
7 u8 ~$ R+ K- M7 W7 M8 N网络模型的保存与测试
! a2 o6 R0 Y) j J( H数据下载:- d# `4 l8 |( `, k
1. 导入工具包
M$ F" b- A8 w; I! N+ `3 t( i" j+ d2. 数据预处理与操作
# u) x& U7 K* `! f: b# c3. 制作好数据源" V$ y5 f/ C3 s# n& u8 \* _
读取标签对应的实际名字0 d4 p" R; E. h! Y3 z; k
4.展示一下数据
8 j& {5 m- k/ ~7 p( u2 F% Y x5. 加载models提供的模型,并直接用训练好的权重做初始化参数
1 N# E4 `8 U2 E6.初始化模型架构
0 g$ ~. l' X" r T7. 设置需要训练的参数# s% R$ c+ [ L* T, v9 _
7. 训练与预测
6 i# ]; M1 g8 z1 R; g7.1 优化器设置
8 v; c) Z1 e0 o4 [7.2 开始训练模型
2 M2 W3 k+ @$ k3 u8 ~7.3 训练所有层# S( M5 W" h. X8 i7 Y
开始训练( v* i% X% D( u
8. 加载已经训练的模型+ ~! @3 Z9 L; C/ A: O5 \
9. 推理
% s6 u' v6 E) t. Y9.1 计算得到最大概率5 ~7 Y) `! E7 v3 T/ E4 C
9.2 展示预测结果/ R+ e+ E) H! S* n6 S% ^9 J3 Y3 q- L
写在最后6 F6 J: W' S- [0 D
卷积网络实战 对花进行分类3 B- Y7 R0 w1 y3 r' H$ A7 B1 f
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能3 J, c! f0 @: _' O( j2 j8 D
3 H$ j3 { w( _1 }4 w在文件夹中有102种花,我们主要要对这些花进行分类任务
2 c; Q- T9 O+ |# W2 u v/ l1 T文件夹结构0 [/ o$ t4 h/ J4 e# {/ n) s
) d4 m, E, e% ^7 s* Fflower_data& Z/ ~7 N7 F; Y% U0 t/ z& B
8 L, k" A, G' Y1 X' X
train. C# M' k3 T( O3 E- x9 |3 h+ [
* O" P- G' {" |0 s& _. W% H5 @* Z: p1(类别)0 `: O. ]! N) q3 ]. O
2( ~: O @. t& k8 h
xxx.png / xxx.jpg4 Q5 H+ Y: J1 H# l" a9 h
valid
, R" l: ?9 ?5 R Y( B
1 T8 M5 q4 b- R+ F# `主要分为以下几个大模块$ S1 Y u% D4 h
8 q2 E! x6 s/ U) R1 m数据预处理部分
8 e; D0 A H) Q5 b- m5 j数据增强
4 o( s s6 I' o. ], ^数据预处理
8 a3 s; e( u8 A# A网络模块设置
D0 u) |2 Z; o! t; q加载预训练模型,直接调用torchVision的经典网络架构% K' }$ |4 ~2 A1 S/ h
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务) Y# t) r, B. g& A |( }
网络模型的保存与测试. l9 \4 d1 Y* y5 E5 Q" t ?
模型保存可以带有选择性
* o2 y* k M6 X; x- X: e$ H7 v- E: p {数据下载:* T9 {* f; q3 K+ B; E
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset m' t/ d& V, Q" r" L
# U1 O) E9 B! t' T7 ^3 \: D! N- L+ t$ c改一下文件名,然后将它放到同一根目录就可以了
' E2 \4 l6 g* A, t; |: E0 \. k, W4 s' W
下面是我的数据根目录 w6 x7 p; a' {5 [
" r8 L0 e" W5 f# l
6 ]+ ^0 W, F4 O5 \* A) v8 d7 p) {
1. 导入工具包 h9 F9 U7 W H" ?3 O. I" H2 e
import os2 C0 {, a0 k# _2 y# f+ g
import matplotlib.pyplot as plt" W7 A" Q \; s5 D+ N k; O
# 内嵌入绘图简去show的句柄
: Z! ?3 J& \5 [+ d1 G( q G6 \%matplotlib inline 0 s4 O+ N+ t* f( }
import numpy as np" Y- z3 O% V8 W+ A* T
import torch
- V: Q6 }- W& u* Ifrom torch import nn
" p2 F2 v1 B, w" i7 B
0 ~9 n4 o* u* Iimport torch.optim as optim( ~$ M# j! ?7 q" k; q4 T# l
import torchvision6 t+ d2 ~2 ]6 x3 i1 H- k2 n
from torchvision import transforms, models, datasets
- U: v9 M4 F; O; R! `7 N2 ?
# m; x: S! o. V$ k6 Y# x8 \ Oimport imageio
8 k4 I: B& {% M Mimport time
4 Q1 [. y9 |4 v* s: Gimport warnings
# ]$ B j |: G# T- {4 Yimport random$ a6 S( N: J5 V" v9 i7 T
import sys
. c5 g7 \& h' N8 Rimport copy8 ~5 A) }; q- M, l2 F: ]
import json
& X5 q4 O7 t6 H4 }' k; a( d" t; Xfrom PIL import Image. X) Y) _6 H+ N$ r
R, I: l2 w: p4 z b( e
& m( w( J* l4 k: A
1% f/ M$ [# d" s j, t3 A
2
5 k% M( r( Z. J3
0 B- a4 X: b& O& j/ |4
6 ~2 K' E- \* S* e) ]5) g* ?" P: i; T* }# _0 s7 f/ j
60 q, W( F) @# `7 y
75 }1 }0 p5 b/ e8 T& J0 L( n
8" Y/ d4 A- e# j; e
9
& p( d* N8 J' T! ^105 K+ _+ {6 Z; D8 f$ B3 q
11. L, ?/ O0 O" D; L
12
) ^; V9 T; o$ o6 R& {# r, q# ^13
`7 H: ?8 F" d7 Z# a14
+ y7 x0 S. l7 I- s3 A0 o, @7 m" r! a15
H2 F) w( G. ~2 F0 w16
$ H( C$ [" W+ O9 ]17& ?5 k3 T# R) a! Y: r( C7 q( M
18
2 M4 z7 @, ^+ t. T19
: z { N3 N8 \: k! W9 l0 _20% {/ T" {: ?- l5 \
21
+ O) n( v" `! N; K) ?( Y g( }2. 数据预处理与操作% P0 u& a9 C" W% ~
#路径设置9 z$ E. T! @2 ~2 n7 w3 ]
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
5 W. h, ?* H% N8 j" k9 ?/ A/ n ~train_dir = data_dir + '/train'- s* r3 P, j/ ]/ J* c) G
valid_dir = data_dir + '/valid'
& b* K$ p, l6 X& b# y5 r, H0 X6 R1
6 P3 {) M( h/ g21 K$ O) j& F3 B5 m! M' F
3
7 |1 x, B$ o. R9 j/ w6 H4! B& T% o) w/ u: q
python目录点杠的组合与区别( w4 m: n2 X: `2 t% q
注: 里面注明了点杠和斜杠的操作
) ~+ ?* L7 S' O' q* Z- m+ O( u1 i& F
3. 制作好数据源
" n) } d7 p3 sdata_transforms中制定了所有图像预处理的操作
0 H, E4 L5 N: d/ n T0 k' |( xImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片" Y4 i! j4 {% y
data_transforms = {
7 j% T! I' t' k8 O1 ?5 } # 分成两部分,一部分是训练; l6 b& V4 w5 U% w
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间+ h$ s0 C, K) ]. a: U' b
transforms.CenterCrop(224), # 从中心处开始裁剪
4 r6 [ O, W' ^2 ?/ c0 g # 以某个随机的概率决定是否翻转 55开0 r! z) K8 m1 b# e' z
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
- Y' u- e9 ?! ?. k transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
' O6 s7 R: I% ^$ L. y # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相* U* o+ s' G" p% T1 Y* U
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),, m% q' {. @; |9 V- Q7 T
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB2 e* N0 d, A! U9 j1 x3 U" x7 z! W
# 灰度图转换以后也是三个通道,但是只是RGB是一样的
: u# @! m& h) T' h/ A transforms.ToTensor(),1 {4 `( k2 q# m+ Y) ?
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差/ ]0 j ^* H. N" a) Q
]),# Y' L6 m) p/ f" K6 g# j
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
8 X. T3 Q$ {! { b 'valid': transforms.Compose([transforms.Resize(256),& F c$ `1 A8 P' |* |) x! g W
transforms.CenterCrop(224),
# Z7 J* s& [1 U/ S transforms.ToTensor(),
! e2 w; T& c% Q) p( A4 g transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
: Y* z+ y! ?% |$ @- P7 F& t9 \ ]),5 n" J8 ~' C( g& {" }0 k0 `6 v
}
. X) Z, b- }5 a% @4 `4 \- ]$ m& K) }& L% {! q" B/ \* p: X! l6 U6 M* L
1+ `2 H4 m# |, c# `6 L' R' {1 H
20 \9 y" G. r9 P5 p
3
- x4 r. l, x$ p9 J+ F* a4: t+ w& O& I' q, m& ]3 ~9 u" I
5
6 v F, y) W- E5 c5 W9 }; P2 `6
6 u' i4 R3 v( D5 g# o- W7 h7& G- c# w; ~: h/ P
8
/ P4 K& C" A* v91 y0 F% B- B1 ^7 o
10
* o' [4 U+ K4 U1 M5 j( ? c& U! Q11
$ H8 U; |" m) u: D7 C122 }6 ?1 N7 V4 ?' q
13) o9 y5 g$ H0 Y% [% a
14
4 r* B8 x( u6 T; k1 q* M+ b15
) P: Z: b- r# c7 u16
$ C& W+ b0 M0 Y17
" V7 P8 k. m: D18% S8 j/ Z8 W; C6 D+ i
19
, W3 @4 Z0 i9 ^8 a% U20
7 k; t. Y% z0 `4 {21
8 N$ n3 i3 W) Z1 t! a4 h) W) A" lbatch_size = 8; _7 h: m- B3 p9 B% I5 b
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}" i, }1 f4 ~6 P* i( z+ Y8 u. W
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
, B) s" m3 r; y* Zdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
+ v/ s( { b8 a; {" _$ l. E" ~class_names = image_datasets['train'].classes
0 S3 j3 q; ?* q% c8 E& N
- r" W3 M# m% L4 E- B9 v- N" T#查看数据集合2 X- Z" C6 A! u. y; b! D
image_datasets2 _7 D \# J1 k0 z' r) ^3 q) h( W
# T# T# P; B! @2 k
1
" U5 @# H- h$ P" V. b& X2
3 u |. g& [" a7 u& j$ X8 @" R3* x \0 M U0 e/ X5 U
4
/ l9 Z G5 x. |# W6 v5 Q) n! [5" a8 ^: F' W {7 f0 ?4 U6 T7 [
6
. d. @: w3 b3 e7
0 i, ]" y% J1 u4 ~80 V- `' T) U" H3 J! g5 Q+ a; Q
9
. {5 D: g& y5 [6 z4 o& k{'train': Dataset ImageFolder' Z' F% I! h: J$ F* y) Z
Number of datapoints: 6552/ `5 A" v- P _8 r" `* c# |
Root location: ./flower_data/train
. i# ]8 Q6 m% u. J6 N StandardTransform( e u, [4 b$ y) O$ Z1 d! x
Transform: Compose($ I! z9 j. C- s9 s/ W
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)7 C* J4 U8 X7 T: Y
CenterCrop(size=(224, 224))
E, T3 t0 k0 c5 O+ l RandomHorizontalFlip(p=0.5)6 }2 B: C2 D* D3 q) e2 m* i! W& M! R
RandomVerticalFlip(p=0.5)
) r+ G0 ] r* t5 Q/ o ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])2 B5 u8 ], @6 r, S
RandomGrayscale(p=0.025)
/ N. y; i3 x7 \. d' o# F _: _ ToTensor()
+ S8 j% X4 j6 U3 B Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
6 j0 [( D& t: N: e' E+ t" u( T ),% g7 ?9 C" t! ? g, r
'valid': Dataset ImageFolder
. T* h% B7 k7 B0 _1 K# S, E, A Number of datapoints: 8189 W0 J7 n Z- J+ r P- `2 S! _
Root location: ./flower_data/valid
/ I, n- x( h. g! _7 l; X* ~* c+ o, C StandardTransform( h9 G1 ~! \, M8 W. I. g, I( J) Q- y( N
Transform: Compose(7 h+ I/ u& S7 @, ?* H
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None). \1 S# T9 |1 N8 u' K2 `+ j( P9 e
CenterCrop(size=(224, 224))$ \6 W- b( u1 T' `
ToTensor() |+ z6 `( X) a
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1 C. M0 u& J# v )}+ j/ [* k8 o' }9 c9 r
+ Y3 S' _- ~$ ~8 ]1 y- a
1
0 u2 U: G! o+ R. z- E29 _ F I# x$ S0 e8 @' }! J: }
3, V& T- r3 k$ t, g. J3 J, o+ b- ^) d
4
$ m& D7 W4 [/ g9 c5+ i( {5 q2 r5 \) e) o
6
" f' e+ b. f! y+ v8 v! U+ H; E7! ~5 g/ m6 T3 M; A5 C
8& i, d6 a0 s' Z
9
$ a6 a# E3 ^9 f W10
1 Q0 a0 d/ I* \. {- Z/ M, `5 t11
4 G9 z5 i! T5 h+ V3 j+ M1 ]124 q# W! p, ?0 [9 Y5 g, Z* ~
13 |! z8 _5 F' d7 l! e: |3 i9 u
14
x: V; W3 s2 X6 W% x* [$ X+ V. b15
- o; E% [) r! Q m, z8 z% s16: k3 c& s# U q/ m7 ]3 G
17
- e, r: P; R+ o! {' H185 A. e# n' o" \3 t
19
$ V2 v. L- w6 q1 e. T209 V, D3 h4 h0 x8 n. T1 D! {
21& @2 t$ L5 i/ h! _
22
& `0 A0 q& n* A" Q9 G23" _& x/ d8 ?4 U
24, `8 `! n% j9 Y8 ? D- f# J) R" [
# 验证一下数据是否已经被处理完毕, |0 |, d f* t
dataloaders) O$ r5 H9 B2 c& @
18 _0 J2 D7 P/ C4 z- R0 _' Q3 \
2
; {+ i m6 p! n" _{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>, G8 d* [, O# F' t. O% A( S! k9 `. H' T
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
g6 V( k0 l, G3 n3 l: L1
% O- f w6 @3 o4 o. {5 u2
7 H5 k; K9 P$ W. ?dataset_sizes
7 ?9 r; k% S+ w" u2 P6 j10 K, v) j9 m$ T5 M7 z
{'train': 6552, 'valid': 818}
2 R3 W; O5 g0 z/ |; O- X$ X17 i% _ i6 F+ V4 \, M
读取标签对应的实际名字
, {7 N( Y& r% o) N) g使用同一目录下的json文件,反向映射出花对应的名字5 a9 x: @' M T" N4 }
0 ^( Z- B0 H7 }2 G" M- O, ~with open('./flower_data/cat_to_name.json', 'r') as f:
0 g- S" N2 T7 z4 O- b. P1 O cat_to_name = json.load(f)/ G+ V: Y6 J7 t& F9 d
1
8 @% Y% N% i* L3 K. l2" F( c8 X2 V) ?; k
cat_to_name2 B* O9 q% S7 A
18 U* j6 j8 N+ i( U, p* L
{'21': 'fire lily',7 j2 o, `4 C- H; r+ j4 U
'3': 'canterbury bells', u$ z( V, G1 @
'45': 'bolero deep blue',
* p* v& }0 u _) u" L1 ~- o '1': 'pink primrose',2 {9 E/ i" K, v. {, f
'34': 'mexican aster',
$ M0 `6 M6 \: c! y7 ~ '27': 'prince of wales feathers',
8 p/ r$ P# V. V. n% U- @9 a '7': 'moon orchid',3 ]) y) H8 Q1 y3 C- N/ m- o! k
'16': 'globe-flower',
5 r% H$ ^& T c4 q, ]' ?/ K '25': 'grape hyacinth',$ P2 W( K2 ?) d9 M9 J
'26': 'corn poppy',* j5 u9 P3 Z/ @( f* a5 J$ }( } Y
'79': 'toad lily',
# t/ @+ @# O4 Z '39': 'siam tulip',
. G& m: r$ O* `- T& z '24': 'red ginger',& Y3 c2 f y3 z1 H" z
'67': 'spring crocus',2 J+ n. G# v0 }% F! h% b) L
'35': 'alpine sea holly',
8 G1 ]4 R& t/ Y8 w# f '32': 'garden phlox',
1 M2 k3 P7 N _8 h '10': 'globe thistle',$ u; u; U0 W4 O$ B
'6': 'tiger lily',
0 N) w+ W0 v7 s p" J, [" _ '93': 'ball moss',, y" U- G. f& K# c( K
'33': 'love in the mist',
) X- `) C: v0 P9 i7 b '9': 'monkshood',
, I0 S$ k; i0 b+ U$ B# I '102': 'blackberry lily',
9 @9 s6 y$ \- c- c% n '14': 'spear thistle',
; ~+ ?9 h" ~! q! ]# T# l/ [ '19': 'balloon flower',
, X) p( O2 M6 _, ~. V '100': 'blanket flower',
) j) p3 {7 r" e* D9 r '13': 'king protea',/ G7 x% q8 P' M' P% _# ~
'49': 'oxeye daisy',/ u6 Q( H1 H3 \! j ^) h( P
'15': 'yellow iris',& r! \ {4 P4 f$ R
'61': 'cautleya spicata',9 |5 W9 e+ F1 f1 r# S
'31': 'carnation',- z8 D& A( L' ^* G5 {" y
'64': 'silverbush'," [/ _$ \) L9 F" U& o K
'68': 'bearded iris',
1 o' V. _ d" }$ L5 A! C2 Q9 U '63': 'black-eyed susan',* t* [0 b. p# L8 L0 p
'69': 'windflower',; l% w- `5 |' k6 ~8 U
'62': 'japanese anemone',
- j2 p- ^$ a6 G( Q G/ r& S3 j; H '20': 'giant white arum lily',
" k0 C; Z% `/ Q8 c+ ] '38': 'great masterwort',
8 H: s6 U7 n7 f6 i2 i '4': 'sweet pea',6 }8 u6 z7 V4 r
'86': 'tree mallow',6 W5 s- U+ d$ D9 b+ N
'101': 'trumpet creeper',- d8 b+ `' N) K* t. G
'42': 'daffodil',9 x( Z. n- u# }' Q6 x
'22': 'pincushion flower',4 I$ Q) c) P" `9 m* {6 N+ x
'2': 'hard-leaved pocket orchid',1 f v- U2 _5 Y
'54': 'sunflower',
! A3 h; g0 U& M/ g7 C/ d8 k# K0 o '66': 'osteospermum',
6 ~5 Z4 e& f! w \# N0 O& c9 P' E '70': 'tree poppy',
4 c; j" @/ h/ H% Z '85': 'desert-rose',& P: r! ^' \: G6 e" d
'99': 'bromelia',
* x- b8 m& n& {' i" e% H '87': 'magnolia',# N9 m6 |7 \ W
'5': 'english marigold', y' i/ \9 v' M3 T
'92': 'bee balm',9 Q0 w D( M( k
'28': 'stemless gentian',$ ?" N. F) m- \. ]( y
'97': 'mallow',
5 }7 u& ?: E+ G4 ^ '57': 'gaura',7 z1 [ o9 o8 v0 \% F
'40': 'lenten rose',& l/ Z: M/ F; @& V
'47': 'marigold',
, n- T3 f! U* M '59': 'orange dahlia',& q1 {6 j4 |# T: ]) c9 z
'48': 'buttercup',
7 B6 U& G+ E( y. j% K. y& u g '55': 'pelargonium',
( ?3 y6 _: h. H4 d% U- Y '36': 'ruby-lipped cattleya',9 q9 {/ f1 {- A2 `: e0 E4 O; N
'91': 'hippeastrum',
* M0 y/ Z8 q0 u/ N '29': 'artichoke',- F! z4 e7 s+ D4 k2 i& l$ x
'71': 'gazania',
$ a& c! H+ D6 |) }& R4 t% D0 Z7 w '90': 'canna lily',% k) m! H! o' a8 x/ k& K& N% ^
'18': 'peruvian lily',1 i7 m2 Z5 ?6 w
'98': 'mexican petunia',* P4 T1 D3 E8 Q, U2 R
'8': 'bird of paradise',. T; ]% m- U6 H! I# i5 n
'30': 'sweet william',. [9 @* K4 z7 }9 Y
'17': 'purple coneflower',
% ~" B& a. q1 V2 t '52': 'wild pansy',
+ K9 m n% B" U6 } '84': 'columbine',% K! F& r: f! z4 L$ i- O4 c
'12': "colt's foot",. E' g" h+ k: p* d, V9 ]( z
'11': 'snapdragon',
" W( r4 \) O" V) E1 x4 ]* T '96': 'camellia',8 g4 l# e# R' M$ j, { M. O1 x
'23': 'fritillary',9 O: q- s3 Z' J' F9 |, V
'50': 'common dandelion',8 k R4 | C9 i
'44': 'poinsettia',) b% O: @7 c6 r7 S% N2 |+ N2 G
'53': 'primula',/ e6 `$ Z( e1 {$ `( `3 a6 l7 ]3 @- [6 v
'72': 'azalea',
q7 E/ p& ~/ M '65': 'californian poppy',
3 _8 }% c; _1 F '80': 'anthurium',
, w* H/ W; M% t( D/ d9 x '76': 'morning glory',3 \: Q ^) B5 V! q5 l
'37': 'cape flower',3 N, k4 [8 p+ h/ p
'56': 'bishop of llandaff',& d' p4 [- N, \ P8 C
'60': 'pink-yellow dahlia',# d, B' G) ?. w, {' S0 O
'82': 'clematis',. |) M w2 p4 Y8 W+ o
'58': 'geranium',
# [6 w+ h* t: t3 ^- D% ?1 P8 ^ '75': 'thorn apple',+ e1 Z& c! y$ |$ K* Z( {
'41': 'barbeton daisy',
, }( a+ F2 Q0 a4 Y/ g; H; z '95': 'bougainvillea',' z2 h& j& h2 ^3 g5 b
'43': 'sword lily',, @ E- m* N8 a8 @4 l1 F! k
'83': 'hibiscus',! h$ v/ y q+ h& S% H! _5 [
'78': 'lotus lotus',* I4 e( @7 c! e4 o
'88': 'cyclamen',/ m' s( m+ ^( e2 j' M5 V4 f$ c+ d
'94': 'foxglove',/ g' j8 G! h& ^1 Y; `0 w
'81': 'frangipani',/ H6 }2 t9 v) d2 h+ P2 m7 p
'74': 'rose',
9 ^; p$ f% U8 G4 q '89': 'watercress',
. _9 t; d- r# |+ K8 N1 ~6 ^+ z '73': 'water lily',
, f: P9 b Y4 h' \. L' U( D9 A! G+ d '46': 'wallflower',
6 D) c G0 ]/ ^8 ~5 r '77': 'passion flower'," R( F# a4 P8 P% { ], j
'51': 'petunia'}
0 z/ W$ m) {! ^! Y; m6 }# _/ j& i8 W' d9 G0 H' ?0 ^
1% K, J! d& L H8 c5 u0 {
2; i/ B/ P7 K7 a& h- ~6 Q* a( E9 d, n
3
! k) h" T5 G) G4
! p! o' Q9 J) f+ |, h4 c5* I" f6 t' O+ j7 i1 l4 V
6
0 u5 F+ D) Q" i, S) }, V* A76 b* ]% R& |0 J. q1 L( Q. P- B3 S
8+ L7 \$ N8 Q7 g: s8 }
9
+ Q5 @- o+ t- y10
$ ]# s- R$ U9 x* ^, V11
1 W0 |4 c, S7 Z, k1 [& j5 j$ D12! x% I" ` l7 Z7 [% k# Y
13
0 o; a$ h: }+ o9 |2 ]: I% F2 Z144 [3 n" A. u3 @& E* I( w
15
l9 s. s Z" Y1 c16
) U/ Q, ^: V* K- }: I5 N; v0 C17% M% ]$ c9 L0 w" w4 n
18
4 m4 P5 G7 H7 b) @19
# |9 `! `1 y* ]0 W) I20
! d" \% x. I" t/ u& V21
3 ^# N" {, U( f7 d9 z% F6 ^22# v ~3 @8 Z2 P9 q
23
2 v4 a7 e9 f4 G' v" {24: `' @: [% [( _6 m) W8 {
25
' o* }2 b5 A5 a, a. H9 C* n; V" F26
" ?& n! {4 X$ ]" Y6 x$ j7 b270 [$ k$ ^2 M* q p
288 u& e+ g' [: }+ |% y7 ^
29
! e/ Q6 f6 A& G$ e: D, V7 w30
4 K2 d* Y. l/ |9 l* C! [31
! M- h- c% C1 C32* j9 n4 e1 g- \/ B Q9 L/ U. y P0 ]
330 b1 b/ V+ P" [8 [+ N+ f+ l
340 E5 `- x7 ?8 n) F! c5 x4 x
35
& l1 R8 K% q1 T l6 r8 m36' g" S* f/ i( F9 w1 ^4 h
37
9 J N1 o0 v: k F* F2 B5 I# K* k) C38
. X7 z6 K# p8 o) t. l; m* E39+ ?5 K4 g0 H0 n) ?5 A
40
# a! l, u2 Q- k41
) i# i, |9 O) P; n0 v42
U& z1 B$ y- Z* j43
# G8 G) m5 ^, G( x44% x2 @" S: ?0 l4 Y& {6 d# d
45
7 A/ K. y5 s7 `* n& `2 T467 M7 X% A$ w' ~- U [" X5 J
47
: t( V% w0 R: V4 D I ~481 e7 S# E1 J; h
492 F8 F1 ^# d8 v6 N: [- y) n1 c
50
- j! S, q# s; S519 q4 h: T0 j) K3 ^& K
520 s2 R, W Q8 E6 B
53& R, K% W# H+ ?7 m$ g
54
% P, V( z4 p' ?551 p W/ K {$ k7 @
56+ I; [( S8 R# [/ L% p7 t; s3 E! P
57, y% r2 {! L; a" G. `
58# t. C O* ~- Y: ~; s i0 `; X2 V
59
. F! F* z) I. T% M- U! v% k9 }60
( U0 D& n: O9 |3 k61
! R, E4 p& F: o6 W' H1 s z/ y62, a; ]' {, \, Y* W! [3 v0 s
63
. N# Y. s6 F- ~* `0 m% U' a7 T64. b# _* { z' X* y" {
65* i4 e+ P5 h4 V6 b4 H
666 C3 @& Q; R6 l* X4 j
67
6 ]) L$ `0 z6 h# k6 B68& l2 X/ G( B: S- l
692 n* Q+ w9 ?6 ~) a* T3 B' R4 a, F( i
70
9 M% A; q% W4 i3 W715 y/ l J) K! m5 h
72
# _! _8 Q( C \% s m73
" `' k+ N; M- L; B: a: ?& S74/ b( @. A3 _9 }+ i- y7 H" U
75+ s0 A( K+ h" z3 v5 @& j2 S0 l4 _
76" F5 s9 Y n2 n" ?/ y( s# S
77; m6 g2 w+ q4 P: O
78# D/ G" k! \. L3 b$ j3 K j
79
4 q5 l( f- C5 v6 C- S% w80- M) u# t f: a, b
81
3 h% v! f# c" I( M' q5 M4 Q) Q824 j* r; W" }; O& G$ w0 c9 A7 p
83
1 Q9 r$ l" l6 c; P8 U4 t" d84- S& g" H/ d- u* T
850 a; T4 A1 X9 \2 v$ ]
86! `, w. o0 b4 w% r- p/ ~
87
7 |( V0 k- X5 T' a887 ?# o- W* A6 e( ]+ T: C4 X. J
89& [) M" B( g5 o3 y6 ~
90
# Q+ \% w+ N4 Y- H! F7 Z/ W- f919 G) B' H9 t2 y9 Z/ P
92
5 k, J& `. I8 {2 ?2 Q93
" Z( r% ?! ^7 p* \, y$ \! @: m94& C* n: P* v. u" f& a4 j' j6 I
95
: }) `5 `6 B( E. C8 {96
8 J" Z; Q" ?+ S7 w97% y- z+ q4 v; A# E% c
987 J P1 q* w+ Q2 y" f' A# l
996 f/ y( e$ r# u# S9 J6 g
1003 a$ ]- x, {' |$ p# R' {6 U
101- K- K( Z6 p- q& K7 N3 h+ B* D' |
102
1 C. A8 X2 z9 t; G, _9 e: `0 F. n+ i4.展示一下数据7 b4 F6 `- L1 D' D4 B. r/ D
def im_convert(tensor):
' D' {9 M5 w+ [( ?! b2 A """数据展示"""
) n1 p( H8 P6 M- E6 C' q/ d image = tensor.to("cpu").clone().detach(), y9 m) L0 V% C/ s, I
image = image.numpy().squeeze()
: Z# J9 G2 T7 a1 O9 ~ # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图. o$ W9 E" y0 x( j8 S6 T) o2 c Q
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
5 P5 j3 B$ z v ~, `8 Y9 Z image = image.transpose(1, 2, 0). a, z3 f" d* t& z- w. N
# 反正则化(反标准化)# d9 O& }# Z& ^& ^
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
$ D% F/ u* W* f X+ I: o5 W' D
# 将图像中小于0 的都换成0,大于的都变成1* b; F3 J1 G4 s* H$ M
image = image.clip(0, 1)
) v- K. u) s5 o4 g A4 C! }
" F3 c# \5 L% v7 G6 Z return image
9 f6 d# D& {) `8 I3 `( V- i1) E5 \" _8 [/ [& ~3 ^
2
8 Y1 K" {0 Y3 b) i+ @" X3# [0 a g; w( b+ \: O. E
4
3 f. v. X0 `; I& X55 ^$ X) ]* I/ h
68 z- D/ N6 i) i8 a9 ^, z
7
9 B! }4 I, a! P: s. G8
( I5 e* Y1 d5 b+ y9 `6 s9
9 A Y& Z5 [3 ~' z# k104 a# x1 r. m' P- v( `# _
115 J0 C- n1 W% t1 l7 F8 i1 g
12; ^; x1 C. V P
13
& X( F0 M: W' H: l14
$ s5 a, R8 x w' e9 D5 \% r# 使用上面定义好的类进行画图: B7 [, [. m1 n$ M0 |6 T8 ]! m! L
fig = plt.figure(figsize = (20, 12))* J* ]3 L% {( y, u. E
columns = 40 [) r4 _; \6 k- c- y7 X6 M9 ^8 g
rows = 2
; n7 z2 W& c6 k2 v- Q1 D! K v" e" m% @& x1 I
# iter迭代器
( \) P: e7 s" y$ j- W- Y3 N# 随便找一个Batch数据进行展示
* o- N: e+ ~# d/ _, kdataiter = iter(dataloaders['valid']); B. u7 v& r5 P$ Y& Q5 E
inputs, classes = dataiter.next()
. v8 u. R7 Y$ ]- k. O$ P1 G& ?! ?; `
5 P3 d% s; G2 u) u- `' K2 m0 ?for idx in range(columns * rows):
. S8 Q5 \5 S. ?* ?& j3 l ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])0 ^8 K3 E' B5 c" S% j% K
# 利用json文件将其对应花的类型打印在图片中
4 t$ K( O, R* p4 t ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])# _2 e6 P2 x) i/ @$ \6 B
plt.imshow(im_convert(inputs[idx]))1 O- [1 J, S7 B. z F
plt.show()
3 w8 J( s/ _- j- b: x
3 i( j0 W# w1 Y; l3 [1. W- T4 z3 N0 g: \5 o% }
2
3 y( I& u7 i3 N3 R4 \/ m38 i9 ^ ?. I5 @! L/ A
4
6 l* h2 _, ~: i, t# `6 G5
* c: u# K" B4 G2 B( }# j6
8 }! A% Y8 o; }7 t2 r7
" G/ ]& \& `' i# E83 U: V0 i2 P/ m6 D# T
9
5 {0 D$ D3 m: |10
. F; W0 c2 `) n11
8 }; ~3 I* {7 V- }' L& X12
, B B% `7 [3 c+ @- f13/ f' g5 ?7 Y" m5 b
14
& e1 b. U% S6 K8 Y! g; d: G$ k15! x; D! }6 L1 @3 E
16
* W8 e6 s# a8 q" [5 ^' ^
! R3 k5 g- ? t6 x3 d8 r
# S# b2 U8 z# q# z- r' f5. 加载models提供的模型,并直接用训练好的权重做初始化参数
7 d5 G8 h% P+ i7 }$ w8 e8 pmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
+ f3 w* f5 P8 ?5 T' {( I1 M# 主要的图像识别用resnet来做; h" N) @% }5 r4 c8 V8 F
# 是否用人家训练好的特征
( c" n" e/ g; P8 I" f1 yfeature_extract = True
1 @$ O. t2 Z$ |% v1 M1
6 f; Q! L8 M/ e, O1 {9 y) U/ |27 B9 h) X) P5 {$ r( n7 K0 N+ f
3
. d* p/ M- _8 M+ P8 T4* `$ S; Z( E+ D: Z9 C5 ^
# 是否用GPU进行训练# A( [. d4 r0 A f' g. u4 C
train_on_gpu = torch.cuda.is_available()/ ^' J+ T' t6 ?* D+ _% i% R* H
- ]" H( |8 D/ B1 r/ ]3 n$ B4 ]: \if not train_on_gpu:
0 q$ L$ C' \' e5 k/ ^9 g2 f- p print('CUDA is not available. Training on CPU ...')
) m* H! G# O4 r1 Zelse:% p3 Y' `* Y& R+ g, E+ H
print('CUDA is available! Training on GPU ...')* I1 z, |$ H. i% @5 v9 C9 E
- B5 k+ [, @& A" l" G
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')$ o. t# A. ?8 b; v8 E
1( r0 j6 [# e1 i" h6 w+ T9 w4 I
26 e, v: ~- `: ?: E* p& z
3
5 W+ k3 @3 x$ q% A4
9 G0 z7 ?( ~3 r, ?' \7 p/ m, g& I5
5 t& f4 L; ]. \4 _( s68 p- t5 r. @2 Z W: ^: h# {- n. H
76 u2 E5 Q! X% z0 W% d0 [9 O3 x
8 a( G0 A( l% G$ z2 Z. @
9
1 T9 a/ X7 c3 ] \CUDA is not available. Training on CPU ...$ V8 x$ A: `7 H! K Q5 @- S
1
6 S" O4 y& }$ } Q# m# 将一些层定义为false,使其不自动更新
, A) Q- H. x5 V8 ]; c: sdef set_parameter_requires_grad(model, feature_extracting):/ ^9 U$ e, Q$ B4 X& e8 \7 U
if feature_extracting:
: U7 d% M: R" A2 ], ^, z for param in model.parameters():
- D5 _6 o4 N. i/ }; ~3 e& }- m param.requires_grad = False
) E' R4 r: x2 F7 L/ _- ?# [6 I1
9 t6 V( d$ c3 M) b% u* f3 V2, T' B. n+ w( b" b: j* m$ [
31 K2 y: Z% o* d2 z) y
4( E, z, `) ?2 k9 s; j
5
4 ] m/ T2 N( [% z# L |3 S& l# 打印模型架构告知是怎么一步一步去完成的
3 ?0 v% U2 ^: I( s# 主要是为我们提取特征的
- Y; f* c0 Z1 i8 x' p5 Z2 x# @" d- z( ^! L( b0 Y
model_ft = models.resnet152()
: [) n- K+ {% J; @0 D- nmodel_ft
3 y9 B( ?' r+ H* J- q H2 ]1* n+ C' t3 H8 |6 {9 ^& `# Z' p/ }( k
2
" ^- ^( r/ A* h+ [0 u: `3
2 B# [- ^2 L0 g" w. | X, A4
5 C4 X$ Q& p, _2 a0 }' u: n3 U9 Q5
. f- I/ V- o' L$ }7 [: q2 OResNet(, P% T+ W8 m" `7 C: p/ a& y9 {
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)' O. w/ h( O6 Q1 K6 T+ ?
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) I+ E, z& y. \) w
(relu): ReLU(inplace=True)" k" K/ @& y+ s9 g
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)) E6 \' b) E3 b4 `4 h, [3 v
(layer1): Sequential(
' r. g$ Y3 ?/ ^2 Z$ L. z. R" B7 x (0): Bottleneck() ]; z4 N0 q. N$ Q; Y" B
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)4 X! m2 ^0 T8 |$ n" W7 h
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
1 B5 v# l1 @% G' h4 X [8 P (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ ]2 g/ ^/ S( k8 z& G; ^ (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
' R# v% X; U# b7 J0 | (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
' Z* w% @1 y# u( H (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)' D8 T; A. `( [. n
(relu): ReLU(inplace=True)/ \/ ~( C# {4 E( U- L
(downsample): Sequential(
}* e! Z5 _$ b (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False). n& `& k; B H$ }( W% }/ g) h
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ A+ d/ u2 N4 k, N7 o$ `) Q ); D" D( I. m+ m. v' O _! h
)
, z; x1 F# T6 {$ G3 R! x7 p/ q. A中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
# B3 \* g! }, `: y6 Z! R (2): Bottleneck(- w% i# O. c! i/ |
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)' E, x. X( \! @3 x
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)( K7 C0 ]6 r$ {, z# C: \$ }8 c( e
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) U- F0 }+ ?" h4 _
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
% ~ U& T G! z( i: K7 r6 q (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)( Z/ j- P# i/ n% W$ |
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)4 x) v8 ]" E7 {
(relu): ReLU(inplace=True)8 H. n: L4 t& x, M' ^6 C
), W0 y" { Z$ _
)
) p* x6 q" r$ F, p# N/ p6 S (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
- M) q+ v/ N3 d) X/ }5 ` (fc): Linear(in_features=2048, out_features=1000, bias=True)3 u/ S* k5 E: i, o+ ^$ P3 W
)& o! f2 E/ _, F
5 k6 L% _/ S6 y& }& e( H/ N11 P1 @6 V6 W: K" G/ ~0 j
22 Q) g" U T7 t" j
3
7 D! G6 e/ F1 a; `4; r; w; E0 c' w- f( D$ B
5# W+ {( o2 H# T$ l: ~+ U/ p. x
6
( E# t' K9 Q% \! r7 A71 X5 G: n6 H8 U7 v- x/ g
8) Y4 F2 N/ E4 [
97 N9 p" L) t! Z3 }3 z2 t
10
: \' R/ \4 R( N( c# Y6 J11: d& u4 _8 A' _) ^2 {" M
120 l9 [8 k# _/ N* M! I# [
13* O0 g2 f& S2 k% ~9 e/ F! T; d
14
$ Z* w/ }" m n6 ^9 B3 T15% }6 k: S8 n, o B+ c
16# _/ Y" ?, f3 ~2 J' Q6 V; p
17
! {+ I, o1 I$ H6 B) y5 g185 T3 C4 y8 ]9 C4 n: g8 H3 c J
19
9 }7 v: D( t( c- {1 Z |206 C7 P* T' M( q/ c3 }8 B
215 q @7 T9 U" O, m1 O. w
22
5 L6 V2 v( ~6 a23
2 e3 b4 R- M7 j$ i; I8 I240 ?0 A5 t! }& X+ }# A: W. e
25
; s! S$ I D' E# T1 n3 o% `7 I26
; j# b ~2 j" N; J4 D5 ]27
4 W9 a. r/ |, u) h, S0 Y' E28" F' w. Q8 ~: }+ p/ a
29
$ O, h k7 Q. N4 N30
2 k* r9 A9 ?6 w$ g9 w4 N31* y: P& b+ x7 s! E8 c' F- y
32! q! n- n- q" b9 z" _( ^
33
/ {' t4 Y: }) V I- B& v, J3 e9 D! M1 E最后是1000分类,2048输入,分为1000个分类6 ^5 l0 J! W9 _% C! w, Z! ~
而我们需要将我们的任务进行调整,将1000分类改为102输出9 x1 d+ `% {; p, O
/ V4 H: K9 x' l3 h0 @
6.初始化模型架构
: p- l. n8 S. \8 Z4 f' N Q步骤如下:
% i6 T3 @! s; V0 \) Z
; v1 h% Z' U9 s6 o将训练好的模型拿过来,并pre_train = True 得到他人的权重参数3 g( k& t( s" u: i" d) O
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)' Q' X6 X% {6 Q
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
& Y# f8 B) ^& P. h官方文档链接
% H! q, ^: S+ Ihttps://pytorch.org/vision/stable/models.html
3 Z5 S& H6 X/ z
& Z R' ` Y& @# d! c# 将他人的模型加载进来1 D5 H. t4 K) s% K* y
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
- v3 R4 r( @4 a M; @) Z8 [ # 选择适合的模型,不同的模型初始化参数不同" ~6 n O n% n4 ?
model_ft = None$ z4 D- O0 u# A' ]4 B8 ?
input_size = 04 ~. ^* E8 U: U# q, B
# ?. w% O# _6 @ ^( _6 @$ i3 { if model_name == "resnet":) B# m+ ^. u' U
"""0 M) T5 r) Z) Y: n; Q) C* E7 k" _! i W
Resnet152
4 @% A7 a# m' n8 i* i2 D """
* h/ }2 M% G" k! [$ W7 n* S( _
$ o$ A8 d; U8 k" [ # 1. 加载与训练网络
) ^! ], `; |8 G: ^ model_ft = models.resnet152(pretrained = use_pretrained)) f2 }/ N# V& Z0 E! B1 h
# 2. 是否将提取特征的模块冻住,只训练FC层$ ]( b" s7 a3 _ Y
set_parameter_requires_grad(model_ft, feature_extract)
. ^& P: C' g) c3 Z6 a' ^0 |( D # 3. 获得全连接层输入特征3 Z" r0 G& G l: s- }
num_frts = model_ft.fc.in_features
# S. k" J/ e2 y _1 E* H$ C3 I8 H4 Z3 W # 4. 重新加载全连接层,设置输出102
4 y- k( s/ V) L2 q7 M0 g model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
& n7 R8 }2 ^3 r" C$ i3 z. V# g nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
% t0 g9 ?4 O" j2 A; p input_size = 224
* g2 `2 t: S& [/ n# i+ N2 x% y2 J/ y
elif model_name == "alexnet":
' [! b8 o- t2 d """# Q# S0 r \- n- F
Alexnet0 w' H! C7 _- X0 w P2 G0 n" J
"""
! t" G. m1 Y J2 H model_ft = models.alexnet(pretrained = use_pretrained)! a, R2 `+ Z6 Z* I: k7 V% S6 K; x
set_parameter_requires_grad(model_ft, feature_extract)* Q9 s/ f4 A; Z/ S
$ h0 u1 t+ R; f' i5 K' i$ c) v
# 将最后一个特征输出替换 序号为【6】的分类器) Z$ L6 D. S; m3 e, | L
num_frts = model_ft.classifier[6].in_features # 获得FC层输入
8 W# {7 W) ?" Q4 \- _7 F) v+ ~' y: ? model_ft.classifier[6] = nn.Linear(num_frts, num_classes), g! w$ X5 _7 B4 c$ n! o. n) V: B
input_size = 224; w. b' |3 Q8 D( v% s; r W9 G+ X
3 K% W! a: b0 S3 R& w) r elif model_name == "vgg":
5 T9 ] `( J& }' Z: q/ \9 h% z1 O """/ S* e) V+ @# f- U q6 c( j
VGG11_bn
4 c! J" S$ z$ B* d* V' p" X """" r; G0 U1 c+ y
model_ft = models.vgg16(pretrained = use_pretrained)" z* h* N/ w1 X
set_parameter_requires_grad(model_ft, feature_extract)1 j P9 S2 L( Q5 Y! W
num_frts = model_ft.classifier[6].in_features
& o" f Q% h0 o5 Y model_ft.classifier[6] = nn.Linear(num_frts, num_classes)' M+ n+ i, G0 x. [/ m
input_size = 224
" i9 v2 F z$ H A; i- }( y* q4 x- c/ \! i
elif model_name == "squeezenet":
! w8 }4 U" G" T1 {! |: @ """
- r! w4 \$ [4 w r2 w- c" x Y: J Squeezenet: G1 V- ?' F2 U
"""- _0 e& d, b: F1 y
model_ft = models.squeezenet1_0(pretrained = use_pretrained)
2 M4 `) d* ]: @. G6 ]" w set_parameter_requires_grad(model_ft, feature_extract)% {* y& M3 ]9 @
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))3 c$ w) A5 {; q6 D/ j! F# R0 o
model_ft.num_classes = num_classes% s; L! g, {3 b% E+ u2 i& B
input_size = 224# _+ q6 ~: |8 W/ P6 d
6 P } n7 ~! p1 x3 U6 H elif model_name == "densenet":
6 u" ^# Q, p9 }. k) m """; \5 n; t0 f: K0 G, y
Densenet6 w$ [2 p! N* V% P
"""! G# D/ y5 r: e( e' [
model_ft = models.desenet121(pretrained = use_pretrained)' Z s' A* S* E' F$ O
set_parameter_requires_grad(model_ft, feature_extract)( o1 d( |& |! M6 L4 r
num_frts = model_ft.classifier.in_features
- e, E) ^! }- V1 {& i model_ft.classifier = nn.Linear(num_frts, num_classes)5 `, b+ t% ]- b2 n% g
input_size = 224
9 `- L4 N' y5 w# k7 t& t+ R, ?# v0 Y2 ~7 r8 M' S3 k9 H V
elif model_name == "inception":
; J; u+ t8 H- r4 o# A' e """
2 |$ U5 j- f: G- P. z2 F" [ Inception V3( T1 V& U+ F: _" y R
"""
* O0 p5 C& n/ L# o' C model_ft = models.inception_V(pretrained = use_pretrained)- L. l6 ^+ l/ @ P4 o2 ^
set_parameter_requires_grad(model_ft, feature_extract)
. o! a/ k" p! J( Z0 Y/ D- Q' t8 Z3 p1 P" s3 C {
num_frts = model_ft.AuxLogits.fc.in_features
2 G% K" J r7 P! ~ model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)3 ?2 T+ D; f% {
) X5 i* k3 V! H+ t4 } num_frts = model_ft.fc.in_features
1 L U4 G" v* N0 o6 S! i; D* P model_ft.fc = nn.Linear(num_frts, num_classes)9 X% `0 }! a+ D$ s. l" c
input_size = 299/ \6 T; ^3 u: q, H. J
7 F; U7 r* k; b9 U else:; ?! ?3 W5 q8 X; S1 f
print("Invalid model name, exiting...")
4 A" G5 h: V% ]2 s! \ exit()
) V$ `0 H2 T3 m: A
1 a3 f5 i0 G( J return model_ft, input_size% c+ \0 Y+ H4 R7 Q3 t% d' e
/ H8 }/ L+ G8 a
15 _+ k9 }' [! X0 C2 _9 g
2$ x/ G7 y. \; i* O
3
N5 g$ r$ e$ R. K4 I6 `7 D% c8 @4. a% z2 y3 C: B* f3 i
5
4 O/ O# l& r0 X$ D5 T6
9 V& G; W% S' O. Q. r. X7
2 u1 L2 G) q3 Q& k88 U- k' f9 S' i
9
T% N) U/ C+ n/ Q- {& A10
& {* D1 ?6 g! `+ e9 F0 r5 [- f11* ?! a& K) ^- t: K3 D7 _- M
12
5 F# { v2 j0 k. O( i0 i! E136 [" p* C3 [1 X4 O% i
14
! E8 ]3 T1 _5 \! K: B6 H15
% b Q9 n* S, o( _16- u% ^5 B$ f9 Z# D! o( I
17& M: ~$ ?0 r* x; O
18
8 B/ i% |! C! S! j/ c$ V2 Z/ ?9 \" ? ]19
1 v+ o9 z) ~7 ?/ v0 N! h8 p* B200 H" ]( ^7 p5 L* d ~" Z6 `
21
1 G3 e/ v i( L2 S" E- @' V& a224 @0 r" `5 v$ b! ~" X. {
23' P: e0 W; d, V9 E% M* v
24
! s) R* ?3 s& a$ B; k7 b25! Q [; X- D/ n4 a
26
$ h u; m" c# Z* }- U27
4 B- ]. G3 a- N5 l3 E" Y28: @+ G" z) h, q
29 j. q* D6 A/ {+ W
30
( }5 q8 M) O8 O e# e6 X( B. [2 {) O: \317 F% h% d8 M. D; K
32
4 b& V0 ?% K3 J8 |, U' f- Y33
$ Z7 L) O T& A% W( q" T P34
# s" d( w9 o! B, F35
$ w' f: \# r Y7 g36! m9 W# }! {% ?: U/ i. e1 K2 r
37
. v y& l5 ^' e5 B38* W/ ?8 x1 u5 L- a) g
39
5 h3 P. d3 _" E2 k! V% u40
" ]0 ~# F3 q3 @9 n v& |2 ?3 Z/ F2 R415 L# C, M- w$ P2 E
42
8 n: t) r* }6 T: t43
2 p$ @0 I# C: F8 U( [$ K& @44
+ ~" u X& B/ ^$ i( _45
: N& f% S+ ~7 y9 B4 E$ m46
$ i; u- W* D2 G477 h* Z8 E% U2 H: m
48+ z' x. Z! |4 T. m
49" V* f/ g' _7 p. c. X f
50
. m& [+ v- w" Q( p- }2 ]51& i6 Y0 t7 X( l# S' y% _
52
& x5 Y% f' M. J( V2 ?53" s5 h& \) J1 x5 T* N
54
+ D. s @! F& x55
5 H: D* j3 J4 q. w, a; u$ {1 v6 w56
1 J8 n% }3 R- [) d% \57
- J: E+ _, k# H9 c- L2 B! u584 a& D+ U. Y+ T# y9 L+ o5 {. c
59
+ K! G: a/ H, @60
% v$ m9 }) n" {1 F) j1 P: {% @61
/ d0 o* P2 }1 a( O626 _) U- ^" _6 L
63& ^6 u; j# f4 L) m
64: T: B m- O8 R! C# O* `
65, g. v) Y+ F- h5 S2 y( a" i) l
66
7 M2 \5 v' L0 u% S- b, \4 ~8 U- V672 e$ D6 ]/ a8 ?: d z: g
68
9 Y" [3 N: {) _69
X7 P. \! _' X! ?4 ?8 O, n70
7 `3 S3 ?% ]: G* L# T5 k+ M71" N' E- `3 d; n3 ?: K e8 B1 h
72% b5 O' t4 O! S5 P* o
73
/ p1 E6 w/ e+ B& O0 c74' k, K. t; d$ s: E+ y% B3 P6 w% t4 E
758 _- I+ c0 ?7 F# W9 L1 M$ w) Y( r
76
. X3 d+ N$ j, q" Z4 {+ H77
) [* \0 h+ J( {% l/ o) Y6 y78" I$ z; l5 T! C* Z/ N8 ]( v3 W
79
6 k2 s7 N- W" O. c4 `+ h807 ~$ q2 Q3 \4 \
81
% B1 M, H/ Q) K2 R- d$ h# d0 j82
: F/ X- C9 G1 i* d; S6 `3 i" k# a83! s5 B( m7 H) w; |; V! M
7. 设置需要训练的参数
2 K! I6 x9 M! F+ T* ]# 设置模型名字、输出分类数' w+ G0 X9 \8 t# k
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True). e" |! v. [! ^
, n& o; _, g9 _, y; y' |& Y
# GPU 计算/ i& f% t8 z/ h
model_ft = model_ft.to(device)$ i5 y! I- Q w$ P5 j6 a" r
) z, Y8 D( x9 F- s# s/ ?; H9 I# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取. G3 r) L2 n% A- `, u; V+ f; \$ u9 `. @' w
filename = 'checkpoint.pth'
7 a: S* Z: N/ M! b7 ]4 _- I5 J5 h' }! U( B, N
# 是否训练所有层3 s! l/ @1 o8 n4 y
params_to_update = model_ft.parameters()4 Q0 J& |0 ?' V9 C
# 打印出需要训练的层
+ K8 ^6 `% @$ a8 i2 \print("Params to learn:")' Y1 G* c+ O7 ?
if feature_extract:/ W1 p( I" v- X( f V# K) ^7 f
params_to_update = []
5 ?. k* U% o# X2 Q4 c0 j for name, param in model_ft.named_parameters():
, z6 c- G* B3 @8 B- Z9 Q if param.requires_grad == True:
+ u5 A' k/ {0 [ params_to_update.append(param)
1 V- H2 p* o7 E! a% R print("\t", name)/ d9 V3 l% ~4 ] ]; m0 V
else:
* G6 F" y+ _) z9 k! N3 p for name, param in model_ft.named_parameters():! J% _; x8 k8 x _) R) D$ E. E+ R' J
if param.requires_grad ==True:
; x/ c& V- {+ M2 C print("\t", name)1 i' z0 I2 M% e* S- L |4 r. X$ g
' I! N5 `$ A. [* B: j% j5 M+ f
1
2 Z& A7 S. `6 i# \& o4 A( L! M3 D2
# b) K i- v" U7 V* }, U: V' f3
4 o3 O; x8 K4 C' m' @* U" {7 Y" N8 m4
5 `4 @( U, F& h2 N3 L/ {6 k5( \6 X- l/ Z. M8 P
6/ l- n- c/ |1 @( Y' P9 A: X- I
76 Z# p; |2 M P: v- X
80 z, U6 W5 [% c( Q3 K
9% j# v: K& J* O1 ]0 s
10$ r r" j( j. }' S2 N
11
4 A, ^; g9 f* V5 H124 ^0 t; F4 i; W
13, g! \) G+ O _* e2 g
14
% I6 j# S. x' C& g/ f& K/ v; P1 E b15- ]( S3 M T4 ~
16+ S) M( K2 U4 ]4 V; t0 s
17/ B$ [* R2 I1 B" u( g
18
u9 U* Q& j* x9 u19
/ n6 s7 W4 }, F0 @" b) ^3 L- }20
' @! q4 N, a2 C' S21
8 F) y* p1 ?1 h" Y: c22. Z- L# i# H$ N0 Y4 W9 z
23
% R( r, T; O# C( G) k( u/ K% @. ~Params to learn:3 S, v; ~! {7 R% N% r' U) Y+ c6 U
fc.0.weight6 G* C- f% O N9 W! \% v
fc.0.bias2 i R. s9 V$ z5 J7 O: @
1. N8 `# K1 G; Y" K1 x
2
3 E0 O! b7 s% V8 _1 R1 I5 @8 s( c3
0 o: Z, u( t2 [7. 训练与预测
- k1 y+ \( h; M n! t! V. `7.1 优化器设置
, j- B" Y9 S; Z( `- ^0 @6 X: x+ ]7 I# 优化器设置
% ^! T" s7 X# V% m" V9 toptimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
# } t, R& f4 Z' n/ @) [8 r# 学习率衰减策略
1 e+ @/ T( \- U0 J1 T% ?scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)" d# n1 L: D+ Y
# 学习率每7个epoch衰减为原来的1/10; i2 `( p3 `: _: ]* O4 @
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
; t0 J7 n$ L, O' F k$ g) h
: E. V. f1 \8 |4 n8 rcriterion = nn.NLLLoss()+ D- b& K5 Q! L& i; d q/ d
12 j4 R3 u, s; f5 ^+ [, |
2) R3 N6 b& d# J# g
3
$ v0 ~* w) a% g/ I4
9 R t# Q) e# J3 S5. `$ e3 v7 X2 x6 r" L
6
" P! N; _! z- V0 ]# T0 V7
4 g# g: B' y5 b80 K$ |: W# C; j7 Q
# 定义训练函数
0 K, J [' P$ P5 X) L#is_inception:要不要用其他的网络
* E( r/ a3 D) [2 t- V; s o+ ~def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):4 q1 v% ], [" g+ @- n* V
since = time.time()1 y ^4 c4 ?5 `, [0 s; k, r9 d; ]
#保存最好的准确率
% W6 q: G6 D# j" ]* ~" c best_acc = 0
r* C- Q4 {/ y" `# b P8 d; [: x """; } I" l. j8 i
checkpoint = torch.load(filename)6 \. w6 ]( O0 t3 [
best_acc = checkpoint['best_acc']
' j8 L1 r* |5 i- B1 d" L$ g" G model.load_state_dict(checkpoint['state_dict'])+ K7 c, O/ G; ^ j: g/ F5 d2 B
optimizer.load_state_dict(checkpoint['optimizer'])
+ T5 P7 R) y! S; N8 R) I model.class_to_idx = checkpoint['mapping']
3 V6 d. ~, S3 _2 T """' R. A$ ^, @3 l- }5 ?
#指定用GPU还是CPU. N" g( l; S0 }3 N# Y& Q
model.to(device)) p/ P$ [$ ~0 m* `7 S+ H
#下面是为展示做的* V* w7 g) Y; Y3 q% Y Q" e
val_acc_history = []# R" R+ A* t3 B9 Q- c1 u* ~. _( p0 {
train_acc_history = []
+ R0 {" g$ ]' U6 j. I" P train_losses = []4 f. b$ z' y9 F
valid_losses = []' `6 r1 ?6 y) }3 R" z" \% |! y, A5 J+ ]
LRs = [optimizer.param_groups[0]['lr']]
- M: J3 y& k* h# c% Y #最好的一次存下来8 u% _5 M$ p$ E: N; @- t
best_model_wts = copy.deepcopy(model.state_dict())0 H6 |- H7 N# j& @. [
2 v$ A& C! r7 [ J' {
for epoch in range(num_epochs):
6 f8 n: v( ?! H0 [. t, b print('Epoch {}/{}'.format(epoch, num_epochs - 1))" t- l4 G! Q+ U) p6 B. o9 ? |
print('-' * 10)
3 Y! F( h: d. e( y+ i( {3 [& {5 @1 M
# 训练和验证$ J% F! d$ T! ~9 j; ?& H
for phase in ['train', 'valid']:
+ m) ^& }6 G6 \1 e, u% i& C if phase == 'train':
U, o6 u6 E! T6 l1 K model.train() # 训练. @0 Q" A( D7 |( W% M( N( ]
else:, X9 C1 v" S% w3 P& H
model.eval() # 验证* n( g9 |# b; e1 X0 @$ J* n
1 l/ K3 f( i$ c$ A3 C j! j running_loss = 0.08 n* b; A) R6 X; F
running_corrects = 0% n6 G+ h% `' F2 u8 d" O: h% Y
: H+ v9 Z+ o, E7 ]2 V6 x+ B
# 把数据都取个遍! j' j9 U# _5 W9 ?2 f9 S1 o
for inputs, labels in dataloaders[phase]:
9 S9 }( s5 I5 `2 V, g9 u" x #下面是将inputs,labels传到GPU" ~0 ?- K( G4 T& ^+ d/ s
inputs = inputs.to(device)2 r- z3 | y @# ~ H8 N
labels = labels.to(device): y" v) o/ R! Y5 g# }( M. `8 m
1 p. B+ `0 z! f4 N1 ?$ N! H1 L # 清零, }) G3 r" Z% [1 h, Y
optimizer.zero_grad()
6 @0 @, ]: `0 Y2 F" ]/ h # 只有训练的时候计算和更新梯度
4 K' K; K! V) b7 X+ u% O! c6 p1 b with torch.set_grad_enabled(phase == 'train'):- t" P! `5 |+ w8 [( h
#if这面不需要计算,可忽略
' G/ A! n& ?3 O' | if is_inception and phase == 'train':; G2 }: E3 L( y5 P( B6 c& K
outputs, aux_outputs = model(inputs)/ b$ {# S4 ?3 j! W" @/ G
loss1 = criterion(outputs, labels)* j- U' K5 t2 j
loss2 = criterion(aux_outputs, labels)
2 n/ D7 J) a% ?9 r* l7 P n loss = loss1 + 0.4*loss2
" V% o" ]8 B5 j- o7 h) z else:#resnet执行的是这里' p2 s2 ~! |" y8 d
outputs = model(inputs)' ^2 z3 U9 i3 B6 ~
loss = criterion(outputs, labels)
: e$ s l% V Q& U! g8 y) v7 E: @" X! q# X* E) e
#概率最大的返回preds, k9 D- k) W, k
_, preds = torch.max(outputs, 1)' K2 ]6 X! ]7 q* U1 c. O$ _
+ k4 z" G2 b: K # 训练阶段更新权重6 C* Y6 N! I+ `- V- ]
if phase == 'train':# l+ `- N8 k8 N! ?& G9 W! R
loss.backward()
( G& b3 O5 r {: _ Z6 S optimizer.step()
/ y3 t5 ~0 u: O8 H" `. A+ a' {! d6 {+ Y; c2 s, R5 x
# 计算损失0 {7 f4 @: G U5 n* F) n* i
running_loss += loss.item() * inputs.size(0)" n7 j, ?- u# M5 r( m3 L1 Q+ v
running_corrects += torch.sum(preds == labels.data)6 E1 R0 p1 v" b: r& o
: m) S, |$ N3 h* u9 j# ]. X #打印操作
/ L/ B% R7 ~, {: L. m9 A epoch_loss = running_loss / len(dataloaders[phase].dataset)9 j2 C" {4 y8 q0 k7 ]
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
: r7 `: b+ C+ a: N6 J) O" e- G" [* F- c& H
8 s6 I! P6 c( Y* l2 Q; n) M8 v% b
time_elapsed = time.time() - since
5 c+ k2 d" L' e* z2 D. D: P$ o5 m print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)), P; x4 Q, _( j) q( r8 s( B- f
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
6 T$ l7 o# Z) p% q6 v5 z
8 X. q4 P# u. N( O g# {" ?; V! e$ ], i& a$ l1 U i
# 得到最好那次的模型
. u! g, \% N& P- u' m% ~ if phase == 'valid' and epoch_acc > best_acc:
$ E' ?4 g3 U6 {2 y% x7 r; {' O p best_acc = epoch_acc8 l! e5 c8 {0 `/ a# C" D
#模型保存9 A- L1 {! J* Y
best_model_wts = copy.deepcopy(model.state_dict())% |( f# g/ t* W; H, C
state = {& l1 m9 }: H- N/ n+ l' j
#tate_dict变量存放训练过程中需要学习的权重和偏执系数
) h/ s+ {/ B$ z0 f. A 'state_dict': model.state_dict(),
7 W3 h. E4 m! q4 K 'best_acc': best_acc,2 @7 w( `2 V" D7 c
'optimizer' : optimizer.state_dict(),
5 c: \- c. D/ Z+ e }6 E }
. |9 k2 e% l/ c- W1 \9 N: A, p. X& ? torch.save(state, filename)5 y7 Q+ c4 i# h% A. I; X. k: I
if phase == 'valid':/ j' z; V4 Z7 R
val_acc_history.append(epoch_acc)/ w7 n, |6 `" K: r7 }! T
valid_losses.append(epoch_loss)
2 D+ E1 p% z7 H7 x, q' v/ | scheduler.step(epoch_loss)
i3 l+ U+ A k$ N# i' E+ p if phase == 'train':
# R- m! Z; P) y' V6 O) _ c train_acc_history.append(epoch_acc)
% Q& ]& F x; s9 K* f# @ train_losses.append(epoch_loss)
0 {* L! J. B1 x1 C8 c: H1 N+ d. j$ W5 e/ q' ?. ^2 ^& m
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))8 Y6 ~( v5 w; O5 l8 c; |/ W
LRs.append(optimizer.param_groups[0]['lr'])7 i" A/ z" H) H3 @# Z! d: p! J
print()
$ x) D5 _5 l6 s8 F
* Q0 o3 c9 c$ x# k7 W% e time_elapsed = time.time() - since9 o" D5 N! Z6 K/ ]
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))8 N" H1 B9 u( U" n" [4 e9 H
print('Best val Acc: {:4f}'.format(best_acc))) @; \; k3 Y4 ]" w, z. v7 _
9 Z8 ^/ ]0 V# z' Q" _7 a # 保存训练完后用最好的一次当做模型最终的结果7 V0 }! s0 Q# u1 Q/ m2 W
model.load_state_dict(best_model_wts)% x% p! r9 K: ^' J! ?
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 6 }+ O- T; B. d3 f, L7 \
, [' Q* K% g" }; V
3 U( o- W) G1 b% `# ~1
' }- d' Z! c) E4 V+ z' l. ^; d) X2
9 n4 B) g: Z. X% O' T5 R3
/ B' m- u* _/ p5 v! r! f4
* F2 D- k9 T: @8 c% R$ D7 Q9 a! o5
1 g6 ?2 z% |' ]; J/ {69 J3 |/ Q3 z( q
7
: c1 d0 m1 F+ f# _) E0 S1 a8) U0 U5 j) ? O r/ G3 N2 G9 V
9
6 F& C- v" U' U8 ~& a, w4 X10/ }1 |# z" V5 R$ q ] c1 J3 q" p
11
8 W1 B4 x2 q) {0 i0 z( z12
0 ^1 o9 {6 i2 x( C! r) @13+ e) d' k- y; K& z% _& R7 o( F
148 O9 v) n0 Y+ ?6 H! }$ W q1 d
15* Y) k. J, g5 X5 S
16
/ l! M" F$ O P7 S17
' M: ^% E: D2 s3 X* e" N! |7 u18
L/ I% m& R5 o- T19
0 A- O+ e. l" f, j) n208 T4 L+ S4 U0 @- Z8 d
21
5 F" k4 F. {, Z3 Y22/ q8 j2 Z5 |0 w/ S: }* @4 J4 l C
23
# ?6 W" Z: d x |' n24
# o! Y b; Y. F) ~9 i' Z25( ]9 w* {2 _5 B8 S, [' Q" j( w+ f/ p
26/ z* q3 M* A L' U3 \. z; v
27& T9 F% H; M$ ~0 i7 ]$ G6 Q8 n
28" J! X9 Q t, R
293 z8 p' q, G! A) K2 y1 c9 d
30' o1 g( Z* ]" {, o0 {, W1 l3 c
31
. z3 W: ^+ n, Q! |/ q323 X8 X, O% V* ^8 D: R( O8 w
33$ u, ?: Y- }- K5 V' c
34
2 v3 D# @+ Y9 P7 [; {35, } @; Z1 N8 c' z" g4 e
36
& P. y6 i* s! _, A! k% @37- j; C6 @6 G3 I
38
5 P. o$ H) K6 A4 J) s; H% s39 @6 E' l9 ]7 i! m: O4 G
40' d- m& B- e# x* V
41
7 f, d) t4 H% q% _0 R: p42
$ {0 t- ]3 T1 d: a9 T+ [43
3 R4 N) w& q/ t, w8 E+ }( T! ^44
- B$ a. X V+ m/ C45
6 `/ Y' O0 O( i# j. Y8 [. P46& L& u( G" h+ H
47
) h- B4 O- R$ l% w$ K8 n48
( i+ ^; F3 X3 d( N. j49! g: N- j( y' j6 d7 L7 u
50
: W/ I! u+ X- i' @- K* e515 Y: U; Z! C- A: ^. X1 y+ z: R
52
7 Y& [5 q% U7 ?) z! I# Q: S53, ^+ [1 u& I) o8 |! i1 c+ u
54# M3 O( M6 W; P/ ^5 ]! X. r
55$ J. \ U+ [; g1 }
56
4 G1 d2 q4 G8 M* n57
- N% }" R* L; V58
5 o7 P2 X# R( x59
9 U( d4 U3 \" r4 L60
1 q5 |$ Z3 P( U61- X: w, c- y# j; U
62* ^% L( g: A0 h6 ~. K
63
$ i% ]; [7 ~3 y% a64
! p! Q' v/ X. R. H. {. L6 J' m) F65, c. t, d2 L$ M: O+ A, P" w0 g
66 n( J* L0 Z5 v# T9 z8 X# p O
67
: w- c6 t3 Y8 ]# d* s5 k68
# b+ B8 U6 x( y2 g2 {69
" o. z) G2 l9 i, Q706 e7 i b) p. N* x
71
# h/ B* B; E, S0 S4 ]726 r' l' ^7 _8 q& _' j7 M
73
4 e8 ^0 ^. K: ]: v* K5 O744 ~. A8 h9 { `- }
75
. d% s8 B% N" y, O( w3 G' H76 s7 z% r: W& c- K: F. F4 b
77
2 T% l* w1 Y% a! t& Y* g+ |3 a) X78+ q$ i" R9 Q& w( a$ W5 U
79
, r' `) _. N8 c, {80
5 @4 e" [/ ^1 T1 Y' g2 w81
, w5 l( I, V8 ~8 R3 \, E82/ c2 J+ ]2 l! l/ T0 M
83; e) x) P* k+ ] D9 Y# e& ^1 \
84
2 G' w4 f7 w n+ m: v1 M851 P" B9 k" j1 ?# ~
86; p( G9 T" A0 m# a
876 q' C$ c4 m' s% M/ K
88; O% S C& I/ h+ ~, `7 ?4 t
89
9 l9 _& v4 }5 H- k. w1 }9 Z# a90& S! H# V. D* z3 S) r! a4 p- j* y
914 Q/ i/ [+ k U. ~6 R. m
92 z) r [9 w0 |% J: Z
936 S6 w* L2 r- i4 [; u
94
3 |% W$ e! V% Z- f1 \& b95
) S& \- L$ z$ E. z5 [5 p9 F960 K# F' A- W5 a- f5 A
979 J5 B V7 |, V1 g( f# y( j
982 b; n& o `3 J1 m# [
990 Q- K! R- W0 N0 d! n+ S; M
1004 {; H! v: Z( Z% X/ E9 Q8 d
1018 y2 d0 a/ o8 }5 \" h# _
102# Q1 B% [: D* M; _$ b; G$ F
103
8 c# f# g% n! ~. u/ A2 V104
6 }) c4 o% \, P105
' d! W; J& ^+ E) `106
* T0 [/ i: a& M1 ~107
3 e! y q# V1 w# x. k108
) y3 j3 l- z) k. b! \1 R3 J& n# \' U109% p3 ^. g7 M9 p: t% c4 [ b/ K. D" x
110) t; X% M! ]! n" `* f
111
: q- l: e/ v8 ?' P1122 V9 z4 ]6 V7 Y8 {9 k
7.2 开始训练模型
9 b- j+ {) e+ Y' d% S! P我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
, q) f, v9 y3 K2 k1 a! Z: s3 {( p2 g* ^" y, _1 W6 a, T; Z8 X; `
#若太慢,把epoch调低,迭代50次可能好些
; p: l$ g, p0 K* Z( G#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
, N% Y% w3 u0 Amodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
+ `- Z% }/ \9 }8 q, N' {- T. b. {+ e9 R" f4 t, A. K/ c
1
6 Q$ p/ ? O/ m/ L+ b2: C1 A, _" I, M3 n) p) W* r, E
3. T) A& ~! Z1 a% Z# j" `: L% C
4: I" V* g# n% M1 @8 S- r! H
Epoch 0/49 r* V- v: R7 H7 e' I' A5 k
----------
; A" w! l% {! i, Z- ZTime elapsed 29m 41s. l4 B( @( S6 T
train Loss: 10.4774 Acc: 0.3147
4 K' n- q4 [% u L2 M3 PTime elapsed 32m 54s1 ^' K+ Y4 h6 I* ?
valid Loss: 8.2902 Acc: 0.4719
9 s M+ }2 X! g( R: \5 JOptimizer learning rate : 0.0010000
& V/ m9 v3 S% ^8 H; { Z/ T
* |* C2 ^ b' g0 j! fEpoch 1/4
8 H; ^; [2 v0 @% L1 W& e----------
( j8 Y+ I; g8 }( @1 E+ n6 ZTime elapsed 60m 11s8 e% e- q7 w$ V5 T$ H& `; E
train Loss: 2.3126 Acc: 0.7053
9 |" E# u% m) ~" u7 W' ZTime elapsed 63m 16s; |: k$ W* `' f( W6 k
valid Loss: 3.2325 Acc: 0.6626$ b X8 i! n% w, U2 g3 x, x
Optimizer learning rate : 0.0100000% t; Y9 I: w* q# t
9 ]8 R/ [ i+ [
Epoch 2/4
m+ B. o9 S4 B/ q8 C/ J# P A----------
2 H/ }2 N0 z2 f yTime elapsed 90m 58s$ Y/ ?% b- |7 U/ ^, t- K! a# w* O
train Loss: 9.9720 Acc: 0.4734
4 P1 w+ r1 Q. g8 ^6 iTime elapsed 94m 4s
6 U( O+ v8 e7 b& ~valid Loss: 14.0426 Acc: 0.4413; m& W( R( E& ?0 b) U, S
Optimizer learning rate : 0.0001000; C9 F+ I! k1 O6 \: {( {& V3 ]
% W4 d5 W- y) N/ Q0 y4 l% e
Epoch 3/49 T9 c* T% A& U/ w5 Y O
----------
% W( u0 G4 S: g2 y! T6 zTime elapsed 132m 49s! k9 O, I. M: d
train Loss: 5.4290 Acc: 0.6548* ~) z3 x' C B7 O" u8 T
Time elapsed 138m 49s6 F* d; @$ I6 u7 }
valid Loss: 6.4208 Acc: 0.6027
- o+ U# M; o2 M& j4 \5 l1 f, wOptimizer learning rate : 0.0100000# G3 L/ L1 R, w$ Z8 F2 I3 e+ W
4 N0 w; Y; d3 Q! Q6 B+ w4 j/ j
Epoch 4/4
# }' C3 `0 V) x* z( k' @* t3 e----------
& ]- n7 ~' z8 k/ cTime elapsed 195m 56s
+ A* N9 j7 x& ntrain Loss: 8.8911 Acc: 0.5519
$ n) Y; j$ E# d h1 gTime elapsed 199m 16s
2 w& Q* T3 c" Z9 \. Nvalid Loss: 13.2221 Acc: 0.4914
9 O( [8 c! I% p) a3 l% x& r: \Optimizer learning rate : 0.0010000
$ Z0 g* Y8 |8 _7 P2 h0 o: f0 L' l* h
6 W, i& i/ H n6 h, f% aTraining complete in 199m 16s
/ e! L- L$ D2 M) d1 Z# xBest val Acc: 0.6625921 X# C% k2 v2 E
3 E8 z0 |0 O* g" @2 w; ]& x1
1 t% b ^+ P, Y- a" v2
9 a5 ^3 ]2 w. J3
; \0 X4 n$ d) c4 v4: i5 [6 o* T- x
5
5 w! A) W! ?4 y# z1 v0 M6
$ R& E! {0 O; Y: |. h70 c A4 T2 {0 a3 T$ @& t
8
1 W+ m8 P+ K- D9
, O' I4 J; l% S10
' p/ @: d) H! Q5 l11
+ m6 E/ Q! f$ F, j* o( e1 v12
; F" a! A( B( {. l137 K- H, O$ R: T9 R/ z+ h4 ]
14
9 X. ?. _5 u2 p15# L. c* K% f/ @5 P3 J
16$ ~3 z8 F. r- \5 ~- u, ]% t2 m
170 `# O' H; X; l8 U
18
+ y9 `, i: a$ w% V W( V8 b' V199 Y3 q& V* Q+ N& Z
20
/ I) `0 `; C/ w8 J0 ?21
; D* ^2 ]% K/ W6 _9 K n! ^22
5 U, J0 M/ V0 j( O- W23
3 R8 Z( D+ `8 z, |: a5 V24
, q) Q9 ~: P+ W25
( o+ |% Z4 \7 Q& C$ f268 `% |& ~ @" P7 z: F! y( e+ v
27
/ l- m: S5 P( R28
& K2 P6 e# N6 Q+ U. D* n29 I7 H! C& ]! {/ M
30
: O) ?, m: [2 H, Q1 I31& O7 E' u8 g2 V) b
32
$ t, z+ K. B- V$ b/ i' c& p335 A; L7 ~6 n5 F- w, w6 @2 f' _
34; l, B$ f9 j/ w8 |% r, W
35# _. t9 |- R; P( V5 \+ e; }* n
36: ~4 H# g8 u. l% n
378 H$ S- [$ \- {2 i# _
38
/ J7 r7 b0 I2 F2 U3 Y39
1 z- v9 G* G O402 C; M) e' [, K; c& |9 y
41
; f) r4 t# A2 x5 G3 ~# M6 z42
8 h) e2 ?- F: f7 _" E7.3 训练所有层
$ H. g( j% _& S6 E* N# j# 将全部网络解锁进行训练
2 |3 ?( [0 O# `+ c* {& Hfor param in model_ft.parameters():4 T- _! Y- K6 c
param.requires_grad = True' x" e3 |% W& n5 k% u( o) B8 [9 S3 o7 e
9 j# K" P* a8 W: t7 b$ E
# 再继续训练所有的参数,学习率调小一点\" w( Z0 C9 _4 @1 D% H( ~7 e9 h: B
optimizer = optim.Adam(params_to_update, lr = 1e-4)' x- x+ D0 o, B
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)+ @, v- m/ \5 k
1 p8 ?5 E5 y" `" _# 损失函数
5 R1 c) s, G0 ]7 Hcriterion = nn.NLLLoss()
C' f' C0 m5 _4 f: Z" ]3 y0 E3 w1
4 |' F3 j& f# S, \& F+ ^# L+ d2
8 ~2 j. [, b( W2 N* N# ]3
* @) F! S/ Q1 j1 k: _! N! b+ m4
6 I( b- c* H; y% O$ Q5% v. M8 S1 V2 A2 f7 ?
6; t6 P) ?7 z U7 F7 v& T8 ^
7
3 S# J" O) r" E) g8
^2 ^. P0 k1 G* o; e9
9 E$ {# T$ s0 r( P10
, k ~( k G/ b8 ?2 R, v0 c. h# 加载保存的参数4 z# E" P- N0 b# f6 O8 W5 ~
# 并在原有的模型基础上继续训练: q: s. R T2 _- E' V
# 下面保存的是刚刚训练效果较好的路径0 c; ^6 | x& V7 {
checkpoint = torch.load(filename)
; L2 h6 \5 c0 d9 abest_acc = checkpoint['best_acc']
' `6 |+ \" }8 l$ gmodel_ft.load_state_dict(checkpoint['state_dict'])1 P% D) D/ U6 D! f4 x6 d
optimizer.load_state_dict(checkpoint['optimizer'])
+ l9 `; F Z2 A% _1
% g Z: S% W% I7 X5 ^% _' Q) a2* H0 e7 D9 D5 W8 _9 w, f ?
3
2 u& i2 D$ F2 H4
7 E" u. T6 p" A: G5
& T+ Q" N- i# S8 w6
7 f3 K+ b: `! d2 I9 L; o74 y7 b) o' U: k
开始训练, n" Y1 c9 F% D# H N
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
- _$ Y& x4 c" T7 C/ b% K Y' G; r q$ |7 Q8 h7 W b1 a
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
( \5 i% S* P3 p7 U. T6 Y1 \; H3 i% D8 v% L1 L' ~
Epoch 0/16 }0 m% j- }' E
----------
( `* W4 T' Y3 S5 Y7 {4 g; |Time elapsed 35m 22s
% }. R, Y) _1 Z/ t5 n9 Dtrain Loss: 1.7636 Acc: 0.7346
" ~0 J' X% c* ?3 |6 ?$ _Time elapsed 38m 42s( k, ~$ J+ p8 v) Y
valid Loss: 3.6377 Acc: 0.6455
# ~6 l7 z( F# f; f3 D' H( O* `Optimizer learning rate : 0.0010000
: ?0 ~$ h$ n5 t6 T3 d7 o, ]' D" \* w" M3 d
Epoch 1/1
6 y( e7 G, t, D+ {----------; R+ [, J" [4 d
Time elapsed 82m 59s# Y5 m4 P2 H) y. Z4 x X: l
train Loss: 1.7543 Acc: 0.7340: a9 C7 E' J% J% i0 S2 ` B
Time elapsed 86m 11s
( [6 r# a5 b2 P; N- }. gvalid Loss: 3.8275 Acc: 0.6137
" c0 s) O* N, b/ V2 z: t8 N: dOptimizer learning rate : 0.0010000; T9 B% F1 S1 `( |7 u+ \
3 W1 Q; ?3 S3 } ~/ L+ }9 c
Training complete in 86m 11s
) ]/ e" N q7 v8 X4 u; N% y1 o" D: ZBest val Acc: 0.645477 V! ?2 R3 B4 ^, ^- E T
' I8 r/ O. ?9 c3 |* u# N
19 S9 r/ T4 | g( M
2" w% h/ e# D! K' n) `
32 w( J0 R7 x/ E3 f1 S( m0 P
49 r, v2 y& \3 w: T
5$ _" k0 i, p6 G( s* \4 R5 D5 G
6
* k, i9 P2 P$ V7
$ @9 J7 Z2 ]$ J) F3 c89 D( f1 Y+ X! q! N% V
9
, }; x* K6 G& M/ Z108 R! Q9 S: I: T* g- t9 H
11& k7 l0 w# F8 }8 f& {& a( Q
12
' A' L1 p/ E3 ?% X13
. n9 \4 k0 i4 |# d14
8 A! ^# C" u6 j* ~2 Y! [" D! h159 |! [1 j( W. C* ` Y
16
+ G. Y- m$ O( l( l9 n17
& f( h, `. S5 ^' t183 v: f- r+ m; v
8. 加载已经训练的模型
3 t$ B4 I- ^3 J) I g相当于做一次简单的前向传播(逻辑推理),不用更新参数: [3 @) z7 `6 y
- f) e3 k" c8 M2 F2 Bmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)) c; z ]% x3 R/ Z2 r# y1 l5 Z
, v8 z' S, ]8 J5 @* P# GPU 模式
% }, ^2 C% q7 c5 [- w3 omodel_ft = model_ft.to(device) # 扔到GPU中
5 w& e0 ?2 B1 @3 \
9 Y7 Z( i$ a5 v u+ t& b: e- B$ O# 保存文件的名字
* z) M1 c9 l3 z. @7 @, yfilename='checkpoint.pth'
4 \3 b: q1 `, R( @/ P/ p; f
& R2 T; s! J' E1 g4 F# 加载模型9 r9 Z. P* E# L* K& A; `: W
checkpoint = torch.load(filename)" l9 a( j* K7 ^
best_acc = checkpoint['best_acc']6 M2 R# N4 H5 y. M% }: i
model_ft.load_state_dict(checkpoint['state_dict'])
) D6 f3 j# i- j$ \- e% e! K15 A) G2 k y. L
2
t& x+ e& M, v& Z3% v1 E3 g" [& }; t3 T7 X+ d- f
4
1 }# Y4 `% F. ^& \3 M3 }- d5, V7 K" C& W0 @% j0 k# u
62 E! j+ }* w3 u5 M0 w) c. i
7
6 G2 C6 @7 |* W; L& O; X+ N8
8 U v: J q6 {; Z& \! _4 J$ ` d9) G2 |& ? \& e1 x `( H* v
10
" O$ q z; R/ }, q11
- g( z( h# p* O3 I' I4 S4 p7 c12/ B$ S# ~$ D. `3 X
<All keys matched successfully>
J6 @" |+ f2 z( x: l0 v/ N. D15 \! t7 }8 w. x) C) @ b
def process_image(image_path):
* H" @% M( L/ k9 F, V* N2 Q. I& E # 读取测试集数据
6 Z2 \2 h; }. t# [7 z' h img = Image.open(image_path)
# Y+ i& O: q+ G1 {$ O5 S; S1 l # Resize, thumbnail方法只能进行比例缩小,所以进行判断
0 r* x3 y" |: ^) M5 B* ~ # 与Resize不同' u7 z; C4 X9 G
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小; R8 |5 m: U1 @8 M" W: }# h( {
# 而且对象调用方法会直接改变其大小,返回None
6 l2 J) A0 \* H: Y1 l# O if img.size[0] > img.size[1]:
/ r. L( C# G1 A$ o img.thumbnail((10000, 256))
: v4 u) O( s" R2 k. K* F g else:. G8 @/ `& W c! d% \- O
img.thumbnail((256, 10000))
$ s! o- U& M; i, k
1 X& c; T$ i8 Q: K. d # crop操作, 将图像再次裁剪为 224 * 224
5 U! e! V. i$ u# q left_margin = (img.width - 224) / 2 # 取中间的部分
# F+ [( w' D7 X b( q f5 | bottom_margin = (img.height - 224) / 2
0 p/ L0 G0 ?# X right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度+ J* @& } T. u% F6 X0 P# A& H% K
top_margin = bottom_margin + 2242 Y7 O/ d6 D/ \. s$ c) s9 Y
+ Y* w: w6 b V: w p- t
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))# Z: W5 x$ a* `
6 R- S m, a% h. P
# 相同预处理的方法& g9 R0 P9 s5 Z+ f2 T) L# I) e
# 归一化# u2 x4 J/ {! \% s! ]
img = np.array(img) / 255: z0 ~- V+ t, y6 z
mean = np.array([0.485, 0.456, 0.406])
: t2 N) O, z* q0 x9 G4 J0 J std = np.array([0.229, 0.224, 0.225])% j7 h e$ w% h9 e! y) u: C
img = (img - mean) / std! C8 U8 Y& J R
" d0 |" X4 v$ v) p5 N
# 注意颜色通道和位置; v. c& C6 }9 O; {* o8 T
img = img.transpose((2, 0, 1))
! [5 d" x/ ~/ H7 s$ B/ ~# O0 a, u& ^( S3 H/ [0 e+ n% X/ c
return img" [+ d0 p8 q0 H! W
& v/ l% q4 b( M* |% H' Ydef imshow(image, ax = None, title = None):1 F( m. X; a( u+ K( Y0 U
"""展示数据"""+ i, G8 J' ^, N# A) s6 g8 P- |
if ax is None:) ?& {- ?: `8 [5 ^: c+ ~1 M% o, K
fig, ax = plt.subplots()
5 a' E: P6 M0 D! o! h
! @" }$ S3 d! N; G3 H4 L # 颜色通道进行还原5 k( c8 @, h6 W0 a& [. L
image = np.array(image).transpose((1, 2, 0))
) Y0 `. y1 T# E" E7 B
, z: `& F4 A6 a% F) _ # 预处理还原3 b% ~9 x, t4 q6 T# X- |4 Z& R
mean = np.array([0.485, 0.456, 0.406])& ~3 d% W+ N L% _) S( n: @
std = np.array([0.229, 0.224, 0.225])/ s' ?4 e7 Z7 }) O1 S# S
image = std * image + mean
2 I- V* L1 `- [* a2 W; |( q. ]% C image = np.clip(image, 0, 1)! V) \' J: q8 I' t. D% H' `
R0 j$ d9 ]5 b ax.imshow(image)
: |0 k) U7 p/ h T2 n5 `# Y* s ax.set_title(title)) j" ?7 U' u$ _0 n0 Q
% I$ z5 _/ r6 j/ k& S
return ax4 O) }1 F6 B5 W3 M% Y( `2 l
6 H B3 Z& Z* a; Himage_path = r'./flower_data/valid/3/image_06621.jpg'" |/ U% Y' r; `0 B. y* q3 m: S: t$ k7 }
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理' B8 D+ J0 x' D5 O# J# }
imshow(img)2 N r7 q5 U& Z: _
; f# {5 K9 M9 s1
! f3 p. ^" b; K8 h8 D0 L# j6 e26 G' D/ t+ k& P5 G* r0 J' W
3
! h% ]; {, ?% |6 _6 N4
) |6 K: X% L# l; E* |* q52 R* W# v3 ^) i* S7 v& \ O& f
6
1 q1 T; Q! f" t8 L) l ~. t7 U! ]76 @6 q8 R% n& x8 K" C5 y4 u6 p
8
* Y s, t4 Y1 r6 {9
2 m: N' ^1 i* U100 `3 n7 k3 Y$ G. `3 V( k" ?
111 ^/ b v# a: w& Z: ^. a8 y0 F; L
12
' @, P" u+ x: Y13' f" q' H: U* c: q+ B. r" X H+ V
14
5 ~# D6 w# p: l$ c* x+ _( @154 Y9 R) X# R9 b8 G6 q7 \% S
16) O, X# B# E. Q4 L& m
17
$ N; P# W7 o \$ d; |18
; P2 v1 b x: g7 j7 c: I19 b! Z: n' V: {/ e. @
20, ]( P6 z$ ]: [# {
21
( a# k$ ~: O7 W. j22
% ^& S k T8 M$ R3 L23
$ y' J% V' p# @) C$ O) N: f5 l- A241 D! u1 X$ A; d5 l1 J! a
253 ]; |- q/ j5 h% X& d
26
$ C0 J9 S. P+ n0 B, q27
9 i1 x1 z" D0 w' r( L+ B28
' e# ?/ q |+ }29
+ I# _4 T+ J+ m! U0 N30
# C/ D& c- I, v, B9 N( F31/ @( c: @6 j0 z$ K2 L
32
1 G( \" W/ j3 |33# H1 I2 j* B R# n+ j
346 Z! u# T3 b" O# w& k1 C
35
& e- N+ W7 O) l" s% d& ~- f$ |1 W36- h' M9 }5 C/ c/ `
373 I8 @ F5 D4 ^6 D
38+ |# m0 ]6 x9 F% t) `
39
6 k, T) i, O; u G40
$ v3 y3 h7 v! L$ s: J7 S2 `41+ V+ X, A5 o& j9 S
42
8 X: O8 t' ?! H* ~- U" ]: l; s3 W43
: y, i% i+ f1 ~, d" G44/ z- B7 V6 l Q
450 V8 J! z& h- U4 k
463 z3 P" N- G, h; _2 q/ I
47, [3 d Z1 M3 F2 S( e6 G3 `! n7 v
48. [# W0 w! l! U1 U* t# ?+ D
49 n t M- @/ G) r
50
& o# u! D' K2 ~& A% `51; l! V0 w5 k9 U% E, `% w
52 [* x. G B# u1 C/ n
53
* ^5 h+ n; l+ q+ E2 Q8 P/ ]54. x6 ^; L+ I! \7 B* @
<AxesSubplot:>+ I1 V- c! }, T- S d* [1 N$ U
1
) }% x6 @8 L& n% p$ Z! I
; K! o, l: ?" B; ~3 Q* J. P* y上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确# w. I4 I9 I4 F% h, M
+ e" O. _+ L, y4 R
img.shape
- E( L* ~- Q; A+ G% i1 x" ~1, Y7 S+ x3 A9 }
(3, 224, 224)
* c) S4 v3 w' q8 q5 S; e0 L8 ?1' D- L% ^1 J) o `* J9 ~
证明了通道提前了,而且大小没改变7 G$ M( _' a: h$ ]# M4 M e8 z
+ n6 r+ x3 @- z3 f1 Z9. 推理. ?, D" E2 C; L: _3 F3 h: b4 f
img.shape
: h/ \8 Q" F: K1 E6 [9 Y! r5 B
" R' Y7 I0 \) B# n# m- A( [+ r# 得到一个batch的测试数据
9 ?( F# E. H5 @4 V& s# Tdataiter = iter(dataloaders['valid'])5 ^0 l1 ]# l. x2 ~- M/ }* W5 s1 E0 `
images, labels = dataiter.next()
/ R4 Q9 H. j, w9 w
, p0 ~: v. g; X4 kmodel_ft.eval()
! g+ l- c, @4 A/ R
, ?3 I3 M Y. |% q1 Cif train_on_gpu:9 }, Z: c; C5 U
# 前向传播跑一次会得到output/ c0 f' G, V p6 u% v1 v; G0 l+ ~
output = model_ft(images.cuda()) w! h. A/ e. h
else:% e3 b' a( { w$ Z% o
output = model_ft(images)$ R0 _- \5 `; @* d
- X; H3 q" C1 e9 Q! `# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值 P- y( v2 o2 Z% F$ ?
output.shape) W9 N7 v* R, Z$ [8 N6 D9 ~
7 N( }7 z5 _! h1 \
14 K. q7 {( _5 a
2
( A7 H5 @: Y P4 ^7 Q9 H( A3
) }6 m' {$ @2 u \0 _/ j$ X) s3 k4
8 v" ~" x3 i) C7 h! E+ _# D. S5! M8 O& J- p N0 Q
69 h/ X7 Y7 h: y7 g, X
74 W8 }# s {# Z. R3 @8 T: t
8
0 B3 |) U G( d9 w: i7 Z/ U- O9! E2 b7 X3 K) x# A$ G# H
10$ a+ w0 O( v" @' O" B0 Q
11) S. c p* i. O
12: a* Q4 o3 z# ]# W; x
135 G# A+ r' F5 P3 C
14
6 G/ v( t: b! n$ s+ e8 b6 g15
1 t8 J/ {# b5 e' N. ?16
% q! ^% \2 y6 ztorch.Size([8, 102])5 o7 V5 A2 d6 \4 s
13 n6 Z4 c: q- s7 }7 Y/ R: `
9.1 计算得到最大概率5 a) `4 J( O7 @! v( |# U- X" l2 Y! F2 Y
_, preds_tensor = torch.max(output, 1)% Z: y5 b5 c0 R' U/ N( K k8 H
- M( l7 y+ ?" f5 f/ \preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
6 m; `$ b9 o5 f& C: i; A `9 m6 ~; ]1
+ s" {+ W5 t x/ x, \3 |& ^5 e2
9 Y) ^$ N& _5 u2 v( W! z3% I" W3 z9 {; m6 L
9.2 展示预测结果
3 U' @! _4 w1 h- Gfig = plt.figure(figsize = (20, 20))
; b& K2 J2 k j* Q" c; Ocolumns = 4' H7 U; O- j+ J% z; G$ v
rows = 26 \: T$ @, a! Q" V9 X6 c
* ]' U0 c+ _8 O! N: Q6 S+ s+ afor idx in range(columns * rows):* w1 k5 m% J( f5 J |8 F; i" S4 o% X
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
" a# K" Z9 O' L8 N plt.imshow(im_convert(images[idx])), D- d" v, [) l: w& {
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
; r) }( b' \+ q1 p! d color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
2 J# v9 y) H8 Z. c' W. cplt.show()4 w5 L" j3 w3 u: Q5 {
# 绿色的表示预测是对的,红色表示预测错了
+ [/ j% _8 p6 Y8 R5 S+ u' u# l1& H2 ?8 s8 l4 v
23 j/ U R: {) j/ J3 `
3
& `5 L2 r" o Y+ Q4/ @, T# f* ^* P6 T) t
5
0 f# v& N. A' E0 b8 U2 G# ]6
+ `: c2 J$ L$ m: s4 `$ P7# J" i( `' @* U( d% F
8
6 P/ p' I" L# O% b! h% E! q9
( n8 o/ t7 i2 f" w# W10
2 B$ i. P$ N) M+ a+ r0 g: J11, v5 _$ g: f9 W3 ?7 z1 ~3 K t) l
( `) f! ?8 O z& U& q+ a% s
3 S: [& p% V7 h6 d9 {- G6 i* n, C
* Z* X# J' ^) d6 w$ s! u————————————————
3 O7 u0 A5 }; y# O2 ~4 v! s版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
9 ]* \. R2 x4 U/ V原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
$ W8 b- B- b4 ?: D" W8 ]9 [ N5 b2 q. Y$ s4 }2 }; l& { X
+ O, g; k6 M( j
|
zan
|