- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564693 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174631
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例2 K. m4 G+ p8 k
0 J# B& B3 {9 ^+ F5 d2 Q
文章目录: _% n1 `1 q3 m4 C0 u. V: a
卷积网络实战 对花进行分类
7 i& E' ?# h3 i* c [: s数据预处理部分
! e/ E& J& G+ T& L. I) l网络模块设置# {, s/ O$ G& _9 j0 u# C7 Z- j/ L
网络模型的保存与测试
% \; k7 `. N% i+ t数据下载:9 a4 `8 }& d9 o, w
1. 导入工具包9 P% f" G8 w0 M+ m& `: @& l
2. 数据预处理与操作
9 V+ y& G% D$ v* H& c- ^. s6 {1 L3. 制作好数据源6 q3 s" @" E3 T; N9 S( O
读取标签对应的实际名字
, P7 ]& B( W3 E4 r2 d% P4.展示一下数据
7 |+ b+ ~, [5 n( U9 u5. 加载models提供的模型,并直接用训练好的权重做初始化参数5 M/ ?2 j [# I
6.初始化模型架构! ?& b2 o/ E2 q, Q
7. 设置需要训练的参数
9 C+ H1 _4 J" m/ m7 q+ `7. 训练与预测 [. S5 s+ ~" D5 z
7.1 优化器设置
f9 [1 }) Z" q# {, u5 ^7.2 开始训练模型
' ~! ^1 M: D, r+ ?# E7.3 训练所有层1 [ \# v7 c8 y* e; K8 d
开始训练$ S" `; ^, L. ?. T8 j. D* s
8. 加载已经训练的模型
9 d1 c5 Q7 o; w$ r9. 推理7 X; U2 ?8 W, y( Q2 h/ |3 S4 x* o. n
9.1 计算得到最大概率
+ [' L% |8 n, ~) h) J" B! R% @9.2 展示预测结果
' q1 W* ^; e1 ?$ U$ {7 B9 i写在最后
: s" i0 X8 N! w卷积网络实战 对花进行分类
5 c5 a) S- Y' F* l; ^本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能; s6 \2 Y1 d Y6 N" P
; x0 o. A7 f% q- v
在文件夹中有102种花,我们主要要对这些花进行分类任务
# ]8 U) R' ^6 h2 I* e7 u; a5 R文件夹结构
+ h; L# P3 M# u x; j
5 l. g0 m" ^4 N0 ?: cflower_data
/ L9 [/ y% G5 ?8 }
+ |# D: f' |' |: {8 }train
- I# v, _: b! A1 H& e% A5 F) ]" ]% J5 ~; _) a+ |
1(类别)4 ^4 {- p* }% v! u6 ~9 P6 K
2
# M3 c9 R& Y8 r2 \xxx.png / xxx.jpg0 Q ]( \( P/ T
valid1 K f. `7 x- J* |
& w6 G' w W$ @: c
主要分为以下几个大模块
( p; K. _, E( t$ Y4 T0 X9 S8 r( r9 H
% P7 h9 W) Z; e+ `* w% M- \: `数据预处理部分* f# u, \6 i0 n2 @ G
数据增强( [, B' J- p$ _& m0 E
数据预处理
+ N" o7 @7 b) R" h" g: l网络模块设置) v" ?$ p: ^3 w& L/ q
加载预训练模型,直接调用torchVision的经典网络架构. r; [8 Y: [- o: g: d+ K7 B0 d
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务7 B5 ]2 Q7 P; v- w
网络模型的保存与测试
- F I1 K" B4 C' d, R+ h模型保存可以带有选择性
7 B+ U1 P9 _ A$ l9 X! A; h数据下载:
$ P1 B. C0 S/ l. m' C* K) ` Thttps://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset% |* j# ~5 l3 O, f {" S& r: ]
+ Q+ W+ {6 K0 u0 V a7 s9 L+ J3 ~" R改一下文件名,然后将它放到同一根目录就可以了
. L) Z, r/ W7 [+ a |- |- j3 \, L7 y6 Z; \, G6 `7 ?
下面是我的数据根目录: l2 u e' ~; ]# K! f
! U3 o2 _0 M0 O( ~4 [3 i& k4 f' L0 b2 i$ N8 v
1. 导入工具包
( s; @0 g b5 H% k @import os
* m0 g! x7 o; s/ o, [& z/ Wimport matplotlib.pyplot as plt
8 K- T$ A, \" R% |% I# 内嵌入绘图简去show的句柄
$ E# V* P0 U3 X" Q o%matplotlib inline
/ O9 \5 {' C3 A- G) w0 O4 jimport numpy as np4 a J& m% I; d
import torch
& Q' o6 e# I6 o6 qfrom torch import nn
8 `1 f4 l; P. I: Z4 |$ u
" i0 T% c; U) o8 zimport torch.optim as optim
6 i" u8 F2 c! G& Fimport torchvision
' x2 M! g! q/ i O& z! Pfrom torchvision import transforms, models, datasets
1 g! j& J* o- Z4 e: p2 S5 ^
, ~* F% O n# F- Rimport imageio
5 q& m1 \" L2 ]1 f8 N7 B7 _import time
5 r5 u! U, m' ~- |import warnings
% y: _) o( |( Fimport random4 t4 X' c0 F% ^6 U4 R9 }
import sys7 P2 w8 o3 R7 m' M* I
import copy- g4 o! l3 {) a: u( A0 z0 E- s
import json
0 J+ a4 R/ Q( t# I x0 ufrom PIL import Image1 R) l4 s$ [+ W" R
7 D) }( Q7 w% a/ V
. u. f# W, x/ z$ I: i5 a9 A! @1 G/ a& w, S! d5 b2 ?7 R
2! F v9 @4 \3 w/ T
3
& y! A. J" m! r4$ B/ l( y. Y0 I) L
5
8 a/ \3 g+ K4 @( M( Q+ W6" ^2 t9 G) q" a- ` b
7
- K: _4 j n$ p/ i( j2 ?( \- r82 k$ M4 t6 H" |8 B: y
9, h0 w3 h( g4 L- p) `
10
$ l7 B8 c/ B, V4 G0 x11
/ I7 M, t4 ^4 ]+ I a# G @12
: R! }) E5 |: X$ a13# b! \1 g; P2 ^+ F
14( C7 ~: S; B1 X \5 B
15
. L! {# I [) t k- l- l( q. ?167 o' h) f+ T; v5 G# v) K
17
3 ^1 o) c, w8 I18
# d% C9 h/ h( i) W3 W, Z& \190 }7 |6 c1 N+ w$ R- \
201 x2 ~# @6 b6 L7 H+ @7 W. T
21
# h1 F8 T! ?. N7 e3 J2 _. K2. 数据预处理与操作
# g, v, t% R, z$ M1 r+ R#路径设置
: ]' R+ p: I8 e" \! Sdata_dir = './flower_data/' # 当前文件夹下的flowerdata目录2 F* u' F$ W2 J7 r8 Q" ]5 Y9 U
train_dir = data_dir + '/train'
* r. t0 L/ T2 K% ]valid_dir = data_dir + '/valid'
0 _6 T8 {5 B9 T; B1
% l& Z3 r6 W' W8 u2, R8 G) x) h( l
3
1 @0 X2 L& l: w m F- P3 O& }, p44 Q6 V4 D7 u- O; I2 e2 s
python目录点杠的组合与区别! F% Z5 d/ x6 w' M
注: 里面注明了点杠和斜杠的操作* |, V* a4 N7 Q ~3 w/ ^6 N% V
% ~& r3 Y2 Z0 ?( L( E" u7 |3. 制作好数据源% ]; p, ?4 X: e2 L O$ d
data_transforms中制定了所有图像预处理的操作
: m2 Z& x5 B x# r o nImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
" e% `) P. v8 ~9 kdata_transforms = {, P6 b9 ?; H# @9 s7 M
# 分成两部分,一部分是训练
. X/ u+ s4 j6 x& D# B$ {6 E 'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
6 n2 n/ ~2 Z3 y5 [" E& o! U transforms.CenterCrop(224), # 从中心处开始裁剪' W+ j* r4 h5 p, d2 s
# 以某个随机的概率决定是否翻转 55开1 v6 ] x }4 N/ ~' M/ p. e
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转/ T) {9 C6 Y- l& Y" p
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
, T/ v- L3 o- @3 P( a9 I # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
8 U- t9 X, v( ~+ ]4 @' V( v& n transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),3 B5 `! H% U7 y# Q" M4 z
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB; R" i& |/ T* H; `: N5 t) R
# 灰度图转换以后也是三个通道,但是只是RGB是一样的
. X7 W/ n* y; z3 R) Q/ W' M. J; k; ~9 W transforms.ToTensor(),
: R" k! W X$ S( |% Q {3 {, P transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
* u+ O- I$ J- m2 B7 `& J. g ]),$ ^3 w0 t5 |9 j8 D" D3 e
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化2 P6 j- [" L/ B9 g
'valid': transforms.Compose([transforms.Resize(256),3 [- |' {7 Y$ r# m. r" R
transforms.CenterCrop(224),
9 m6 Z d/ s' D- E" C! n transforms.ToTensor(),% m% W J$ _+ Q4 [4 g4 r
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
9 u$ j4 ^0 S f8 { ]),
5 a7 a: ~8 O; e7 U8 a1 m}
( ]7 I3 ?/ g( h% w% H* h) \! O' v/ G) S# Q( l
1
, r% u" o& e4 \) |( b: U% r! @24 i, S$ s% h9 B4 `- q$ r
3
# k6 {/ z2 r6 n o3 T2 h4
, R3 A/ N( y+ J0 j3 U4 N/ ?5 I: z5, @" q6 b# O7 G; w- o$ k7 C' M9 K
6& I- b' \5 S, N# Z3 ]3 k, I, m
7
, _; a! D: l3 U3 K0 T1 n! |82 E* l# q7 J! R: e& n1 r
9
c3 o' n1 n9 Z10& C" u) z1 V9 i. s F
11: r% N$ R( K+ E5 M* _
12) V8 M7 s7 @& u6 ?5 L7 G7 y6 f
132 I/ M( n8 |' V3 u, H0 X( L2 z
14
8 }+ R; J7 _ [& R3 W8 G U! v15! E* s4 j/ u7 U0 q7 [' z
16
( y) _4 Y- V% [" Y }17$ W9 @* b! S z$ }
18" K, Z, p; D/ n6 W7 @4 w Z) f% S
19
" h6 b! r. O" d6 y, J6 r; q0 _20' w& {7 i8 }! A& Z
21/ Z; L4 W0 w9 O2 r
batch_size = 8% |' t' @4 n$ M" c3 Z' W5 o3 @
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}; @$ f8 O' v6 F. x9 y
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}5 N3 g% r( V0 m. @
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
& z5 {/ n2 D4 i; J: ?. g: Y/ A& mclass_names = image_datasets['train'].classes" _% K9 Z) [; ^
" R" d: ^# ~( e#查看数据集合, A' e$ U& _7 [' d/ h7 l5 p
image_datasets
( c v0 N9 D4 |0 f% a* e8 k' L# h! Q$ h6 x5 J0 S; I. J- _
1
& t- n6 f' }7 m6 ?) M8 |1 c# W2
: m" _( d6 H1 X. x. a3* G9 p Q0 |) ?# `
4
: X2 }; j+ i0 m Q6 ~2 ^" i. d5: F0 Q N7 [, L/ p6 {+ j4 ?
6
4 I) t; u7 W* k0 @7" O) o" M1 y8 i8 ]! R
8
% _- @5 q6 F ]( s4 i9 B/ D9
1 y2 I3 y- g% d$ X{'train': Dataset ImageFolder
( R4 h0 c% ^( f+ G; ~& u Number of datapoints: 6552' w% B2 |6 Y9 t2 b; X
Root location: ./flower_data/train. }* w3 o+ \/ l8 I# h# F! z
StandardTransform
' J/ ~6 n0 S- L E ` Transform: Compose(
; U, A' d; i' V RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)! H& _ ]+ |' i% A) K6 w
CenterCrop(size=(224, 224))3 h. P5 I% b' f8 P4 t% B
RandomHorizontalFlip(p=0.5)
7 [( L/ N% E! U/ Z RandomVerticalFlip(p=0.5)
5 i9 ` k5 s$ J8 I8 X! q" j* e* ^ ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])( P, k% W% T+ {2 d4 u6 G
RandomGrayscale(p=0.025)$ b0 n2 j: x: [* z/ d" V
ToTensor()
9 ^7 Y% P0 i1 E3 s) k* v1 ] Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): x2 D) B. I1 c* s/ p6 D' i0 t
),$ x* G1 b' W8 p) }. A
'valid': Dataset ImageFolder
- f( s+ P! Y# ]; y F9 @ Number of datapoints: 818
+ r1 v6 h0 G d* e7 T Root location: ./flower_data/valid$ l& m$ H& Q2 F, i. B4 a8 D
StandardTransform
: \, j5 U# d- Y- a' X; W" w Transform: Compose(9 @) F+ @, h/ A
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)% H3 | x/ k+ i, ?: F$ w$ d9 i+ H
CenterCrop(size=(224, 224))- K/ Y, G; X# H8 T- U
ToTensor()
# S/ \) `" |" N3 ~% S% L Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
% g* z6 B+ T2 Z2 y- ]1 A. q" Q# g )}) H, @' F6 U c
' j o- A" w7 j! k/ @
1! h3 _3 p' D" Z
2
' n4 [+ v" T& b& u3
0 t, T7 X( f# H1 {- z4
9 D6 H/ R; Q ?' I7 {" u5" i& Y* @, }" J, G& E" f$ |
6) E, k$ d' D( f: @
7
q: R" r7 V, u6 M+ s, a5 c8
" v9 F5 t8 B4 t9
6 d3 \5 a/ n$ y! Z/ Q/ X0 Q' U10
& Y7 ?0 z+ i4 Y4 G$ [11- v( c- T2 \9 @2 P0 x
12+ H2 o/ t& k2 ^' e
130 z5 y4 z" m0 X4 g( z7 E
146 @( |8 l. ~8 d# T. B2 ~
15. Z5 k) ? D5 c% B/ [2 U
16 z, E, v1 i; v V/ N1 I e, a; \
17; c, {# S& o' D8 [4 Z( z* |' t# M C
18
" B8 y f( d+ L! N9 k4 h19. f0 ?- S8 w( v4 C F+ X
20; ~6 n* B8 G$ g7 x/ W
21
* _1 p4 ~9 Z' K* g22
) `2 Q. O- [$ Z6 y4 @237 n- f+ E, H# i4 U& e+ n M
24
% v% ?! P% `& t# 验证一下数据是否已经被处理完毕
' S; n- R$ [9 hdataloaders
- x# u# b$ v* p1- j5 q* _8 o/ s
2
4 ]$ u2 s4 @$ s8 h, J d! W{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
! u& W0 f' W2 X 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
7 \% Q" C" C! t3 ~4 }% B+ H) N0 G1* U& I) i' d4 i4 L: Q U0 x9 A
2
& m8 }4 `% F5 s1 ?dataset_sizes. s! C/ i: `1 r$ }# D3 K% }( \5 B
1) l8 P% a- x, {( b* ~ a/ S
{'train': 6552, 'valid': 818}6 S+ Y7 R; }; Q' R
1
3 a$ S3 \. H# K3 v) c o读取标签对应的实际名字: y* _8 T8 e9 |6 z1 A! s
使用同一目录下的json文件,反向映射出花对应的名字7 t5 b6 ^" m2 o3 n6 H4 S C
+ ?& A" G& N# C# g, xwith open('./flower_data/cat_to_name.json', 'r') as f:
% K- d) T3 W m, u1 e cat_to_name = json.load(f)9 J% f) V( h% [* e3 F, q6 ]+ M
1: I* ^- v8 z( X$ O- b( m# L0 p6 K
28 N% L1 U) d- X- O
cat_to_name
, j6 r0 q7 ]0 l) c1
7 z- I; Q6 Z9 K% P4 q{'21': 'fire lily',$ I$ a4 Y0 S4 M! s8 I
'3': 'canterbury bells',: Y. k' x* _4 Z- h T9 W1 ~
'45': 'bolero deep blue',
* j9 c9 m! t4 ?% w '1': 'pink primrose',
! m( }" v3 @% a5 y( j '34': 'mexican aster',
- I1 g6 Y9 F/ }2 U: e '27': 'prince of wales feathers',
% \) ^/ n% u& t9 z '7': 'moon orchid',) C+ x; r5 M0 ~9 }
'16': 'globe-flower',
. e' Y8 `6 V6 Y9 _/ h '25': 'grape hyacinth',1 K$ ~% U8 Z2 U! ~& c4 e$ d8 O! T, o
'26': 'corn poppy',
1 W' Y' d* M# \6 r: X6 U '79': 'toad lily',# r7 R) Y2 S# n1 E2 X+ Y
'39': 'siam tulip',
" G y9 `* A% G '24': 'red ginger',
' O$ B2 ?* k0 n5 O4 w4 k# _ h% A0 Q '67': 'spring crocus',
6 \. F4 ?3 ~0 i; ?3 z$ H& Z '35': 'alpine sea holly',; J6 t: k5 H( H! z& M
'32': 'garden phlox',
% G# G8 h% a2 f$ ^) p+ ?) A0 D P '10': 'globe thistle',
3 w* o3 C/ w! R9 b/ \ '6': 'tiger lily', h) X! w6 I* j0 s8 p
'93': 'ball moss',
4 ?& d3 [+ a$ C4 Z) J '33': 'love in the mist',
8 Y5 m3 f) q2 [, ^) n '9': 'monkshood',( A2 F% q6 e0 c" G
'102': 'blackberry lily',. o8 i! W1 ^8 b7 ~) [
'14': 'spear thistle',
b2 |# ^# v4 O5 u '19': 'balloon flower',5 l8 ~# h& \1 @' L1 P; n6 Z& F
'100': 'blanket flower',4 O" m) u b% k
'13': 'king protea',: `$ E* V6 |9 u. k
'49': 'oxeye daisy',
& l8 w; M' a& M1 M '15': 'yellow iris',
8 |# J8 I) v' U- k '61': 'cautleya spicata',
+ Y4 _5 f0 P& ?/ w '31': 'carnation',
" r* ^: o! y4 f; T! r k '64': 'silverbush',' R5 c2 A, B: M/ q9 Z9 r; ~% R
'68': 'bearded iris',
8 m5 c* P5 V0 C4 O/ c$ u '63': 'black-eyed susan',
% q3 m; p/ A4 K( i- d' J0 z '69': 'windflower',2 H( [2 X3 o4 l! l- t" L9 `
'62': 'japanese anemone',
6 d( z2 k" W# k& g3 m/ Z '20': 'giant white arum lily',
* E, [& M& m5 o3 p2 H '38': 'great masterwort',
: E% t: D8 a. C2 Q Z, H '4': 'sweet pea',
# b( z' e3 `: S& E7 F- [ '86': 'tree mallow',
4 F8 I+ r8 d( ~+ s8 }9 p" D '101': 'trumpet creeper',9 s% v C( X. R" b9 k; `1 ]! o
'42': 'daffodil',
o$ v' R W7 N9 |7 ] '22': 'pincushion flower',; Q/ @# `. E+ ^- ~! A" @7 F! j6 t5 P
'2': 'hard-leaved pocket orchid',
0 c: y+ D5 A- m! U) X. `6 X' u '54': 'sunflower',
k6 K7 M# j4 u' H& h1 N '66': 'osteospermum',
! Z* e& G8 s" ?' v! @ p '70': 'tree poppy',
4 s! x; o: A3 V4 z. R# Q; P '85': 'desert-rose',
I; j7 a; g G7 U2 G3 V& C '99': 'bromelia',9 f/ {7 G. H4 V
'87': 'magnolia'," D2 a) E/ N; {: ?4 W1 c0 E1 h
'5': 'english marigold',# m$ x3 }" V( N* | `# T/ Z
'92': 'bee balm',
* ~ h# t9 Z, e1 y7 g" _ '28': 'stemless gentian',
( }2 t; N- e+ `5 Y. D1 u9 i% {2 j '97': 'mallow',) \6 ]7 |+ J; g: M1 q: @
'57': 'gaura',: }. b7 n" O; R9 x
'40': 'lenten rose',0 ]$ a" @( A6 }# r+ e3 z
'47': 'marigold',
7 m* e5 l" i# b% c- p) @$ M '59': 'orange dahlia',
, S ]0 w! L, W8 Z1 F '48': 'buttercup', K' Z, Q1 }0 y
'55': 'pelargonium',# X" ]+ P+ ]' @# ?; T7 k
'36': 'ruby-lipped cattleya',
& J9 h0 Q+ o+ S '91': 'hippeastrum',
9 z6 N; f5 X2 z. R '29': 'artichoke',+ [; _/ x: G* i% w
'71': 'gazania',
" x# J h$ O. k1 ?2 Z! w '90': 'canna lily',8 u. l2 L2 z3 H. ^. T5 ^2 f
'18': 'peruvian lily',
' j1 I1 N! ~3 g. s6 t '98': 'mexican petunia',
" T0 D! M. G2 x+ @' P '8': 'bird of paradise',
/ }) @- L; N0 \) q( e: r '30': 'sweet william',, C: u# Y: M; t' J- O
'17': 'purple coneflower'," s9 v5 G+ A" T4 m% s4 B" ?4 z4 S& J! F
'52': 'wild pansy',
" T1 r! m, p L% N: F '84': 'columbine',6 V6 M8 E9 ]) W2 ?- J
'12': "colt's foot",
. H* H7 T0 y9 D% G* ? '11': 'snapdragon',
3 p$ Y6 Q- G1 Z5 ? b' a3 O '96': 'camellia',
7 o; T- u$ N1 X. B" G3 K '23': 'fritillary',, ]$ o7 B" t W& j% X
'50': 'common dandelion',, V8 c" X+ Q# E. _) F& V3 l
'44': 'poinsettia',$ O1 A; z) h& `+ d8 T T. q6 U
'53': 'primula',
6 |) P% G9 x; d% e& B '72': 'azalea',* [8 W. q8 f+ v- e
'65': 'californian poppy',
6 O c h6 B6 G '80': 'anthurium',
6 r) N( X/ X1 [& q! u '76': 'morning glory', l3 j/ u) U: [7 a& K& S9 S$ b
'37': 'cape flower',. v1 s. `- D/ l+ q+ Y
'56': 'bishop of llandaff',6 t5 h2 m8 |9 K
'60': 'pink-yellow dahlia',
/ T+ A! }0 f& T) L '82': 'clematis',7 B4 s) V) n3 Z1 t
'58': 'geranium',# t7 Q) m) T m) s" J; D$ n
'75': 'thorn apple',
5 Y3 p$ [9 I) Y3 l6 i: y! z3 I '41': 'barbeton daisy'," l" o% U5 W4 `: v8 t# I6 z
'95': 'bougainvillea',
7 S* M* y& l; i# x! Z- Z$ M '43': 'sword lily',
9 ] v: S+ \3 _1 r* ?2 V '83': 'hibiscus',
! B2 a) H3 W$ e- V" ^ '78': 'lotus lotus',( I4 p& Q8 w" ]! z, x. e
'88': 'cyclamen',$ F! M W8 |: y$ b
'94': 'foxglove',
" y" r1 X' _6 P' R, a5 ?+ b! A) X, x '81': 'frangipani',
" W$ y6 K1 T1 _# B '74': 'rose',
+ [5 _: D) t9 y$ y '89': 'watercress',
8 }4 V9 W Y# W% K6 J '73': 'water lily',, V+ {- j# c% k; O6 H- L
'46': 'wallflower',
* {+ t* q9 y# s( a '77': 'passion flower',. n' V3 ^( A0 s2 u
'51': 'petunia'}
+ D; h+ n1 W2 W0 V4 i t) p P1 e/ a6 T* T1 z$ Z& Y# S
10 g) e2 l# a+ e4 H; q8 n/ l$ ^# C! C
24 B9 q2 j' M# T) ~: I# [ i1 ?
3
1 R* d7 o. v- A4" ^: O1 v% n9 y6 b% p' x
5
& @& H' [9 \! f h5 _' E8 ^6
- n s I5 m% S$ q76 A) g% I2 Y. S! f$ y7 v5 C
8
5 u, L2 h# x3 l' N0 _( L9
, h7 F; h! M% B& P% L1 @10
8 g) e1 @2 _) N& a4 Z( W- K11
0 {0 ~; j' ~, o# H12: y3 ? X4 q' U/ s. s! {8 o
13
, M8 f9 O. b5 ?3 k2 M' Z6 h$ l14
1 K: U5 I- l/ \3 r$ @152 r. m. _% f& G3 }# Q% y
162 h) R* y7 o- b5 ?$ A0 H/ {
17+ C! T! f' p! i( V
18* Q5 f" w% L' h5 U/ b( \
19' V% R, {& D% E/ e
20( { v# i& ~8 e' V* O. ]8 r- W3 \* S
21
4 s7 p* C. F* G6 ?6 e22, O% Z2 T2 {$ X% R3 ^. F
23) ]" E( Y4 S, w
24. ~( X; D- D5 v( f* Y, @
25
: T1 |7 z# G4 X7 [ p2 ~. m8 e26
% Z- T% \) n. N9 B27+ g+ a* q% `8 u7 X
28
! m. K# I8 w$ f8 K29
& e, ]9 ?, Q. U9 E" Q30 M6 J& @8 M* Y1 m3 |
31
% h- G5 F) e8 c$ X0 ^4 H! S) r32
. e' a# V9 r( f) U8 x# S- r' Y/ y336 i8 m# |% A- ? A7 E1 ]
34- L1 l7 H, V$ Y% F) M
35# F# y6 D% s% w v w3 M2 {
361 j3 P3 y5 C/ V0 E
37
: ^; z4 h( Z$ R$ a! H2 [, r: [- K/ h38 g/ k4 f; L9 [$ a: H1 `3 g
39
: ^ N. h( m* a: e400 o8 ^' W( A ]
41
! j/ {7 D+ A& f7 G, P. U42
$ \4 C8 s: k. ]- u43
/ X+ ^* Q2 y! A$ l441 e% D; @( K1 y4 d: O
45 Z: |5 r) k! t( o
46( Z; w* d% ?8 u9 [
47! X& K: ^: E2 F1 \
48
( T- \( c E, T- J$ l6 ~49
; o5 O. a( m- d4 _50
- l3 a7 @9 p) d2 L+ q/ [% j51
" U) y- L, d/ D& T I% m52
3 k" H5 t$ u G( ~53
3 ^4 G/ B/ ^% Z548 y5 \6 W, t# W
55$ I6 V( b3 R* I0 w% \. V1 P
56
`2 B8 c9 q0 u1 v, t578 C- c% X" l3 M9 K# } ^+ @# W' {
58
+ d. G, C% b3 t6 n59 f1 u# U+ O$ C; G" B, X. X
60
0 X7 x8 X8 a2 i61
; O6 A1 v1 z/ N62
1 d H) A4 D) b+ ^8 X% Y63 R. v, \1 f- `) ]! a
64
9 }, F( F) \0 e& w7 Y. _5 c65- K: L+ k& ^% p" n8 s2 I$ r$ z
66
& w8 k; E! L1 m67/ p5 Z# W! |3 I5 k0 _; n
68
^: v, v& C9 R+ g# d3 J# x8 d69
/ Z- @/ p8 w/ C1 W4 O- W- {- c700 a3 D% B( t& N4 B( C, _7 Y. b
71
3 q |. H, \( |8 B# k72
9 t t8 I4 F! g73# D3 b- _3 T, }$ @5 t- w
74
2 z& u- t4 l y75, @2 X/ t! h. F2 y5 z
76
6 y* Y- @* \* E9 G77
) \5 }. v0 p' h, Y2 A: ?78; ~* _" O8 P, m4 J1 V. u* D
79
( g, b3 g! w2 V2 E6 s80
/ y, I7 f5 c: g- w2 ?% I! {7 {" |81
" c& k, _7 h9 l# T2 u" U0 c825 q) O9 U( S8 V6 f% q
83
, G% `0 r ~8 U1 f1 e& c84
4 C( p% a5 M3 X6 I# e854 R: Q/ I4 P1 E( G
86
$ n, H/ P' A: Y' n871 ^. w4 E; T4 W
88
8 t, [- y i. Y. }% E89
m d( ] g3 _3 I90+ ]" O" ^! } C2 J+ }( F% [
912 R# Q1 B2 a8 ?0 @9 G
92" V( j) l3 D, D! V( {
93
5 ?! h% H& W9 ?946 k( z8 b! q' S
95+ M/ L1 _' c9 W" R- Q
96
5 E5 f( c! z7 C+ U1 N* Y977 M, C' u6 j7 C% W5 q9 n' `
98+ y" Y" N& J: x2 k' F1 q1 I% k, O
99) x9 N r2 A( r
100( Q) I; ]1 r) e6 g
101
) K# i# m0 E* p- h5 ]9 W2 Y1029 \3 x0 [- N, r; }* a
4.展示一下数据3 ]# f6 z# @7 p/ |1 {
def im_convert(tensor):9 r& r7 [4 p8 o# h8 S: ~, R
"""数据展示"""2 R D2 [9 T U6 |) B0 C. w
image = tensor.to("cpu").clone().detach()) x3 T* f7 \) Z+ F8 ]7 n8 W
image = image.numpy().squeeze()
! e$ @, S+ \3 n4 \ # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
- N4 Y9 T# s/ Z$ C # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)8 s! T4 ]0 r! v9 y
image = image.transpose(1, 2, 0)
" g" [- s+ J6 k" ]* f- V$ T' ^ # 反正则化(反标准化)0 x6 E( M- F4 @, h
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
# T3 _. I& i B( g3 E
( [& _' _/ `" k2 b R- L; b # 将图像中小于0 的都换成0,大于的都变成1
% N+ T9 E9 x2 w( Y. G1 j image = image.clip(0, 1)* j8 E, q H9 V) {
# x; `5 |# E: U: m' X return image. X$ l ]/ X2 c Z
1
5 _: V# |& \- Z2
F2 b/ X x" a3" K7 v: _ F) k: Q) }' R- i9 Z
48 l. C7 n* X* J# s5 |+ `) G
5
g0 d: Z$ W I6
5 m- n& Y6 n1 n( M5 q* {77 F% b; I* o* m) l8 L; K: e+ U
8; F& {! _5 k7 q0 N
9
8 S7 u5 I0 t# p3 f103 i0 T0 \1 w w2 A! i+ Z# [ g' R
11
6 e# ~: ]: a' H6 a12
5 |1 K' k6 H |, R* b2 d13
! q0 v" v* |. }3 M5 o14
6 |4 { v: J0 k! w0 S% z# 使用上面定义好的类进行画图8 R+ w. M# c6 a4 [+ d
fig = plt.figure(figsize = (20, 12))
: j; L% R" O- C# rcolumns = 4
( [/ ~, ~5 z) o @+ Qrows = 2
6 G8 Y8 N2 ] l5 Q( e0 i' `/ u6 @" Y- ~# ?
# iter迭代器0 q) F: A! x0 S
# 随便找一个Batch数据进行展示
& }0 U' V5 b: P4 M7 M( d$ Odataiter = iter(dataloaders['valid'])% h' M4 B; q! U4 K- z
inputs, classes = dataiter.next()
0 E( c6 ?8 z# k; W. W% f
! b" @. Z0 |7 P9 Cfor idx in range(columns * rows):, |$ L2 P- q) ^* P
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
9 }) ~ S6 ?2 R4 w+ \ # 利用json文件将其对应花的类型打印在图片中
3 \; d; F# y0 i" z ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))]), m) [9 j" x$ J; F
plt.imshow(im_convert(inputs[idx]))# ~5 p" H# t( @! c" s/ B6 O
plt.show()
6 N: f" s4 @5 ]+ k" m
5 Y/ {6 B7 z \3 e* _1
) Y, {& K8 @; S5 D5 L$ }" W2
( _1 `- R9 a1 F6 g5 V1 p" f- d2 v% n! Q3- {7 r# R9 l/ S2 y
4
, Y1 K! p) w! S+ y5
# L& I; }9 [& s' l6) M, i! @6 y; D' o
7
' M8 g# `: e. s8. |8 M( T1 O7 O" n( T5 h7 R
9
( K/ I' `% g8 e, X10
+ Y: P# H( y5 d# ^7 l# W) ~5 _11
; t+ w: f4 R- }$ k8 O. w9 P12
" S C+ ]8 s1 L* z. ~13
l' ]+ r8 R7 J" K14
1 ]2 U- m4 g0 q7 A158 P/ s) ?7 m7 h+ K8 [* Q
16
+ Q( l! d& R* M& s, V9 K7 P* a, b* V# M4 ] H; A& r$ R
1 `6 k5 `, l& n0 h$ T& Y, O5. 加载models提供的模型,并直接用训练好的权重做初始化参数
9 }) `" L8 G7 B' Amodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']* h. n- E; t9 {5 V# d5 [8 h
# 主要的图像识别用resnet来做
. {' A2 \. `7 U4 H! M+ v# 是否用人家训练好的特征+ k2 {1 h& i- X7 i
feature_extract = True/ C% |2 L& E; O. }& J2 t
1
1 o4 O f5 ~ z* z1 [) r' V21 R- W0 M3 n/ s6 u! O
3
4 e3 a! r6 E' L/ D$ [% v9 c4 r) N$ w: V( D7 H" f5 {4 u( d
# 是否用GPU进行训练
% h# M5 W T' y: _" Dtrain_on_gpu = torch.cuda.is_available()
; S, M7 c* H* N% [) P+ w, c
. W- A/ E7 y6 D) Y: f; K- Kif not train_on_gpu:0 p: I4 x2 B. K ]
print('CUDA is not available. Training on CPU ...')1 h6 Y4 _3 D& A: u; P- M
else:
2 d+ Y+ e+ [/ |/ e+ N; M% S print('CUDA is available! Training on GPU ...')
4 l5 |. E. I, V" [- O; F8 } H- L+ m: [1 A" {& R+ a4 Z" F- f" _
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')6 P. G2 r4 r" s' k$ e
13 Z$ Z% r7 g7 q4 ~
20 U F5 ]" C1 A+ n1 [5 f: E
30 `3 \( N9 m* q$ n
43 S: c0 z+ j! Q
59 e5 u: L d6 h2 n ^
6
, q9 [& w# r7 _. G+ n$ A8 R& e! M7
; L4 _4 s, F5 F( {8 W( r7 g8
; m. p+ P8 ~6 j* `5 ]) G: B9
% S6 w2 ?5 U4 K9 M* ~$ J" p$ jCUDA is not available. Training on CPU ...
9 v3 F5 g7 U+ [7 d5 l1
0 {* R( M5 O6 v- ?6 b2 Q0 n# 将一些层定义为false,使其不自动更新
@2 F9 A7 N0 ?9 a" z1 Adef set_parameter_requires_grad(model, feature_extracting):9 p0 ]) j) ~, K, u/ z' _# l
if feature_extracting:
6 p5 X( j9 h" B for param in model.parameters():. g9 f+ O; y* Y$ J
param.requires_grad = False
w4 m) A% J( C1# ^7 C* p+ G1 u1 Y* P, V8 {
2
! x* L4 `- o+ `3
) `. T5 H7 `) {: T. H* R( A4
; o; a1 H$ i. e& k2 b; O; T5
5 x0 k; j s+ f( i+ P) [+ m; {# 打印模型架构告知是怎么一步一步去完成的0 y: C) P3 a, k! U4 W/ w- J) u, f8 y
# 主要是为我们提取特征的& I& }4 Q3 C( J
. n1 k; N C' x! Y9 A* T" gmodel_ft = models.resnet152()
7 o- O" w9 q+ ]& X8 o- pmodel_ft4 X: r% N' m F, s
16 P0 U' S$ i: i" a' ]4 d8 v
2
$ `. S8 ]) R+ M/ i! I& L3' k7 X, s) s3 H: O/ C$ J
4
1 [, K+ }3 Z% Y& y4 n6 P5
" k9 C$ n: Z$ q1 S: U) tResNet(
+ q( Z: `; F2 ~2 w# H6 W2 z# J (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)6 d& f$ {2 o3 x8 E
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
. y' B4 l% B0 `* h3 N9 G) E4 N$ J (relu): ReLU(inplace=True)
A- G% d- A- l5 d. @# i0 g (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
. v2 B; _( j" U0 M5 G (layer1): Sequential(
4 |, c2 z4 |; h! \' `3 f. m (0): Bottleneck(0 ?# X9 S0 d& x- `
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
( f2 H- Z' w) S+ y9 g0 u (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True). c/ {1 G8 _6 i3 [
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)+ l+ v _/ D5 N9 t; G1 v9 p- ]$ L- T
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
] N' U; I' ^ `9 ?% ]) y: Y$ i (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
) X! ?( Y1 @# p" P/ ?3 x/ I (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ d# n: Q5 B" p8 G j; Z C
(relu): ReLU(inplace=True)& N. C* V* ~, G
(downsample): Sequential(& U- H! D4 d& w# G; A. a( R
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)6 g. l5 M z" V) n
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
, E# F) x8 D3 O )% r* e7 M: V! c& S
) G: W! g& [2 v+ R
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
. S7 F! U" m- B (2): Bottleneck(
% n# @- K4 A+ ~8 j# d (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
* H% y) R a# Q; h$ R+ h1 u: O d3 ]0 o (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)2 a: h0 |, b0 |0 p
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)# C2 \" ^7 Q+ @+ w
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)! q4 s: y* u7 `5 k0 Q2 U
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
6 w/ m9 M& P: k' V/ i& t (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
% ]* f1 n7 x1 ]( l (relu): ReLU(inplace=True)& q! L/ z: r& j1 ]9 o: U: ]- l# y, C
)9 b8 x. h3 A7 C# F4 _( D( u
)
( b7 W% S9 [* m5 _# i& [. K (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
4 r& r4 F2 I0 s9 x% V/ Y (fc): Linear(in_features=2048, out_features=1000, bias=True)' u+ W2 w3 w& H+ T! }7 ~8 k, I
)
# |4 f" J) h) h: i0 W, F2 M& j# v9 `2 ]$ [& l6 _. z
1
" d; _! {$ F; N4 x2
4 b% R1 w$ @: c% H5 R0 q3
" m: i4 |2 C9 r, J3 a6 @, j) i4
# n6 J- O+ Y- o7 _; I2 m: Q5
p6 I4 T5 Y l' r1 O/ v( r, G9 |6& n/ S$ x6 Y$ k& I3 o# T o
79 g6 k, z/ F7 ]# g0 K6 W2 c
8
# }1 Q$ }+ d0 z: s( U$ b94 f$ P9 J! [( f, O; G- @6 {4 @
10" f( w3 U: _6 {" N3 c
11( \8 O' K# O7 n$ `
12
" q; B) B% e9 m d) X137 n8 c% _7 S' s. C0 B# Y
140 o( t) K) a" m. m5 B) b/ W
15
. f) j/ p8 u$ e- s9 S16
$ p* o7 H, X* S3 N& A8 q0 e170 T, C6 j: |$ t! i4 i' Q G
18
2 ]# ^, s3 A7 S& u19, C- \! _2 Y% H+ g' u2 {7 `3 \
20
- x- b0 p! t7 b9 j g M21* h$ S& b9 w1 G+ Q; b
22
7 n/ X7 y( i, f+ O) W' c6 I23
* S1 F) ]& ?' n# n" E244 X) ^ b7 S( m: h/ Q
25
# h% w4 y& ^2 Q8 l* Q26+ T: T6 M$ `0 L
27( E6 E l- J! P; n" m/ \6 ?$ i' I
28
1 {% y( P/ E# T, |/ e29% ]& y3 }. b" x
30! j- \' t" \5 [% W8 T: g) d
31+ F. p1 f, v+ v, Z& ?4 x9 {0 ]; K
32
7 |8 W$ C3 s2 c$ E3 c/ b33
1 A) L7 h5 w8 F! G# {, x1 b最后是1000分类,2048输入,分为1000个分类
' j/ R9 N! I4 Z6 D' D n4 w( _& O而我们需要将我们的任务进行调整,将1000分类改为102输出* S' j# q; ]; O6 M8 G8 c
+ ^+ ]# D2 u6 ~+ F; M1 Y1 }8 L6.初始化模型架构. y* F; ]1 H7 \/ w3 S4 U, X' r
步骤如下:9 G, V; A4 `, E' `. M, |6 c
& h# l5 F( x1 Z将训练好的模型拿过来,并pre_train = True 得到他人的权重参数- d! e" x# r5 d" [& ^2 W
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)' M7 d! M6 G! ?" I# @
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数' T* r X, V# y6 x3 @
官方文档链接
9 h. r8 ~8 [8 k' F; ahttps://pytorch.org/vision/stable/models.html
: y5 U- q3 w* n! z) P- g, J; P7 W: X6 _5 c2 S
# 将他人的模型加载进来 `- O2 I' x5 f: |+ N
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):+ a5 [7 N) a9 l" ^4 Y7 v
# 选择适合的模型,不同的模型初始化参数不同/ x( N# C" o1 l8 h" U0 a
model_ft = None
; O* R! |7 ^7 b3 M0 f' L input_size = 0& M$ D. E+ @1 L' D5 H
5 D4 Q4 ^* |. |$ T1 H" ?, q if model_name == "resnet":2 S4 m7 ~3 R c8 j
"""
! L# h! R4 X: y+ B& P; h Resnet152( t* k) S' p; T+ G7 r
"""& }) a8 v* \% M9 P3 J1 A3 R2 s
( F; K, L5 V: x5 {: ~ # 1. 加载与训练网络4 {1 N8 g0 G, r
model_ft = models.resnet152(pretrained = use_pretrained)
) i: t/ B+ f; k c+ t: W5 X" f! T # 2. 是否将提取特征的模块冻住,只训练FC层' Z" m$ O% R$ D! o
set_parameter_requires_grad(model_ft, feature_extract)
* \* Y& B' N1 x4 h # 3. 获得全连接层输入特征: A( A3 y8 N+ q( r9 K; Q* e& ^
num_frts = model_ft.fc.in_features8 T3 ?3 O5 [: U' f1 g# w7 e( H3 D
# 4. 重新加载全连接层,设置输出102
8 R8 n# q ?2 A# ~. z; _ model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
7 }6 t+ h9 r) @# e# ? nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1# ]- V! X3 e) t8 e4 f6 [
input_size = 224
* m6 P4 E! f! q7 Y8 r* F2 f8 h( h) \; l- Y
elif model_name == "alexnet":
# W- L! @2 }- u7 K( G """
+ l+ p/ o/ a. n8 h Alexnet
7 m+ ^ W1 t( }) }" k: r """
- q% O% |! s5 t5 p5 |; o model_ft = models.alexnet(pretrained = use_pretrained)
8 {3 W: \ W& v+ P" l+ e% u Q1 o' s set_parameter_requires_grad(model_ft, feature_extract)! Y! m4 v3 d: ]8 h8 M& T# b: K
: z- m& X+ T' h; U3 e: i$ B # 将最后一个特征输出替换 序号为【6】的分类器+ o& y% S, `3 z& x, P- `* y
num_frts = model_ft.classifier[6].in_features # 获得FC层输入
/ ^" M' d% O; U8 f* p0 a% L model_ft.classifier[6] = nn.Linear(num_frts, num_classes)& V7 H; f' x8 H, b( `" x% m% H( X# {5 R
input_size = 224
/ C" ~2 Q# G5 n1 I0 R4 f4 q$ J
; b# q$ u2 y- f v( d elif model_name == "vgg":
- k; _; P: m, @+ n4 T1 w """. }' r% Y( r Y/ n
VGG11_bn2 k9 Z2 A/ A4 O5 X, q
"""
- ^1 m/ A( ]' T model_ft = models.vgg16(pretrained = use_pretrained)
1 w7 q: p+ `) T set_parameter_requires_grad(model_ft, feature_extract)
% r& T# C; R3 C7 M3 {# d8 M num_frts = model_ft.classifier[6].in_features
, T/ `& l6 I9 x0 u# l: [8 Q model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
, d( Q: G3 ^3 \ input_size = 224& E& p3 N) t3 Z6 n
7 [7 I9 R1 V( Y" h# y elif model_name == "squeezenet":
& V* v' n% s4 r. Y( a """
: e$ C/ v$ Y$ v+ a5 T& y Squeezenet
8 v% }/ P! h' p% C% q k3 z """. ~; A1 g/ c" P& v$ x* [2 X
model_ft = models.squeezenet1_0(pretrained = use_pretrained)( O- I4 w" h& ?, M) i- k1 m9 t
set_parameter_requires_grad(model_ft, feature_extract)4 u H3 K2 f1 T: g
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1)), s" U1 ]* x- {: l- Q( ^/ X; t
model_ft.num_classes = num_classes* D( F7 n2 d. }/ v5 l- _
input_size = 224
! D4 d- j1 n+ q, S: I' H: z5 I$ r% K9 l9 ?# J @
elif model_name == "densenet":" U1 j" ~& K" D, u+ y% {
""" ]+ ?8 c4 n7 F( @5 v
Densenet
6 X- b# \$ j2 j' A """
~' g& h9 G( ^( V* Y model_ft = models.desenet121(pretrained = use_pretrained)
3 U0 `; s- }8 N! }" @/ ?8 d7 z" v set_parameter_requires_grad(model_ft, feature_extract)/ T. h, ^0 p$ G, k; A0 k- p: h& o
num_frts = model_ft.classifier.in_features- k! \4 ]2 W* k" N$ S: q
model_ft.classifier = nn.Linear(num_frts, num_classes)
- Y6 z2 Y3 g6 g {$ ^, O input_size = 224
0 ]* x9 V; O6 R) T
% G3 L0 G2 X( _ y elif model_name == "inception":
9 V3 }7 q# L+ e9 h9 y2 V """
6 w! m6 Z+ _# g3 @. o$ T! ~# Y Inception V3
. R/ K% ]/ e4 x) u& j7 Q """
) F1 [6 L, [* r0 J; h( n model_ft = models.inception_V(pretrained = use_pretrained)* n0 e; P. E/ @& W* p3 B
set_parameter_requires_grad(model_ft, feature_extract)
7 Y4 A+ u$ [9 q# W ]2 d/ |6 b
9 H: V( ` A) t/ Y7 }# F num_frts = model_ft.AuxLogits.fc.in_features8 V( @, |- u# U6 G
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)# Q2 y3 V8 u8 D) u) V
7 j v. i; y" @* w. `% S. |
num_frts = model_ft.fc.in_features% |- t ?9 w1 K9 ~* {/ m1 b
model_ft.fc = nn.Linear(num_frts, num_classes)
0 ]# g6 i4 U5 `- q+ J% f4 p; O input_size = 299
: A) Q) h# a- }7 Y' C+ E1 x( l8 L4 W# ^) v" o( P8 P
else:! z; _# j }6 i) }& u6 z
print("Invalid model name, exiting...")/ u/ f/ ~- `% o6 @) r
exit()
' p( ^2 u8 y0 t$ i/ u% y9 c% E4 j N Y% D
return model_ft, input_size, g0 Q. ^# H2 W0 C" J1 w
% V6 V* U2 u7 p+ _3 f: j
10 h( P1 j' z' l% _0 [+ t$ g
2
2 R) x) Y+ H! Z4 x: d. ?3
, v8 U4 O4 Z3 n( U4
* o; L! _7 P% Q" b. g8 D3 Z5/ {( `. y5 _% X# f( {4 G' p2 S
6 ~9 B* E: n0 y! `, o T
7+ R8 \5 f9 w$ R2 ^
8
( |7 v3 ^4 f1 a+ ^2 J9. q5 R$ n2 B( L. ?# `
10
" R, {9 u, z9 i3 H$ {11$ H) |' H) S) ~* I
123 b. B! x# [; w- d) Z& n, Z
13
4 v6 a# B2 \, G! T7 k# W14
5 N. z( c! s+ K, p0 U15
- m' c6 Z( h4 G) k167 E, S1 f- \# h
17
7 t' E( f+ f7 W8 ]9 _# q18
% |7 J6 j! h- d19
) }$ P! @ r5 G) E20
" L0 y7 s+ N0 z# p21
7 f& ~6 I2 _4 ~22 w' u# l' h0 W2 b
23
* g" o0 ?; J! J& K246 }8 n6 B- q6 z1 c1 G& q" K
25
5 }8 x2 J# s( z8 o7 I+ ~7 J0 U7 L26
* E) f$ t' p9 V, V; J$ Y278 x3 M3 y2 K1 `4 [
28* j. F- s! C$ @- P! u) }
29
7 b3 v) E, m1 i8 _% X$ p! p30
3 a' T q3 Q0 l* m7 P6 ^1 y" g313 a- `" w, ~6 p* G) [6 G
32
3 a' ^3 L7 E% ^3 H! D33
$ ^/ i) L: ]2 R2 O347 b9 D( R4 O* b
357 e3 J4 n, O' [9 w9 b; h
36, ?$ w5 `& n! t
37
. V& Z# N0 K$ `5 h* \% K3 ~38% C6 m8 _" \+ f0 v+ a
39
! {2 B# K6 V, B" I407 I& x" }$ W4 w/ n! C( l1 W
41/ w3 g9 Q7 {9 X5 M& I. h4 u! n$ y4 \
429 q' ^- o+ f- e5 F1 E2 L e0 V
43. |( D e- f$ ^' A; w
44' U! m' `) G- s" }: W8 i/ W. f
45
* \0 ^( o* f8 z46, z3 K: H4 k. W/ p! w% Q: |
47" ]/ T$ ?2 h' L
48
6 N0 t; z/ y* A; n4 \49
! ^& g6 y) _5 N! M" T! l( z& d( m3 I50
/ T; U7 h' ]$ F* G$ s8 O$ E, w H* e51+ T1 @, C1 y+ d' e
525 Z, D: l5 [# o& E( g4 j7 |9 v
53; L' h r7 i9 e$ m8 m& h1 J
54) {5 X. V3 E0 f, X" ?5 G
556 C+ s K# R6 [0 u! {! z+ Z- C' ]
56
% [0 J8 V# {: `0 N1 H/ u/ s575 x9 P( o/ O6 h0 l+ [
58
: T! n3 d, G* f7 n# l) T; o) f/ ~59. ^8 S9 ^; N9 b3 I) U) i1 C
60
X" X4 i' P/ j7 ^, a, b! g/ O9 C61# J* Z7 y5 |' u! A: S
628 R1 _" j/ i& V
631 D2 K& v2 M1 e2 @7 U4 k4 w0 z8 y
640 ~4 t( L: T1 v
65
" `) B T# o( {' w66: i! M; w* B# y4 f4 _
67 g X: Q. @3 A- F- K) L8 t
68
0 n- ?/ K4 s' V+ K. I0 B+ d: M- E69
( f" [4 N J7 A70+ p7 I( V+ K6 s1 l
71
L$ t# I7 Q3 H72
Z @# u% Z1 W+ y2 Q" p73* \. O Z- ^! T+ o
74
( }/ x+ n2 G0 u& R3 M# \75
4 ]" k8 ?# N/ S* k- `76
; o2 o$ D- D3 v+ e: _( z5 c77
' _# C" I# i2 n789 D( r1 U. `, o6 b: ~
799 j0 n- i! ~; {; a
80
" D: A- B! Z$ m7 K5 O3 W- Z814 m1 w4 ~6 i' v( f/ `" r
82
' V+ G& j8 t" _! z83; T) F; z$ W, K
7. 设置需要训练的参数! k+ |4 P3 _' R3 p) w$ [
# 设置模型名字、输出分类数
* u. i! k- a) h2 z" _) Z' A# ?model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
! i5 \; f S: Q, O/ j( U2 ~7 C( Y4 a6 y
# GPU 计算
( G% {, L W: p- a9 u( T) y3 D& Pmodel_ft = model_ft.to(device)
/ d6 {2 N. i& u5 L1 B/ P: A
# d# {) A i5 u# I$ i" o4 A# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
8 g- J, U/ `; ~2 {8 L2 Ufilename = 'checkpoint.pth'
/ T; V/ c- a8 k- V; O* a. F2 q' p ~
# 是否训练所有层2 Q* u- ?) W& Y6 L& t+ h2 i
params_to_update = model_ft.parameters()
" b0 y8 ]7 h( L) _* ?7 z. @# 打印出需要训练的层6 ?) @) c" S! l. l* X- y! N
print("Params to learn:")
' S w% ^5 l# b( Q6 Rif feature_extract:- i2 n3 z( }% ]5 r7 N: w4 ?
params_to_update = []. e' W( B- K/ V
for name, param in model_ft.named_parameters():3 T$ v. l; T( b' Z+ @6 |! b9 Y0 o
if param.requires_grad == True:
! r1 D; _2 U- }* @& c params_to_update.append(param)
* c' j3 K0 Z q i E! z& b7 p! P7 E print("\t", name)
) G* W4 `$ E' ~2 O" Ielse:: w& P7 _3 Q) H& X: `
for name, param in model_ft.named_parameters():7 a: l* g. R6 d& K
if param.requires_grad ==True:: j% V' e7 n& n3 t2 V2 n
print("\t", name)
& w6 q. @# W9 B( z
4 P' t4 o! V% H1
+ J% ?2 E- o/ Q2( h0 q! \/ F) E0 X
3
$ {% R* x9 r3 l4! L) F0 E% \( D; p+ ^
5! g. W: U4 L3 H5 d$ F2 R
6. o( T& _5 [6 `
70 @6 H& ^1 h4 v$ |, S
8! O1 H& }' Z* S- D) v
9( h2 R" E, R! H6 _6 I
10
/ j0 H/ D" ]# E( |$ @8 v; O. K11- r/ d/ n+ k* K) q0 d$ h% `& |
12
8 {2 _3 ^6 P" Z0 |2 j13
6 |( I; @2 ?# K' F r& C* O) S3 p% u14
5 L5 q. B8 G* Q6 A+ X15
% p" V, F( i: E; H6 G( ?0 i16( F5 B& G+ x0 K/ Y
17- o" m# |4 e" o
187 |1 X0 h _1 T
19
$ N6 }( f* E3 {( Q9 T205 S" Z# T; j9 m% J
21
o1 M( D" n4 T) |22% c, E) T* C$ L: A. _1 e2 ~
23& @. T& F6 _ `; Z+ ^
Params to learn:
- ]% ^& |0 c. i P% C6 V, O% u fc.0.weight
+ V( N2 o7 w& k5 Y: P3 T fc.0.bias5 T6 H% T& |# e
1
Q) ~3 W8 o G2
8 }. `+ Z# L, z ?0 i% \$ u# B3: r- r# p! K/ y
7. 训练与预测
& m! I X' z- _! l' l0 `; u' A* p7.1 优化器设置
2 G, l6 W2 i$ w# 优化器设置
. b% d+ w, l8 T2 Y& V9 B; }$ Eoptimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
9 R: \- r* X% k0 v/ @+ _* Y# 学习率衰减策略
* G2 X4 X' T, L' Jscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)4 }9 l- r# ]; k: x) P3 s- r
# 学习率每7个epoch衰减为原来的1/10
$ Y+ h ]: g6 y6 R! F# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算* S" b% w% F( K7 u
8 c9 }$ B) `+ g1 P: Ccriterion = nn.NLLLoss()2 b8 H3 i# F @/ j% s" H/ U3 {
1+ z9 m$ D" k' F+ x* d+ Z
2
8 z; S. ]+ a" |1 x3+ T d; E2 o7 M8 a' n0 M
4
0 ~4 ~8 U/ F. ~- w" o5
9 \ ]& \7 ]4 P6
# l U! j7 d. a U) i7 I& U: s3 D, H/ J ~0 D" u; r2 E" S
8
# V$ e/ l# ]. V9 a0 l1 f# i# 定义训练函数
5 ]( `2 q4 |* ]! A#is_inception:要不要用其他的网络
0 i4 z6 h1 B G+ V8 ndef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):# a5 J9 A* Y j& c: d3 L+ l& A
since = time.time()
! i8 U ]2 s' J0 p/ B- j #保存最好的准确率
4 P, U/ i7 X: P8 C. s best_acc = 0
1 F+ I2 i" h. `4 v """
: q- `% Q$ J3 K2 m9 C checkpoint = torch.load(filename)
7 m0 O, f) s) k3 \) `+ O1 f best_acc = checkpoint['best_acc']2 z+ A# b( |( }/ `9 Z5 H j* L
model.load_state_dict(checkpoint['state_dict'])0 X, a2 s0 S. N* I
optimizer.load_state_dict(checkpoint['optimizer'])8 s0 {+ @4 c: K* [" K: b* A
model.class_to_idx = checkpoint['mapping']
: o0 J4 D2 L3 Q. v. ^" O* x- F """
; E: B, G+ b2 o6 o1 @3 P4 f #指定用GPU还是CPU2 {( v" O' Q9 {
model.to(device)4 l9 N) b; W" X& V& |
#下面是为展示做的( ^) t/ O3 P2 n6 v: J
val_acc_history = []5 J, o& [1 u, y& L; z/ w9 r
train_acc_history = []. o# r' e Y! k
train_losses = []5 B3 ^' H+ z4 i& n, i
valid_losses = []
. `; n1 p7 S. A& u6 y+ u LRs = [optimizer.param_groups[0]['lr']]* G8 d% a* n0 G4 C0 M
#最好的一次存下来
7 {) |: C$ z7 j8 U best_model_wts = copy.deepcopy(model.state_dict())
" P" U6 }0 d. L) Y6 Z3 d. [# U3 J! n7 @8 p
for epoch in range(num_epochs):
! i! K/ h" x1 `& m2 ~) u5 {% Y5 Q print('Epoch {}/{}'.format(epoch, num_epochs - 1)). L% |5 S! W1 Y" ~9 I
print('-' * 10)
9 I! I( f y1 ?+ Q0 ~) } P6 ~! F H9 n. \7 M$ ^/ b
# 训练和验证, O0 T4 k- ]8 Y" ~3 {
for phase in ['train', 'valid']:, G7 u8 R9 n U
if phase == 'train':
. ?* {$ Y' ~9 e model.train() # 训练# O: y* a S1 }' B
else: N& {0 v6 m+ b9 R4 m# z2 {
model.eval() # 验证3 j- l! i5 H/ m! h; k" ]2 I+ ]
5 H! ?( [ ~( _4 ]* X
running_loss = 0.0
2 d5 P& }, ]' j) m- n7 d running_corrects = 0* H* M* c+ T, z b' \# r+ h; W
# d, u! z6 @7 z& K: }: J
# 把数据都取个遍
9 B$ E0 C4 J5 t) ^( r for inputs, labels in dataloaders[phase]: X8 q3 A+ x0 i) x/ z
#下面是将inputs,labels传到GPU
) N& X4 Q: {& f' D inputs = inputs.to(device)
' [# j2 x, k1 f1 q labels = labels.to(device)% c$ E) i9 C7 {
# A& s+ p- a% M( g* t1 n
# 清零! D% a+ h3 z- U: |
optimizer.zero_grad()
: @$ f6 z' j. U8 \0 o2 k+ i5 z+ n # 只有训练的时候计算和更新梯度! G" l" l7 Z3 c: a2 D
with torch.set_grad_enabled(phase == 'train'):5 x- ]5 O* L! A6 ^/ c
#if这面不需要计算,可忽略3 v" S; l8 B6 o
if is_inception and phase == 'train':- y7 f; l+ p9 g5 D0 H
outputs, aux_outputs = model(inputs)
9 Y8 j, i' i" n' ?. s% A) i loss1 = criterion(outputs, labels)+ T0 |" a, B$ q, Q/ L# {
loss2 = criterion(aux_outputs, labels)% y, l& w4 h5 P7 \
loss = loss1 + 0.4*loss24 B0 A3 h L! w4 @# C# m( u
else:#resnet执行的是这里7 R3 u: u2 d! p8 [
outputs = model(inputs)
1 L0 N3 w! F( p( v" E* Z8 V$ H4 ^ loss = criterion(outputs, labels)& V/ U' _* U% v/ I/ \+ D
5 p/ J8 ], v) ?
#概率最大的返回preds
+ g/ B+ P* n% b* m" v" U% l _, preds = torch.max(outputs, 1)1 V0 U' {+ a& R. U# h+ d, l$ s
5 E5 E% V3 _- y: v' d # 训练阶段更新权重 o4 L; N# x1 v
if phase == 'train':
# t& H4 L: R( t: `' _4 N( _. ? loss.backward()
) {2 W' @4 q( b4 h6 O2 B9 i optimizer.step()# q. H/ w4 P* l" Q9 C" ~- A
7 ?( ~$ |% [! m) c* e4 v # 计算损失
+ C" ^$ f, r% C running_loss += loss.item() * inputs.size(0)
! g* M* R! p- v! _% z3 L8 x running_corrects += torch.sum(preds == labels.data)! k# [! f- S6 V, |
* p; v1 k' y8 W3 y2 @" `# |$ j
#打印操作
/ l* c' b) d- O% r& @0 v, [# ^5 [% K epoch_loss = running_loss / len(dataloaders[phase].dataset)5 p6 @. e( F2 U: ?) d+ o5 {' F; u
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset), J' s V( K; J9 J
. [ p/ y9 i" ^% @ U# D( u; Z# R2 C9 {% U
time_elapsed = time.time() - since
: s5 c _( o* D; P. ]% P print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
, P# |/ Y8 c1 D print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))" e5 V: x( }! Z" W5 d- R
- m/ N2 x9 h7 M1 g# `8 E. i
+ Y4 x4 X+ k" m5 ` # 得到最好那次的模型$ l, d6 r4 s, l
if phase == 'valid' and epoch_acc > best_acc:, A! P Q6 e) f/ ]$ C7 m! V: T
best_acc = epoch_acc l2 w% Y! e- X
#模型保存
2 c, _9 b& F4 H# S. W best_model_wts = copy.deepcopy(model.state_dict())1 N, H% _2 d# h* J" n( }
state = {) H0 a' |2 K: O) ~* {7 F+ o: J
#tate_dict变量存放训练过程中需要学习的权重和偏执系数
; g0 f u+ p- s 'state_dict': model.state_dict(),
; d) @" U1 P6 i7 l7 T) H 'best_acc': best_acc,
) A- M* ~4 I* \. h7 A 'optimizer' : optimizer.state_dict(),, a8 G, N% c4 k, G0 D
} c: R- L" M2 X2 X `/ S
torch.save(state, filename); p! ?6 D! C% p0 u# q' [
if phase == 'valid':5 D: J. o7 {/ H6 G: z- m
val_acc_history.append(epoch_acc)* q# H5 f& d* C* F) x
valid_losses.append(epoch_loss)1 K: \* z% ~* Q
scheduler.step(epoch_loss)
& ^ A; B& j7 ]1 R* w% h* @% G if phase == 'train':
- l2 Q8 y6 Z) C6 ~ train_acc_history.append(epoch_acc)
0 I/ B4 r0 Q2 a+ ^: z train_losses.append(epoch_loss)4 T3 \( Y( ^, x
" \) P2 Z& r* o$ i1 N/ W% X- A
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
/ p( M. _; `# w$ A; k LRs.append(optimizer.param_groups[0]['lr'])$ `/ J0 u3 ^; D5 J2 z
print()
; @7 C6 q1 k; d7 A0 X0 r- b2 ]) c2 N1 l p* S/ L) j, f
time_elapsed = time.time() - since+ o9 c2 n7 T& n# c
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
& X0 a; I: H1 L& V6 K k) J print('Best val Acc: {:4f}'.format(best_acc))
5 F" L" ~# ~6 w) ]. N% O9 ^! Q
, b4 u1 n/ ?0 R! O* @5 d6 z # 保存训练完后用最好的一次当做模型最终的结果
% r& ]9 S: y; B8 h) H model.load_state_dict(best_model_wts)- s [8 t/ m+ y1 W$ l
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ; m8 g# K8 O9 g+ O
( O' m; l- r+ E4 x3 g6 s
: _$ d8 L# f* Z8 G! [7 q$ x! d
14 p0 R0 ?8 k, H$ q" r( m# N1 [
2& p) _) I/ h: v3 u5 J+ o L& V" y% [
3
8 a' `9 \9 T/ X$ D# S4
9 }( `1 ]9 ?6 O4 r58 e1 p' C% M# b( p' }
6; V* ?! q' h$ {6 x4 T9 j* ~4 P' |" ^
7
, g& c, w3 I M" a82 R! S9 c: f& Q) O* O8 T
95 C$ ^2 S* f8 Q' [
10
# T) o, c/ `" g% \4 Y; I0 E+ N11- ?" _* @1 ~/ B2 c& Z
12
3 g( H) v# ?% Q* Z5 [& C, k13
5 V! u4 W( J9 ~0 D141 Z; q8 l, A1 f q) K
15
7 K, ?+ r \8 E4 m+ A z1 _16+ J, P4 M; X3 y! A
17
4 x/ a3 E5 b9 ~* }5 D: Z1 v18
2 G- v, W) u% ?$ Z19
9 _9 k- t) |' G20
: X/ E) G; q& r2 X: i& a5 m/ p21$ G7 @% w: P: C' E. Q& u C
22
: g2 Y1 X! V/ N1 }1 b23
' F$ Y- x4 L6 l. ]# @8 `. U24, L/ J4 L5 l& U# e4 m7 Y
25
; g( a7 x3 ]$ B6 a26
, ^2 g' R9 H$ y% b& T27. I: c, o1 {; Z2 A, c
28
; _8 k2 R0 o) ?; t& {2 t/ Y29 F) p8 H2 F( a5 E2 _
30& I! `+ A8 \* U1 s* [
31
( ?. a0 h. X: E( z u- i: R32
" A0 v/ d# j! n! S) E* C33
: C7 V, _. ?; w9 U34) X) v4 m( p' o; T. [9 L
35+ L. J5 n# O7 O, M& A) Q' y
36+ h, z( A$ T) b t$ ~' g7 S5 |5 J
37( i5 T- L, h6 ~/ a
38
8 Q! v2 r+ d& r1 {& G1 L391 K! u' z+ n6 c \3 v
40
; c' E D" P9 o2 _41
% ?8 ?; ?& j3 T4 d6 @# G: G# o42
4 C3 }: r: B: G6 A; }43
* j, v6 R4 ?( @- `44
( @9 O1 {; h" a2 f* w, Z7 ]; |45
( L; c% ?! J5 W# G( N6 N7 u! k46! Y; M( }5 a' a/ R
47
' W$ ]! U# ?7 x/ W: C; w48
. Q1 T$ g" R5 }% F4 _/ b49
/ @5 p* G3 [3 ]* y: B2 h50
0 g+ K, H) r* Z# O0 W% C3 G; b' @51- R z- y7 \/ r7 Z6 G# J& \- V
528 P0 B( A; z( o9 t) W
53
# ^( F% W: w/ @$ C54
: W0 }1 G: X! j: g7 L+ E6 d" y( L" w55
8 [3 J7 [% \3 N/ A" F568 A6 A7 r: _! w$ q3 n
572 A; P9 W5 I7 Z V
58
( `: _( b5 ]- X0 A& |2 c2 @59
* a1 R8 a, p/ q' K0 o7 I+ V( b60
; Z, J$ o \. I1 R. s/ a, ~61) |/ z9 ~! R4 A0 F- s) P& R
625 Z6 P* l# A% o ~2 A" X
637 m2 w! M4 l0 \( c0 y& W
64
/ `# p8 N0 P9 D* v+ d3 q65
' |$ s1 u" V' \: b9 k3 Z66& J# R1 H5 {# v
676 X3 w- U2 \2 |: Y" h* K4 p/ {
68
0 E$ z$ ]2 I+ f$ k" A69
3 H8 l. [3 N, `0 ^0 S, N70: R* B5 M7 q5 j
71
" J6 v% h% O( _& `7 [72' g$ X: n' s! w- c
73
0 p! o0 G/ e" [1 s+ b- j% G74 o `0 J' ^- n# q
75
+ a$ u2 T2 o9 _/ ?7 a% c' \( Y76. `# [3 v& k5 l4 o! |, w
77
. m7 o$ G- i) M& h8 P# P2 f, O78
, D9 _9 b5 w( D0 D# X% M79+ e# r& O; a4 l1 o
80% o" J! g1 {9 w w7 H0 b
818 i. |2 J/ v% M& {* L' G# I8 h2 g
82
* `# r& { p0 \9 [' A+ V83! y6 J% J. |0 L) d
84
' ]8 v: H' [5 l+ c2 O) _$ x" A# P85
3 z& {" l3 S4 B: j& x- p86' `- k7 E4 C, g1 I5 U& ]
872 H1 I1 {) P& R8 P) _9 t0 |1 q4 s
880 }* _, M4 ~! R7 r4 j4 u5 F
899 o7 b* f2 w3 D* X7 N
90
4 [: E/ @8 G. ], L! C, M91. ^4 a w3 D* B7 D7 o& q
928 ~6 J: s3 }* n- i- }) c# s! C2 B
93* {+ N1 i$ z0 i+ y/ ~' H
94
& L: g% y+ e5 X2 W! G+ U/ g95" B0 N% U8 M( M' P c
96
- Q; v, W7 H! j2 | g6 C! O; v6 ?976 b6 ^* i- t- e' O( T' g, w+ d& l
980 j0 d# e' {0 p2 s% A s+ a$ D
999 G6 E# @7 F' h+ X& L9 r5 C+ R
100
# E& R/ g+ p J. d101+ j. }! u& l4 j% h. F; @: V
102) b. S) v: I6 k* W0 X. C7 L
103
$ e: j, Y" D9 r" W8 M8 @104. a% C7 |1 f. P$ A9 n; w
105
% S4 W4 w! H M- A( {8 k1062 I5 c: J& M* N) }; d9 z6 c
107
8 a$ w- X7 C+ G1 p' b& Q8 C108
% a/ i- A/ r1 X o7 U1090 S7 j C# q q/ n
110
8 X* s! `( s' p& ` ^# O/ M111
. ]% T2 u: m9 R1126 Q/ O$ h# o. k
7.2 开始训练模型3 [) }) @1 `0 n C+ X
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
& h# q+ t. {, e f, ~+ ~
' F9 h. y; {) o1 {: r, C @#若太慢,把epoch调低,迭代50次可能好些
4 T4 C' G, C! c$ [#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
Z, i- R8 d9 Z1 Rmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception")). h! b& Q0 [8 T( H
8 u% l; V* ~' X% f1
) y& [' p( y9 l) V. r2, U' A% n: u- X D$ {
3
6 I/ J7 D. S9 X5 Q4 _& U! l' h4, B" D0 p& ?8 s; ?
Epoch 0/4* g6 ~, z! b1 S1 v
----------
2 b! M; E* u' q" RTime elapsed 29m 41s1 u4 ]; H( R- x/ V1 n' l. i) e
train Loss: 10.4774 Acc: 0.3147 @: w Z- d1 g M% c
Time elapsed 32m 54s+ B6 y, K6 t/ [. h- Y7 ]
valid Loss: 8.2902 Acc: 0.4719
4 n w9 v3 v: J- ]! yOptimizer learning rate : 0.0010000
# B, z1 w; U$ ~; }% I& _! h# u% T; h8 p
Epoch 1/4* W+ W z, `6 L/ b# q% }
----------
3 T& a5 ?! I" B" mTime elapsed 60m 11s
/ @' ^( u& d! u3 z8 f2 itrain Loss: 2.3126 Acc: 0.70536 r8 t+ P4 A8 Y8 y* L
Time elapsed 63m 16s
& r9 k6 @0 d5 y) lvalid Loss: 3.2325 Acc: 0.6626
9 }! @4 ]9 A5 U7 ZOptimizer learning rate : 0.0100000; W0 e$ V- t# f9 e3 c/ {
4 B0 |) M- K/ P/ W8 p" vEpoch 2/40 G p& i; d5 s
----------! a" a" Z8 X! ^, e5 u/ g; T3 H" |- O
Time elapsed 90m 58s
& u3 }% X9 j* Vtrain Loss: 9.9720 Acc: 0.4734
& S9 y' {% e3 XTime elapsed 94m 4s5 H/ O) p6 w, N% s* J
valid Loss: 14.0426 Acc: 0.4413
3 w: K% M* D& v \; K% Y* f" eOptimizer learning rate : 0.0001000
4 D: M% j8 V- x6 L' x9 E/ g% u( g. T4 r- r' I8 `9 [+ `5 R/ J! R
Epoch 3/42 L8 h& A9 C& U. o/ W
----------( M) o: `/ O. v
Time elapsed 132m 49s
# ~5 R6 \9 O8 M/ J& _& t Ctrain Loss: 5.4290 Acc: 0.6548: ?/ U/ d5 K. F8 I+ v$ q: m
Time elapsed 138m 49s
# l2 K4 y& i% w! T7 Yvalid Loss: 6.4208 Acc: 0.6027
- w. e3 o3 A2 V' U, R6 w2 N# dOptimizer learning rate : 0.01000009 {9 A/ V c e4 P
9 H8 B/ C: K1 w a7 k- s: ?/ B
Epoch 4/4
( M) O* ~* S5 V1 F9 ], h) ^----------
3 g1 Z0 _# A) b! ~Time elapsed 195m 56s" X- t2 D/ ]( [; n- L0 F6 P% F2 |
train Loss: 8.8911 Acc: 0.5519
" w" x' h; w+ V9 CTime elapsed 199m 16s/ q& ?1 u. X, S- E
valid Loss: 13.2221 Acc: 0.4914: Q) B( I0 t, K* ~
Optimizer learning rate : 0.0010000
4 F( u; ], H0 \, _
+ g$ c$ ^4 e- WTraining complete in 199m 16s# {# p1 S+ b" t7 O
Best val Acc: 0.6625924 h8 A4 v& X5 |4 O1 u1 U6 T+ G
2 Z; P' g3 v' _* l8 H7 y+ P0 t
1
# m, _# i1 G3 h7 e* V1 V; Y# z8 o2
! V8 }4 T: y8 R2 e3
8 ?3 j1 k' h N4& a- }7 } Y: s4 b2 E: l7 i1 H( F
5
0 _5 R3 O. U* D! r6
( W2 C. E8 z0 \( @7
$ n' f3 d: E% O( j. O" [83 H$ ^1 ?# s Z) Q: Y) {" n
9
, _; B/ q4 z% g7 u0 h! u x0 F10
+ K6 V& n1 O5 _1 S11
: Q/ C+ ?' M* q$ d0 X$ U0 `0 t12
" Q4 R, t' @/ _' a C13
& b# n Y. x+ ~14
# c5 c4 Y3 w8 a( Y' ]15
- F7 a+ w# T- e% V" h16* J# d6 p& Z; s3 o9 B3 ~
17! c7 A% Y' P( m N! }/ o
18" G7 g# k( K( D9 h% a" r
19+ L, k6 A9 K: L( I. U! X' t5 L# U
20
$ A y4 K. v9 V, r0 V" r218 w! |8 ~, s( V& b$ t
22- M0 P) _% C ?7 o6 w" M D1 Z
238 a$ j3 I' o$ u4 T& j) U& W
24
. y+ P6 f# q/ R q256 C5 _ x6 i& \4 S4 Y* v5 f
26" R, f' |; F3 c3 u
27) m0 s8 {+ o2 ]3 ~$ U
28
, v' h! \ m6 k% {292 _+ o6 \, h: N4 G! I/ X/ @
30- \ W5 F( O- i& `, h* M/ y. @
31
8 W: ], ?; J3 ]& B2 t/ y32! U8 f& }! L3 Y8 C$ b4 i
33
/ q- D' b) r0 p e1 G: `344 g0 H/ p- Z' c' H5 l* ], m
35
3 l5 l7 M% I8 X6 c# D n36
9 k6 W* H; d ^4 d1 S37; C }$ I! |, K( u B2 b2 F( ?
38
) O. U: _( R9 O4 ]- b! n39, a$ G; K# K( N
40
( z) b3 q' V5 T4 O6 u/ A41 n, e4 j6 g! T
42
" w5 w, U& r& y; _; U" Z' r7.3 训练所有层3 y% f( g% B* x
# 将全部网络解锁进行训练 L' g- m: m! H+ Z X
for param in model_ft.parameters():) d* L( p) G. |! m
param.requires_grad = True8 m' f @1 Q$ T
# M* ~0 ~: J' S5 z0 Q
# 再继续训练所有的参数,学习率调小一点\
3 a: L$ a1 K: s% j5 n9 ~optimizer = optim.Adam(params_to_update, lr = 1e-4)
( q; \& u3 ]% B ~8 Cscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
1 ?' `6 O( q2 I+ w3 k1 K
4 i# z- l- r$ |" b7 x# 损失函数
f J6 n! Q& L+ ~. w6 [criterion = nn.NLLLoss()
' g6 r/ S' _! f9 Y, H5 f6 I; f, W14 K$ E$ @- ^4 w8 Q/ S
2
; N% c% g s% }4 Z3
2 f C8 J) I; n# o0 u9 z& L/ ~0 z- c4
& f0 g. q4 e% Q+ B/ Y) Q5- @, W) P2 ?3 y, r5 W
6+ b8 [# o5 g2 T9 u
7# E! l: F1 x+ }2 i- a1 ?: ?
87 |- b. ]' `! B, a7 d, w: `
9
# W$ M- r4 K c, C$ @& F q( b10$ O' n7 m' ]$ o. D7 n9 v* b
# 加载保存的参数; C. Z" p& _* D- G$ Q! z6 t
# 并在原有的模型基础上继续训练9 V& T& I" B+ `7 @- A, D$ v" t
# 下面保存的是刚刚训练效果较好的路径
# t) x1 l: U- |: Jcheckpoint = torch.load(filename)
; p0 k: ?3 L) o; u# z' f5 w# Ybest_acc = checkpoint['best_acc']
. p9 P5 U4 _2 u7 z" f) i. r5 D7 ~model_ft.load_state_dict(checkpoint['state_dict'])7 @* H& r# e" J% P& a
optimizer.load_state_dict(checkpoint['optimizer'])! C# Y# J# G. x/ Y L+ q
1
, j2 x5 b% M& q2
: p0 D: ~- e6 K. |( k3/ [$ z7 Z8 C. i& Q7 ?% [
4
6 z! D2 T6 o2 i% ]58 x, R6 P4 @7 r" e
6$ x7 j; M/ K5 z# g5 e
7 d8 ?$ p$ S% @2 ~3 Z) X
开始训练7 e% c9 T$ s( V4 G& z( Y, N; ^
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考 d: r# d& e; ~) Y7 v& N& A( t
; }5 V, D$ ~9 P% K2 nmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
/ h3 r% E& m8 {1
. s# W( H( E4 oEpoch 0/1
( J& V# g7 q4 A----------0 }9 u: c. u& m" a! I8 @' W
Time elapsed 35m 22s
& y! c! V4 V( r. \train Loss: 1.7636 Acc: 0.7346* Y2 m2 {% c8 q+ {
Time elapsed 38m 42s2 m1 [2 Z% e3 t& N8 Y/ M; ], Y% s
valid Loss: 3.6377 Acc: 0.64550 b/ w( U7 G7 x4 I3 i+ H
Optimizer learning rate : 0.00100004 _* \: t- N; g7 p/ u1 l3 C2 t0 \
. L' s! R: H1 |1 n6 {- EEpoch 1/1
9 n( E( i5 L. a----------
* z6 m! _. u9 H1 |Time elapsed 82m 59s
" d! S0 K {# i2 u* A1 i+ l$ ?train Loss: 1.7543 Acc: 0.73404 g# ] M+ y- b% o7 s* B, \
Time elapsed 86m 11s) V! Z! f7 Y$ J2 d2 {+ M+ `
valid Loss: 3.8275 Acc: 0.6137% f% e8 \4 j W
Optimizer learning rate : 0.0010000
9 G: x$ i+ v* K$ p; f+ `
+ k; b# i$ Y }; \7 P) hTraining complete in 86m 11s' m% J: U: S. q- c0 t
Best val Acc: 0.645477 O" ~5 m1 L( F4 I. d7 C4 w1 X/ c
" }% `! W$ p! W( q' Y: s; _
19 h. A3 ]' M$ D- c+ g0 B
21 D$ Y A, g7 x
3
, O6 k9 a# m" l1 j9 Q4
$ q9 {( Y1 M8 l3 @( W& z0 o8 {51 t0 _( W V0 s6 j" r' V
6, Q' y& [: F! C+ q- B8 I% J6 t
7
$ `' s; R" k. X0 k/ [8
; z g( v; R5 W9
: m$ @6 b$ n' c. z) u; r106 b% o- V4 W( J
116 l+ A- A9 G$ C: j
12
( @6 Q0 K! Q2 I, ~2 k6 U) h131 w+ D' t$ a5 i Q2 `: ?) m
148 J6 M4 Q8 u" {: I8 i& t% i
15
) Q# o. C. f8 B! w; u, j7 ?16; b( M+ B8 j! N
17/ m: H9 k/ \+ R( c
18
' o# I! }* H( a- f* J; g P4 ?+ w1 x: p8. 加载已经训练的模型6 i/ \& J0 j q- Q8 k
相当于做一次简单的前向传播(逻辑推理),不用更新参数
6 `1 @: P* L' B- s0 d, Q& c+ G; b+ y0 A
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)7 v0 w9 G) L4 D D( N: I3 y7 _; ?
# r9 r) W3 U* t8 f4 m- y8 q& `# GPU 模式9 v) T1 Y& a o4 t0 L* @
model_ft = model_ft.to(device) # 扔到GPU中
0 Q0 P" ^7 S m+ [5 b2 r0 P9 {4 X7 X5 U7 V% ?' _6 T6 o
# 保存文件的名字
7 }. J' f* B/ v: b! W/ f) F8 U; Qfilename='checkpoint.pth'# Q1 F3 q! M% H/ Y' Y2 p# W0 L
\: ^. J5 W5 F7 f- O! x# 加载模型- v, E' G3 `* G0 Y4 C
checkpoint = torch.load(filename)
! O8 R# x9 N' v% V, U, I- ~best_acc = checkpoint['best_acc']+ h/ k1 V3 c+ C9 J; }
model_ft.load_state_dict(checkpoint['state_dict'])3 D! D* z ?2 B
13 O( c( M; [1 q- _, K& u
2
2 g6 J' x* c V( [* p( e3
# K6 z& t% b3 |- u. V6 y0 d4$ m8 T% |3 L2 p( I& f
5; {; H9 ~) _' j" V9 J3 ^
6
' Y& {/ Q4 O7 e$ @& q, w( Y7
; B7 r; K$ o6 u8
% ^& f9 a( ~' f% I9 f0 [9
" k5 x! X5 G3 H! H, q D10 l. y4 y: ~ ~, f9 X, w% H
111 t: P# u4 Q% ^5 Q9 d
12
- o8 f) c8 L A% D<All keys matched successfully>9 W/ h" N6 d$ u8 b
1$ q! x) d/ O2 a. I2 W
def process_image(image_path):
" I J7 T+ C$ X7 M2 ] # 读取测试集数据; z2 {' [: D* k0 @8 I
img = Image.open(image_path)
6 p# d, A' D5 Q # Resize, thumbnail方法只能进行比例缩小,所以进行判断* R* g* m9 x' e( O4 F
# 与Resize不同
! ]$ R# v) l; y$ o( b C& ` # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
5 k$ E' v* K3 ] D3 ^ # 而且对象调用方法会直接改变其大小,返回None
9 @4 Z5 Y/ ~" o3 ?9 B, u! ~/ z8 |8 s if img.size[0] > img.size[1]:$ o* u) j, `& E; Y" s) T# m
img.thumbnail((10000, 256))
5 W4 ~8 m: Z& l. t else:
9 B3 a3 ~: v4 ~+ T) P img.thumbnail((256, 10000))& b( N% _) |+ W# `: V! s# ]
$ x1 J5 n& d7 E' J5 J
# crop操作, 将图像再次裁剪为 224 * 224; W0 H# k. u: \& p7 b; Z
left_margin = (img.width - 224) / 2 # 取中间的部分
5 ^ b( e1 ~) c bottom_margin = (img.height - 224) / 2
+ ~/ E3 ]6 J" X( n: v4 Q( G right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度" F; b* r. c7 f# D$ o
top_margin = bottom_margin + 224
' V! j! O, D# {% h; f8 S& C5 p' g8 d+ f# Z. w2 M
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))* a. E- \- t# v6 d5 D# R+ L
& y# L& S- ^6 \: i2 `; f( o$ f
# 相同预处理的方法+ R- O! X2 y6 U! ^) n' q6 o
# 归一化" {2 Q8 Z: X+ v3 x& y
img = np.array(img) / 255
+ p/ |( V: E! I+ i2 h6 v% h mean = np.array([0.485, 0.456, 0.406])* B' z( r( S1 q) L' A
std = np.array([0.229, 0.224, 0.225])
+ k2 f+ C$ } }4 e9 \" E img = (img - mean) / std
5 H1 P) D( [( Y9 `/ s) g9 o1 y( X6 G" m+ ^ H" z' h' o/ F4 }0 s* p% R
# 注意颜色通道和位置
0 g% _' u4 e. N# W img = img.transpose((2, 0, 1))
6 X( I. J7 t& ]. j
- d) M& Z6 g9 R+ q3 q7 f return img
4 h( j. Q3 ^9 z8 v7 ?4 x% v/ _" B$ C% C" Q8 O) z
def imshow(image, ax = None, title = None):, M1 w( ^' v/ S* O0 y3 X
"""展示数据"""
p+ h: q+ x2 {9 a2 j if ax is None:
B$ }4 z; Z1 M! a fig, ax = plt.subplots(): l4 L; X+ v2 y9 T
0 y# A" D, @+ C
# 颜色通道进行还原
: T9 f2 G; u- Y image = np.array(image).transpose((1, 2, 0))" r! A+ z' p; H; p" `6 ?
6 [( l. ]1 M+ \+ Z+ ^
# 预处理还原7 C% v+ Q% t2 P$ N
mean = np.array([0.485, 0.456, 0.406])
, N9 Q) B% [/ u1 p std = np.array([0.229, 0.224, 0.225])
4 f) v* p U" d% E, @' Z image = std * image + mean
; Q) z1 B1 X- z8 j: e image = np.clip(image, 0, 1)
0 Y* c3 A5 d/ d, g. u/ W/ M5 a
& z$ B, k/ w! q g4 @. ^9 h ax.imshow(image)
1 n5 Y$ Q. ]; Q ax.set_title(title)4 F! X) Q& p( K3 D& | Q
+ G0 f# ]$ }& a( c
return ax9 ]* r S, S2 F; ?. i# m- z
: ^% f9 ]- B$ pimage_path = r'./flower_data/valid/3/image_06621.jpg'
0 Y% B* u) ?& t/ Y) q: |img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
# e _$ R, F" U+ G# V6 H" K7 [imshow(img)
* c# M( |7 j( j: Z7 L2 z: A2 V [$ X- X
1, h" k; ~! K% K) C: _7 R5 o
2/ e8 t0 _* l/ |. o2 n6 X
3
* I: ?0 K% H) L8 Y5 {& r7 Y4
2 Q8 K; l: p \+ S5% H7 O# \+ k8 @- I% T/ b" P
6
) Z/ V# g) i o0 V4 t) ^0 o, \2 X7. j. H! k0 ^4 `( P$ Y3 _4 h1 _5 j
8" O! i3 F& L( U
9
& Z# e% t0 }. C10
! W6 V/ W0 R* |, G! P' G5 B11
* H. o; `3 p5 |8 e% x( ~12. h7 K0 f* ]4 `& d0 t( V' q
13+ @% a% ^9 d. r" x( O. Z0 |) r
14% V* z, L: W! ]' D0 t* g
15
3 d& D Q+ I) T/ {! I7 C% r. c16% x/ h3 {5 W& x* S- c) r3 U
17
. B+ ~+ A+ J7 A: q18
" q% d- s7 L' M9 G( I3 w19( `9 z @2 x% e2 r7 p3 t# `+ d. o
20
2 p# A. h* k" K3 o' S& } I21! C" N; U/ j' L" i
22
, B0 A8 S8 r4 F3 J232 c3 h ?: `1 l! l4 y' x
249 P9 J7 v5 K+ n; L! G
25
& ]' N C: z9 h8 Z0 ^26$ s' u1 n! g8 u; d; q8 [
274 ~4 R3 f( g. k9 e
28
" m) N8 C2 E# D% J+ m29
e; A0 O/ `; |; n1 ~30: o+ Z5 M a( d; P' G
31
: n, u4 L+ Q4 \32
4 `0 W2 u8 [; D33# W. k5 t. S$ n) d9 _
34
/ ^. q G# t9 j8 g0 v) f35
# K6 }3 N- J8 R% _$ F2 {36 X) @7 U9 _8 z! z6 a
37
! l! J' U9 d/ l7 ~* F/ s% }8 o38
! k8 p2 p, w; J7 V7 u39
% a1 v9 l/ P! a; L7 k- u% T; Y: f40, d! K! f9 N% f
41% Y% O0 a; K) q+ E9 b8 Q% w/ Q
42
. d: @$ f8 A2 B' ~' a43' d1 s! Q' \4 t! q G. p
44
# V2 R' ~, Y* ?1 y$ x" [454 |0 o P4 g& u, \ C# _+ |
462 ]9 z$ F; H) S8 `' o- K
47
, Y. k+ \, Z' d2 |7 z* {* I48
3 `! a' N- h" J# {( e( g+ E49
4 f& S6 ^( Q# f: v) A# p50
5 I: O% T- W* ?) q" \3 M: R51
( a1 m7 O8 e. E) {- m) @52
9 j) D9 w' b* ?6 T5 p53# N! i4 r4 Q5 }, M( A5 l9 M: \! y
54) G1 U, ~( W+ j( b7 x" p s
<AxesSubplot:>3 J4 Q, f, L9 U6 x/ S
1. s. y" C( H t4 [
, m1 Z3 k2 d3 Z- T2 d上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
$ T9 l7 S' _& B! I$ |
# a- W; J" q3 n7 n ~6 Yimg.shape
, _; e$ ~; { H2 w; y* z15 H: w/ C$ Y9 k/ M$ d
(3, 224, 224)( z. W6 g4 g# m R& C# ~
1, B# }- @1 M6 Z+ Q0 N" U( L
证明了通道提前了,而且大小没改变
6 v& w9 L/ z$ ~ X5 Z _7 Y0 }' C" r
9. 推理
; l0 |" ~' v1 I6 x* l! jimg.shape
2 [0 X& w7 q" q7 M0 k% @( C: |0 x, U0 s t/ I" U. {4 C) Q. H
# 得到一个batch的测试数据
* {* t5 |5 @$ Q7 d3 jdataiter = iter(dataloaders['valid'])
* q- s" s+ b! f. |& I% H# `4 R4 ~images, labels = dataiter.next()+ @& | Q% h5 b: M k- i2 y) h
, D" r. U3 G1 D
model_ft.eval(): D' H; E5 v5 m- y+ E
* `' L( x& o8 W: i4 f9 h
if train_on_gpu:
; D- }$ O; a) @6 N' X! Y # 前向传播跑一次会得到output
; C6 |' `, H+ k% v5 C1 e- d- P r: ` output = model_ft(images.cuda())- y$ {$ M& G X/ w/ s. ]9 s
else:- o( k2 P& y" f1 V6 j
output = model_ft(images)9 D7 C2 g4 \$ `( @
5 e: d$ x& v) o$ \# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
! Z% _- l+ k5 Z( K6 T5 \output.shape
% i K+ v" q% @5 X+ O$ U w, ]1 S. m1 m" v
1
# H2 e$ ] J5 F: {: [- x0 }2
- F/ Y r/ O4 V8 Y) E/ y3
5 T/ \: m% i( l- B/ [. f4& Y: @- ~. Q2 e# i& }# l
5) e& Z% @& H7 y. U" l7 ^7 i6 W
6: I5 Y- h. B! @+ h
7
. \- _% ?9 h6 A5 k8( X' W1 G. C' M7 p2 b
9- w6 z# p" a. R4 }; ^: z7 B
10
. t4 Z. F" b8 v/ I4 [/ n11# ?$ S. E9 p: l; c* K; e9 ^* W; [5 ^
12& D$ [" v) U. @
13; b. k$ Y' o# }+ { T
14
}% e3 T# {' O! V) p" L15
8 {. w5 d9 B: }( ]% F6 _16. q, k1 n! ~2 _* }
torch.Size([8, 102])
2 b; S1 m8 y: V; y2 E* ?1
& `0 W/ V4 B& `" X, \" W1 ?9.1 计算得到最大概率
* y! Z3 f1 T [1 G! x_, preds_tensor = torch.max(output, 1)
* m, ?7 t: k% G: g
, m6 G9 f) f. Q' }- `+ T. Epreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
0 {* q3 x, p+ u1; o: v4 Z0 t5 t7 W6 g3 `
2
' @! x, a; r1 a6 J$ q% {% D* I3. C- k1 B9 H7 @4 W$ Z, c
9.2 展示预测结果
. ^+ I! `8 D/ t" `% C# ffig = plt.figure(figsize = (20, 20))
* J$ {7 _ A* S# D! Q3 mcolumns = 4; [2 S7 e' V& N# z# ^0 W
rows = 2. I9 H% U1 {2 f2 p1 ]+ y; l9 }+ m9 x4 |/ o
! G) ~: i6 O" Qfor idx in range(columns * rows):' z5 }5 X+ H) R0 Q O+ R4 a2 H
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
, ^1 O% S9 h: ?) L& Q plt.imshow(im_convert(images[idx]))
3 I: K R6 Y" E Y ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), . N) \+ |4 `, l; X5 P' `2 c. f0 i
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
) Z1 V6 j" t0 R: ~2 x, e4 Y* fplt.show()
, z6 A8 u2 Z o* t, H, T# 绿色的表示预测是对的,红色表示预测错了3 D0 G2 b1 L. s; `
1+ Q8 H4 A3 G$ i: ~1 g# K
27 r/ B0 z! v, H) G% |6 i( p
3' U% ^; k) r, X! f9 N
4$ r4 b2 B$ q/ r. ^( W# D
5
3 j" D6 x7 Y0 n* K! ?) {6
' y$ g+ o" a0 M; O, y7, ^+ R4 D4 O& u" q8 i% B
8
_& Z& S5 q! w1 r9
. f+ a: q$ _4 {, a$ b10
$ q/ j' T+ `0 @( N11- L) S `' s$ j/ H8 m
8 `+ b: d% [: X' W1 s2 r
, X* w% m1 m4 @4 w4 L _: z P1 q1 M$ a& X$ g9 S7 w* D' S
————————————————. ]: h' y$ M- Z6 ?1 v
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。4 r0 I& s k2 V! V- y
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
- J' ~5 Y! M7 W8 j
. E8 s' d m4 E$ A- n) [. N
8 q( s* b# H( f0 |; b0 }! F |
zan
|