数学建模社区-数学中国
标题:
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
[打印本页]
作者:
杨利霞
时间:
2022-9-8 10:41
标题:
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
9 b7 l* v- n# H: ?# Z- A, @
( J& e2 I0 s; |* Y+ U- [
文章目录
0 g" g" P1 Y2 `! ?
卷积网络实战 对花进行分类
& T; n6 t1 H" r8 u, Z& q6 j0 [, K
数据预处理部分
/ r I# t- y9 h, P8 l
网络模块设置
$ k3 M$ Z4 j' C
网络模型的保存与测试
7 m- Z6 q% F: u$ v* o: _! v$ ~
数据下载:
) r( b, \5 D/ f3 [
1. 导入工具包
- O" _3 y% n4 g2 [
2. 数据预处理与操作
# {: O6 ]- h/ \ j
3. 制作好数据源
( a; c7 J( e4 R- E, M! @- k
读取标签对应的实际名字
; a, a: _6 L% ~" B; w2 l5 d( C3 ?4 R3 y
4.展示一下数据
5 P5 h* C# G5 m5 Z8 I) I, S
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
8 I' o9 e8 V- E. l" l& J8 u9 |
6.初始化模型架构
& w0 O. i5 E- S0 K; f0 T+ w8 Y
7. 设置需要训练的参数
( k% |" |9 U/ p, k: [8 h
7. 训练与预测
% c8 {" ~$ O, d+ v+ t; v
7.1 优化器设置
' ~$ a- N9 W( l% L d) p& k0 t, w
7.2 开始训练模型
L0 A2 S# |) l6 I9 \& e% _
7.3 训练所有层
0 k6 ` e# g. s" ~
开始训练
/ B/ e3 T7 x, p; n8 r- V
8. 加载已经训练的模型
$ n( t+ u1 y0 o
9. 推理
: D# q# F1 S$ ^, _
9.1 计算得到最大概率
" P0 a* m- i& x5 d4 l
9.2 展示预测结果
; l# I b6 P& G0 M. X
写在最后
6 r) v; y L f+ y6 [
卷积网络实战 对花进行分类
) ^: f( d; E l. x; o0 f
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
' H( Y# M% k9 B* X) d
. k. F+ b ?8 ]% z
在文件夹中有102种花,我们主要要对这些花进行分类任务
* ^7 G8 I0 ~+ m* v
文件夹结构
& A" H( l) r3 r- U' R& j
% \& i( M0 ]% H+ @! H$ d
flower_data
. ~, I7 c- g+ {2 ]6 d! K/ e
6 U6 l# X& T; y
train
& h4 Y- Y0 s O" ?( e1 P" H
4 `$ l- }( P, B) ^; |
1(类别)
& @# W& o% R p; l* K
2
+ g( ~6 p! q+ X; q
xxx.png / xxx.jpg
( R4 s1 Z) G4 y
valid
, [% k, s* t# _$ b1 f
% l$ V" h2 {$ k$ k/ D6 J' F! F
主要分为以下几个大模块
" b+ ?! O* D4 M: z
* S. a- W% {8 q
数据预处理部分
) C, N% S* z$ p* L
数据增强
/ l% Z0 n' R- P
数据预处理
% \; b# J8 h4 g/ e: F) ^4 Y: ]8 T5 m
网络模块设置
% |- T: s: w% b% ]# I- {3 Y0 D
加载预训练模型,直接调用torchVision的经典网络架构
* F. I3 w G0 W& J: ~ P, s
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
4 u6 \) S- I( g; C; @) C
网络模型的保存与测试
) _3 {$ g5 X" [5 t8 h1 f
模型保存可以带有选择性
& R1 ?: O/ Z7 J/ |: Z$ x
数据下载:
9 a5 F/ V$ v* H! x" R( M# u8 ^. W
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
% x/ W) `/ H- M$ @0 j! `9 l' V% B9 f+ G
5 ^3 n% L1 K0 y4 R! ^0 X
改一下文件名,然后将它放到同一根目录就可以了
. O% b( z7 s; {8 h' g7 l, z9 V( f
. b3 O6 ?9 t) _6 F1 [
下面是我的数据根目录
. C' @4 p6 l) Q' V+ M$ x
, H3 c* \1 M: k
# Q1 g# N" M% b% V! ]8 @" n
1. 导入工具包
8 B- i9 K7 j+ L1 |: F9 B
import os
, `! B. @" Z1 ~1 x
import matplotlib.pyplot as plt
7 k2 ` @ T$ }; Q
# 内嵌入绘图简去show的句柄
6 n1 b s8 S. \- P! O8 ?& F3 ^
%matplotlib inline
4 M6 G' \3 j# _
import numpy as np
9 a6 \$ n1 E+ N _* [, h
import torch
3 l0 J. {# D$ F% r
from torch import nn
9 w3 X E2 J0 z9 M% p
8 P0 C4 r+ ?; z7 o
import torch.optim as optim
( i4 m3 x: f' j
import torchvision
- u! g: Z! }, T
from torchvision import transforms, models, datasets
/ c4 v# d0 J8 P# {* W; |* A
4 `! s. [4 j7 i5 o6 d4 l. y3 d
import imageio
/ R% x0 d2 s) s6 X1 c
import time
" h& |7 C8 L8 M2 i; W
import warnings
" h! ~! U0 ? s! U
import random
7 Z5 m: v. x' d+ j
import sys
1 W4 }7 L4 a" o% h" c9 R
import copy
# R( a$ K+ N2 C' y( ^
import json
/ w, f' e; X$ S4 l' ]2 v
from PIL import Image
9 g" O" I0 s: ?8 u
! Q1 ^! Y# S" g
( q: U; Z* C6 \# `4 r
1
4 F2 S* e% j) c, B) l
2
6 m+ r- Q9 E& N+ t4 g3 {+ r
3
) [3 O3 E+ S/ Q: b$ b# h% l
4
2 u/ h* z9 X* l3 \0 A
5
8 o2 c, f6 r, y. X' L. o
6
7 o) }. I/ Q) P( l8 X
7
9 g2 Q, C$ U, M u g1 C5 ^
8
% X! O6 F" T8 F9 v2 L
9
( E. i' ?: _ x% o% U
10
: B+ N- F+ B6 f6 z2 C! f) @, r
11
. O! A1 E( `9 t4 t' s' d$ q
12
9 R% m" N/ O1 `4 K6 x3 o' G+ G. T3 [
13
% L& Q9 m s1 A8 Z
14
4 E* f" Y& T, K/ B
15
* x$ B2 @! x% x+ V# a: {; d
16
6 i2 Y% w4 [* I7 Q
17
* z8 K& s& A" o
18
9 N! h: V+ m9 T) n& c) j) e
19
* ~- a9 z1 b0 h& q3 G
20
8 ~4 ?$ R0 ]5 j
21
$ \; ?3 Q8 Y i% S9 c: ^3 w; @
2. 数据预处理与操作
$ E; R8 l$ e! s% U! M" V! _
#路径设置
, q, l! Y2 [+ u! E) q% `
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
/ }4 c. A* R- W$ s7 Y
train_dir = data_dir + '/train'
1 N# d9 E% |5 ~% M" g
valid_dir = data_dir + '/valid'
( j4 J, Q2 i$ ?/ g7 j
1
; k& x: I* }% ?, f: ?" z
2
) H/ @. Q! u6 q. T) p
3
. t. S1 B. h! e* z9 a/ w
4
# t! i% [0 T. {& i* u# e! i; `
python目录点杠的组合与区别
9 ]% Y( e4 i8 O9 M [/ p
注: 里面注明了点杠和斜杠的操作
5 _; j$ t2 S; s G; J% n
: X: I/ `) v+ g; j
3. 制作好数据源
$ V. G0 O; A3 q
data_transforms中制定了所有图像预处理的操作
4 C5 s" x! ^: A8 z. `+ q
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
* f9 |. [! l! O, K- L$ o" b
data_transforms = {
: l/ ?6 D' P [' A8 x% d
# 分成两部分,一部分是训练
% W3 F Y; |* I% ?# o
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
; s4 O) u* }3 P) ] n
transforms.CenterCrop(224), # 从中心处开始裁剪
y: @- d% B/ [0 Q* v' b
# 以某个随机的概率决定是否翻转 55开
7 }) ]0 ~% i1 E2 G
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
) X0 ^4 k" q4 L; c% f+ L
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
; k7 n j [! C, l$ c, m
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
: u, r2 t& J2 {# ?: t, X5 J
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
8 r0 o2 M# ^. E: Q* M- F. i
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
0 M* L0 B n9 h( c$ r
# 灰度图转换以后也是三个通道,但是只是RGB是一样的
0 F& T8 P. f4 S2 @
transforms.ToTensor(),
. {4 r' j& b6 b0 q: J
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
4 G. M& Y; Z/ y1 R2 t+ V
]),
, h0 X4 L2 O# w1 `% }7 ?5 l" s
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
* \, Q0 o! u( |% C4 t
'valid': transforms.Compose([transforms.Resize(256),
' O- t! O8 u) C# K+ d
transforms.CenterCrop(224),
# z; X) I2 _( F* b$ M; D
transforms.ToTensor(),
- [5 h" x- j! s8 H! F+ ^
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
4 h6 O9 ?: V) H$ Y8 L6 f D# {
]),
' p6 V6 D) U4 n) V7 J! x
}
( k' U9 Y' @/ {/ a8 a: d
9 C B$ H3 B O- W- M% Z
1
+ @0 b$ u% e% d6 s
2
9 u9 a, v) w+ U8 `& e
3
" ~% u4 _3 u7 u9 A9 [( d) O7 t
4
4 Q9 t, l0 D3 ], M
5
& m8 `- ~4 O/ |; o) O5 y! f
6
* A: ~8 @/ y) k3 Q* F5 ]
7
% t; k- n+ `2 h! u6 U
8
( I' n4 m5 {3 y! u
9
% @% N' I& Y( a( z; G
10
) s& N, r/ L, _) ]8 b) h2 P& u
11
& K: W( L' Z5 j* K
12
7 x( A+ i3 e. P, L; z
13
4 @) c% E$ L8 B% N% v! Q' T
14
% D. v) C Y# F) c' N
15
: t+ X) A3 I D( g
16
( `: \" V" [- e- Z8 \9 u
17
6 n; _4 X& M+ B! h5 |, D6 k% q
18
4 Y/ t4 V# M$ h$ C4 L4 L1 H& b
19
9 Q1 o+ v; R/ H9 @3 L8 N
20
. U" H4 d$ F" |& W- Q- e
21
/ L9 ]$ a" j# { G7 {8 }& K
batch_size = 8
' Y6 k; }' I1 {5 R
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
5 D- d F. P D9 H8 j- A; Q
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
! d0 H7 L" A# U; }: \, b) |3 L V
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
+ h( ~2 g) s5 U/ d$ U
class_names = image_datasets['train'].classes
. n/ t- b' z2 N% ^: z% ]% |
x9 T) G- K1 Z" L% d
#查看数据集合
: u/ Y. j# l- ^) f+ u) j
image_datasets
7 W; s9 w9 k/ @" ?
) \) T8 A# [* ?2 n' v( V6 k) I# ~
1
! K% J1 c" V# O; R& o
2
# ^: a, K& x6 ]9 m
3
" U/ W0 Z- V( f' t) g
4
, l# |; j1 Y/ }# Q5 B7 X
5
4 c! V0 t# Z f- n, |
6
# h1 G8 B% }- _: y# X/ b
7
$ z& t) g/ c7 z7 [9 W
8
$ D3 M' x* J- R' B
9
" n [. w! y/ A% n
{'train': Dataset ImageFolder
' j3 p- }2 \# Q: y0 h l! [
Number of datapoints: 6552
- {" p* Q4 A) J: r/ Q
Root location: ./flower_data/train
3 z- [# h: [; D; X8 }6 k" J
StandardTransform
6 b3 s2 b9 V' {3 y. s% ]+ O
Transform: Compose(
: `: W2 `6 i7 N8 L; X; B
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
* Z; v N; A, |+ `2 U3 u; g
CenterCrop(size=(224, 224))
' x! C3 D l6 Z
RandomHorizontalFlip(p=0.5)
: M, d; K4 \! Y Q
RandomVerticalFlip(p=0.5)
5 z5 J; d% @1 n( O; O
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
3 l! ^8 ]4 b5 V# }
RandomGrayscale(p=0.025)
* @* F2 \3 M) k- l3 R/ t+ q7 S
ToTensor()
# \* u# Z/ @0 d, ]
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
l# N1 y7 a6 `
),
1 u* f; w. j4 O5 c) Q
'valid': Dataset ImageFolder
, S5 G1 r3 g8 o, z- N8 a
Number of datapoints: 818
! M7 S8 U+ p" R0 E3 L8 y7 u! M9 O
Root location: ./flower_data/valid
: p+ }; f! V8 B7 w; P
StandardTransform
, B1 u. ~5 _. H# E
Transform: Compose(
0 e, x8 M5 E5 _: u1 F! ?) T
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
a% O8 [* [' W
CenterCrop(size=(224, 224))
6 h. I: s1 V5 Y
ToTensor()
( f& Y! a- J, V5 ?
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
) d# F3 ]: x' K' }4 V" L' V
)}
& G( i$ S- p/ h, O# ?
0 @& b! v( W' D) W
1
; y7 u9 i. r( u& ]( G) [9 o0 n
2
8 C) B5 u& y+ o P
3
) V( v) M- o0 t9 R
4
5 |4 d+ x& |) f! M) d
5
" F" r. b. |2 F+ J9 w6 t0 T
6
0 M, |0 w2 I" g! Y3 q
7
# I1 V0 I) d2 W* f
8
( r0 T; @- b8 G: _
9
) N4 _: T) O; B+ U
10
% {. S; q! c; g* k2 X1 z
11
3 p, B- [9 [ [( C3 e
12
6 m" r+ L9 ?; C* A7 V9 z4 z
13
6 M% @7 B' e" m' v! i
14
- u; l* p9 N8 g' `
15
9 {2 R. s. v" B8 S4 w9 M% X1 b
16
7 Q; @; g& l0 O
17
& k) c$ u1 s/ o
18
8 k/ w, i& B, q7 M
19
* o8 N) \% c& J+ t5 e4 V; M. A
20
5 R7 V( a! A+ H6 S
21
, ?- W' x0 K! r! ]3 c
22
) ?* j( J: X- \! O
23
x! }. M) P2 e' B
24
( F; J- K" x4 L) T r
# 验证一下数据是否已经被处理完毕
9 W3 w9 E' q- L# X
dataloaders
: \$ S: M- ?/ e
1
+ k3 R6 `1 T- `' `
2
% E" l4 R6 B3 r! _: e
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
. K- p* H' s& e) C4 l: B
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
( N: F4 U: P7 }6 X5 {
1
! U' p/ }/ T8 }, D5 G X
2
+ W" }' \& t2 K) @1 ]0 u
dataset_sizes
+ e0 l& `; `# O
1
# P a: R6 G6 L/ x2 B" [
{'train': 6552, 'valid': 818}
& l9 P8 I2 A7 O: @, t, Y
1
0 J# H/ d3 F) v4 d8 I
读取标签对应的实际名字
) O5 Y8 \3 W6 \! f
使用同一目录下的json文件,反向映射出花对应的名字
( s5 S( z+ O8 o* q. P
; M8 ^) b3 w2 G5 i8 W& Y S- x, F
with open('./flower_data/cat_to_name.json', 'r') as f:
& v0 f3 A( k* ]: D4 e
cat_to_name = json.load(f)
( Z4 M2 H' |- F. A* U* P
1
- Z( G' A6 r- x+ P) f
2
. W$ E& o2 A, w/ f% h* ^8 g8 y
cat_to_name
, b4 v- {9 q" e/ C& N
1
, }* X4 ?7 _; K/ W* o! b9 L/ b7 B
{'21': 'fire lily',
0 r+ S, R! b2 E: Y
'3': 'canterbury bells',
" X( K. x/ M) q/ M
'45': 'bolero deep blue',
0 e, i% X2 s1 ]/ W6 l& O
'1': 'pink primrose',
* ~7 R$ R0 ?) P
'34': 'mexican aster',
: ~( x5 ~5 [( _
'27': 'prince of wales feathers',
* z1 x* f! c' b0 l
'7': 'moon orchid',
, E4 g$ A$ F1 j Y" ~" z; n
'16': 'globe-flower',
+ Q5 K9 P" I' d
'25': 'grape hyacinth',
0 J0 U! B; p e% M* n% @* R* L
'26': 'corn poppy',
+ I7 l$ ~! `* A" ~# x" r
'79': 'toad lily',
$ q2 @. H G3 d! q
'39': 'siam tulip',
6 n/ K3 B- ^2 ?
'24': 'red ginger',
: S; c& U/ y" b
'67': 'spring crocus',
& W* g$ d2 i" [* e, u6 K# w' |
'35': 'alpine sea holly',
& l, M5 _$ q9 D: Q( d9 j( U
'32': 'garden phlox',
" {, e/ ~1 W; o8 U3 P
'10': 'globe thistle',
+ H E! h6 u) q1 |3 `4 `
'6': 'tiger lily',
. Q5 Q/ _. a/ O# b9 e0 u4 O; C" }
'93': 'ball moss',
( Y7 D/ Y! G: {; A* T
'33': 'love in the mist',
0 c' f' [/ G$ d* I, `6 r
'9': 'monkshood',
: D* B) m; Q& p1 Y3 W* ?+ l
'102': 'blackberry lily',
- y$ K. J: [' E) o, G7 k" o, P
'14': 'spear thistle',
. c1 f4 r4 O$ s7 V2 S4 |
'19': 'balloon flower',
( T5 e$ I7 B' }; X7 g; l
'100': 'blanket flower',
% d4 \8 c) x$ ]+ x/ J0 a5 v7 }8 T
'13': 'king protea',
% F/ F% i" O1 j6 b. {
'49': 'oxeye daisy',
/ m, R; N& Y9 u. P1 }! A* h
'15': 'yellow iris',
: S3 l2 s1 v: p0 \
'61': 'cautleya spicata',
D! L6 l1 y# b& ?9 H3 M
'31': 'carnation',
. V1 Q$ C) R/ ]3 I/ D
'64': 'silverbush',
! i1 W. m" o; m w. |- N# E; o( q+ ?
'68': 'bearded iris',
. A2 H$ y* ?4 a9 g& b1 X* l* W& v
'63': 'black-eyed susan',
( r# W# d: v% [ R6 ^
'69': 'windflower',
6 D" M) }; M' m5 m1 w+ x# R
'62': 'japanese anemone',
" g6 Q2 x: _) W; U3 J- y0 F
'20': 'giant white arum lily',
2 {# ^9 g. `- C3 u
'38': 'great masterwort',
# S; O- Z( I. e9 e
'4': 'sweet pea',
5 ?$ q7 x, M- X2 q' _
'86': 'tree mallow',
, F+ b4 a! g9 k6 u
'101': 'trumpet creeper',
# b; m* j( L$ ~4 E7 e, I
'42': 'daffodil',
# H& a. F+ g: j& m) j0 w& C
'22': 'pincushion flower',
& n; W( u5 S- q! a9 T J
'2': 'hard-leaved pocket orchid',
2 Q$ Y2 `. @1 y3 {0 Y$ k; B9 [3 ?
'54': 'sunflower',
6 P5 @; i4 M) I: x
'66': 'osteospermum',
2 l8 [" A% P- Z, I
'70': 'tree poppy',
8 z" ? f) Q: k9 e+ ^
'85': 'desert-rose',
6 }/ A. l) o* S2 R u7 J2 C
'99': 'bromelia',
5 n8 ~+ u) G- E% ?1 P8 Q/ h; }
'87': 'magnolia',
0 U5 w6 ]4 {9 _% Z, r( z" Z
'5': 'english marigold',
& c9 a) t# `" _* ]. B& v. g/ V
'92': 'bee balm',
% J" j! J- u* a/ H
'28': 'stemless gentian',
+ n5 f# e9 ?6 j
'97': 'mallow',
; m1 r* V' [3 _. G0 _# |
'57': 'gaura',
. C/ W5 M, D) G* P( R8 p
'40': 'lenten rose',
, @8 f/ T& B' o9 {( Q" Q1 q: X
'47': 'marigold',
" h/ J; O3 ?3 ^- _& q9 |2 S* i9 ]% ]
'59': 'orange dahlia',
* C- V2 \; i$ |; l
'48': 'buttercup',
8 ~6 z* i3 Q# B0 A
'55': 'pelargonium',
+ E" C& e7 q. Q* h
'36': 'ruby-lipped cattleya',
0 \, h9 J- r3 s( z& ?1 V+ J8 {
'91': 'hippeastrum',
3 C" B: E' i/ u1 w$ ?
'29': 'artichoke',
* h( c: c: @8 h
'71': 'gazania',
1 i; [- M$ v' R( f: [! ^& k5 e6 [2 L
'90': 'canna lily',
3 g+ K% a7 u' }4 Q0 ]& S
'18': 'peruvian lily',
0 x7 E. C, C! O7 }2 X! U. [0 U$ W
'98': 'mexican petunia',
& S% p c7 v+ \& h9 o
'8': 'bird of paradise',
7 C( |# B1 k _) ~9 N9 d( k
'30': 'sweet william',
% G6 n- R- g% o$ P
'17': 'purple coneflower',
0 R! n: @# ~+ \
'52': 'wild pansy',
q8 j! u# B7 p5 R# t
'84': 'columbine',
/ {; I- ^& S2 m$ V4 ~' f/ k" o9 e
'12': "colt's foot",
1 J( F: j2 a& H+ s& t/ U, z& G8 O
'11': 'snapdragon',
2 Z+ Q1 a8 C* k5 B# a/ N
'96': 'camellia',
9 O( p3 R/ x1 k/ w- |2 v
'23': 'fritillary',
9 j. \: l* ^" s0 J! V
'50': 'common dandelion',
) I* m, F* M* y$ B9 e. d$ ]
'44': 'poinsettia',
4 p4 O" @( N) L( J! D, @
'53': 'primula',
4 U' j8 Y* W5 \1 q* b
'72': 'azalea',
+ z, |! H' a' I+ S S. A& Z0 W( {
'65': 'californian poppy',
$ ~. l) a- n$ [% S R+ O' c
'80': 'anthurium',
3 Y# h. t0 v) y) A+ ~" Z6 o8 M
'76': 'morning glory',
4 z7 ]5 ^+ V; @) T1 |8 ]
'37': 'cape flower',
7 ]7 M3 ^2 T% n3 Y# c
'56': 'bishop of llandaff',
7 m. d; A8 j; J8 |% Z+ O* K7 h
'60': 'pink-yellow dahlia',
+ {3 f0 {" k. Q2 M& i
'82': 'clematis',
: t8 d! f' p4 p P6 l: {
'58': 'geranium',
+ R$ C) [2 q9 j- Q: o0 l% ^
'75': 'thorn apple',
; O" n) s: h' B
'41': 'barbeton daisy',
% M! k7 y5 s& K% E3 q
'95': 'bougainvillea',
4 X& U7 K' y* G$ ~% \8 V
'43': 'sword lily',
& Z# D: \2 l' L+ o5 p0 c; X
'83': 'hibiscus',
5 t" W- R; _& c1 z" a" q- y
'78': 'lotus lotus',
" b) c9 d: T7 x5 `8 u
'88': 'cyclamen',
0 S7 I2 w4 X0 Y4 B2 _
'94': 'foxglove',
0 `" T, q# }7 f2 @, h
'81': 'frangipani',
9 C7 L* V2 ^6 i4 O5 ?2 k
'74': 'rose',
" s' V) M( e# s; X8 [1 s0 S R
'89': 'watercress',
5 V8 N4 z7 ]$ B8 ?; ~. c1 f* l Y
'73': 'water lily',
! b5 H9 P! W" ^2 j
'46': 'wallflower',
1 r {" @: x2 `$ K. V5 [
'77': 'passion flower',
. f+ @0 Q+ c, ]/ x1 u
'51': 'petunia'}
7 m, S" m. Q7 L8 Y0 \
6 l1 Y% y8 N' y6 `; [
1
2 x( I2 z) p5 K3 R
2
% j& Y- r4 y: x+ r3 W
3
; x( k0 p; N- i) z" H( |; U# ]
4
1 v; F; G0 X" q: K9 t, J" f |0 K
5
! i5 H e W" k; S% \/ _ [
6
% ?1 e2 D. d; M5 z% Z' H E
7
8 Z1 I" D8 M" i! {2 _
8
. U5 E4 S( U) W& f2 \( e% W6 ?
9
" r* F* x4 @ W1 b
10
( O4 O' K7 Q# g. F+ r
11
9 V' z4 }: ~. ?& A6 ?) G* ^8 u
12
* N3 }& ~3 g. c0 ^- b
13
' v+ T7 M! i. u, v8 Z
14
( u8 m6 R f }+ F
15
) `3 n9 M1 h6 B
16
# J: p2 `3 H: w
17
! Q' k7 B, b& r1 `
18
+ s* V* a: N/ ], o* I9 d2 x. e3 ?
19
1 I% d0 _4 f6 z' ?! S8 r
20
1 H! ^1 Q7 ~ e
21
% c# f7 ^4 ^1 p" R2 R( @7 Y' P
22
' x, e. c' w' v+ x/ j
23
3 ~, @. B0 F: ]/ \ `) F2 g, }
24
& c5 G, x5 [; K+ N
25
' z" Q- D3 e" x
26
# L6 s# f7 F1 l% k8 Z X6 g$ t
27
( U& F9 B ] }: l* r' R
28
- b0 [" v- K$ C1 W' D
29
; p. w8 _; B1 `) D1 q
30
9 ?5 n: ^$ m$ |! `5 `
31
3 d% s, S$ S- r3 {4 }& _) C! j: S
32
5 X/ k0 N6 Z \0 K" l7 f. K" o- Z. L# n) [
33
* r) v; V& @6 ^% F9 ~
34
) f N) i2 n! y! u# t
35
* \! w% U8 u2 ~4 l( V/ z
36
8 s% N4 k6 z, z5 c& w
37
8 u, f: o' [/ c6 i6 ?
38
: i( u( Z1 A% M( G, U7 v
39
_, N5 {- j$ D2 M3 W& U
40
+ z. N4 t+ S. L, j* G# L
41
; B( s( Z; J6 r' D, H
42
- i. g7 w0 w" L
43
- ]- l9 F. m4 Y1 D! s* \" F
44
# d$ b# G/ i/ b& n# d4 L
45
" z! j j0 m7 r4 p
46
: a1 K9 F; S1 z3 Y/ r: V, @
47
' s1 C9 S x* V- b! o7 Y
48
$ x3 G" b" G8 S9 L
49
' B) f- x& W7 _3 b; h
50
% n; X5 L) N$ ?* Y+ D
51
- w: j: ]& j% Z( _7 w8 J3 S2 l' ?
52
2 p0 P- t: X+ Q* X
53
! @ O. [2 q6 M) P7 }9 j
54
! }$ ]9 i, d1 D% \3 h# k
55
7 l5 e8 a! |! v% s) J
56
( n8 i* [' F: C& C! E5 G( p4 E% k5 z
57
. {/ D; o3 b1 o/ P. q- P2 R
58
& c/ ~1 c" I& u4 N: \/ M
59
& I( H m$ Y# \4 i" a/ @: B
60
' m7 Z9 N& t. S1 j
61
% B; M& p& U# B) H- \' d
62
6 [3 d8 f5 b; k* W4 H4 _
63
% s! [9 U' m" p% u) ~. t% {6 o
64
' t- n) y( S& G, r
65
% D; q: C& m1 ]. e+ [' n
66
% p0 V3 {$ D) W
67
0 [% F' m' X+ Y( w
68
- E; y% B T: I
69
# E; A" O# y$ N, U
70
0 k: a \( Q; r- L
71
- ?+ f# h5 I' M1 E) j) {2 J& a
72
/ i2 l* Z J l- k6 E
73
& M$ j5 \% i& M- J3 U
74
7 S/ A: m, T, K% a
75
* S( x6 j7 F( q$ X. U" \3 d5 s/ S4 c3 ]6 L
76
% j: O/ O" b3 P8 U( J& c
77
# Q% P4 q8 O1 W( `3 y3 w8 S3 B
78
0 c1 L* ]1 D6 N4 ?2 Q/ N5 m
79
, Z H8 j; f0 a, _( Y0 i g
80
% C) ^5 |0 N2 y0 T
81
k& D9 H( o/ K+ Z" z% C
82
4 S4 m3 Q9 R- r" E
83
4 D* J, r3 D' M6 C2 Q
84
' U7 w0 p- z" O! [& U
85
' ?; Z) v6 ~ ^; L/ O0 D% u0 E
86
. ]0 i: m/ |8 |* e' x$ N" ^4 c! h
87
1 x0 i& v( A- C' g; y0 ^
88
( R4 ^* l; u8 t3 h k
89
; ~% \. \& @: D3 \- E7 L2 y9 {+ D* w
90
4 N% N7 S+ I% A8 b3 q2 D/ K/ w9 @
91
3 H8 D# b( _" l f' c% q. Y
92
5 h5 w( Z. w7 h9 ^
93
: \6 E# k3 J& V
94
4 U P* e, q0 S% |# b, A
95
- \& |# B/ e4 m8 C, S4 W, A8 c
96
* x# d e% J" w! `! J- l9 k) d
97
! {; b. u+ D6 P+ P) i# m
98
8 I- l8 b8 W/ n3 B/ Q3 F
99
$ i2 `3 Z8 O) v+ M; U" E7 Y
100
4 @8 u: |8 ]" K
101
" L( E+ L( o. q; S
102
% I3 V2 K2 E5 l% |
4.展示一下数据
2 R; z( y6 R+ A2 Q; ]+ f
def im_convert(tensor):
7 ~$ t$ H5 F6 F3 d( Q" b/ g
"""数据展示"""
) H) I. ?" J9 G1 s8 E
image = tensor.to("cpu").clone().detach()
! y7 v/ @% `( i8 K3 M, j
image = image.numpy().squeeze()
6 x7 k- q( n6 x3 }0 p; d+ o
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
4 G8 J: m. S* e% `
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
2 ` O$ @' u7 S- k$ A( T9 D9 _% v) J! c- R5 H
image = image.transpose(1, 2, 0)
3 w. ]% j" C+ H3 J
# 反正则化(反标准化)
; c" C/ I5 f5 z
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
# ~7 S8 V9 }9 M# {3 D' m
* D1 @0 ? L9 V% L
# 将图像中小于0 的都换成0,大于的都变成1
: u1 ~: k( b/ Q' k
image = image.clip(0, 1)
4 l: T5 Z/ |8 f/ O3 |
' L6 w1 L9 ~! x, A& M c4 w, c
return image
5 q3 o8 J6 ^/ O/ D, y0 g0 z
1
! B8 i! H( D" B* @" U" k4 I
2
" i" J* q8 t( Y4 o" X* T
3
Q! c8 E. [ b8 z; Q. m" n o
4
! q- U1 H" l* {% {9 A" g9 j
5
4 N S5 g6 O, l3 @8 j. l
6
+ ~ ~4 |0 B8 t, h% D; Z
7
$ W% [, D3 b9 o. i+ F1 |
8
7 T& S' h/ g! ?/ a( M9 j
9
G& e. y+ D0 K* P6 G+ N1 H
10
3 I& D, {9 T# O; Y$ X
11
: ~1 A, Q: w0 u6 r7 w; n
12
! \( ]1 G9 V) M8 o5 X/ w
13
# E- E- z- w A6 E2 m
14
# \3 O) j7 s$ x5 c& \+ w
# 使用上面定义好的类进行画图
# J0 p9 j% W: {: t! J
fig = plt.figure(figsize = (20, 12))
5 |% a4 {; ?; v
columns = 4
, Y3 M5 V m, G v$ R' I
rows = 2
) Q7 B b! y' r
6 Z4 u& S- S$ d. u# {+ J8 ~
# iter迭代器
8 U8 S9 G3 Y1 u
# 随便找一个Batch数据进行展示
( R9 }! t5 u6 c9 A. w2 w
dataiter = iter(dataloaders['valid'])
7 J; l8 \5 u& s: Z4 `
inputs, classes = dataiter.next()
! J4 I$ T! M; V4 e
i1 s/ E1 D* { \/ p
for idx in range(columns * rows):
6 L( t% e5 u% `! W8 C( F1 x2 s' P; N
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
) p8 a3 T: q6 f6 I8 a
# 利用json文件将其对应花的类型打印在图片中
8 ]) r/ }( v% n8 |) P+ b
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
4 V0 ~. A, @) C% @4 a4 L
plt.imshow(im_convert(inputs[idx]))
0 k' ]: k7 W8 c8 f% B: g
plt.show()
1 }0 @3 o8 ?1 h0 b- A: J. e: _
( ~- x1 ~3 U) v/ c& z2 T& D
1
. O* ~+ H/ I- b9 Z' X
2
% O5 ]) Y$ X% p% Y
3
+ U4 r1 f6 g( h. F R
4
! e( f: b7 U0 o+ x) v9 ?
5
g9 W4 ^5 z) M6 m
6
" C+ }3 T" C7 t- N
7
9 Q" ^5 B7 @, ^
8
5 c7 ?: }# A% L! m9 @
9
" D: [: ?( T1 M& k/ T A
10
9 t/ u: }2 O" ]4 Y3 d
11
* O" I" F+ Q& U1 }4 `4 t$ g/ |
12
, \# z+ _% ^7 W7 n! b
13
/ O7 a) o8 n! \ ^9 m$ ]7 }9 s4 ]% p0 K
14
4 k5 y W2 _" L0 l
15
g) R. I1 s0 k) l: K
16
% p1 ?; O; n' @& @! W
& }- c- V# |8 k
8 f0 v( n! n) N2 g2 w
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
9 O/ b) r4 G( M# ? q& G/ U! K+ p: ?
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
& f% v! E: G D4 G+ U7 j* e' M
# 主要的图像识别用resnet来做
- W1 \4 u3 `4 R T& `: _0 }
# 是否用人家训练好的特征
1 D9 X i+ \* L
feature_extract = True
- g: \- {, l/ m/ O0 F. a
1
; J6 q) T' N! k
2
6 C* A! {+ d/ s4 C* Q' g
3
W B7 y; D4 W3 E
4
2 v y7 C/ T; J) B( s' ~7 W
# 是否用GPU进行训练
2 \- V: N3 P2 ]6 L& Z, w0 w, h' c
train_on_gpu = torch.cuda.is_available()
8 o$ |1 S& Y, o. [* U( S+ r) }
. q. i. E* _8 R1 P
if not train_on_gpu:
, o# _, R/ A) p0 m8 y
print('CUDA is not available. Training on CPU ...')
$ a0 v. x9 n1 W9 y1 ~) Q
else:
4 g" f( @/ p# ~, v1 Y( p$ o
print('CUDA is available! Training on GPU ...')
2 _( C$ H& G8 c4 G' G% K' d8 c
5 k! [! m+ d( r v
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
8 Z- n2 v* }' j8 k/ d/ P: h
1
. |4 [! v3 o7 s; Y% O
2
* u$ N( q3 s; S2 u# D) A
3
. Y3 z) O+ X5 l- n7 V. \/ @
4
# }5 w3 a) q5 e8 i# Z0 W' e
5
! {) n; B- \: w7 Y
6
, Y2 r3 t7 a; H: i& y( h) g
7
+ c2 d9 j- l; [
8
- e/ \5 E4 t% ?2 F
9
[ m* s' ]. J, b2 g5 F
CUDA is not available. Training on CPU ...
: z4 w2 x' i$ D: S y
1
- S* M# S6 t- z H0 S( n0 ]9 r6 u
# 将一些层定义为false,使其不自动更新
7 m# z9 l0 r6 y$ E% W3 ?1 V
def set_parameter_requires_grad(model, feature_extracting):
: T0 g4 X8 X1 O3 g8 _
if feature_extracting:
0 p! x3 k' ?: A8 U$ }% g5 [
for param in model.parameters():
7 `5 n+ F" M# x8 S: S$ d
param.requires_grad = False
' v# k' D$ D9 D9 O8 z9 s K
1
) Y" p* h7 f: l
2
; n9 Y% H3 D$ Q5 l6 q- C
3
6 l! Y' [8 [. v: b( C* y9 d) p
4
/ |5 n/ V) h' v
5
! z% g0 N1 t6 o3 G6 U
# 打印模型架构告知是怎么一步一步去完成的
: _3 C, ^/ d/ y
# 主要是为我们提取特征的
9 g8 Z) H( u+ t7 z( n7 _7 V3 k
' s7 K4 A# \, o7 x
model_ft = models.resnet152()
9 z r" c1 d6 ?" v8 X/ K/ E T
model_ft
4 Z# [! G0 P g( l2 W6 k( h) o K x
1
. w4 `+ p$ E9 ]# U: i6 l
2
; c2 w) D& y0 O, f4 o7 `+ ^0 D
3
" [% i+ _! z3 l7 e) Q6 l
4
$ R" M& p/ p; f- z: y u2 o
5
0 W( n0 K' X0 z, D' ~
ResNet(
5 w9 O6 s6 F/ q
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
7 U Q" n3 n& d, M. u* Y( K8 m
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
7 P( y, B- z0 Q8 Y9 J
(relu): ReLU(inplace=True)
' L: h2 h+ r3 |1 x$ k0 ?; X% J( W
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
1 D2 q6 u9 n- `& l
(layer1): Sequential(
0 @6 m3 o5 n/ } n8 a2 J0 F
(0): Bottleneck(
2 r" w/ A. Y* G) N5 O; a1 d" a
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
4 G) u/ a5 [2 @: Q. g+ @. T
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
3 {1 L- V! G3 W& ?6 m% F- c, K
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
4 L. V; L7 M. ^4 P
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# `% v1 s s7 ]: s
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
8 p0 ^' a1 K3 k# v h8 t
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
: a3 `, [: [5 ?# \) _, E) Z, D$ M. P
(relu): ReLU(inplace=True)
7 ~( R( H5 i u" e
(downsample): Sequential(
! h" J3 j0 D, T$ ?# c
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
$ d% J" D% T3 O& c/ o
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 g0 c' x8 p ?; [
)
H3 s: {8 t, G. Z
)
& v9 Z- v8 X3 ?( Q C) @5 g$ l
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
5 P2 X; Q% b- E7 P5 |! \1 W
(2): Bottleneck(
* w" c6 O( V1 E# p8 g/ l6 a# Z
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
# @% t7 V5 ^, e3 M* y
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# S: Z% A& q# e0 W1 ~& T( d
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
N+ s/ i. i* c5 n; W
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
! \, v+ ]$ I# M9 X+ e* p
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
; w9 v% t6 d/ ~& T/ f
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
* y% R+ J! r$ P4 N& |8 p$ n
(relu): ReLU(inplace=True)
9 b1 _) b1 \3 |: ]; B
)
7 B; B3 \" Z' Y/ {
)
# S0 }3 v9 `4 Z. ]+ @% Z
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
5 b9 ~3 T2 {0 y _# w
(fc): Linear(in_features=2048, out_features=1000, bias=True)
! N" n/ \) N7 S3 u6 a# \8 B
)
8 v$ q/ I4 C4 [' ~3 X- n0 `" Y
) j& _3 {$ A. N6 k" E
1
' r/ s8 S$ w0 o# k* g
2
5 |2 `1 h+ x1 U) k9 I( H$ f# m
3
6 {7 | U6 n2 C* D; I' ~
4
" B5 L0 P" y! ]
5
; p. ?1 x; K7 A: r
6
- M; I0 C2 G3 S+ [! X: H3 R @0 g" X- }
7
& h' G# h6 a* x) q
8
- S1 ?, g4 \3 e) }1 y
9
; g5 K0 J3 \* {' R* d/ v
10
0 x/ j: ]2 `( i
11
- O, Y2 s8 z( t& M) r7 \
12
- Z( V4 _( u, U
13
. j! {! O3 {7 n
14
. i8 p# I, \) P# o% m& x
15
+ ~7 i1 Z; S- @. F g$ j
16
5 K, _: ^8 I. r7 W( V
17
& A+ r- @ k" S& }0 @3 j# j* ?
18
$ s' y) y9 B: S: T6 f' X0 `
19
$ M5 F' z- f' ` G0 O8 f$ @+ H3 B
20
( w6 P8 A, a% [% }
21
! }; x, z* k; v2 h
22
d( {" `: T2 K6 G2 _$ [
23
5 [6 s. A' _$ A# n: s$ J! U% Y3 @8 ]
24
/ ^ a2 i, y) }& w7 m' q( o
25
# u. U" g; c8 w8 M2 m5 |
26
! u, a7 F+ Q+ F. x8 G. M7 d
27
; ~# q4 l) w \' }, j N
28
6 h+ F8 b& M g( Q
29
; u2 {1 w9 Q2 N. t5 }7 d+ G
30
8 `$ g$ L. \3 w) j+ L. |+ R
31
9 s7 [2 H+ X' {' v V8 d; H
32
" P: r$ F( | c
33
9 q- a: U+ F0 C+ u/ W% z
最后是1000分类,2048输入,分为1000个分类
W% _) t: p( k( A- m: A
而我们需要将我们的任务进行调整,将1000分类改为102输出
; k/ }! W0 I- ]2 q
# R- f& @$ E/ P' w' a5 u
6.初始化模型架构
% x$ _1 {* F5 N0 I9 w* j
步骤如下:
0 @2 R Y# K4 [$ z; `; X, O3 h% B
/ R: O4 ^0 A: I$ {8 k. z
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
5 r( L' k& e0 u$ {
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
+ b* ~1 A! |# f# `
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
7 F* }/ d' t) t; S! ?" @) @
官方文档链接
( w( q+ s. P7 F; x2 o' u, b5 }
https://pytorch.org/vision/stable/models.html
' j6 D' b3 B- |7 q o! w8 o
0 F% y/ j& Q* N& D4 ]3 \
# 将他人的模型加载进来
+ H1 ?: W; K( H- C. k
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
) N6 f! P9 ?0 Z& y: d
# 选择适合的模型,不同的模型初始化参数不同
E3 ]$ J: {5 p
model_ft = None
; k s4 b9 X4 [# N% ?1 b$ L
input_size = 0
5 i n: N. \& c$ _- {
5 I3 I6 T) P9 U. C, W7 E) d
if model_name == "resnet":
9 m0 z' t( w7 j! d5 ]* z
"""
) Z& I$ Q7 ^% T4 l0 c
Resnet152
D2 Y5 o9 @6 D; R) O8 F% L7 B$ `
"""
( l$ `3 W+ C0 l- v
0 [4 k% n: ]6 P6 N8 h
# 1. 加载与训练网络
" {3 ^6 E) T% c7 U
model_ft = models.resnet152(pretrained = use_pretrained)
3 R9 z$ F! b8 G9 V+ Q
# 2. 是否将提取特征的模块冻住,只训练FC层
9 u& u8 Q; \. P% k
set_parameter_requires_grad(model_ft, feature_extract)
; F0 o! [* l2 ~4 W/ z" X. m) z
# 3. 获得全连接层输入特征
6 F: h* ~5 V2 _' l! a
num_frts = model_ft.fc.in_features
# z; }/ W; l: z2 W
# 4. 重新加载全连接层,设置输出102
" j& J! `1 d8 p& [
model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
# J5 S. d' I! ~% o: Q
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
1 i- ^2 o3 w& P& ~4 V2 U7 x
input_size = 224
* D: c3 p: z5 n- g, S1 f6 h7 H
* ]* w! v6 S) C* p! g% ~% o
elif model_name == "alexnet":
y, m; N4 `# M+ H6 n
"""
: ^" T0 b. l6 G0 t
Alexnet
# N9 c; E" m+ F7 {: v
"""
+ Y8 y8 o/ K7 @9 w1 T" h) O
model_ft = models.alexnet(pretrained = use_pretrained)
* X1 g$ Q' t7 _ X
set_parameter_requires_grad(model_ft, feature_extract)
# v$ I' C4 f- x) Q1 h& S
" |" ^' F8 g8 y9 ~
# 将最后一个特征输出替换 序号为【6】的分类器
$ z9 q7 h- v: s$ A8 l
num_frts = model_ft.classifier[6].in_features # 获得FC层输入
9 L9 ^& [+ S, I% P0 I
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
! y4 p: l! b$ V1 ~; n( X5 s4 A
input_size = 224
3 }" `, m8 H! R: j/ l" F. M
8 w2 u+ w9 ?+ E. @4 X9 u4 ]
elif model_name == "vgg":
% V: @! n1 j6 _' A$ b
"""
0 e* p8 v/ D, b, W' r
VGG11_bn
0 L$ a& n K. B F! K4 }
"""
; A/ B/ E7 ]* M# @- f4 i
model_ft = models.vgg16(pretrained = use_pretrained)
& U- @" U) ~- s" G- I8 x
set_parameter_requires_grad(model_ft, feature_extract)
8 u0 |0 A% e; w8 V) L' M u
num_frts = model_ft.classifier[6].in_features
9 N, |' J, s& _; X, `" W
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
" L* K# r" b- X i$ L4 B3 |3 O
input_size = 224
% u& [+ h( w, c; [ d3 I: M% I
8 L: Z3 F6 M E5 `6 x8 r
elif model_name == "squeezenet":
5 D( G( }& `% S! R3 x6 n* ^, o6 f
"""
- p1 L9 B) J) j
Squeezenet
; Y& @; U* O2 F1 ~) j# i" Q/ `
"""
2 C- |. L6 L/ t6 T$ ?8 n
model_ft = models.squeezenet1_0(pretrained = use_pretrained)
& i6 C& T" N: G) f8 D3 ?
set_parameter_requires_grad(model_ft, feature_extract)
/ q# a- d1 G4 r0 K# t d) ~* k
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
& E4 p# c* G! Q. b$ c* Y/ y
model_ft.num_classes = num_classes
( r, r/ @$ W8 p: t9 z* H9 J
input_size = 224
9 ^( v% {6 z' z! _" d7 ^( Q7 L& J+ {
! Y5 F. O6 J, Y
elif model_name == "densenet":
" v8 t1 T9 ]# o2 C G( _
"""
% W: B4 ?- @& j7 f7 {# T4 n
Densenet
7 G+ C/ M# o5 ~9 ]* P! k& p( W
"""
/ |1 Y# H$ M1 [' K
model_ft = models.desenet121(pretrained = use_pretrained)
9 m+ ?0 T- b. f/ Q" P8 o
set_parameter_requires_grad(model_ft, feature_extract)
# S8 O& ^8 i$ @
num_frts = model_ft.classifier.in_features
1 b; v7 _! i0 h4 w* p5 C
model_ft.classifier = nn.Linear(num_frts, num_classes)
3 p5 g9 f5 h4 X. H
input_size = 224
& u8 D! \: ?' K
+ F; c8 w" D8 g, X F
elif model_name == "inception":
9 a2 A3 ]0 y+ f; o( l, X+ u7 j9 Z
"""
- S+ H- B0 l$ `' V7 r) q
Inception V3
1 u/ K0 b, S6 j* F3 s
"""
1 L+ A, P8 Z+ D9 `
model_ft = models.inception_V(pretrained = use_pretrained)
; F$ I* n; R+ \ S4 G: I, a& t
set_parameter_requires_grad(model_ft, feature_extract)
- _* O3 u* y* ?7 O$ ~8 c4 @
5 a5 X" a" X9 q. \! R
num_frts = model_ft.AuxLogits.fc.in_features
2 v3 B% u9 m) O
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
8 {: C1 k0 H8 y# e6 ~
! _! }$ I; _- _1 `
num_frts = model_ft.fc.in_features
7 V; d% R1 ]9 ]
model_ft.fc = nn.Linear(num_frts, num_classes)
0 c0 N3 J/ o+ I* {- y
input_size = 299
5 ^9 B4 R3 `1 J2 D! X6 C: V# o
: z& P: k: I, T# H& g* u6 D. J2 j
else:
6 m- U( s, X s& E2 N
print("Invalid model name, exiting...")
; M9 U K9 s1 j% k5 r" P
exit()
5 {! \+ U# i) m3 M0 P- |
A2 T" q. `' E, |% {0 z7 Z. }" e
return model_ft, input_size
0 Y$ B8 m; y7 u' d* I* m" }1 p' D) Z
# z3 `: k! I( a" e% }/ v
1
+ w: c9 B3 H i
2
% i, `1 L$ c5 ^3 D9 T+ k1 u: \
3
' d) U% u+ [: S& n4 J* P* `
4
' [4 f# D6 Y; h8 R
5
, t$ [3 g7 Y. y3 B0 w. q. P
6
4 H8 R: ^- U+ J7 f& M% ? q
7
. f5 r0 l1 T, v/ y* k% a2 k
8
( I+ G+ {# t0 W/ [0 r) `
9
/ y6 d: a5 M+ I2 O
10
' n3 W A5 v% W6 | S7 R
11
/ c z/ ?2 b$ q( M$ Y
12
5 K) N* q0 ~ w1 I7 a( o) n! E+ w
13
7 Z4 q' K& g8 y' ]" h9 `( F7 U
14
: V; S. o" j; i+ ?
15
9 M8 I6 o5 `9 d; U2 J
16
' L5 m3 J! h3 B8 `! ]! T/ h6 u
17
2 g' @& U0 r" B* b U
18
; x7 y7 h. Z. x1 {
19
/ u! d4 ~' L6 R
20
% ~' Z' f: ~; j: F/ P" x! [- Y0 P3 G
21
" N8 |2 \; g2 r) w; M% c3 V- g
22
( y! ?" b8 e0 N! R2 \+ T
23
6 F3 l% [' G5 \) @" f
24
4 j% Z8 [& K: a+ d0 W% U# H( T
25
. M* t# B, Z& S G, P; ] R3 z
26
' s3 ? B: z+ p. j2 v1 m. ^
27
/ |9 B! Y1 P m: f
28
8 n3 g. M) s9 T9 N
29
0 O X8 k7 h. v
30
5 c8 V t9 _; N
31
6 ]) u u5 q* [& ]! N# ~- y
32
2 Y) B- j" e, Q$ @) v; N8 J M
33
, I5 w- Q! n7 l6 v4 W2 e: Y4 I
34
* N! ~& Z' I/ U; K8 K' O' u' H' R
35
+ v3 r9 T% @' }; y* h2 e9 I$ W
36
, u" }( ]& }% q+ K
37
, X4 ~! C" i6 w$ p, |6 l$ d
38
4 U# |7 O5 a) Q1 K9 @$ s
39
5 y. p, d- A8 K. r( D$ F. x
40
. ]3 k+ }+ W4 ^
41
/ r y6 u7 U3 k. ^2 j) N
42
1 C3 j& w3 }2 |- u- I; b
43
/ S' y/ V: M5 a9 P. O" C
44
$ B" @5 M @( t4 N5 X9 N) d
45
$ y& U8 K8 [# W$ R7 z, @: E
46
. C5 i. g" K, H
47
" Z0 Y( J3 b# E# G2 W4 _2 C- ~) N: Y
48
* j, _ s$ ^/ V" r' ~# h
49
E2 Q: C& E1 ]4 r2 B4 R- I: }' I
50
1 ?0 T. Z8 v4 Y0 ?7 {
51
/ H) q/ G+ m5 X$ r
52
$ Q* y8 K8 P9 R) W L/ q6 [, w
53
# n6 W7 S) O, X8 b% V6 B
54
+ ~+ V& {' c* X! M
55
9 e. P3 l- y( B; c' W
56
$ c$ o: z" E: L5 S: A
57
+ t0 d* L" C/ }$ ]4 R9 r/ i: s/ G& |
58
) @2 `+ R, P' p' h" Y
59
" A* O; G. E7 x# t! v' u# C4 g
60
7 q& d# l$ d* R+ L+ X
61
* [. U, k; x" j( R& O+ e. x, _
62
$ M- D( Q0 o5 ~ U( @0 i
63
I$ K4 `. l! z3 @
64
/ J) X ^% A% ^: W* H( v
65
6 `6 J- e+ Q w6 |& p8 e7 P
66
: b4 ?" H( p3 y# y& m
67
3 I. M& @# B$ c# f3 o
68
0 Y/ S0 O" N8 U& t
69
6 ?% Y a7 u" N+ a
70
/ e* b* J9 g/ B, | O
71
# X' v9 X+ `) p# \, K4 t
72
5 }0 e( H: A6 G; H9 n
73
. G$ i9 l9 x9 B+ Q
74
( Z4 H3 l. @. A3 U0 w$ G9 c6 A
75
- m0 M( \% d. T/ f: q- f# A$ M4 g
76
: X% A8 W: G- g3 {, G- s
77
: p+ A. l, ?' p0 o* Z5 b
78
. y2 G1 F; ^/ _1 p
79
9 ]3 w2 g! q R4 H
80
# \$ F5 x1 {$ z3 h
81
0 Q/ Q: E+ ^2 ~. ^' d; i& n
82
& f% ?% C. O) k( `0 [6 s
83
4 u6 @( J T& | O0 K' _- H. A
7. 设置需要训练的参数
$ Q" [; q& u! `* m$ `8 v+ q! q/ G
# 设置模型名字、输出分类数
7 f0 `# t3 c3 e; |* Y$ e
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
$ q) J, M. Y) ^8 t) y- [, u
' C* d7 j4 U& S' O
# GPU 计算
5 ]: {: r8 ?' H* @8 _: T& k {3 r
model_ft = model_ft.to(device)
l7 U; n- G4 \. @; M4 B4 I
0 J T7 E/ d1 d+ T
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
& L, w2 c- ?5 f, M% u" S
filename = 'checkpoint.pth'
1 a* v% F8 c" z6 C4 J" P# u
* }! \# J" L5 I- ~
# 是否训练所有层
- y- M& E& ^8 _2 x& H5 z7 n
params_to_update = model_ft.parameters()
, C% O( [1 R, h: R/ r7 E& a+ X/ `& w
# 打印出需要训练的层
2 J0 H$ Y) r: ?! ?: R9 g- G
print("Params to learn:")
7 t# E2 {+ G' r+ o
if feature_extract:
1 V- x3 x2 Q% }
params_to_update = []
3 S! }! l/ P- x' J& w
for name, param in model_ft.named_parameters():
5 U. N8 [( W" l" b0 V! l
if param.requires_grad == True:
/ p8 c9 L. M; C' c# ~& [7 \
params_to_update.append(param)
/ T: U L7 r( d6 `1 b# K
print("\t", name)
% d1 k; U6 M0 J9 p
else:
# ], b" t: a9 ] r# S' x+ t: Z
for name, param in model_ft.named_parameters():
* Y `- U- r4 N: j
if param.requires_grad ==True:
& Q6 P V) C4 q% U' f. |/ F
print("\t", name)
/ J! V1 k# g7 W, T$ Z; g
$ L: h& @5 f1 B( [) h$ u" H
1
, L' x, `8 h( D7 c, [/ h5 f& x& ?
2
' E4 c; ]: s# W8 Y) b( k
3
) U1 W7 z* e$ }* j* n+ Q+ A! H& |
4
+ o9 q1 J$ g- r$ O: r
5
: a1 l% c9 n" s* T C( x
6
# A& P J2 Q( x# D
7
, M# b* a) F# u% C; T
8
, d& t( s" L, A& s9 O/ S
9
, E8 ~7 h0 m% |& ~1 I
10
$ e6 K; Z; M3 a# `5 I* |
11
5 r- Q) J' Y8 D$ ?5 }7 E: h, b3 w0 i
12
! J p4 u' J. O5 ^$ B% w
13
6 i# F' ~9 N' r/ ?
14
+ l4 {. S4 `/ z: a
15
7 k2 s: Q8 R' d1 p+ E
16
3 L- V/ N+ w. t5 f) P# g( @
17
L- ~$ F4 U( R9 x. B: M
18
. g/ z: C, s: ?) e- v1 h/ f: q* }0 X1 \
19
! b0 |$ | I( ?+ k! ?
20
% W: ], f! m8 j f) o
21
* h6 m) r4 `, x& Q8 d
22
R4 z8 N3 M4 g5 N) }7 Q
23
5 |) W* O& Z- C' C/ y
Params to learn:
7 K. N: Z/ ~+ e6 U f0 `
fc.0.weight
% i2 |3 J& W/ B/ D- g, c8 ~3 L/ w
fc.0.bias
& t& b* B1 H' X# Y' S5 M
1
9 r% b/ b0 i& R+ I
2
. d; y. |$ w7 _8 g$ N g
3
6 M) d" C* R9 i# o F8 v6 o
7. 训练与预测
3 W( {/ o% r% P
7.1 优化器设置
7 r9 @0 J, x3 M: q- `
# 优化器设置
3 e- H: b$ o& G9 V1 q
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
3 p! w+ h U6 b, `0 a! _, n
# 学习率衰减策略
# m+ H2 v7 L, J' w: C) }
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
" P' O, X" i3 q; \$ b" @2 _3 Z" p
# 学习率每7个epoch衰减为原来的1/10
h" M* P' C7 z4 ~3 [
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
/ `% x' f. }1 \/ F
1 Q: @* t4 [: F) a0 W: i
criterion = nn.NLLLoss()
5 h# A9 s( A- m" n% l9 _. b
1
9 U" f) q; k! F# Q$ V: _/ @
2
' L. I4 Y [1 j7 L! d$ N
3
* q8 E( H4 A5 f9 N: A$ g6 F5 L) I
4
0 P' G3 }2 k# g8 a# { i8 Q$ @
5
% n+ H- a* X" q' a% d
6
' X& a& [1 n- D1 t+ B! H
7
0 i4 N; w. I- n) p& w
8
3 U/ y0 |6 i5 |9 O7 M% }" j
# 定义训练函数
$ S! G6 q1 Z. a. J7 {
#is_inception:要不要用其他的网络
" ?; f* N# ~( i2 n/ q
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
9 O9 U" q) G% q" Z7 o0 I
since = time.time()
& H6 e* Y8 R# m; G5 X
#保存最好的准确率
: W0 p, ]( a2 v8 |) G" ~
best_acc = 0
2 B- T5 w& _& m% v4 M
"""
\/ o4 f2 `7 x4 F& F4 l( Q
checkpoint = torch.load(filename)
& ]9 V3 f0 Z- N
best_acc = checkpoint['best_acc']
8 |# r9 K0 a+ H2 `7 T$ H
model.load_state_dict(checkpoint['state_dict'])
3 f3 ~ o) F8 s0 r
optimizer.load_state_dict(checkpoint['optimizer'])
* W, n5 d% g" ~. S9 g( u
model.class_to_idx = checkpoint['mapping']
: F L( d* J) [% M6 S2 K
"""
2 U- S/ a# [& J, E3 s% x
#指定用GPU还是CPU
^% b7 p6 r# m
model.to(device)
( |0 d8 r. N3 j4 j7 L
#下面是为展示做的
! g8 ?' o! o) p) K
val_acc_history = []
( h- L" [. f! S) N/ A$ `6 u
train_acc_history = []
% e$ U- J' K5 C, E5 Y
train_losses = []
4 H; P1 T1 ~" c4 P" I# K
valid_losses = []
' X- w2 u* T% Z3 F
LRs = [optimizer.param_groups[0]['lr']]
8 k2 J. }# E, B8 \
#最好的一次存下来
; j- ~+ H7 T4 X* m. _* u
best_model_wts = copy.deepcopy(model.state_dict())
% }5 J! b' H( V$ x( S
) c- w- k; H3 i5 h: D( B9 O+ m
for epoch in range(num_epochs):
3 [# C2 F, g' Y% O- Q, B5 C, f9 Q
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
7 g- [: B( V1 |
print('-' * 10)
1 K. s8 G. W+ [% b) b
( d ?8 b3 _+ r6 C g' v2 v
# 训练和验证
: _7 F1 X! E6 S* l f! E
for phase in ['train', 'valid']:
# q8 a6 j7 o/ m# s. q
if phase == 'train':
8 h8 G/ h/ g4 k7 ]! F- O
model.train() # 训练
% i; ]+ ^$ w/ i# ~; H7 v
else:
2 f" Q4 q9 w% B$ j
model.eval() # 验证
. x3 C5 q7 a' f7 j) b, U: S& `
2 t; g- _; o; ^0 z: r% f) x. @
running_loss = 0.0
( H' g1 W p7 |7 o: c8 A) K. I& v( [
running_corrects = 0
/ e, @ `3 {7 Z$ z8 y
9 r+ x' `# o4 M2 r* d N: k) P$ |
# 把数据都取个遍
: z+ x6 p4 \$ @: D( D; h
for inputs, labels in dataloaders[phase]:
* g' F2 }1 d7 D& I+ Q+ H& f
#下面是将inputs,labels传到GPU
; \! v6 d2 G, v; ?" c5 ~
inputs = inputs.to(device)
2 L- V6 N- f: \2 x: I; U0 Y
labels = labels.to(device)
4 K1 m6 c6 O6 o( T! _* b
( d) O. u( b' q+ F
# 清零
5 @+ `' e* x; n: q2 A) h2 E7 Y7 s# s3 N
optimizer.zero_grad()
Z+ J; |9 @' u M0 e
# 只有训练的时候计算和更新梯度
6 i( E" x" e5 v1 R
with torch.set_grad_enabled(phase == 'train'):
" ~; `1 L) u/ S' F& D% A2 Y
#if这面不需要计算,可忽略
, {8 H- U+ f* E" T) J
if is_inception and phase == 'train':
1 V" d% B0 ?! s+ j6 a3 I. E
outputs, aux_outputs = model(inputs)
9 ]% c- `! T8 ^# G5 b; W
loss1 = criterion(outputs, labels)
5 L6 y8 D5 P/ K& U5 K/ ?
loss2 = criterion(aux_outputs, labels)
/ o" f' c0 F1 J) y: B3 }+ K, o
loss = loss1 + 0.4*loss2
: ~3 J& w) _% w0 L
else:#resnet执行的是这里
- }' w2 X: Z J6 z6 ^
outputs = model(inputs)
* ]* T% B6 e) e' p" B
loss = criterion(outputs, labels)
0 ?* _7 G' {0 X) ^1 I/ y2 E9 {
! W/ |4 l% x' ?+ f) l
#概率最大的返回preds
" y% y5 ]( w4 }/ p; E4 K
_, preds = torch.max(outputs, 1)
; b3 r) q6 ], q' x: b, U
; U5 |% n0 J2 f0 x, J* z
# 训练阶段更新权重
3 i1 `7 r7 P0 r% u$ j% D6 \
if phase == 'train':
% j' |1 I3 h# O+ t7 c5 y
loss.backward()
8 ^/ C1 j7 ~+ e$ S/ r$ G
optimizer.step()
/ d I: ~1 C" O- m' r
6 K% {' N! J% V3 ?7 X& O. \ z
# 计算损失
% H% S8 {/ K: Y- g' R
running_loss += loss.item() * inputs.size(0)
6 I5 t# V1 F- D# t6 ~0 Q' |
running_corrects += torch.sum(preds == labels.data)
- B: L4 \+ c/ e- J- i: z
- t* Z- L6 F% c# u- K. I7 l# d
#打印操作
# o( ^0 S3 V, U
epoch_loss = running_loss / len(dataloaders[phase].dataset)
/ Y; I& G V' L' V. K, M
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
; D3 S J9 j2 h8 G9 o& A
; D1 z/ G8 C5 q. L
) H* C0 y" _) w4 V1 d; [& m
time_elapsed = time.time() - since
8 p0 o( V- S* p: \+ A* }$ a: ~! E
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
) q- G* U7 E3 W- {; r$ I* O4 Q
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
, t2 ^% A4 \( h6 D
9 x' L/ v; y6 q4 p8 V+ A0 _7 j
1 ]4 F" U/ L& u6 U' M1 w/ i
# 得到最好那次的模型
1 l. x& k8 A* V2 j, a' P9 s# a# E
if phase == 'valid' and epoch_acc > best_acc:
8 a# Q& V4 O$ Z5 d6 }: h/ v
best_acc = epoch_acc
, N# I: e: i8 k: K1 J! v; T( W
#模型保存
% }( J5 Z; j; x5 g7 k* B r
best_model_wts = copy.deepcopy(model.state_dict())
: w1 o6 E0 c& Z2 F9 Q: h
state = {
( I! \1 c& J- U9 f
#tate_dict变量存放训练过程中需要学习的权重和偏执系数
9 x3 D( ]0 I7 F: ^% |' j7 e
'state_dict': model.state_dict(),
R9 B; v, A! ?
'best_acc': best_acc,
' L3 b/ Y/ N* R7 `9 \2 \3 j
'optimizer' : optimizer.state_dict(),
1 h% ~- f( I+ A- a0 g: x; W
}
. f* {& u! {1 t0 D+ D3 X6 K ~
torch.save(state, filename)
% r8 S$ H8 t. e! T; E! e9 F
if phase == 'valid':
4 @+ ]& R! p4 k& W. n
val_acc_history.append(epoch_acc)
' h0 y( b' j& Z/ P& u! v) T+ A/ Y
valid_losses.append(epoch_loss)
/ X' U$ M/ b/ ?& ^" i
scheduler.step(epoch_loss)
, l- {! v0 N( u# J: N
if phase == 'train':
% G2 F7 I) P2 c2 g
train_acc_history.append(epoch_acc)
: b4 b; w9 @5 P2 Z
train_losses.append(epoch_loss)
4 E$ C: e, q( V) k6 C( [* v% X
2 j" }4 D( o5 {, k6 f2 w
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
* Q; Z- c# x" m
LRs.append(optimizer.param_groups[0]['lr'])
( H- B- _, `, S# M [
print()
/ T9 G9 {) [( T9 ~ k7 i' M
/ l! l9 O* ~1 R
time_elapsed = time.time() - since
" F% j5 x; c D! \' G: W
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
' D/ ]0 U% t0 M0 \ Z0 d0 G
print('Best val Acc: {:4f}'.format(best_acc))
' x2 l* @ `1 g: H( ?; w
& L+ s% s( ^% T
# 保存训练完后用最好的一次当做模型最终的结果
8 E/ _2 W6 z+ V4 E2 k4 U R9 N( B
model.load_state_dict(best_model_wts)
! ~# _1 A- O) F- D3 _
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
/ ?/ Y& l0 c- q8 w* _, H& F6 \% k$ u
( k9 z0 X2 }3 `; z: S. M
: E! V/ e. _5 s5 _+ T* k$ ?
1
/ L- R: Z" O1 G3 J* G
2
K! ?$ M+ E$ V5 B: C
3
6 d+ f# ~7 R+ D
4
1 N% V5 v) X, T" w) i9 I( i
5
- C' X- [+ Y1 J W# a6 K
6
( x/ T9 z9 M$ i' G( _$ M0 m8 L
7
: |& a5 t% `) u
8
. F) {" S% W: l
9
: Y; u# L8 @; t0 s% H0 d
10
6 f, T* F$ n0 Q) F# b
11
4 Z/ Z9 l* O; P6 q2 G
12
) o: s, s! B, Y4 i
13
6 ^+ j; x7 w8 Y8 M. w# [4 w( q) Q
14
7 k* V/ \9 @& L
15
; z: k2 y8 {. h; @+ ]6 T; c
16
# i) N9 ?0 c5 |& C4 }
17
' N9 u2 |4 g" V9 ?$ M# a5 q! C
18
5 e7 H- d: g+ \; p$ s5 w" E
19
7 p, G, T( Y6 S3 E" s X
20
5 Y7 V: C$ j/ I. s0 G
21
, }0 w) q9 g/ ~5 O
22
" ?/ C( @) B6 d
23
4 s$ C2 S5 ]; b& U0 \ s
24
5 x% q# s- a+ \1 Y, ] ]! ~: P, ^1 ]
25
) p* N ]! ^5 L& p
26
2 `' _0 _, z. S% k
27
. z X# s" S9 r' w% F
28
4 F# O% q& Z7 Z% S' Y$ x7 \
29
0 Z( }# ~( D) z# X% f8 C
30
+ Y6 N1 V; Y" J% Z
31
" H) }, l0 ]8 G7 ~' g" S% L
32
9 O, s, i- p8 N; E: m/ @1 ~! X, M
33
( T6 @+ A! j1 q$ D# m$ f4 a2 Z
34
% f8 R; L1 ?8 Q
35
* N9 y9 v3 u" \. C! f+ }0 ?
36
; f$ C0 \9 d% A; T) F- V' J
37
8 j' u( h1 v: g5 r: l) l
38
) D- K4 q. f( y! u. }# a4 \3 @" z7 H
39
8 |6 n9 h- _, c% g, I# u
40
: l- q, K$ L- d
41
5 ^8 J. L2 f& o+ r( U: I
42
4 ~1 J8 [) z# w# _% |
43
! `" z7 L/ V- g+ ^
44
7 r7 F; S9 j4 s8 H
45
. j1 h* }7 M" L1 F1 l2 w" J
46
( e9 u# V$ Q8 h4 C/ y
47
- s& e% ^, b3 x6 V
48
. L% k0 D2 K4 Q7 X
49
~9 b4 V: S: M* ]5 j
50
: n5 [. S% s& t: i' o, n
51
% C) N9 u& G$ `
52
7 N; m' h7 c' e0 |" l3 K: `: R
53
$ q# \; Y0 u6 ]( W& t8 R
54
- a2 d: S6 A! I1 P
55
( m A1 a+ T* s: {$ m& w
56
! ^& x. v. s d- b9 K7 x
57
- p& R. N9 V: t% v: J
58
' ^# ~4 W2 F/ \; f" n8 L
59
2 r% Q; d2 a. q5 h9 g
60
$ O7 m& V2 h( p8 t2 G% E/ Q
61
2 n& V7 |# y7 _ M/ U4 R, q. ^
62
1 _8 h& b, x- O" Q0 g& ?, L
63
- T. d$ W+ E; q* w. S! m
64
: N; O- O( M0 z2 J' u+ w) [
65
5 {1 ?1 G. Y$ B1 U; ]7 ^0 |
66
6 k N8 w' f) u2 q) ~, c7 ]4 B
67
) @6 y+ q L3 J' q" @3 {5 ?' }
68
: | Q; Y( S+ c/ v. \4 i- z
69
) Z8 x# h( I; A( |0 F" a Y/ z
70
0 |3 {, v" F, u) ]6 o
71
; ~7 s$ `! g6 `7 ]% i" e
72
) D- O1 d: c z3 b7 {3 a
73
. d% b6 ^# F* E, M! D0 n5 E
74
" e7 y1 P/ O$ l4 x
75
9 o0 e8 J' c6 r9 d0 P9 r
76
9 }# C# i6 f) S
77
4 D' }) B# K) E1 k* B/ s W
78
; q. j; T! k, Y) Y# t
79
) [! a! i. _- E/ \
80
2 f9 ^! n# I5 s0 r" @
81
0 p0 [5 K3 E4 r5 M( T5 A
82
2 N. J3 E4 c- u$ N) V8 ? o. [3 i
83
: ~6 J3 O" i* H$ A7 u8 g
84
( @$ _9 Q" |4 ^7 s0 Z" |
85
2 E; {7 b$ G+ }5 A6 ^& T1 \
86
2 b, D+ ]2 ?1 ~$ j
87
3 a4 M' T6 T5 W- d4 @1 L% H
88
4 e; q7 m8 G% I* S5 G
89
3 K T( o, D# |- y; B
90
+ a" G2 l5 q0 ~% @! s8 |! j
91
9 K. N, f' J0 m D! C5 H( h* w
92
- y& F) p4 b0 H+ T1 M$ F
93
& v: r% o1 E9 B3 _8 I
94
% ~- k5 w" Z- C6 q. U& ~ m
95
2 d+ H6 r: R- a: X3 g
96
, C' |& T Z" f9 w
97
$ U# {! b& |! H( n8 M% ?
98
( j9 X9 ~( B: f. O, v
99
( W) p( x- X7 o' J1 x
100
( E, X! _* W. e) O
101
- Y" J m* q+ I( }5 `7 g6 {* j
102
1 K3 \5 ]; j5 h& K
103
J2 A. Q; q4 X8 ~
104
: X9 x3 i4 a" I2 {9 C5 x5 A
105
, s$ B: e+ F$ g# Q8 B) C" k
106
9 K5 N. ~+ m0 e$ B9 f/ `/ ~
107
7 C! V1 G$ `+ T/ y
108
' Q/ Q8 G& a5 c3 |) t4 m% G
109
: E* L4 V; F, @+ t$ c
110
0 \: Z; h" L, _) ~
111
" q0 [5 h# }0 ^* N$ r h
112
7 U7 z# G) v5 ~4 V2 S4 q
7.2 开始训练模型
! n, [ E& b' M$ J: H# V( c
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
+ O1 i: K/ ^. Y7 T' x$ k
6 C4 F: C/ O# a* Y$ n" W
#若太慢,把epoch调低,迭代50次可能好些
+ e, ~& w' Y/ V4 B) p1 p5 l$ v2 V
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
; \- D* c/ P" B: U* a9 {
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
3 p+ }: t, B1 u1 t5 Z
9 ]% D) _, M+ ~' F* D& y
1
- ^9 b/ E- P" n! k0 G3 b
2
& R8 m# E3 E3 H$ j! m$ O6 N2 N
3
& ]1 ~4 }; a+ K3 z/ P" a) b
4
' C+ n$ ~, A4 D( `- I: M2 v* F
Epoch 0/4
3 x# ^1 d% I1 k& t" v( D
----------
/ v) Z- O8 c) N7 u* v8 @
Time elapsed 29m 41s
! u# p5 U* F ]
train Loss: 10.4774 Acc: 0.3147
: U7 h5 X' K/ z
Time elapsed 32m 54s
- E+ U& V; L' H9 k2 E
valid Loss: 8.2902 Acc: 0.4719
0 N$ J% E8 z) I) }( E, D
Optimizer learning rate : 0.0010000
1 h( l; ?. L6 ?) c% T8 b
; Y! t1 Q& j4 D# q) S
Epoch 1/4
2 L! }! V) a U# H& {* n8 c. x
----------
: w. i V+ L; g
Time elapsed 60m 11s
1 R! J1 E% J! v
train Loss: 2.3126 Acc: 0.7053
" m5 n: q' N3 F
Time elapsed 63m 16s
) L m$ [+ v& h; c/ G
valid Loss: 3.2325 Acc: 0.6626
& D0 l/ N: l" n1 [9 {- K
Optimizer learning rate : 0.0100000
6 U1 g V% U* t/ @% e2 u
, P( D5 v9 U0 F* r5 @4 G
Epoch 2/4
$ K/ j# j% `6 q
----------
. H! H0 _: d2 o( f* J5 K
Time elapsed 90m 58s
& Y8 P* a$ _5 O. |
train Loss: 9.9720 Acc: 0.4734
& v/ N2 {% O1 A9 D G
Time elapsed 94m 4s
2 [7 D" }9 ]2 Y2 m
valid Loss: 14.0426 Acc: 0.4413
3 {$ ?4 {+ Q0 `% s& m2 W! @
Optimizer learning rate : 0.0001000
7 U0 n7 p7 k1 k: t5 S! a
. W5 _0 C: @7 H" X4 `& k
Epoch 3/4
' y7 X" t; C/ _1 O
----------
; Z5 @) B; @+ Z2 W
Time elapsed 132m 49s
I3 c. S) ^2 r3 l
train Loss: 5.4290 Acc: 0.6548
' S1 g) R+ U4 w3 n, S
Time elapsed 138m 49s
# _4 F2 x! P- N& |. p! Q( L
valid Loss: 6.4208 Acc: 0.6027
$ t# S/ v' F9 H" r: g
Optimizer learning rate : 0.0100000
$ A& R2 S( R( d% U. |3 P6 H
: A2 w+ [+ C! {
Epoch 4/4
. W0 @# A8 ~) M, i0 _# I+ V
----------
% e/ S$ S, N; H9 i8 Z7 F
Time elapsed 195m 56s
6 N7 X1 j; W4 n3 {( s, n
train Loss: 8.8911 Acc: 0.5519
$ Q- M* J" u. Y5 A) N$ E" E4 ~
Time elapsed 199m 16s
- e" R- J6 z/ M3 C! w$ r6 R2 ^5 j* g
valid Loss: 13.2221 Acc: 0.4914
2 e7 q8 ]5 f4 Q0 P7 m& K/ j4 X5 _
Optimizer learning rate : 0.0010000
' B" y+ N* {: @! S9 W+ p
" P5 ?! k* a5 Y3 r' C, C
Training complete in 199m 16s
. S* [ _0 ]5 N8 B
Best val Acc: 0.662592
7 Z9 a4 L0 n9 \9 T6 R' T. K. X) w
, Q Q2 u3 }. N9 t! u9 g
1
* k# E+ Y4 D0 A/ s
2
5 @/ [7 T8 J) X
3
# x0 I" `: ^ N/ f L1 X
4
" Z/ }6 i1 ?+ ^4 N' k2 | Y9 X! ~
5
: ~* z" H9 l# C) x. K
6
& v; e/ y( y# u5 n
7
. @) k" L2 G; s$ H5 @
8
4 M3 Q5 U q$ n+ o" [& E
9
+ t! v+ Q p' d/ T
10
+ }, D0 r* R6 V7 ~4 a
11
. w1 i7 n; c* u+ b4 E
12
' p2 I" w. l3 `' n; C; |- M
13
1 @3 N0 {* c2 w! m
14
, F4 Q2 e E; E, v; c
15
9 s# Z$ z' h* f: Y, e! ]
16
+ l, `, R3 t, f
17
- n% V2 }4 I! c! V/ ?* @9 L' j
18
# [- E' R s- k L/ A t1 R' }6 C
19
0 P" ^7 }5 ^8 e& ]
20
/ c1 K6 K: i9 D6 G" Q
21
8 D9 s: B# A: u* Y' j/ P$ j; E
22
( V* h2 [( G% h) A! T" Q% S
23
, f+ k$ c' z# m/ p
24
7 ~3 N( ?0 g- N4 g6 h+ A5 l
25
9 o+ K+ E& z ^; M' X
26
+ X c/ a; D8 h2 m3 t0 v
27
6 u( ^5 J* H: V' h) `
28
: N$ Z) y' ^- j, n
29
$ d: Q3 v3 o0 C+ e2 L/ Y& @; t6 H0 @
30
% U& O2 e9 g, ]6 d b
31
9 M9 P0 K! z: \: [% w: L8 B0 g+ K
32
) a R& w+ q% D! \- M! b. Q
33
1 }4 Z! Z; J3 k: R' B- U7 [
34
% W" c8 E# ?: t7 y4 y+ Z) S7 v0 R
35
5 Y2 C0 l X: S+ O
36
: ~7 |8 @# E; m, L) P
37
" i+ a- c& q: {! e2 y$ v
38
9 f% a5 F" u- X6 {) U
39
+ d+ W+ O$ R" s- m1 L9 b! O" U2 h
40
$ ~1 q/ { [! [. V# X+ n* d
41
4 q% ^& c. y+ K; C+ |9 X1 X. W
42
0 q; z; `9 D' ~- a$ N
7.3 训练所有层
1 B Z0 ]6 ~; r6 [9 o8 c
# 将全部网络解锁进行训练
+ a5 o* O+ [& m
for param in model_ft.parameters():
: ] [9 m& ^) A; C; u% e
param.requires_grad = True
! q7 D1 \9 w3 i) B [
q8 Z0 y9 L: a0 ?% }
# 再继续训练所有的参数,学习率调小一点\
4 |, w( R; c$ h6 S6 @( z* A
optimizer = optim.Adam(params_to_update, lr = 1e-4)
; T: ^5 k* f$ }1 }$ R2 v, o% s0 M
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
9 F' M' S2 O1 z q9 z. C* X
4 G5 d: u: E2 T1 A7 Z; Q
# 损失函数
: T% [3 i7 j. c- E) u Z
criterion = nn.NLLLoss()
! G% A L7 W M
1
" Y( a- C9 ^" i. X
2
4 S! H2 q5 S N) s! {# S! v* q( j* r
3
) E5 r1 M* Y X% a1 _* @! \2 c+ j
4
9 q3 \' z$ N! B1 R; A0 @5 R
5
( S2 `/ J {1 j' L6 i) B( x1 v
6
! R' l# \4 V" ?- h8 ^, ^' Y( x
7
3 |5 c' x7 D3 m" L; p" D
8
# r) h- O; i8 h: E6 i) K& U
9
3 @# h; o8 Z0 y
10
$ d1 h0 h) O2 I2 V
# 加载保存的参数
8 E) b# ?: j" M
# 并在原有的模型基础上继续训练
& f: m& q/ G( W* q; m* ]: p: R: h
# 下面保存的是刚刚训练效果较好的路径
: l' [/ o) Y5 _( O
checkpoint = torch.load(filename)
; ]$ t2 f# x7 B1 g% i" K
best_acc = checkpoint['best_acc']
F3 ]' G3 [; z! ]( c2 s U
model_ft.load_state_dict(checkpoint['state_dict'])
T# E( o% e6 W, t
optimizer.load_state_dict(checkpoint['optimizer'])
% N; O/ H( j' \! i3 M, z
1
2 f, d C, P Z: M/ ]4 @* h
2
{% ~4 O0 w4 n. K
3
- y4 k& G/ Y, l' H% D$ D
4
# b$ ^/ K6 W8 t
5
; M. U8 W3 V9 A8 q! ^7 Z7 j5 b
6
6 z' z7 g+ C; z
7
6 X* ]' z, G- F8 p) b& ]
开始训练
! J3 z& \; y# E# \5 f
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
) y/ c0 P( h$ n& z
1 g% J7 v4 J5 F9 |9 a
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
# c( v' r3 L' o z4 j( s, c# Y
1
3 v' r4 ~/ X7 P
Epoch 0/1
" {8 G6 z9 I- P. Y& A
----------
2 e; D- C. G$ S7 E
Time elapsed 35m 22s
% Y9 H+ O, k2 r* A
train Loss: 1.7636 Acc: 0.7346
1 y* C+ Y. N; f7 E4 G
Time elapsed 38m 42s
3 L9 }& k) q' A T \7 v
valid Loss: 3.6377 Acc: 0.6455
) j6 I$ o+ X6 a3 ?
Optimizer learning rate : 0.0010000
: l9 q( ]. U5 G2 b$ |+ Z9 h/ R
+ z0 k7 C, x8 R" H
Epoch 1/1
8 l$ r0 e& k% U
----------
, Y" |! K, V7 R" n, Z& e
Time elapsed 82m 59s
3 }9 {( J* X& h: [7 i/ p8 j
train Loss: 1.7543 Acc: 0.7340
& ^% O& t4 H8 ?
Time elapsed 86m 11s
4 ^& V8 B, n- v: H! ~: |* K
valid Loss: 3.8275 Acc: 0.6137
! Y+ z) H! X& E8 [
Optimizer learning rate : 0.0010000
3 L. r6 l' P h" c7 ` J/ {3 d
1 y+ M5 j$ T7 G2 \; W) ?
Training complete in 86m 11s
' ` Q9 f& B7 C( W$ h8 U F& ]
Best val Acc: 0.645477
0 P+ Z: S3 s1 L8 V- L4 c
# w; I1 {2 ~7 o; P) E
1
- n% H% j8 k) o% D# A3 y
2
: y/ O6 S: l. \1 a" _
3
$ h, y, s& M2 j' h6 c& Q) i
4
w. k4 L4 Q: v! J/ w
5
5 R* P5 G$ r3 N9 L: n3 l8 ~
6
\: N2 n4 {( r" e9 ^7 }
7
5 i* e; X' o+ H7 V3 t* ]2 _- `3 ^% @8 Q* F
8
+ J% L+ R4 ~7 o7 f. G( O
9
& t, |2 S- b, }5 d/ Y2 D
10
9 H% i' r- F' S4 f$ W: U
11
/ |- g: d& a9 x1 P- o
12
6 O& F$ Q/ }0 |- h; S' G
13
% A: |5 R+ `/ F; X" [
14
* z. }2 ^3 [/ T- q% t# D1 Y' |
15
* K+ _8 G0 |+ w
16
# X: }& G. o( F: X
17
( H/ X3 p5 Y- A, i. `
18
& v2 T; k, A7 D9 Q2 J5 K
8. 加载已经训练的模型
( Y# v- R0 q q
相当于做一次简单的前向传播(逻辑推理),不用更新参数
1 V; m# Z) C( V( Q9 Z* t) c
$ p0 X2 ~+ H4 s5 k, I, e
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
, k$ {. a8 T4 `! h L
3 d c6 F4 e% p/ ~9 \. a) ^! X
# GPU 模式
$ S7 T; a4 q6 s
model_ft = model_ft.to(device) # 扔到GPU中
# v; [* Y: W# B4 M2 z3 v/ _2 j" g
8 v# j- e, q% S8 g: O0 ?( i6 U
# 保存文件的名字
8 n- j6 W, }3 ^
filename='checkpoint.pth'
- p- J( j" b& B& p' e3 l
) `; h) F/ z+ C$ X/ a$ u
# 加载模型
' L* u% t& u, v' ?, H8 } Z$ t
checkpoint = torch.load(filename)
8 v; l& @" o# l4 B& E! ^* m
best_acc = checkpoint['best_acc']
* e4 Z! Q# R: z' Z# o
model_ft.load_state_dict(checkpoint['state_dict'])
$ z2 y) V& `+ o+ U: @6 e) n" o# e+ I* o
1
$ j& @6 y7 i: H: X3 a, k7 H% \% |
2
! H) H2 s* N" b; A9 d, z
3
! X. C$ f5 N1 U9 ~
4
5 B0 L* {9 a) G2 b% a
5
8 [5 P4 b) `! V4 ^) C9 s
6
9 _9 { O. M/ M; A; X g
7
6 ?2 {' X% S; v
8
; V! d" y" _7 H$ Q
9
2 A% f( m3 A% P" v% Y) s
10
i+ Y- G& `9 E- M; l
11
; \2 \, Y! [: y! @2 R
12
, H, h. t0 w1 K& a
<All keys matched successfully>
7 U% W" \8 k0 p, c0 g4 P7 ], n1 D
1
; N2 o& L4 Y/ t8 D
def process_image(image_path):
6 ?- c9 ?3 n6 A9 H
# 读取测试集数据
' E& q+ g C9 g, z# ?9 f1 O! E
img = Image.open(image_path)
# D F% c) ]) j; I8 Y f
# Resize, thumbnail方法只能进行比例缩小,所以进行判断
9 l5 M* _* b5 U
# 与Resize不同
; v T5 y& ?# X9 g5 @
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
( P1 x$ r7 N: g' D& Z; A v
# 而且对象调用方法会直接改变其大小,返回None
0 K$ ~3 C) ?# i1 I0 C5 C1 m
if img.size[0] > img.size[1]:
3 @) g+ x- Z9 O- b' n x; G
img.thumbnail((10000, 256))
0 g1 y7 j& l. s! [8 \) F) G
else:
8 X3 h/ p2 V) s' q9 F5 A% T
img.thumbnail((256, 10000))
" a4 s( K4 `- v1 Z, j" F# e: |6 Z0 l
* ~& y2 Z+ ?" e, a9 h
# crop操作, 将图像再次裁剪为 224 * 224
4 p) C ?* k+ k7 F. N
left_margin = (img.width - 224) / 2 # 取中间的部分
0 t9 c: ]# |+ J6 J5 x2 t0 h
bottom_margin = (img.height - 224) / 2
& E# {4 D4 S- ~3 J# R0 q. N
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
( \7 F% ]& G P( a4 j
top_margin = bottom_margin + 224
. n# y5 [' w3 T* B9 e
8 d) D! z6 n# x. h$ [' g( o7 l+ s
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
; u- t: q2 `- m& C6 C
( h6 Q8 G$ @( O3 E
# 相同预处理的方法
' q6 T! L& Q' U
# 归一化
# E$ t% i, k4 Y0 ]
img = np.array(img) / 255
: U9 Y" c9 ^% J/ G# n0 \; j! ]
mean = np.array([0.485, 0.456, 0.406])
3 o2 G8 H# V/ T
std = np.array([0.229, 0.224, 0.225])
1 H1 O+ z2 V2 l8 A! b: d* p
img = (img - mean) / std
1 B# o5 ~ O: Q
4 ^7 K. _0 g' s5 ^; }9 s$ ~
# 注意颜色通道和位置
( U: v8 r. [0 N( q3 d
img = img.transpose((2, 0, 1))
( g: b( T7 P4 Z
4 @: U: \- D, L( U7 D2 X
return img
! J5 Q) ^: c: N
& C' l. w$ [3 i+ r2 |
def imshow(image, ax = None, title = None):
; [- Y, u- p1 z# R% \
"""展示数据"""
* W, N, ]& [, {
if ax is None:
, B% s. ~! o% {8 @1 ^
fig, ax = plt.subplots()
3 N7 e. u1 W# G' t) r( g1 T9 z
[, ^" B) V' i' M, U& O. J0 o1 ]
# 颜色通道进行还原
# h" e0 I9 t' p s/ k1 f, d
image = np.array(image).transpose((1, 2, 0))
& _8 z+ i% C1 G: H# k0 I n2 O
a9 Z8 L6 w( m5 S
# 预处理还原
& b- y$ K$ H6 f$ W4 R! o9 {4 o; v) K
mean = np.array([0.485, 0.456, 0.406])
: l- V9 z1 _" z N& ~
std = np.array([0.229, 0.224, 0.225])
) I' s% z4 S s& \$ W
image = std * image + mean
) P" U' ?8 ?9 q% G5 [$ j, ^
image = np.clip(image, 0, 1)
% @" \2 u4 ^8 e: ~' B
+ A2 \' E$ d2 @7 J: u
ax.imshow(image)
8 S1 W" z6 p r, H. x2 U
ax.set_title(title)
3 T+ E' Z H4 U; o1 S# _
}2 t+ p( ^- V
return ax
0 e6 {8 S Q) l3 @+ z
# o% @, h# h+ t: u0 g+ u! n7 U
image_path = r'./flower_data/valid/3/image_06621.jpg'
3 i* Y: S" z; P3 z7 d3 @, }
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
/ f- q% _* n* X' r
imshow(img)
8 n6 x$ M8 k @" s1 X5 w
* k" Q( v( Z+ w1 u a
1
( I5 T4 o+ H: G5 E; l3 T) A8 U" i0 [
2
, x9 [# ^! T: y/ @
3
- u9 B7 }' S+ z( e
4
" t5 F( a# Y; b* j* y
5
3 q S5 p. {9 [) p; Y( w* ?% J
6
: q9 d. a" Z$ J# n$ |" i
7
" x/ ]/ I. N6 [: s1 ^) P' B5 ?
8
0 g0 d" M% F2 S3 d
9
: g* J7 S3 r( X8 w& E4 W
10
0 e5 N: k% l* i( g! \3 h* i- Y
11
w# v* i1 z- E( P% s) Q
12
9 u& I: C3 g% y ]! G$ e; i6 K: V" M
13
( ?) C; p+ r. H& v) `5 P
14
* r v% P& P; `& W% e0 U$ C
15
! U: X# i% Q5 s3 o
16
2 f0 L9 H8 {3 o* s9 m L) N3 b
17
! _9 b$ d5 h2 z( Y' g
18
# a! F; L; X$ o, |. G
19
7 _+ _2 Q4 z. h/ B4 a7 c: B: I8 j% o
20
6 V; A4 D$ i+ Z) D3 H# u! k
21
: J+ w# }* t$ m- d' f
22
" W5 h# G0 V* L1 d0 \
23
2 ?9 O& ?1 Y/ Z
24
+ n; z. h8 ?0 ?( u/ `2 K% l' V
25
0 m8 N) o& L0 ~
26
. B) r5 ` r) A
27
3 j& F2 _) T% {; Q& m4 H! n
28
: Q2 N4 q3 c% i* Y! e' d4 B5 n
29
2 P9 [ m: U) i
30
g4 w3 M. Z( q% I, a1 y- |
31
' [$ S; U: K" g7 \/ v ?
32
. H* y+ Z6 ?$ c3 | t2 u
33
: f( s! q2 m8 x( z% z) {* j2 _/ `/ u
34
- {3 E& Z; g/ g9 L5 ~7 G# @- @
35
8 V0 ~6 j/ S& ?6 ~% o ~! X
36
; e2 ?. o0 f! S4 P( U) ?
37
& D$ C* V: Z# C2 f3 f
38
: K: w$ l) X& O5 a# [
39
5 g% P5 |; Y6 ^8 h7 x
40
! F1 i8 d. \# c, h1 U9 U1 J
41
8 D* c+ s3 m% `0 D& ?1 x
42
0 t! L( [- ?8 m' }
43
9 j- F8 @9 G- K
44
' H3 k8 b3 y# m/ ~0 S& }
45
% B+ ?0 P* ` f& H0 H
46
: y8 i2 i1 G2 r/ l9 W6 c( s/ g2 m
47
" y6 @) H- U. v4 Y- y
48
5 D: D6 t8 F* e
49
4 B+ }0 p; p [+ c+ N. q/ h7 k1 ^
50
. r) ?7 T# V5 c/ ~+ M v) q/ {
51
( Y9 b7 H' f" Z* B$ i
52
+ M t& c/ u6 ?% o. ^
53
5 i. ]# |' U( S1 D; v) l$ s1 y
54
2 v( G/ @* a% J3 F: V3 L4 R
<AxesSubplot:>
: z# } M4 m) m/ C6 s1 C# ^8 L
1
7 \8 f7 \ g7 x* z
# W+ j7 G; \5 r# I7 Z% j3 V
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
$ k: Q' y" a0 Y+ M
. K" h; i/ [# V3 M
img.shape
, ]. U; H3 H( z* s S0 @9 w+ V
1
( }9 y$ q$ O: N" ?! q1 X7 e
(3, 224, 224)
4 m6 p# n/ Y: u6 r" X- O* Z$ S8 v
1
1 J% S3 I: |' i- }- U( K3 I) K
证明了通道提前了,而且大小没改变
* T' i# v$ l" C' U
2 }" f8 o: M# [8 j
9. 推理
. D* }% I0 ^' m8 _4 b2 ^
img.shape
5 J5 f: F5 Q% `: l) L7 K! e
; I6 k* d& l/ F% I$ L- H/ U) k
# 得到一个batch的测试数据
1 i/ T+ U F5 V. j7 p6 W
dataiter = iter(dataloaders['valid'])
' |4 v7 f5 o p" Y! A+ L
images, labels = dataiter.next()
4 `4 m! k Q2 I& u- I% L6 W/ y
8 R" N0 n- V1 K) t
model_ft.eval()
+ }' w8 v/ v! d2 @$ [0 @7 X
8 U5 V, O% }5 F. O, X
if train_on_gpu:
9 ^# l B$ Z% t% g& B9 r% F |
# 前向传播跑一次会得到output
0 t9 a) H; A; \# d9 G6 f: r# {. J
output = model_ft(images.cuda())
' U! C7 _+ M! n6 _9 ~- x
else:
- E: I6 D3 u t. s
output = model_ft(images)
9 [! u8 M" z2 F+ L7 ]( b
3 Y; O/ f5 h9 I$ \$ K" U* o
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
" Y* n8 H% k1 G" q) q
output.shape
' G1 H) ?; }# o M+ s" o
9 p0 y& r6 b1 ]3 y/ S9 @
1
" K, L. D# c4 d& Q% S
2
3 K) c+ r: _1 C6 E+ s: ~
3
* ~! v" x2 j+ H. O
4
! v" d$ }* N2 Y* k
5
7 G. _% p, _- ^6 @" q
6
6 o- ~! J4 g( N9 K& ?
7
0 S) t& Z7 Q; P% h8 C" ?
8
1 t0 d1 t) l5 \
9
( {9 K& d. Y/ @0 B. Q6 G
10
. R4 j9 d4 l$ s; F
11
% c, H2 O+ I% x8 P0 |
12
( Y1 s* a/ l! N) Q9 i
13
- ?# [* M$ [% J# Z
14
. V l2 ^0 w' l/ |7 e
15
4 g& G- \ A' ]; ~
16
0 s, ~1 u0 `2 e3 b9 P- [/ p" W5 G
torch.Size([8, 102])
- ^- a0 z6 B4 g1 W0 A) A0 d- O' u
1
1 w, G: b7 H- r* y" m$ b' f
9.1 计算得到最大概率
; q. E: p6 ~' a4 w4 g
_, preds_tensor = torch.max(output, 1)
4 `4 l7 q! b$ g. \3 ~
, v7 t( Y0 e7 E
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
: R# R% m: w; |7 h {+ ^
1
/ M8 H/ T0 G0 l/ Q" u
2
4 ~! p5 K- U: ~4 n T. L+ q. W
3
* I0 u9 G% d! d& N7 z
9.2 展示预测结果
" f9 ]+ p! {' q, X/ B. _3 u
fig = plt.figure(figsize = (20, 20))
% c6 [7 `' D- [. w6 U3 R9 ^
columns = 4
8 |6 x4 R3 e) p4 b1 Q" b! Z1 B7 t% ^
rows = 2
4 _! y, t6 r4 u
# k; D+ {6 {& [- d
for idx in range(columns * rows):
A. e! E; h# e- [
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
3 I O( X! z8 E/ X. e1 e0 g
plt.imshow(im_convert(images[idx]))
P: H0 q5 q6 v1 H. I
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
( P, c; R% k2 E) U3 T* e% e
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
l( z) K: D! C- t& R
plt.show()
: f2 V4 T: M1 {$ {7 D; }9 B
# 绿色的表示预测是对的,红色表示预测错了
0 H0 j. n0 X' q$ ]: [# b6 u8 ^
1
$ I( Z) o0 p2 S, \5 H6 |
2
1 {. l" z! H, i; y* b$ ?* q
3
) ~4 H5 Y+ x X" A% H' I: s
4
$ l8 b' D- P8 K2 a4 K3 ?' F$ n/ O: v
5
- G/ C) v( G7 R
6
9 A1 X; f- y! r+ ~8 ^' |6 T
7
7 D- i2 }. ^9 A6 B; J2 f
8
% d- i: k5 i9 k, U4 y1 X
9
% x, n {: c9 c* i6 \' Q
10
7 @4 B7 t2 J/ x3 y- J
11
: `9 i. C/ S" S3 n5 `1 I6 ~
0 E' |2 J# q5 I& a
% R* ~5 T6 F* O% a/ g5 O+ g. O
3 Y( T' _1 y& i; B) P/ @8 s
————————————————
# a$ Y g1 z& @$ o2 G* ?
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
`8 [; d: D; g
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
: g, S( s1 x! g* g7 I0 f5 \
0 {7 M+ H4 m) R# G0 _; S, H7 C
: W2 ~2 B5 N- l. @) F/ S5 ?
欢迎光临 数学建模社区-数学中国 (http://www.madio.net/)
Powered by Discuz! X2.5