- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563304 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174214
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
+ e- q/ `2 ]$ `6 m
6 D" [* [9 J: i文章目录 d/ b# {1 W5 v# w& I
卷积网络实战 对花进行分类/ J% s+ N$ E2 y5 F/ G
数据预处理部分9 w2 D- V+ E; _$ t6 ?+ C& @ J
网络模块设置' ]9 z3 d1 w4 V0 o5 d
网络模型的保存与测试# W* P3 @+ r1 ]: N8 e9 S
数据下载:7 e9 o/ s3 S2 R! k
1. 导入工具包
# N& X b c& [7 h2. 数据预处理与操作
& }: T/ [* D, [6 |4 r' j% c8 {: u9 q3. 制作好数据源
4 d p) _4 s6 b; I5 P6 U1 ?/ C$ z读取标签对应的实际名字+ T1 H. ^, t/ R; `6 i0 v
4.展示一下数据
$ M6 y! R/ b0 g# Y2 b2 h5. 加载models提供的模型,并直接用训练好的权重做初始化参数
' x. }- A3 x+ M( a6.初始化模型架构+ k5 l! } R/ ^) _. [; n
7. 设置需要训练的参数
4 o$ t4 `8 p: o) s2 T, W0 _7. 训练与预测 n4 s$ n* ]. z# Q
7.1 优化器设置
' N2 p4 }+ h( R0 N7.2 开始训练模型
1 k" }5 V' P0 x+ X1 }2 e7.3 训练所有层
, D/ r2 y$ R+ U开始训练1 V8 k2 g% k7 {
8. 加载已经训练的模型5 z* Q% E9 L9 g) N. K2 M
9. 推理
$ }5 j% T) Y! ~1 l W9.1 计算得到最大概率
; u$ J! B+ S, a9 i! Z" F$ H- \2 v9.2 展示预测结果
9 y# z/ Y% P, r( i, |/ U/ ^写在最后3 n+ b- `) V! ^& |5 R# U, e
卷积网络实战 对花进行分类
8 b. n3 X* h; Z# ~" G2 D本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能( ]2 X& M& E E* Q+ n) L" b
; x# [6 |& K) x! g { o
在文件夹中有102种花,我们主要要对这些花进行分类任务- t8 g( R# L* S
文件夹结构
- F# Z P" O6 i* D( j. R2 Q
' Q) u$ @+ U9 wflower_data
3 W" M; D/ z! D/ C8 k e I
2 h$ M( Q9 I ktrain9 j6 q2 ?) V* s# e& ^1 ^1 b
0 N7 T5 @. P+ `( l, y1(类别)& K" L& ?7 b J2 N% a! X: \. u
2
; ~" g& W; \! c ^$ W. uxxx.png / xxx.jpg
- a) P y6 R" j4 O3 Ivalid
0 V" ^+ ~* G e f; Z6 E O, O% j- [" ^0 ~
主要分为以下几个大模块8 u5 R8 L" m# {: r( u0 Y
) t$ z/ G' ?; s/ P0 m
数据预处理部分0 [) W b% p+ D- r+ F. }0 p
数据增强- a( F @6 P' o# O- }
数据预处理4 F3 i" U9 j1 u) K! L7 Y5 l
网络模块设置
3 t5 A: u& q4 X7 f! V# |& x加载预训练模型,直接调用torchVision的经典网络架构# K' \1 m( M( G- J) h% E8 M
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务0 |) D) g! t4 C: Y# x
网络模型的保存与测试
+ a S# \, R6 H- ?$ o模型保存可以带有选择性
% J" q& G% `1 J+ B: p数据下载:
: L: e, K6 ^0 {' m3 ehttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
2 E! V* j- g+ V' i8 W8 T' P- H% m% m$ D! ^- P
改一下文件名,然后将它放到同一根目录就可以了
$ ?+ E$ n7 T, O! S( v; e* H/ h: M5 V9 T5 d3 d4 Q
下面是我的数据根目录
1 p. B+ h2 J: }" n# g. \1 b7 E: A5 _6 j: x$ @
- G) ~" `$ }! q* v
1. 导入工具包" E' |, o3 C" r6 y
import os0 ^* Z# v0 A6 Y2 w, F' G
import matplotlib.pyplot as plt9 [" p8 \' ~. }0 r+ m
# 内嵌入绘图简去show的句柄! l. O0 J4 m5 K0 H0 S/ u& v5 S
%matplotlib inline
' G" Q. P1 `4 ~import numpy as np1 x( y+ d4 Q9 M3 U7 [% p; {
import torch! U# t2 f3 x, b# a0 I+ ~
from torch import nn3 z; D! r. D3 @9 O& ~ @
5 q6 x6 e+ v% [ w$ Iimport torch.optim as optim2 o/ u* e: m2 k4 \' H+ X3 j
import torchvision: C0 r: \: H" y
from torchvision import transforms, models, datasets
" U" B# w' U1 H$ |6 k( j; v$ W( f/ B* P3 B0 I
import imageio5 z6 K1 |" y+ P3 `/ \8 t) y
import time
! C7 |' I. p5 ^5 u9 g7 G& pimport warnings/ h# y2 p# @2 Y! \, r$ U- H- M4 U. J
import random8 y# \0 e7 R0 g' Z
import sys
0 B& ]1 c5 c% a. A/ \8 V" rimport copy
& w1 d R9 i, O1 t* R3 G0 |import json
4 G9 x+ \# f* h# c1 i: I% T& pfrom PIL import Image- m0 ]1 `1 B, H8 ]5 u6 m
3 Q6 m, f* y# \1 `
, h3 [, ]6 k& V6 }8 z! z
1
% [) y& {: E- [0 s$ W: {0 C2! w. l W! i: V, ~
3
: X3 z" I- a; P8 N+ Q4
5 k! {/ U7 z- k+ n4 T. }2 m, J! `5
: K! A: z: C8 ^4 b& u: k. X/ N, ?6
/ z# c) R% J2 j5 R7& a4 ~1 N6 |" C! l/ y
8
3 A: Z Y% ^0 S' F% s# d9 D9" _4 `; T2 u+ C% ]
10
" j4 v. I' o/ S8 K11$ y) y% U' j% w0 Y4 H9 w
12
/ Q7 X) }- Z7 Y, w130 n9 P: r, X+ z" M% t5 r
14
+ K2 D; H/ I- d. C15% |0 M* @; `$ Y, ]3 o5 }
16
$ S" {! U! d4 x, Y- c: l# L7 j17
7 V, x" v6 k6 |, A5 b3 G18
1 U% J/ m5 B$ k19
* r$ U1 `% n% s7 e20
7 {4 W1 O+ Z+ h: y: }, e# |21
5 ^) J1 f' [+ d/ Y% b2 x2. 数据预处理与操作8 G/ }: X6 T: n/ G
#路径设置
; [3 \: ~& B/ Z8 R5 O2 y! Z1 ddata_dir = './flower_data/' # 当前文件夹下的flowerdata目录
$ M5 b6 i( @! v6 f \train_dir = data_dir + '/train'0 j/ @% ?* M6 Z+ y( c) D, X
valid_dir = data_dir + '/valid'/ G7 G9 |9 q+ Y3 U' t
1. s# s. A6 |* W8 }; P* S
2& N* H/ l' C- E; q7 p" a
3
+ Q- h5 M' S& v41 b3 O% b9 ?, B/ W d: D
python目录点杠的组合与区别
# @3 n6 }% |0 n6 e注: 里面注明了点杠和斜杠的操作
) Z" N% p( M) P5 B4 a, k0 x. y* [' A! O
3. 制作好数据源* m3 K& t8 h+ `
data_transforms中制定了所有图像预处理的操作
, A# | Q: G& S1 u$ oImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
' p% d$ Z5 V8 m/ K1 F0 ddata_transforms = {
' C! X; q' m& c' d9 s2 \# ] # 分成两部分,一部分是训练
: W, q0 A# K- B, w& ? 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
9 r; P& G8 r/ W+ Z; @8 a% T transforms.CenterCrop(224), # 从中心处开始裁剪
* c% T, Q$ R- D/ I( A& v4 I; q1 ~7 k3 {2 f # 以某个随机的概率决定是否翻转 55开4 x' C9 r; V9 ]# x% ^ t( u5 i0 w
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
4 o! E' `% |! m& e transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
! z4 x/ D8 G9 h* f, } # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
) r& q2 T& m8 p. _7 O, { transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),, F$ f9 n: x8 I8 R9 r a
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
8 _& c# h1 }3 D9 { # 灰度图转换以后也是三个通道,但是只是RGB是一样的" P" P) C+ U8 W5 C) D1 x, u5 _
transforms.ToTensor(),
7 a: M% G& X6 { transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差: }9 a: q! z/ E( W7 E" Y. g* Z" A
]),: r E0 x6 h; Y
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化: X! p3 Q( {, s2 m+ f b& [* X& |
'valid': transforms.Compose([transforms.Resize(256),
( n" d" w* }" x- d3 i2 {2 r transforms.CenterCrop(224),
- }8 M( g( A0 N) r0 z4 J4 c3 I transforms.ToTensor(),
; K5 G$ k, q6 g- V4 ^4 G ~5 W# B transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同' U3 w% o! B* e
]),
5 E1 T6 X( y6 T6 `: y}
: n$ `; M0 o% t
" k/ q; {- g& {+ A/ O. z( x7 r1! i! c7 @' X- W, |* _
2
1 q2 k: j% X; p, O39 ?& e ^" V# Z& k
4
1 y0 Z8 R1 H% o. ^; w- E5
) T" v8 T+ O0 k! h9 `# F$ P5 O6
) Y( u2 o) y) ^3 z7
# m) l" }& q" k. h+ Q2 ]8" X( ?, A- {3 p2 Y; t% I
93 p+ y. _; H0 v2 H/ I, f1 p. ?& i
10
) T+ G8 a5 g( R% k0 Y11
& N; h% x1 n1 J+ S12
! T9 o8 e) p- R% k3 q9 r13
. l$ y) R+ d% h3 O( D% w14
1 s7 g$ \' A. Z3 c0 ^8 U& r7 X" J15" u1 W1 B7 C0 s6 W7 P) f
16% c) x3 T1 J! d- k. Z% l
17! }( x+ y1 P6 i! }6 H# v1 E% w
18
% u9 S2 V( @: [1 `3 n* N! f19
% f% i/ ^% J/ I y( x/ x9 ]20
' C& q* c5 N: B% K& n" u% k" l21
9 V( r! J0 ^2 k5 k5 \batch_size = 8
1 i7 Q6 p$ c* u' C7 Limage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
- H5 s% d( s) y& h+ wdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
4 j4 t* W- B+ O7 D3 M: X7 Qdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} # o% k V( Q: A' ]+ X4 S
class_names = image_datasets['train'].classes$ O, o2 J4 i! d4 J
: w3 `- a) n/ m# k- ^
#查看数据集合
. r! p' f5 e8 m8 ~" D8 K% qimage_datasets" q- n5 h1 ~5 I5 x
4 t& r9 G6 a( B$ A8 |/ F1# s( `5 B2 X8 Q" T5 C
2
4 W( m, s4 i, H9 G; j4 d' H7 D3
$ B; c5 Q) C5 W8 h* L1 f) u% [2 s4
* e8 F8 L7 [( `4 Y5
! A6 t, E; L; O7 a4 R7 | j6
: q3 P* n8 C, o) o0 d( d7
+ A' u$ d N! \$ z, A- _$ N2 j) G M8 z, ~/ o; I) Y
9
( \5 h9 T" |% I& M- M! J; i0 P, ?{'train': Dataset ImageFolder
" x8 p( K' I0 w Number of datapoints: 6552. g9 h2 f$ U* c5 ~' k3 g( w, A
Root location: ./flower_data/train
" w5 H$ N" T1 R8 T! c StandardTransform
' t9 |+ }1 ^0 {2 d Transform: Compose(+ ]% p2 H7 o) H4 V$ c* h3 d
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)0 D" P4 t7 D7 `" y, m
CenterCrop(size=(224, 224))
6 d% X. }* |. ? RandomHorizontalFlip(p=0.5)
9 ~2 _! G( l. f, x- M6 D4 H- W RandomVerticalFlip(p=0.5)0 t1 n4 G# [! L0 P6 E
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
5 X% z t. {5 p! y5 I* t RandomGrayscale(p=0.025) k. S5 P5 N: B$ m: w5 H d, t% u
ToTensor()- L! ^7 w0 g/ l7 s, a7 C U
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- U3 ~9 J# w. r( }8 b- n# G6 _ ),
0 y: \" _& {7 ^1 E* \' U- _ S& Z! m% l 'valid': Dataset ImageFolder
. r. y- p/ A6 i& Z5 r Number of datapoints: 818
: c; g! y( V# N Root location: ./flower_data/valid( J% ~( t; L+ g3 Z# d N: V- w
StandardTransform, N* Z ^. S5 E- v, r
Transform: Compose(
3 D) i" i/ O! X6 E- M+ s" v" L Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
* ^/ |4 ~ A( s2 I* [: V0 d3 f: U CenterCrop(size=(224, 224))" T+ b7 ?% @% x8 p4 j' Q3 Z1 f
ToTensor()" \7 h) u! F! k/ |, y' A$ T
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
; I, k$ b) {& C* S5 p! q. u' Z )}
- ~# ?6 N! R! a% K0 T2 Z) r& }2 {: ^7 M# Y% J) ^6 f7 g
19 O6 D/ o6 @2 E: Z$ w
2
2 u( \3 r" f! ]: r) w7 J. F0 v3( D2 v8 I6 m7 ^$ @, d# }
4
- I h4 i) W0 B+ H5) P! ` }. Z& r
6
. J& ?( q. m* W, p2 ^7
% r- T& \9 ~; o5 @% q8
1 s- W; i7 r+ ?1 s. B, M7 R: N9
1 T8 B e) P5 f' h+ S6 z& E10
! X& R2 g% v1 A3 E+ m7 g; u11
# b2 `# _7 W; c120 L2 c! C6 D/ D% S% O
138 }: I' Z) [& e: c+ J6 X" I
14) B3 k1 ?" V& n; w7 A2 v0 ~: l
15
}# S# j5 n% i1 S A* b16: G4 L. X7 A' z3 a: L" e. M
17! F- g# R1 B/ p" R+ T- y: T
18
1 ?6 U2 F0 } u& |' _0 Z19
; Q ^& r) i D9 a. C200 x" g9 I" s! S1 o9 q1 U/ E% K
21" X' ?! r8 x$ |
22- ^4 u$ B! K7 o- s% V! \$ B
23# c: A( e$ g; Z& R3 g# k
24
0 q$ }8 N8 V4 O- w- b! U6 `# 验证一下数据是否已经被处理完毕
- A2 \% D4 [. n) s( `dataloaders
0 R$ R N# @. Y: P) F5 n' `" R& D1& U' b; q J0 ^* a- v
2
: s2 T$ I) D: `% q; F* L* M- i{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,* K7 h) ?$ n. p( }, d: t
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}+ |8 R7 P$ n9 z0 i' m% l
1- }8 d1 l- H9 X3 N( M
2
3 [4 N: c$ m; Qdataset_sizes
/ k, B0 a( A1 |$ s2 I1( Q2 h# F3 [8 Y1 R4 V
{'train': 6552, 'valid': 818}/ o- x8 t! s, R H6 d
1) k4 _1 p7 k8 ]+ P
读取标签对应的实际名字: \: r7 j0 g& N7 Z" ?
使用同一目录下的json文件,反向映射出花对应的名字0 b0 V8 W: i y) k/ i8 |& `& ~9 R j
# c9 ?4 g( E" o- Y- B% P' S$ h
with open('./flower_data/cat_to_name.json', 'r') as f:
3 L/ H# Q9 Z+ ?; W7 t/ G4 E cat_to_name = json.load(f)
( H/ Q" W/ J, B' k/ g2 f1; Y9 U2 A$ p% p% X3 K+ T3 C
2
$ I8 a: n" }3 icat_to_name9 k2 ?8 v8 ]7 \1 N
15 J/ \! q! b9 \. D
{'21': 'fire lily',
; E9 K# ^. ^# z3 e '3': 'canterbury bells',
0 t7 Y4 v" Z- W8 S5 c '45': 'bolero deep blue',$ }" y, Z+ ~; ?
'1': 'pink primrose',
! K/ c' k* F- W+ a4 x '34': 'mexican aster',5 y% m1 ?/ E; D1 V H! j
'27': 'prince of wales feathers',5 b9 A0 l' S$ M
'7': 'moon orchid',' Y3 f+ l0 g3 y. X# Q
'16': 'globe-flower',! n# I; v0 d& U2 S
'25': 'grape hyacinth',) X* _+ r ^# d* h2 G
'26': 'corn poppy',3 t1 p) t9 W( |1 K! o
'79': 'toad lily',+ i5 O8 W& I4 M/ c! I# g
'39': 'siam tulip',
. s" S4 f- J* O# b0 v4 |4 J '24': 'red ginger',) v; q/ G! i& V8 a3 o& Y& ^& w
'67': 'spring crocus',( \# j. R7 V8 O8 ?! }( \! H3 B) }: M
'35': 'alpine sea holly',! I, y9 X2 ?7 P- Y
'32': 'garden phlox',
8 @ N ^: d- }; ? b) u! r '10': 'globe thistle',: b2 W1 c1 w4 t; }1 t
'6': 'tiger lily',9 x8 P8 y' h8 J/ G) l; K1 B
'93': 'ball moss',
& N6 k2 d7 R, ^/ Y, b0 u4 g% v '33': 'love in the mist',! b; \ r, e4 h# }# G; B
'9': 'monkshood',
9 P1 H$ P) \" {* y3 e '102': 'blackberry lily',6 a# [5 F& s6 t6 s
'14': 'spear thistle',4 }6 O+ f1 A. C* R2 L5 l: |) Y
'19': 'balloon flower',) j P4 t3 l& m) L" i& u$ A$ K
'100': 'blanket flower',
) l3 V- S7 d S3 C '13': 'king protea',
; ?8 v+ A, Y X '49': 'oxeye daisy',: K' r- X9 I" }" R T% R
'15': 'yellow iris',$ F8 D. O+ W1 s$ |
'61': 'cautleya spicata',4 Y d4 x6 ], k, l% Z: |* n1 ^# T1 @
'31': 'carnation',; A f8 l% i+ R& l5 c, W: |$ N1 M( i
'64': 'silverbush',- y7 v; \" ~) O( D
'68': 'bearded iris',# ~" b& o {1 Z0 ~
'63': 'black-eyed susan',
8 ]- V$ a7 _8 }- M '69': 'windflower'," E' @- u8 R7 {) S$ m
'62': 'japanese anemone',
' ~% {; v, l- k, K% M. k '20': 'giant white arum lily',
, r! k* a% E% N! _ x. u '38': 'great masterwort',
- g: h# }# `$ o, [; _* u; u8 ] '4': 'sweet pea',
$ c: ]# s7 D8 b5 r '86': 'tree mallow',. Q3 U7 @( J i" D- s$ ]
'101': 'trumpet creeper',
5 d' R8 h8 C$ o! \' \) T$ ^. ^) X' h '42': 'daffodil',; e1 |* F4 Q5 }+ a0 ^2 \1 A
'22': 'pincushion flower',
5 q$ P1 I1 S$ P: m9 B2 v '2': 'hard-leaved pocket orchid',
3 t5 k0 L7 b/ f '54': 'sunflower',
% A& m8 }/ _/ B5 w' z4 _ '66': 'osteospermum',
. [7 ^; S# G! p. t+ H- R '70': 'tree poppy',
6 F$ }) e1 A$ t2 g, a; J, z '85': 'desert-rose',
2 z# ~0 C: ?) O( N' G '99': 'bromelia',6 e' }3 r9 Z0 W( N9 _& C* o
'87': 'magnolia',
7 k7 ~) u, U5 ]7 R) r( U '5': 'english marigold',
l ]/ J& }/ [8 p) q. ?/ l '92': 'bee balm',. m, m" d3 C+ S' j8 ^/ }# g4 k
'28': 'stemless gentian',* H# C% Z3 Z% I. E, }
'97': 'mallow'," |) Z$ E* q' m
'57': 'gaura',9 [8 E" ?# s6 N1 _0 Y N9 N7 \
'40': 'lenten rose',5 L/ H% d/ X: [( R5 K( @, \* s8 d8 Y
'47': 'marigold',* Y* h/ l! F/ n5 m# H3 }
'59': 'orange dahlia',
* h- s2 T5 }5 r '48': 'buttercup',# L$ Q6 \, P! u& ?! w* H, n1 @' U
'55': 'pelargonium',$ a; p6 ?- V8 m& [4 C
'36': 'ruby-lipped cattleya',
4 B3 m9 K6 X( O$ Z '91': 'hippeastrum', h6 s! u3 {* X6 N+ u
'29': 'artichoke',
& _4 w/ l# G, q( [5 r) r '71': 'gazania',, w) h# o+ _6 Q% l* ~, a
'90': 'canna lily',
0 W1 w, e+ r- J1 c '18': 'peruvian lily',
# z n- J) \ A2 w( {( o '98': 'mexican petunia'," p1 \/ F s, r/ M8 y% b
'8': 'bird of paradise',( a! S+ P1 [. T: X5 ?9 A
'30': 'sweet william',9 K7 r) @# e6 g+ V, T7 ?
'17': 'purple coneflower',* ^2 W* D9 M, p; k( Y$ W
'52': 'wild pansy',
/ E) T! P6 d5 F. h& A! r '84': 'columbine',
) o& Z9 ?6 \- p5 R( C! B- [ '12': "colt's foot",
0 S' P* U5 r p '11': 'snapdragon',4 d. Y0 Y5 P( b: W
'96': 'camellia',( F: Q) K4 b% ]2 \
'23': 'fritillary',7 S! r: l! @' @
'50': 'common dandelion',
( t2 U: |: r8 v0 _ '44': 'poinsettia',& o H6 q* {$ s& ^! A
'53': 'primula',' D! f7 D3 ^1 [$ g) n
'72': 'azalea',1 N3 c0 F# V! t" \" k: Z. G$ ]8 ~
'65': 'californian poppy',
7 x1 R3 Y2 f# ~ '80': 'anthurium',: d5 B1 H( l5 ?8 c: c) M; P
'76': 'morning glory',, M5 V6 U4 j" |% {& }2 V2 Y
'37': 'cape flower', r: X6 }+ q2 r% C! t9 E6 Y% ^
'56': 'bishop of llandaff',
$ Y$ H, H6 m8 L$ ~ a0 o '60': 'pink-yellow dahlia',
9 [3 y. u/ Q9 C '82': 'clematis',
' W5 a% Z) e9 H; M4 h '58': 'geranium',
7 y3 d: x: @. A, C% B, V '75': 'thorn apple',' \; \3 x* L* M2 l9 l; B
'41': 'barbeton daisy',6 U$ V" _. k- m) e2 o! a
'95': 'bougainvillea',# C. k0 a0 [- a5 N% l
'43': 'sword lily',% V$ K2 ]7 z6 x+ q
'83': 'hibiscus',. e6 s+ |7 r0 z2 k0 f( ?0 z
'78': 'lotus lotus',
6 N/ g2 C- `" ]! C8 d4 o& d9 E '88': 'cyclamen',; S3 ?# ?% h" P3 r: B$ |; U
'94': 'foxglove',
3 D9 Q" ~' j6 m8 q) b D '81': 'frangipani',- J9 e% w0 S! t8 x% I. q
'74': 'rose',6 ~ [0 M1 O. \+ @+ J- G i- r
'89': 'watercress',
/ e" |; y$ H+ G2 n1 ?' P '73': 'water lily',
; w0 z# @3 B$ m8 k$ e* y '46': 'wallflower',
, V! n C6 ] \4 f0 Y9 k& b6 v '77': 'passion flower',9 y; j& `: e; B; r( T8 F) x
'51': 'petunia'}1 X }2 X3 _2 ]% ]7 X5 J7 l
! M. I; o; Z) D; G0 f ~
13 x- P7 l* g2 Z+ h( x. M e" B2 p
2! [ G8 K- W4 Y' d# @# X/ U
3) o# f$ f5 N5 X1 w# K1 Y
4
" q4 g6 d0 C$ Y( m) v5
7 H4 R4 t/ ], t: k( o! M2 N( Q& u6& @ Y$ h6 |: G# d7 I& U4 {
7 F! y+ |& q9 J. r" a! G
86 O% g2 d# g( G+ C& |
9 {, r. C; @* c: u1 j' R0 [
10 u, ~3 c' k( h2 o0 M1 F/ d+ ]
11
% x1 |2 Q+ L3 d9 \, A- u4 S2 ~/ e. h4 `12
8 B+ ~$ L+ s4 e9 |" A0 g4 e13" r, P' e+ w) t2 N H* t9 i
14
z/ A2 e( S8 p$ f9 E- D4 O15
" x! t' Y# t% b" t: ~16% Z E: l/ F9 T) M) ~' B
17
5 V4 y/ ]1 ]& Q+ e) Y) s) Y8 G& B18 h; G$ E3 m7 X* v& Y
19+ W; |' D' K# B) O6 m2 ^
20! ^5 H9 t9 a* w% M& _, Q
21
* r* Q, y3 S- J7 `4 i& }22& _) D1 w) d, a6 g4 O6 F: n# g2 c
23
# W" Z0 @9 f: z0 a* q# S24
2 ?. P# F+ r3 |25
9 s% p" ?" _, f# [ o; p) w262 Q; V) d7 D' ], I$ u+ \
27
) Q9 S1 x2 H$ k( T' O$ F7 A28
% K7 d9 U* j4 h/ H/ ~9 W* C29
6 M+ S7 M8 N' Q+ |30& u$ L; ? ?& `- _; e9 j2 h
31
5 O6 a, g, S5 y7 X3 N5 A32% E( J( ~3 m/ E& [3 n0 T$ Z
33
0 f$ Q) S5 s$ Y. B! J: K* D, b3 Q- x34
7 N8 Y8 R& m* a. U' Q0 S0 w- [35
& {% j* \3 s" x2 Q/ \# m3 l36! ]' L8 k$ c! u8 H1 v% [! q
37
+ g' A+ X1 N4 T+ x5 K+ q382 ]/ g+ l" z5 h
39
+ T8 B7 D# Z3 U" T/ n( n! s8 D40+ s& s# X& I( H! r6 Q6 w
41
* Z3 T5 j" L* y( Q. ]0 N8 H42
/ U7 v4 S3 `7 f43# t+ J, f' V5 K1 V9 e3 b" C3 f
44
) `4 w7 M1 R5 v1 p' J0 W45
6 v; g" m2 p# l9 m P l: l46
( r! c7 l1 w( Q" p9 B2 R& u475 a7 v$ ]) t! w# Y7 {: H' q
48
3 ?* Y" [9 |4 Y8 \3 D; b9 W49" W5 q! r* S j' r, g
508 t ^) `% {/ c. V4 h/ b5 o$ L
51
9 K! s$ W4 f$ c; q7 k+ A- K- o: i2 ~52; Z! [" E3 h4 L/ B* c( }1 o* ?3 Q
53
( u, X2 B9 j, Z5 s544 t/ O4 E6 y/ u$ A6 u
550 ^1 K \0 H) K, C6 `/ m
56
/ ?3 p+ G. T6 q+ Y4 A3 v57
) K, S7 m! U% |' X4 [/ |& ~58
% _) @: _; q* b# {4 ~* p597 y; F) z1 ~" } U }# a6 C
60
6 ^1 s) Z) ^0 ^4 H$ j61
, v5 k: W5 h V/ q62
, z1 s$ H! {. h. ?" v0 ^* F) k63
5 F, }" z% W$ ^6 u$ q* A2 f# `645 y& ]4 ^( \3 A3 e1 k2 ^/ b
65
% O$ M; t! Z+ o" Y( c, J9 E66
% I. Z3 `7 Z% a4 m67: w- K' x7 \, y6 K" A( l" l
68
0 m& Z: y4 D' ^7 _0 `' G5 @69, T+ X0 o3 S3 w; ]5 l- `
70
`6 Q' q+ C/ f1 P, B7 Z$ S" r71
) \0 b1 o# x7 f! r5 o72
1 A3 Q6 F# A" v: k% z- ~73! t' `/ I+ ^6 l' H0 G
74
7 o) l4 }( ?4 b( T" `& q75( p& {7 P, \( x; z: B( `
76
% L! d6 c" A! V3 L: h; V0 \3 i77
! a) b; k! n5 K" R/ k78
8 x2 {' J" o" D [ x79
% c4 n! d# V5 p( y80
5 m' R/ X* Q' M* X* p81, M' C/ w7 c% g7 y5 a; }
82( R3 _! k% k* ?& R5 I: D+ t. J
83# F; y& H' t8 z& }* V" s# A
849 I# L0 z8 G# l, }3 `. }7 j; h( ]8 w
85
$ _" p7 f9 j# C" j! l6 b# F* o$ ?& l86+ }" d* H! v* O- N/ I# x* }: O
87* J% z) f [/ N+ I5 M" e
88* \' C% W! L* i7 H6 T2 l
892 s. m8 W+ @+ l* o# p% ~1 f
90% w: \. { U' _5 B/ [8 O
91* U" ?3 u0 K0 T2 |' I
92$ C: E1 D2 k @ {0 x+ g U% B) e$ A
93
& X* j" _! n/ A7 U: R# M94. U+ @, w$ C1 {( J1 ~1 S
95" ^) ^" i. A2 X/ y7 s1 C- j
962 }, O& k9 H5 v2 s
973 F* C e0 m( v- |
98# c9 `) ^1 f4 m
99
& ^( a$ P9 y& h/ R5 _100
# m; |* m& F+ |2 y" q _1 E101
5 K% ]$ V# H) n1 h4 Z6 i102
% @9 Q- T% \5 |, s/ d% _& X! Y4.展示一下数据
5 d& r, ]1 P6 C1 E" T5 `, sdef im_convert(tensor):
1 |! @$ W5 Y+ u2 U, X """数据展示"""( C1 w6 p5 _, ^9 r( `# x7 k! [
image = tensor.to("cpu").clone().detach()
$ U9 D/ P9 [. W5 ?3 w7 X+ q image = image.numpy().squeeze()
9 A/ X& I5 _1 a; W- T7 U # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
- F: Z% Z/ G& P$ L3 P$ j # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
- e. m7 {2 C) f! h; E! x S image = image.transpose(1, 2, 0)
' K7 `+ ~2 L5 F) @6 I* w' O& \ # 反正则化(反标准化)
7 H5 U1 e z+ S9 v( U2 K image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
( d1 n! }, D( y3 L$ L* R6 b8 Y$ \+ \" U
# 将图像中小于0 的都换成0,大于的都变成14 [* Q0 A7 G- F( E7 E: c
image = image.clip(0, 1)
3 U. G% D! W$ _& U6 ]0 K! ?% T- {
return image
% G5 P9 B$ v! e% [, m( W1 ]1
: i6 N. w9 V- Z r; c2
( J- _; x8 v K# E0 z34 j* I% }1 T/ O$ a3 P6 W
4* s% F. j! Z) Q) G; P8 \- k$ U" K* p
5
" D. z+ I$ h. z9 ?6, X/ d' P+ m# a+ O8 x
73 y9 G1 J( {' R0 _& C4 j
8
3 M9 l& p# {' T! p+ e: u% ?5 h& D90 _5 r- }4 L. L' `9 f+ f+ b+ e
10' R1 c4 ^- s2 n5 e$ Q- P6 W
11
% W3 m8 K3 k) h, J128 i W6 G# y% _( E
13
# p0 c P5 t. G8 N* _* _$ L1 w14
5 p8 P. @" ^4 i; Z! Q3 V; U# ?# J# 使用上面定义好的类进行画图" A, `2 N1 F7 d o
fig = plt.figure(figsize = (20, 12))
8 e O* _2 U' \3 J5 h6 Z) K# Rcolumns = 4
5 ?- I4 \' n @/ lrows = 2
, d9 `8 n1 d0 h* p! `6 {$ r! T f+ i, i/ q
# iter迭代器6 Q% h7 `6 }" n# R" B
# 随便找一个Batch数据进行展示
+ K9 L, o* r {# }% Odataiter = iter(dataloaders['valid'])2 e& W) `9 Z& G
inputs, classes = dataiter.next()% q* `7 Q- {& u6 t+ g! j# E
; Q+ r1 X1 q$ Y6 a* t9 T
for idx in range(columns * rows):
- r2 V7 c# J- J& O( l$ Q ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
9 w3 Z, d7 g7 k5 j # 利用json文件将其对应花的类型打印在图片中
# l. e6 A |, G. [ ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
0 W# |5 i% p5 H5 i7 q- v: I plt.imshow(im_convert(inputs[idx]))
* X3 M; P2 L4 V7 H1 ~# zplt.show()& \5 _ o+ f9 S. I0 j, c: r6 v
+ p( `- Y- C. I1
0 @4 [: E8 K4 |6 q+ M3 l8 J2
9 {- D: _. b! j1 c5 g3" P& T( \6 T- n G
4
1 w8 \2 Y9 O2 ~5
~& Q% t: _6 b0 |) z6 \1 x6
0 D$ |/ d6 _* z# U0 i; f7
( Q* Q& E7 j. X0 K, J* X! q8
! }0 X1 G4 }# V" [93 c, N0 j/ L- p/ L6 t
10, l* _# l" W8 H: N$ u. d
11
6 ?4 T( J+ f5 Z: L. X- f12
7 r6 z5 O# q# F6 }13
7 B0 ?4 \1 N; T. T. B0 S6 W. E14
+ u+ @% @ b1 j! X$ {2 V3 q2 u7 l15# [9 Y6 Q+ m8 e% F$ S
166 w* a o/ S7 P6 i1 o8 e; R% B
1 K8 y& |6 K9 t6 B U5 P1 M5 x
6 x) h ]# g& N. d d# [
5. 加载models提供的模型,并直接用训练好的权重做初始化参数* H6 V6 P( }) t6 C" x. u1 H7 r
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
! G# `7 p% J/ j% _/ j& u6 o& m# 主要的图像识别用resnet来做. z2 J; f/ E2 ]: M" G* E6 ~
# 是否用人家训练好的特征% q8 p8 x2 C8 S3 j6 w
feature_extract = True. ]+ L/ f) x" y6 b' Y/ J
1
) b$ V4 `3 h# Y6 F4 F: s26 A! c) S- T6 A5 A/ A1 |2 ^
3
/ U' c) _/ D. b. W `3 O( T4
; I1 s; O3 G! k a( z1 L# 是否用GPU进行训练
, }6 L: L) ^4 h, _& Etrain_on_gpu = torch.cuda.is_available()
/ B3 @: @2 C: u! O# I1 t" `4 V
4 B3 Y1 o) e1 _$ qif not train_on_gpu:
- g7 ?, I' l) Q- ~' t print('CUDA is not available. Training on CPU ...')* C* W8 Y8 n0 C
else:( `4 w- y1 G5 O* Y' N: ~ K
print('CUDA is available! Training on GPU ...')
, R, r0 R, Q8 @' T5 {' W0 m; @' b( U4 D( U
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
- W! e, |7 t' n. N& j& \1 y* Y7 i' `' i4 g3 c4 X" S
2
' l6 U+ y3 U8 Y) l( d: A3
6 \2 c; M$ d8 N- m* q4
7 E1 y" z! ?# N7 ^2 c5
& b' p1 N) m" |( j$ H# c6
* B& E) _" m& L- `& r; P9 j7
8 a: y& g+ \4 ]8
! ~. C9 Q1 |$ N% ~/ D6 B2 n9
M% A( \. {$ b# }) P6 GCUDA is not available. Training on CPU ...$ X0 r2 y' ~( x" u
1+ u) X- |# W0 k1 X
# 将一些层定义为false,使其不自动更新& k8 k6 L7 t W7 ^" C/ X: a, p0 f
def set_parameter_requires_grad(model, feature_extracting):
0 E; x$ h9 @0 } if feature_extracting:
! U% y% }1 p0 z/ A for param in model.parameters():8 l( ?9 X$ N Z1 p# n( w
param.requires_grad = False: I( P, Z, B$ ]! D$ l
1
" M* m# A! N* d3 A# ^2
. w" x6 m7 O' d& Z5 r1 n0 U( P3
, O0 S% J; S: u9 O, F% c4
* C) d3 [' W) h7 b- ^6 k* p53 I- |# a8 z# B0 |3 |0 A
# 打印模型架构告知是怎么一步一步去完成的' u- K% x' D2 `
# 主要是为我们提取特征的( Z: \- p4 m( B8 [: c6 q
. d8 u; w# c7 E! R; \/ n4 D. X; mmodel_ft = models.resnet152(); P' o; U3 `' P, u
model_ft M3 q2 f. k5 s
1
% c r& x' H+ q$ A1 P: H2
$ N2 F& W, T- v" k3$ U0 l$ w3 o# V8 E* N+ f
4
/ ]" X% s9 n+ \/ X7 ~5+ [- X1 w3 S- o, t
ResNet(9 @' ]8 b) F. U% [) D# c6 C. m
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)8 O# n: o; ]# k" y( T" H% u
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)% n( [ D4 c( |3 q4 Q
(relu): ReLU(inplace=True)
4 W9 _% u( v) c. n7 }# w (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)$ K+ G) ~# p* Q, K' k# k" z
(layer1): Sequential(
( @; c- [5 E+ [ Y4 ` (0): Bottleneck(
) u6 j: P2 F# K, D5 T: C: B (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# R& i6 N& ? v/ {0 ~( I# U) O (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
$ I4 [2 ]. t/ I7 b5 J0 f$ X (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)! u) W6 r8 x Z0 n2 E
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)/ x3 Y3 o5 D4 r8 @9 p0 x
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)$ |7 Z/ E8 M& I* a
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
( K+ S! P1 q: K/ n, F (relu): ReLU(inplace=True)
" T( f7 o, c4 J- Y (downsample): Sequential(
+ v8 k& F1 Z' _ (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
( N3 H7 J! W8 x' g# D/ d (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
6 L: r5 L2 j: J. k )
7 d) _; [7 ~/ ?, Z+ D" E H; S# y )
6 a! b' ~- K9 ^2 {& T' ^中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。; M: V$ }! }- z* P& ` _& b( n9 J
(2): Bottleneck(, w2 x- X# \+ k
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
. h4 J$ c5 t; V& |- `9 J) B (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
" j* @0 s! b# K0 Q1 j0 Y& q (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ b; x7 n* F$ a( i* B% O3 n (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
u% @2 f# x) Q4 s' m3 E5 w (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
% ~3 _' S! P q& u) V" S f (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
' U7 R5 g# v$ N# a$ h1 M g- [ (relu): ReLU(inplace=True)
" n% i3 h q4 p- _% j2 @ g0 g3 m2 h )% z. }/ r+ `/ j0 _) H
)
( F* \9 Y5 P: d" } (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) _. y5 M$ \0 h9 C
(fc): Linear(in_features=2048, out_features=1000, bias=True)
1 @; P% R4 f o)' R+ P- w7 w9 O. q) z6 }& V
- ~3 C8 f3 M* n1* O; _& o+ Z' U
2: _8 ^) `3 M3 \) D3 A% F
3
+ e7 w6 q; w1 H: Z42 G, w1 V9 ]4 f& i: f; W
5
' A5 F! ~2 _: k2 G6# ?1 S t; S6 I" z" Q/ @5 B
7
. R, J% N" ]7 l5 T3 s) i6 l8 u6 l; C! O+ o2 @+ w: D
9
4 L8 A& j; [- Z+ H% f0 v10
& m7 C* B5 f' U11
1 C1 j9 _( g8 E5 C" [: R12
4 i ?+ f d* f% N: T: G" c13, l Q4 h4 E4 r
14# p- _. \8 \( ~$ }" i3 }$ r
15
" n1 I4 d4 `% H: e* k$ N16
5 l M$ W; t; }+ W175 t3 V$ H& c1 {: s0 P4 b
182 E- A) ~* ~* C* w; z4 J3 U2 v, ^
194 X1 K* C8 x9 R& H& c' W
20
. H' Q( m: ~3 G. _211 h7 @; s. _( g$ M5 Q4 ~2 S1 F
22
& ?/ G4 @6 L) l u/ i" h; h23
( i, y8 c: K& a7 g k$ o5 D ?* G6 b24 _$ s8 N8 M6 M& Z$ _
25* u6 r/ s, `! c$ ?2 p1 z+ u/ E- H
26 H1 N) e* q; }' z
275 A; d7 n* G. R3 g& O0 i' J
281 r. G' O; A: S8 O h
298 ?9 }' u# ?8 p" S0 S6 i$ `' A
30# b) }( N; B i# z2 {3 c$ s. o7 \
31
3 H* a; `! T8 C$ A+ a32
$ W3 r* V" f$ Y6 K/ K334 n) x3 F$ K& \1 L, s
最后是1000分类,2048输入,分为1000个分类; G! T4 x7 N: d
而我们需要将我们的任务进行调整,将1000分类改为102输出0 n5 ]* u5 p; A$ u* _
" |, T4 w& ?0 ~6 T p8 q% b( X6.初始化模型架构; K+ s1 L% ~) J
步骤如下:/ P, W( L6 {, s: g
5 q8 S1 L9 ?% k" ]; k
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数8 \7 F: T* `2 Y9 o
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False). u2 B& j* a% N; y$ _
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数$ X# o/ t1 K% d" o8 m0 e
官方文档链接" {# ?4 e ?2 ^- P
https://pytorch.org/vision/stable/models.html
5 F& b& O R7 {3 m8 G& P- m( f6 X3 J! |
# 将他人的模型加载进来' X1 i( O" ~) r! h2 S# H! e
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):3 O! k! F2 V. J6 u+ f
# 选择适合的模型,不同的模型初始化参数不同( h4 {5 u8 H1 w0 v
model_ft = None" b @3 q& S5 e$ @7 v
input_size = 0
" M; R8 t. J2 e
. b; \! B1 S& @& d if model_name == "resnet":
; F# }9 k9 g7 Q+ Q* R6 [# Z; W """
* W z9 }" f# f1 t. {( o Resnet1528 I4 I8 D4 }6 u$ y$ l3 Y' c" v
"""
6 z9 v7 K5 i2 q6 [* P( t
% O; R: ` f( }" A- ~ f; b4 f& G- x1 N # 1. 加载与训练网络* p9 N Z2 @2 ]+ d6 P) y0 t
model_ft = models.resnet152(pretrained = use_pretrained)2 f# n# D9 n x/ S
# 2. 是否将提取特征的模块冻住,只训练FC层9 t* d6 p: z. k j
set_parameter_requires_grad(model_ft, feature_extract)
- F3 p5 L7 S. B8 A$ O2 R( T; { # 3. 获得全连接层输入特征1 n/ }3 {: s5 I3 b9 O
num_frts = model_ft.fc.in_features' Z( [) }' P m) q- w
# 4. 重新加载全连接层,设置输出102) ]. b5 u4 y% y4 y
model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
$ o3 j& F- S- q# |/ Z nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
) }' \% O' ^* O: ^ R6 J input_size = 224 G+ C% Z; I! } n" s
" \' [& e) K5 ?' v! g elif model_name == "alexnet":
" C, b! o2 l% }. F$ [2 O3 G$ v """( w5 ~% v8 n7 k. B6 b3 b2 E
Alexnet
# d% U {6 q* s- K) y """
0 y: w9 O. G+ A% X% w# {# V model_ft = models.alexnet(pretrained = use_pretrained)0 X, e m: @# [$ n
set_parameter_requires_grad(model_ft, feature_extract)" ~3 x0 ~4 i# a) R" @9 W- Y
+ ^; b/ Z) Y# K) y5 p, `/ [1 ~ # 将最后一个特征输出替换 序号为【6】的分类器
, e b/ s8 ?* c0 v( [/ M num_frts = model_ft.classifier[6].in_features # 获得FC层输入+ u: j" M0 {9 ^+ \/ M' e& X5 l
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)" T: S* ^ W1 V( ? L [+ _
input_size = 224
$ G# Z- x `6 c' @4 k; v( { B5 p9 B1 ~
elif model_name == "vgg":
C9 H" |9 h2 q/ M """
" Q# y5 K& q. P( { VGG11_bn
7 w; Q% O- F: Y" T """
A6 G( e3 @" V8 i7 }( A& o model_ft = models.vgg16(pretrained = use_pretrained)
- F& [2 e) P0 p: \, V set_parameter_requires_grad(model_ft, feature_extract)
/ \8 N- X+ P9 \ num_frts = model_ft.classifier[6].in_features
1 V3 k7 ^4 o. b/ O4 n+ x model_ft.classifier[6] = nn.Linear(num_frts, num_classes)* Y) |5 N8 P8 C
input_size = 2247 p! M0 M+ w. k; O! q
% i% ]& A; B4 c0 T( z
elif model_name == "squeezenet":* j, T% J0 n; O0 L5 N6 _7 e
"""7 k$ V$ B/ U6 |
Squeezenet
0 m1 M/ v* x, c8 ] """
$ j0 n; f; x7 a% Q$ T8 L model_ft = models.squeezenet1_0(pretrained = use_pretrained)1 G' y+ ?/ \0 u/ {6 m
set_parameter_requires_grad(model_ft, feature_extract)2 w% e: C1 Z7 i- @ o
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
9 B& X( y6 n) y+ F6 O, N" L6 @. F3 ]$ a model_ft.num_classes = num_classes
5 T t/ O) w, @" P) s+ c& O5 z3 U input_size = 224
7 g: M, ^* ^& x, U7 T5 n# H. Q/ O5 U5 K8 Z- k+ y
elif model_name == "densenet":
4 a$ x% @$ [- @; P5 w: Q4 r """/ a! |1 D5 N) @
Densenet
G; V# @" m5 L2 Z8 s) I """2 ^ Q7 m0 N3 P$ b2 [
model_ft = models.desenet121(pretrained = use_pretrained)# N& d# h1 q; A- [ }
set_parameter_requires_grad(model_ft, feature_extract): S2 l6 `5 k2 {) a
num_frts = model_ft.classifier.in_features
! U8 [0 n, f+ H! t1 N6 k5 h; a# h/ E9 c model_ft.classifier = nn.Linear(num_frts, num_classes)8 B, k1 U0 M7 f! n" S9 Y) Q
input_size = 224
6 g1 A4 R" B) P% B" Z q
6 S; @+ L" H4 ~% I8 L. H elif model_name == "inception":& [; f; {# j7 H7 ~) }" n: Q: L
"""
6 Q% k$ v" A# H% X/ v/ S, v/ S# L Inception V3
2 x# r0 Q3 C Y0 V: l% Q+ q8 G """
9 s9 V* K% h4 ~! r( ^: ` model_ft = models.inception_V(pretrained = use_pretrained): ]) d0 l V" m9 M7 i
set_parameter_requires_grad(model_ft, feature_extract)$ i2 \ m; w1 M8 e: r+ |8 @
) t; {7 {5 x1 c6 `* h num_frts = model_ft.AuxLogits.fc.in_features
$ p$ V+ \$ X, z- P* A model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)& {, p$ u; A) E, A* \9 @
) C4 Z# N2 e! w( H7 F num_frts = model_ft.fc.in_features' q! d# _8 Q- f6 G9 j0 I! D$ H
model_ft.fc = nn.Linear(num_frts, num_classes)# @; @6 q% f1 t' H
input_size = 299 f6 C/ a& u; Y- C3 Y$ N
8 O8 h3 J' {' p$ ] L! j a else:
. Z; i* o0 o' v3 _ print("Invalid model name, exiting...")
) m" [+ T' i8 l' r exit()
; k1 Q3 r; t, a4 f' X
, _4 N0 j; |+ Q8 D7 Y+ s4 ~ return model_ft, input_size0 a8 w# s4 N7 k6 ~9 M* c0 l
( Y6 p; ~3 K8 Z
1* M' C& W3 U, e
28 b) W) b0 W+ `3 E% ~" D$ U4 t; [- {
3
4 l( M& d. X3 W8 B4
+ O0 ^- J9 Z) G$ t1 z5- ]7 R+ a$ H8 U: a6 J
6
$ Q9 v: U1 o0 @; b3 z0 `7
/ u! b! p% l) u) f, ]8
8 E# }0 y" Y- n1 w5 z96 L ~' d* p( ^8 h
109 y. g' W# X5 D2 e8 C% M
11/ M* K' s" q% g- U3 X' t
12
/ f5 p# q3 ~& f% I8 D8 q5 o13; J4 j3 R$ t$ \* f% w/ H+ T
14
. j+ x5 a2 ], J. m153 V, A' g( g# y7 W" o
16
# y& R. s& }% O+ G17
- |6 H& X6 ~% K& U; M18
+ R! O3 [) o- A* L! J! Q' G19
! x( N1 C) |" A0 v' _# u( h20
( V7 z; R) Y, c' ?8 Z6 O21
5 S+ ^+ j% ?8 y1 V' k222 ~. |7 _9 n8 A" G
232 k" f$ f/ D6 b1 j! g+ m
24
- g. p0 c7 V( m$ z. E+ U25
C# Y% `; D" t' C26+ v9 b0 i9 `/ o: k( o8 F
27% q* Z: P) ^. r ]; O( V
28
* D9 C) }) E6 l9 z+ Y( W% V1 ~29% `4 Q; e8 @% Y+ f
30
, i, ^% f0 ]8 e& c" l6 U W- y4 i3 p3 ~317 a2 |4 i; X! X5 E+ ~! `5 T
32$ N# x) t5 z' G5 f$ I, ?, [
33
7 N3 N8 {! w( ]9 g% v+ C3 E341 j) }7 G& ], e9 g
35
- ^! H( x4 w2 C! \0 m4 ~36; D' G" E, m N j
37. x2 f/ R) C0 ]- E
38' L# [, }) c4 g B5 U6 y
39; G3 t1 _7 q6 v! n: z! f
40
* F/ ?" y( B3 s& n/ Z41
/ o- ^3 ^6 q( U; \+ A4 v2 C42
" R8 l5 {0 m3 l. \43' U( q/ b. X8 I7 u8 ~4 S
44 h; J, ]) d/ x+ S9 `) X5 n0 `
450 g; C- k6 r. }: \& s
46
) N1 E) W$ j% e3 j$ C# M' b; E47
2 J! f& P% K" [5 p( s48
( L; y" I4 ^8 @# q. {, r49! D0 a5 w* }9 X! \. W2 ~
50# @& G3 J# a) E6 v4 @& a4 U
51" }1 I; I. K) B+ x$ _9 ]' j
52
6 L% q7 G q+ `2 ^/ M53
8 n" b# l' ^* ~) n54
" y& d, X, p1 @$ h* W$ W* o; S! G( M553 Q% y9 D" h6 g
56
- y _! { Y- K `- f# M57
0 @, v/ X: T; T6 u) H7 h7 N( o6 q58
3 I/ w& w6 b/ @1 e592 A' j9 W Q2 U& a
60$ o9 j- g2 J/ y% E3 n2 ^
61- I0 q* N8 w- `0 g( A: r
62
0 K9 R: S4 D+ w" H5 v! V63 S! e) F: `5 v+ z
64. w" D; L7 S M/ O e6 P* f
65
8 c7 {: w5 o& S0 _664 e0 z9 w. g$ @
67
% i9 ?$ q( X; U; U" h. y2 o& s% q68! t E$ `3 d0 o& B6 Y" F
69
8 e: x, q. I1 F2 x1 w70% ]1 J+ [" x: G( N7 K& O' S
71, x4 t; r/ m3 n
72; i" g6 U) y/ W; \2 A2 |* S
738 T Q: ~. n" b- Q, w
74
, ^5 t y8 a; c) O A- P75
& j: T+ e* \: f+ N/ z76
1 R& f: V; w" N3 d2 |77
, T7 ~9 }( s) y78; `; H, g; C- C2 X* W
798 Z/ B/ f: C% _% k# s% V' o
80
) Q$ V( S. W- s9 [& Z81
, X& v2 j- p+ {2 l3 U. F5 o824 D5 c$ ?0 x# E9 x/ V/ ~
83( ^+ `9 C' D0 |, o! a, z2 Z
7. 设置需要训练的参数
# I3 h! W( w2 x1 n, q1 L# 设置模型名字、输出分类数 g" I3 {# P5 l7 |" M* `
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)$ U# Z+ k! ]' F
7 B5 I, m3 `( y% I
# GPU 计算
7 u6 p) b1 a8 H* Z/ }model_ft = model_ft.to(device): l' T; V8 B+ S
& f$ y' a7 ^5 f, q- S
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
7 A/ o' S3 A3 \1 ]0 h" \, jfilename = 'checkpoint.pth'( e+ s3 d9 K- `
5 a) A' W8 A; R; d% c( |( D0 R# 是否训练所有层& b) G4 ]% {# s& N5 b
params_to_update = model_ft.parameters() h4 B9 G" a8 w+ l
# 打印出需要训练的层% R" _) H, N+ S& v) P6 ?6 i" J+ |6 R; C
print("Params to learn:")/ i# Y# O% g3 h, E: u+ G9 u+ F. O
if feature_extract:
) ^4 G( F) ^* p2 D params_to_update = []: I9 Q* ^0 K; C
for name, param in model_ft.named_parameters():
% ^( F; w/ ?* U; ~3 q* m* S if param.requires_grad == True:* k2 D7 \6 U2 U$ S
params_to_update.append(param)6 ?6 i# E+ D- I4 L
print("\t", name)
0 ~# {9 G/ r# i. ~( [; x6 | k9 Y: m u Xelse:$ g0 d' S5 @# J4 l6 U/ x
for name, param in model_ft.named_parameters():$ D% d9 v y# c/ C1 C( N
if param.requires_grad ==True:5 V5 S9 `5 h; F3 S* f# q
print("\t", name)+ f4 s7 L- g! y- A
* K7 A0 b% L: g" O* H) V$ k- {. t1$ ]2 F& I4 C7 n x" K* P/ m6 }
2, n' s) R: e4 M3 G( P# D% Z `
3
, C# d1 w$ _- y! y7 O! H4
/ ^2 I8 D% I( P( W! N" j5
~6 l9 v# f O6
S o/ i8 b/ V& r0 w4 ]8 I7
1 ~. W& [/ @9 J7 [1 ?; e83 o W' j8 U, r9 Z
91 a; {( A* g: o4 J* U+ W/ {
10" G8 T0 m' S( G- v7 ^; w3 U
11' N/ O9 }, [, F+ {" W( m
12
" _& h6 X, P$ T& `13
% u$ H" c. A8 y, i' V/ E, V/ V141 o2 p, k) y% o7 r$ Z
15
$ Q0 \! @3 ]$ c H16( H# A; K1 D1 [3 b& D+ o9 t* |6 N
17
" n5 T F, b' v18
$ x5 P! e5 t# f- y4 d' A19- r8 {* |7 Y7 K- n/ u6 \8 p
20 [6 |: B6 d- p4 [
21% k6 a- {! W) T
22- k ?& G8 m9 q# v' ^9 B
23- Q: `0 W$ S$ v1 J. i4 C
Params to learn:
8 P. i( x) h+ l) e fc.0.weight: A3 [1 c# @% \( Q5 Z1 X
fc.0.bias
; K% b& V2 q4 C9 U o; O1
+ R+ M3 t- Z! @2 p+ F! r28 f9 x7 g) ^% ~
3
! o& d3 o F% ?7. 训练与预测 ?, }9 y- I% @$ o
7.1 优化器设置
8 d. h: O, s9 Z2 x0 Q& w P# 优化器设置
# Y4 Q6 j) [5 o0 Xoptimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
0 d8 l" w5 t$ v2 K3 W2 D0 o# 学习率衰减策略
, d6 d1 `3 V5 L. ?: n- U) ^; [scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)9 V/ z. }& [5 o* O0 J
# 学习率每7个epoch衰减为原来的1/10
* n8 r3 H! n; G# V, _# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
- O% \/ _0 `$ E% ?* x2 n g2 n: m" h( u% O& c& }/ _
criterion = nn.NLLLoss()* i/ O; S3 n; O. H2 A4 j6 N
1
\1 c( r+ A& Y4 A23 z J1 T- i5 a
3
X' P S# H: A/ l4
4 P' T) t0 X/ t( O! Y; P/ l5
/ ~" A( G; {. y1 x N6
' ?' x4 c4 u$ o5 `2 ~- h7$ W* w- x2 r6 W7 A1 d
8
* a+ T# w8 e3 H/ n7 M( _# 定义训练函数4 _8 W7 K" V: D5 E% T# g; @
#is_inception:要不要用其他的网络! s. R7 g9 |: t+ G0 `$ z
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
( \. v; j' f0 C0 V0 z# P, j7 C( G since = time.time()
9 u( T- ? b8 H( b; M0 ~ #保存最好的准确率
1 g" T3 f v, Q6 W9 `. { best_acc = 07 d2 @: B4 y5 ^3 N% M6 U* e w
""", e: Z) G+ b3 z! \& U" F
checkpoint = torch.load(filename)
; T5 K/ `, N5 I7 J: ^% q best_acc = checkpoint['best_acc']( m0 K6 r% u$ {- F! `
model.load_state_dict(checkpoint['state_dict']); W6 }% |1 C9 n1 n7 h- R
optimizer.load_state_dict(checkpoint['optimizer'])6 ?: ~4 }! W7 K z% D8 J
model.class_to_idx = checkpoint['mapping']. P6 g4 t8 c# N7 X3 e$ g! g
"""
, S$ N: B, F( i) ~ #指定用GPU还是CPU
% s% U6 K1 X0 a" ?1 s4 D' A* P# c' ^/ x model.to(device)
1 H/ Z/ h8 z( U3 e4 y! E. o4 F #下面是为展示做的: b+ S% o1 |6 o- B) H( O1 l/ ^* v
val_acc_history = []3 n. O! R/ p. v
train_acc_history = []
5 K7 }0 T; j' ? train_losses = []" p _# k: B" `
valid_losses = []! s- C2 J! x9 N/ n3 h+ ~: n$ {
LRs = [optimizer.param_groups[0]['lr']]3 o, C7 ?3 O2 H. w# \. f
#最好的一次存下来
2 _( ]' h% x# S4 W0 L; L best_model_wts = copy.deepcopy(model.state_dict())
$ l0 P8 u! s# i) w* l
1 A. P6 J9 J" ~ for epoch in range(num_epochs):
1 R: T2 Z- M5 V0 e/ G print('Epoch {}/{}'.format(epoch, num_epochs - 1))
% f1 n4 N& v$ Z/ v) q3 j: p print('-' * 10)
9 Z4 L1 f/ w8 G8 W) |7 s
, J4 n! E+ S7 ?0 A) s ^ # 训练和验证% B+ Q D3 N% H5 a+ p# D- @" I
for phase in ['train', 'valid']:) s G* s; ~8 v5 f- Z+ L- w
if phase == 'train':
: a7 m2 z8 ~1 W" r9 e model.train() # 训练, w6 O( ?. k) s) B
else:
k U/ V+ D5 _( f model.eval() # 验证
3 e+ L3 J6 {$ X2 p9 }: ~
! j1 V' G# k3 V: `$ C1 s* l running_loss = 0.02 r/ V- q$ m- ~9 V+ A5 h$ e
running_corrects = 0
! r! I* {& v" g9 n% }. t! f* q2 L7 P3 A# S3 {: G
# 把数据都取个遍5 z9 t( x# \5 e, e0 n% J
for inputs, labels in dataloaders[phase]:* C+ `3 v; s! X* F: d
#下面是将inputs,labels传到GPU7 \! u0 y C6 k `9 U( O+ j
inputs = inputs.to(device)9 X U8 d/ ^1 c; }3 L: E$ u
labels = labels.to(device)
, y% d$ \: a3 ~# F* a- m @& F) r% Z" K4 j, a5 J2 d
# 清零* R( J, | E, `! }, Z1 e6 ~( S
optimizer.zero_grad()
- W; N, D. i3 i8 h, M # 只有训练的时候计算和更新梯度
! P7 I+ E- _! Y) w3 A with torch.set_grad_enabled(phase == 'train'):
# F% D% |6 `1 H# Y+ r2 z0 B #if这面不需要计算,可忽略2 ]) m; f5 t* I) G- X! V6 k e
if is_inception and phase == 'train':
/ k" M/ w2 m4 `3 W2 U! h* D outputs, aux_outputs = model(inputs)
) W _, r- k$ y8 x loss1 = criterion(outputs, labels)) \9 _% M0 x5 f* f7 L
loss2 = criterion(aux_outputs, labels)
( R$ _7 g3 A- P$ q7 Y0 T loss = loss1 + 0.4*loss2
8 t0 J0 e" S; F' S else:#resnet执行的是这里
! B3 D) D/ \; e outputs = model(inputs)* O& h9 G6 z) @
loss = criterion(outputs, labels)$ A9 F/ B6 f1 [ J
+ a, }7 w. Y/ m. N, K
#概率最大的返回preds6 T. F% @. y8 k4 P( ]) f: K' w
_, preds = torch.max(outputs, 1)
+ k& Y' m5 F T% z8 h v" z$ l! T: k
# 训练阶段更新权重& F Y# V' l7 r6 w/ o
if phase == 'train': D% B6 @9 w8 O- r( c
loss.backward()) }9 M' x9 {0 |* ^
optimizer.step(): O; d; E4 J1 e7 K/ R! u
2 e7 n7 |6 z( f) K& A! i# h2 c
# 计算损失
. {' T" f5 a7 {# Y running_loss += loss.item() * inputs.size(0), {; l( X* d+ n4 K8 C7 P3 c: R
running_corrects += torch.sum(preds == labels.data)
4 G7 ~" q' b9 V$ |1 i. ]0 u
3 b- T& l5 A8 h- u' O9 m( l #打印操作- b4 K; E, ^) L) q
epoch_loss = running_loss / len(dataloaders[phase].dataset)- J4 V. ~# Y! D# y
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
# w' S- F% \5 W4 ]8 y# x
; A0 F7 ], D1 F! R' c4 v: c% @* e' w3 P- ?0 T/ a2 I! v3 y
time_elapsed = time.time() - since8 `& `9 T# U; @4 M2 h
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))( ?! j; I8 z7 o. c
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
8 C; A9 O/ {( K* Y/ u4 D9 ?; B: f$ b- o+ W X9 A4 c. m' P+ J1 @4 C
1 X' M9 r; {: d7 V0 e" F
# 得到最好那次的模型
8 K, u9 c ]0 l/ u$ G4 ~ if phase == 'valid' and epoch_acc > best_acc:
* a' F6 j# a" F, g best_acc = epoch_acc/ O( Q- P& O+ ^
#模型保存
F2 ]. V! J5 B3 u; h; u- b best_model_wts = copy.deepcopy(model.state_dict())% F, ?$ Q2 i& d/ y/ j2 B& ~( u. m
state = {
& f* k/ q0 h9 s+ [1 L8 e #tate_dict变量存放训练过程中需要学习的权重和偏执系数" r' ^8 U" {6 r& E
'state_dict': model.state_dict(),
3 I! q! _3 A& I+ Q" h1 a0 O 'best_acc': best_acc,- K W; k2 e9 N1 x
'optimizer' : optimizer.state_dict(),4 ~6 p+ D3 L; W5 b
}" v) d: k% q2 x, _7 q" P
torch.save(state, filename)
# `. `; o* ~4 S8 j& A8 x* J if phase == 'valid':
7 d. N$ O. h0 R" A. ], R val_acc_history.append(epoch_acc)
4 I' d! J d* d; v8 A valid_losses.append(epoch_loss)
1 D, l( \& W9 l5 n& o0 ?3 r scheduler.step(epoch_loss)
$ E1 \8 s5 O: P2 h3 a if phase == 'train':
D+ \3 ~$ l" e6 I) T/ H train_acc_history.append(epoch_acc)
* }$ V t6 M- p1 f1 ^( z7 w train_losses.append(epoch_loss)
2 ?4 f- C- @; }
' H9 P: C* q& f. l print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))" J/ V, X, m: g$ n7 l5 g3 `
LRs.append(optimizer.param_groups[0]['lr'])
+ q3 V- X% t* @/ d* `% }: X print()
+ r3 ~0 o/ f3 z5 T) d' ^6 B, E* f3 L& I
time_elapsed = time.time() - since
2 Y7 D& S1 }0 D0 ]3 ] print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
6 w/ t/ R( t6 ? d9 q0 b+ O& U0 B print('Best val Acc: {:4f}'.format(best_acc))+ _! j$ e, m4 i, n$ F0 I
1 x4 y$ L6 q& I9 Z+ J/ G- U # 保存训练完后用最好的一次当做模型最终的结果4 S9 p2 m% _: [, A; L$ O4 L
model.load_state_dict(best_model_wts)
9 C, y5 A* g% N; t4 Y return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
3 T2 c8 l- }" h9 o0 p1 H
- M$ N, _& @& v
, {/ P! G' E7 _1 ?1; c7 b* v$ \8 ^3 N" S9 f# z
2
0 t+ y% L. f. \& d3
" f( l" u5 ^5 X1 Z$ ~' S& z( ]9 R1 h5 k+ U4! n X& M+ U5 f1 u. C
5
/ N" G' M+ c" ~( _( W- {7 Y- [6
! w7 Z% f+ q! n% c7: c# t) Z& |1 ?
8
( ^% s% F$ f, |( e$ s) g- ~9
* R$ _2 Y' x6 e# z4 B10
# D: [2 T* ], y11
8 |: B9 y$ T! E1 J2 r+ Z/ V$ k12
% B" H% S* a8 |. v135 O" \$ F, n G& r
14
7 j6 a: \9 T. G: B4 s15
$ H; i) G& C% m9 z% g16
. l" o0 k5 X$ L) R17
6 m3 C1 R. S& j) q! M6 ]18
$ t. n: x1 J, m+ \1 ~$ Y19
" q% Y. U: B4 B2 c* c4 |20$ F) ^# r6 U* K% C Q: k
21
% i# t+ d2 r- [5 K, v$ i22 W- }4 M0 H5 ^2 y( Z d5 s7 [
23
3 z* k/ ^ t7 p& ^3 O I9 U5 _9 U! J24
" P+ k# E7 g7 B$ u( p) h+ c M255 v: y0 B8 |$ O: O* N. e
26
/ D! U. R; w9 M* J/ i' J27; l, [" e6 U+ y" x2 u7 D5 d* R& w9 E! ]
28
0 T* A' m' D8 ^/ @% m/ q* ]295 i) |9 N& ^: G9 _
30- E$ M* \1 g- p0 _
31
' ?) o& v' p) Z% c329 t9 q; i, \/ p; }
334 U0 n3 B/ o- q2 L
348 h( B/ s+ h, t1 i# e+ P6 ~
35
& r5 p1 y8 l3 Q7 ~ P7 C36
- e7 C# U" R4 N$ {# x37
5 l0 [3 U4 J+ q( x' w38
+ Q* K2 [ k* ^- ~39
! p7 o: ]) j( s40: c' G5 r* a) x: J7 B$ t
41+ }2 o0 G! Q1 w2 s1 M, e
42
8 k4 v3 |& a& j i4 s, R! L: U43
: ]3 N: u" D) D6 b; T44
* f% R6 }, L0 e$ t+ O% z% G45! b2 L. p- G2 Q8 i4 L; }
46
: y8 l8 x0 s7 J. l" d8 j/ `& h47) h7 ]7 X+ o- _0 }0 F
48
% q' P8 c7 N! H# c, l) N0 C4 L49
8 o, {% L8 h$ |. r50
. L& y+ a7 g& U51: L, T5 R9 x5 g9 ]
52/ I* Z: x7 O# Y) \+ A: {- j
53 q- ^% y( a0 p% ~2 {% Q
54
) V# j* }# c6 v55! a, x' i/ j! w$ R7 v& Z: C' a9 H7 h
56" U& c/ O& |/ J' M' S
57
: ?0 H/ t* F6 P$ N. @% t3 l( ?* w58+ \2 D% P. B0 v% y' z
59
* n- G3 ^! p! r: U" `1 P0 s4 f60
+ T8 { S* ~$ W" r! }+ E61. w$ j5 n' a9 M5 t$ l4 s
628 z/ s+ z% R6 g; f# Y, o; y8 F
63/ o! G8 t! D" H) x& M
641 M) @1 X, a5 v: ] f; H
65
& G# l7 {" O5 V+ s" M2 L66
# P3 c% R6 y: j* \* X2 r- R% P2 P674 Z) B+ y7 |% O) I1 D
68
- Z0 W8 P% c4 p) `# ^8 R7 Q69# S8 m! ]) K& \' w- ]8 q) I
70; ]- L5 n" k4 R8 J
71
5 {4 E2 @9 `! g: \" @72
$ x0 i8 H4 }/ s. Z2 N7 v1 X5 Y73. z$ t7 Z- j5 c
74# |3 u5 b$ N; j' j, s$ U1 K
75
% j* d- o! ]4 N7 M$ q, ?9 \76
2 e9 \# @: D0 X* X77' \+ f1 f) A% S
78! n O' E( x8 x$ `9 C% M5 ]; p
79
& B; h+ L$ Y" W' S2 ?80
3 a* {: _- v$ `: }; C; W( [815 L4 |( y: `7 [: Z( f1 k' s( D$ }! J2 t
82
s. r U" u0 E; J* M. [$ r' P83% J# G( ?5 y+ d( R5 B4 a
84- A6 a' b& C1 L6 d. P
85
/ ?, c$ Q3 N8 n" j' L' ?$ D: E86; w* j$ G8 i8 ^8 [' ?
87/ g B) X8 j: J9 q& M8 k6 J
88
' L ?3 |# Y4 H7 r& R' ^* f89
5 i; u) S0 o+ Y- Q7 X3 Q/ l90
]4 @$ y/ w$ }4 ?! B1 _: n91
- O* \+ x& K Z& }5 \; X8 q92
' T4 _8 h. T+ L( o93
2 D' a9 v; n/ b7 C94
4 _% p. h8 b4 l! g. k8 W95
7 e7 v! g7 d8 @96% f2 I# @5 u- k, y. _' M: k
97
( O- i8 {& }# O8 j. }6 b989 t" s; @, G3 V1 M) B l! y
99
: w8 _+ C" |) k% W/ Q100
; ? h% D6 u/ K; c9 [101
4 A) X3 `# w" w2 x0 a102
& B) u5 n5 L$ @2 N B6 R103
; l6 B8 h [# H7 z104
. Z7 c. v% `- N0 e% ]105. i0 H- x E: T" r
106
# W2 X5 W3 W( i6 N7 \3 p107
. m7 T( P# n/ h1 c. k3 }! k8 r' L1082 p. m! Q z% w; `1 G5 G, V
109
! j( C& ` a: N$ a6 K+ o110/ ] f; C, [7 H: J: |
111) r% e* |) ^, ?9 v
112
( Y, e- k4 w& }7.2 开始训练模型
/ |$ W8 f @4 j- p# G我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次2 A9 M6 M8 _) R& w# Y
: [' F5 `- m# ^#若太慢,把epoch调低,迭代50次可能好些
8 [. w6 [) D v( m6 }# d" M( X- a0 [#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合4 S/ ?; o5 h+ a0 @" Y0 G
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
Q0 X4 g5 o Q# Z+ ~
& y( g$ B# U0 j1 a1 ^2 ~) i1: E. ^" Y+ f$ @/ O" @$ x
2
4 P7 V4 d) H4 x4 q2 X2 _# w3
$ H& p! O. ], C* t* d2 h/ D4% X( h' Y% q% A: M% N: i
Epoch 0/4
o0 ? b3 S" ]----------- `7 J+ x" @9 X7 k
Time elapsed 29m 41s" _3 ^" K$ Z1 c* g9 y3 ^9 q
train Loss: 10.4774 Acc: 0.31479 `" n* B! R, t
Time elapsed 32m 54s d) L7 R9 @3 e+ \; c+ f* Y
valid Loss: 8.2902 Acc: 0.47199 u5 c0 C' a `1 n' C: R, M
Optimizer learning rate : 0.0010000
8 y w" w% J; u$ A
+ \8 C5 C5 G O) s; Z) v* pEpoch 1/4. y- I! Q, V+ R/ W6 s5 j4 Y
----------
% m% `. C- t" k6 vTime elapsed 60m 11s
; B, Q! v m. x. q3 s( y) U# P0 v+ Ltrain Loss: 2.3126 Acc: 0.7053
: e1 u6 q d" l% A x2 cTime elapsed 63m 16s
2 ?3 s& ^" v9 O6 @valid Loss: 3.2325 Acc: 0.6626, ?) L8 }3 _0 h/ M& A( j5 H
Optimizer learning rate : 0.0100000
# m% x1 u3 N& D" h( x7 P
/ s) M% u1 h! \2 j5 `! rEpoch 2/4
$ f+ K/ [- V% \0 H" ^ k----------
* k d0 v# _3 }) [2 jTime elapsed 90m 58s
9 k3 n; i" }) r! s/ m! W* Q# ktrain Loss: 9.9720 Acc: 0.47340 m' G5 }: w3 I1 p7 @* X
Time elapsed 94m 4s3 W$ p* E" W5 `- G' _# u$ O
valid Loss: 14.0426 Acc: 0.4413! V) i1 \! y" ]. n3 }* d7 o
Optimizer learning rate : 0.0001000; v& s. K9 H/ O/ k
( L, O: o) F' ^- d# D. k
Epoch 3/4. w; d0 ]! X' C( a& h
----------% A, ?4 G% M0 u8 B5 A8 P) P
Time elapsed 132m 49s
% g( M. _# i; E2 |0 itrain Loss: 5.4290 Acc: 0.6548
. {* `9 s9 W# A* Y1 a; M; L; hTime elapsed 138m 49s
8 }( ]. O1 x* B/ c% M$ Pvalid Loss: 6.4208 Acc: 0.6027
" y4 P* y1 R5 N7 o. E0 @# t/ ]Optimizer learning rate : 0.01000003 d) G0 m3 |' w$ S I% c& ?7 E
: ], z( w. s6 H, M7 }: L$ P
Epoch 4/4/ v F7 D+ N" B: a- { Q
----------
+ k+ m" z0 k3 X3 J9 T0 mTime elapsed 195m 56s4 W; m' s+ ?* V$ i p0 I0 Q& c
train Loss: 8.8911 Acc: 0.5519' i. u1 n! i# f
Time elapsed 199m 16s( H# ]/ J: F- P6 E' W3 W v7 `
valid Loss: 13.2221 Acc: 0.4914; C1 `8 i# n# P5 l: J7 Q
Optimizer learning rate : 0.0010000# F( _- T" M& p- U- l
0 A' ^. j4 h8 R Z$ h' Y2 F9 f
Training complete in 199m 16s3 q9 B) z( ~) ~& }0 y9 e% m
Best val Acc: 0.662592
8 v2 S$ y% V* S& P1 c
$ R! V+ J: @& j% ~18 f d8 b( O6 F* K
2
3 P! z3 n, w: _ Q3
h, q; [" Q, o' p: _9 o4 H2 L8 h3 t& S Q. H E4 U. k! s& W
57 D; j; S- i# w) K
6
+ B- w: K+ E6 X. r) z0 F7! Z% }: z3 \9 g+ }
81 B" v9 h) o# Q" h
9
8 Y+ b' }. P) N K10- `; I+ z2 m& L4 z
112 Z& D# x' Y# X- j
126 n" S/ R6 W+ @% |( S
133 V. T6 o7 a" E; w
148 B! U2 u! u" f. l* y
15
8 L2 ^/ i7 S$ g1 d! h9 Q16
3 b! | K2 k& @# z17 O! n+ O- g) p8 \' L
18
- g5 ^5 u7 l+ u/ S19! s' {" g% B! H7 \6 O" @- J# r6 A( i
20
7 |- T+ x5 `% m) Y/ ]. u! T8 D21
9 B( p0 \6 @- q5 I3 X7 A9 _22. g' E5 `7 [3 Q2 j+ W3 i7 I
23
* a3 h3 h# }. p( L7 h( Y6 J24
2 N; W Q+ |- k7 h7 {258 {5 [, `5 {( b e8 W- ?' D
26/ J: s7 [% g2 c5 @$ R! s
27
) c" h4 b5 |6 g* K% j" Z; a28
3 A5 V: w% f# K# o1 Z0 h% u29+ R0 U8 o$ U9 ?" S% |4 Z1 ~
30( c' P1 o& l9 ? H' _( m! r+ n
31% h: d" A1 W) I9 n# i# Q
321 H1 o% h; J1 E) C7 W. N+ a+ \0 p
33
0 {- G, V0 F& w+ Q3 Q34
% B: l* w; Z- W3 w; y' l5 v35. C! {" i5 b }+ v' C" u5 u/ o
36
! q0 j% |1 b( \7 e37
" D; q% f2 J! D; m38
+ n" |4 a8 n& `$ m% K4 A39
; S0 G' L+ G1 X4 N) n! l l' }, z7 F401 g: O; u; O% _8 c: `$ g Y% z7 K
41
8 L* {: n1 Z0 M9 D42& c$ N" ]* f( Y1 T1 N$ Z6 F
7.3 训练所有层
' k- k" r( W2 t1 e7 ]# 将全部网络解锁进行训练
9 Y4 Y9 y7 M E; j- ofor param in model_ft.parameters():" m9 @5 |5 ^( F& N5 ~& z
param.requires_grad = True
9 R; ~! O. h3 q- Y d! q6 F" j- a& n3 S; |/ E
# 再继续训练所有的参数,学习率调小一点\
' m' @7 R: q: r3 X/ d xoptimizer = optim.Adam(params_to_update, lr = 1e-4)
4 b0 h) {: q: i1 t7 ischeduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)9 P! {- V; d$ H: Y
9 |! X4 ~8 C1 V# 损失函数
; }7 N0 G" Z+ P0 h, _5 ?+ h. dcriterion = nn.NLLLoss()
* ~4 W1 C1 n9 W. y; N) R0 ]1+ F% N* Y- D7 v. _
2
( a8 I" P5 n' ^$ M- U! X- N3
: D& s% d* h( h/ Q4/ v' q. V% Q R3 w7 k8 I
51 g7 X' x. G! `5 w8 z9 J9 \- `3 w# s
6$ z! E- h5 h. ~+ I$ |( E
71 {: q( D& N3 l
8( _- f) A& W& o0 L @
9
9 |# a: C( Y6 n4 j6 g' P10& e" {5 @& E( F; V
# 加载保存的参数, r- X( ^# V2 p
# 并在原有的模型基础上继续训练
% x# t7 j2 r E9 |/ k$ `* t# 下面保存的是刚刚训练效果较好的路径; F- g S; e. v5 x/ t" P) i
checkpoint = torch.load(filename)0 ^ t5 v9 X$ C( k; G& b
best_acc = checkpoint['best_acc']& v, R7 H+ s3 W/ I; x; d9 h
model_ft.load_state_dict(checkpoint['state_dict'])
9 j z# u5 E, u( V4 ~4 x1 d" |8 }1 l) s7 l7 boptimizer.load_state_dict(checkpoint['optimizer'])4 t- ]7 x Q' B3 B- P- w
1. A/ t9 T7 _: A$ P( C" q1 Z4 [
2$ V/ d; d* o8 ~- @
3
4 v7 v! |$ K; [- ]* J3 x$ j. \4
3 c6 K8 X& k: O6 F& W5( Y/ s' p* D% w" d7 j
6
: S0 u# N2 r6 n# v0 a7
+ I: ]! P; O& {5 g" T$ @开始训练7 [5 x: M/ e, z4 H' H7 k& X
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考: F3 j/ u! v- U6 T
1 z6 T3 ]% t1 p& c& amodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
7 I" N( L' S a1 w4 c1 t1( Q! j* m# D& o
Epoch 0/1
4 \' i* v2 y+ t! b; {) c( w. _----------
d9 J5 n( H `% x' uTime elapsed 35m 22s: Q7 J9 D6 k6 k
train Loss: 1.7636 Acc: 0.7346
' W0 [* N/ T5 H( p8 ~Time elapsed 38m 42s5 K+ E5 z( O0 Q- C& m
valid Loss: 3.6377 Acc: 0.6455; E: x/ i( d r2 D, s2 S
Optimizer learning rate : 0.0010000
) v" f& d3 A# n" D) O0 m
' G, u9 J# d9 y3 y5 A5 V' hEpoch 1/1& D, X: O" O {- j, m! J8 J
----------, X' ]4 I: T" [& x; }
Time elapsed 82m 59s( }3 @, a3 ? B! o2 y9 B4 p# A
train Loss: 1.7543 Acc: 0.7340) w1 B0 w/ S8 y5 y/ Z
Time elapsed 86m 11s
( V7 `" D( r& B+ q6 w; O+ y, Dvalid Loss: 3.8275 Acc: 0.6137
' A. O# f4 R, Z+ @5 I3 F* |5 SOptimizer learning rate : 0.0010000
# f' I, o' l- ] Y; F8 _2 L- b+ j6 Q7 q* M8 Q
Training complete in 86m 11s* `/ U, \9 O$ {1 O
Best val Acc: 0.645477
/ C) j- U. ]$ V- b3 z9 M& g/ S2 b; A5 V9 V
1& g% I" {8 t* W
2; ^$ i, O$ c0 x8 Z
3
. x; }3 |6 _- u# `' R5 ?4
( x6 X/ b/ U& E$ F: g' K5
! i ^; l4 f/ g0 Z) x6
- [1 N& C( g- C& M7
0 ]0 f' D, H( U- Y8. I2 d5 Y2 F5 n+ x' `
9
8 i" S# M$ c& ^0 C10% J) b% z# s9 q% X6 M4 D. |
11: m3 \3 a5 H4 O* v8 I( K
12+ v) U2 @+ @$ n/ M5 X7 K) l. j! [- l
13: H4 w, y) R. N- A x
14
6 x0 o( m! m+ {2 h- C* b/ S3 h15
# q1 l7 G% w4 U% P$ \8 w16
/ z3 q% R! F7 L8 y17* D( n1 `: [8 b& L g4 L6 K, ~
180 F" T- z1 `3 _. l# L
8. 加载已经训练的模型8 ~5 X/ g% d1 h( F( c0 J3 c
相当于做一次简单的前向传播(逻辑推理),不用更新参数
* ~% ?6 j3 ^4 t/ j8 {& d. y
, M4 k% t% d; g6 A5 I' nmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
# I/ c9 d5 N* U7 S" v
( ~5 D7 {2 O6 a# GPU 模式
6 r# ]5 z/ v. H2 @8 ]2 ]+ I4 e: d& umodel_ft = model_ft.to(device) # 扔到GPU中/ @& L8 o+ N! X; H5 {
5 L! ?& f) T7 J5 ?+ L# D
# 保存文件的名字- } Z) W: F: a1 M8 j9 ]: n1 o
filename='checkpoint.pth'8 t: C' e; ~2 C/ r
; M" N! l5 W: s" r
# 加载模型
) @6 W& }! \9 T9 g0 T. A1 lcheckpoint = torch.load(filename)
6 @% m2 D) } A/ ~- Gbest_acc = checkpoint['best_acc']; S! p3 [& p }) B$ S" ]% ]- {
model_ft.load_state_dict(checkpoint['state_dict'])
3 k. S$ m) c4 K8 O4 U2 S' Z! g1
* U6 ^8 N: P9 Q6 g2
- t+ `2 X! j1 J# h. K4 r7 y/ h3' T; v: K: S, ~, l) I: V2 Q
4
$ F" R4 s7 g- b; G0 T0 z5
' J- s6 C6 D- u65 s N* J; Y8 _3 y$ v7 @7 P- t
7
2 Z% V( Y$ G! u! g. U3 G89 _0 K7 k: s1 s6 ^4 H; |
9
2 V! H+ n# i' [8 }+ ?/ I f5 C4 Y10, D' \! W" `' v- |# t6 W. c5 u
11
0 l8 R+ ]0 f1 Q7 E% w2 g12
" J% D9 a }) F<All keys matched successfully>: ~+ Y; K4 r; l i9 N1 F. K2 Z7 c
1- Q6 w% [4 u T& g6 C3 A; U* k$ O
def process_image(image_path):2 s) ]' a( Q3 W, _2 T* n3 J
# 读取测试集数据
, v4 t9 M; z6 b# `+ M img = Image.open(image_path)% b, ~8 q- f! P! }3 D! M3 I
# Resize, thumbnail方法只能进行比例缩小,所以进行判断: x- z4 i# i+ d# e+ ?" ~1 g0 V
# 与Resize不同
& f9 R' z% _! N d1 S # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
6 v" ^: Z A0 l2 @( g |, S # 而且对象调用方法会直接改变其大小,返回None# D( Z, U* W) ]4 O
if img.size[0] > img.size[1]:
" l( e A' d% a" m+ K4 d. e img.thumbnail((10000, 256))) t3 q' @8 M! H, |- |+ {
else:; C* i& l. h- f% @ T, K
img.thumbnail((256, 10000))
( l! B; R m4 @2 T6 X9 a
1 _: Y# X+ q" v% \& ~2 R # crop操作, 将图像再次裁剪为 224 * 224
& o0 b& F" S5 u+ z' @- P- q left_margin = (img.width - 224) / 2 # 取中间的部分
( M/ y5 F8 o3 ` bottom_margin = (img.height - 224) / 2 2 I3 X9 F: y4 t2 k
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度/ i5 _& {* U7 @3 _. i \
top_margin = bottom_margin + 224
2 X7 [6 m- y# w' O$ I- t
1 V, \8 Q2 B `7 O. }6 ^ img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
5 _# l# Q1 Z. P u5 D4 B) r8 w& x# }# ]5 p7 e3 u% ]
# 相同预处理的方法. y& {" i) s! N; y
# 归一化. D, I1 N5 u/ W0 f
img = np.array(img) / 255
4 U2 v5 J2 y7 ~ mean = np.array([0.485, 0.456, 0.406])* e' c$ m7 C& r% ?. j6 S1 o
std = np.array([0.229, 0.224, 0.225])
3 h3 t# R6 W- \ img = (img - mean) / std9 H7 t$ W- T& S6 i' a; i' x9 u k
' v9 b [' e- N0 i; D- e: a # 注意颜色通道和位置( G; ~: r$ r5 W. k
img = img.transpose((2, 0, 1)): y8 y3 s; V0 E! E; B
2 W0 K) m% Q# h1 o; X% T return img0 P6 @8 P! o: M! H
' U1 l$ y, F" c7 x: {. h1 {def imshow(image, ax = None, title = None):
( Z4 @9 e9 k% n+ B& y """展示数据"""3 F. K# [/ z# X" {
if ax is None:# q: [8 x( m0 o0 R# |
fig, ax = plt.subplots()
' O6 n. A9 T/ S/ Y" Y% o9 r N
6 L( J, c% Q$ R% }9 W # 颜色通道进行还原/ d ^" L: j0 h$ M- F8 s$ X
image = np.array(image).transpose((1, 2, 0))' \$ Z0 M6 p2 }1 a* F5 s/ _
5 C0 \5 M( X& r2 T) A" s5 M6 ~7 ?
# 预处理还原- f5 j2 [& E7 z& L' O+ }2 Y5 e& V
mean = np.array([0.485, 0.456, 0.406])6 c$ h7 c- v4 Y; T0 N
std = np.array([0.229, 0.224, 0.225])8 K' V) U1 j8 D$ q! D* i6 u
image = std * image + mean
1 X' P7 j) }1 [1 e3 D9 u8 a6 X3 Z image = np.clip(image, 0, 1) q) F& x) |+ k- v3 R7 G
$ n/ O& P# Q5 t: [+ b4 o/ G( R/ L ax.imshow(image)
1 P2 ` f* Q! \, h* j' c3 v ax.set_title(title)
4 _$ g: d: O9 j2 y- P. @4 A% F, U5 z: m2 N4 U
return ax7 A6 G/ z5 J. |0 b
- s0 @ f$ v+ q+ R$ N
image_path = r'./flower_data/valid/3/image_06621.jpg'
# x! v% E$ f0 Uimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
/ s6 O9 }7 o7 n8 e! f/ timshow(img)- m2 F( @* j( G3 d) Q( n' F
: W' q: I* ^ f4 g* V# f; Y1) N& E+ r) m5 C+ b$ K
2) }* S: i; W( z1 a: V4 U3 z
3
5 O7 Y( g- h, h3 D: m. @4
. z5 W8 T/ z, \1 d+ l4 M5
) f4 ^4 d0 A6 P3 Y7 [0 p3 j' @6
$ ?0 }( X- c% m8 e+ U- K7
0 e' h; u6 r4 |! d( j8& `( V2 l0 s1 H( j, d/ X- u
95 @4 U' n8 m% ^# p! ^1 {
10
5 `( j+ y" ~! w! w( U119 @2 t' H0 N* o& \0 K0 t- ^4 w3 q* K" C$ ^
123 J- s7 s; c \
133 B# O) v7 l& g; |1 [ \6 H
14' `: K. O* a# |3 h' n# h
15: P) w, x/ u2 X4 e: M
16) I6 g- R P, b. T5 _
17
, ]( b+ m: W, j8 u" c3 [ d* T8 S$ @18' G/ k( A n2 h" B& z: p+ l9 d
19
: |" O: f. z0 ^4 m5 @/ U20
) A9 I( k; O" p& m& v1 t21
* P% ]' I; ~9 S% l4 q( w22
% { F8 A( {. u# ]- P* N- _23
. l5 p3 T+ V# s" t# b* G8 o24
1 c2 k; t$ [) s& {, C% v) O2 e! G+ S, G25
/ \$ D+ T) L$ {/ d' [ O! Z265 x. I2 R! g; e
27
! M! J! e; v8 S1 K8 q- n. o28
3 ~2 R% J: \4 ]/ r( Q. p29
0 |3 j2 N D/ d' _30
* X( E; c& k- r9 S31
8 L2 S# Q0 ^) h. q& r0 X' s" J32
0 w9 k4 l8 U% R9 x33
8 n9 P! C: c5 d9 I& F0 X34
. P/ A% ]9 ~% a% H/ @35
. l5 V3 ?2 z0 T4 L+ d" o5 e! _36
4 d+ c! ~8 S" c/ a1 C5 {37
0 R" O5 j# [- X2 j3 Z3 N38
7 ]# ^2 |. Y e39# R0 D; G: K# i8 Y
40
( K& S) O/ x6 }1 }7 D41
7 _/ ?: S: s3 a42. {: c# u5 Q1 B' e; V3 b5 {
43. v: I- L- K0 a1 [$ E( w3 \" f
44' H8 B& v- M, ^( l! K1 l) _
45" ?5 P, Q8 i2 m' m6 k
468 G! W5 }( c, U L
47
& [8 K1 y& w# X7 K/ @8 Q486 ?5 M# T" X( m1 _1 j" y1 g
496 I/ Q1 x$ e, v
50
0 Y6 ^" e" U3 h$ d5 m51
$ g4 c4 S# ?9 f( g% \3 d/ T' e" U3 c52
* e3 o$ S( _* s53* w- u; b2 u+ d' H$ {
54
- e0 u5 G3 P( `$ k0 k( ~<AxesSubplot:>0 R# ?- b* l+ L, w- W3 Z
1' X; }' h. D: H: ^# S
$ c% x& A' z& n$ {- f% D1 @, k7 z上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
$ |( H- @* `7 H7 B$ Z( z+ m. o
img.shape
$ w! q# c' K- R* i7 A1 k) e( s/ ]1# v1 ?% M* K0 e3 }# C1 l% D
(3, 224, 224)
a- C T4 m& p6 C1/ O6 E; e" e0 p7 s- u$ N/ L5 e
证明了通道提前了,而且大小没改变4 `+ k# E! u, B/ h; N0 i
' Y# k: g; m! ?8 X9. 推理
$ g# H* y2 u- eimg.shape
5 Y- C; }. G! Y% ` U' }4 u0 c. I& J/ D/ k3 m* Y5 z
# 得到一个batch的测试数据
9 |' a9 D0 Z4 x, }dataiter = iter(dataloaders['valid']); v7 Y0 o( A+ P, @
images, labels = dataiter.next()
0 ~ R. L& h2 b4 g8 [8 e1 C5 K1 n. [5 W1 A+ a. \
model_ft.eval()
$ D3 r0 P- j7 M. {1 Z( o. c$ _) I6 g; Q, v2 \/ x
if train_on_gpu:
, b' z& [ o$ z T # 前向传播跑一次会得到output
- r# n$ u H& R7 C0 r output = model_ft(images.cuda())
0 M5 j, H7 ^$ K- m' l' Aelse:8 p4 O" u) M6 ?: O! M, K) O/ c% Y% ]
output = model_ft(images)3 ~: `/ Y5 c- T1 q0 N: `4 Q! W
( U1 o& g" n8 L/ B' K
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值( ?/ I b/ p' B
output.shape) _" @6 q9 S* X' m- U3 V. a% Z
9 L7 S/ P' T* G2 X7 {7 D3 {( p1* N9 P8 V" E0 p, t7 _$ d* g0 t
2% g8 {9 M' X* a8 ~5 { j5 ^' u. S
3: R& C) l" f9 z
4
& n' @% ~& R% D1 Y5: p" v- k1 u) L, E
6. p) d" h, B/ F' U, a
7" N3 r z3 Y5 S# s
8! k, |( w5 b5 H" }+ y2 D
9
% [ |; P- K7 V- e# H- Y! y10 m+ t' d4 S! L
11# R$ K" ]5 v: O- B+ Y
12 A3 m7 \! ?+ J4 n$ ]* b, m
13! z3 g+ p9 V! C. G# J+ x4 S7 Z+ b
14
8 h3 Z8 {& Y: L# m6 n. q7 `1 `15
0 P O6 e$ d" y N( ]# O: {16
0 p1 {+ n( D/ Q$ v" w Ltorch.Size([8, 102])
, R7 ~# P+ P+ @2 F; l7 [1% J3 j$ K0 D M; [2 V$ A9 l' K4 w
9.1 计算得到最大概率5 g$ ^9 z2 f- n p, E7 q
_, preds_tensor = torch.max(output, 1)
% r% t, e; k) |1 E' }" F' K- J! r3 Q/ { n* n) {
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
6 z# m6 I; ^3 C6 _1
5 ?5 H- {% i; f9 W29 @+ a' a. H7 a( I1 _; }# w
3
6 v, O+ p" k- O( g8 L9.2 展示预测结果+ L% d0 Z* n1 C. b+ F4 l& ~
fig = plt.figure(figsize = (20, 20))
0 } ~; \- x7 |/ ^columns = 4
. Y! K2 @7 g. O+ `9 ^, `5 d, hrows = 2
2 M! T9 L1 C5 B0 B
+ {& w7 q% T+ n# h% {for idx in range(columns * rows):
3 w T7 B4 c/ W ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
& q% @' q* N4 w/ t; O4 g5 b- F plt.imshow(im_convert(images[idx]))4 B+ k, I( y/ F: d$ s
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
6 p( X/ ]: G, f J, p6 b color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))" [4 d. ] @. V& A
plt.show()
+ M' ?; e$ G, a- h* ]4 k2 `- K( [# 绿色的表示预测是对的,红色表示预测错了
" ?& t9 w2 w# ~8 e+ N1. O; i3 v) q3 c# _: H/ K' U
2+ `! w8 ^: T6 s& g$ Z# s& q+ u
3
' o3 x+ t) e2 f, j; u4# _& V6 e7 B R1 b& M
5
7 b' w6 X5 }3 j! l4 _9 ]( l' b6& z( W# j! t# }* {1 d) K; g+ N' @. |
7
0 D1 C5 Q* ?" }' ?+ {8. C( F8 X$ o, V. X# X) l# x
9; K/ U6 y( Y! x& d! n4 k1 _
10* V4 ]/ i8 h/ l/ x' y: a
11
& V6 D: }) q( V1 C
9 s8 W( c: m0 B' g) B! d4 y0 T; {9 l/ _; P B4 E1 c
2 `" t8 q# I A! _5 Q
————————————————
8 f4 _1 I/ F% W# t版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。5 Z+ x3 J2 q+ E% ^
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940& ]( H1 N$ w1 Q7 _% X2 F
! u* j% b0 P% X# j
) |* q( F$ `6 n* M. ?/ `( x' ?/ u |
zan
|