- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563353 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174229
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例5 I' F0 z: h4 J0 ]+ m; x" i
* a3 F4 W; X! w f3 J" R
文章目录- T1 X" b( E1 ]( W& }- `2 W
卷积网络实战 对花进行分类" `; }' S" \) ?: B/ y4 a% _# H1 c
数据预处理部分& Z' \: T; w% w" V) T
网络模块设置
. B, x4 J: N- G3 [) D( G; [0 T% l网络模型的保存与测试
: \% O1 m8 i" \3 Y. z数据下载:. P* ~1 _/ @8 c- }" K
1. 导入工具包
& I" t- A4 d L7 a' z1 ?7 _2. 数据预处理与操作
+ A$ I4 B6 }2 B# X3. 制作好数据源
0 f9 f3 V' s5 N读取标签对应的实际名字- j O) W0 J' ^. h4 ]( C4 e
4.展示一下数据
8 I+ A9 u1 ~) |( e% F. k5. 加载models提供的模型,并直接用训练好的权重做初始化参数1 j) _4 a: k, b8 j2 e/ [
6.初始化模型架构
/ @; u' }# G, _* P# P* U7. 设置需要训练的参数
7 C8 s# q/ j( N! _* C% z4 t7. 训练与预测
# X/ V& [. n K8 ~# _0 r+ [9 h7.1 优化器设置
$ t# b# K; o2 n7 n8 o+ K7.2 开始训练模型) O2 m* {, g c: X% z
7.3 训练所有层
& d2 R- H1 g5 Z# ~ f0 K6 o开始训练4 ^& ~2 v X7 z+ [8 f
8. 加载已经训练的模型
* k, C7 T, @& T5 Z( M/ a, U9. 推理: h4 J( y' X8 B9 o! u# y
9.1 计算得到最大概率- T# f6 `& u* d. l9 n8 a' S0 U
9.2 展示预测结果
2 W* j/ q4 D1 o% H, U& K7 p写在最后* F4 o% j8 [* G' g. i* R E% Y) b
卷积网络实战 对花进行分类7 z) F2 `5 c( L. [% m
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
4 ]% d5 K' @6 N# V5 O @3 ]5 g* x
- k8 c- f8 g3 m7 _ l) Q& i在文件夹中有102种花,我们主要要对这些花进行分类任务
1 y- f7 `0 ~0 |6 h文件夹结构
7 V. [9 t, l4 I4 }, I4 I$ z
2 n5 P) _$ [6 m# ?" U' k$ {/ mflower_data
, ?. W$ p4 @* n8 q/ u. ]6 W8 P) V* o) j0 @6 ]/ i/ P7 e# P3 {
train
: ^# F5 _+ V" k; [& k) f. p" Q) M, u& ~
1(类别)+ t! F" x! }# o3 M7 I
2
5 W# m% @& \8 X1 Oxxx.png / xxx.jpg
% c$ v3 F# l: @5 M( tvalid$ H0 j( h" z* i5 ^
% o a" t+ G( K5 f4 x主要分为以下几个大模块 \+ S& p; W; L* m
8 i" L/ A' a2 ]6 b- G; J$ q2 v$ }数据预处理部分
9 p7 X7 X; d( } K* I: l数据增强) x# |3 k& D+ y( p; n
数据预处理
7 d2 s3 v/ B4 ~# Z6 R" }网络模块设置
6 Q' X4 J$ O$ p7 D" v加载预训练模型,直接调用torchVision的经典网络架构3 v' G/ E7 v0 ?; w& `
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务7 \2 ?' w8 d: e' ?
网络模型的保存与测试9 }$ F: y% L) l5 {8 R
模型保存可以带有选择性
h. P% O" ]3 \3 S8 J; w数据下载:
2 T- e1 G6 g1 B! b. t& {https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
/ o/ o5 V2 B) [0 D6 R
% M T' O2 m( k1 y改一下文件名,然后将它放到同一根目录就可以了) T" V, u/ J% l4 b5 T8 d
* d- S% O7 I, C2 }: j下面是我的数据根目录
0 P! J: R5 }7 H' ~& U. y% ~/ }& Z! c, q) |2 N* o1 _) W& F) A: @, X
" k. B* d$ B& m [. p1. 导入工具包! v& i5 T8 f% Q# k8 K
import os
! e3 D3 Z, f" L/ H& Mimport matplotlib.pyplot as plt5 K7 p! [1 G5 m% b; ~ C, g
# 内嵌入绘图简去show的句柄
5 I: s, i9 G8 n%matplotlib inline 4 w% |- q) v6 c0 ]. O
import numpy as np
% h* u/ c1 {5 n0 R; simport torch* y! i2 I B/ r% m( @2 W
from torch import nn. p7 J5 n) W) x* {4 U) E4 ~$ O9 G! V
6 h; w$ Q' N, r
import torch.optim as optim9 l7 M' G7 L6 b$ r9 r- F
import torchvision* W* u; w4 y7 Q
from torchvision import transforms, models, datasets- X/ Z: ?8 k, A: T7 T
$ e2 C! _* H" l7 X% J, W
import imageio
8 W- `, N. {% B# k+ dimport time; u$ l; f$ M: ]! ?' `8 x8 z
import warnings
! G# c* H9 a; J. i8 q$ timport random* Y1 T5 o6 h. F* x
import sys
. n9 O0 ^1 H0 ]" K1 Simport copy" k+ o/ b6 {2 n y# f( [. x8 }
import json& ~- }0 P" T. [6 a9 {6 Y; q
from PIL import Image
0 s4 r: P2 p3 q5 b/ s W
* `8 y/ a* c- g+ j2 o: T* b# }! C( o, Y+ T/ I/ i& X, a: m
15 r( @% o B7 O+ s* u
2
! w4 a( f* s5 e) P3# c3 a7 [. |8 p4 l4 i
40 Q$ J& A0 G+ s, R
5" F* h' u1 I/ I+ F1 s& f3 C2 k
64 T+ y- B6 O( b& C
74 L0 l! i, k7 s3 Z
8) l" g9 t1 ?! R5 {7 |4 e
9* t1 I# g" Y7 Q2 y$ Y2 e+ q7 f3 H
10
! T, C$ A$ m; z+ \4 D& ~. n117 t+ W& l. F. M* y
12
6 K& p# x; W- ^" }* i0 e13
- F- D+ B- D# }1 B, w14
8 P2 N5 O0 G0 T( _" R6 p( S. ~15
5 ^+ Q0 s! s$ D% F2 ^5 t# e16 o& J, e8 Z5 q4 f5 V
17
/ g% e$ O- ]+ `3 y: }5 Q2 K6 a) `18
0 v4 q$ T4 ^( U' ^2 J) V19
/ `' q" q' r8 S& P) p$ w) ]20# p7 Z z u, A2 j' m; }8 m }. _
21
2 z/ S6 I% v7 n% w6 ?2. 数据预处理与操作5 S$ [% i4 ]# H9 b8 _- R6 ]# Y
#路径设置7 w0 G9 K/ b) b) |
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录$ j0 P' T4 b# o U8 {8 E" r
train_dir = data_dir + '/train'7 F/ f6 _1 D' l/ t* w8 \+ S
valid_dir = data_dir + '/valid'7 Y+ l* p- |; ?
1
N* m: N$ B) E8 |2
$ e. w* ~; U9 ~. S) _5 w3 l) D3
, j- ^! u9 ]/ _4
" n4 }. F1 }+ p( I( Spython目录点杠的组合与区别1 k* p. v; I8 z+ s
注: 里面注明了点杠和斜杠的操作8 @# d+ `& f% p _3 ^0 T# ~
- P) I8 ` }# K/ J, U. A" A3. 制作好数据源, P# ^1 ?. ^; D7 ]7 _
data_transforms中制定了所有图像预处理的操作
6 T, a) y# x; PImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片% j- N: f I! E- k: ~
data_transforms = {8 n& v+ w! X$ Z1 C, G8 ?, Q. l6 x' s
# 分成两部分,一部分是训练
& S2 c# d6 W. ] d7 q9 V- A 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
2 ^2 H! J: r* I3 N6 g. ?) Y7 u transforms.CenterCrop(224), # 从中心处开始裁剪$ q3 \2 T& j( K7 N* U1 q" f, ?
# 以某个随机的概率决定是否翻转 55开' Z V' ^9 s+ H4 `* ^& o, y6 r
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
; h! s4 _ d9 b9 ]( F1 L4 \) d U transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转- m# ?4 k$ m8 _1 T& b
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相* Y- C$ d; t, \7 L9 ^
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
- Z3 l: D+ Q; d `* X3 T transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB2 G9 O2 u/ o+ ]
# 灰度图转换以后也是三个通道,但是只是RGB是一样的5 q' i4 D+ R1 V5 V
transforms.ToTensor(),
/ y- }/ M6 P' c7 x Q8 c transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差* V8 J2 \8 n5 Z5 y0 o& K2 b9 G. f' t
]),/ s G4 ?' G0 ~# \. \$ E( K9 g
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
0 W1 u6 ^, i S) E0 q 'valid': transforms.Compose([transforms.Resize(256),
$ L& x3 L& ]4 W, `9 | B/ `8 K7 O transforms.CenterCrop(224),* y8 p$ t8 G8 M# c' l" _
transforms.ToTensor(),
: i+ p, D$ e, [& s& F5 r transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
. {7 Z7 h1 V- v9 \ ]),
" G/ d3 y$ E4 D9 l# e5 ?. ]( M! N7 I}
$ e; ^# | \ m, E# x/ K0 H$ h" ~9 t8 [3 W6 M U2 H$ g
1# }0 H5 m9 }, i. H# C4 h
27 u$ ~8 K; |% \4 n. z* W, ~6 M T/ E$ E- d
3
/ U7 T! y5 ]8 Z! v4# x7 D- Q- W- }9 T7 Q2 O
5
^7 o, O1 G4 C6 |; ]; n8 J6
6 X& i4 b% p% C' H, P) p7
5 O4 z6 V1 d2 K- [8: t+ f0 n% a7 A# T
9
% m. `7 o( G* i* R* E+ C" E1 k10
2 w' T; i8 j3 p! }. ]; Y! O11. e; g7 b( S6 A$ |1 z( l1 C
12) _: b' [4 q" H) ]
13* P. ]" ^3 X# V6 [0 h
14
; G) W7 z# p8 q( h15
* S7 L$ [9 u4 N2 k16
# S+ _0 q% C) x. O$ s17
- g: q5 f0 F" Y3 M D: O18; o$ o- l1 ?( D! ]. m
19) `. {, N) ~; l9 H. M
20
i( C; t7 Y* A6 O21
5 ]2 K0 m, ^/ \, Vbatch_size = 8
- N* K- z' e+ G. [! Himage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
- Z) `% f; H) g' Pdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}5 F6 [ ?+ h1 P3 g0 z0 M
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} + _5 a" Y: {6 {8 f
class_names = image_datasets['train'].classes7 _4 [1 ^5 u! B h1 E; O
. J4 m: `) L# p
#查看数据集合. K$ l- v4 v* }
image_datasets
9 n7 E0 ~9 D: i( x- W# t9 S# s7 q1 {/ s& \. x
15 A1 Q1 b; G; G4 ]$ }2 e! b5 ]6 v( [
24 a% q& b/ N) R6 V* J R
3
3 T8 ]+ X% e/ [# ^) [/ @2 p3 {4- x2 J% k$ v3 d$ f( s* Z
5" d5 _& o: _6 y; I) U: y
6
' q' W, ^- w5 R$ l8 q7' s# u8 |; u' g( v% [# r. g
8
' F9 c! L$ |9 E# p8 m f* z# i( o9/ k; L5 I) I& K0 I
{'train': Dataset ImageFolder& f) Q( q* s/ J# _7 o+ Z" U" c
Number of datapoints: 6552- q" ^- E9 D) Y& f1 ]0 Z: {. M
Root location: ./flower_data/train
2 D3 K# T; m. i StandardTransform/ I9 h, n, |7 X- W( V
Transform: Compose(/ `9 V$ n' k! d d9 \
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0) N5 r0 Q) V* Y5 H3 u
CenterCrop(size=(224, 224))% s+ I$ A( S* n
RandomHorizontalFlip(p=0.5)
) G3 c% }8 g7 o RandomVerticalFlip(p=0.5)7 w# M5 V7 u' x# l2 V
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
/ z; W+ a- G* H" ~: V5 I RandomGrayscale(p=0.025)
* w, n) ^9 r ]9 o, p: o ToTensor()5 o7 h8 w5 l3 c: P. j* }) C, ^6 @; E
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
8 a$ f7 c9 _9 Y) [. b" D ),
2 j" }' ?$ X' D 'valid': Dataset ImageFolder
- \! x5 P0 F8 q5 G U* Q, Y6 C, S Number of datapoints: 818
4 w9 F4 M8 J @0 t Root location: ./flower_data/valid5 @$ G o) r' I' F7 {0 D6 R
StandardTransform5 e* y. [ y, _7 `% q# D
Transform: Compose(
4 U* W) M' W: X& r/ V& u Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
" s/ Y" e! m3 J* |; E) b( | CenterCrop(size=(224, 224))
0 M" R8 c" G" t' a. ^5 ]5 F ToTensor()
. e" C, J: d" a& {% L: N! O Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
( r$ u2 L, m. W& B; I )}9 A3 N8 _" e1 s! T% ^4 O
* N1 [# u7 H: J3 T' ~: N- k* N
1# D6 p, V: P6 ~. p9 s5 p' C
2- ] a8 Y+ C! N/ S- C
3
5 Y) V- \- Y; @$ O. u+ q3 H4# d$ U0 Y. ?( T% z% r# }
5
; ?9 n' b* [4 C( }' |5 E% H# a6# g3 ~4 }7 {1 P9 X6 A- }4 S
7
& B9 M* | x3 u/ ^+ I8( A. M7 U1 D$ t( }5 c0 ^& y
9# B. O' U3 d Y: ^8 z
10
* R( G% ]* b5 `/ y% V( z* Z11# |6 a! [/ l9 h- r: O4 p. _4 ^
12
+ b9 d) f5 V! o$ Q0 v5 v13
' H$ E! |& f" _% G s/ z- |6 k14
7 f/ ^! [3 M- L* N3 r' C157 g( k: g& ]7 I+ M
16
+ T6 E$ U# c2 R1 |17
$ n5 z4 U7 q! L( @187 |: b" W& X. h; |( D) o
19/ e) \8 |! C$ d2 e+ Z
20
0 g7 K/ a! Q$ v211 R) ?4 X8 N* e
22& s: Z* _3 X: G, G& a# T
237 G$ s. @, A5 E3 p) D& V7 |
24
4 G1 |2 ^% B; Q7 N# 验证一下数据是否已经被处理完毕) U4 N! C+ |0 L) h e+ |8 c- s2 v. | g
dataloaders* w) J$ m8 i5 d0 x& Y' h3 Q
1; f" l9 @. ?- J) H( D( d3 M
2
# `$ D0 R7 N" u8 y( r/ Z{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,0 v+ g( |1 ?9 {" b- `
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
; h7 m' k* F& l, |" s: Y19 r2 U5 W+ c; N8 l
2
( O( f1 Y& l# y2 w: hdataset_sizes' ^ C7 C' h' J) _ ]
1
% ]6 M0 x/ _3 R, r. q! I{'train': 6552, 'valid': 818}/ d2 I$ M1 \" d- f; g- C
1
: I$ Q9 d \6 M* d读取标签对应的实际名字7 ?, X& K, S2 v% ^# K- c
使用同一目录下的json文件,反向映射出花对应的名字* [% R( I8 G* z. p1 b4 L! ~# x
% F8 U+ Z1 }% L* Nwith open('./flower_data/cat_to_name.json', 'r') as f:
& y, X8 N2 p: E, V) S$ p cat_to_name = json.load(f)- n, n5 i2 P8 s7 F7 T/ v+ W( m. c; U6 z
1) Q' V. D# F/ f M# s* x+ g7 X
2* A! K$ h* {9 G" y3 ^8 S j- W
cat_to_name5 @, r" {7 S1 E E
1
! e+ m: j1 G$ o% v, a9 P; b{'21': 'fire lily', F3 r; o2 W- g' `2 ^6 T- e
'3': 'canterbury bells',3 O; d: o1 a: f+ S
'45': 'bolero deep blue',
9 M l! s) J3 D2 ]- v '1': 'pink primrose',1 B2 n2 p4 r4 a4 `0 H% @, I6 ^$ ]
'34': 'mexican aster',) O. w9 v/ R' ]) w& M$ I
'27': 'prince of wales feathers',
8 Y& i1 e2 U- J '7': 'moon orchid',
- M: f" a! k! C6 w' } '16': 'globe-flower',
T+ e/ }, J h! a) B# p& c' R '25': 'grape hyacinth',, [- ^5 N# s t6 w n3 y
'26': 'corn poppy',
* m: W! o3 @8 ^) Q% l# K" _ '79': 'toad lily',6 L! I& f! l0 C
'39': 'siam tulip',
, R% P5 ?$ \& o7 o: M& b '24': 'red ginger',
% S5 f. d7 G) X" N8 ? '67': 'spring crocus',$ w# h1 c5 ^5 y1 r5 a! @
'35': 'alpine sea holly',
7 U$ c3 G7 V+ i' \0 m9 ` R8 L' b '32': 'garden phlox',3 m) X6 _1 u; G2 \; J
'10': 'globe thistle',
$ y/ s; q: L* M' F. i '6': 'tiger lily',
, G0 n' O" H& C0 K '93': 'ball moss',
' y9 X. E& d4 [4 _$ P '33': 'love in the mist',
8 r- L @5 b% \8 L8 p! L# W '9': 'monkshood',. W3 ^1 a- M3 x; J4 T6 h% h3 n
'102': 'blackberry lily', c9 b, f0 k+ x+ K
'14': 'spear thistle'," B1 Q0 ?5 r" ?2 C
'19': 'balloon flower',
3 h2 I' U3 m9 ?- J '100': 'blanket flower',. `2 ^( j+ X$ N/ E- R' q
'13': 'king protea',. r; o. [' [6 q! O* E0 c4 P! A6 D& x
'49': 'oxeye daisy',# p Q! I. Z* ]( [
'15': 'yellow iris',1 F0 K' T/ B7 o, T. H0 s* K
'61': 'cautleya spicata',
& g/ ]* e$ {4 c/ h$ @ '31': 'carnation',
% n! i9 q t6 @% c '64': 'silverbush',$ h1 T9 I* q7 }7 k$ e# k
'68': 'bearded iris'," R# f9 \) Y* a3 h5 [+ L0 }
'63': 'black-eyed susan',
2 ?4 p' A/ k* L# l- M. A% J: U '69': 'windflower',' D5 [! \8 o# o9 {
'62': 'japanese anemone',$ I b$ X0 v" ^* k
'20': 'giant white arum lily',* C$ c0 Z) R2 R+ x' A/ R
'38': 'great masterwort',% m7 W! w4 z3 S
'4': 'sweet pea',
' w0 C7 d) U7 O- \5 {, t7 _ [6 e '86': 'tree mallow',
% f3 S, x$ v: ?# t) O '101': 'trumpet creeper',1 G4 L9 P- b& R6 P. y
'42': 'daffodil',
5 G' I0 o. Y' l: X( I q '22': 'pincushion flower',
; D* @- ~. y- w' g '2': 'hard-leaved pocket orchid',6 q! \7 G; }7 ?2 P/ a$ L0 ^. W
'54': 'sunflower',( t0 O) y; S- S& ~! O) B
'66': 'osteospermum',
Z" l% |! V& P5 f9 S. S( l '70': 'tree poppy',
. q+ [4 C& ~' U- M: S '85': 'desert-rose', c: f* d" O2 J1 ]& S+ \
'99': 'bromelia',
2 z5 x, C: m" H '87': 'magnolia',
) l7 y# i- p8 Q. N) U) G5 Q" B" I7 Z '5': 'english marigold',
2 _) G( @$ u0 }7 V+ R0 M8 q: A '92': 'bee balm',/ _- l- I- E7 H2 B
'28': 'stemless gentian',0 D8 x1 [( K( j$ t9 [
'97': 'mallow',
7 z+ x, v. J+ L; x/ P0 f '57': 'gaura',* v5 E1 @- ?2 r0 z# I' v+ J$ f
'40': 'lenten rose',9 K% P( P3 n# d. B) T
'47': 'marigold',
8 E. S) W" ]0 p6 J' ?' Q3 B% U+ R '59': 'orange dahlia',$ q! m( w: Q0 G! p
'48': 'buttercup',
C0 |7 T2 ~/ B0 ]# L: U! B% o '55': 'pelargonium',
: f1 R9 q; M. d7 P/ C9 E8 } '36': 'ruby-lipped cattleya',$ t7 y2 p8 ]/ X
'91': 'hippeastrum',+ k( e7 o! f. s- L' W
'29': 'artichoke',7 P( D8 u4 e1 K! o1 ?! [( A
'71': 'gazania',
" a1 c1 t" Q; i& n4 g0 ~ '90': 'canna lily',& g q, M a! U- [
'18': 'peruvian lily',5 s5 S: I9 ~. _- C: p* L" |
'98': 'mexican petunia',9 C5 S8 e$ u+ \4 w$ V3 f7 j
'8': 'bird of paradise',
- s- y# W1 P! y* V2 f1 ?( v '30': 'sweet william',* P1 b, {5 H# ^% \3 s, a
'17': 'purple coneflower',
: e" M! r7 @7 x+ ]$ Q '52': 'wild pansy',
}, r' V) M4 W4 @ '84': 'columbine',: ~: x' X" r7 q. J0 R q9 z
'12': "colt's foot",
9 j) ^# N" U6 q/ K# C/ I '11': 'snapdragon',
0 V* _ Q+ i9 N '96': 'camellia',- j6 }' b0 D8 p& D7 c
'23': 'fritillary',1 O( ]! \7 r0 h i8 F
'50': 'common dandelion',
/ A# y' T& L8 L6 _ r+ \% u '44': 'poinsettia',1 p5 t0 H8 {4 S! M4 n% r
'53': 'primula',: L: L( e/ ^5 e; M% W/ V
'72': 'azalea'," C' Q; V; x; D9 U8 T
'65': 'californian poppy',7 X1 ]: w) |1 A0 P' {; G# R. u
'80': 'anthurium',
' g* E2 x0 @. ^9 X; ] '76': 'morning glory',
. P, s( l! k. j$ t, I" d4 ?; D '37': 'cape flower',# ` h# u* c% g$ s; A. o% C
'56': 'bishop of llandaff',
- b5 }8 D% X+ \( p' K( d '60': 'pink-yellow dahlia',
- x/ b) H7 ^' j! g5 I0 F '82': 'clematis',
! M& w ]0 w+ X; c6 z- A$ v '58': 'geranium',
) J2 Z; N" R8 s '75': 'thorn apple',/ g3 ], M' ~3 I% {
'41': 'barbeton daisy',: X3 r$ Q. P; ?) z
'95': 'bougainvillea',1 v; v# ^" h% J4 A: J/ j
'43': 'sword lily',7 Q( Q8 _ t3 Y1 e
'83': 'hibiscus',6 X$ _) d+ E( ]2 q
'78': 'lotus lotus',% P0 E% x* y7 q& e( J9 F/ \' j
'88': 'cyclamen',
& K+ ?1 y" h+ L '94': 'foxglove',+ m' B/ ?7 L* I$ W
'81': 'frangipani',
7 Q8 l5 ~9 \) f '74': 'rose',* M$ T6 g" _9 a s/ ^
'89': 'watercress',
" A+ g/ h! G. G. [& x# E" { '73': 'water lily',, Z: S6 g, Q( p' V
'46': 'wallflower',8 H$ h! {9 {) c
'77': 'passion flower',6 N( ^( z+ R" g# T% t# ^! g
'51': 'petunia'}0 E$ K! q, A/ H/ q: c
) j/ Z7 e! S6 N9 I1
, s9 Z; `4 ~: j P% x2% o+ J* N, ]4 ?; O% d
3
. `8 a4 j9 C& c1 F$ s V6 D4
0 L: l4 _1 |. ]3 A+ t- ~5, w0 H* p2 j: k8 `
6
$ U8 A1 i1 y- r' \. x% k& Q& J7( T7 J, s$ \7 I4 W5 r
8
3 m) L! r) M0 O% @8 l95 ~, V" a7 a4 ]: N: q6 n5 P
10
3 m- v) R/ K; w$ V% C0 l114 B7 T! A' T3 h1 i8 O" J9 c: |9 f/ a
12
# w1 B5 B) R8 o; @/ Z; Y0 w13
- Z- ]$ L/ l6 [& Z14! o7 U5 {8 d! q2 K9 ~; c: a
15
: B0 x8 v+ P( C; p% I16
: p" t' @7 H9 o) @, ~" m6 T17) B& x% @2 a! J
18
5 _7 b8 P$ e! h! u9 ~) j19" M% b% m( X2 t1 G, d; r
20: ~- o* n$ n9 T6 Z8 B5 F0 f
21
" C; W" L; ?9 U# Y. {225 n1 m. G) Q r" R& s
23" ^/ G. ^5 y" h! S
24
5 Z: }# o" n- o6 W, E' s252 p5 F8 m# j* C
26
/ l8 G3 i/ A) Z( @* E& ~+ C27 o2 `6 l/ n: N
288 R& y+ ^4 V: i& o; n/ W5 l L
29
: \- P2 e+ Q8 k- i- v30
# A a( H% A" j" b; Y9 L) `' _31/ K( b& P& ]/ N0 N/ n3 Y @9 P, m
32' p5 {1 X& a4 \& ^- A; _( O8 g
33
' C, M4 M' Z% w" }# H34
! B, M$ Y9 e' f5 f/ }; e35+ w9 i% w8 w- X j$ c4 r; n c
36
4 D5 Z# u5 Y5 Q37
% Z9 {: }2 H( S4 E, L38& ]' |/ Q% g1 M; d$ p9 A9 u# S
39
: M/ ?( i% _& `8 i. ~. B4 Y, J40
2 e# e) ]; n5 \+ b! E41" \4 E0 o5 d; {: S' R. X
42
' K: ~" s. F* c9 t4 X/ L2 {6 k43
- V. v. v* `7 p& A9 y; n! s: ^! N446 s4 m6 Y& H$ c! |. a
45. d g4 d2 c( X& ?, n" U
469 s7 _! {' S. @1 D1 E" ~" J4 F6 I
47. x0 \ _5 c: F
48# F3 v8 E/ T5 w7 Q8 [; P6 b/ A0 W
49
! S# A0 L7 n& i2 |9 W0 G: ^502 `1 S0 l& A1 ? i7 _" V
51* j+ t: }! k* k m3 w- _! U) a
52
2 Z: S, @3 n3 f53
) |5 @, c5 ]' w54
0 d( K* m6 L5 Y& d. T0 S55$ W3 m" t# b1 W. m8 k7 z$ b, i! o
56
, F- ^% y6 }: x1 ^57
8 n, d% A, a, R; G58' U# y. i0 ~) S8 w- n7 A. i5 v
596 |) A: d9 C u0 o
607 B2 n: a3 @ a3 l
61
# t- D: K0 O6 ?2 ~2 j/ _62
( |# m4 w2 F, d, u( t63
- P% W8 n' g4 l0 N m+ M8 Q% H641 r3 S- X- v" V( h
65
3 ?- B$ w& v- U% E, ~664 Q! j0 C' A5 B; s
67
- E* i3 b9 U2 j8 K7 x& B# e- f682 x" c( c/ I& N: m2 V, d
693 M6 o- `- F7 r% M- D& I
70
5 F2 ^8 f+ u% o/ V71/ D0 n6 A1 q- x( e7 @# O+ z
72
$ E! [. W5 U$ f) E. v2 L" \, |73/ f3 W A$ d9 {( H p, U2 c( M
74
. m# s) K; v9 J: x1 w; H% d& U* P75
+ U0 t; T! W# @+ \& `. f766 t( \- ~. @5 q) W; l4 W
77
1 T, Y$ @( ~* M7 e$ J78
' W/ ~* w: U$ p- C$ E. N793 c: V, V7 }0 I( K+ t' t5 d
801 [ ?' y5 [# V0 |. i5 l. p7 N" b! x
81
4 Z: b1 }; i3 J/ V: C8 n82
2 G4 f* ]' F% ]2 b833 Y, `5 _# P- d) B( C& ^2 y' V. J
84
- ]$ e* d1 `5 G- c$ R85
* ], G4 p& G$ E( v4 `* n0 g! C5 y86
' Z4 k/ j, F6 [0 y- |3 P87
# C m; z3 ^6 t4 F88
r% m' F0 F: f! ?3 `# z ^) j89
: y! a/ Q# t4 s7 F& K$ W4 Y( @6 H90' f, c. E& x; Y7 f
91( W3 w( Y% [# m( N0 C
92
( s* z) v/ n& a$ L8 M4 Z- I8 `. B934 a4 _( c: o3 g& `$ |
94
1 X5 n) Q6 {/ x95
5 W$ V+ ~5 d" `, u96% y+ |! n: r9 z* I
97
7 Q: H& H% f2 w* i) l98
! G0 O3 _6 o: q99
g+ Y4 d+ @% V$ g' x+ a" x, M100
. I. [0 Z/ ^* y. L101
4 p: R& P: [& o) A$ o& A1 i102
! y2 E8 J) \( ?7 u/ t6 _' g4.展示一下数据
: g& r; G0 T8 y+ Pdef im_convert(tensor):$ c/ C6 h) O- U I9 b: [3 f
"""数据展示"""
* t( j6 t' { c2 U' y6 l6 d image = tensor.to("cpu").clone().detach()
( f7 d/ M0 }8 ?7 {9 O; l% c7 Q image = image.numpy().squeeze()! ]" h1 R) h# d; S+ T+ t" C
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
0 j% j+ G9 O8 c* A! c # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
: P& ]) n4 \2 N- Z image = image.transpose(1, 2, 0)5 _2 _7 z" A( ? a* ^6 _
# 反正则化(反标准化)
M$ ]$ Z9 z0 B, p( p% G8 O. R image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))+ N/ n9 B2 I1 x" \7 X
% u( _6 h2 a4 m2 R
# 将图像中小于0 的都换成0,大于的都变成1
( t M- e( D6 z% j: W image = image.clip(0, 1)! M" f' y: W z& [0 d: `
1 [6 n% Y( m8 h0 P+ j
return image& a- ^1 @$ z; l
1
3 B' T4 Z* D, C9 j2% r* f6 j6 k; A$ G) R& D, f
33 _6 d4 X& N7 W9 f; Q
4( p8 g, x4 }2 n8 t! U
54 }2 {! n6 {/ B) G1 \
62 \' V( a X# K
7
3 }+ N' Z6 z4 B* {4 z83 H' L7 k5 D4 E
96 ^8 g. V+ A; [& l+ B, t2 S
10
$ U$ c: f# z) R3 j" T* {& d11% r: D- K! u7 B# C" x
12
) a" k- L8 {9 b! M. J. L9 c131 y, R' X( s3 Y" B$ i0 n- G% Z
14. V4 H2 m3 z# F5 h
# 使用上面定义好的类进行画图) h; f8 a1 u: h6 l9 ]0 W; p' R
fig = plt.figure(figsize = (20, 12))' {0 j, W& L" K4 X
columns = 4
; h& h6 A* z3 j7 q& ]rows = 21 n6 K3 W9 g5 n5 n$ g" g9 V) b
|) r0 T: ?- z. A5 _# iter迭代器
) h) @, B9 t! d/ o, z0 L; w# 随便找一个Batch数据进行展示3 y1 s h- v% L7 Y3 U T' c5 p4 Z8 Q
dataiter = iter(dataloaders['valid'])* s& o" ^6 y6 q
inputs, classes = dataiter.next()' b4 p5 C7 s! z: z" }( E+ q! m: q
/ q- E5 E" e# P" _2 i% `+ j' l
for idx in range(columns * rows):
9 w+ C: N$ H1 y4 W ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])9 n: h1 y! E0 ~2 V; j
# 利用json文件将其对应花的类型打印在图片中/ M4 B J# j* p. v3 G5 H7 \2 c
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))]), k4 D& r3 u% J- H
plt.imshow(im_convert(inputs[idx]))& U( T8 y/ Q: |5 ^4 k4 ? ~
plt.show()9 e& W5 n- b% @
! H! [* K' K0 ^: o# H6 X
1
; ?6 U! e$ E) j: `1 C* B6 w2, h4 T! _) w. T7 F' V+ O
3
. C3 |, f7 t* d# c5 b5 U" x47 e# r0 I6 ~- j3 O; D
5 F$ h) Y8 B7 y+ L
6
6 s" b8 N* g2 r9 G7
6 h0 j' R7 [4 e# x9 w# U, D8( Y) L1 E. v0 X' v
9+ A3 s, J4 D* E" W7 ]/ [
10: q; m5 ?9 o- U3 l7 s2 u
119 F: g/ D8 |0 a" c, s& @
12
; f1 J7 H3 Q( A( h133 {" [1 b5 k! w+ F0 G
14" x1 Z; c) ?6 v
15
: \( C: J5 L' Z- w& {" D- k3 y16, o6 E& n3 B: l$ @" {! P; A/ ~& I2 s. V
2 M7 n: k) f1 }9 @0 e+ ]0 p" y# c K" V" g2 i8 X1 ?
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
# I7 i+ T4 i3 _" P/ Mmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']) o' P, Y X( m- J# I. Q) r
# 主要的图像识别用resnet来做6 a1 @: {( K8 l$ x& H6 `
# 是否用人家训练好的特征
5 x" M+ p/ l- @- [8 tfeature_extract = True+ K; p0 z* z8 `
1
: ~6 U j+ D0 {21 X7 b& Y [$ {3 c6 G1 {' a
3, a$ k, k' m. h( _. Z+ J
4% u& c; K# C: p' k
# 是否用GPU进行训练* H1 r7 }% Y% A
train_on_gpu = torch.cuda.is_available()
9 W- n1 W# a& s p' N T
3 H! P+ U5 h* x: Zif not train_on_gpu:# z$ _7 F. A$ _& w G4 J
print('CUDA is not available. Training on CPU ...')% V1 }! w3 [ ^6 j" G
else:
# {6 p3 t" b$ s9 X; e. K print('CUDA is available! Training on GPU ...')
+ D/ F3 D+ [8 d- }& g5 Z* u$ z
" V3 W3 F+ s! T3 o( o4 s) |& N. o& pdevice = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
% R# |: K. H S: T S/ ?7 \1
L! X3 X# c- U6 H" }7 O3 r# y) u2
5 K9 k4 S5 G1 I3 R5 \3" U- N; h$ H. o+ k. [
47 q8 D3 i5 L* `+ y8 g) x3 b9 q, l7 v
5
* o9 F' g. s( c1 S$ G# W( T& P* ?6
# H: o( e3 ]" U3 f: L7
( s5 m6 V4 Z5 O3 i, u3 @8
! o% p* g7 q# R$ H/ M* E z/ D9
0 U, i' e1 v6 s" C) ZCUDA is not available. Training on CPU ...
( x" Q- X0 B) t6 b" B1# p; M- {8 U( [/ m
# 将一些层定义为false,使其不自动更新( c2 b. G; H! z+ L( B
def set_parameter_requires_grad(model, feature_extracting):
2 b/ j, \2 h+ Z if feature_extracting:8 Q; | K E1 k; A7 O
for param in model.parameters():' q6 C0 e3 z9 m! e7 s$ X
param.requires_grad = False
# Q: J- b# X, ], }1$ Y6 T' I* M; R. F' B- @. }9 Z
2- o# ~( y3 J' ~2 K$ i
3! ?$ h, ]. ^4 \
4% y% A1 E# L( n, M9 n
5( j$ X- H( \) E2 P4 t5 ^9 h
# 打印模型架构告知是怎么一步一步去完成的8 K( }9 v! `9 l% D" f
# 主要是为我们提取特征的
7 h' n2 B9 ~; {3 |9 \$ f) j+ U
o$ r# m, ^$ ymodel_ft = models.resnet152()
, Z! m- n+ R4 D h% B' t' Bmodel_ft4 I6 u1 Q$ P g' ]. \- ~4 o
1! @; ^5 ]0 u/ |1 O9 K( \% u% `
2( |- o9 P3 [9 V/ I
3
+ J! R% ~$ f' F) \8 V- h/ X* k4
2 J; v5 e+ X2 {6 _" W, t+ C- ^; K5" s0 v& H* a; \8 o; J
ResNet(
- \6 |$ Y7 M8 k# q6 H p (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)" s! D; }# x% A) r- T9 u8 `1 N" t9 m
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)4 m/ A7 U# K5 r6 n
(relu): ReLU(inplace=True), i& n1 o$ w5 F" l1 H* w, q* S
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
2 ?4 V0 o0 d4 I (layer1): Sequential(9 `+ y; A, g: k1 S7 Z L/ R
(0): Bottleneck(4 x0 P+ B J/ I" f; S# A+ Q
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)5 N. x: I% }# g! y
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- O$ d6 s. ~: E; \6 [$ F/ Y
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
. y A( \, `8 T- T3 a8 g) t' ~8 M0 l (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): k6 D0 h: j1 _5 s1 ~ i; _
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ o l- A6 \/ b2 p% F0 B (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)4 u- w X, t$ ]) S, }+ x2 h
(relu): ReLU(inplace=True)* l3 k4 ?0 C1 V
(downsample): Sequential(9 x4 Z5 f' X/ S* V/ V
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)4 D/ Y$ Z9 g {7 T7 [8 d- E# b
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)! w8 Y: ^+ K! u2 D, u6 L
)) n0 C0 B1 j1 W6 n& T9 _
)
2 D; x1 w& |% ]! a6 t3 \中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。2 Q+ I+ ~1 ^3 G5 O N
(2): Bottleneck(3 e2 D2 _& `. T, }* l* Z4 j6 a7 J
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
e1 P j4 T8 L9 k/ ?/ m$ @8 T (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
. `& O! W W8 S6 o( @& | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)3 u U& O4 p; D: L% U/ r
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
1 {3 W1 i/ r+ V q (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False): F! d/ I3 r1 G' Z& L, ~8 a2 _# V
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
! K. {; @4 G1 o4 j) v, H4 e (relu): ReLU(inplace=True)2 V6 E1 O/ r1 _1 K
)/ J( @1 u+ k# P2 F
)
7 F/ v( Y. j \3 {% B4 c' w/ x (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))6 @: K; @0 q7 R; ~7 u6 f3 Z1 Y% g" _
(fc): Linear(in_features=2048, out_features=1000, bias=True)
$ O6 w h& e$ U- S3 Q)$ D! p# a' h+ q5 J* P( g
) }; o# x$ F4 p6 t' I
18 V8 k: s% q* N, G
22 v: d7 ?# M0 Q3 `! j
3
- |" `4 r2 D; x; R& |- r4
# r7 L: e. D( M9 r/ O5
' S w5 o! G0 [; k. u2 y9 G- b% `69 W* ]3 P- A9 _" p
7" d- a& Z+ ]8 u3 y( }. K
8; d F; t7 D; o
9
: A. }6 \& ?3 N4 }1 u10. O, ^$ B$ w k+ A8 H v2 O
11
; V9 P- ?# k. A6 `$ H( P. N12
# j/ L8 _# ?3 e8 Z13/ a& ^: o0 h" t- @( k
14! J, _$ H$ |; p! t. j l
15( ^& S- r; Q R b% R" |
16
; ~; K; W& D2 N: o2 k17
$ B' W3 T) z2 z- w8 [. |18& t; k/ U: q& O% I" s: P: X$ I C3 N
19
; f. g/ |5 a, U5 t5 ~20$ P. [. ]7 f U
21
; p* e% h, b9 ~22
7 Z5 R. F$ v# Z$ a* H7 W23
% k- j- |: x7 k1 X8 x/ h: |24
|( B& e- `; r: V' ^: e/ @, S0 [1 ?259 S% B$ N6 R5 h9 T. f: r' D# H, M
26! M+ y1 h% b- a/ [- f* {3 J
278 b" V3 D4 |) T/ J; m
280 `' }& w: K/ @1 Z, c' @
296 H% F+ ~/ p A$ ~! q$ q+ y. a
30, n* k3 n' o% v6 I) t
31/ r3 g- K* e5 X5 o# i! ^2 I
32
, f& D* q. k, _2 u: X7 {/ C9 V33
" l2 v1 n- Y) ] `最后是1000分类,2048输入,分为1000个分类1 u2 |( s9 k& S1 {
而我们需要将我们的任务进行调整,将1000分类改为102输出
( H3 X1 j2 [5 r: k6 B$ `8 \% D: k5 k% W
6.初始化模型架构# P! Z8 B( v0 ?4 n
步骤如下:
* Z, _3 w* @2 Z1 s+ ^0 ^+ ^6 X9 H% @" i2 W& H6 e$ }
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
) x. U) ~; Y2 y0 i可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)0 }6 w! l' l4 Z" E' `' V
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
, a" i5 a! |) a3 b官方文档链接( B4 ?7 c' ?4 @" B' Y4 x8 J6 Q
https://pytorch.org/vision/stable/models.html3 ~! i8 k0 B- [# n0 m
% N, i" r. V7 d, U; w" s% `- Y5 o3 Z
# 将他人的模型加载进来+ l' [9 H6 b. ]1 C7 t% I6 t9 q
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):) V/ h) |" }9 J% G6 E$ h
# 选择适合的模型,不同的模型初始化参数不同
9 Y7 v& Y; ~6 N/ g model_ft = None2 w! ]7 U+ C( P- s
input_size = 0( f8 f$ M2 a+ w5 \& C. V
5 C) f9 \- r% m! p' q
if model_name == "resnet":5 J' N+ S( B- @) O/ z4 a$ H
"""1 t$ R: x% ~5 Y7 b5 E4 W
Resnet152
! M1 ]9 `2 L! c """. W5 n# k( B" d& X2 R2 E
/ B3 j& P% Y" a: N6 q) b: x
# 1. 加载与训练网络& y( t/ x2 @7 g" h
model_ft = models.resnet152(pretrained = use_pretrained)( R% E: H7 z7 a( D1 r% ^
# 2. 是否将提取特征的模块冻住,只训练FC层
" Y) j9 G( K0 d8 S7 Y set_parameter_requires_grad(model_ft, feature_extract)
~% H) ]+ L& ~4 k # 3. 获得全连接层输入特征
6 x1 L3 p3 ]5 W. q9 @ num_frts = model_ft.fc.in_features" B8 F, ~0 e2 r6 n3 j- m4 r) P
# 4. 重新加载全连接层,设置输出102
- f" e. a. j9 } model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),# ?/ N( Z# g- m1 X5 l' w& H
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
6 t1 f( o: i$ P1 Y( V/ y! n input_size = 224, g* i' Y0 V+ w. ~1 U
9 u- r" l; n4 w7 h w elif model_name == "alexnet":
. S3 F2 _! Z0 r, K. O """
% i/ }* F) `7 z! T) V Alexnet
9 A, ]4 C" a7 b """$ m% r# Q- s9 I4 Z+ A8 J: F1 r
model_ft = models.alexnet(pretrained = use_pretrained)
- v. Q8 j: O" y& y set_parameter_requires_grad(model_ft, feature_extract)% o! e8 e* u- a2 _6 F- _
: [: p( Q0 i7 n# {" {
# 将最后一个特征输出替换 序号为【6】的分类器
! t. O6 ]1 _% u% E num_frts = model_ft.classifier[6].in_features # 获得FC层输入7 X8 j0 n+ L' e7 M4 \
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)1 e8 G. R' E! t+ |! V/ t; R
input_size = 224; A |+ S) [1 c3 C$ P" P
! r) [5 m- Y$ i; K2 N% i* `3 P elif model_name == "vgg":
' \: v/ A% j5 W' e _ """
@* r; B( @+ D) g$ l. s VGG11_bn
6 a+ P5 k8 x b2 U """
* E8 J( r1 h& `1 m/ e model_ft = models.vgg16(pretrained = use_pretrained)1 J7 |6 _ \+ X/ O3 z
set_parameter_requires_grad(model_ft, feature_extract)
0 ^1 H3 S1 h4 a num_frts = model_ft.classifier[6].in_features
7 x. ]: j2 B5 b model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
* X' B( D( R+ ~8 S- K input_size = 224
3 P& F- g- J! h* m, h) D3 O
! ^1 x7 Q. L7 h9 n; K H3 D/ B6 t elif model_name == "squeezenet":) x) P, h. r. W- s" Y( c
"""
& u! k( u# V! `) }7 v+ E) I Squeezenet
/ G0 X# _& m7 z, {, s, w """- I4 T0 v5 H8 a C. A6 V
model_ft = models.squeezenet1_0(pretrained = use_pretrained)- v8 d' p( Q( l1 n a! e
set_parameter_requires_grad(model_ft, feature_extract)
6 |' | f5 {, l& e( r model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
1 V' O8 I/ f) E3 U. n' ] model_ft.num_classes = num_classes
( K" t" [& [4 z, E5 y$ M input_size = 224
+ w; [) w8 n; `( r" [3 U
7 G# R t& O3 O" l elif model_name == "densenet":6 d; ]8 ~& j7 j3 w. r1 o
"""
" q+ X( k( f# f+ M# t; O* e Densenet' K* U% J- c: [* A% X7 w7 _
"""
7 K) e2 i4 Y4 l" ~ l model_ft = models.desenet121(pretrained = use_pretrained); G+ j: @7 L4 e, G0 m5 U( ]* z
set_parameter_requires_grad(model_ft, feature_extract)/ y1 o" H1 T% n2 M4 Y2 h
num_frts = model_ft.classifier.in_features
% Z% x9 l/ o; c# ]2 ?+ ^ model_ft.classifier = nn.Linear(num_frts, num_classes)9 Y2 f' ?8 E+ D1 t. j- l* k
input_size = 224; F) ]+ G# c$ x7 d' ^; V w
' u/ ^) I4 m, u) ` elif model_name == "inception":
+ \. W) @% C. O: }/ ]$ L* u """
9 l5 V9 X1 W* R/ R& v/ l7 Y Inception V3 b5 G, ?; Y E3 j
"""7 t5 y" P6 V( @% U! S
model_ft = models.inception_V(pretrained = use_pretrained)
) m5 ], O. v/ H/ M set_parameter_requires_grad(model_ft, feature_extract)* |. K2 A$ q! n% c. h# l$ d. J6 @
- `0 a4 _! r/ [" c E& [ num_frts = model_ft.AuxLogits.fc.in_features& x9 H5 [3 \$ G. V
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)3 l( p. X( G0 G/ [/ p: B7 x4 h
: B$ A4 X" e1 m3 J6 C
num_frts = model_ft.fc.in_features
/ m: P2 W7 B% d( W model_ft.fc = nn.Linear(num_frts, num_classes)4 a8 e$ Q* U. \2 B/ m
input_size = 299
, N0 u) P# l4 G$ I8 u. z {+ |( A7 b N" p7 Z% `
else:- _. t8 w* z, f2 D# h2 `
print("Invalid model name, exiting..."), L4 @. c, v3 q- n. l. T
exit()
7 g4 E: ]8 h! J4 | F; W+ S5 w) ]* l2 q8 ^
return model_ft, input_size
9 A5 Y, J+ K9 J+ t! t# v
9 f% l6 E6 y- b4 w% }& |5 Z1& Y- L: C" @5 |- J) J1 F1 B
2! N8 M( C- o6 y5 f C1 t$ R1 f
3
1 _: ]. ~3 F2 T2 d: a4
J, d* U! c/ f57 j% V+ Z6 l" V5 k/ f5 s
6
4 B( ?8 Z. k2 b2 L) g& `7
3 G6 W% o% @7 @" }: s: c8 C. d8
3 w& N( d/ S# g8 T4 V$ p4 z9/ E0 @0 a0 z. {! _( @. j3 ?
10
; c% z2 @$ _5 ~* ~& x) x11: z0 f) d8 z B8 N8 v
12
) T; B% v+ @, l" x( z13/ ]* K/ K( i; j7 F
14. A% _. v( b% C
15
# T* J) H: w4 F4 ]* |16% ?0 O }* z% q5 ~
17& l6 [8 B# w2 P; R( K; [: ~$ y2 j
187 ?1 `2 d2 u; O$ y
190 }4 a, b0 w2 [: w; e9 y5 q+ v9 }
20
; \ U9 _# d) X2 E1 q219 j* Y9 c* e" o6 ` }& \$ p5 T' ?
22$ @- V! S& y, A! e: k9 T
23) r* r' p! n. j, p
24
# E, @. ]; ]. z/ D5 X25
- A. |0 G4 R/ o+ p% ~26; k, a2 }# _7 y) x
27
! i. _+ E) i8 j28
3 g* k$ D/ b T- y& L* u* R294 a, q- @) F+ U. s% E2 k* ~+ @
30' w$ [' V7 n4 A! `' M; B! K
31
2 B( W1 k7 c( z. V/ C: G0 b32/ |& `8 e, C, c0 ? V$ ]# z
33
) K! D( f! l' a. H; [& c34. I4 v( Q0 z+ h O( B5 y
35
, Z3 u; y. Z6 K# T. k363 e- L8 ^- x2 h- x. l5 o3 R2 P
37
% p+ N1 f _; {4 \1 e) k0 q- x38
% v, M0 @( w U2 r' k. \8 s39& f1 `: M, }- s0 W5 B( t5 |7 ^
40 [6 ~0 m; y0 g) S- t
41
- [) g6 `( b5 r+ d! g/ ]& _42
* u# J5 A! E+ O* p3 d" q" `43+ f* m. F; A( H) D
44
5 c( @2 L' y* f% b45
, h2 H9 e4 D# O8 q5 `46. N0 N4 Q1 w7 `. I* Y# O8 U: ]
47
$ D) b7 X& ?1 ~! F" ~48
$ x5 O: y% ~) T- y0 C Y+ l49+ H! u7 O+ u7 Q* t
50
8 I1 \* c7 \( g& x+ C51
8 u0 o& G; t# V- m& I) S) e52
& Y1 q2 h q" D" f$ r8 w1 O53
1 ~" u* B: }# X5 r" A! D54/ ~- x6 ^6 W3 Q" w
55
8 h6 E! p, \- d0 y9 A$ M56% b9 R: u3 l6 Z3 X/ W4 k- D/ w' [) D
57
* A: Q' k" U' x- k+ Q, `7 T8 P58+ ?% f2 ]5 m; x' l4 U
594 Q; S0 \- U- |. `; p
600 \4 t& {% E. J x
61- A7 q/ \0 F. ?) t/ }
624 b$ ]/ g. O t3 ?8 C
63
7 ^! e+ S* K; _ ?64
`6 `, D# v# p0 a% d65
1 Y! {9 u2 i# U0 t- U3 i66; P& P6 `9 n4 L y0 `/ b
67
8 T2 G( K( x7 e3 G# I2 O68
$ N% @( M. @9 c6 [% ~69
$ ^$ d( r1 d/ V* m. H70, ^$ r; G( \, r, a2 o
71
4 K, M# _& q$ |8 m' T72
7 w% X5 a; d0 G5 s9 \5 L, Q73
: W& u$ a) b' b" V3 Y* s744 X# ]& G, C! S6 Y1 f
75& q$ F$ A) E' F) K/ J7 b4 v& W
76
8 S( h! n+ P4 ~77, h% A1 i. m; {7 \+ c. c6 y
782 z, Y; p S* g. r
79& x- W8 n! X6 C2 T/ w0 v# L2 q6 K
80
" D7 o# C- r: z) a81) Q9 m7 R: J3 i2 u' K; q' R: K
820 w, d% R3 e* \ }! L
832 y* Y0 T' q4 z. ~
7. 设置需要训练的参数
, `5 h( G" l% c' u1 H, x- ~# 设置模型名字、输出分类数
. T. e" Z2 g1 v8 H1 G# j% H# c F5 i* Pmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True), Y8 d; K5 o. L1 J# \
2 l0 j8 w s$ W0 N) l# r5 m4 `
# GPU 计算
% y& e! q( @1 |7 t) z* Z8 W7 r. ymodel_ft = model_ft.to(device)
8 Y$ A( R4 J5 v2 j) C
4 `/ X0 J" y7 x4 {0 V# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取4 t+ D$ H0 h$ `3 M' {* M$ f
filename = 'checkpoint.pth'
8 N8 P7 q: P. \ g' k! e4 R- p6 T4 ^5 O- n. C. e; c+ Q- e3 [8 e0 W
# 是否训练所有层
. k; H% s1 c3 lparams_to_update = model_ft.parameters()
5 p$ P$ d0 x- M8 g( e# 打印出需要训练的层' j8 f5 C. n3 |0 v5 c0 D
print("Params to learn:")
. C3 [ S* ]; R+ ~6 vif feature_extract:. l8 Q, N) q6 M# J4 l" r) i
params_to_update = []0 i' ~! b+ r2 w( h7 R. e% I
for name, param in model_ft.named_parameters():6 T, W0 V7 G" H6 b4 x' m
if param.requires_grad == True:0 @3 T; T+ f( Q0 j( }9 u K- u
params_to_update.append(param)! h; d) k# P4 j0 {2 M
print("\t", name)" Q0 E8 f! _' s$ O! N! d
else:
% j4 x5 z4 F7 m: z7 Y for name, param in model_ft.named_parameters():3 g. M9 c( T" P' c: W- i" U0 {
if param.requires_grad ==True:
( O7 N- `/ n. m print("\t", name): c& Y/ p8 [- }3 C
; U7 j8 k; X2 T$ @" H+ M0 `' b- L
1
0 Y6 N e$ s: \1 R$ s' n) U& q8 a7 |2% Y% ]8 \" x4 [% x: b" Q
3
) A7 D) {( |0 p8 T. }47 o# K& W. W; J3 m
5
+ m9 |/ l3 U% c3 X: C8 t8 }! |62 c& L3 ], m' f) `! a
7) L/ U7 B$ S! `+ L: Q" \
8/ C+ L6 z+ p# ` z( {4 P
9
S- E! a- V1 R) Y2 F! z. l" t10
* Y" j+ `) @" f( _" D1 ^$ q% E. y11
1 u8 k7 _& G1 d/ m12
/ W6 v$ w' c4 s @. E. o13* l1 N3 A* }/ |: _* a# d
14
6 n7 N ^- U5 g& D* R' c7 C; a15, f7 n X+ A3 L( v, [9 D
16
+ J! t: }* Y5 e3 }4 Z m$ b17& G9 @9 J0 X6 e! W
18: B) m1 d. c8 t7 z' m! K, \
19: o" X+ Q/ W2 G0 b0 n1 y+ ^& C
20
: w, p1 J: }+ n- L21
# e9 x+ s, m1 V9 S1 b: d22# K; H1 X, V0 [! B9 E. _
23( f2 M4 w0 m7 k
Params to learn:2 T* V' e- k9 X0 C) b
fc.0.weight5 N: R- _/ K( A! t' [ U/ g& E$ @
fc.0.bias
" g+ D# |. g2 R- k1
- h) v) N; e5 ~" K" B9 b7 Z# R2
9 a) z6 F, r1 p1 J( `7 s4 H31 }; O9 l: @' t( |2 M1 @- `
7. 训练与预测1 j' c' z; s) U! S/ m
7.1 优化器设置5 M$ S4 m6 w1 ]* g/ [7 R- S/ c
# 优化器设置0 H+ v. t$ V6 H9 I% }
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
' D$ b$ \1 G0 d7 m. _# 学习率衰减策略
4 b8 o( r6 g B/ F J+ d, q5 c! W8 Mscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)! {# N% E. v% Z& W1 T# s p K4 I
# 学习率每7个epoch衰减为原来的1/10
1 {) t- p/ g! @! e0 {0 X2 o# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算3 s- V5 [. b* _4 L- X
* d8 P. _9 J9 T! k' X0 E; P/ ncriterion = nn.NLLLoss(), l+ ~1 Z5 ~; z$ P8 U- ]0 i
11 t* j1 {; q2 @( d
2/ e/ y& `% `( |1 T) W* ]$ W% I
3
. y! m1 u0 Q. Q& |- j8 c6 k( X+ [7 |8 b4! Y. h) N. Z n% w2 b/ ]6 G
5
/ Y6 m/ F& x% b8 M/ ]6$ \2 m z' F# `! h2 l" R
7; E6 Y. Q0 g" J8 j0 W. N; W
8
+ F7 `0 L6 m6 e% u' _* G# 定义训练函数6 A6 x3 B& T2 T( Z% K S
#is_inception:要不要用其他的网络- O0 b; S% i1 ]! j7 }
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):! S. {6 D- w+ \' `, l
since = time.time()
- k9 u! G6 g5 {$ k# j* K #保存最好的准确率2 x' d2 x( L* G5 k9 |$ }% y
best_acc = 0) Z7 |. z |0 }+ p( y
""", z( u# T" R2 B! a; U, V
checkpoint = torch.load(filename)3 N# e# |6 J! a" y! u
best_acc = checkpoint['best_acc']' {# |& P5 |, B! a5 Q
model.load_state_dict(checkpoint['state_dict'])1 x; |) i0 E( n8 K% @/ p
optimizer.load_state_dict(checkpoint['optimizer'])* s* t) H# _( I6 N+ b+ Q
model.class_to_idx = checkpoint['mapping']
3 X) F' z4 n- Y' Z """
1 _2 c/ \; Q+ H& c, U #指定用GPU还是CPU9 H1 E+ o" j2 _0 X8 }2 F' M
model.to(device)& _' G# }) n) i- H8 k" c; s v; q3 K+ _
#下面是为展示做的+ k( L% b6 T$ N' T+ ~. K3 P
val_acc_history = []. K/ ~! p# C0 @5 A8 X. _2 {
train_acc_history = []
3 ~; C2 G5 J; d2 }" K train_losses = []
4 v9 q! g! A; x6 @+ q0 R- k4 O valid_losses = []: C- U; |& {6 K
LRs = [optimizer.param_groups[0]['lr']]0 q. f, {7 {, U+ S; x' g1 q6 B/ Q w
#最好的一次存下来
. R! g) d( m a9 O best_model_wts = copy.deepcopy(model.state_dict())- J" z5 a( @ q* i
( c) i% Q0 v+ [6 u+ t& Q
for epoch in range(num_epochs):
0 `6 y/ s+ G% X, |! o) z' L print('Epoch {}/{}'.format(epoch, num_epochs - 1))8 V' f" x" [: W2 ~
print('-' * 10)/ e/ u( U9 z) W, R9 G: y
; b. a. f z+ t2 d
# 训练和验证8 p; y" T- C' ]$ n' Y: n
for phase in ['train', 'valid']:$ ^1 {# I5 q; J2 H5 e
if phase == 'train':/ q9 B6 e& f3 m( a
model.train() # 训练5 j( J, a0 [& U0 ~8 z, I
else:
) I$ A$ t. C/ ~! Q' [# M model.eval() # 验证, Q. I! D- f1 c( T8 E0 \- q, p4 j
G1 N+ O c) w* ^: e3 Q' L running_loss = 0.0, p- g& t! T3 j8 a. ?1 X
running_corrects = 0
! S+ [" p& V' p7 Y3 I& E! i! {
8 }8 F* m+ ]) `" P* S- A- _ # 把数据都取个遍
4 k8 N! f/ N. ]& | i# W5 G1 [ for inputs, labels in dataloaders[phase]: z5 _0 s3 _/ E; U9 i* f! j
#下面是将inputs,labels传到GPU
% Z: m7 R6 O$ d2 R6 m6 G2 s& q R! J inputs = inputs.to(device)
" W7 _2 G' k8 I* s labels = labels.to(device)6 ]4 E( M0 C% Z. m; n+ P9 W
5 \& p! `5 B* V1 @* I. I
# 清零- V8 x; R# s' d
optimizer.zero_grad()- v( b" r' S: ]9 x C
# 只有训练的时候计算和更新梯度- J, l6 x+ H' g( o8 t3 q1 J1 `
with torch.set_grad_enabled(phase == 'train'):3 @5 P( U% L% C# ?* S/ Z* O
#if这面不需要计算,可忽略
- R- x4 t0 ~8 |+ n% {3 P if is_inception and phase == 'train':
2 H4 @& X# R8 ^" s0 W outputs, aux_outputs = model(inputs)0 e7 f8 T3 C2 Z* O7 i
loss1 = criterion(outputs, labels)) }) C- r5 b' N1 o/ m; R8 C
loss2 = criterion(aux_outputs, labels)0 m; e. V9 ~% y
loss = loss1 + 0.4*loss2
) c* z3 s% C0 D$ a0 G else:#resnet执行的是这里
! g6 \+ P9 Y2 Y outputs = model(inputs)
/ L5 p; n( p; }/ g+ x" G5 ?. T3 E; l$ @ loss = criterion(outputs, labels)' H) m3 z9 A4 l# p1 `, m
. m* x- m, `5 \& }- y- g, r" k #概率最大的返回preds
) P: ^4 b9 l- p l4 E+ O0 `/ F _, preds = torch.max(outputs, 1)! r9 X6 N1 {6 w, \- Y# u0 I
% v5 l$ O' K% H P# ~
# 训练阶段更新权重
& z M+ d) b' L$ G9 M if phase == 'train':
" ?8 l0 A7 i1 n; Z4 u' v loss.backward()
, d f/ p- E5 t9 h optimizer.step()" O6 c/ r* h/ a* H/ k
2 f3 h1 j1 x2 n+ L: g0 D # 计算损失
! |6 ]# H/ j* o2 @8 z5 x7 ] running_loss += loss.item() * inputs.size(0)) ?$ p4 H" F, g7 d! \
running_corrects += torch.sum(preds == labels.data): {. _+ w6 K9 K1 Q& @7 ^1 ?; S7 c/ m
1 ]3 q9 ` v' a
#打印操作5 {9 n6 b" ~0 G' C ^' }8 E8 c
epoch_loss = running_loss / len(dataloaders[phase].dataset)
; i% [2 m' }* T% W% H- U4 Q epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)# r' W& p& O: S3 A
6 w0 k$ q' u# t6 b$ W2 k' R$ F( a
+ e& Y) F: c6 `2 x time_elapsed = time.time() - since
' |2 F4 ?5 {2 [: k print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))# s) W) L! G! a# f$ t
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
, o7 g% x0 t: g9 p2 Z3 |& O4 [
1 P- {0 M# ^# c( T6 X4 [$ h- T5 |2 I0 d4 g6 Z
# 得到最好那次的模型
( Z, y D# d2 Q if phase == 'valid' and epoch_acc > best_acc:5 f1 X, ?# X7 E9 t
best_acc = epoch_acc8 p t9 O" k1 ]% l# v
#模型保存' ~+ i' y$ s6 n
best_model_wts = copy.deepcopy(model.state_dict())7 L0 V% j& D& E+ l' P5 l4 x( }, l
state = {# V8 E: `" }5 ?$ P g, k. { w
#tate_dict变量存放训练过程中需要学习的权重和偏执系数& ^; F9 c/ {/ C# v( K; j
'state_dict': model.state_dict(),- Y9 v3 ~; @( ~& l( k0 m+ O
'best_acc': best_acc,
* g V+ [$ R8 f& E3 F4 D" _ 'optimizer' : optimizer.state_dict(),; W, P& m6 @/ a- ^+ c
}
6 c- M3 T* N( [0 ] torch.save(state, filename)
+ L+ f, s) i0 X6 g, t if phase == 'valid':
& F$ o: J. k# c$ ~- o$ r val_acc_history.append(epoch_acc)/ { s# g: m8 f& a/ c( T# ~
valid_losses.append(epoch_loss)4 \( g; K7 _/ F& B4 E$ f7 \
scheduler.step(epoch_loss)
: k9 ~+ D7 f7 r+ b1 R, [. X if phase == 'train':
{3 I. b9 X' I% U train_acc_history.append(epoch_acc)4 g r F& s9 d' I8 A: k7 o
train_losses.append(epoch_loss)0 N) A8 V7 j) T$ Q' f% b9 b# A) ~) M* c) [
7 l$ t6 D( |1 `4 d0 G, i: O print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))3 k. E& X# e2 X# T& }$ d0 @/ I
LRs.append(optimizer.param_groups[0]['lr'])9 c$ s7 M; W1 O- Y: @0 H/ b
print()# A0 J0 W- w% x( G" z8 v8 e
0 _1 q1 j% j# c3 f time_elapsed = time.time() - since
. _9 _% \9 w- }* b( ^ print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))0 [, Q5 r# Z5 u
print('Best val Acc: {:4f}'.format(best_acc))
2 A* U. g0 ?0 {" [$ N- O
# N1 w( ?7 U# U W # 保存训练完后用最好的一次当做模型最终的结果
7 F( h4 m, S# S; m4 W model.load_state_dict(best_model_wts)
! [9 d$ i" J1 b return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
8 Q S+ o! G, J- e: @0 @( C
$ B6 [4 n8 o/ ?- p Z' M% J
( e: R5 M: O) } [! z: a1
) w) L. K p. I& q K/ u9 x H# f2
/ k2 j& r& I) ?+ j) U8 \' q% a6 @34 S0 s# [$ x- h9 R! r4 Y2 }
44 \ @& N) j) B6 v, V8 ~3 e
5
/ \# c( X4 a9 H: P' Z6
: `; j& S. D5 F) M+ w7 n$ @7
; Y* K+ Z4 p& P( W86 A( N! z: ]/ A
9: m% G9 x+ s/ n. G% v- r9 ^" `
10
, q( Y5 Y0 R( Z& k$ v118 X( B" j4 K$ q1 l
12: W& U* H3 Q4 F* e: n3 D; ]8 o0 h2 l
13- K% Z+ i, V2 N. j$ t
14! Z$ J0 L+ }" v9 I6 u
15, d# r$ _7 p _
160 ~" m- y: S4 H& ~5 n1 N
17$ P g8 h" w" K- t7 ]4 m
181 y4 J* I) d6 J. A
19# ]( T0 A$ y# e7 n& E
20
: j: P) r( I# N3 J4 A7 V2 i21& O* n5 K% ?* J- Y
22
3 M- N V; O( Z+ l/ H4 x! K2 i23" M* y4 L8 k& o9 F; q: |
24
4 C+ E z$ W4 O R8 n! e7 A256 w1 @ n! y- z) r
26- }+ ? q9 v, a) _
27
( j1 c+ d# F+ X' u9 R3 H289 c! G P$ b( |: t
29
( B) A0 c; U$ t" F- v" c30
, m# N; ]7 K2 `3 w. j+ v31* t' K) p3 z$ p3 M- b* J
32; j6 e% \4 v' h5 g9 `! P
33
2 }& E2 X: H7 p. h34
7 ^2 _. ~ M! E$ A! _35
* Y6 k6 _. l% a# C368 `3 p2 @+ [( n, [! l P2 d
37% I$ e* Y6 }/ R5 W# ?; B
38% X7 i: r4 l. ~" q( r$ g
39
6 s- M+ ~% I: i- d7 P: ^3 b3 s40
) v+ ]9 Q6 k8 F r G41* {/ L0 B* O8 S8 A# E" J5 _) A
42) k7 p E9 L5 f5 @5 H
43
+ a( E3 ^ C5 i _! N44; _7 B1 a4 K0 u- n% x) y# R
45/ t% K( Z0 D1 K3 w. C2 }
467 e- {0 e2 h7 q) W" U! _; \* F' Z
47
a; ~& b& I5 u48
. x3 I9 R8 b8 }# F# n49: f+ y/ t7 p+ k* w1 d2 w
50
/ h# A$ s* p7 `. w3 E, N" `: F51% V( z* i* @( r( Q) Z1 b& j
52
9 B y# z+ b/ y: I) s" v \) b53
P# g3 K8 d# g* |540 L3 P0 R' u% H$ L. J
55' F- Z3 W) R* b
56" y. c' X6 ~0 y! O0 ?- [6 k
575 D. k. d6 A2 G- T/ u
58) j. R3 _, `; k0 q! \4 S
59
# r0 {3 A2 V+ X+ x60
1 l8 m7 b1 ~% C61
( E( E& N5 ]* o b! U0 Y62
. F7 w: {. G. {0 [) O- d63: b3 O/ M! v/ b x3 T' L
649 I5 m: E9 [& r
65
1 t* @# \' F0 h |# j- Y( U66 i% g3 R: z2 {2 N- `6 C+ H% H: a
67! ?0 Q' F" |" @ W( m7 M1 w# j/ I3 V: Y
680 z, B% m, g9 |3 a* z6 M; E1 y
69: m; v7 y. q% p! O8 o! [* u* w6 `
70+ T& Y. V5 z) B4 A+ X* r* d# D
711 t4 ^! J8 g0 T+ w( y: X1 ^
72
+ X7 v/ t' {2 n73
2 T" m2 C# u6 d. A74% G8 }1 b1 \1 \! O: }% e9 p9 D
75/ U# N# J9 T5 W/ w3 b& C
76# ?0 J( j: W6 u/ r1 o- L$ r5 m& C4 P
77! ]" w- X1 C: B
78
2 q2 B) R" v0 \: l: O2 ^8 c79$ @ i0 {: z9 e. m- H
805 g: U$ R7 o5 E7 g! M+ y& Y
815 ^! i8 w1 a1 D8 Z" L
82% Y3 V) n4 q/ j! w$ I
831 O- b- _) h4 g/ }) r' C
84* B1 P4 T0 [2 @' B( G
85
% ]4 S, H* R# @ V86
3 ~4 i6 O- g4 r+ i) V87+ L1 v0 U; d; h8 i
88 d3 Q5 Z4 O4 Z; a5 c/ F- {
891 T! _9 M/ q& D5 o5 c
90
# ?7 Q; N* v6 J' d$ d/ ]91$ c$ s5 F$ }8 ]! D0 d
92
! P! x0 x$ \2 E7 X93, ^! Z; a" V7 l# A9 e
94
3 b8 T7 l2 @, o; v8 ?" q5 z5 x95
) k! s" n' ?3 [& C) n( D96! H! z( i" m4 `& q2 m+ H# n
97
- E# J/ [* I* V: ^4 ?98
; G2 Y# x5 y2 i% P# B; n3 _: c- J99
+ }( |, `( p3 X: H& [100
% c w( ]0 l* f8 a, T$ |6 r101! ~+ q4 H' T4 O! O
102( o0 o0 B) N; `: A1 ~+ }+ i
1030 Q9 N' V! L" a: F* D1 {6 Z6 ?
104! }8 u1 }# o$ a' }) R
105
. l W* ]: Y0 {- k3 a6 q106
; M$ r- y7 _! W7 X% m! u+ X% ]107
' P$ `! X, K1 w: W0 I8 i108
5 C9 F4 y' v2 V1 E8 r109
' A, z" i1 i, n9 I110
9 ?6 s! z# X1 x* |+ i R8 @111
5 |% m: c8 B! J$ Z112 u* i6 l% W b3 n! U5 b1 ^5 u" E
7.2 开始训练模型
# c+ @$ ]0 Y, D* ?我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
4 I& K$ O, `3 {* e8 z" X, d9 c% w4 g2 z: x
#若太慢,把epoch调低,迭代50次可能好些& b& x, g$ K# o. S
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
4 m5 S2 |7 B; t2 i6 R t+ Pmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
6 l2 @9 P8 j+ a
$ t/ e. V- P% O7 J9 {& g1
1 N) U" \# c- x: R% ^8 X6 r' s2 B2
8 c9 A* g( t( k1 G3
2 T5 i8 ?! Q' w1 r3 o4
/ D+ J |0 W9 N* e: Q4 Q% d. YEpoch 0/43 D; X7 e' {3 C
----------
$ v) \) e+ @8 d$ `8 [6 ^Time elapsed 29m 41s
5 W7 R! z5 A0 m7 g1 j2 etrain Loss: 10.4774 Acc: 0.3147
0 P0 `' V2 m9 K0 t( R8 cTime elapsed 32m 54s
6 M) k' X' Y4 r5 y+ avalid Loss: 8.2902 Acc: 0.4719
) _' Q* o* e! B1 y$ q$ u) xOptimizer learning rate : 0.0010000
8 |/ T0 f8 c6 Y% J0 f" S) `$ C
+ j' Y* Z+ k: E1 z9 T' {Epoch 1/4
* P3 k' }3 ~: f, d. ~! I" _) u----------
8 o z/ u& O; j3 H1 ~9 vTime elapsed 60m 11s
- E4 D1 c5 h2 F4 J% j) s% ?: g0 Ltrain Loss: 2.3126 Acc: 0.7053" ]6 _) ~1 t/ ?' ~
Time elapsed 63m 16s
# u }; Q, T( \9 p/ @7 |valid Loss: 3.2325 Acc: 0.6626
) A4 Y% g% e! ?- y& F/ D* I, IOptimizer learning rate : 0.01000000 l G/ M- S: \
5 n, X0 @2 H+ Y# J$ L
Epoch 2/4+ f O D1 U" q! L' M
----------9 r4 r e8 W$ m" N6 ?) A0 O
Time elapsed 90m 58s' e c: n$ F+ i' \7 m
train Loss: 9.9720 Acc: 0.4734, ?( ?0 B7 @. J& r3 p7 g0 g
Time elapsed 94m 4s, a% I% u& E* u# `5 O
valid Loss: 14.0426 Acc: 0.4413; M2 f7 T. X+ _
Optimizer learning rate : 0.0001000
- A' @# {: b( \" L1 y. I9 @/ \( Q8 N, b; {: K9 s
Epoch 3/4
) R! N& `8 g( K1 }# j2 [----------3 M8 m' q& L, A" T" ?* M3 K
Time elapsed 132m 49s
4 Q$ e1 e. T/ t, vtrain Loss: 5.4290 Acc: 0.65487 T* l; R7 B# B4 ]( S! ~
Time elapsed 138m 49s
6 N* g. K0 U( U' Fvalid Loss: 6.4208 Acc: 0.6027
' g, N' B& B, |0 ]Optimizer learning rate : 0.0100000+ c' @* E, ?& r- E
' h% `0 n# [; X. K aEpoch 4/4- Y2 a7 B# V. N8 P* c
----------
) q V$ [2 _ V) c+ S/ [8 tTime elapsed 195m 56s0 r2 r+ y7 c+ X
train Loss: 8.8911 Acc: 0.5519# z) W6 [5 W$ z8 g
Time elapsed 199m 16s
f/ c0 Q# L; |$ C( @. {! K- Ovalid Loss: 13.2221 Acc: 0.4914) Z' m2 j4 F f7 E+ \
Optimizer learning rate : 0.0010000( g5 `2 Q( g) \2 }0 K3 K
; A1 Y0 f4 Z4 g
Training complete in 199m 16s" L0 M, a1 F3 F6 R+ r6 T2 C
Best val Acc: 0.662592# x1 Y/ B7 h. t, l& L4 i( ^/ ~
9 @0 Q0 W" o) r: v# r' i+ @
1
$ i& x, C. ?# p" {$ j21 ^8 P& Q! I& ]9 l3 U
3
9 X/ R$ ^9 Z5 u( i6 z# u p4
5 P4 d0 Z) b' b* @8 k$ p- g$ G5 }3 W, f- b( r+ ~" [9 v+ O0 e
61 a% o7 J* \0 t3 E9 Q
7
M8 J( G- t& t: E, b* \8
0 F E( p) U" `( H9) ^# J: h$ J+ B4 }6 ]! B0 f
10
. J0 V2 ^# V: S; z11, _: V* r$ H7 j3 f5 x* m+ _
12
% z1 v1 U3 I7 W& K3 k13
1 s* r d4 |6 \14
/ I3 R" i" R b15$ W3 `5 D6 N# B0 e7 q7 m# O( O
16# N% I! [) h; u" r0 x' @- @; Q
17
7 G/ X/ Y0 M! X" G2 O. ^. B18
$ W0 u6 X+ V- U4 s. ~; u19# \9 q1 f5 G% A) g
20
* T+ C; ]# `3 k! ?+ O! P217 b4 P5 f# v" u3 e% Y
22
/ M+ N' r6 o% `' o/ o0 F23 g9 b% A2 W+ z9 Z) k2 s
24
: h# d" N% g6 ^$ ]1 z! D+ U25
: ~% v' N5 |$ c; O6 a& U26
) L- j0 E) Y' c# \5 v27
2 x5 }* u2 J r1 i9 ^+ |+ b28
% W- L1 |- f& C! u$ w; m& ^29
3 h% i. T+ ~5 S30( ]3 p* ]2 N( b2 K0 |/ N
31. r) b; Z$ G7 b6 E1 A7 ?* F
32
5 `6 V$ S2 |: m1 M33
- p" X+ G, O/ H6 r34* Y5 f6 ~# H4 a2 {: m9 q1 v
35
+ Q5 k2 T' t8 g8 U2 R3 P, o+ |36
* b: {5 x+ q6 G' a2 h: N7 A37, l1 ~/ \ N+ B
38% H+ R- n6 F3 h7 S
39
% s- @" T0 @( I: O" g' E' z40
9 A. w: G. X9 s5 D$ K3 g7 ?9 D( b. c41& [: l( H; ?1 {7 T( n# @0 W
428 {. n! W6 m; S+ j
7.3 训练所有层
& l8 q9 n3 b5 L& ]8 F, R$ O& [# 将全部网络解锁进行训练2 ?. f# c8 v7 v& }- G. Y
for param in model_ft.parameters():
8 }( }7 e+ @: E param.requires_grad = True
+ o- v" s) E: D3 Q3 E+ u* {: ^9 a0 A
# 再继续训练所有的参数,学习率调小一点\& w/ i: M* r! |4 h/ \( j2 [% j
optimizer = optim.Adam(params_to_update, lr = 1e-4)
" R# }# p q4 z8 {) C7 u u6 Sscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)$ b9 b' i7 M; R& N/ b
9 V4 c$ q( p+ L
# 损失函数
: d5 S% s! p, o C( Jcriterion = nn.NLLLoss()5 x. H3 [$ k( g5 F3 r( x$ C
1) @+ k6 G" `3 P) B. m9 X8 O2 T7 J
2! x$ c2 D3 X* X- f
35 h) N# Y' P; o, U! q) j( ~7 @
45 o% o9 u2 {3 o. D$ p x# ]
5
% f, l0 r9 c" y: ~: ^: y/ M, Q6
7 c) q( ^5 M; w' z7
2 h4 [! R: x N2 ~. y2 V8& B2 S8 W& W v7 O
9 R$ ^2 V- W6 {3 }6 v% t8 S9 C7 R
10
) y" a d% V; J( @$ S6 ^1 `7 }9 u# 加载保存的参数- c2 N5 o/ ?% i) ]
# 并在原有的模型基础上继续训练
! d# }; s0 T0 [6 |# 下面保存的是刚刚训练效果较好的路径
" k) A, j9 D7 m% acheckpoint = torch.load(filename)3 G2 |3 A. n3 i# `
best_acc = checkpoint['best_acc']
) d5 _# I, q9 Lmodel_ft.load_state_dict(checkpoint['state_dict']), Z( {) V4 V3 g: g
optimizer.load_state_dict(checkpoint['optimizer'])
, ]( h y+ Y+ H1
* S2 O' o$ O' A* S20 }2 j/ v, a# [& u8 I2 ?
3
/ t0 d& b ], A! w! R9 K+ X4
, J- f9 h u2 X: Y' ]! {, i5
2 t, s; h/ ?/ f6
6 u. V9 N1 T% w# H7
' `; {% K B' a# i7 W6 H3 `开始训练
6 i/ ?# d* i4 Y. e8 B注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
6 I# h9 ~' @: y% E3 u# ^' k2 S. r0 ~4 G2 c
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
* r: y$ i- e `9 J7 X+ B1. L3 H2 R& G% B2 ]/ L+ o
Epoch 0/1
( J1 R- K0 J- X+ [----------; p8 n5 _7 x2 O7 [; P' |( x
Time elapsed 35m 22s9 D5 N5 N6 c" x' o
train Loss: 1.7636 Acc: 0.7346) ~% b3 ~; e2 U
Time elapsed 38m 42s
0 I2 p( H7 o9 hvalid Loss: 3.6377 Acc: 0.6455# q' m0 L& y, n) y. ^2 x/ y
Optimizer learning rate : 0.0010000$ Q& m6 j! Q: S9 l' y
, e6 ^/ ]' A% Q {( O2 X# jEpoch 1/1
$ Q6 m9 G0 h" Y6 O% G3 Q. A- a----------
* m+ Z8 h8 N0 O9 p: `Time elapsed 82m 59s; M1 I* L2 x4 N+ S4 z
train Loss: 1.7543 Acc: 0.7340
/ o( w5 t% x# V- C7 \4 `Time elapsed 86m 11s& D8 @- e! V* T5 J: n
valid Loss: 3.8275 Acc: 0.6137
0 D' ?1 \# `$ N/ Y3 t, \Optimizer learning rate : 0.0010000# |( Y7 u0 B3 m( h! ~
( B/ n" m8 O, g1 c3 f- e0 _$ A) ^- u$ M6 e" Q
Training complete in 86m 11s# F g" f- n& O* U
Best val Acc: 0.645477- \, _' c7 j! t) g+ n3 {- a
$ w1 J& w* u0 `+ Q4 m1% P/ Q( }' x* H7 m. n
2
$ Z; C; @& H- F, `: T% r& Z3
7 N/ P; p$ J) l; p K4
& d& x3 K: k' `5 G2 m" i6 A52 n8 d. c5 B' d: c, s5 r" X
6
! w2 F) X' B# |3 {9 E7+ H% K( [" Q* \0 T9 z, d( g7 J
82 Z9 q9 t0 @) D' j# ^- B
9
. R1 h! ?. h6 O: x9 ~. J10; Y. w8 R3 h t+ x' L _; d3 `
11; F6 N, G- o4 F7 ^& `7 |, o
12- P/ i+ s! N ?$ b( U ]. a: S' a
13; D" ~5 B. C4 I k U% R( m
144 e+ `2 @& w! T* h0 K7 Q6 G2 _/ j
15
4 R! M. S" p9 d9 Y# j% d16! g- j% M8 ~. c6 T! x" g3 w9 x2 ~
177 \8 x" D; H; ?# ]3 u- w4 H
18
" X8 f3 R2 h0 y/ O/ p" C8. 加载已经训练的模型
5 X7 }0 o% u1 \4 V相当于做一次简单的前向传播(逻辑推理),不用更新参数
) ?7 u. }1 `* ?- x
9 ^ s/ e% q4 ?+ s. gmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True); t1 S! E5 |& b( [! O- V
: b! Y! e4 T% d* p2 e
# GPU 模式& A5 J. H$ l3 L" q! ?, Q* s3 ~& O
model_ft = model_ft.to(device) # 扔到GPU中* k4 A) f. Y! g& r" ^ E
4 W# [9 W2 u6 e/ O& A# 保存文件的名字! e1 {: r( D7 q. K) B# b
filename='checkpoint.pth'
$ W6 l& d3 f# [; E: O, e1 h4 Q3 V% t7 V, ?+ Q8 A: b2 f- D8 j
# 加载模型
' Y1 l0 d; J# P' s. n o1 K9 D0 jcheckpoint = torch.load(filename)
8 \. G% d6 k7 v! L6 M1 U8 m, x. Ubest_acc = checkpoint['best_acc']+ z, ?! y; Y9 e5 @* y; x
model_ft.load_state_dict(checkpoint['state_dict'])
, D9 D0 v6 n. |( f1
" d8 e8 W( a8 f; D% Q4 P2. E5 R* `; n+ g2 ^ s/ R
32 M, g; d% ?: H9 C& P
4
r; R: P: g u$ V56 s5 o8 `8 q% W+ f2 D
64 n# r* Y4 D! H8 H
7
, O# J, ?9 V2 V5 J! T( Y( N `8
0 ]+ K3 a. H! x9( ~5 Z, J, e. U; p# N6 I
100 o, [, b7 U4 C# G4 c Y- G. I4 \
11
3 A4 s% p! ~' a12
# u# m1 v. B( I7 U! L<All keys matched successfully>
& [% |, l+ D0 j m1
8 G, I) C; v3 tdef process_image(image_path):
/ [0 I y* s! G3 F3 A # 读取测试集数据- F( `/ j/ _: j" B" }
img = Image.open(image_path)1 K8 B( D" G* h/ B) l# a. N
# Resize, thumbnail方法只能进行比例缩小,所以进行判断1 u4 W' F x" `4 n. T0 Q4 t6 ^4 P
# 与Resize不同
- z! [* ^0 T' Q' s7 [ A9 Q # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
' C0 |; X! B9 C8 t0 E% u" P# b # 而且对象调用方法会直接改变其大小,返回None
3 x1 l+ b; }( i* Z if img.size[0] > img.size[1]:5 X8 y) L2 L; Z
img.thumbnail((10000, 256))$ e5 ?4 q* L) x; i% I$ b% F! b
else:
# i/ @+ R9 M u# V* K img.thumbnail((256, 10000))2 Z8 w; h# H! i7 d
, b9 N. J: W9 Z( d% h1 n+ L8 B1 A8 ?
# crop操作, 将图像再次裁剪为 224 * 224
1 u) e5 D7 v/ l7 ]$ S! d left_margin = (img.width - 224) / 2 # 取中间的部分
0 _' |& b; S: S bottom_margin = (img.height - 224) / 2
0 [! L2 \: V: H L/ {4 U right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
5 t- s& x/ d# t6 M. q top_margin = bottom_margin + 224
) [: F9 ]) u7 t D) s, k8 K2 g; c4 }. `- `7 w0 e% m
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))6 c' k. {- h3 P3 O. U% |
: X8 R7 n9 Q- g, G `! p # 相同预处理的方法
+ a( \4 ~+ e. ]# Y/ J # 归一化
7 D( d$ c$ ?- }& X' B' U img = np.array(img) / 255# K: p; q- N8 U% o# g
mean = np.array([0.485, 0.456, 0.406])- p0 L' t3 l1 R$ v- \8 P( z$ p
std = np.array([0.229, 0.224, 0.225])
. Q J$ b# U$ }3 z2 {7 O0 L' f img = (img - mean) / std
' Y: [! J, e2 a, [
) S! _+ @% k/ D! F: j% O # 注意颜色通道和位置
! _7 ]! t, v. |5 T0 } img = img.transpose((2, 0, 1))
; W0 ^! V# L+ p# {8 y( |* ~3 }
return img
5 P8 w8 M8 I3 T4 `/ p' l$ I
3 k' a) E# b& K5 hdef imshow(image, ax = None, title = None):
9 y5 @4 j: ?9 l6 I! \7 y# m """展示数据"""' M; l- z! E7 C. q
if ax is None:
. `) V m6 `2 c* ^, [ Y& Z fig, ax = plt.subplots()( t6 r% R- h1 V+ ^" r
8 E: x7 p+ {4 ~! \, Q
# 颜色通道进行还原6 k+ w; p0 \1 U+ _
image = np.array(image).transpose((1, 2, 0))$ D5 ]3 I) O; \, ~" Y9 _3 u8 Q* o
* o6 Y0 X C$ j4 r& |( F# P' A% w) @
# 预处理还原
/ u) \" i; n1 U) P+ x( c/ d mean = np.array([0.485, 0.456, 0.406]) P; p) M/ p/ b7 h3 A
std = np.array([0.229, 0.224, 0.225])
) J; C0 ~1 [( T4 q$ w. t. a image = std * image + mean
* u: v7 v" \1 v image = np.clip(image, 0, 1)* f4 T+ i3 E: b3 V. ~
, z% w( `- i; O! p
ax.imshow(image)
. b& Q' u+ H% f, Y3 v ax.set_title(title)2 `3 U5 ]5 |; n0 a2 V
' I7 I0 c0 Q( a8 J+ G return ax
* o3 L) w3 M( [. |
3 r7 n5 [. M7 Q7 rimage_path = r'./flower_data/valid/3/image_06621.jpg'/ u, N# E3 Q! C
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. `" K, p$ [9 B6 ]. l
imshow(img) N9 U% A0 J: q: B0 q# t0 w
9 I: u8 x% d2 s7 l# J10 h. w! y' J) A2 h0 ?" t( r% s
2. a5 y# J4 K: o' n& ~& K& Q
3) n/ r) T4 ^( L
4 p: Q/ k& O# i+ J0 b
5
. }9 Z5 R+ {2 ]+ ]; d9 @& _6
" Q% W C; c0 y7 ^( _1 M2 G3 | Y
8, V8 X/ c" G+ M7 H- w4 ~/ C% v
9: r, g7 [" V3 h' g
10$ M: t; G- j z2 S
11% ]3 P. O& U# _0 z9 M. C3 c
12. @+ C/ j( i* C2 N" @: S
13
( Q+ ^. l4 x" K% B# }14 }4 ^9 K2 V7 ] I
151 }6 G! k6 x6 X$ ~9 r9 \$ i3 [
16# ~! G5 t' j N% Y# I
17
! _. D( T5 C* h; C0 l7 {: X18
6 A, X, v& a$ D5 _19. a/ e) F7 d( s; X% l$ [! a* L* c
20
, u" J% [7 E% x! ]( ~+ h21
# d) L+ ^# b+ m0 y2 e22
! N- ^( [& w+ b$ D% Y2 o# v6 [( q3 a3 v23 J# j$ C. p; \# u* H
24
6 T( T: i" G o8 ?8 d, r% I25! @. y) _, i9 ^3 x7 u+ C
26
0 P. m% Q9 f/ \6 _+ c27* @. ^$ G$ b. O5 c. ]0 V
28
4 J4 ?( H' G, E$ j) d4 i- A# _/ g, b29 a' T9 L) O- Q- Z
30
# R) e* b7 H6 ]8 R" m310 e3 y% K0 i7 z* D) y- i, A
32
0 d5 K3 }! m1 |! K4 m1 T N. F" [331 L2 ]1 L( d3 u. K$ ~5 m1 X
349 _# m/ W, y2 o4 l7 B
35
" ]# ^$ ^3 o r0 ^/ E366 p* G$ y; s: r
37 Z6 c2 `6 {4 }% Y, p4 m( { X
38( e, _1 d* Z4 `4 R, v
39
/ _) `: f5 p8 b" f( _7 N8 E40
# W5 z3 n) U/ W* V, W% B/ H- a# G410 t* w8 g( U" {
42+ Y% r7 \5 _; m/ G- A
43( J+ `0 D( j+ _
44
5 H9 G4 x- ?& V& s6 W# D2 K45
0 b. X0 n+ T0 n1 S1 g7 x46
4 L$ Q- T; v2 J9 r472 n9 \% i, _0 ~6 r8 c
48
# f( \! t! Q* X( m$ L# E49
: |: X* v. X# f0 Z% X50
* {1 _4 @( ]) f! y- v9 v51
& o' n9 M" p7 H# ^2 n527 }+ Y, l# k) L& q% D; ]9 s( {! S8 N3 s
53
7 y1 A$ w- d7 m$ z3 w549 l3 Q" W) A1 W# r ~& m4 x. e9 `
<AxesSubplot:>3 a+ Y- b _7 P9 i8 [
1& @5 @3 n; g. B1 X! |8 ~" ~6 q) f
* C8 }. @9 U$ }3 v* a6 ^1 ^0 g上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确' M$ n9 r/ l) r
7 {! K9 r q/ b6 ]" S* ]- q
img.shape* l+ M0 x# s; q5 j
1* L( z+ O9 i8 T6 ~- U
(3, 224, 224)
1 _- O1 a8 }7 u$ f- _0 ?) ^1
" A- }- N5 V) L5 D8 x证明了通道提前了,而且大小没改变
( j; ?! M- m- n7 D9 d0 y4 P) X2 {# D9 d
9. 推理' V: o S' E. S( m5 A& l& _* j; q4 Q
img.shape
/ l5 y4 T3 _' @: o& Q7 s5 v& _- R0 \
6 p, V$ a7 i, g/ h. Y# 得到一个batch的测试数据 E; O) m1 D# ^! q. w: r+ ?
dataiter = iter(dataloaders['valid'])
) P' ]% F9 o) Q, `) ^* T; e( yimages, labels = dataiter.next()5 m$ \. Z0 j1 y- d
4 j* C9 \ J0 k \' n
model_ft.eval()
8 b4 C8 H, ~* i; o3 X5 z. Q( T" t
1 |0 H$ b' k! U2 a; eif train_on_gpu:
6 A* P& ~' a# { # 前向传播跑一次会得到output* P1 M( |; F3 s5 \) E: Y7 J! f" K
output = model_ft(images.cuda())) K& S- i/ r, }) V
else:
! ?4 b8 h. e! T output = model_ft(images)
. A- g0 q2 M/ k5 c: V, R N Y, b% W
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
+ O, ^8 g( {' V/ [5 g1 moutput.shape, B6 H8 u( K0 C8 W8 A8 k
! t f+ X: D" ?% d0 G- i
1
# z0 H Y. E7 a6 i$ _. v; u. H& c2
: A/ F0 G2 u# u* R9 N! N$ l3
- |0 L0 x2 h- b, e* g2 C' L40 }/ t$ e4 i E' |1 r
5
; b5 @( D. z5 F69 L! q: T' k# `# g- _. q
7) Z* B" g* M0 m. _& s
8
3 U" M% [8 P9 k) W( |9
9 I! B& y7 ~8 l$ Q9 \10
) J# k$ C/ u1 D8 y" M11
, H: I$ U* g3 V. v& A, D2 s12
0 b- P f1 ]. T1 ?- j13
0 G& ~: _! u( p( m0 v s14
" C! B9 F( `4 O' n/ i# _* ^& H, t15
5 N& |& z7 g' L. T16
: |0 v' a2 j+ v) n1 }. M) O9 ]torch.Size([8, 102])
U J, n! P: y- c7 f, P1) }9 c1 a# \ q! C
9.1 计算得到最大概率# Y: s: |/ D0 T7 C, h
_, preds_tensor = torch.max(output, 1)
% b! c5 e, Y! o' A9 H
. a' t8 e$ Y& v% i9 O# a. S: zpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
6 N, N" f! K7 K- a1+ J' `, } ~& ?
2
7 {$ Z' Q6 o, I! O37 i9 D2 S1 H" U9 K7 @- D3 i
9.2 展示预测结果
8 A3 |; N3 }5 ?+ r! w" r Hfig = plt.figure(figsize = (20, 20))0 y2 E+ C& Q m# `6 F
columns = 47 V9 [4 F! p( _; T3 G0 o! O' c& d
rows = 2
4 U6 n0 }3 e' _* A$ \
4 n" W) v7 X: ~' H: |+ K6 _; Wfor idx in range(columns * rows):0 {! @2 N* _( m' Z
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[]). B+ ^0 d5 L* ~8 G: F- [
plt.imshow(im_convert(images[idx]))
7 i; K& @' n7 K* T9 E ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), - m* m5 ~5 X$ }
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red")), s7 }! N* t5 Z! T" q0 K
plt.show()
2 z8 r0 G* ~" S' ~, f* ~5 ^# 绿色的表示预测是对的,红色表示预测错了' Z0 [- n" M$ K& ]2 ^
14 E8 y/ q5 x) n: K4 c$ x0 ?
2
8 X9 F2 [/ L: u0 b6 n3
- w+ U) a0 |) ~4
6 C* H( l7 O% `, S3 V+ t+ i5
7 H5 u7 [# L: n, s0 S6
- {5 c# ?. J- |& S7
. W" w" a7 N5 Q- p$ T, J; ]0 g: R89 Z5 K! b0 e4 [! x8 M
9
" _1 B$ |# Y i# j+ m10
/ V4 a( g0 \- m# X116 ]2 z) D% n! U( g3 @+ u0 _
1 X% j: \) e. X- s
. L$ B- e9 I& P3 o8 q7 a3 |( f( E
9 I' \" b8 _: v' L————————————————8 Z- a* F9 C2 i
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。% }1 {+ f% O* ?/ ?8 E9 @6 @7 Q
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
; a$ Q( U# c' A- Q
( \/ f( r" p% q- ?& W
* E; [( u$ i2 W" O* T# N |
zan
|