- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 557411 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 172596
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 18
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例" V. e) _1 b. ?& Q& Y# X
8 S8 h# I% T, C3 }1 K文章目录$ i% D0 Q2 w% O* G. E& a7 z
卷积网络实战 对花进行分类
2 u3 ~) [' l5 C I; ~5 Y5 j! V数据预处理部分8 j6 |: C" D+ z( K$ u& L( i
网络模块设置8 [( Z% X6 _ Q1 F" `+ g
网络模型的保存与测试* d3 M2 E. [* f
数据下载:- W2 X. w' g# n5 J
1. 导入工具包
3 R; f- x0 ]! Y+ i4 K* N2. 数据预处理与操作
) d+ D& @. Y/ ?6 N3. 制作好数据源4 g% z- J% B9 Z+ A
读取标签对应的实际名字5 i6 p. h- @! K# p
4.展示一下数据
% K, y7 w0 J3 b8 m2 _5. 加载models提供的模型,并直接用训练好的权重做初始化参数$ f% ?8 X- @7 n" L& C3 |
6.初始化模型架构
5 d7 v3 G* k, P9 n/ h6 W0 y" r+ R7. 设置需要训练的参数1 Y- a! {3 u' b0 A
7. 训练与预测
( |4 w4 X! Y4 m: k: Y0 R; E$ I# d7.1 优化器设置
* s1 s6 B6 _' T( Y+ V0 j2 w* v7.2 开始训练模型
; B2 ^5 L/ l: k6 n7.3 训练所有层! g/ q; N& C3 s0 c C
开始训练
9 H6 V- e7 V2 _# J b( V! ]8. 加载已经训练的模型; j3 ~, @0 V3 f, {& u2 q
9. 推理" w# d* @8 m8 {& s' V: u
9.1 计算得到最大概率1 |7 `& [: u7 O; h: K
9.2 展示预测结果
) W3 k# B) W; k+ }9 c写在最后
/ ^ [2 F7 }$ o0 R* }卷积网络实战 对花进行分类
1 w+ }, @ B: Q2 u/ `0 W本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
! b9 Q: X. Q& K2 N5 `* j# X' w% M% e& l
在文件夹中有102种花,我们主要要对这些花进行分类任务3 m2 M( ~) m) }. F- P2 O6 D1 u3 u
文件夹结构
& E: [6 h% o! \% B( r
% {+ O9 G/ z1 _1 r3 w; Wflower_data$ ]# k0 M# r4 o% { `" h
% H5 H. s" ~4 u) W% mtrain3 v4 w; I% A; M. W
% d- e+ k/ h' z4 \% O1(类别)
* p6 N& W% W7 K }" N/ j" d2. I/ [1 y/ W+ e0 ^) D! a* X6 W
xxx.png / xxx.jpg/ h5 M" _ j! X% ^: W; c
valid
" l/ D0 }% n7 x: k# U- p/ G% H. ^" ~/ M0 E. L' p
主要分为以下几个大模块
3 `! }( K! F& M) m9 u9 S7 \2 R: {- z& N0 J8 z3 Y
数据预处理部分2 [9 _; I/ @8 m" a
数据增强9 @" r: A( y: o3 _5 Z% f. U
数据预处理# c4 t& q5 l5 @& G) N* y- m
网络模块设置
5 P2 O5 Q: M; v0 W$ O6 y# Q' ]+ T加载预训练模型,直接调用torchVision的经典网络架构
- ^2 @$ H& ^( }% t因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务2 q0 [& j- w( v3 X
网络模型的保存与测试
) ^& t2 i/ m) B3 t; R4 k模型保存可以带有选择性
# [3 ]& y4 f* {( ?8 @, A数据下载:, X ?1 a; W0 v
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
5 B" l% ^" M" X$ `7 c: t6 I$ x" H) p% D& D
改一下文件名,然后将它放到同一根目录就可以了
# L; V$ e, O! O; k6 W
f3 b! k \% H: j0 H+ w+ h下面是我的数据根目录8 e9 Y9 V' H8 k, y0 Q' b
3 H' @! @9 e0 \& {. L k
7 |+ {/ S4 A1 q# A9 w1. 导入工具包
6 X! Y$ {9 G6 F4 A0 t( Z- Zimport os
' w4 }* `: q9 O1 z# ^$ {import matplotlib.pyplot as plt
9 I( x* h3 m4 a8 [# 内嵌入绘图简去show的句柄
& w& L! C0 \6 I%matplotlib inline * Y+ [5 v, C% X/ ^
import numpy as np
0 g4 J, L$ {. R" _0 gimport torch" N) u* F+ t6 D1 Y" {% S
from torch import nn# x- f" V2 |9 d6 c) X+ Z% p
3 ?: a8 _% ?8 himport torch.optim as optim
. J: a- I0 S* ~3 Qimport torchvision* Q' l8 A# p0 x) t
from torchvision import transforms, models, datasets, V5 r$ C5 t+ X v" x% W2 M" V
, \- B& }& |6 N# iimport imageio5 |6 _9 _' S* L$ {, U
import time
% b7 x, S8 e* y' ? g \7 G, himport warnings* l' g5 i8 n5 {/ A9 L/ j
import random
* J% x$ i( P5 e' T# D8 W) W# kimport sys5 u9 K/ E# O$ j, ?
import copy* J. t) l; z6 F# A5 w4 i
import json1 P5 }8 M) R) @& X! ^
from PIL import Image
0 P o+ Y0 p5 w8 M- a7 g% Y4 `) X! r( j; y' u) u
1 B. F% Y- s: E2 F4 O- p0 |1 c3 X1/ l9 K6 |7 u6 O
2" e; y3 ]: q& L: n4 y; b4 g0 d
3
; U+ E) f! j1 S: m+ `3 P. m3 E4+ h. `# P0 {# x4 {
5
# F8 z; r' C7 ?0 Q7 V4 Y. V# D+ u x5 A6
* O$ W, x$ v0 j: N J70 [9 w# |1 J j' `' K
8! f' C- t* p" A+ q! g
9) ^% O' C% w3 F2 b* b1 b' y
10) A& p+ }' k j0 I
11" D* {7 X' E) V B5 ?
12
9 H6 E1 `; Q- D: \5 T6 G13& }* u/ {5 h/ |1 F* P
14
4 t5 b; G l* J/ y; ?# Z) b' e15
r" M- k3 {. ~3 s1 j16
' m/ y' F% d9 t8 E$ g* d8 h17
8 C9 w/ v- S$ R2 ~/ m180 S% N7 Y1 y8 |
19
1 X2 _4 `$ b& o" J* [7 l20+ k+ s, y9 }8 X
21; Y. [' L7 F4 D/ y+ |6 u. Z
2. 数据预处理与操作
, u* S4 U/ O' d$ \, e2 P; W#路径设置) x( F/ Q' J( P5 m5 w
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
( O+ f$ Q# K/ `9 Z8 ttrain_dir = data_dir + '/train'
" O' e( ^. x" V/ Xvalid_dir = data_dir + '/valid': D4 p! g, }# X$ @; m5 y( g3 ?
18 {$ r9 w$ @. U! A
2
( e' P2 O( M8 V, `3
! k. W$ D3 w, _/ g n: \% @8 N4
/ T: i, q1 [& \4 t' R2 _8 w) apython目录点杠的组合与区别
% m Y- w6 { k注: 里面注明了点杠和斜杠的操作
+ d# k* I+ Y5 o7 R* q8 {. w4 {6 m9 O- }$ r9 k, o
3. 制作好数据源
- c2 R! w) b' T! d# wdata_transforms中制定了所有图像预处理的操作/ K+ V. w: ^( w! m4 z+ k
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
# F8 l4 D5 U3 B% k( {$ i; W$ c x8 \data_transforms = {; Q: Q8 o, c7 Y8 I
# 分成两部分,一部分是训练# ?0 \! D- r- ] L$ |# F ^
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间1 F' {# c& X/ F9 F5 b1 T' [
transforms.CenterCrop(224), # 从中心处开始裁剪
8 t/ c' D; W+ H3 d% V: E # 以某个随机的概率决定是否翻转 55开' C, O7 ?) ^( H
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转- k- R2 G* K ?6 {: ]0 L6 r
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
' D0 Q$ K9 u7 {5 M # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相( K( F- _. @( B" Y( d6 q4 l
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
3 A$ F( {/ s- N( e2 A9 g1 a transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB: T/ K, P1 G# T. S8 q, l$ ]9 u/ n4 [
# 灰度图转换以后也是三个通道,但是只是RGB是一样的- \2 M- O H8 c2 u1 i! c
transforms.ToTensor(),
; K0 L. ~1 |7 B; | N transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差* T1 B9 A- [# x1 f
]),/ C, u* e8 ~0 p! Y; w) g1 M
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
0 v4 e/ S9 @: }4 t& q* J* k 'valid': transforms.Compose([transforms.Resize(256),
7 s3 }% Y8 e9 j1 F transforms.CenterCrop(224),
/ u6 h. U4 ~+ O( B transforms.ToTensor(),
2 w# Y, _2 s7 ^6 `2 R" w transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
& `0 ]5 U- W" E! P0 A5 l& f ]),
/ A. G8 }; v1 Z$ |4 j1 H( b}5 Y! A; T" f: n6 S
# F/ G$ `4 R1 w5 w. ]1) `2 I" S: ?% t8 h
2% l Z8 \ }; D$ c/ _4 Q" E
3
+ l r1 v5 u" f; v, o, k' k4" b) M# T1 u. l% x4 @2 y5 ?
5% H6 k* D4 _/ q5 H# ?* Y, i: m
6$ M' F& q) A) s' @
7* u" t/ Y. H8 O8 _5 ^$ z& G
8: g. S. G! [9 `; F; D/ b; w
9% _+ }/ e, d/ ~" f: V; K# i
10 e; g9 A/ V$ Q0 _& Y7 W: ]0 X
11
: c) z- j+ K( @6 @9 R12
9 K1 }: F+ Y" K' H& i: H/ w- t13
) T4 _0 N$ H2 l+ B14& T0 y7 [& _7 L R
155 i0 c1 k" {8 S1 e' Z
16
" ^2 W( r) |# A. h. [6 R# z17
+ y: P5 J6 |( L2 }2 ?6 `- X1 ~18! B1 x; p: {+ ~
19
* q. F5 x5 e: m, R5 H20
0 M9 O! N0 \$ d5 C" r W* V21
3 }/ Z1 w9 t' G; Q' n. j) N0 Fbatch_size = 86 T# I, ^/ k) ?7 I( S+ ^' |
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}- q4 H% S' M6 z
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
, }1 h5 y1 T( G$ _4 U6 O& Ldataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
) I* X0 I9 |8 L+ e0 ]class_names = image_datasets['train'].classes y: j" ^1 f! ^* g1 N$ A$ Z7 Y
$ e7 M4 A# F; c7 b1 X/ }5 m* z, a) n
#查看数据集合
+ q0 J2 X) G! K4 W& _image_datasets+ l5 y8 k0 J7 N/ G; V( z( y
s+ J3 C# V9 d1
4 A$ [8 t0 I% _/ @2 S- P# O20 \2 l3 c# X2 Q5 E' o& V
3
2 a9 k& I6 Q5 I: s4 X0 X, d42 ^3 D! y7 F; u/ t6 p
5
" m1 j# ]$ u8 K, m8 J/ f6
% A1 L+ [( m# q4 N. T8 o7 D! E9 [7$ n: p& f$ C" y# a+ B* S7 S
8+ h K% z k: j
9
! f8 z% I$ g/ Q, B, n2 T{'train': Dataset ImageFolder) c& U( N# d% t* t5 O
Number of datapoints: 6552. `, U8 s6 F+ N& s6 x4 Z
Root location: ./flower_data/train
% F; R) Q2 a: K& b; h! l StandardTransform
3 y% I$ V$ d/ d6 P9 c8 V- u1 ]/ K Transform: Compose(# K+ ~ N% I9 \" R1 O
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)5 Q. ~" h/ e% B+ e
CenterCrop(size=(224, 224))
. X) \7 ]) g7 d. J+ k" |9 ~ RandomHorizontalFlip(p=0.5)3 A, x. q# O2 q: m
RandomVerticalFlip(p=0.5)
7 |/ m( p3 T( k3 D# b5 ] ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
w) k( v: A% v m1 U v9 o* c4 s1 z0 w RandomGrayscale(p=0.025)
: |. p# `' A) T) }# d6 I+ J ToTensor()
- f/ K- x q$ T, L7 ~; _) C" t+ h) Y: r Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])7 Z0 W6 p' @" J* u2 a0 t
),
# r1 A' i; a8 `. k2 Z 'valid': Dataset ImageFolder6 n7 e) z: X* Q5 W; L0 _0 k
Number of datapoints: 818/ X6 n' W4 d2 `& F% Z- u4 [
Root location: ./flower_data/valid
5 K. p: m3 S- j* X" g StandardTransform
$ e+ E8 L/ I! S! }! e0 b% q# V0 ~! I* { Transform: Compose(
g3 o0 N# u i! i# |% G Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
3 ` D/ M {4 }! F; ^ CenterCrop(size=(224, 224))1 }5 `+ |0 j* ~" `
ToTensor()
8 K- g2 H6 a, H8 |7 Y1 } Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# h. @ p* ~8 L+ Z' v/ }2 a8 j2 P )}
! n' j7 L: ~& Q9 q9 N' w- S
& e2 c3 ^3 Q* r1
2 \; A# E, K6 Q+ i& g27 c6 y. @. B8 }! d
3
8 U- G7 R6 R8 v( P7 g4% H: h0 A: b6 K, Q9 M/ L
5
, k& ^' o u8 V2 {6) h8 u; a- X9 L4 l6 |6 G
7
1 Z4 e+ d5 C, y6 T9 Q3 t* U% l8
) S$ m# E6 {/ q3 r9
& e6 r9 D) |2 }% b10
. r/ O1 j5 o# i W( ?! r11
3 S9 m: ?: w. l1 `0 |% n' ^1 ?8 E12
5 J ]2 M# }, F5 N( R13" e1 L" h( M/ i9 F
14" H$ m2 _ `9 x
15* ^1 }5 K4 \4 @- n' ?& X% |
16
; t) S7 `) z, w5 a% ?17
$ ]9 M# i/ b( O4 x/ i: T2 Q& o8 B18
/ x+ K3 D9 g7 Y191 S) A. S+ Y. o8 v3 w+ B% }
20
$ H, O; c; c8 O% z4 e) o4 z21% P/ o: {( J3 a# n. w
228 k; k5 W: y7 x
23
+ e- e. u2 z% d- A# E: Z" t24
* ~& y P1 W% |6 z/ \, u: t0 _# 验证一下数据是否已经被处理完毕# o V2 u& }* A* p% c
dataloaders) l' k) z/ i" w7 c& H) y
1: B( L0 g* L" o
2( E( o, z" P' M0 H! E$ t U
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
5 A: J7 e3 n, _$ |% l 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}4 H; N3 a, a @$ u7 E. ]( L' d" c
1- ]( Q7 g! n( ^$ g- d
2% ?7 T' _* B I o3 e
dataset_sizes
3 P1 f h( ^; P5 H7 x1# @- Q0 X+ f1 w% m
{'train': 6552, 'valid': 818}% f" u6 B/ G, N. w9 L8 t! w
1& V; [3 f6 m0 F0 ^7 x& j( [
读取标签对应的实际名字
! \; P: Y& S4 Q/ {0 n4 h' j& n) Z使用同一目录下的json文件,反向映射出花对应的名字
8 Z: Q4 n) G+ D$ I9 j b% u1 P' [6 j Y" y8 U% e
with open('./flower_data/cat_to_name.json', 'r') as f:0 A6 u1 I- v( }' N
cat_to_name = json.load(f)
* K u- r8 c2 D' e6 _3 i1
# ~ x, }' q) ?; j2
9 Q* \" Z3 w/ b/ c" w" Q3 b. ncat_to_name% @9 P8 H6 n8 g( {( Z( D3 ]4 Q
1, t% A# `$ y3 D2 @5 U7 w, e
{'21': 'fire lily',3 ^% u- x- f: I& f
'3': 'canterbury bells',$ H G% B0 t7 n6 v
'45': 'bolero deep blue',
) N) m s, Z7 `& ^6 J2 y6 i '1': 'pink primrose',
3 Q! ]" U! H, t4 m5 d4 F '34': 'mexican aster',* Z4 l6 @4 k" P; r: D( X7 v
'27': 'prince of wales feathers',1 g9 y. [; M) Y! o+ ~) s: R
'7': 'moon orchid',
) U) I, I% i: v! E- F '16': 'globe-flower',
9 b4 c! t9 S4 _5 N" D6 g& D '25': 'grape hyacinth',
, `7 Z9 G/ v! k3 k '26': 'corn poppy',
# F, e7 n. u% k: d+ e '79': 'toad lily',
) m2 B0 V6 N% n6 _3 X( u# ^- p '39': 'siam tulip',9 Q; M3 S, o) Q5 B+ c( u% C
'24': 'red ginger',. O9 {* I$ F& Z H9 l! U# w3 w
'67': 'spring crocus',
2 ~$ l" h V6 ? '35': 'alpine sea holly',/ j5 e3 E! J$ D$ ^' x6 C& y+ O
'32': 'garden phlox',
0 f6 G9 k8 V8 i* b '10': 'globe thistle',
$ H J( d2 D2 X7 u+ r" Y5 |* T' H '6': 'tiger lily',
9 \) f N+ e/ Z; Z7 q7 u '93': 'ball moss',
) f9 [- b# I N. A) ]( r% r2 |9 o2 _( b" K7 h '33': 'love in the mist',6 ~% O: r2 u/ \' ~4 ~( `+ p& C9 S
'9': 'monkshood',/ F* K4 U! G, q; v2 x2 S# u: Z( X
'102': 'blackberry lily',
0 a, `9 u$ G! L6 [: p) L# | '14': 'spear thistle',
4 j! n$ p" S- h/ e* S '19': 'balloon flower',& P3 U5 @' I; Q9 F7 m
'100': 'blanket flower',
8 B% S0 ]- F* d# } '13': 'king protea',# W& \$ f6 ]$ z( ]+ `/ P6 z8 a2 E
'49': 'oxeye daisy',9 w: b0 F2 `% ~3 O
'15': 'yellow iris',
2 E5 Q+ z' Q- X q- _+ X" t '61': 'cautleya spicata',% q& N: ?1 Z; m& i
'31': 'carnation',
/ H6 V: D* |, ~7 Y '64': 'silverbush',
. p; p) y d3 J" Y& M% s; w '68': 'bearded iris',
, m# D' I5 A" R: I4 n( W '63': 'black-eyed susan',7 R. @+ Q# Z( Z5 i
'69': 'windflower',
7 T8 j" H! ?0 s: a1 _) k; h0 R '62': 'japanese anemone',
( m& t0 D) D0 u1 O4 H9 ~/ N1 G0 i! w '20': 'giant white arum lily',4 a! _4 n& k! k Y0 ?
'38': 'great masterwort',
' E4 V$ p$ c3 q* ~ '4': 'sweet pea',
% b" b2 z, c& \ F '86': 'tree mallow',
7 u' e" m: n7 k, y& C8 z '101': 'trumpet creeper',
! C7 S) x$ t3 Q4 I, V '42': 'daffodil',
/ Y3 \0 N+ F5 C3 }7 j '22': 'pincushion flower',
, R! o( ^4 ~7 n- f9 E- N$ v '2': 'hard-leaved pocket orchid',
/ Z7 v$ T f' c( Q3 C. } '54': 'sunflower',( r+ r+ {) g) l# E
'66': 'osteospermum',
1 n1 Q* ~, v" ]6 {" ]& j '70': 'tree poppy',
d8 A* T, ?& T7 i% ~! m '85': 'desert-rose',/ `4 M! R% X6 A6 W0 w
'99': 'bromelia',+ x$ w, e# P" b7 g3 x y" \
'87': 'magnolia',; M( a! {% E z
'5': 'english marigold',8 W% ~3 ? I# [: V3 [
'92': 'bee balm',6 r x- h% C' h
'28': 'stemless gentian',- C0 m0 _- G5 p2 [# J' r y
'97': 'mallow',
* O- Q; @2 ?7 [. w '57': 'gaura',- `4 ^# L* ~8 x5 h! K6 F* A$ o
'40': 'lenten rose',
+ V; u6 s4 I0 U- ^5 }1 R3 ^ '47': 'marigold',, X& _7 g. \) V A$ T1 P6 {
'59': 'orange dahlia',, ]8 E' }: L( n* V9 x! k7 g
'48': 'buttercup',
" g9 [9 J! v! O) S+ t '55': 'pelargonium',* W2 k1 U6 C. d/ U4 w
'36': 'ruby-lipped cattleya',
1 W' P" |1 A$ l: Q0 y0 l& _$ H '91': 'hippeastrum',7 B4 D" t: Y( Y( ~
'29': 'artichoke',) b, ]7 T3 ?0 K- T! |
'71': 'gazania',
0 V. b2 S9 L% b6 _4 X '90': 'canna lily',/ v8 S# }. `8 l: z5 ^
'18': 'peruvian lily',) c) L% A, A2 S! a, K" g: T" u
'98': 'mexican petunia',! Q5 O8 D9 q9 X* ]; c3 D
'8': 'bird of paradise',
! u4 P2 v2 o7 x* f6 J2 e '30': 'sweet william',
/ X- o3 C5 f6 H# u I '17': 'purple coneflower',
) ?1 x& |2 J. @" c+ a '52': 'wild pansy',& k, t' W6 ~; w
'84': 'columbine', p' d0 Y( p8 M
'12': "colt's foot",
9 q% k% J4 |& U1 U3 P4 H '11': 'snapdragon',
q: _0 `; d; M '96': 'camellia',3 ]7 [ ~2 K$ H+ B! c7 ?
'23': 'fritillary',
# F' g& T) L! L1 X/ U9 H '50': 'common dandelion',5 {" A3 r, r+ W% C r1 U
'44': 'poinsettia',
+ _. f% P* ~8 O* F) G3 } '53': 'primula',
9 R# {9 A3 g1 d3 W; m: P/ e V '72': 'azalea',
n0 r( \# M6 ?4 ?5 c '65': 'californian poppy',0 m& y( j% j6 x6 P
'80': 'anthurium',: X; `" h$ u3 b1 r# }
'76': 'morning glory',
3 k6 V9 C4 N# i2 a5 I '37': 'cape flower',/ e1 H' c8 K4 X) _
'56': 'bishop of llandaff',2 ]2 C! i: N7 L9 I5 W) f
'60': 'pink-yellow dahlia',
& D4 ~# Z! ^ X) [; U '82': 'clematis',/ y, c" {( X$ t$ V) s8 C
'58': 'geranium',
7 V5 U$ p3 t+ z2 Y4 x+ N '75': 'thorn apple',
9 N2 `1 [ q V2 o5 C '41': 'barbeton daisy',
& k& I! C0 a7 |( [" o '95': 'bougainvillea',
. u5 R4 ?$ p- E; U& { '43': 'sword lily',
0 Z* Q9 d3 R' L; N+ Y1 F '83': 'hibiscus',
0 `3 ?: M+ f* ^2 S$ m '78': 'lotus lotus',
( R6 y$ X1 L( | '88': 'cyclamen',$ n( ^/ y7 S. Y ~/ U/ D9 P- u
'94': 'foxglove',
! X3 I) e+ S: {3 U3 N '81': 'frangipani'," Y- M- q1 ]4 J5 Q
'74': 'rose',
+ ]1 @" @( x9 I '89': 'watercress', F; T0 L$ E' o- ]- y" y
'73': 'water lily',) Q5 H! N8 _; C+ f1 d" H
'46': 'wallflower',
0 @' M. b' ?) B '77': 'passion flower',
' m" ]% O) f( s# c3 ` '51': 'petunia'}7 \/ W9 a* k, X4 s$ W' C( o4 m" k2 G
4 i( r1 H$ h2 a, [, `
1
! F& ^1 B- K$ m! c ~. p2
) S, Q% }9 S$ R+ S+ X" @3
, A( r2 R) [( B5 i& P; X+ P" h- D. j% S4
! i0 @7 o6 u5 a- _1 c, g5' Y% f- }, q( Z, ^9 T" \* x4 S
6 g6 y) Y# g' G) C3 J9 ~
7
$ I& q- w7 a2 R1 Q8- H" J, m# i, _7 |. @
9
& ]+ y6 w2 f8 h( t4 G10
$ C5 P7 E s8 Y( \! N: B114 q7 d9 _; S* l M) S
12
6 d) \& ^9 h% p9 T, \& l4 c; k13
" D& x5 `7 X, L; [, V147 F* w8 r+ `$ @2 m2 r
15) S9 O d( j: I
16
, R% e7 o5 ^, z$ r17! U6 V( E: {) h3 @% D
18. m% L% P* r# l" @
19
) k& ]4 ]: S$ U$ M. K20
% `0 ^& e. {; F0 k+ ]: ~21
8 k4 q: g# c) Q( c( _9 L: p1 [2 s22
) s# Q$ N/ B4 k# b23
+ K8 u. @- _" c0 G24' q& H3 o) u! r y) n+ x
25
* c+ F3 S' f: H! f; x* h4 o26
Q! ]5 N7 _$ k$ M27* O* g, P% d9 }6 g, o
28' G7 ?2 d' i* D6 M7 f! D5 M, F
29' }1 l& ?/ w0 N- Q: G% w5 E
30
$ U1 K1 [) ?5 c q% R u& ^0 T31% ^/ U8 D6 |$ A$ J( j& I. S
32
G5 S& R! ]% f% K, r2 ^33
$ F6 G {0 Z3 _9 ~+ v34
- ]5 _' {# z" d' H6 D5 k- ~354 k: b R1 a4 q! N* d
36
6 f x/ U) B+ ]8 ~( B% n. [) ]379 }" E# w6 I) v2 G, u
38 Y; I3 k+ P* k$ l5 H- F
39
) G+ ~. E+ ]* N40: x9 P0 R( C# g0 w
41
/ q% J4 Z4 `6 \420 n" _- K% J' G+ p) D% z4 o: I
43
3 c7 g; `1 T3 @3 ]) ~" E3 i44
0 M! r5 a7 W: ~7 V( z6 K# z ^45
! X, B$ b! ]; r5 |" w/ D/ D46 B& ]9 M" @+ t) s, ~
47; D/ }- h) f1 Q" ^. w6 _* h
48
2 E* I- F9 t' r- H r49
* |' |! O) [! ~3 k509 V4 z2 U& U. z1 j; n4 y! V* k
51, ~* p4 L* q. x- q
521 u u8 Z8 g% [3 T! S
53
2 v' b, f+ R+ j+ g ?" z54/ k# z( \+ w1 e {7 c3 p/ e5 @2 g$ Z
55
$ ~* V1 |2 `& D a' ^5 F56# E, B% B5 j5 y( ^9 r0 G) f
57+ ]5 z! x- B9 Q8 }0 |& M4 B
58
7 t/ B' V8 U7 ^% |% Z$ E5 P) ]59- C* R2 l# Y; H2 k. d4 j: r5 b! s
601 G$ T, Z. D6 [9 o! H: C
61
7 Q. Y, o2 p2 O+ o1 h$ i, o620 R- X1 u$ e' |& G3 A2 m# @9 {' g0 O2 e
63
) C J ^- Z8 W$ r! J64
% ^- f* Y! d& w* \+ Z65
! y) @/ g% b8 \9 y+ t66% F/ H2 }$ t. `. s/ }+ S }. ]
67
' V, a' U5 }+ A( D% K687 o3 g9 i! u A* S9 T& R2 l
69
8 i+ Y3 x1 \! K7 k6 d705 A% a+ W; I+ \" H
71
8 v0 q* R* a& l/ _72
& Z8 p( ~& c: C# m73. Z2 i1 g; J7 u. u8 R
74
! W) d9 ~) G# i! i6 _6 y75* e3 k; D: c* C; m2 E" f
76
. w3 ]% I4 ^$ {4 L2 l' U77
. u: `3 q+ Y; n4 t, s8 B78 c; n; E- ?+ @
796 ]5 t/ i* q) j9 u
80
) k0 a0 m9 }1 \: ^81 g, ?! y* c0 N2 B5 t
82! x, U" }2 E! Y) w6 J, r
83
5 z5 l2 Z6 `+ [$ e# m8 ~5 j84
( u& ~0 ]% q$ V- u" U& A85( M) t4 b6 }1 |, v# F
86. k0 ~. C# ?. n8 `
87' |% Y P2 \: `- ?2 }* x( m& L
88
- b/ k- P! i! g6 t891 t- ~ P. S5 f, ]
90
: V" X. ?6 G# B0 ?# N6 {3 u91+ } x# I/ z5 b1 ^: _
923 _) q6 G# }) l) A- d
932 ^; O# b6 ~3 J7 l' V) i j
94
# E7 ]. I- y" `$ J7 g" q' p95( }6 t9 `8 o* {7 f* o
96+ j9 i5 L4 v8 S5 o; h% B: i+ O4 }
970 }% h* M2 t4 Z: q' ~
98
1 ~2 d- V# s# ]7 _# ^99! k" K+ `# M; g% O( Y
100+ ^" Y4 g1 a8 w8 I$ x0 A
101# E& f% o2 {0 [' a* B
1024 w7 \% z0 i: S' U1 `# P
4.展示一下数据+ X1 i! R/ ]1 P4 N" x0 |! v
def im_convert(tensor):
7 |" N5 Q3 K9 Y$ Y """数据展示"""
5 q9 _+ J3 }7 R! T# D3 H image = tensor.to("cpu").clone().detach()! G1 B; a0 k" v8 P- E
image = image.numpy().squeeze()
7 i) ?9 b+ B+ S8 X4 l # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图8 B/ ^3 e% a* k# Y* y
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
* `5 x' P# B) s Z image = image.transpose(1, 2, 0)
2 g# {) H( t1 B5 X$ c' J # 反正则化(反标准化)
! G7 s2 {* P; w image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
1 v- {4 P+ L% N8 t* t
- o6 z3 m$ h" R; C2 n7 ^ # 将图像中小于0 的都换成0,大于的都变成1+ D; m: W. h- F1 G6 b& [ l" r
image = image.clip(0, 1)& V+ c, R4 h! z1 \4 U
) K; C( X! A6 V( _ return image3 j3 q4 f" G% F
1
( c7 q8 o$ Z5 q( F26 S% Z9 r2 ^' H3 |
39 x* t* m( B) [2 J9 {6 g& ?
43 {/ i6 [ s2 k
5* o/ N+ l- G8 l4 s h
6
8 W+ Y" h. L9 h4 `- \/ q- E! t2 h73 |- V" }; i" k5 f, T) B1 n
8
( S( S% e! X$ y. r. Z' T7 p) B9
8 @5 w( e: Q; d! y10; V# o6 m( H) {
11
" ]5 |# L$ ^" |7 t- M: G5 C8 n12
0 [8 h: T& R- ~13( R: @/ K5 \0 [& f& k* b2 l6 X2 p
144 x: e' G5 P t4 M
# 使用上面定义好的类进行画图
. B& p2 V$ }8 U# m4 D/ t, m/ ifig = plt.figure(figsize = (20, 12))% m, o- v8 J0 F2 W% d: N( M
columns = 4
9 F$ l4 P1 y- X0 O; I' U: Trows = 2
( F, h0 X3 [/ D9 A( f9 ^2 s; c: F. C
# iter迭代器, t1 S# @- A+ ~0 f- t
# 随便找一个Batch数据进行展示
4 R* U. R" h* _% J. Cdataiter = iter(dataloaders['valid'])9 f9 S6 D# y* x% ^& P
inputs, classes = dataiter.next()
$ Y7 W( n2 C: G" v F0 P7 ^
! S' N! \4 }: a7 O% Z6 T8 ?. ofor idx in range(columns * rows):; h% b2 t1 l: Y: y/ z
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])- U0 i6 o4 n; o( }/ Y
# 利用json文件将其对应花的类型打印在图片中6 v0 |9 s# [3 W
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
9 g0 B |' D4 g. i7 S5 x plt.imshow(im_convert(inputs[idx]))7 s* H, A7 G9 @- f' g9 ~
plt.show()7 J# q6 \2 G6 {9 ]! u4 T7 Z
% x" A1 b6 ]5 v( ^" m- p
1
; S: n( O; P) m% T' Y21 W5 z2 F8 B6 D; g/ d( o
38 F6 H1 Y# |3 ^/ Z( R4 }* k
4 q# i$ F& Q3 g2 t+ R
5
) }' Q6 I! j5 q M4 U9 M, s. J) w4 W6
/ j9 Y0 f/ x4 h* N7 m7 o$ t7
+ F. P7 ^* R! p& d8
2 e" i) k Y# q, ]) t3 C9
J ~; h: ^! s0 Q- S. u10
* o1 u" f' Z. @- g: P116 v) t3 F. K; l$ a. `; H
12) r" G# i6 d5 M6 h2 e
13# ]8 p6 H2 n( ^1 x' r5 v
14
) m1 o/ p' P) z/ `+ {15; o* ^- B, m8 Z, [
16
; X5 ` [0 Y: n$ z* K
$ r9 g/ k* p% ~4 Z/ W. }1 g" j4 F; b* @3 ^/ y
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
M" i# X6 }: Q" w8 N, N1 u# Bmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']% A4 m8 z- H$ [6 J3 m. o# j
# 主要的图像识别用resnet来做
0 B; t) r: r1 r+ u# 是否用人家训练好的特征
$ q! L0 p! r% P) i( Hfeature_extract = True$ W2 G, D* I; C1 P- @0 [9 g5 n
1+ [! M4 U" b; {% J- i4 o. ^
29 h! _. E* G) R/ D0 {3 B7 J7 n, L
3; T# O5 m+ D& ~1 \4 g- g
4
- {1 |; \% S! I# 是否用GPU进行训练
1 T- k" O# Q8 a4 H qtrain_on_gpu = torch.cuda.is_available()
. ^3 Y2 f( }0 l2 Y( }. z& D2 S' d$ W" Z6 \+ U) N8 S
if not train_on_gpu:8 S9 R# a3 o! D, Z
print('CUDA is not available. Training on CPU ...')
7 |, U! v: p# ]else:( L) h" y6 r" f0 z+ s" ^
print('CUDA is available! Training on GPU ...')% N! ~8 N2 v$ N6 Z9 J* u! s8 o
+ ?4 Y, k' e3 I. W
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
0 h7 |6 \. o: r* f& l2 {1: k$ m) Y4 n \3 ^
2
5 a' C" V1 J! w36 h6 [$ D0 g6 l3 y
4) `4 ?3 o8 C2 i6 I! q
5- j: d- v- o+ U: S0 l& i7 u
6
' z4 x" h, b* [ R* |- [, E/ }" U( m! I7
0 r9 m0 a+ L0 I0 U8
0 n& L, i% i# K/ w- k0 S8 Y# Y9; J' w; o2 B- k
CUDA is not available. Training on CPU ...
1 I* U! r/ T8 b" r5 b1
, J, ?0 j4 L7 H" o" w# 将一些层定义为false,使其不自动更新
( q' W- c$ J" U% @6 N" s& u0 S& U1 Sdef set_parameter_requires_grad(model, feature_extracting):/ O& _3 J- ~+ ^$ y' e$ w
if feature_extracting:- G5 R. B( ]) J4 W+ j" i
for param in model.parameters():! g$ B# } [( I0 S" R
param.requires_grad = False9 W" B0 j# R' |( V' v0 w( F
1
# d8 b' c- G) J3 I1 x4 ~2" R# y7 }" v2 Y f( k! G* t
3# v2 W" f. e# v& Z
4
+ T# I: a$ `7 D3 S5
% b/ K, Q8 @4 _ [! l. Q7 _( q- _) m# 打印模型架构告知是怎么一步一步去完成的
6 P- q, u4 I J4 ]/ i; [% a6 {# 主要是为我们提取特征的
; @( X3 e2 L8 }
5 [+ _6 h' l" r3 ?model_ft = models.resnet152()# H2 e" u9 B) R" X" j! _
model_ft( r2 M/ u$ C- P6 G8 h& ]' G7 u
1& J( Y) i6 H) ]- R* ]5 I; L7 ? a
2
2 w6 _) t; m* x( x$ t$ s+ `3
6 o+ `0 B" E. ]4 u2 @! {* E+ j* i4
& d% V2 X1 ?; f8 o5
$ M O1 \6 X( t, L0 M' \" jResNet(
/ j! b7 ?2 h$ V7 T (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)8 R) w6 ?- z$ d. M( y+ b, D+ p
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)$ L& {$ B' R' {$ g, T4 w
(relu): ReLU(inplace=True)
" }) d3 |5 i' r, q- @: } (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False); `& E T M" t9 I7 f) z
(layer1): Sequential(
; d1 A+ ?8 g% l5 @ (0): Bottleneck(' g# d/ u4 `6 j9 q% r
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
5 V+ h1 v# p8 m7 T (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
7 d* @5 }& U- P) L3 p (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
. E! j( U {7 x' A5 y5 x5 ` (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)+ P; Z6 \) O0 c9 Y; D' [
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)& ~# F3 k. U" q/ v
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
1 Q5 R. V. {3 Q1 q (relu): ReLU(inplace=True), W6 A- y+ v9 N# t
(downsample): Sequential(8 }7 }5 [8 A% _7 k
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)- H( W! k2 N" d0 g1 ^2 e
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- }; x8 z4 f4 r4 g7 p# T2 G5 k )! N" {! ]0 [& ?) [1 Z: l
)$ @: q2 S3 N5 E& N0 X9 l& ~" B
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。( G4 v" M: W( }+ Y. D
(2): Bottleneck(
; m$ j9 }3 H3 V% U5 m. W (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
5 |: K" F* ?, _+ s( g) K- W' p- y (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 ^# O( U7 J- w/ I% T2 X" } (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)6 Z# k+ ~' l- i ~$ _' v' Z7 \
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
% d( U2 J; `) h4 s+ w c* D( _+ ~/ j (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)2 T% k" m& M1 c( ~6 x
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
4 K, m8 G1 R+ E z (relu): ReLU(inplace=True)
% O x( h" x: o0 g4 h) ^ )
& F, w3 @* W! l8 _( L )+ v c6 j/ k5 }( ?7 L
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))5 K/ ]: a" H$ E. ~
(fc): Linear(in_features=2048, out_features=1000, bias=True)
, m( m# E e% u0 a)
; l" e5 T6 x. V; n5 Y! Z) J" z0 \+ z1 B$ u
1
# U6 B4 E+ y4 n# A2
) j+ z6 ?- s: Y7 ?5 y! K3
: Q" B) M# ^( v0 ?0 B5 ~' g; a4
$ D. y' [4 j- g5
+ B3 [/ A* C+ m* n6- g. X; O/ g3 K( d) x
7$ i$ H( g) U1 Q# y: W2 W
85 b# w9 S9 s* v! e
9. P: P6 W3 e$ s- |2 X
10
' J, f% J6 P1 c+ W w2 A0 y11
: }+ Y t2 i" D! k D. g12# Z3 Q6 m/ \" Z- j* Q/ |' c* x7 N% u3 Y
13
4 B4 K# K& F! B# [# ?: \14& U1 ^2 U4 W- y C# T8 |4 s2 h
15
" i" l2 W5 q8 H6 q" @& ?- V3 [16
- r8 ^: X1 m" ^4 g7 e' A17% T" K' U: z+ f4 i
18, e# q3 S* t# P5 N$ l
19
0 B! t X ~# F4 _3 ?/ n208 |* y' m) V& h
21( h z8 d: d, z) d6 m/ k
22" m/ c% @* N# Y. O( q
237 @" p3 ~2 D: o
24
8 D! y5 o1 q% u Y9 k) f( q25
( s" L& X6 ?6 B- U; [ f26# ]4 x0 y2 D$ ?* R6 Z
27
, I9 K* S4 J) y6 B" w' J; r, ?$ d% O28
' ?) U/ V+ a9 d' P% V29) J, x& P% S4 d5 P# s! y
30
9 L0 Z2 y/ g( U$ S* g31) F- S- v$ R- k' C6 A3 @
32! v. V3 {% K5 d6 }! \- S3 a
33
+ n/ d0 f0 `0 Q4 }最后是1000分类,2048输入,分为1000个分类, p" V! m7 g/ f* }( {% y
而我们需要将我们的任务进行调整,将1000分类改为102输出
' @& D/ _( [' x. K* C5 U2 G
$ f' U. L$ H' v) U6.初始化模型架构
" O1 W7 r1 c' T; A. U3 d: g1 Z: q步骤如下:4 p" n% o; y" [$ _# ?; }
. {4 }5 g& u% A! N) \* ~* G5 p将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
2 l/ z) w2 Z0 |. k可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
% n- J9 Y8 n/ ~) C无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数6 U- t( ^# h8 E* p: N9 P# a0 C! s, T
官方文档链接
2 h- t/ A4 F" m+ Ghttps://pytorch.org/vision/stable/models.html
$ |" S8 g: n& {8 ]' m6 q/ x& @
1 z: X$ g! v$ e, k! p5 P Z# 将他人的模型加载进来
( u% d& d @( s: o0 t7 ldef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):0 X! p) I" j1 B# X' z( z
# 选择适合的模型,不同的模型初始化参数不同
" m5 C2 a9 E4 s! P model_ft = None5 e! u# z4 z, f P: m+ f) G( o0 _
input_size = 0
$ \1 `" c+ u. r$ E: s2 j; m3 D/ |, _' z D% Q
if model_name == "resnet":
8 x8 h9 J8 C2 p0 C """
: S# F+ W) g' q) b; c6 r Resnet152
4 O/ |: ]8 ?! U6 {; |" v """6 U* L3 ~; A! o$ j2 {
) F! A& v8 I- J# ?; H7 `
# 1. 加载与训练网络
1 `. W/ h# V; t2 D& D model_ft = models.resnet152(pretrained = use_pretrained)3 f: x& o7 H9 t6 N2 b( R% L6 p
# 2. 是否将提取特征的模块冻住,只训练FC层
3 o: n4 V' T0 @4 q. v0 v set_parameter_requires_grad(model_ft, feature_extract)
& I' K/ S# w. o# _5 M # 3. 获得全连接层输入特征9 Y( X$ g, ]. }- @/ h* }: a6 R
num_frts = model_ft.fc.in_features
& G( z' k/ i i' c) B" J( W/ q# Y, S # 4. 重新加载全连接层,设置输出102; P3 E P- r. i
model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
* }8 ?# i0 I5 r1 G nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
( F/ \ t5 |1 R* u! r0 d input_size = 224# C, c& F. @4 q! U) _: n! I
9 d' b9 L$ A) \5 v- c3 Z5 `3 o7 E
elif model_name == "alexnet":
% ?* O5 w$ H& O! I3 F, W }& m """
" u4 k6 M2 U$ N/ C) W Alexnet' R( U/ ] ?* J& w
""" h3 N, y" k5 {* g6 Z2 _
model_ft = models.alexnet(pretrained = use_pretrained)- C0 @ b4 `$ a5 e
set_parameter_requires_grad(model_ft, feature_extract)# W6 s: T; i; o& l' L
; K% D/ ]' y5 ~ s) G
# 将最后一个特征输出替换 序号为【6】的分类器
2 |( N% ]& h& r* N num_frts = model_ft.classifier[6].in_features # 获得FC层输入
, F" A9 }1 D' q% r& s4 C model_ft.classifier[6] = nn.Linear(num_frts, num_classes)4 ^" ]5 H8 w1 B: _8 R. Q4 D/ I
input_size = 2246 W$ u; e; ~, n$ r7 I3 _
9 S! I3 Z! g* f0 E9 n
elif model_name == "vgg":
3 l" n C6 z+ k2 k! B """
( A% S! ~# K2 ^/ O1 X% Y- m7 K VGG11_bn
% F- K, J4 X4 q- @* T: | """) v8 c x7 V9 d
model_ft = models.vgg16(pretrained = use_pretrained)- ?: A) T9 m. c$ l1 d+ z q3 i' k2 ^
set_parameter_requires_grad(model_ft, feature_extract)
1 d) Q$ i1 T1 u7 Z0 c num_frts = model_ft.classifier[6].in_features
. n4 ?; |0 Z; K! V) G y model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
8 V4 }2 \2 h- P9 ?% H9 z3 _0 f( l input_size = 224
' W/ Y2 ~$ v6 e. Q( v' ^
- ?0 c% B5 p, @$ F' }1 a( s elif model_name == "squeezenet":; ~3 k5 Q+ l4 m k
"""% D2 K; d _" t. E7 ^; w9 O4 a
Squeezenet
" {6 y5 J2 _7 u6 G5 _ """: x7 H: T' C) Y3 \7 l/ C% _8 a, _( ]
model_ft = models.squeezenet1_0(pretrained = use_pretrained)
1 |) l7 A6 l: R& _9 s set_parameter_requires_grad(model_ft, feature_extract)5 m/ O& @8 E; m. s( j# S
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
- H8 j3 P9 b A9 f model_ft.num_classes = num_classes9 [5 x+ J, o, ]: E
input_size = 224
" W# @# N E, I$ w: U0 d/ P; }3 z' U+ R% ^0 ^
elif model_name == "densenet":: o6 e0 I' j) P' | U9 Z
"""* \$ H- b0 D4 l
Densenet
^, j0 K( ^9 A8 [ """0 T, r5 s2 Q. \
model_ft = models.desenet121(pretrained = use_pretrained)( u- Z* ]# k' \" @# _& J
set_parameter_requires_grad(model_ft, feature_extract)6 d5 D/ R0 z& M- H7 V ^
num_frts = model_ft.classifier.in_features
# A8 @, O5 l8 O) u model_ft.classifier = nn.Linear(num_frts, num_classes). }- A& E' U7 l Q
input_size = 224
, }4 p4 `/ D( w7 L& p) w
6 @2 a! D' O! v4 M/ g% H d2 O elif model_name == "inception":
9 S$ a, J& I1 H# u1 l A """ o/ M2 ]! `, r. V+ g
Inception V30 U' V( i; v: ~' u
"""1 q" E. T9 Y. }+ }. [$ u4 U
model_ft = models.inception_V(pretrained = use_pretrained)
! T2 C5 p& V; o/ N0 `, `# L set_parameter_requires_grad(model_ft, feature_extract)
) i5 v" h) b5 T* t K s3 Z8 Z- r" ~) X7 h
num_frts = model_ft.AuxLogits.fc.in_features
, F# W, {. `4 o' Y3 o model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
& F5 x; b4 e' {% v8 \+ w# z1 R. t- C) A
num_frts = model_ft.fc.in_features
8 _6 F0 t8 Z3 y% [ model_ft.fc = nn.Linear(num_frts, num_classes)
: z m' q C5 u4 @; a' [* c+ k, V input_size = 299. P1 c- K7 d' K1 |7 L' }& c P! T
, y# _- p3 w: H4 c* d ` else:
0 Z4 T: m1 k/ K+ o! h3 X print("Invalid model name, exiting...")6 f; j+ y0 N$ ^3 G% N N7 @
exit(). y: \- O8 u0 |! _' Q# i
1 p0 G* W, ^& x+ E q return model_ft, input_size# ^; \# l, O) |5 F9 e% D
% I d* r; D8 ^+ H5 Y" o
1( n0 N n3 H- K6 c9 J
2
0 n' Z' H/ q; ]# A& I9 ?3, e/ D. J- p% n/ x0 i; w. _
4
4 W3 [( P/ \! Q! C3 A, f+ k4 s* W7 m: _51 Z7 l) y. o( w. Q
6$ t7 W8 @) X! ]+ j8 `0 i
7
4 D g3 o. [, w$ @. m: q. v8/ h* \ X3 {+ W* B/ Z1 b
97 q. `2 V9 y* K; ~
10
% A4 J( U% Q. m* M: g# }( S110 _4 k" G! d9 q0 ^
120 ~ k% j: z' t6 @7 [
13 d0 M2 [- m- i- S
14
2 ^0 z, V; n3 |( g159 ~- x/ P, `( x- y# P" N
166 {* x" I" B8 J* B
17" M( Y7 ~) v& y% F& L; Y* _5 n
18
& F1 {: u# v% G8 g$ \# O& R$ ?1 ^- z196 y4 ^! _- I& P$ Q: y6 X
20
+ e: t/ d6 A1 U/ E4 @0 R6 y% e21
3 O' I* r' I w5 r0 A22
- q# Y1 Z% s9 |! ]23
2 `( f6 ^$ @& M( E- |$ r24% f9 Y# j. l5 }( c2 t+ l. S5 q
251 T" ?, {* ?6 L$ K
26
9 m* p% a: r" B7 Z+ p s z8 E27
; {, a4 {1 W; T5 y$ m2 n7 u28
$ e7 x; t) U- n& X' L29. t! D1 O* @0 k( f
30
! m% g I4 t, K! C' U# W31
7 `- o9 g1 h! W [, S' ]. `32- s* ~: e! J% ?" c ?
33
5 p$ t5 K" _( M0 _) g/ q34
% ~' G6 p/ b/ m! G- a35+ P: A, r( E0 f3 [$ r1 n
36! J) m4 E8 C3 ~/ B5 p
378 k" ?3 f$ C" o0 ?. G) i8 M
38
" A: K, \5 G* ~# I2 W, Q2 V s39
, A; w" h0 ?% a+ Y/ W L40: \* i8 r$ L* P. k- K
41+ @2 w* [( H' G* {2 L
42
* A8 D8 L4 X3 `/ I43
6 O6 ?& m* R- T, h- @$ ^7 G44$ r0 Y1 Y) E/ j/ i# ^
45
, u# {) X3 h3 f, W! e6 {46
) o5 {+ K. M* Z! G1 X! Y47
! }- g/ Y. G& P+ N3 u48
% o. n k* ]3 u1 z49( |2 L4 P: N, V$ ~& s
50
w4 Q) J, y, V51% ^. f/ f3 k! t5 W( y% \
52, P2 c8 A' P% t3 R l$ k: e" I
535 E8 e: |/ e* [$ W. L$ n
54
$ G; h" k4 k0 T2 l2 E- u55
: L3 }9 W* C( q3 M$ j56
) |0 e9 W& o. e3 t) t' @6 \% f57
- t1 P$ J; \7 A5 N7 ^! S4 {58
6 H2 p( Y& \& s2 F59$ b$ r& x; K; Q: i5 Q" T7 b
60$ e) d- @! Z# p6 v1 G
61) A5 l5 b4 U5 `* `$ W& K# Y
62, ~+ ^( Y. g( `' j# r- [
63% Y( X4 y& R8 Q9 \1 ~
646 W( I+ _8 {% L' W- k4 q, w. V
65
% k7 i0 F* h/ |) m; j: F66
: V! {! Q5 }' n. e5 G( k* s% ]67. h1 c# I0 o/ ~9 o% h' G; l
68
3 A4 [6 M1 z; R2 M) _9 b8 a69
7 ~; G; i, _& E& q% e$ [! a70- L; |! `: g: z8 r! T
716 G; Q% R* z' s5 Q( F2 X3 {: D0 ^
72 O% u9 C4 D+ V( ?" W1 M& M/ _
73! [; d t) v2 Z0 s4 |. h3 \' X/ ^
74
* R7 ?1 Y8 y: B: }5 h" t75
, z2 G; P" r9 Y3 _/ T# j; T764 u: }5 \" K# A0 t4 k& E8 ?
77
% {& I9 h" \; F7 O8 t788 x( A$ J. y% L
79
3 V; v& X2 D( u9 Q* {, B! m801 U- P+ p, Q/ E @1 T9 @1 y
81/ r0 _: ^% I/ x! b: m# K8 k
82, u" L, t+ q5 N" u& J& c U
833 K7 m6 U; V- W/ n
7. 设置需要训练的参数3 U# p4 [; h# e+ z6 J
# 设置模型名字、输出分类数' R, N, m5 ^3 Q5 {% h
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
5 Z% T! b2 x' ?4 t! n* `
1 s6 i6 h' ~6 t8 e3 W3 v# s. h# GPU 计算
" [- n+ p0 v! Cmodel_ft = model_ft.to(device)
5 Q% j2 ` e6 t' l" N
& Q' ^) F- h' r& A# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取7 s6 N0 y1 p$ E3 J
filename = 'checkpoint.pth'5 ?" J: h1 Y- k
- i B3 x; y* y5 X+ V) U! |+ x" m# 是否训练所有层
1 x' a. H+ K' F% bparams_to_update = model_ft.parameters()
, j3 M! E) x. o1 i1 s( {# 打印出需要训练的层$ `% I0 Z) \5 y ^: y$ O
print("Params to learn:")
1 O3 d7 S; X6 h7 b. Eif feature_extract:
. `9 k! q1 v1 A params_to_update = []% ]8 o' A1 Y7 ^7 b# X* j
for name, param in model_ft.named_parameters():1 a$ W0 H4 ~) Y, j% u0 N5 R
if param.requires_grad == True:
7 ^. s( A3 B' q3 a: l params_to_update.append(param)# r6 z4 Q I- j: H
print("\t", name)2 M6 u5 j7 S8 Y4 h- o1 I6 a1 E
else:
8 a9 }4 x- n L! M for name, param in model_ft.named_parameters():, S4 b. ^- }1 n# n* f* O% W
if param.requires_grad ==True:% t( r9 k' m5 |( `5 V& f5 D G9 W
print("\t", name)
$ R* b# T z+ P6 w0 ]9 v6 V1 X1 P
5 `4 m; o0 S# K: t1, i: T: U- @; L% ?- G5 g! ?
25 i# [+ f- w! T) o+ c Z# \
3: H( D r+ n- B
4' e- z f% P$ e9 C
5
6 Y/ Z) w! b! P) r+ q% r0 B% H! q6- L. g) A; J# x! u. J. v9 X5 ?2 p
7: |5 b* l: }) k
8
. a# L+ m6 l; j t( W9 T9' l$ W; X2 W$ W' Q+ w
10
& N# Z2 g$ _, _4 G0 n3 m11* ~: @5 T6 \6 @ f; M6 Y; f
12
# m' D+ o5 ?- ~& n132 G" h- Z& O& k+ _% o' p
14
|* B) s# L% x4 e! u! S15" j, D7 e }8 j' y# }2 @: [( k9 Y
164 g* M' x! Y. x7 J) @9 p5 M
171 B4 }# I$ H2 A5 p( \* M% p3 J
18, F! r. L4 P: z# r, P% L
19* L" t' P z: F! B, f
20
; Z. t. S' w! a/ Y* q21
2 W9 j, Y: Q- j N22% ]# v% b" R9 c
23
5 z' V# N- ^0 W' O3 Q ], LParams to learn:
. ?0 f2 ~3 A x fc.0.weight0 y1 [0 W3 q7 {" K
fc.0.bias& x$ N! o/ i) q. p9 ^; B
1
1 x( I+ Z& U' Z8 H0 ^6 [7 v2 y/ Q/ N; T; |. W
3
( ]# \9 Z$ a9 ^3 X9 {7. 训练与预测
2 e$ g6 g/ T* L7.1 优化器设置/ C5 p& e: G* Z) C( u
# 优化器设置
9 d2 g/ o+ I$ W+ ~. P& `' H) l- Y; u+ foptimizer_ft = optim.Adam(params_to_update, lr = 1e-2)) s2 [4 N4 m6 T/ w' h9 U. M
# 学习率衰减策略* B; c ?$ Y5 X t$ H+ N. [
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)3 P; l( F+ h0 m( m/ q' h- |
# 学习率每7个epoch衰减为原来的1/10
( `) P) f) C- o$ Z# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
8 k1 |1 H& L \' e& {* M' }; p
7 N- @4 r1 v4 c l6 ecriterion = nn.NLLLoss()
8 P# ?) Y( ]% B, n; c$ @1 P1
# e! @ X( f+ t. k( C' c7 E2
) u# x8 |) r" }% d8 ^/ [, c J3; l/ _* q. A% ?; Q: z: v. P
4
: B. `$ ?2 e# i7 g8 N57 k. a( p9 R+ X& ]
6; I+ y+ U, Z! }' e
7
$ o) K0 T7 h2 T/ r84 V6 y0 A# }, w# k
# 定义训练函数
+ w9 j+ {+ o: O T5 a9 p9 J6 O o1 A#is_inception:要不要用其他的网络
! g8 ]+ S9 Z1 z- ~% z% bdef train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
( B( x1 _ A$ r0 l# \3 Y since = time.time()" H& @/ u I% I9 L0 y
#保存最好的准确率
7 i/ s& _" k1 q( E best_acc = 06 L* g* `0 d8 p) g4 d0 u) S
""": N* K' q+ h% [4 G- z7 e% [0 d% v
checkpoint = torch.load(filename)
4 z- v( ^* ~$ a9 @ best_acc = checkpoint['best_acc']
/ A' m. y: K6 c: e8 s( m model.load_state_dict(checkpoint['state_dict'])9 \1 ^0 t! u5 Y( k+ I m
optimizer.load_state_dict(checkpoint['optimizer'])$ z7 g7 P' g& G; M
model.class_to_idx = checkpoint['mapping']7 A. L2 C0 i9 ]+ o1 J2 Y
"""7 X7 u3 p( I; S% y A+ w
#指定用GPU还是CPU
1 M4 A+ k8 d2 [1 h/ S2 W model.to(device)' Q* H# m" `1 M3 H8 m" q
#下面是为展示做的$ P: M9 r9 P: Y8 t" M& [' T: U
val_acc_history = []* q7 x- K, _7 [; V! g, t* Z" J
train_acc_history = []; O# d# A: L1 v
train_losses = []1 }, d0 d1 o: h t7 v7 v& Q" P' l
valid_losses = []
8 C4 e; ]1 A6 K5 T LRs = [optimizer.param_groups[0]['lr']]
$ b3 c5 y8 m' @* | #最好的一次存下来# k7 q9 m' C, ?+ F
best_model_wts = copy.deepcopy(model.state_dict())2 } m- r5 ^! U" G6 f1 i
1 x# M5 Y( ^# N- ] for epoch in range(num_epochs):- m' g1 L! A B# P& F
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
5 w9 m+ L$ T, d. z print('-' * 10)
* i P- }1 s; Q6 j* ~' P g1 O
3 r* k- }- \0 D2 y# o # 训练和验证
4 T7 p' ~, s+ v. x for phase in ['train', 'valid']:) S1 n+ _: b: r' f6 `" ]( Y" \
if phase == 'train':
+ d2 i* _' k3 U% R9 p! n( B3 M model.train() # 训练
- b+ t8 v% m' K5 Q3 D J2 z7 ~* u else:
- j4 Y% D# M0 d4 M' F9 Y model.eval() # 验证. Z* [ l! U S/ `- Y, n2 ~; e0 J
" o: P! F( M' b; _ O. j
running_loss = 0.0
. F S; X+ Q( T running_corrects = 0 @' f" e" \9 R+ y' ]) e; ~% J
& k" I2 ^( k* h$ L) M& Q2 C2 S
# 把数据都取个遍/ V8 X. K+ C+ c/ O( w) D
for inputs, labels in dataloaders[phase]:
5 L# t' j3 v, c1 { #下面是将inputs,labels传到GPU
: x! P; I$ t+ z1 t) A inputs = inputs.to(device)
5 ?2 y: q& [, d/ I" ]- ~& i. R' j6 `+ O labels = labels.to(device)
2 v, o u) b( m; o8 v. e9 V" B& i) W9 n! E$ B/ N% j
# 清零
1 j* X! l w- ~9 l; c optimizer.zero_grad()
) S7 S9 o; Z. _2 ~. M/ I' Q0 g # 只有训练的时候计算和更新梯度) e! i2 h* p6 T) }$ D" a
with torch.set_grad_enabled(phase == 'train'):
: {3 N% Z, }% _# b5 O #if这面不需要计算,可忽略& M/ S& W8 B. Q* P
if is_inception and phase == 'train':
% W4 K6 g3 _4 `0 p5 d outputs, aux_outputs = model(inputs)
& p% y+ [2 M$ @; M3 s& N1 M( @ loss1 = criterion(outputs, labels)
6 q0 n5 a8 R2 I* v, J" {; u1 p loss2 = criterion(aux_outputs, labels)3 L; E7 G# L# a3 S& `" G/ F+ r
loss = loss1 + 0.4*loss2
/ P" r) c! P) u% H0 F else:#resnet执行的是这里
, [$ _7 r* ? a7 ]6 X S3 w outputs = model(inputs) p8 p/ Y( e7 U" W9 w1 {2 B
loss = criterion(outputs, labels)8 m. B2 V c- N' \8 V$ e
( \ w- {$ b9 Y$ ~: P! h #概率最大的返回preds
5 Y. u6 W' G0 Z% f/ k) s" [ _, preds = torch.max(outputs, 1)) o! X: i# V# S
* o1 k+ p: X% C) C/ G # 训练阶段更新权重
! R) s1 y/ A7 e* r/ v3 _ if phase == 'train':
' R1 u+ B5 g) h3 C! ~! V loss.backward()' ? g) u" I% @; e3 y: Z- ]
optimizer.step()
- D* ~; `2 @4 M9 a/ @
8 g5 s. E' R7 C3 H# Z # 计算损失8 s, I* M# h# o& F
running_loss += loss.item() * inputs.size(0)
, c% x' N' ^1 _0 u$ Y running_corrects += torch.sum(preds == labels.data)! C5 m Z+ h5 e
0 ~1 u4 u& _" A
#打印操作
( @: }" R2 O. e epoch_loss = running_loss / len(dataloaders[phase].dataset)
% y. t" W! A! G/ |/ G- l m epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) L8 \% {/ c8 H, Z. M
. q( e8 y( ~$ u4 H/ w$ y
! D2 ?. O- q! M3 B" z
time_elapsed = time.time() - since
7 l6 t- T9 a* Z+ ], R print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)); i+ z* [$ g" @( g
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
1 m. Q& c" {$ ]" f+ L) q! X! _7 M) Q5 M: }# k0 L1 Y
2 K b$ `) Q: E! C+ W2 O # 得到最好那次的模型; v& ]9 E/ y0 j9 W, z+ N
if phase == 'valid' and epoch_acc > best_acc:- ~; @+ O% X+ K* R6 c5 Q$ o
best_acc = epoch_acc
' [/ Q! ~- [* [6 k- v4 ^ #模型保存2 Q: F; T3 u* e+ ~* [. J
best_model_wts = copy.deepcopy(model.state_dict())
- _3 l& @3 A! I5 ~. d, a state = {
2 F6 h# B, J# J. m0 P* `$ ~ #tate_dict变量存放训练过程中需要学习的权重和偏执系数
$ P/ E2 a# Q: u2 K# I/ L 'state_dict': model.state_dict(),, i8 s8 G9 ~$ D" O4 Q+ n/ e
'best_acc': best_acc,& D) M1 b2 T' T; D) F3 [: |
'optimizer' : optimizer.state_dict(),, p6 Z7 D* }: y! }* _1 A5 U
}( l. @; r! N0 t' |; |
torch.save(state, filename)& E" O: [# Z- g* L; y3 c
if phase == 'valid':
& a, f; `: |2 z0 G; g. h# P: T/ j val_acc_history.append(epoch_acc)3 K" _ a; u' H5 } E) Q
valid_losses.append(epoch_loss)3 ^% w( H( J2 o; E
scheduler.step(epoch_loss)
) d5 i* b) \* p if phase == 'train':5 M- B* _9 U' _- J/ \$ o( ^ ^$ X
train_acc_history.append(epoch_acc)+ S/ j5 ~' H! {6 n) Y/ S/ {
train_losses.append(epoch_loss)2 J. }6 C4 m& l- X& \
6 K& x. Q! l. A* f. n; N8 ]
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))+ ^4 J* W U7 _: _* y( R$ i1 y
LRs.append(optimizer.param_groups[0]['lr'])
; u% }6 h+ l b0 I, G print()
9 d+ B/ Y% ]; ~/ _. h6 Y z0 H" h, r% f) e1 Q2 h, d$ K
time_elapsed = time.time() - since
5 p- F" r3 I4 @; d6 x6 X: a7 G print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))0 T8 o5 Q7 K4 T0 b/ o9 l
print('Best val Acc: {:4f}'.format(best_acc))
3 f2 p5 K7 k6 {( {; v
* A3 X5 ]0 V6 x5 }0 f5 K # 保存训练完后用最好的一次当做模型最终的结果
& i2 c$ D6 A/ l, ^; q/ Y% G' h. J model.load_state_dict(best_model_wts)
; b: c; ^; w5 \- ` return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
; D( ?, \9 R( t1 g7 s) P) o: h0 N8 X: p6 P9 [. r: t% K6 ]
* X5 N. V& I( [) {/ a$ c, p: G1. L# s/ R; v0 R' h0 {+ u
23 L' Q# `; p. W
3
( s5 t0 F) ?0 W5 \6 x3 P46 A! N- j3 b+ j* c
5: Z4 ?% e( A X5 g5 l6 w0 i
6' {( z. {9 d: S) V* R0 k N
7+ N1 H0 ~/ y0 x+ W# t q/ K k9 t0 a
83 r/ q7 D% Z1 P
9: b. J, z! T, B' {
10
4 ]5 l9 _: U( ]- J( g; m. J% l" p5 H11 t5 g1 w# }) _. D% C
12# Q4 _5 f$ h1 [; Y& h
138 ?! [' n+ Y- x( v/ K
146 V5 v4 w* T; L/ b+ W
152 C3 E% l4 y) C
166 m& v' n$ ?7 |3 ^& \' j+ @: t
171 Y$ \4 l3 N5 e- o+ H! }
18
( {- F& J5 v+ ?7 v( A19
; Q" B3 C. q7 }1 L7 v/ ~208 g1 V8 h8 @8 n" ]4 c
21" O9 B6 w6 R" n9 \5 `
22
- n. [' g9 _6 j. Y) P. ~0 ]4 }1 h23
. ~0 Q% F! i: j5 X& O8 o4 D. Y24
- l% Q6 g3 |' \4 w9 P) V25; j1 G" Z& L5 D1 A9 o: `+ U: J
26$ M+ ]1 B+ c i7 ?& }% W& P
27
- d; f3 t1 _2 V; a286 P' o& S ?) z" r
29( W' J* ?- B; o1 P8 s( q0 N8 j& a
30
4 Z- N4 R3 m4 W( y3 @/ b" _31
2 V( x5 q5 P( O, e3 O32- {7 C% P: Q8 ?# i3 r
33
% H- g5 G b/ d7 V0 |34
7 ]* Q8 a4 V1 ^7 x( @4 H+ g35
( \: k& n# T( t36
. }- f0 y, c/ b S4 U: g( B: g2 ?37' s$ _2 B2 J# z
38
( r. P# ^" g! M$ O) p39
f y# I! k1 C; ?& M" C- H40
* B. I/ M& a6 s% G41 w! Q) \' |! u4 G1 w3 H' {3 o) z
42 m6 [* O$ [4 i5 d2 O6 J
43* X$ F7 s0 H$ _9 s: j- `% ^" w1 }
44- f) Z) f- }& T& W, z5 j7 u3 f
45, ~8 I. w6 c; X0 w* A r
461 ]* ?$ k' q0 Y$ O
47
4 `4 B) \9 n: o5 E5 ~48
4 G: s8 ^8 m' y4 d3 R3 C49
0 K+ ^, [. H4 n k, C1 P50
H0 W* Q k/ B W% F51
, q& v: ^$ l) V* U52
+ F ^1 T4 v! J& A2 @538 j5 k* M0 D) c- p
546 o9 c" `. x0 U1 u
558 {7 Y4 g+ o! r- i+ V+ N3 h# M5 u1 L9 l
56! y* r/ G9 X6 V! c' K' ?0 X' b4 r
57 \1 H2 `* \0 R: h( a8 `# C% o3 E
58
+ y: c8 c w6 a' i& i- Y0 h, o59# f- O0 j; x/ j7 c
60
: o1 C$ L% q- v& O* N& E61) R q6 D9 U% m) T
62
1 t' o* r# _ l63
5 y1 B; }" {8 q64# ~0 b \$ r V0 _) V( ^& J
652 c" K& T1 Q0 ^; ?" i) q/ i7 Z
667 h# t! K+ Q q" ^' R
67: t7 T* F$ }# R* P) l2 F2 e
687 ]( f5 t, q! y K( A2 O
69
5 A. R6 {& x* z8 ] i# g70
7 V9 T* P- w! Z9 s7 ?3 l71
6 K) r3 P% H- G3 Y( h72. K* |9 @! O# M; f2 n
73( q' u! ?6 g! J8 g4 @/ C2 f0 W& @. }
74
' b# Y" \8 f3 K; c2 j754 P9 q; i+ C+ z" E" ~# y
76
( g; `. |$ e `: v- K# r8 j* m7 {+ h77* }8 N% q: K$ ^
78; n( N6 h0 ?# \( X& E7 z p+ u
79: j3 b/ X, y+ G" p& Y) A
80: g3 x( D) P% o2 q# X
81
. v: ?% O8 E" w+ ~7 v% t$ ?82
5 U( ^8 o# V2 G3 w y8 ]. e# K83
8 I' L% ~ B" d. V84
* k4 L8 E1 O$ C/ E r' E& \+ A; M) B9 M85
: W6 q; ^: U, y# L. B8 W86+ Z9 a: n" J; n
87
* k% p) ^* T% A) ?2 m; V' o& w1 V881 d7 {4 ~- P4 N) X( s& e8 E
89
2 Q2 o* O" z& }3 o90
8 i' d0 _) J; v" Q919 T h8 t' K u6 l/ [0 x
92
1 S$ g; G6 i; R, f% P93
. `0 z R$ d1 P% k94
; N0 U# @5 t9 G0 @% w95& j0 S$ o9 ~5 A9 U' I: N( U& v8 y
96& Z# [! `3 W2 c
97# v% J0 ]+ F1 v$ u
989 L/ D, T$ `' T& y+ w
99
8 ?# r, ]- @( V# `100
! p# C% C! f( K% `8 ~0 D101% [* v% g9 t) z4 M5 ^
1021 A$ E, }2 b/ R' S4 N* s; h" s: H
103* r+ W# F7 R) \
104: U7 d$ j3 Y- n0 m9 h+ Q( T0 H
105
. c0 t% m5 T; q0 F& [106 Z& L7 t! X4 W, o
1073 d. Z: Q; N: m5 A# P6 J" m
108/ u& w1 T$ E r; ]
109
/ ]4 [. Q2 Q5 F6 P; f110 s) Y: A. M# d
111
; ?! S- b9 A5 |( d7 X) b" B" T112% Q/ n9 b! g+ C/ M9 S1 @( a
7.2 开始训练模型
^, [+ P1 |% W, b' F4 E) o1 `我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
0 g/ x0 i0 G3 M) }. B/ f9 ^
1 z4 H0 `' n' _) M7 k8 U7 y#若太慢,把epoch调低,迭代50次可能好些
: _/ v. Z& q1 P$ b/ j# W3 G$ u5 l#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
5 L% A. K' I; ]! u m0 }$ q& Imodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))5 _- ~+ j1 ^ R
4 {0 B2 k3 z9 } L/ \
19 Y+ i6 J3 Z) d
2
. }) L! l3 g- j; G5 |9 N3
- i2 z% p' i+ n. L* O43 p' D! V3 J( V7 D* k- t0 `
Epoch 0/4
) A; l; C# q9 D( _- B# t n2 z---------- I4 F8 e. ^( \& Q( r
Time elapsed 29m 41s/ W/ Z4 S/ a3 C. T" b
train Loss: 10.4774 Acc: 0.3147
2 @0 O8 f2 H# u) w h) j5 YTime elapsed 32m 54s
' o8 X- J3 t( Y; vvalid Loss: 8.2902 Acc: 0.47199 H9 ^9 |) D% r d4 O5 }8 f
Optimizer learning rate : 0.00100002 p" I, f! K2 C& a" m, d
+ A/ v* {( m! \1 |* r. aEpoch 1/40 R* m# A, @; E; t$ z
----------
2 M# B& K) R, N" s! M% gTime elapsed 60m 11s
& [+ J$ Q |' ?5 b$ mtrain Loss: 2.3126 Acc: 0.70536 v- j/ m4 T: n
Time elapsed 63m 16s
1 o4 l8 M+ J& \- Svalid Loss: 3.2325 Acc: 0.6626. A8 ^* c4 Q+ g/ Y! E4 a& v
Optimizer learning rate : 0.01000006 N& u: H5 l1 s+ ?# ?% |% b
$ u# z% ^+ i" I9 f* LEpoch 2/44 A1 t8 V$ B4 x* m" b( ]
----------
2 P4 H# x1 ~6 m8 F- bTime elapsed 90m 58s% T. q2 d* m: W2 c% V2 l8 |
train Loss: 9.9720 Acc: 0.4734- Q* Z* L- {2 t) c# M
Time elapsed 94m 4s
. g; N: y Y+ N6 Jvalid Loss: 14.0426 Acc: 0.4413
! ]3 I, b* W4 h0 B, H, t/ EOptimizer learning rate : 0.00010003 ^' Y& ]) F. w& d
$ s+ t6 Z% {: B! X, g. a7 @9 k
Epoch 3/4) `, `) q+ a6 a8 }$ i/ D
----------
7 ~% t, `& P5 Z2 W# oTime elapsed 132m 49s
3 I/ s" S v" i8 Ytrain Loss: 5.4290 Acc: 0.6548$ F( k8 m0 F& h8 z# A! ` p7 I
Time elapsed 138m 49s) y: }/ F" p& ]" ?! l' {) J. r8 G
valid Loss: 6.4208 Acc: 0.6027- r8 [( K/ W; o! d2 _& d$ s5 S \
Optimizer learning rate : 0.0100000# x9 i" s9 K8 j+ ]- P
- v& Y' w1 y2 F6 H- c, i0 \Epoch 4/4
' s5 A6 K' e$ {4 U0 r----------2 @$ p! P+ i; [
Time elapsed 195m 56s2 G& J: s5 Y5 i+ r$ ? }. y
train Loss: 8.8911 Acc: 0.5519
- I% M: [ B7 L9 ~$ X |2 r9 y3 ETime elapsed 199m 16s
/ a: f y6 O# Z& ^valid Loss: 13.2221 Acc: 0.4914
- F9 [7 y' z7 j. q: B( }Optimizer learning rate : 0.00100002 s1 S) M: S3 n! V5 `- F
; e6 s3 y" Q" r: @7 b2 o9 g4 H
Training complete in 199m 16s/ x* f2 t' i$ K" b
Best val Acc: 0.662592
' ~* ` N6 C: Q8 M$ i+ ], m8 @
' F3 H3 M1 m3 V; O1 B1
- Z: X% H! ^: n/ w( L2, R8 z, n: T& e
3
& [' W- b* e% i* x ^8 ]/ @; V0 P4" k' T5 H' F1 R, M, i# f
5
$ S3 E0 `% t. ~: |9 P68 P' l6 F( o' `- P# \- \0 B
7- f B9 L' x. e% |: z8 z3 A( |! v
8
1 P) M: X% f& ^' \. J98 F1 N/ k k' p* }, ~
10
/ `: o) }; } `& p% C11, r6 ^4 b5 J9 W1 \5 H& N
123 Y- I" |2 [. |. M- q- s0 z' A* p2 U
13; a& B X: p9 A% Q9 n0 F, i3 u
14" Q+ _( d+ V: T- i5 D- Y
15
. T4 K: o' S8 [! y" ?3 E( Z$ \16- }/ X* e7 i+ o0 G
177 @% w- `+ V! {
18* U& ` v: V) w* e
19
6 p5 ?# O+ D9 u6 }20
9 m6 T# |9 o$ }, m; O21) b: \7 g1 i8 U* Q6 S9 Z
22
d% P' J/ {* C! r( @1 H23
4 B3 x" L: H6 [* r5 `4 N24& u! w4 C [1 _, }- ~
253 C9 T* `6 A" X; a, B& |/ R5 \# z* V( X
26
w; f% i; w T0 \) ]27
6 A5 S, a/ x" `" w$ c: ^/ k9 W28( D" ?" d8 D- j! s0 E0 W& w
292 y; d' ^6 G' s- {3 l
30
# s, A- X E! m31
) w5 J# u: V* F% U+ a: B1 \# @32
% r+ x0 |$ a1 T: V33- j5 {/ r- Y# ?8 z# s4 \
34
: c+ g" t( Z, B$ M35
0 S# H& B1 f/ \- E9 W/ t36: Z- R. ]1 l, O/ j2 L/ o3 @
37
$ `3 u- s6 T9 y4 d F38
/ b5 n/ {, k: D, H: T* t) g39
8 `3 j; }1 V- T. I# n9 M& _40
0 K% F' Z% `: E1 z }41
( t6 Q) P; x5 l- g. ]- f" N42
5 f9 F x: {8 [5 P7.3 训练所有层
0 F3 U0 w& n" ?# m9 `0 e! H. ]# 将全部网络解锁进行训练* ^& P5 ?- O) A! C
for param in model_ft.parameters():
+ ?7 R! s# x$ u! Y9 C* d param.requires_grad = True
8 n( k% x: J$ ~* R" _
l t8 X6 H3 R# 再继续训练所有的参数,学习率调小一点\
5 g, s3 q7 Q0 Z4 i& Qoptimizer = optim.Adam(params_to_update, lr = 1e-4)( a- H3 c8 n2 Y: p H* I
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
& V' h; }9 }. b d2 }' x8 M1 u( e% k8 n; }3 r6 K
# 损失函数
" L& J6 L. w9 e" ^. tcriterion = nn.NLLLoss()
6 p, f, n Y Z! j8 K/ y1
5 U8 T( \" q" e/ G3 u/ O2
+ Y6 H+ N3 Z x; O7 F3
3 V+ c3 K% `7 h48 N/ a6 |! ~- F* k
5
7 e+ |* N. Z: l0 J% u6
( W( K2 H+ x. B/ z7
! W/ t' Q! w$ W, ^( m86 [4 h7 c: a3 w" \
9
0 q( m! I8 z$ K) A( h10) _0 b7 r6 f, q0 y
# 加载保存的参数
6 q9 N0 u' |! _- }% l! F( }- i' L# 并在原有的模型基础上继续训练5 Z! A* n) N+ V* T
# 下面保存的是刚刚训练效果较好的路径9 s W$ s5 c% h t- V' z
checkpoint = torch.load(filename)# y1 G5 y4 B* s7 f# [0 z2 k$ C# y
best_acc = checkpoint['best_acc']
/ I! ]" H# }" u# A7 B2 x. `4 Omodel_ft.load_state_dict(checkpoint['state_dict'])+ Q3 J H$ d& H7 k, x
optimizer.load_state_dict(checkpoint['optimizer'])
$ w" Q. K4 p+ f: L1
5 e: [# \$ w0 F" \1 K& }22 z; W0 v8 s8 Q, i" X
3+ i0 `8 [2 e. H( s
4
8 ? G4 N( H- X3 W) L- T! N5
0 i" X1 S4 A: d# r, X1 _6, u3 g' \+ V* H
7
: k0 c& t# p) x# T& d7 }5 q开始训练
3 |) w0 f, L4 S& | G注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考8 ^" ]" Z' F3 P6 `! k$ g4 V+ X" e
: W: F" R2 b5 K
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception")); e) T( V+ x( Z7 [" s0 @4 ]
10 O8 v q8 H5 J5 s; M
Epoch 0/16 W* c* A( l- F3 {8 a
----------
6 {5 W) k" E* E# VTime elapsed 35m 22s
% w( Y) `8 A" f5 W( Vtrain Loss: 1.7636 Acc: 0.7346! G3 a+ x+ t0 x$ ^, P2 y8 R
Time elapsed 38m 42s
! A5 J7 y$ r+ c% q7 @/ n% m% ]6 H3 y; Y8 _valid Loss: 3.6377 Acc: 0.64557 f0 f, k1 n* d3 Q( ~/ k
Optimizer learning rate : 0.00100001 I% K2 u L5 d* S( }1 K
& T! S9 |' s( f# l! mEpoch 1/1
+ L6 a6 Z F9 o4 \1 F----------
( w% r) S4 \: S9 H8 b' D2 L" STime elapsed 82m 59s
, K5 Y9 k! W- L& @9 }: P7 dtrain Loss: 1.7543 Acc: 0.7340
' I8 _( X3 N; |7 t: PTime elapsed 86m 11s( ^) e$ k% ]1 c y" u. @$ @
valid Loss: 3.8275 Acc: 0.6137
' b: K) B$ I; V/ d9 ^1 ~* P0 d& G ROptimizer learning rate : 0.0010000
& y) P. ~) T5 I" b" @
! `' W+ F2 X# iTraining complete in 86m 11s
$ m. I/ H5 c* VBest val Acc: 0.645477
- e$ O+ `4 Y# r+ X" g; I6 q$ a. p0 }6 p8 L. f, r
1
: i8 r* ^+ Y) e8 V; m2- D: ]& w' X6 j8 F, D2 ~
3
: }9 D5 |- l' F6 x( w2 m4
, @+ _3 U( Y$ r' z2 f5
! I9 G% ^% A4 s6
8 V: _5 q9 p1 i) o3 O N7
+ v, e1 a' G% w9 i2 G$ q2 i1 ]7 X8
* [* b+ c* T! I0 U9
+ b+ w& F# ]7 Q' b' S10
2 x3 L" x& N. p; r4 r( a/ f8 B a+ k! r11
) P' n3 N- K, L$ \. j126 f) C6 q) s/ Q" x) p+ k" E
13
4 T4 d5 K9 M; L* q% r7 f% E14
5 D, ?' K3 p/ u# @# }1 T15
4 { f* X' R" N g16
7 Q$ B% |+ A/ u; _8 |- m+ ]17
4 z n @* }! Q- j18( o8 ?3 e# b: @* e
8. 加载已经训练的模型
( P' w/ D3 t Z) w0 s" z5 X/ Y相当于做一次简单的前向传播(逻辑推理),不用更新参数9 O- _& v t" `) H( h4 O/ ~, e' K
+ R9 J. u$ c$ @* d, |/ ?7 U' Q
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
8 u; w+ @& b( Q6 {# G. R& M4 D8 I3 L. R' S
# GPU 模式
`/ w. i. q1 z4 D1 Dmodel_ft = model_ft.to(device) # 扔到GPU中$ r3 y& }3 Y! j# r x
$ c9 u% P: W6 S" A5 ?: z# F# 保存文件的名字
: l4 p) Y) N# k1 c* c3 T8 Efilename='checkpoint.pth'& ?2 ?2 D6 T0 d- d$ q% _
& H( m ~1 S4 g" o# 加载模型+ C& {7 X. m/ E5 n
checkpoint = torch.load(filename)
$ P- o* X2 i$ S$ L* ]best_acc = checkpoint['best_acc']
2 D1 X. C' W: o8 i# Xmodel_ft.load_state_dict(checkpoint['state_dict'])( ^' K2 u, e& d' |' m! e# ^
1
1 _) S2 K E3 J2 T: v' a- s0 u2
: F, @1 G2 [, u! t( T- D3
$ `. V1 B Z! u n* d% g1 m% W4
1 R E+ w; F( ~7 O3 N K59 T3 l3 N P3 _& d
6- x4 R ]+ w1 N* T( u
79 M8 {6 ?) k2 ?! E
83 @8 C9 z& \5 l' Z" i
9
, R( y/ A; Y* x( l10
& E p' i5 U! J3 b" o3 Z% Y111 w4 ], ?9 |5 _; O U" l: l1 F
12% t" P; P; D) r: b% d* [
<All keys matched successfully>) b o* @/ ?. T0 d+ j6 u
18 R: [5 z: B' i% H5 a) }
def process_image(image_path):
! m, J/ }4 Z' A # 读取测试集数据+ D3 f2 q3 o) B
img = Image.open(image_path)7 |! Y$ _+ R' t: E. x! b8 V
# Resize, thumbnail方法只能进行比例缩小,所以进行判断
7 T5 o& E7 G$ b( s0 O& l5 t+ u # 与Resize不同
; d1 `! I1 @$ u: y) [7 Y # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
. W" Q4 B& B" D* s; g4 I: C4 y+ P # 而且对象调用方法会直接改变其大小,返回None3 V. j. i. {0 X
if img.size[0] > img.size[1]:' b) K4 H4 M, w# t
img.thumbnail((10000, 256))7 X( Y% s6 _2 i: @4 X3 V+ w) R1 f
else:
# r' `9 Q `0 \$ M! @: b, P% H img.thumbnail((256, 10000))
% ^$ Q F7 Z6 f. @' S$ Z$ P
4 s; Z. W! N2 ]( ?1 I; y' ] ^ # crop操作, 将图像再次裁剪为 224 * 224! d1 f7 @, S& I" @
left_margin = (img.width - 224) / 2 # 取中间的部分6 Y* u- `6 M J t" {' n# D( |
bottom_margin = (img.height - 224) / 2 6 a( |, S( q, N) S+ d
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
' e; s) B+ b4 f" H7 L- | top_margin = bottom_margin + 224
2 }) S6 F$ A% P2 p0 L! k, u+ L' r# H3 ^" o2 K5 d! K8 M
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))5 C; z& s' I; o: Z" Z4 A
& e& y9 q2 I/ F1 Y9 S. z # 相同预处理的方法& z6 `* W6 p$ m3 `( h3 {, n j
# 归一化7 I7 C, B& G5 b1 x; h. M \
img = np.array(img) / 255
+ z+ r6 K$ t& {" N! U6 J) ~ mean = np.array([0.485, 0.456, 0.406])- {/ Z) w2 s( g( F% U+ z
std = np.array([0.229, 0.224, 0.225])
; [3 [/ r" B7 C+ w; [ ~ img = (img - mean) / std0 c5 W, `# M( y$ e# z4 d1 Y5 F- Q$ I
- B( b0 H% M/ r ?& z0 n # 注意颜色通道和位置
. A5 t& _/ D& r img = img.transpose((2, 0, 1))# ?) f/ _! g# i* r
, A5 o0 R, Z5 ^: q$ ^ return img
7 z7 \& L# P9 b/ g) {9 G T: x) x2 p" t+ ?6 D [5 f; e: ~- Q' J
def imshow(image, ax = None, title = None):
" W; C4 L8 ~ | """展示数据"""0 {% w# {6 d2 Z! E* o! b
if ax is None:
8 x9 Z0 p; @9 x. d" c' u fig, ax = plt.subplots(), i; X! Z. X R( a
" }' M$ p# {1 }( {5 `8 E! }! i # 颜色通道进行还原
7 r3 T9 B8 t8 y4 D' ? image = np.array(image).transpose((1, 2, 0))
& X9 V6 q2 q; M# u7 F) q
0 T( i0 G- H' `# F # 预处理还原5 i$ p3 h( p4 P. g" [7 C2 E3 Y
mean = np.array([0.485, 0.456, 0.406])
( y/ f& e4 @, U0 l std = np.array([0.229, 0.224, 0.225]). r4 t9 g" b# ^% l0 c0 P$ |
image = std * image + mean
, Z5 T! g- |; [ image = np.clip(image, 0, 1)
4 ]/ R6 l8 A5 _) M( {. X$ I0 r0 T# D. ~( ^. |
ax.imshow(image) A4 g( Y) B) C. I% ?
ax.set_title(title)6 [/ K& c0 ~5 I1 |5 x; `
) J$ z+ e. M1 G7 v4 Z5 m
return ax5 T8 O' E+ f8 j* h1 f3 U
# t/ N. l0 e# M8 u. K$ X
image_path = r'./flower_data/valid/3/image_06621.jpg'! N O$ D" u& _: x; o' P8 g* T
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理$ F8 D' {' p1 B, ?/ m* N" n
imshow(img)5 [& V6 J5 `. |* b5 M- {
2 j0 ?# F8 M5 B+ c7 \% {# U15 Y, u: Y" Z/ t; }8 w$ H
27 C% F& i1 i0 f0 u t" i
3# ]! e+ b. k) J3 n2 L3 E2 @
4
8 F* B# D+ Z, J, M1 a" T3 o5" y* x, q" i: X }9 g& d9 ~2 P) F
6
2 L& O: b1 P; j) F7 W# t |! O7
; ^2 D4 K/ }( a5 a6 Y ~2 E8
4 L2 E, {+ i$ t' |3 D9
% T( a0 [( H% _& F8 Y10
6 n% S- n4 Q6 L7 D- P! v! `/ L( z( \$ D11; k* B1 y( ]0 m/ N- y
12/ \7 U' Q& L( H( R
13
0 x6 ~3 S2 R: z140 N [1 v; a% Y6 d& s: |
15
- ~& a, R$ v3 d. p5 _; y7 o163 D6 h6 g# x/ j/ g- _# I( P
17
3 ~0 K6 e; I9 j( j18
6 N7 ]1 I L' Y/ p' X" k19
" ^6 n8 x* r( C: o" x7 A202 @ u+ D* A# X9 @1 p' k1 M
21
* ^! H. C2 [8 P) l22
( q) y1 @+ e" [$ M/ `234 B: F. U# E" M- h$ h8 X" R
24
8 \# V! b6 J4 m- [$ S9 U# h25
, \, ?# m0 k* X& T' ^26
2 h; c& U7 U: l- D7 S3 s27- D* _9 q, w, Y( l
28
( c" K. f# E6 M' C+ w' B29 d. n! L l' s# N
30
# i7 }, \8 \9 u6 ]1 r1 q) L3 F31
8 k: C8 m2 Q9 ?0 Z' ^; Q32# p8 x* r3 H+ D7 u- y
33: | r+ K/ e5 t4 ^% ^% A
34
M P8 L6 h% {: O- t35
" q8 `! u& |: v) @% [ X% z36' C- h7 p& ]. |6 T
37; \, C/ v9 _0 G* y" b0 u; G+ f0 K
38- m, T7 O/ p% Z: I
397 R# ?6 e* K8 [% Q0 s2 e
40- \! ^" k1 o' r2 d2 ^, x+ \+ T
416 Z: q3 o1 E9 M2 L4 C X$ o" o" N
42
, {/ r' q9 `6 H9 X43
: }7 ]2 H% `1 U. l44; t0 U; S- v3 _
45
& `4 p2 G" I/ E( ?: O46
, A2 x2 H6 [3 ~478 f0 R0 z1 ]- a: R& K [$ H B# S7 M
48
3 i b# j/ }' |9 E0 z49( C. f2 R$ r3 x5 n- ]
50
0 l2 W Z, n7 r, `, |518 z; U9 B! {" U. K" U
52 F. m3 x/ q' }
539 o4 T; q i+ } }( | n, s7 N
54
7 d% c+ O8 U: x8 I( P2 P<AxesSubplot:>
9 v/ l' J3 _8 z) W5 Q$ R1$ j3 g; A8 I w- w
8 l I" \: Q# X/ ^6 U6 V
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
) Y# b- \* w7 I! S$ [1 K( d
( R- _! h1 a4 T" R! r t$ oimg.shape
) F: r% ?3 [# }& J5 {# w1
8 R8 B, S$ ?3 i(3, 224, 224)
# R9 l7 q! W+ `- t% d1
# {& J1 Z$ _: K( p8 T证明了通道提前了,而且大小没改变7 j% u0 ^* U, t. J( Q" K
3 C% i8 O# }7 f$ F9. 推理: T$ g$ J) K( K
img.shape/ ]! Y8 f4 w( l: e+ @
6 b9 A+ C+ B# \& N0 D! c `( i# 得到一个batch的测试数据 Q) j2 Q1 z5 F4 d) [1 m% m) f9 j
dataiter = iter(dataloaders['valid'])
0 q+ ]; k; C3 _2 G+ a- u* l# Mimages, labels = dataiter.next()
( H4 I3 r' h0 y. e) x6 f! a" s
. e; A. a7 m) X; E4 N |model_ft.eval()
* B* u* _3 A4 N3 Z a" r
5 k/ q& O; F3 j4 x: p$ C5 s& pif train_on_gpu:
0 B1 F6 H7 T% k m* d # 前向传播跑一次会得到output
' G) M" `( n; [( X+ ^- U: t7 q output = model_ft(images.cuda())
# u. v( b' @0 K( ?2 b8 Eelse:
: n& e8 F: [: ~3 F' |. B output = model_ft(images)
- D' k! }- W- ?; V. ~- K m8 I3 l
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
) z- ?2 a6 t# P4 q) Ioutput.shape
: R% P9 p2 w+ x, o% l2 `
! a$ H/ A1 [" u. R2 [1
* }& P9 |$ I9 Z/ D2/ y; h1 M W: |: [( x2 t$ O
3
2 n' L& D/ O! x4
- w8 u' ]1 C0 n! N+ U5
) x7 {" J& C+ L% a9 z5 r6
7 x% ] p9 V8 i( V7; j ^# x, c; E) s4 m M
8* g5 C+ I c+ F R0 D+ M( q
9
) g6 Y" p! h+ t10
; b3 K1 `# }3 j) k7 K( q8 |117 a/ E* f# X9 K! u5 n2 v6 L& o% f
126 g# c4 l% g( i; }9 a
13% Z$ U) h5 W4 ~0 B; V
14
c3 E0 i- z: Q6 j6 a( e* u( `15, \$ W$ G% T1 u* m9 K
160 E2 C( @6 R7 U% ]
torch.Size([8, 102])9 ^9 v$ }0 w+ L5 N
1
1 ]& S8 @# V9 \8 ^$ Q9.1 计算得到最大概率6 O6 _; ?' V. @9 ?6 @6 [
_, preds_tensor = torch.max(output, 1), o7 C' s$ o0 Q: p
4 H4 |* W" \7 `- D d% T6 Zpreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量, s+ |/ C" A$ O7 p6 F0 y
1) R9 ?# O1 P) G1 N+ k+ {3 ?
27 g: P/ H6 [4 A5 \5 `2 g
3
, R: c' V, Z4 U+ w; S- N9.2 展示预测结果
$ k1 Y% o9 b: Ifig = plt.figure(figsize = (20, 20))- ]0 c3 C7 I4 P/ }# w# F. l
columns = 45 ^% G8 F2 h- E. O* \
rows = 25 p; c) I/ b. ?) U
" [2 M: ]3 Y& M. i* pfor idx in range(columns * rows):# W& B) r- Z: q8 X4 U* o
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])( `* f7 s# k7 r- n' s
plt.imshow(im_convert(images[idx]))
& {+ ^" f, E# G- M ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
& W, M& U# k3 d6 v0 K color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))* L8 |2 `0 a; a* z
plt.show()4 e- M- |( I L. e4 k: V
# 绿色的表示预测是对的,红色表示预测错了
' P5 {: j+ S. y19 Q6 ]0 G. k' I2 ~- v% \
2
5 h7 Q; M8 q2 _2 ^3
# _, W9 }. y1 @. p4
; U7 `1 L7 z! O q5% ~6 d0 F" y$ m+ c! l3 s
6' }8 T/ {4 ?3 q
7/ `: z) n: z* D8 J. r6 w# M8 `
88 j4 c& e" `$ P$ Z
9& _5 k% Y, L3 M2 \) B# j
10
* X. v3 \& C3 J11* W9 W- U$ ]: u* m2 T
' [% ?$ C: R0 g$ s
, V9 q/ @4 n- b- k% I
/ j9 a7 p" Q: m————————————————
( u, J; ?- V m6 M( F* G8 S版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。7 I- }7 A* q ^4 Q5 x
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940- X( W# F+ _/ C) d; b
8 _7 L4 z9 ?& `! a
# b3 `0 G, w" `) G( L |
zan
|