- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564448 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174557
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例0 A: b8 _: E) w
& U- d$ x# n$ i
文章目录
, {/ K! W9 m$ e卷积网络实战 对花进行分类
L2 s- z) q: C数据预处理部分, t9 t# s" |; A% X7 U1 m
网络模块设置
8 \6 {$ V; i2 p" U* m# K% N网络模型的保存与测试2 k' {/ y1 x8 T) v. {2 |
数据下载:) z% Z; X, z/ A( R6 S. q( n
1. 导入工具包2 ~) X L2 S- B( _
2. 数据预处理与操作
9 K8 E8 [4 @0 H3. 制作好数据源
, h! n) X! A9 I0 c读取标签对应的实际名字
$ u' Q2 y7 K* X9 u' ?2 I/ S Y& Q4.展示一下数据
( s4 d) l; ^& D6 c; y% ^$ {& d+ f5. 加载models提供的模型,并直接用训练好的权重做初始化参数
& _0 m7 t0 X* T" a7 i( s1 {3 \6.初始化模型架构# J2 f, I7 `7 V5 \0 |, d! G
7. 设置需要训练的参数; B+ w# {* I9 E1 G r
7. 训练与预测
# Z1 I0 E* q! p- i: y" r; O7.1 优化器设置+ C' E6 C$ A+ W" E: V9 T
7.2 开始训练模型( E4 M( m7 ~9 p4 h: _6 r
7.3 训练所有层! N9 `" s0 q# P1 P
开始训练
8 x2 {0 K0 d1 }: _; P! [. L% _8. 加载已经训练的模型1 ^" r' i0 i/ B B
9. 推理
/ Y1 o* Z& `, G4 I9.1 计算得到最大概率- X- _% \; i/ A
9.2 展示预测结果
# d( j) |, c9 D1 m E& L& P8 c写在最后. ]) g! q' m, i
卷积网络实战 对花进行分类' \& A8 E8 n6 B8 R9 a0 e
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
* ^, t( o' s3 F+ ^* h; n7 ?
8 J" t5 M- Z) d, r/ z在文件夹中有102种花,我们主要要对这些花进行分类任务
3 _ E# O2 r2 h) ^2 L文件夹结构8 A V4 e/ L3 \5 o/ y
# M5 Z; u, s4 }! C- R7 h" A5 r/ tflower_data( V: r0 i# J# T4 R9 |* p
; p v" ^1 K; P# J
train' v7 q6 G9 e$ L0 r7 P0 w
: W1 Q5 i6 \* g2 y8 {& e
1(类别)
) a9 w) P. ?8 G% @, v2
# _, u, `& ~, [2 O3 Axxx.png / xxx.jpg
: R9 u Y% i; d, ovalid
7 k# o; s. `; T G
4 E& R: q5 E& v0 i W/ p/ z( B主要分为以下几个大模块0 t; w4 @8 B- I. r$ n7 m c3 l/ F% J7 m
3 r4 ^9 g f3 T( k) ?7 V
数据预处理部分
1 k4 U' u4 S9 }数据增强% h# P' }( ?% y% G1 }+ K) ]* `
数据预处理
( V- h4 ^3 K& S0 m网络模块设置
- q* z5 o4 w1 v; I! J! Z加载预训练模型,直接调用torchVision的经典网络架构" U6 o; @' x0 ?- E# g
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
' @4 ?$ M, ~: f9 D) X9 J* c网络模型的保存与测试, k; |8 `# t' E* m) T; b
模型保存可以带有选择性" ?3 {- s: m/ L- F( _! U+ h
数据下载:
. i, b) ?; R/ U3 X' f" Phttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset3 F) T) j2 P* O: u
# E0 J. H0 p! T7 y. i s
改一下文件名,然后将它放到同一根目录就可以了6 \5 O5 Q# \7 i' o( ~8 R
+ f, ~1 W/ u0 D* p& C9 T
下面是我的数据根目录
L0 @, j/ J3 w; e4 \& z$ x/ |, n: j- p) c/ w) }
# @" W$ Q( [/ m
1. 导入工具包
7 |$ K; W" T( s9 b# C+ l, b# Eimport os; q. \, A" t2 v! ^; A& R/ i
import matplotlib.pyplot as plt
2 y' o4 S% ~1 _# 内嵌入绘图简去show的句柄
8 i0 z+ M! W# t: G( x%matplotlib inline . Z# A h! U; \6 C } g. l
import numpy as np8 T! q* Z- u+ f( I! b C8 B
import torch
- ]5 L) k! f" Z! X! Mfrom torch import nn
: ^4 ^1 m# U: I: W! Q' `. I9 V3 f/ H
import torch.optim as optim
9 [0 u& A2 E5 t, z! n' pimport torchvision
) v6 |( v. D8 ^. yfrom torchvision import transforms, models, datasets
+ R1 D& B0 {% x- E$ x2 g0 a
, e$ c; ^$ o# o5 E0 Iimport imageio
6 K( w5 J/ @: J; {: X8 g' v1 Z+ pimport time, I. H2 l1 l9 y9 {: o. N
import warnings
3 T$ `: G! `5 }: q6 e6 o/ iimport random0 ]) }4 H0 y( `1 [. t( z
import sys
: |& t# ~! K6 ?7 e, ^import copy& d3 p* j- g% _9 J9 z3 X. l. d
import json5 B' {- `8 G- A+ a) y& W
from PIL import Image
2 _. Q4 }" x( `% Z
" X/ n% L' G* ~) q8 o; i- C) u# ^8 |$ k5 [( y! a9 q
1
8 j/ W+ E. {! s s3 u2
0 w( r/ G- k$ H. H0 d q: T$ t3
3 b, ^* [/ O3 ^* r, @6 g, L( I41 I9 C+ Q" C, Y8 R+ |! h( r+ O* z5 J
5
! A9 C" B4 | r( d' c" }6
2 a, y9 @" V8 L! L+ g7) m9 ^* F% y P, d- v) N U
8 R! M* |2 n& |/ {& X. C; Q' z3 p& |
9
~( L; _/ X3 s0 ^9 i c- Y10' i, E8 Q+ N( ?* l; V9 w" Z
11
& U2 }- a7 d* l G) C) H' G12
6 U& i0 ]7 U* I3 M5 `13* w/ P8 S: q5 j/ N/ e9 X
14
5 V9 c3 d8 M0 l8 V/ D* `15
+ \# {8 ]; W7 u16
, h( G6 p" D+ H2 T8 [! i* i& G$ I17
6 y) X# H$ d) e% T/ I; B& h18: }) i% I1 A5 t8 j% J( o9 M+ l |
196 W2 K- l# {6 |5 L) q" m/ a: u
206 ?8 Y, C4 u; @. ?1 l- A
21$ |5 U/ u7 @4 ?
2. 数据预处理与操作5 s% _+ N+ L; H9 V6 p2 q$ u0 K
#路径设置5 O/ H$ T3 v0 N; @9 q+ D7 e: L
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录/ q' l+ W+ i: P5 R/ `
train_dir = data_dir + '/train'" X. s* L2 s3 p2 Y: N
valid_dir = data_dir + '/valid'
) L* p3 C6 }0 Y, a6 e- s+ v1
+ N2 C$ z( Y1 h& u* ]$ ?+ Z2+ i4 G. ^8 n8 B
37 e/ j* F* ~3 m/ [: Q3 b
40 d6 V4 D+ Z# j5 A! ?7 Q) {
python目录点杠的组合与区别3 I8 B% g2 U8 x. b4 ^/ b$ Y
注: 里面注明了点杠和斜杠的操作
; o# f. `# \! p* D. p
, I* @" s6 i$ e6 R& G* s3. 制作好数据源
t1 L# Y* M h( S; Y1 p4 Edata_transforms中制定了所有图像预处理的操作. w% x$ o9 o" h$ s$ q, T
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片& b& e0 s0 c9 ?1 P" p' C, }
data_transforms = { C6 _0 F: i2 L |
# 分成两部分,一部分是训练5 |+ m% Y4 g4 ]5 i, w# I% S
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间5 `! V3 Y# v; V, D& l
transforms.CenterCrop(224), # 从中心处开始裁剪
+ _" o; E, P3 H' k `6 e # 以某个随机的概率决定是否翻转 55开
) l! Z9 Z( t2 _% C- e9 y transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转6 R& K8 b& Q; Z% e2 U7 l& H
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转. X$ Q/ G% S! G( k. Q
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
z o$ u& S3 L2 V transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),2 `; ?# F+ Z- ?5 F6 ?
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB( t! U9 u, B4 B- u _. V
# 灰度图转换以后也是三个通道,但是只是RGB是一样的
8 k0 d$ u5 D+ U) O6 h; F; d transforms.ToTensor(),
/ _4 M- `$ w2 U. R& w: o8 h1 ] transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差3 ~" i: R2 G/ `; d
]),
* Q( u( K5 |: R5 \ # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化 z2 B' o- Z! p5 C& {
'valid': transforms.Compose([transforms.Resize(256),( m/ h& @2 P0 ^4 t' `4 ?$ s9 G
transforms.CenterCrop(224),' v/ a* p4 V6 I( G; C, u) U
transforms.ToTensor(),
1 F6 n0 l5 a; Y! i! G5 b# F transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同- O; e" g# x1 g! `/ u6 \0 n3 E
]),
( S' U8 c8 v6 u6 E}2 h1 y3 N7 T0 l' R5 L8 e) a3 E
7 H4 J. A2 @5 a$ r. C1 Q+ ?, `
1
* h* v% k3 A! f, r: r/ k# w2
" [% M% r# m8 r: [/ ]: s f3' n, H1 e2 A9 a7 o& {2 m
4
) j' S5 w6 L$ v; Z# b3 E5- E' w5 l' ~9 O) b' H1 p" o
6
7 K2 v/ C! T: x3 O. z7; q( \9 a2 D, x
87 e$ s# t8 I3 P$ v8 N& c8 r
9# D+ s. \/ V) p) C4 K' t) I
10; T. S* \' q7 {' {% N: k+ t
11) ^5 H! C x3 I8 ]2 g, B; n1 t
12' e! d7 l5 }( x. e/ ^& J$ R$ u+ n4 [
13
2 b4 \9 n1 v( w* U0 t( x+ j14
3 j& N3 K) R9 X% G: N- `, ~) d3 z Q15
- G! K; Z& S! L16# W- L; g2 O4 Z/ t; U' U
17
; K5 v6 ~+ @7 m8 b18
4 t/ i$ i# ?1 P( g& p& U% O19
$ A+ p' e8 ^4 [" I0 z20
# J4 O. Q; q" V8 u( f$ p21' o2 Z! ^' O* w2 O2 m6 S2 V& Q6 t
batch_size = 8
& P; g) X4 ]$ f7 X1 Dimage_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
8 m. B% |, L! qdataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
* l2 S4 W( E) ] Hdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
3 w z6 H. u( N! X$ o9 ?class_names = image_datasets['train'].classes7 n$ Q' S5 `5 Q4 n
' V4 q9 @( v9 B7 y; Y+ s#查看数据集合" z. H4 x- s9 x) f0 e( u
image_datasets
9 K) N e9 ]2 I# h
1 o7 G; B5 z2 s @. | Q% m1; u1 ^' f" B1 J) K2 M$ F9 D) Y/ W
2; g( B9 X- L4 M2 \4 v" \
3. L& b2 B; c( {; ?
4( u# Y3 a6 U2 I3 P
5
' G7 D; r- y2 {6
) X Q& R6 s7 o7 G8 W9 m7
/ t" f" f ~. M3 l# n' a8
5 N. ?. _$ d2 g: Q& s9
1 c4 c0 N' Q7 j c- }% a, t9 C{'train': Dataset ImageFolder; n3 S" X, {3 _0 D* l
Number of datapoints: 6552) C- Y Q' a0 Y! i: { s3 S
Root location: ./flower_data/train5 D3 E5 W9 u+ n" v$ [
StandardTransform8 P; {4 U n: q, y# A f+ s5 h' b
Transform: Compose(
$ \% t- T0 `/ b! i RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)" q1 Y7 M: {! W7 X, |, s
CenterCrop(size=(224, 224))
) {6 P2 ^! }) X8 [5 V G7 H RandomHorizontalFlip(p=0.5)1 V% f2 x: |, ^: B
RandomVerticalFlip(p=0.5)
( p4 f4 L, B" ]8 e1 u ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])* }; [2 x+ A; {' |4 }0 R
RandomGrayscale(p=0.025)
9 b6 ]+ ^; P. E1 n( P- z ToTensor()3 M, c1 K$ b! S. I- e5 `
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
! U, q2 H, X* {0 l8 M ),
! f9 H* J* c: P 'valid': Dataset ImageFolder5 U$ P/ @3 f' u) M3 _
Number of datapoints: 818! Q8 Q+ H2 v7 T5 R* m+ k+ k
Root location: ./flower_data/valid
, [) P6 w3 d# V2 p m3 f, R StandardTransform
3 h5 b" }- p: A9 f4 F# c Transform: Compose(
0 q1 R( e9 J; e) ?+ m; K" l5 R Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
/ V) e; _. x2 W# N. X+ p% K: k CenterCrop(size=(224, 224))
) m5 d0 Y, q# X0 {9 ~0 | ToTensor()
% @. Q) v* L/ z2 S d Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) o! P H& ]. U- Q- a7 k$ z
)}
6 {" u5 B; ^: ` s2 G" |2 V+ \
% }" [8 p0 U' z1
* c% \! D+ x5 x- S2
7 N3 N7 ^# P1 {32 B2 O$ g: R# Q0 N# V2 o% Z
4
& X- F7 c' B7 O# [5 h+ B5% I0 j! G5 H, S/ b" {9 H2 R7 ^4 S
6
" n( F& e9 P( b3 m: B& k8 B# a7+ @5 v& O9 I9 }* @: z3 k* o. G* d3 G
8
- m$ X, @8 E% t, g5 v9/ T3 S8 { p/ _: j# j! M7 j ~ P% a
10: y' i1 ]2 Q5 k3 H7 \
11( F+ e& J' W* L7 ^3 _0 t+ E
12
7 f" U+ N" r! v' z# v13" W; W' ~4 |, ]3 ?
146 n& I8 E. k! Q5 V6 `# E
15
# g7 u; y9 _* M+ R, `- ~/ n167 f9 ~& N2 Z8 U# ?7 R9 e
17/ G' u" {9 H# p( }9 o
18* I* L# N8 V% s5 P/ B% ]
19
3 T# b% h$ C; B+ I20
; t: {% Q" n0 K5 C3 W6 w Q" C5 x21* r w8 k! R3 o$ T p* [
22
- \. I' u6 \1 e6 S7 C- Q1 y23' o1 A; M- E9 O% G4 v2 Z" J) B
24+ O4 K/ ~" O/ t* D; c& L
# 验证一下数据是否已经被处理完毕
) o/ o$ K7 H; j1 j. X4 s1 G" qdataloaders
" }6 W: F, \ z14 ?2 K8 H- V- @9 ~9 W3 P
2) y2 L; M* J' ~) A" s5 { q
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
, Q6 c! I4 x6 P: v- i: x7 L# E 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}1 n1 A2 q( [2 M6 e3 {. ~7 G$ D
1; @1 g ~0 Y, ~0 _( y$ P0 ~
2
, m' l8 _3 l- Q* h: Xdataset_sizes
9 x! m+ ?. M8 d% {4 V3 M6 v: k$ F4 ~3 p1/ j; L/ R, ~# i8 c2 F: C" Y
{'train': 6552, 'valid': 818}/ M" D) d& l ~8 X
1
7 r z2 M1 i8 T% z4 Z9 Z读取标签对应的实际名字: D5 C w9 }! `: M3 ?2 r
使用同一目录下的json文件,反向映射出花对应的名字
0 @ S2 z9 K! q7 B0 y2 S
' c. X1 R! P+ r" ?$ B1 cwith open('./flower_data/cat_to_name.json', 'r') as f:& Y, \! x9 r+ I: v
cat_to_name = json.load(f)3 r+ I) o: ~7 b2 r+ G9 r' l
1( v( f* T" g9 l" t r' X5 \
2
# D: _! @! v( N9 {8 h Ecat_to_name
# i% [9 e% ]; `. c9 t1
/ S& C$ R5 C" F{'21': 'fire lily',
3 r4 i. M" f' x8 U" f1 `0 }) F '3': 'canterbury bells',
$ \8 b; E3 J/ r; u* H '45': 'bolero deep blue',
l, e- v y: G '1': 'pink primrose',0 i( L& V# J% e$ {2 f
'34': 'mexican aster',
; r9 [) @( e/ K '27': 'prince of wales feathers',
! @+ S% n6 Y) U- B% [ '7': 'moon orchid',: E1 o$ L4 O3 q. r# [
'16': 'globe-flower',) U$ x m$ L- m- m: E: _
'25': 'grape hyacinth',
% i% K# a; {- y, l4 L '26': 'corn poppy',' q4 M$ x9 b' ~0 N3 P$ o
'79': 'toad lily',
- F, K. \7 p+ W: e. C '39': 'siam tulip',
" ]* G2 k4 n# e0 D '24': 'red ginger',
5 ~4 N/ U- b5 m( x, b '67': 'spring crocus',
, L5 z% N/ [6 ?8 M '35': 'alpine sea holly',
2 J0 y, o' _0 S8 G '32': 'garden phlox',
# s3 N$ p+ b2 M '10': 'globe thistle',
3 Q* L; x' e0 K; i+ t) J$ u '6': 'tiger lily',
B7 r# _+ F5 i# Z+ o '93': 'ball moss',8 C0 F2 D* @8 F( p1 p! N
'33': 'love in the mist'," w( K8 K# V) S1 |
'9': 'monkshood',
4 k7 \+ n4 ~: O9 p '102': 'blackberry lily',) Q4 r8 v1 O( w8 V" X; i, p
'14': 'spear thistle',
* C% C9 A' m1 c '19': 'balloon flower',$ i8 |& k: M, K, p) @# t( F6 @, O. r' K
'100': 'blanket flower',
+ \+ I1 h* L! z: ~# z5 }7 S '13': 'king protea',; m* t) F* N0 l
'49': 'oxeye daisy',
2 \2 K' _5 p$ m9 D4 f; ^ '15': 'yellow iris',: O' ~) T% D# A" p
'61': 'cautleya spicata',
- \! W; ^4 m* p$ V5 O u& U( V '31': 'carnation',
% t4 [2 P& ~. J+ h '64': 'silverbush',( t% z& ]6 G1 l0 {& @% Z5 k: u
'68': 'bearded iris',/ d/ f8 C& J+ P7 E
'63': 'black-eyed susan',
) K# y7 Z" C& h! k; o '69': 'windflower',' z: M9 l; U0 }* r: N1 P* Q
'62': 'japanese anemone',
& v6 m& j+ R4 E0 Q '20': 'giant white arum lily',
& I* Z7 P! P; S( ]1 R% G '38': 'great masterwort',3 ]3 o1 ?! k1 e- B# O5 n/ |% Z
'4': 'sweet pea',
}, |; e. G! E: c3 E$ F7 M '86': 'tree mallow',) z* W2 M( j1 a; \
'101': 'trumpet creeper',
1 K: ^; t [7 f' S4 a '42': 'daffodil',* J, g5 o& a3 C" b- B& g9 {
'22': 'pincushion flower',
+ i+ Q2 j4 a% k8 o I* P '2': 'hard-leaved pocket orchid',
6 n, |, ]4 J$ _- {9 f% D '54': 'sunflower',
' O1 ~9 Q' |: ?2 g1 M '66': 'osteospermum',# w ?4 J) \" L5 A9 G
'70': 'tree poppy'," s8 r% ~& F9 `* k, a8 `" _
'85': 'desert-rose',
! h* G# [% v7 j0 ] '99': 'bromelia',
1 U# T# g8 f; M" ], C) P! z, v '87': 'magnolia',
0 }8 V) l/ s0 p' i7 U# c# E '5': 'english marigold',
/ |4 P; j2 w: p) y# ] '92': 'bee balm',
" m3 [5 }6 m1 C) } J7 j- p '28': 'stemless gentian',
# x$ U& ~0 r5 t! Z) e5 ?7 G) ^ '97': 'mallow',
) L# P) u* f. P7 f; l( R9 A '57': 'gaura',
7 {; g3 J9 c$ R% P+ } '40': 'lenten rose',
1 Y W' L) A" r' ?- } l '47': 'marigold',
3 d5 q) {& m7 B, C6 ^/ z! ~6 | '59': 'orange dahlia',6 o. x0 w I ]" s% W( l5 g
'48': 'buttercup',
1 ^* c$ O- ?* x. n" n) z b4 X '55': 'pelargonium',' c$ N( k; F& |3 H k0 r
'36': 'ruby-lipped cattleya',- Y0 E! B. o% M% W; z2 A$ B
'91': 'hippeastrum',
! }% B' G. {1 I* n- ~! T '29': 'artichoke',
7 f6 ?! B. ^, G '71': 'gazania',! |4 s B) R) ^. q- I4 i9 x9 z( A
'90': 'canna lily',9 m+ S4 i0 t8 j
'18': 'peruvian lily',
. Q/ { o3 S4 ]# c/ ^; v '98': 'mexican petunia',' ]) b9 A0 M5 a* w
'8': 'bird of paradise',
$ q* \/ U, D0 l. G '30': 'sweet william',! a+ ~ w' m: `4 G1 O, }3 C
'17': 'purple coneflower',9 Z; p* k# ]9 I+ `
'52': 'wild pansy',/ T& b n7 _) c7 c
'84': 'columbine',
7 y* P& m# {6 P' d '12': "colt's foot",! M2 v5 }. ~& `6 s& ?! y; S
'11': 'snapdragon',/ e5 \5 n8 c0 p% Y$ j4 Z7 e
'96': 'camellia',* D7 `! m% @( h) j( Z" Q' L R
'23': 'fritillary',0 E* R. L7 ?! _+ x- T2 [
'50': 'common dandelion',
: S* v3 @& s" y8 t& x6 R8 D '44': 'poinsettia',) E4 z# r# E, ^- {5 ~1 S
'53': 'primula',/ f( H, c( w$ g1 y, U
'72': 'azalea',7 a% O4 f* d, v
'65': 'californian poppy',8 p9 h( K5 Z: V0 I
'80': 'anthurium',6 i p, V; m; W1 ?4 y
'76': 'morning glory',
0 R' T2 E5 l* u' g, i2 j; w '37': 'cape flower',2 h2 [' q, Z3 a# q4 i
'56': 'bishop of llandaff',- y9 c. b) B8 H+ D5 D8 @( W
'60': 'pink-yellow dahlia',
7 F; \1 V; f. L P' M4 g '82': 'clematis',3 V) a, m( I1 p: ~
'58': 'geranium',
; T) q) k7 `) g '75': 'thorn apple',. ?2 J6 B) U8 Y; t
'41': 'barbeton daisy',
; k9 S0 J: e1 o& [. g( X '95': 'bougainvillea',! z! i1 l/ K9 W4 ?
'43': 'sword lily',
3 f T' b$ F7 I6 Z '83': 'hibiscus',) X4 \. B# d9 t8 n
'78': 'lotus lotus',
/ v9 W; h5 D& y '88': 'cyclamen',
5 s- q) w w2 _, s; u '94': 'foxglove',
8 ^4 C6 A+ z# ?+ R) @ '81': 'frangipani',
+ U' t7 W! N( x- m '74': 'rose',* g& S+ ]0 U& |& g O
'89': 'watercress',% p/ S2 o! J- y8 j
'73': 'water lily',, v. t0 m, C/ n B3 r) c; d7 n- a. z
'46': 'wallflower',
; D' P% c# M0 A6 H; ]# b1 S '77': 'passion flower',' Z& R- ]& X# c' U/ ]
'51': 'petunia'}
' s" Z) }9 ]% O u( q
W: t. O# |& G13 p7 R; a' @4 K: ?3 a
2# X1 y' A4 W A( @
32 a0 u4 x$ P) k6 F6 k! }
43 E: J* u4 C; j, e; P- Z6 ~' F
5
: k* {9 d1 V5 L: j* \8 ?6
1 S: X5 F2 I* P% D" c( R79 z' X0 g! l; ~: @" i
8
. Q' N3 q, L" E2 ^% t9- e9 u8 f1 \. y. C" x5 O7 ]
10
o) Y# `" f/ O# {( l/ d/ ^11' i2 @3 q. P0 s
12( j# A- Z5 \2 G' A. P- O# t; B7 |
137 C: Q: c! H( c) A8 j3 \
145 Z& ^; @6 C$ C6 a3 g6 X( I$ S
15
& ?( N* T6 K! }& ?16- r; |3 \# r( j$ D e. { `
173 e2 |+ P& ~5 t( |
18
3 i) N8 Y# w2 H. t19. q3 U2 U X. j8 o2 X0 W$ j
20 J( n+ ?9 ~: R, ~
21
# V. S8 v6 n) i1 e+ M9 ]8 z. L22$ V' ?- s+ s) D& n- u: b m
23
) |1 f+ r7 x* \24* U. R* t% s1 W* W
25
6 C, Q8 \+ b9 y, F* v26
_+ ~& L# S, O. N8 `; ]27
8 o3 i' [7 T4 d/ Q282 P6 `; x" Y8 G" F9 W
295 Y. q7 e8 v0 n) t# `- _
30
7 L" f4 ]3 ?7 [31; {) m+ E/ O1 v" h* ?, h5 b# F
32
3 g, K# }# X k3 z9 O33/ w c' S( w2 I( I
345 B1 T9 ^: c1 d2 x5 \! `
35
! M- L5 c, a; r) R" ]4 E" B36& h+ i. h4 f7 q
37
! U' i5 [0 x8 \. y" v- x2 r& c0 l38% I7 |- P5 m# X* i
39
5 J e, v; b1 @+ J# Z' ~1 Z) |40
2 A( Q1 n ~4 g1 m8 ~8 t41/ d6 i' `- ~6 _$ N; n+ D! e( T3 i/ }
423 D# m* s. o" K1 D: d X
430 g& a% J1 j# e5 q; ]6 b! M0 c
44
2 k' P0 X) N; v45: a+ X- Y/ k9 ~8 m3 Z; N2 V Z+ s' A
46
% X; m+ [/ V' d" Q5 I+ G: J47% X( ` H$ d# Y& \
48
V3 u. A7 m: c/ _+ Z49. M3 B1 h- p. |( z
50% S* {: y p: W% B3 h
514 I2 a; I2 E6 ` J
524 e0 p C$ P1 z) L! T/ r5 A
53
% z( W6 Y$ R) l5 C) j. ?8 Y54" p% i: ]2 _+ t S
55
: L( u9 Q8 @8 G3 f" R562 a9 {0 b; h+ z9 k; _3 C
57" w2 V( A" Q. n, L
580 ?. t _3 h' C2 [+ Y3 \3 W4 A
59
1 o- b. M) f/ u- d! }& j60
5 P5 A- P1 P4 p y7 I61
# C% ]. Z0 r, c8 x4 n% m62
+ d3 K" |0 M/ k' C4 G63
9 \% J3 r1 F/ E64
B- P: a# a& d6 I r2 S- c$ S Z: O65/ z/ v/ R: n9 o* n
66; h+ D) Q9 A& I- j* B. P' Q& @; v
67
6 }* R$ K4 ^+ q8 C2 Y" J680 C! `* X' a4 L: w# \: b' L
69
" x+ x: e" y/ g5 x9 M( {5 F70
$ }& o3 Y' W9 V t2 t: b71
, J* f9 w" {4 l3 b724 L5 R) n `1 E8 j
73
! }3 F5 T o7 F* b9 G74
9 r# W: [3 ]+ G4 _75
% a7 Y! v7 X" B- ~; H2 f/ f+ T76' b* [8 D+ S, L% m( u. _3 o
77
7 i, l) S4 r/ K! M78
: R/ [: f8 j; A. F79
3 |6 A- b* P; q) Q3 Y) `7 \ S6 @80
4 p# u7 h5 K1 s3 k1 W( C6 l81) S' F( p' l h/ W2 ^; `
82: X% Z) _" y$ ]: W4 [0 J
83
3 W; [9 T# j# z! Z2 u( n4 f84& r! k# V( O0 h# T. k$ {
85/ K, f- b4 f1 ^9 u# E! l
86
0 H% K9 M2 h. f# T87
5 Z8 ~! o i% K) d88" E5 u9 T+ @2 V. D
89" V7 x' R% \5 M% G1 I6 n
90
1 |5 Q. _; L- U91& s0 F- H" R" c" L% t
928 D. Z N% U4 C- m* Z
93
0 A. g- W j7 y4 y4 u# p94+ B7 V+ @0 L( L0 ?( ^
95
$ K4 w% f' Z2 A3 B96
0 J3 n% b1 y6 _4 C& A' {97
- b8 M, ], }( ~2 ^" t4 i98
! b8 e, F: C! Y99
. n8 U" z; e- v. A100
7 M! s& M& H' G+ v101$ n; V. j3 f3 K2 |
1029 ?, X- i' E, [- q
4.展示一下数据! W1 W2 d4 \; Z, i$ h0 K( I0 F
def im_convert(tensor):
; y7 l% }; S' A* q0 k """数据展示"""2 F9 p7 t8 M. |: z
image = tensor.to("cpu").clone().detach()
$ U$ g1 L1 w" j i. y1 u image = image.numpy().squeeze()5 z, a$ ?7 D x
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
: i% K+ L! j# t2 F) H- S # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
' |9 {( B0 q9 h* h7 {7 M image = image.transpose(1, 2, 0)
( V, g) x- c8 \ H # 反正则化(反标准化)6 B& P: _/ E9 v
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))7 C# u; t- F4 \% T- V5 ]5 P
o3 U; q3 {8 {8 ?, D # 将图像中小于0 的都换成0,大于的都变成1$ \2 p* }- `9 C7 W B, [# B
image = image.clip(0, 1)' {2 n$ s" T' l" L
( w% V; I* b2 d1 W$ k. c- S: F5 m
return image2 |+ m3 C4 J+ D5 Z. g- f" n$ n7 o8 n
16 M( o1 p$ B' n# Z: p( G
2+ a% c! E, \1 q6 P5 y( o
3
' y3 T+ ~4 O/ u6 E, s4
( I6 s8 H6 {+ T% I; N! h5
* P* K7 L+ {+ ~5 g6
& a- R; ^. M' Y$ M4 y! A) A3 l7 s7
0 ^/ y. O7 Q* f5 b8. m: N: i% T" M. P$ ]' n( L
9
. x7 c4 x3 x% @# y" k" i2 Y( h5 E10
# P( |) M0 t6 H& v( |6 \11
1 O+ ^$ y3 K6 `: K" F12 L/ | b9 z0 ^- x; w9 I, `
13! ^8 F% J+ n( I' \6 ?" V) r
14* B4 Q0 @3 y8 S: _2 N" R0 A
# 使用上面定义好的类进行画图
0 g D- N2 M) T6 H8 pfig = plt.figure(figsize = (20, 12))% y/ J5 p6 B/ R5 b
columns = 4 d, E$ m5 G8 T g* M$ ?) H0 V
rows = 2$ O+ T: l6 f0 G h" k$ ~
) \. w1 g0 ?' n- u& Z
# iter迭代器( w3 U. W8 o; {
# 随便找一个Batch数据进行展示8 j7 P+ I1 {# q$ W
dataiter = iter(dataloaders['valid']). C" K6 g( T2 i# y, @
inputs, classes = dataiter.next()% [; }0 @2 r' S9 \
4 e1 U% r9 v* B3 \ y
for idx in range(columns * rows):3 F, k8 G8 n9 D( l( Q; h
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
1 q% O' a8 z3 O # 利用json文件将其对应花的类型打印在图片中" z5 P' _9 r0 D5 O
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
\: F" F0 |, C: E; M( G plt.imshow(im_convert(inputs[idx]))
9 k4 h3 [7 l+ C; L& oplt.show()
& O& E* x5 U- n6 x7 s; I3 C/ W" Z Q. Y+ W* \
1
, ~6 m9 w" W2 a4 K! ? Q2
O6 a8 H- _. e+ f4 A3
, q5 P9 t' k0 E4
* q6 V/ i( A# L$ `/ `5% b" d7 u5 w# Y" Q6 p P! \
60 Z. F2 a8 G. r$ m+ q/ O
7# f, i! D9 r, L0 Y. I8 i* Z
8
% R8 p/ w @, U9 b9
5 C4 y' B' G( O3 ~* x2 ^10
4 N/ m* d6 ]/ P7 l0 e11! |5 x n9 {# _$ {8 W' U6 j3 s! s" H
12
x4 \4 l# b3 T* F5 C, ~1 e" |13
* P& q: B0 i" }14
$ j. W; y' F' p15
1 k4 a6 V, m- B+ N6 C% {165 o5 G& a" W# R- m, f
5 X$ K7 \/ B) A$ }" z8 ]* L; x( e& ?' p5 H9 H
5. 加载models提供的模型,并直接用训练好的权重做初始化参数/ t' U9 I" h7 l' h7 [ I
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']! b! I' q& n/ K
# 主要的图像识别用resnet来做: A4 X/ o# [4 j @8 @
# 是否用人家训练好的特征% T9 W( V5 F7 n! }. B
feature_extract = True3 \ n( T( B1 E3 m3 [( E
1
- _ S7 Z6 [; \. Y: \: V- ]. I; W2
9 _! h, B5 e" p4 p6 \0 f5 m3
( A8 F* H C/ |# w% ~ w4# a8 B" t% g9 [$ r5 A0 G3 X
# 是否用GPU进行训练2 s9 r* r4 _/ H3 [; E' K- t9 z6 u
train_on_gpu = torch.cuda.is_available()! ?) ?' D+ }$ r2 i! R6 Q( e- N9 e4 i
: t/ T0 X+ j/ |& |if not train_on_gpu:
$ B: r- {* Z- L. [! `9 A- z print('CUDA is not available. Training on CPU ...')* N* _1 P; F$ M' U x$ K
else:7 g. ?7 V/ p% k
print('CUDA is available! Training on GPU ...')
/ w$ }0 c& K% B# ?! l3 j G* C7 h3 C5 k* b0 n; n
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')9 Z( L# \' Z$ h7 j+ V& i! K
11 Z u# l! u( U7 t d9 E& K5 u
2
, ^' v7 w: @# c: t: [0 `5 d3
# a/ q* J# E- N5 F8 S4% X1 L, A4 [& W* Z, }
5: m3 _# k+ D( b: j
6
$ q% n# n' ]1 W77 q# ^3 k, k/ K9 z4 I$ D
8' v" c7 r7 N8 L0 y& q8 D% s
9 u H6 n8 K6 f
CUDA is not available. Training on CPU ...
7 m$ r& o0 i L3 d2 G1
7 s4 \; Y5 i8 ^5 b8 L# 将一些层定义为false,使其不自动更新/ {- V% d* ]* c8 b( q8 ]" B3 f
def set_parameter_requires_grad(model, feature_extracting):
; M3 Y* P: E. s if feature_extracting:1 [+ y% O. \' U
for param in model.parameters():: _, C. k2 _( i
param.requires_grad = False1 `! {" h3 N8 K8 }; Q
1 A n7 r n/ `
2
8 C/ A: C2 E t/ Y( t3. S' ^+ h; |- U
4
3 F% I4 p5 J$ r) j5* E5 Q9 w. D& m w
# 打印模型架构告知是怎么一步一步去完成的
; M9 d* v1 L" ^; I, @* w# 主要是为我们提取特征的, w& L' r" F C Z' G7 U! G
5 J6 E0 [- P& ]5 T0 V& Qmodel_ft = models.resnet152()
0 A3 B) t I9 e+ k* Lmodel_ft) a3 q+ p$ k7 Y6 d# p# K6 b) n% `2 w
16 u. v* \- z u( ~
2' T+ [2 E/ C* i
37 N% X$ B2 f4 G: i @4 m
48 l. R5 \/ t$ I3 P! J/ B
5$ [! g' T8 ]( A" @
ResNet(
0 m# F: u8 t3 N/ G i (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)# T5 v4 W# [- j2 [( I3 j# X
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True); d! B. ?4 e' t4 S2 e
(relu): ReLU(inplace=True)' _; _# P6 L& J) y' Z0 }
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)3 N; i/ z5 |, y3 G7 [
(layer1): Sequential(! p3 q$ c. X3 M0 z! `4 G
(0): Bottleneck(- \. U7 q9 }( q, [+ h
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
+ q( H+ Y T) n4 \ (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
8 r# d: r/ K% u3 g& Z' s% b& H8 f0 R (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
; h+ p2 s; N9 [: Z$ W3 T (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 p3 x$ H( \5 ?8 t* g (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
' k0 o$ K+ D K. C3 A# f6 W (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
2 j2 F1 j' L5 i% E1 K (relu): ReLU(inplace=True)% M( U9 C) S$ q# c
(downsample): Sequential(
8 l: K- e2 W8 g. l t0 O9 I! t (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)0 U" I3 j0 y* D- Z
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)* h+ |0 D! j4 K
)
& W! b" {4 P; `9 Y% f# p )' v- l! m7 P: g* |. c# N
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
- G0 D* L0 c5 {6 d3 p (2): Bottleneck(
1 G, H- n; n6 Z" @7 r) B' { (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
5 G, F& M W$ d$ W (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 t( d; j3 F# c (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)9 N I& f+ o' k
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)6 m5 i5 Y2 ?' s3 T+ Q- N: G7 m7 e
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
! a/ {% t6 |) s( y# m! x! } (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
; e' x% k& \, S8 }: f# p' \2 L2 n% h (relu): ReLU(inplace=True)
0 @3 A( C) w ~5 T, V1 f/ B )
% H( k- B; t/ ] )
1 J) t3 N- X( K9 \( s# t/ B (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))9 y. K, L0 U$ k5 P7 i0 `
(fc): Linear(in_features=2048, out_features=1000, bias=True)! @. p0 y4 p! L- e4 s- e7 X
)
! \: h3 @) A" n% Z3 X. q
$ T" Z% \) l0 q1% K+ U6 Y2 ?# E% e( ^! N
2# i$ ?1 x8 ~' z
3
, |: M! e1 K2 f: M, Q4, ~' S$ a' A, r
5
' @* ?; ?# L$ I4 Y% W F2 d* ^6
- z( A* j& ^4 E: ^: ^; ?$ ?# ]# a7
2 m+ j( O* p* `$ [' \/ I9 {8
G( z/ N# _) X# a9 g9
: b1 n, t, u, R10
. i, J: o" M" m3 J' x Z11
) p$ N% F5 K0 ~) p% }3 G12' z: N$ s) a6 H# j0 L$ C
135 {1 J1 q) w# }( T V
141 E& e3 E& M' S( ?/ M+ e
15$ {0 f* S- d/ G9 U& ]9 I
16
" I b. ], J$ R5 N) U% W17
) N7 U5 d C& Y: H* s V+ h2 ]& N18
# M4 {/ C4 S( C1 A2 @19
+ T( D3 F6 ~1 p' h208 C+ b! C+ v4 ~0 n" u
21, r( ?( ~5 u, Q2 Y9 b; v
22( }; q9 X4 |0 m: N
23
! ~( O$ @) u/ } T" I24
7 s' ` s/ \* c2 z5 ^; C# c25
# } O% I) K! p- @1 I26: g y5 P+ A" p/ g# R' E
275 @1 v# y1 E# f. Z( g+ J) [' ?* m! Y
28
% {7 }3 }0 T* N+ p# y4 D. [29% p( l1 e! j' E
302 l7 [# x# X7 o
31
- U$ L. @' m8 ?: x32 }1 ?3 V. v3 e ]9 K1 E X5 l9 S
33- w2 F, B! W$ C7 M4 q
最后是1000分类,2048输入,分为1000个分类1 P3 |; C# b. f/ k$ w6 `
而我们需要将我们的任务进行调整,将1000分类改为102输出
. r9 l( R+ d' Q( W" B t- w/ w7 F$ U0 S5 H: c3 u p. z# }4 i
6.初始化模型架构
b4 K" i, U: k' A/ x& ^2 A) `步骤如下:, K; D2 q: @4 Q4 T% Y
4 w$ b2 _8 h4 m S; C将训练好的模型拿过来,并pre_train = True 得到他人的权重参数+ W' ]$ J; ]* s+ E( P
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)2 K6 E* B+ r/ I# F$ T: u
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数+ P1 b+ }# @4 E/ N$ k- |
官方文档链接
5 j r# |/ W. o( S/ U; m, @https://pytorch.org/vision/stable/models.html
. ~6 y% H, J- B2 W/ x: S1 j! q, l( _
# 将他人的模型加载进来, D" z5 D6 b/ ]1 k
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
% x7 j& h3 S" ~7 ]9 w # 选择适合的模型,不同的模型初始化参数不同
) |: R$ l+ l3 y$ J: L model_ft = None
" m$ _" a: L' {* i1 m0 o input_size = 0% q4 _0 f8 P* n
7 O [& C: W% {" O9 z0 v if model_name == "resnet":
- l* d3 U$ y# {- G% c' _ """
( G& o2 K. ~, ? Resnet152
* c0 s; @6 g( D) N" u W; n3 c2 F """. s* e5 [7 ^+ d$ w: }, e# a
% r4 j4 C; D! N7 U# k2 b( w& D1 c8 E # 1. 加载与训练网络
+ M% y* f' A9 ]9 N$ N& G model_ft = models.resnet152(pretrained = use_pretrained)
; n, C3 a# a) b # 2. 是否将提取特征的模块冻住,只训练FC层
6 K; S" Z% o; R& C0 K* Q set_parameter_requires_grad(model_ft, feature_extract)
; |( u8 s! N( |7 n) ?* r # 3. 获得全连接层输入特征
; ^ N5 o4 d2 l. u7 ` ` q num_frts = model_ft.fc.in_features
. b! i9 @; X P2 z. z8 x R9 F) J # 4. 重新加载全连接层,设置输出102! [" M9 N# u6 |: |; y
model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
7 x8 \! e& `0 U$ X8 q nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
E5 d% I% x3 Z- Z3 C" V4 [6 @) ~ P input_size = 224
& c6 o$ B. g) j) N5 R( ] p ?" Z* }. ]
elif model_name == "alexnet":
" G1 F4 c& p8 n; u+ |( R0 p5 ^1 D """
# R8 q/ |" p0 `% h% e5 I Alexnet- |/ ~2 N' q! E. l! R. B( S& G
"""
$ h1 N( Z. D. W. D$ Z4 H6 I9 w" F model_ft = models.alexnet(pretrained = use_pretrained)
) l }9 _/ h* H6 t6 p6 K+ S+ U' s- \. x set_parameter_requires_grad(model_ft, feature_extract)
9 W% d% I: q+ I3 v) ?
/ V \4 w8 f+ O # 将最后一个特征输出替换 序号为【6】的分类器
5 b& \' }4 r3 m' z3 w' R+ |% ]5 I) N num_frts = model_ft.classifier[6].in_features # 获得FC层输入
, D9 Q6 w3 i& e: y% [ model_ft.classifier[6] = nn.Linear(num_frts, num_classes)2 U# T0 _8 i2 Z4 V+ h
input_size = 224
# M2 Q4 [- \; K% H( o! R+ P7 y9 O2 \" B3 N
elif model_name == "vgg":
8 w, Q2 @2 `+ M6 X& ~ """& ?; n" Z' U3 q, v7 G
VGG11_bn
8 A$ ?( z# p8 U """
' R# T e! q! v2 N2 \# E; w model_ft = models.vgg16(pretrained = use_pretrained): A+ F4 Y8 i2 d0 a5 C$ P0 }
set_parameter_requires_grad(model_ft, feature_extract)3 ?, `( X$ W3 M: V. F8 \; W4 D8 H3 ~
num_frts = model_ft.classifier[6].in_features
& ?* F, V* J& ?5 p+ f! X; T model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
! C( R8 m+ k) h input_size = 224
! ^9 p1 G! i: Q6 Y; T* Z- a8 f' T
elif model_name == "squeezenet":
$ f; Z* p- Q) A/ Z! p4 g# O0 C """7 r, I% E8 F* m, a
Squeezenet9 t% }# M# r. _7 f4 I5 c$ C
"""
# `( i4 v, X2 H6 q' p4 x model_ft = models.squeezenet1_0(pretrained = use_pretrained)7 X9 k8 M4 r. O4 N/ U
set_parameter_requires_grad(model_ft, feature_extract)
" t6 k4 A; o5 P5 Z model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))+ P5 @4 R8 }# ?1 D
model_ft.num_classes = num_classes
: u7 U. T, u8 L7 ?, g input_size = 224
' l7 w! F5 E( z; |$ L. J9 T/ _3 d) }/ s9 K
elif model_name == "densenet":: {! K8 v9 @0 C" Z' M2 f) f
"""5 X" `3 W" Y2 a7 t
Densenet
% _1 Y( H5 K. N3 }! W, T """* g3 p/ }. A. p: s- l3 d+ d
model_ft = models.desenet121(pretrained = use_pretrained)
/ h3 s$ W: M* G5 g" _1 b set_parameter_requires_grad(model_ft, feature_extract)
8 l2 d7 O( Y/ W5 | num_frts = model_ft.classifier.in_features9 O0 Z s5 s& _5 g5 X0 Q
model_ft.classifier = nn.Linear(num_frts, num_classes)3 M. Q- G8 t! Y) O" v* {8 B" K3 s
input_size = 224
! |6 _# G( C$ Q/ d) t: t
: [3 q6 l+ |) F: | elif model_name == "inception":
4 }$ R2 o7 b4 H! z' H- T """
% H. M( t# n& _. _' u Inception V3
/ q5 ~* X5 p- L% d% W: I """. s# R* |5 p+ P" g# r, o
model_ft = models.inception_V(pretrained = use_pretrained). r* b9 |/ W" \5 M& P
set_parameter_requires_grad(model_ft, feature_extract)
4 J5 X6 E% y% k- h4 K9 B7 A; v
num_frts = model_ft.AuxLogits.fc.in_features( h0 \) x; A2 H. U7 }4 s, p: V
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)6 A% o; F: c0 Q" e, d3 m8 X
. M* c# G, o2 A" |2 w) p8 V+ |
num_frts = model_ft.fc.in_features' t- e9 r0 t6 M1 x' m
model_ft.fc = nn.Linear(num_frts, num_classes)
% ~8 N' X {, O input_size = 2998 a& w( }/ ~" c& I @# }/ p" O
1 ~: ~7 Q# J. k+ b
else:1 ` v9 @/ w! A8 T9 I9 Z
print("Invalid model name, exiting...")
$ |& N0 P4 [2 ^7 }5 R1 M8 `/ c% l$ l7 a exit()
, t0 P# C2 n3 f& d/ A$ |* p8 p- B& O1 w5 N( e7 U1 Q
return model_ft, input_size8 X0 Q* O- R. B4 B, |
! T, S* |' l) Q8 G5 ^# p% a4 k1
) u' F1 `2 _9 j- z) s$ A2
: K8 m9 ]2 S* h2 P- T7 ] h% v3
+ S# {: K* U+ p7 E# O; F: m4
% U% w- y9 X+ F5
: ~; p1 R9 d1 Q$ z6
& Q4 P' w) A6 w2 K# Z+ t7/ v* w# \: B+ J2 ?
86 G5 w5 ~' Z1 |1 R5 a
9
( o) t8 U) d) j$ o10
' j1 ^# \ d' ]$ ^; [112 Y X0 X7 d2 J( O/ H! x) y5 T) J
12
$ ?9 `& R9 x! k& e4 L. e13
. I" }# G2 x0 Z- N) S- a14
' ^7 x2 [3 Z7 c& e5 ^, w" r15
. ` v& B. C+ L- P( E" i3 n( O5 Q16
$ ^* A! ~ T. f/ n" {: ? g17
5 n, P4 f6 Q% q( I7 I* f18& `4 q7 I7 J& r, H
19
5 K' s6 w$ k7 y8 h20
( i0 R, R2 S, N- {$ ~4 i: [21 L7 B& p+ B3 T) [& T
227 R6 Z8 A: `# ]. n8 x0 r6 w8 M
23
: \3 d9 \* H8 |24- U! V8 ?: d1 R0 ]7 P& e5 N/ z) m
25" r* K" Z' ~! g. ?: I: U6 Q: x' \
261 j2 X& K7 R F* G! J% ~, |0 B
27
0 V. v* ~6 L6 [- N28
/ u* G' x5 a% O$ @ g29
- m" ?6 N3 ~% r30" y8 e( }. J" C: N' h& g
31
8 ~" b& y" r. R) X32; M Z( H' A8 ~% d
33
; U* ]0 l6 w5 x' U N' T% W34
% @7 ?4 }: N; a* `5 f, ]9 C35
& s0 K2 G6 O/ S9 L* d363 i' R+ D' I" c5 o0 s$ J
37
5 E. @* y# b9 R38
# H( T3 Z0 p& ]: M+ a/ }: T396 s: Q& z5 m3 k6 A+ X/ |
40, f' ?: X1 j9 L! g; {& H
410 l* s9 F6 r: w+ q, {, y
42& _0 J4 N* g; h# j. V
43* R! B+ H) H" B' E* i
44. [3 x4 A0 Z1 @+ `* L1 c
45
- N* |. s. Y* I* A7 K4 P46
. X" _$ O- f% T8 f: _2 P47
e9 G. N4 z) _0 c( O. O6 H0 ?# o" S481 a+ e% K5 N- Y5 @# X; z( Q2 e0 C
49) U5 B0 C, s8 y" D1 @# S
50
& ^8 D. [5 ?0 e4 q z3 N0 j513 p4 P! q* `+ I* @# a1 ]8 T) f
52
2 E9 H* H" b" M( D( P: { n53
' o! n7 n, G, M% O, q0 S3 j54
7 a) Y6 X* `! k3 d @% K55. c/ [% T$ S& q; @
56
2 R0 J/ r6 ]2 Y570 i8 u# w, o! j( f
58
?$ X& g: s7 o" j! `3 T( A3 R59 `0 |* U; T* u( h9 w
60
* \3 w. N4 J3 X- k i% f% \2 i61
& | s! A( ~/ L9 w0 }' j' z627 ~* G( v7 b/ e8 A
63
1 b; E- b) h1 M3 b+ K' d* D64
4 Q# e' Y9 c# g" y( R4 |65, z) n9 J. X# e1 O; T2 A
66; Z! q& `0 \4 w8 ]4 ?0 q4 L
67- U# Z/ B* S* R
68" m0 W4 m0 i6 ~: Y( z+ ~
69
# @% D! o }4 P! c0 j4 Y5 e70
" ` r1 [7 @' b% |% q711 ^0 x- W3 a) }
72% w; P9 I6 h4 Y' F$ ]
73( M3 c' |& g" G9 A
74
0 U: J' P1 k$ O+ r/ H75& X) d1 P( M& `3 n& B# E( o3 s! G
76& G- f& `+ w4 J5 ` \
77
3 I0 ^# u0 n( G. d# |5 U0 X78
3 ?, c9 l6 {6 y$ d79& a) [. j5 S3 H9 T3 L+ z4 C5 O
80
$ y7 a6 j' c: j# q; m+ ?81& q+ g: N, _; ~ b! `/ P* M
82' @/ s, H6 W' s8 H
83 e4 K( f9 x+ q! p$ F; P
7. 设置需要训练的参数. }5 W6 m' N( C; P7 X5 B$ K J7 I
# 设置模型名字、输出分类数2 o# ^. U$ U2 I% [: O# P
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
! b& V9 K* j3 j" r( S1 u/ T7 C( W
0 s& U5 f1 c7 g1 p. H# GPU 计算! m# C5 O1 ~5 c* O$ @5 p ~3 d' O
model_ft = model_ft.to(device)3 j& ]( j9 ?3 g
4 I4 Y# o f: f* m2 I+ U
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取. W2 x2 Z) a+ ~# B$ i
filename = 'checkpoint.pth'+ ?* h3 g+ h* Z& n
5 i' q9 ]7 j, H+ @# 是否训练所有层
! Q. i$ J+ Z( t' m+ nparams_to_update = model_ft.parameters()& z. P: @6 f- C: q5 P
# 打印出需要训练的层7 q4 b4 ?5 D+ h' d8 R" s3 Z
print("Params to learn:")
# L$ B# V" d8 |3 ]0 Cif feature_extract:: D1 n. s4 U0 o* _: }. o
params_to_update = []
- L" J/ z- j `9 O9 b1 L for name, param in model_ft.named_parameters():
; ~8 o" l' a- i* G0 B3 b$ @8 f if param.requires_grad == True:
0 I; O! I+ y- `3 i, ^5 s+ O params_to_update.append(param)
7 M( O( G+ G: D6 r3 r' t print("\t", name)7 Z0 b- r! C% G% F' n" d! `
else:: n6 ` C. T- W' e& _3 O& F
for name, param in model_ft.named_parameters():, V x9 w' s" F8 ]+ r) S
if param.requires_grad ==True:2 w) F) w3 D! L! N
print("\t", name)
7 m3 [' M. T( }, ^5 i
' P- Q) @4 H+ q; z1
9 d8 u) s- m: ]% U7 K2; X5 i, T7 x' T9 e4 g$ d$ x& ]$ p
3
# g9 n Z1 ]/ i3 g, e3 B4
. r( {( n; \; R6 W, ~5
! O& q9 S$ u( H6 O$ o1 R6
2 G$ d# k( d0 Q! w5 ?7
! Y5 d( b4 v" M, s" q7 b8
) p$ b% n9 ?' I! L9
9 q. ~: }' P3 Y/ _# {' E10
E" J5 H* k, z" }) w( s! }11/ k6 J$ s# X1 x4 s, p4 o& b7 [
12. X2 q5 u4 Y6 [+ Z$ Z7 T
13" ?# E: y1 [4 [; R' h; m
14
+ c$ F P, t5 I& \* }; x2 o1 a: d15& N/ k) P- m2 A. n
16/ ^$ s* t, n+ o6 o# m
17" h( d: x" A. M
188 O- L/ J( B0 [9 H& J g3 k* C
19. h% c+ Z3 b1 P" N/ a$ c) f
20
: [' Y0 b' ~% F* z4 c21" e0 X' ~* T/ }8 N; ?+ [
22
: d: B& }8 w3 V: [9 T0 T9 W. s23
, X/ R7 d, l9 z Q7 d, k Y: e& kParams to learn:/ M5 W( `! M8 t1 H' B
fc.0.weight* s' r: ?- T0 d( z* c! `- ?& _
fc.0.bias
6 n3 H& C3 D( q4 y5 c17 ?5 h2 ]' d: e' X% i3 O2 F: d
2" y- w* G' k8 E% V( r5 }3 U; ~
39 L2 v" x; N. o1 o- w
7. 训练与预测
) Q; c% q+ N. s6 e% @$ ?$ V7.1 优化器设置9 T, D- s7 u8 w& c4 F
# 优化器设置
; `6 j2 h4 q' u- b3 G; goptimizer_ft = optim.Adam(params_to_update, lr = 1e-2): J8 V- _% u; o2 x3 M9 y
# 学习率衰减策略
" R4 t3 b8 }& Y( zscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1). w! c4 A: J& k
# 学习率每7个epoch衰减为原来的1/10- r9 B. ~; l* Q4 D b2 z
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
; t; r# P6 T! w S+ [/ X' ^
" I0 J% {' J2 I2 z! [criterion = nn.NLLLoss()9 @1 i( C) R9 u4 G7 k ] a
1" ^# {; }! _2 P5 e7 I7 W
24 ]: D' \+ K- ^
36 v! G4 {; s9 o
4
( e, @5 T! w: m( [5
0 t, T" M- Y8 f6
* o1 F* q, B7 @2 w) i" ]# e% E7
p j( B. g1 U2 [, q5 J87 c/ I8 V% [4 v) T8 `
# 定义训练函数
4 u0 M& H* a3 o" E! e1 C/ Q#is_inception:要不要用其他的网络+ W% z, O7 m( Q' d' r) e5 o3 Z
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):% C2 K+ K$ a0 S! f; B( O' F" w
since = time.time()
* D) J! Y, O- J7 s* t #保存最好的准确率9 b& o/ S3 y6 @0 }) U2 a9 V6 O( q
best_acc = 0) E( g/ E8 T7 z/ _
"""
, R3 o4 D% F9 l% F+ n, c checkpoint = torch.load(filename)5 [* r6 r5 R: X, V/ B7 y
best_acc = checkpoint['best_acc'], h9 T# z; A" ~; @. |: t7 R
model.load_state_dict(checkpoint['state_dict'])4 v' b: V* t1 z* z _: _
optimizer.load_state_dict(checkpoint['optimizer'])
, k F, o4 S7 Q! t2 m2 @, G model.class_to_idx = checkpoint['mapping']
0 J7 C, a8 _8 `5 `, q- x! q5 g """' ^' u% R& b; @; D" M
#指定用GPU还是CPU
) g0 }& ]& p9 x- G! w5 J3 N model.to(device)
f8 }' u- D$ @% @* l" R #下面是为展示做的" X! L- \. b N! X$ P3 G
val_acc_history = []4 Z) M* S7 _& @5 E% t0 o! j7 U/ F
train_acc_history = []! M) Z# k& e$ O% c
train_losses = []+ W F, f6 v1 _( ^1 x" Y1 J* I" `+ H
valid_losses = []: _: t" r# I+ |$ y7 M& r
LRs = [optimizer.param_groups[0]['lr']]
) N% l$ i# h# T) i4 c #最好的一次存下来/ g4 G$ c6 V( a/ |( A/ j
best_model_wts = copy.deepcopy(model.state_dict())# D! B& K$ A& z. ?6 N5 g! P% t
5 b' G9 `: j) G9 ~& D# t/ E
for epoch in range(num_epochs):
; i9 _8 ]7 O, a1 {9 \3 z' s print('Epoch {}/{}'.format(epoch, num_epochs - 1))
8 V* T4 J+ p) f$ `+ @/ u print('-' * 10)
* g( R8 c, i. i2 r+ A2 t9 t
+ l: h! A* `. Z, ] # 训练和验证
$ j, w# h3 R. k5 ^0 h# A for phase in ['train', 'valid']:
4 T$ v+ _4 {& e/ \; r if phase == 'train':
: N+ U* A/ p" J# E5 m model.train() # 训练5 y: q: r( m( f+ Q* Y9 q, p
else:5 Q( X* _, v: v# w* P3 y1 A
model.eval() # 验证
5 W( o3 u& @" U# ^. _
9 Q0 Y4 P# {& s/ |9 S0 s running_loss = 0.0
4 i" ~' ~7 S8 v running_corrects = 0
- f1 c/ M( G$ r! }8 _3 b. i0 s* v5 y+ h( X
# 把数据都取个遍
0 q# l8 D5 g6 m* K* u3 O# I z2 t for inputs, labels in dataloaders[phase]:
, ?9 e* z U3 k' h #下面是将inputs,labels传到GPU4 x2 P2 _0 g! r! f) u6 I- J
inputs = inputs.to(device)
4 ~3 D" S" ~4 {. v/ z labels = labels.to(device)
4 k- d2 V: D# ^) |0 P' M( F3 y1 X! C# z0 `
# 清零% s% f& W) }: j+ ^% z
optimizer.zero_grad()
; j$ x0 H( T9 z; Y # 只有训练的时候计算和更新梯度
2 \& Q7 A* v# m- ?' }8 \4 I4 b with torch.set_grad_enabled(phase == 'train'):! Y* U4 T( h2 q/ n4 s3 h
#if这面不需要计算,可忽略" }) y3 {0 ~/ Q
if is_inception and phase == 'train':
/ X, ?# o0 |9 d8 q" a$ k8 M outputs, aux_outputs = model(inputs)
- F' v, l1 O# ]+ x5 [; U, [ loss1 = criterion(outputs, labels)
: R1 B- {# Q1 V! k+ `* s loss2 = criterion(aux_outputs, labels)
, h0 X' n4 @1 Y5 J# p loss = loss1 + 0.4*loss2 \/ H5 m! J. N1 k
else:#resnet执行的是这里
/ e1 h+ R8 k3 g5 A3 s( ^/ u% o$ f outputs = model(inputs). a& W e1 ?7 R
loss = criterion(outputs, labels)
3 q4 A' _6 ~' ~$ t6 I: [% K E8 k6 P# F. i0 W% ^% K0 K
#概率最大的返回preds
. y W& G0 ~) ~( W; ?7 h _, preds = torch.max(outputs, 1)
" p* s# z' C! N2 Q$ N6 f8 U! h/ F2 ?. u
# 训练阶段更新权重
; m4 k) F- y4 t7 Z( ` if phase == 'train':. i( Z) c* T3 G% B
loss.backward()9 r$ x4 U R4 A& A
optimizer.step()
5 M. T% l# @4 U A. U5 o
& b, r" t2 @( P3 p5 q! ]5 l # 计算损失
r' u( \& J; |% W3 } running_loss += loss.item() * inputs.size(0)# @3 R/ Z8 c# M9 G n1 ^
running_corrects += torch.sum(preds == labels.data)
4 }+ t Q0 u) R
9 g' W7 a0 E! l- w4 b7 r4 \ #打印操作
. f4 [7 x9 G+ t; U epoch_loss = running_loss / len(dataloaders[phase].dataset)
. `" ~6 r( N: N epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
: ^# x! ?2 m# O! j
3 w9 O: g& r2 E: w. s$ {" g; r- U/ N4 w9 y& r
time_elapsed = time.time() - since- `9 j& V9 O2 b
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))- Y" @( o+ G) A) ^8 V$ e: r
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))8 E5 {% U% X8 L7 z5 b2 b8 Z
5 z; E* ]& w0 g# P7 R
+ U L- {, a. C( ]7 |) J
# 得到最好那次的模型/ i; R0 ], F) U) v( s
if phase == 'valid' and epoch_acc > best_acc:* n5 C+ r6 E& [; q U4 `6 |- m1 |8 }$ t
best_acc = epoch_acc
; {0 V6 Y7 @& _' F #模型保存
4 d9 f! D( r' x$ z' w' H# H best_model_wts = copy.deepcopy(model.state_dict())
& `: r9 s4 c; @/ j/ ^( P state = {8 S3 M1 W6 z' w& t% g
#tate_dict变量存放训练过程中需要学习的权重和偏执系数
& }) |* g* `( P- c 'state_dict': model.state_dict(),5 Q( z3 }7 T; [, x
'best_acc': best_acc,( D" x8 x/ T0 K/ k' n
'optimizer' : optimizer.state_dict(),
/ d, p" d$ J! t1 M" j3 c1 h% t8 Q6 J }- U+ x' O# l- ~1 J* h) ~7 q
torch.save(state, filename)+ q1 P) r) z/ @# R
if phase == 'valid':* w- v5 ^. f% o+ T0 e
val_acc_history.append(epoch_acc)
8 u1 ?' i/ e+ v! U* D# }1 d) K$ ` valid_losses.append(epoch_loss): ?5 d+ Y# `& c
scheduler.step(epoch_loss)
8 e4 d$ A4 T4 z if phase == 'train':# p! w2 _2 s( y- F; _7 t
train_acc_history.append(epoch_acc)! }# M. n/ y" Y$ `5 C+ q
train_losses.append(epoch_loss)
- l5 J0 j V, h o. u% S: h* f" l4 a4 a& J* y; p) x9 ?! s
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
& ]) ?3 `, a; l' l7 U LRs.append(optimizer.param_groups[0]['lr'])
/ x1 u) x4 j/ j4 v1 b9 B print()
, A; V' k; Z3 Z7 J5 F) b
% M2 K2 W4 x0 C$ T time_elapsed = time.time() - since2 p3 T) E* P# @
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)): f1 K8 m1 i+ a
print('Best val Acc: {:4f}'.format(best_acc)), ]& b" i% [- e: a
0 v% d0 d: A" Y9 |0 j6 G # 保存训练完后用最好的一次当做模型最终的结果
3 j9 f( h) D8 C model.load_state_dict(best_model_wts)! e+ |, l0 ^ V8 o: V
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
; ]" F+ u- O# \7 l4 Z2 v P- g5 ?* E$ z
; [! j6 [ u9 R: \4 S k1; u( ?3 y+ D7 R% ]* _
2. `0 V0 u! g) C! Z) D
3" m2 l; j5 [* \; o. B9 o
4
+ F3 O Q, ?& X; e: F* L5' q6 h+ R2 I1 `9 m" l/ `/ ]; w
6
: k. U; j A0 I( x7: J$ l3 Z6 v3 i! u! h3 a% W
8
F8 H5 ?2 m K7 v9
+ k9 H1 l5 q9 T# M! l10 Y/ _; N; J% C2 s2 z: z
11
0 B3 i2 |: ^. X: e# G2 V: ?12
. _1 I/ Y# m* W5 i; H138 W, i0 p& m4 d7 _& [6 o
14
z5 E# r0 o( P15
. x2 `( v4 s7 Z& R4 Z) T6 b H16
+ D- `" _% p( e0 i17
; Y0 Q* Q) k1 l4 I6 f" s4 U180 \) U& @, b- `& T( H
19' e' s" H0 h% A& z, L
20: u- T& S7 z& `. g
21
( f% B" [% }9 j+ E. e22
8 P' G% K3 m7 t% A& s% _23
) @3 E7 ?5 [5 t! Q8 ]24
0 t9 V+ V+ r4 ?. \25
% Y+ e! e: Y% j5 d26
6 w& s1 v+ `. O* |2 W; Q; g27
- Y9 J) H W, `7 J( q9 o28
9 O- f9 i; ]5 }) M+ M) d4 ], Q29
% G0 g( j& o1 r: t! u30
: Z! Q$ a( Q) e; j31/ [& a0 [' K5 t: o5 h# i
32
6 w) c+ T8 g2 P) b" \' g! r4 F33
- g. e! i( t6 ]& H- p; H0 s344 Q( C% V3 P! h
355 y5 [" {/ j2 |6 P4 ~3 `
368 @7 b) l9 q; l+ c# y
37' {; n/ B( y$ N* U9 @$ a! U
38( C& E; s. D3 `6 D a# Y: _
392 D# G! W; }; y3 J3 W
40; w3 L a% \0 \7 |" Q
41* |4 [( e, D* `6 S
42
9 p4 v8 f4 e- v/ D2 Y43+ s7 G* f) X/ v/ o* X0 w6 {
44
- D! R9 y: B8 o/ _! r+ E' J. {# \45( ~* w: l6 }5 [" }
46
7 h5 j- R5 T) {, Q47
/ y2 [0 N* A( u9 ?* A* d48
2 q0 O% o9 Z) z49' x: f) q1 F1 q8 V2 i' |- J
50
: |! @& F5 U4 |7 i51
( ]; l: c; R! K52
. h. y) T! Z/ q539 @/ K4 R3 r7 ~0 h8 K2 y
54) ^4 g3 j! g5 e+ @2 X- m+ ^; M& n# ^
55 t: z4 k M. t* Z6 q
56
% O$ R% Q7 K0 R$ t: T# f57& i% M$ f2 H' k0 {9 N' K
588 J7 p Q$ L' T+ r; P+ m
59
' }1 d7 t! p1 ?$ d/ j60
" z3 e1 p+ [# k& v V; ^) ]6 t61+ u2 |1 `7 E" T4 O
62
% ^2 X; e( p. t& Z4 o, h( o63
% c7 U$ N$ \' m( X* m; F4 H64% w( a1 E' a. x! T4 p3 `8 j8 M
65* c g. J, I: ~6 B' U7 G$ w
66' P# h! `5 g2 \2 z: s g O' s
67 Y* w" o+ T7 D' f; H( q0 H' Q
68/ S& w! j9 |5 X4 C
69
% r/ X$ `3 R! j- m! C P70
/ g: Y* T: q" w9 r9 x: L71* U$ e& \& q7 L3 p2 ^
72
! f4 w7 U6 u/ [+ ]6 T5 q+ t( H73
8 I3 r1 N; I) h7 u8 D0 r74
! P" o$ @. Y$ F. ^- v" i75& l! P) s- Y9 D+ ^
760 d6 C( \3 g5 S9 g! u
77
+ ~" E1 _& z8 @0 a+ | H78
7 g. l4 l% m, w2 D% {79
' p; d# O8 q5 L/ I1 Q% R9 e80
1 y. Q. T+ y0 i* {; x) b" \1 \81/ {( O" i; a+ }4 Z
825 \( `+ [* c3 W% h
83+ K3 d; @5 x6 Y* g
84. v) p m, w* B# d9 {% ^& _
85) |5 R2 h- L V8 X
86) n s, R# m z+ J
87' b$ h9 A; ~8 {% j
88
3 Q: {& K& W; y89! u4 c, A7 G2 C. {0 X) P/ I
90
9 b2 r4 B5 M2 ], Q! c& V; _914 O) V8 ~) ]7 P' l8 R; Y g
92
9 X+ l+ |! g, l4 I# U- k93
( ?" U% w/ [: {' ]/ B5 J94$ d0 H' ~# r; c$ @
95/ Y' d: @1 V2 S: P4 \% n3 V
96
+ y, V' I% e$ A8 l: z* A# d* b1 y973 H9 t. V Z9 c& x
98) ~4 \. @: l# W
99
; P) W" e& L( L0 l x6 A- E100, G; \1 ]3 I9 x# K8 H- z! W
1018 L+ @4 D0 H$ W0 o, {( A2 I
102
3 K' n+ R6 S' R! o/ F. I103* H' C& y; y8 _9 F, p/ E9 w$ r7 Q
104
5 x9 N9 N& I- T0 C! j105- c. c4 h/ o% I6 V. p& A
1064 e8 x; \4 ^. g. D& T5 @
107
0 @! L0 D! v6 H8 F# R f1087 B. m9 n0 v9 A* @5 @" e7 M
109
& T# ]4 V0 L" U1 C1 q, p110: [2 s I; Z0 A4 r
1118 E5 Y+ \1 S0 Y. N
112
, u; J$ u- y6 D, \9 `7.2 开始训练模型4 x& Q# D/ M' E% P5 E0 G5 i& a
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
9 S& ?/ i5 k! T1 S' @' A/ i. m! V5 @! K7 ]- v& M% {) [
#若太慢,把epoch调低,迭代50次可能好些
! J! ?7 _/ m: G" p% \, X% ?#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合 _) Y4 z# _" N) [
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))1 z1 }3 K: C9 |4 L$ v% s
6 O' ]5 ~+ v+ {: l. p, w' o1, S- H+ ~% F9 u4 j# Q
2; x. g; H2 r' V* ~7 [& @, k
3
- b% _5 o# W. M6 G1 Y47 Z! i% R4 l2 m4 B) q$ @/ I
Epoch 0/49 q4 p7 C5 I2 S1 z. A Y. W3 t
----------7 x2 L( \8 H6 S) f! ~% y7 D
Time elapsed 29m 41s! ?9 o, B: C0 Y/ M# O
train Loss: 10.4774 Acc: 0.3147: K9 D _( A* w7 A9 G2 m
Time elapsed 32m 54s4 ^$ C2 V( O+ r3 X
valid Loss: 8.2902 Acc: 0.4719 V9 M, N" ^1 W; _8 s7 X( T. k
Optimizer learning rate : 0.0010000
, s+ h; w+ y' n' Y2 h+ @- v* p" Y$ @' [% U! k/ q
Epoch 1/4: S* u4 }5 j# S" Z
---------- w, Z7 D8 v( _: I6 m2 e& }' K
Time elapsed 60m 11s
; X8 c7 r4 e. c# b% m5 Z# ~train Loss: 2.3126 Acc: 0.7053
: z9 T$ S1 A# j3 dTime elapsed 63m 16s
" Z4 A+ @! Y- D3 m/ kvalid Loss: 3.2325 Acc: 0.66269 c, F5 v- F5 c T) W
Optimizer learning rate : 0.0100000+ B! w) H/ K/ s9 @. ~0 N
: `: K4 V9 J5 m% }. o& F, b" {9 U) gEpoch 2/46 `9 e) t& R+ X2 J: G. Q3 |% G
----------
, {6 c2 E( r9 ~. D8 h2 i) {Time elapsed 90m 58s
+ F, F3 C& ?+ z& A6 vtrain Loss: 9.9720 Acc: 0.47340 B9 D% X) I ~) o8 n
Time elapsed 94m 4s6 M6 Z* [) b) \2 O" [3 ^
valid Loss: 14.0426 Acc: 0.44138 O) n; y" L; C1 \3 C# E
Optimizer learning rate : 0.0001000* i2 p( A3 I( f0 S" T8 D
9 A6 s3 d: F5 I' e( B$ e" L4 d xEpoch 3/4
$ e# r6 f( j8 g1 y+ K% e0 S2 Z----------
/ m' n2 H/ C7 _, z% H3 B/ ], iTime elapsed 132m 49s
, N3 L7 U! V' U: e& ]1 ~6 k ytrain Loss: 5.4290 Acc: 0.6548
/ Q) z+ ^) {+ u( [ Y ]4 Y/ R! YTime elapsed 138m 49s2 U! K. \* @) B; s9 U! K) J; v9 O
valid Loss: 6.4208 Acc: 0.6027$ P1 O$ i3 `; F; @! u- K
Optimizer learning rate : 0.0100000
) x7 M' e) ?- D- p r$ _0 Q+ }) S+ k# `) d1 c
Epoch 4/4
3 j& ?/ E( m. y/ a----------0 A+ P1 s7 a, |7 v
Time elapsed 195m 56s
7 G/ J9 H8 L N3 ^( T$ m/ ~2 ^0 W! ]train Loss: 8.8911 Acc: 0.5519
- X+ O: N2 d* Z- _Time elapsed 199m 16s
6 q( e2 ~! |. f" N* p% A zvalid Loss: 13.2221 Acc: 0.4914
" J' {4 _' `3 L3 Q [8 N. D) wOptimizer learning rate : 0.0010000
7 W) @) q0 c- a0 Q' k" H
! } Q U0 \8 ?( ETraining complete in 199m 16s
' ]7 h( U3 ]" {5 Y9 ~! X* f4 }Best val Acc: 0.662592$ T! e! V/ i) L+ c4 ~) i% X
+ y! W- F) }( Q6 b9 h: M+ I3 b
10 w$ U" j) h+ W
2
0 f. D. b4 G, M3 \0 e; W33 Y" d. E0 X; Y# s. M' H
4
( u& p/ i, Q$ c1 H! U6 w, U5
! h7 @) O) b5 n6 t1 Q6$ O; ~) x8 N1 p' _1 U# T
7
q4 b, Z# V/ j0 B- R85 u" b# }* I0 h8 y% ?
9# Y( @1 N# g: |6 w/ q3 r& a
104 d! A; t0 I# a5 R1 W% J7 a
11
; F. C: Y0 w2 e$ n12
- `. P- i7 E- H$ m13- g9 `8 C& b0 N# P) Y
14* u$ ?5 {: M; i3 n5 j W
15
! B8 s) `- i s/ }2 }16: `9 s, \0 Q, _8 b% t
17( O8 S- u( p# u. F9 `' s" m& j K a
183 {/ t1 }' m+ r1 `
190 l4 H. o/ a8 c. i
20
# ~- p# [5 D0 ? d+ ?2 F3 ?+ R21/ M, u4 a/ P" |* r; i- _+ Z, `
225 U6 U! X3 X' ?. E5 C) ?1 Y
23
; J; V! n- f Y9 Y247 {4 B+ g# z; {( F. q
25
/ b, F! c1 b' S$ } o1 |261 L% b5 n* k: ]
27: z6 e" ]7 h) y S' A
289 W- o) S0 z1 Y5 ?1 E; T& i
29- D1 y( L3 U& _8 r
30& K8 a- n; _% u- e/ q9 ^) @
31
( P- N% [8 M% m! A32
1 ~+ l. Q. @# Y6 k2 k! |33
. S# l) d$ Q' w. w* |' V7 E34
' _* b0 {& D4 w4 |; ~+ R( M35
( E' x: j7 o' W9 ~4 e( J ^36# @3 X' A+ A- _5 F
37' J# P: v8 P! T3 M2 N4 `& h3 p
38
3 I' Q. w A$ c4 S1 `5 i393 m' F1 ?. C- c1 x2 J& i
40
1 n7 J7 q+ i7 A41/ N2 \, @# Q4 C5 I2 y
42/ h7 w3 `' q" f' s2 b2 \+ K
7.3 训练所有层. ^, G7 T$ T, F; ]( L- o- Y
# 将全部网络解锁进行训练
8 X) A& u) Q' Q; H9 H8 N' Pfor param in model_ft.parameters():1 U6 Y( S- v# R
param.requires_grad = True
1 Q( M7 b/ J4 `& _9 [* j. Q4 m* l& j) d Q5 A: U
# 再继续训练所有的参数,学习率调小一点\3 H4 K0 G6 C H
optimizer = optim.Adam(params_to_update, lr = 1e-4)
6 l- _2 E$ j) T1 u+ Oscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)5 R# |! D9 A% W6 S* `: @! {
+ L3 x R6 n9 G
# 损失函数
u- i: V ?" r+ d( o4 o$ D: Acriterion = nn.NLLLoss()
6 j* I) e' f& G( G% I z1
8 m6 @7 A+ f4 \+ X& g+ ]* T5 O2' |* N% H- S# n* ]0 K: W& j
38 P8 e+ _; _' r/ \5 W$ r) k0 _2 K$ ^
4
6 n7 Z4 }0 ~$ F' w) ?3 ]52 |, V k( b$ \/ V4 u7 u7 j
63 v0 ?, u1 ?" _$ S* Z
7
! i% @% `1 S5 F5 T/ c85 y: M H0 O3 _ ~+ v6 o/ _
9( v, a2 b: I- K- M2 f5 i+ P
10
4 d; R) t8 L* X8 l) {# 加载保存的参数' `* Z; {% t: ?( u
# 并在原有的模型基础上继续训练; z* y' t4 K' Q! N% d( O
# 下面保存的是刚刚训练效果较好的路径1 c& y+ `0 p+ G/ x5 ^) D
checkpoint = torch.load(filename) I, t9 b, l: T# `. }$ X7 K
best_acc = checkpoint['best_acc']
" ~* _3 `% v+ k6 Y3 Smodel_ft.load_state_dict(checkpoint['state_dict'])$ c, D+ \$ `6 u7 i" y- s. d7 v1 @! J
optimizer.load_state_dict(checkpoint['optimizer'])
+ j& n. h! o% R- r z1" B, D- N' i. v+ B$ P
2( P) d9 J6 J ]2 V: _( q
3
. @7 q6 U- ~% m- f4
' ^ `+ m6 e+ i) Q. F52 I, a; C7 C' f, F
6
9 L& P8 B+ {$ i* U7
0 W1 ?& d- E2 E; ^8 u8 h% K开始训练
& v+ t( y' r. c7 Q注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考 K5 a5 B3 e* s0 c' t% R4 w
. P# P* f. e3 l3 L* Z9 ~( h
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
, r: Z5 b$ k# c4 S& r4 Y6 ~1
( j2 ~3 f/ h$ p2 tEpoch 0/1
) N# A" x- e8 _2 E----------
$ [7 A0 ?9 G" {% fTime elapsed 35m 22s
( r9 I& T" s5 ^9 P0 `/ p, k# q* ctrain Loss: 1.7636 Acc: 0.7346
* j ]; R6 H, I# V" sTime elapsed 38m 42s
) ?6 k; D k: v7 x; gvalid Loss: 3.6377 Acc: 0.6455
( q! L* G, J& J0 b qOptimizer learning rate : 0.0010000
* i7 a( ?, ~, b9 A
6 \! w, e1 J# OEpoch 1/1
( e M! k! Y7 y5 Y----------
9 C# U( V- Q' d' x$ M3 G# JTime elapsed 82m 59s$ z4 B+ T( ?9 P. d
train Loss: 1.7543 Acc: 0.7340
9 ?% d- e/ E- v3 V1 l9 yTime elapsed 86m 11s; f% G4 t* G: D( N
valid Loss: 3.8275 Acc: 0.6137
3 p/ ]4 J& Y% f4 tOptimizer learning rate : 0.0010000
2 s9 l* Y3 b: E
2 m2 q- `) c. G* S# u$ iTraining complete in 86m 11s) p- B1 p# c+ L* q4 I9 \7 U
Best val Acc: 0.645477& Z7 L n$ Z3 x. l
; c* E- G/ M! V) Z1( V) C, O' G# W% T' b# v
26 i5 U) J6 K9 u: N* C" n3 M3 ^
35 ~' u$ _) O# Q5 B. D
4
4 X# b1 c- E+ u0 j' r1 m: U5
, g! G! U/ D! R9 s6$ S: z, X+ M( V! H, [
7% R0 A; j9 k+ |
8
* }9 ]6 ]/ p# m6 L( o4 t0 \8 n5 M9
) I( r% n ?: W4 p10: ^ d6 e4 C' I9 U
11( g7 ]/ J& e9 S
12, O6 {6 ]0 g& a/ S( x4 I
138 F5 ~+ r; x( \1 @4 [" G
14
$ r1 f0 d [% @7 O( g. G2 `+ D+ g1 p6 o15
! `- F5 A( k0 t161 Q2 M: S( z; ]8 L, r
17: w4 l. k- A# l& h8 k5 b. D
18
: P. e! Y* D& V/ k! |8. 加载已经训练的模型
' [6 V# J! `. |( e. L3 f+ S相当于做一次简单的前向传播(逻辑推理),不用更新参数
' p" _( o1 j" G3 Q% U% E$ a. H
$ T: _" i0 v: z( }/ v9 e& R6 Wmodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)3 c) E6 c4 f* N9 r, ]6 q) [3 {+ [
! K' z# K1 F8 W, h; j) i# GPU 模式: `' {& l/ l/ j1 W
model_ft = model_ft.to(device) # 扔到GPU中3 Q8 u- U7 R; Z' L9 |8 S/ N y
+ J" F' ~( K! w Q* {# 保存文件的名字
- K' u) A4 Z! v6 Q9 m, ?, xfilename='checkpoint.pth'
' w3 I# n0 [9 q; r
' K( H5 t. J u3 r! J e8 ?# 加载模型
$ O' o7 s3 r9 U y. N2 v1 [& Dcheckpoint = torch.load(filename)
% E7 U2 s& @' t2 o' L; _( o* ?best_acc = checkpoint['best_acc']" H* Q& I1 u; M, B
model_ft.load_state_dict(checkpoint['state_dict'])& U/ `% ?# c5 c: {7 \$ q$ I/ Z
1% Y2 y1 U8 J& Y7 H# W" X
2
; m6 U" w% U3 q4 j+ I3
" m: z( a; r9 \, G z7 ?4$ ?3 T' A* `, R* J" _) r4 u, H
5# J4 P# W$ Y: f. |
6 @3 S/ n; ^7 ^/ T" X* f( t
7
" s+ r# E. C# m v u: v8
+ I& P2 a% F0 v/ B* P# ^+ a8 |9 Y* j: c# T k' ?" O
10
! k( {5 c+ @6 S- y11
% B/ x& O: b5 z5 X12
1 Z* ?4 J' G4 t4 A7 y) V8 o4 B& w<All keys matched successfully>- d( O' F' D$ b; _$ n0 B
1& Z2 O* G5 U6 Q
def process_image(image_path):9 y4 i0 ?( M8 r) s1 I* b
# 读取测试集数据3 U9 }% L5 F6 P9 m2 o: t
img = Image.open(image_path)( ~1 N, _$ ~2 l! `
# Resize, thumbnail方法只能进行比例缩小,所以进行判断. D' y4 G/ k, |8 \
# 与Resize不同
?# y7 \& L+ D* N/ ] # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小+ L( v1 J8 \& S, i( V+ h2 s* r+ }2 z
# 而且对象调用方法会直接改变其大小,返回None- D3 L& X {& j+ Z0 M- ~; G4 q0 z
if img.size[0] > img.size[1]:
$ U1 L- Y3 C# [! l2 b, S: R3 ` img.thumbnail((10000, 256))
+ ?2 M2 }5 D) o% b1 [& `+ _2 Y else:5 W% E, s: K3 Z+ `7 N
img.thumbnail((256, 10000))
8 e7 s8 i# s" F- n, a! V$ x/ \) x& m3 [
# crop操作, 将图像再次裁剪为 224 * 224
5 d, y+ f2 a% n left_margin = (img.width - 224) / 2 # 取中间的部分% b3 x: |& C4 }5 L" B1 i
bottom_margin = (img.height - 224) / 2
; H2 ^8 ]3 r, i5 O# ] right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度# `8 |3 A m4 z" `, h
top_margin = bottom_margin + 224
' j1 ?0 w& l* e! `& m! b) w1 A7 Z+ L% v9 r
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))7 x' ]8 e2 ~/ D) k' ^1 Z6 T0 p
' x# M7 s- @+ @& D! M0 s( e* _: k" @
# 相同预处理的方法
+ q) i7 k7 s. L0 G # 归一化# i. M. u7 r3 O( M
img = np.array(img) / 255
' R( X: \0 G( x' Q6 j: [8 k6 y mean = np.array([0.485, 0.456, 0.406])
% ^, W4 I3 C" s* k) C7 A- M std = np.array([0.229, 0.224, 0.225])
) ^9 r4 I! u! ] c% b' z* M img = (img - mean) / std
* o3 a. ^8 K. L q, V8 \' }8 O4 m0 a
# 注意颜色通道和位置% K( h8 A3 b4 z4 s- Y
img = img.transpose((2, 0, 1))4 f: H0 ?" ^5 _- W0 i8 h
8 c! i- C: ^+ x+ u5 }/ b# x
return img0 m& Q v% l2 i3 w0 c/ `* Z
8 i5 g" k2 l9 X: k) Wdef imshow(image, ax = None, title = None):* P0 j, P: |8 J
"""展示数据"""( Y" |4 k3 g; h% q: z3 g
if ax is None:/ I6 ^' y( C/ J" v9 Q
fig, ax = plt.subplots()- d1 `5 Z7 p, W4 ?+ c& C
* f U3 Z" ]+ m0 `/ x. L
# 颜色通道进行还原
$ R* W B' A4 r image = np.array(image).transpose((1, 2, 0))* K6 l( x D$ R' C3 k% c0 U* S5 @
T; ?4 Z. A+ Y Z. h& Y # 预处理还原
; S9 Z5 A5 A' K3 g/ `! Y mean = np.array([0.485, 0.456, 0.406]). c% ~9 M! n$ h4 g. f8 |( f1 I
std = np.array([0.229, 0.224, 0.225])* Q5 c2 x& l2 v; l0 h& \! P1 A
image = std * image + mean5 ^+ S4 |( T& V7 g" R5 B( J5 V) A
image = np.clip(image, 0, 1)3 J9 d) F4 F. r6 U: U7 [0 ]) B7 n
9 V5 b# x ]% j" r( s/ d" ]
ax.imshow(image)) x2 V4 b% M: R; ], G4 u
ax.set_title(title)! J- ~2 x- Q5 A2 G- R
2 L& k4 B0 U; `: c3 Q
return ax K1 m/ d) b# B$ @ [1 }! P2 R
6 k, Q0 a3 A, w! }image_path = r'./flower_data/valid/3/image_06621.jpg'
; B6 [5 A t, T' m4 v4 I" Rimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
, {6 ?1 A, n! T% J7 a3 T. _imshow(img)
8 X, S' E- \! y, }$ w1 Z8 q* b/ o2 c( q* L
19 r( z0 V: h7 O% H, m9 Q3 T
26 G" K/ P W- \: Q, U
3
) z5 ]: d- D5 |" V `7 R4
& b+ ^6 h0 p3 n3 V2 `5
' ]3 Y% r0 F: T6 L/ W& q" l* f6: i9 f& H4 g- M9 W' E8 V! `
7
8 W8 h4 ^4 [, W7 P7 q4 s$ P* V& m8
% g n2 m8 M2 Z. N. v( O93 ` |; l3 d5 ^" i' q+ q' _
10
) S, |5 R3 W8 B, U* i118 p ]$ f5 ^# [4 y) H
12
, a {' u; o1 V# v) x13
* j9 U7 c2 Q, T/ Q$ b148 h+ K/ O8 n; o, _7 R
15% k/ x0 h8 p J
16; L2 s& }: _$ ~3 Z: y
17+ m: k4 r; o7 g) ~: g' W
184 y" | J u, M9 E) |! |) s9 p
19
6 N( ^; u! e3 H9 f2 q( i20
" y' Q- {; x8 T+ j& C. W6 h21
) W% A# a: a- c- y0 B) P22
& @7 Q/ ^" I' }: [231 O' i/ m: r5 k, S( @) {
24
# y; K; ?0 M( A ]$ N% O25
; b' n3 t% }' F26
4 i/ |9 G9 t9 N' B. M+ ~1 h. y% U27
; T8 K: Y! h. ~' P28" r& k. r+ E! y% z+ w/ w
29
7 _8 f* o+ ?2 Y8 G, P- ?30
% i& [8 ]/ p( w5 }31
$ Q) P' p6 X! t# C32! V6 n' [/ j# I' U: [3 r
33
( L# ]% m$ w. ?- \1 {- ?34% J( o4 Z: e1 L6 V
35/ t- u3 a. b: F M
36' k0 l; n( i& R4 d
37( f. s' W7 I$ h4 Z- a
380 p1 d* J* u5 Y. i
39
9 N5 T$ ?% |3 Z- S40
9 X5 b- b8 i& V5 P* s, D: w41
6 a# w( c- V/ H422 B. @ v% ?6 I W2 \
43
1 i8 i% K1 R$ s5 `( U4 Q44: i7 E# U8 J/ c& B
458 k5 e+ r2 W9 U9 Z8 H
46
; \ a# `. N7 D, I% O8 ]6 {; b478 }' Y$ U. }# O, Z) G* j+ `$ @
483 A9 P* w' P7 Q
49
# }8 M2 A# K2 ~, }8 A8 H50: h/ L7 b9 y7 c1 z3 O) y
51
3 `2 Y' U8 {! y W52! s) n5 R3 U/ k
53* N$ b- r* s2 t0 ~
54
" ?, H; ]5 E7 ]; c- A<AxesSubplot:>
6 P6 y+ S% k* J6 ~, c: L$ l13 c9 z; z ]" U) ^0 j; k; e
7 i9 b6 g6 n1 V7 ]2 d" D. I
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
# w( g9 j u' n- p G5 X; _/ L
$ F) {7 _ q8 i( Dimg.shape! u" M) f9 j3 U/ |
1
3 k y1 ^# b( e D$ }% p A(3, 224, 224)8 y* ~; W2 h/ i: o
1% y" b2 x \! r, |
证明了通道提前了,而且大小没改变
/ X+ q- E/ B; Q( y1 ^) }- s0 ?4 G5 Y# n7 K% Z: H
9. 推理
; \+ v3 L6 J ~% R$ Q; b& V* kimg.shape. F+ X& `! C7 s6 v' n1 L
7 P' {" i, d5 p* l# Q$ v: L, Q
# 得到一个batch的测试数据$ a( {$ v; K; A7 [+ I
dataiter = iter(dataloaders['valid'])
3 N! W5 W6 z! T% c& }/ J; Timages, labels = dataiter.next()
4 M8 ^+ c# E, l% a; U' L$ u D5 d! ?, ]8 Q% O; ^6 N
model_ft.eval()
6 u; T1 M" ?" F" B2 k8 f: }& C8 L: P, e" {* U& ~5 v
if train_on_gpu:
8 N( m; {4 M1 M4 w, J7 n # 前向传播跑一次会得到output
' `! Y3 z9 U8 ` output = model_ft(images.cuda())
- q( K5 R6 t" ~1 N) ?" Welse:
9 F; s4 k2 T, a+ c) @8 c output = model_ft(images)
6 w0 ~( v; x1 } a
4 e. _7 z" t! C# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
7 y; Q& A* ~7 ?) voutput.shape, _, l: f; {2 _$ ?$ ~% w9 n
+ w0 [4 n# h) n' t! i; p/ \6 f0 a5 q8 L
1) k( c Z- S" q' b% g
2
5 j- |+ S- T$ D$ _3. ]1 [& ~: e$ x% j9 v+ V/ S
4+ m4 g" H, d1 Y( d
5
; J* L( W+ Y$ f9 \. a6
/ ^: W# e; i; v, m# q: Q! `7& {9 s% ]' \ x1 A. H; @1 _0 l
8
3 h$ j- ^) b3 X) u* h98 v0 J0 y" O, A& d
10! d4 p& k1 t3 {
11: G' c. ^ U2 E& q4 y, t3 P D6 q4 ~
12- _0 ?8 L7 e, n4 w% D; q
13
; {/ m! O z5 f6 B8 ]14
0 A$ N, G# E) o9 F8 A" L1 e6 v15
9 e) x) v! `0 e- g9 y+ X% e167 v$ Y& [* z# U; @6 i" q2 `
torch.Size([8, 102])
: O( J, o. }) M. ^" ~. \! o1- P% n l+ o$ u1 }0 k( a7 d
9.1 计算得到最大概率
C3 ^: r( O* i5 v_, preds_tensor = torch.max(output, 1)
) F6 b u% x3 X6 v1 o/ m5 _) {
+ V0 I7 B! q% J( j% Xpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
& G' q7 t( y0 _3 _1 O1, Z# g6 ?. ]2 m7 }6 M- y
2- m4 p5 J. t& c9 }: O8 N: h
39 J$ O* x4 [ l% U, |5 \ C; I
9.2 展示预测结果# b4 H" h% i0 W7 `+ D: N5 F
fig = plt.figure(figsize = (20, 20))
$ [' a# w: Z+ k4 F8 bcolumns = 4' R: q7 |: Z0 |& v' t0 S
rows = 2
. \* a" [$ z% r, }4 l5 ]+ W3 u: Z
for idx in range(columns * rows):( T3 \4 c: Q+ d/ j* o, r
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
* M0 M$ a) f/ R5 V& z; M$ U- p plt.imshow(im_convert(images[idx]))3 t5 j) ^) B# c4 B) [9 U3 y9 q
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), / C _3 t3 |1 c/ @5 `% u2 L/ V, ]
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))" f9 p- H% o5 x, ~
plt.show()9 s4 X6 o; B- ^
# 绿色的表示预测是对的,红色表示预测错了
, M+ @9 v! u7 d8 L18 q9 U1 N! u. @4 s
2) I( i6 t. {0 F% f+ D+ h
3
) z) o7 z, i( K1 r# q/ P4
% ` s0 [& R0 W* C5+ {6 |( E7 P- W6 @" o
6$ E, v* o9 L7 q c
75 ?+ x6 |# Z, b
8
0 C1 v8 p7 C! ^6 }$ p96 }* O& n" ?/ w' z/ ^/ V
10
0 g4 y& e# c+ O( G& p! a11
6 m: S* a( F a% h5 W6 d, Y6 t, ~2 X
' B$ _& X+ s$ X; O% f
; y7 u# j$ d$ U z3 H4 p" H! @
————————————————
3 r( E. i4 ~; A" `- C版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
2 ]0 j7 z: d. E原文链接:https://blog.csdn.net/LeungSr/article/details/1267479409 @6 L6 j* u$ E9 X5 ^
+ e% i8 _+ ?: d+ E8 E1 e
4 h. f9 W% ^6 M: \* W1 g |
zan
|