数学建模社区-数学中国
标题:
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
[打印本页]
作者:
杨利霞
时间:
2022-9-8 10:41
标题:
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
( i9 B% Y& m' _9 y s
: q9 B0 g+ f+ n% B# _# _$ O! z
文章目录
( e0 S, h( v$ P$ `' H* b
卷积网络实战 对花进行分类
" _( u4 F: }2 }$ P0 s8 P4 F* N4 q
数据预处理部分
/ a$ G& z/ S! V- O# [) g$ R8 T9 h
网络模块设置
3 J; Q0 u8 w4 J
网络模型的保存与测试
9 f# X( q$ m; {/ @2 _
数据下载:
( _6 E- i. z: m* e4 j
1. 导入工具包
( l9 k8 z% @, Z; w% M3 I# m
2. 数据预处理与操作
" f w4 U' U1 T1 o! S" H) Q4 u( }
3. 制作好数据源
- z* Y8 ^, X4 w X
读取标签对应的实际名字
/ B- I: V# L+ G% E2 \
4.展示一下数据
3 o% n! A8 q/ A( `. a* V4 e; B M- V
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
4 x: T2 u0 S S) [4 ?2 p- f, _. h5 \
6.初始化模型架构
! a3 T& R+ O0 v7 @: _
7. 设置需要训练的参数
5 {* J# G2 d6 l$ v, Q2 p B5 |
7. 训练与预测
* o# a8 v4 S! S5 ]$ C5 k
7.1 优化器设置
2 R+ j. R6 w. T
7.2 开始训练模型
" M/ G" q) [/ V# I9 Z" f0 U
7.3 训练所有层
, W w' a! P- y" H& j8 }7 s, y
开始训练
) j3 @5 T/ d5 D
8. 加载已经训练的模型
2 F* F& d! W+ o* [3 i
9. 推理
) B; H3 k# M" [6 w; \
9.1 计算得到最大概率
& |0 n0 ?% t8 W+ _
9.2 展示预测结果
# M6 t) `! |5 v3 E
写在最后
2 ?4 L/ L+ E0 M' ] T7 `; J
卷积网络实战 对花进行分类
- G `9 \2 w0 U) T; I) I! r% E
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
4 c% T+ h4 }8 l
1 o) F, P7 n& w5 I8 U
在文件夹中有102种花,我们主要要对这些花进行分类任务
: G: @ T4 C* H A5 _- q; i
文件夹结构
1 L3 a W% b6 K" e
/ X* `3 I, Y/ M1 E/ Q3 N2 L# I
flower_data
, h$ r `* N! z8 M7 |! B' x
' h) o0 F7 b# I1 a8 Z
train
+ @# F4 m+ T" N
$ I" N# d8 r9 N' o- K \6 s- `# L
1(类别)
0 F9 t6 ]' K" ^
2
8 h6 X* _1 A* F( o! f2 N% b: G
xxx.png / xxx.jpg
) b: L6 Q2 l* i e: T
valid
; ]/ N( \ F* S: R4 |# C/ C
1 i# j" {9 @+ ~9 H
主要分为以下几个大模块
* Y' z- \7 K J) @( L" W6 e# e
, t5 p. [- L/ I+ l3 [
数据预处理部分
' w) N* `8 Z4 Q3 o' W
数据增强
2 a! z: Q: d. i# B8 I
数据预处理
- h1 z* j* F: B W4 {/ q: D
网络模块设置
8 v. O0 u% d; f6 @. l, m6 P
加载预训练模型,直接调用torchVision的经典网络架构
! ?) C3 L- e. H
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
" b9 j# N$ h5 V' ^% J0 U( B
网络模型的保存与测试
) l6 Z# G0 Q0 U! C
模型保存可以带有选择性
, J+ N2 [ E0 r! i( v) q# @
数据下载:
6 x1 J8 z$ ^# j( \: N$ m9 N F6 K
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
+ r1 f5 @" g9 }* |2 q
4 U2 r+ e+ `( f/ Z* m$ z- p
改一下文件名,然后将它放到同一根目录就可以了
- A8 E2 P5 S( @* X
* e4 \, u* n. r9 M9 P
下面是我的数据根目录
' |% a. i4 y0 H! z) {+ B
) p5 B; q8 N# ~: j# K3 Z
C( W& g* o7 L: Q& a! [7 ~
1. 导入工具包
3 @7 b8 L/ S S* h3 o2 z# R
import os
7 E7 F6 ?& }. T7 W& a/ b/ {4 X
import matplotlib.pyplot as plt
( {2 d" Y) s5 t) W4 |
# 内嵌入绘图简去show的句柄
/ V5 f) `% n; {2 {& o
%matplotlib inline
/ [2 q: d3 f1 ] T
import numpy as np
; k5 B( t4 y( V$ |
import torch
) ?# }' t, p* B/ h8 @3 Q& t
from torch import nn
! U8 E7 y1 w) t ]# H& j
6 _6 z& Z3 v7 u7 k0 A# E! |) {" M
import torch.optim as optim
7 C5 X' J. j5 Y) w8 V( t% T" {
import torchvision
+ F: z5 E2 x! M
from torchvision import transforms, models, datasets
6 E9 q# ~. y/ T" e. Z4 U3 F
( ]) {, Q, ^4 i y# K+ Y) Z$ z
import imageio
% W3 Z. ^& n1 S+ K# G6 \7 @
import time
8 z/ _' ^& l. c+ l
import warnings
+ D# ` y# ]3 b) w
import random
5 g+ ]8 i( K1 @5 N; N' q3 d
import sys
1 G b* q" R$ n, g& V0 ~+ c
import copy
; f) D9 X. b- O |( U8 ^) H6 }" `
import json
# }: O5 w+ ^$ r4 O8 f. N
from PIL import Image
5 l5 ^2 J5 t; L. b3 {8 S" r
' Y3 M' b( B5 u* n8 s* y
. t; @, Z% Y$ O- K3 o* x1 `7 W
1
$ z! _# U% A5 x% I3 `5 {' H
2
% E8 }- r8 F# z# T2 h& u4 v
3
& H0 ^9 T$ R7 g0 C7 T) }
4
3 j f4 k7 F! {/ b
5
e" ?$ Q( ?- E( V) k
6
; ^2 c% Y. L, S
7
5 }5 Y& c$ o; a7 I5 C: ^& d
8
- g8 `% i1 r+ \% x* H) A3 Z1 r
9
n( r& @( B9 _8 g7 M% v
10
- C# @7 y8 Q9 w& D0 M
11
" o/ [! E$ b' R/ L) e
12
+ o: i# Q' A6 j1 ~
13
. K3 E7 B4 t' @9 I: i
14
8 r/ \, z5 D, e6 Q
15
. k6 V3 W( C7 O
16
" p" a+ y5 e; [9 K
17
2 u& g0 B& t4 I U0 f
18
8 D( b1 N0 N+ D; A
19
+ _0 m( s: R9 O3 e
20
/ D. [1 z& N Q
21
3 ~5 w. a$ r# N) }3 K
2. 数据预处理与操作
# ^1 }4 P e+ k0 e% B" O
#路径设置
0 _( R# V& y( S: \+ P
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
' l/ X/ ?( a. m
train_dir = data_dir + '/train'
& U$ F* c7 q6 \, F. @5 `
valid_dir = data_dir + '/valid'
) l8 x! u. M9 e& V' a$ { ~- S
1
6 l) e' _& k. A
2
( R! o+ f' B( L+ z) \9 L) t
3
4 O& e. d+ [; Z6 q) l4 D1 ]
4
( x6 b) z% ]6 p; S- E9 v
python目录点杠的组合与区别
+ g' l: Z0 m7 M
注: 里面注明了点杠和斜杠的操作
+ `( ]. |8 j$ n/ E) Y$ m0 p/ H
. Z- w) l& g: m% k- A
3. 制作好数据源
5 l: w6 c& O3 N4 \
data_transforms中制定了所有图像预处理的操作
9 J! Q8 _7 Y% y8 K
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
, ^7 s! I5 U! a# r/ d j! a/ N1 n
data_transforms = {
, m8 [( ^8 L5 Y0 y
# 分成两部分,一部分是训练
4 {: E# [. [% q' i/ W
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
8 k8 Q& v; ~9 K1 s0 S- f
transforms.CenterCrop(224), # 从中心处开始裁剪
' ^+ B: @! T. j. y. i
# 以某个随机的概率决定是否翻转 55开
" J$ D1 Z7 w" ]5 n: w
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
H( g9 c2 K, z* {" Y. e% p' p
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
$ T1 _: L. C) T+ \4 `0 H
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
$ z7 V7 O6 m* g7 `! f
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
9 r/ `6 Q: y+ C( q/ G i
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
6 m0 p& B; G" M, Z5 X4 R
# 灰度图转换以后也是三个通道,但是只是RGB是一样的
; q. a# U" F8 _# B/ s: f+ ]5 |
transforms.ToTensor(),
5 {2 L' f# }" R P* g6 [0 V* }
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
5 x [" k2 ?+ g% p- r. o1 X, ]
]),
3 t" \& j! n3 s, F
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
3 p' C9 r# i/ Y7 E3 o
'valid': transforms.Compose([transforms.Resize(256),
2 s8 t+ Z) e1 T7 Y0 ]) D8 _9 m- V0 n
transforms.CenterCrop(224),
8 P# ~7 `( u: ~
transforms.ToTensor(),
; B( D, _1 O, m
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
* B0 Q S4 c# y4 F
]),
; b9 D2 d: w% y3 E' C* m5 p
}
+ X( O) z$ K* H- E3 l
6 l) s- @" X7 x) e) h
1
) b$ u. P( E/ R
2
/ D0 f* r! a* D7 C$ B H
3
s0 n/ C& D5 w( {6 u
4
+ t P9 \+ h8 y1 g
5
- l/ D; k4 f/ h- s% ?! n% b
6
- L7 F8 G+ G3 G" U1 C4 W9 B) V9 t
7
4 |+ J1 W) ?) B9 d' ~+ j% _* ?
8
) F* T! h- l; F7 V
9
* ^( b ]2 C! }- @
10
$ f- J0 C- A2 T- \! r
11
/ Y# g& E, F1 N. t q6 S
12
* q0 a1 W$ j) j3 N3 U
13
- g/ P" R+ f2 x: z
14
4 h. r0 Q1 [5 W! k! {
15
& ?4 V, W' H, T* }* _
16
* C5 z3 e) `& d* ~, u, }
17
( y3 X5 X, y$ g) l5 ?
18
$ [8 x( j& B& J; r! }
19
0 h9 W1 M' _& s, E- T
20
+ y. r$ }6 @. k; ~) H
21
- i& [8 }+ B5 }4 l, H! `( L
batch_size = 8
, C' Q8 {9 T9 `7 `3 K! m" P1 D
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
. ]+ P- a2 l2 {( C! L) |7 g
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
: L1 ]& D: B# {( E; h( H
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
+ R% k, _ m$ ^/ Q. T
class_names = image_datasets['train'].classes
+ z0 v7 W% Y* A p
# u- \ n5 r% p' l
#查看数据集合
/ m* B: W2 ]- v* {7 I6 u
image_datasets
% U v8 h( r: E( U$ V
7 K3 F, P. q7 b9 R& f+ a' I
1
) `, v- l0 B( g* ^2 P
2
9 R" y- |1 D, H! k) h) r
3
- N" F- ^5 X% j* p1 k( [; {# Y
4
0 Z& F1 T3 z' I1 ?$ _3 v1 m
5
1 ?& D/ ]$ [; j$ w- u
6
; E& ~6 ~5 |' D) {+ O, m: \5 m+ v
7
0 j: ~# \! d# w ^
8
7 h* ^3 [$ ~& L5 ]8 k( z
9
0 z+ J6 Q8 O; m* [) V! q" {$ s( W
{'train': Dataset ImageFolder
4 N1 W) m3 U+ P# @) ~5 S. o
Number of datapoints: 6552
7 a" d9 G* L+ n0 p" R9 V% z
Root location: ./flower_data/train
$ i. e1 c6 |% y0 J% X1 U6 K+ k
StandardTransform
0 Z" ]; w; s* l( D$ U
Transform: Compose(
, \. w) M0 p8 V
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
1 N k# s# b9 @# _" i9 i
CenterCrop(size=(224, 224))
! W; E& H2 ~* ^. Y T, J$ B
RandomHorizontalFlip(p=0.5)
% u: s& T- i/ q' P0 f: l( V
RandomVerticalFlip(p=0.5)
; [6 I( R, z% D4 i1 C
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
4 g, R I1 W, y" K; c
RandomGrayscale(p=0.025)
3 h* o2 ^2 R$ T' B. W- W
ToTensor()
' ^' F9 C5 K) f
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
5 Y* ?8 V& L; Z. I
),
b; U4 I$ _% S
'valid': Dataset ImageFolder
9 F( g% l+ U( ?1 t
Number of datapoints: 818
; H5 d( p' l, u" `: U9 A
Root location: ./flower_data/valid
$ s( b" W# r: `! y/ X! C! X- y
StandardTransform
5 P8 _5 D& L) s
Transform: Compose(
+ f1 C5 }# [+ o/ x
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
* f/ [1 u" Y1 n$ t' J! U
CenterCrop(size=(224, 224))
. G6 W! {" ^, z* k* z P% O9 g1 x
ToTensor()
8 U5 \/ }. K$ D* d$ |# O
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
4 w" L: @/ P- ~6 `* ~
)}
3 j+ h. R7 r. x# n2 r
: P9 H8 R6 R' ]4 z9 Y' i9 i8 z
1
/ J( D+ }* I2 f
2
' O1 J5 ~/ D, ^# T
3
& y* [2 s$ ~4 S6 @6 ?
4
, O/ K0 c/ [% A* f
5
3 g( r/ U( D: h* r& n" o. [, M0 z
6
$ v8 N* J, ]! Q
7
8 f: v: w# m9 F. X& h+ {7 W" ~
8
$ g! \) U6 b9 Z1 ^
9
( a$ y7 o* a" g0 N2 }: Z) r# u
10
. K5 k. ?% @ u7 x; |
11
( {8 Q0 W" [8 |" ^2 x
12
% ]% o4 ^, I1 `# j* L6 M
13
# m/ n" [+ j# ]( `
14
5 X/ M% X/ Z1 H4 P) e9 u, M
15
6 V$ T; A, g- v; H. t6 H; z5 c
16
1 d. ^4 ~' }8 w
17
8 [3 Q8 d+ ]9 |4 Z
18
! h0 B5 s7 M& Q4 p. I
19
/ C% Q M0 |9 R/ B
20
+ I' d2 O" n X
21
1 l2 c" d4 F8 c. I, X5 U# \; M
22
% b0 J; v6 B5 p
23
3 C! Q: X) X. z& A; c( X) h
24
: x) B1 k- r+ `% t0 b! m0 N8 h
# 验证一下数据是否已经被处理完毕
2 x& J( _9 z9 r
dataloaders
% a) U4 d/ {! u
1
) Q. x( R! K, p8 v" m" S$ h
2
: f% E+ R5 o7 \
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
# e4 |7 E* r5 S( U0 \
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
/ I' R" T" l% U7 R' Q" O* ]* s
1
# y6 [$ @2 y' d
2
; @" Y$ V5 ]! X; a) e, u
dataset_sizes
0 E7 ^- _ L6 a- D9 Q' l
1
$ O. v4 t' o9 A. y5 f: Q
{'train': 6552, 'valid': 818}
0 i; Y% d" n0 W6 K$ n! \0 }
1
- M% k1 B/ {+ V$ Z8 I" ~8 ]
读取标签对应的实际名字
) P& g/ _1 J; C! W
使用同一目录下的json文件,反向映射出花对应的名字
3 @2 z& X8 u% L& `6 S
7 ?( L9 ^) I8 s0 Y/ R4 K
with open('./flower_data/cat_to_name.json', 'r') as f:
; O3 n; V. v5 X- Z* r( g2 V6 \
cat_to_name = json.load(f)
' |" h ]8 c3 J# s* s! ^
1
; e2 i& O/ s3 D; y
2
* M4 \ o H! U) e4 p/ e6 J
cat_to_name
0 I' s) D7 y' }2 l0 C# ?
1
/ M5 M7 ]+ @; E' v1 T5 g' \: @( ^1 b1 F
{'21': 'fire lily',
0 r0 h. Y* o) B! g3 ~1 ]2 ]2 V
'3': 'canterbury bells',
6 i" u. ~; \. l( w% Q
'45': 'bolero deep blue',
( F; |' `$ R) p# T
'1': 'pink primrose',
$ w! {8 F8 Q/ A. ~5 |2 x
'34': 'mexican aster',
: M3 j! } j+ i) @$ P2 m9 l
'27': 'prince of wales feathers',
3 i+ }9 j* z$ e$ ^0 O- x' P% H i* N6 @
'7': 'moon orchid',
4 U2 F2 t+ } ?9 f) {' X
'16': 'globe-flower',
, w& D8 N/ Z( _' q# W& w% o3 g% N
'25': 'grape hyacinth',
) s0 r. L& T7 K/ L3 o6 S5 [& t
'26': 'corn poppy',
# R( t; \2 @& t9 c4 \( J
'79': 'toad lily',
& s! T; w0 m4 u
'39': 'siam tulip',
, e+ E R' s+ j8 n8 `
'24': 'red ginger',
- h9 k/ D. W; r
'67': 'spring crocus',
0 H; `6 ~" Z1 U
'35': 'alpine sea holly',
* _" l& U- J: d7 e1 }: O+ q+ K! y
'32': 'garden phlox',
7 F8 H: ]6 i& z
'10': 'globe thistle',
4 C/ |. @# B! \4 C( x
'6': 'tiger lily',
* o( l1 W8 z/ A8 h$ a5 b! x
'93': 'ball moss',
: `- T7 I9 R5 o8 a1 r" h; R
'33': 'love in the mist',
" a# e% H, K+ V' S B$ `% }1 `
'9': 'monkshood',
- p7 p6 d5 }+ D
'102': 'blackberry lily',
5 ~% `6 C4 m5 ]3 l! E+ v8 z; Q; O
'14': 'spear thistle',
& Q* `! R! a8 z8 {, ?1 y
'19': 'balloon flower',
: |; y$ X- E; [+ A
'100': 'blanket flower',
9 X: f" Y- q2 j7 V; p7 i
'13': 'king protea',
( K# {1 d" y0 q8 \" V
'49': 'oxeye daisy',
% k r7 q8 H. B& x$ O8 ?$ `# w
'15': 'yellow iris',
0 ?1 `! E$ [7 {9 H% [) r% g& k
'61': 'cautleya spicata',
8 |7 L( Y+ f2 n9 l* m
'31': 'carnation',
( v2 i2 o# [ p% @3 ?* o+ }
'64': 'silverbush',
- T, _3 K \) s2 _+ o: u
'68': 'bearded iris',
2 h, k, S, D" \! p0 ~! y3 t8 j$ T
'63': 'black-eyed susan',
# z; U l7 Q. v. V
'69': 'windflower',
$ S0 ^% o8 X. G) x# X
'62': 'japanese anemone',
* y' H5 p) W8 k+ N N8 Z# U
'20': 'giant white arum lily',
9 B) q; @1 Z" O# E# D- V9 q
'38': 'great masterwort',
+ G. ?) T& m" f% J$ A1 D9 d: Y
'4': 'sweet pea',
) O5 y; P1 e1 j# y2 p, k
'86': 'tree mallow',
# b5 h2 D! Y1 p! Z, D8 h/ K7 G
'101': 'trumpet creeper',
" N* F6 x: @. K1 [ B. e" l, {
'42': 'daffodil',
/ s. r* Z; j, |; f/ ]
'22': 'pincushion flower',
6 h0 z. @) T7 g' i6 H
'2': 'hard-leaved pocket orchid',
& \% b6 ^& x9 W" g
'54': 'sunflower',
+ r& k: ?6 m7 k f' o' c
'66': 'osteospermum',
1 A/ V# l" }/ z2 G: i% L
'70': 'tree poppy',
5 M- ~3 W% \5 S' r, J
'85': 'desert-rose',
! x, y$ c3 Q8 b1 T, U9 b* n' L
'99': 'bromelia',
0 r2 \$ o. p6 E2 w4 v
'87': 'magnolia',
9 ]3 m: d# ~8 J: I- H/ d
'5': 'english marigold',
3 a+ [6 ~4 h* h/ n5 Y
'92': 'bee balm',
" w2 `/ V- h7 s _
'28': 'stemless gentian',
/ A$ F d8 `% z j
'97': 'mallow',
4 g6 `# v( l# D+ m2 t0 Y4 H. E
'57': 'gaura',
6 p' k; V/ J5 \
'40': 'lenten rose',
1 q3 }- x. a) H% K+ |( f. k' G
'47': 'marigold',
0 p6 ?8 V+ J, v* Q8 B: P
'59': 'orange dahlia',
7 N* p% o( a- x8 v: a7 O% `
'48': 'buttercup',
9 D3 k9 T- i. |/ N% l
'55': 'pelargonium',
& {6 | `- l J* R8 r: u5 T
'36': 'ruby-lipped cattleya',
- H. l/ g" I+ z" \' [
'91': 'hippeastrum',
" _% `% D; _3 x! M' ?) ^. u
'29': 'artichoke',
* A& q3 v# Y9 u$ Q- Z1 L# E* d3 C* ^
'71': 'gazania',
1 B5 M U0 C \) x3 J
'90': 'canna lily',
% X1 @6 ]" d7 g
'18': 'peruvian lily',
7 W. \' {2 ` ]4 G% I1 i
'98': 'mexican petunia',
], n0 u6 x# D* H5 S9 Z. [
'8': 'bird of paradise',
$ M1 d# _+ {5 s- R
'30': 'sweet william',
# C, Y) R$ c y. r1 p4 ^0 U
'17': 'purple coneflower',
: @- B9 G `2 a1 [' q2 Z2 C5 ]( I
'52': 'wild pansy',
, T" S0 C `0 A! j1 [- ^
'84': 'columbine',
9 M7 u* z7 }+ d7 b" s' X; m; u; k
'12': "colt's foot",
% U7 f* G7 f- o# O5 k$ i
'11': 'snapdragon',
0 V. m- [9 V/ U t. `4 E; ~
'96': 'camellia',
( d5 g: R: {9 |% E/ h6 s% ?
'23': 'fritillary',
& M5 v5 A. G! ^1 o. W) i, d
'50': 'common dandelion',
$ _3 z0 \3 U3 @6 G) M! Q5 M
'44': 'poinsettia',
7 [# Q7 x" `5 o9 G' }
'53': 'primula',
' T- q: ]% _. M" Z- H
'72': 'azalea',
0 Y, a( h" x4 o' t4 m- {
'65': 'californian poppy',
0 F$ {2 t ?; c( ~
'80': 'anthurium',
$ P4 _; y# g- F% A% \; c% ^* k
'76': 'morning glory',
7 b+ ~2 d) V. a- a; ~- i: ?
'37': 'cape flower',
! [+ y# L" }; k: L4 H9 ]; X
'56': 'bishop of llandaff',
) R2 w5 E. l/ ?4 u( L
'60': 'pink-yellow dahlia',
- ?, e, y, u. g
'82': 'clematis',
* i2 u' L4 A3 K
'58': 'geranium',
5 D) x: v8 L D! i& ]: [/ b" F
'75': 'thorn apple',
8 ]1 G1 ], e9 A5 F* J1 t
'41': 'barbeton daisy',
4 `9 N+ {1 K* w4 n
'95': 'bougainvillea',
$ d/ p; S1 K6 d, A
'43': 'sword lily',
( v$ G) t, E4 F$ [ |! C
'83': 'hibiscus',
! L' ~; ?2 K& _
'78': 'lotus lotus',
8 j* B' A# M( H! `6 `+ u& H6 O
'88': 'cyclamen',
( w% z! E* C' }4 w
'94': 'foxglove',
* l) y1 C3 g! F f" I
'81': 'frangipani',
/ J7 g2 v4 ]% [# H/ l) w5 \
'74': 'rose',
% |5 R2 @* X& v! F* | d @ l, w8 V* v
'89': 'watercress',
$ W9 j7 Y; r! o4 K" u
'73': 'water lily',
2 }0 |* Z3 n. G9 G9 z5 v
'46': 'wallflower',
# t0 K: z; y/ ^9 o
'77': 'passion flower',
; f5 C6 F0 `- n9 m) g% h
'51': 'petunia'}
5 t) _8 N! @( H o
9 O L* c5 f6 d& M7 n" o
1
( w( A) v6 m! D
2
" Z3 {4 `2 _0 C
3
1 l+ v1 A0 L$ V% }2 L; {
4
4 V- G `$ `* J6 I$ V
5
& c9 U' `" ^4 s% }: R. [1 R
6
" T$ p6 h9 q1 i: g8 |% m6 N
7
" t2 r0 L" n" r* z
8
) [. u* @( t( _$ I: r: p. D3 a0 ?4 k0 A
9
# _# i/ v" ?. I# q( H \" ], G
10
* ?$ X2 s( Q; L. \# ~
11
. i9 ~( Z: U7 t1 M1 }& y3 e
12
2 n0 n+ s, D3 G, u- q1 J% i
13
/ q- x6 y; B( ]# d1 q3 t
14
. J/ u8 _) [) n0 v- H9 r5 V% P
15
% a6 h0 M, [: x2 m
16
+ W: E5 X: E ^# {/ C1 `8 j7 R
17
2 c8 P, N* Q' p3 C J6 I
18
" Q. r" G _. Z! o( C
19
7 s+ G5 V2 L& m6 G5 W7 A
20
1 }3 v d$ l+ V: i5 J# x0 ?9 R4 b
21
/ N7 S( o" n7 t Q, e
22
6 T' w% h, U6 m# L8 D% e4 @0 q
23
4 H# y: f+ x0 c& x8 [- \0 b- Q
24
2 A4 g0 B0 C9 s8 d- T! q3 a
25
! q2 A1 Q4 q1 G" b
26
5 l/ N3 ~# V" F% e, y6 P
27
/ v* x/ z$ ~6 P; Y p
28
, W" t& j) b# N4 f; [# R
29
8 z, k0 _1 q( }0 z0 V/ x
30
/ K5 f# e( ~* G& q
31
( P* t& a. O) j- o# I
32
4 j0 O ~7 v! N4 W
33
: U. R, d* e u9 R
34
! U' o/ f. b; h! [ V* q
35
9 X% [7 e0 q: y {7 V/ P- g8 K$ D( X
36
9 K! r1 ?+ Y, J: y- ~* \9 l
37
) ~2 N, E3 b7 f: ]0 z# T/ _' u
38
3 z3 i+ N* d$ G, x
39
3 r @1 L5 E8 ~8 u4 r8 G; P
40
! A9 c' p5 Q/ C+ R% Y7 ~5 _
41
& B7 D0 c6 j2 M' N8 V1 [6 w
42
! W( i: o: Q/ m1 f* R# t3 h: W
43
$ \3 @; d+ x ~5 {! S
44
* `4 o* \; h2 D. v% y- z% }
45
5 ~6 V! h, \- A3 y$ K
46
' E r- H& _, y7 F T& Y
47
2 N1 j9 M# S# G$ l8 Z% L
48
3 O9 F6 w6 v% }- U9 @" O+ W
49
* ?8 x5 L, j9 v# V, g$ J
50
) B, r- t# c9 `
51
" Y" W. e1 U K& T" g
52
- ~. o1 d" q/ y( Z( g8 X) C/ s
53
% v# [" v# z2 ]7 l1 k/ y' Q/ S* A
54
3 B6 s* F0 I1 L8 `" Z
55
& ?% A( l2 p# e' e/ d
56
: v" x; C9 i- k( j/ t Y5 s/ a
57
9 H1 c1 B# r7 @6 k
58
4 H* K; P. q4 p
59
- V! j" I) d& [' f
60
! ~& i4 m* _3 V* s7 q5 P% B( @. o
61
+ F' K# Q: J( o% h; M
62
. ?2 x& L. w$ n0 V& Z0 [! |" d
63
5 b: U, r$ e" J0 a9 ]
64
8 ~! a) H5 k! a O4 X
65
) H- t6 K" k% D5 P& j/ s# R' V
66
) _8 o/ V) i8 \: [
67
& q- x1 x0 b( [( D9 u
68
# O; N o$ j' b1 \! M: k2 l& e
69
8 y# v7 j( T( T5 Q9 g
70
/ V: g& u/ _: c
71
6 o7 @, S" n/ F8 k {0 P1 m
72
5 t" I* m- Y7 [' z4 m9 N
73
7 C. P! j% F/ h, c0 l: s6 T1 k
74
- _3 j4 H, U" m r9 N9 H
75
0 _% v) _6 R7 Z/ ]
76
7 r. [1 ?& ?9 V' X" k I
77
: i8 r0 M g1 Z' h) G# c8 W7 V
78
9 f: Q+ k3 X' M; y
79
4 ]5 ]# F% M0 V$ [
80
; g x }- M8 q4 J) p, t
81
0 w9 z* C' f Y) e
82
7 c/ Q9 L/ U% Q3 M8 d
83
9 D: f$ x- B# o1 C+ k! B
84
7 i8 Y* f! F; P, x
85
( {" `) e0 r& |# a1 ~
86
( j% P S0 @5 e5 C8 @
87
! g1 x' M i* s; O
88
! Z: Q G( `$ d \' a3 R% v2 K
89
* t" u/ E1 }- k9 n h- h) x, f
90
' X* _' Z: x+ p4 W* ^' O/ l
91
+ w+ C) x& w4 G
92
6 R! Q. r2 v. }/ o; d: ?6 l
93
" N: f2 j: o5 J
94
9 @1 G- O# L) k% V, j
95
7 N" H+ x, S% q& Y2 J5 I
96
- X2 g# f% ^$ o7 [( s
97
- }2 F8 ]$ l O. u7 V; D
98
/ i% D. [0 W$ @, q) D, j4 Q
99
4 Q' r# u# H/ S4 N
100
( n( a7 ?% d, ~$ M8 y, p2 L& s. ]
101
3 _. F; `# G G
102
# U8 O/ W% X& ~* w/ E
4.展示一下数据
' O8 {+ o9 j9 O
def im_convert(tensor):
2 K; m3 u8 b9 W( Z3 F6 d
"""数据展示"""
5 Z+ d+ [1 d2 H0 Y' P/ c
image = tensor.to("cpu").clone().detach()
: k) ]' W+ w0 o7 O, i E* \% A
image = image.numpy().squeeze()
/ A* c4 j K/ {
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
1 o: h9 D, M1 k+ b% Y
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
2 Y2 P3 g% ?8 W9 Y+ P5 P5 Q
image = image.transpose(1, 2, 0)
9 _! A/ f- [' C4 {- S5 k0 y
# 反正则化(反标准化)
' j# ^1 y$ Z0 L
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
]+ H4 N2 M5 j `2 ]1 T* r
9 u, L0 b+ s( ?8 X6 S
# 将图像中小于0 的都换成0,大于的都变成1
0 O1 d1 X+ }9 L) X' t) K
image = image.clip(0, 1)
' c1 ]$ @+ p6 i* l8 ?& y
' ]: I+ y9 w' [5 P# ]* Y
return image
3 ^8 Q; V7 l9 G% A$ p" H3 V. V
1
6 C$ P6 t7 C5 ]4 q5 k# W
2
* e, F1 A' W$ P
3
# T4 L3 y$ @. o j u# S K
4
( W2 [& M. P& U. P4 k K8 F
5
3 B2 R9 ?# `9 O# }7 i2 M0 p
6
- Q$ a' R0 _8 y* y- H. h$ v# u; T
7
; e% R5 I7 i3 E, i8 D5 Q$ b4 S, m
8
) l$ H$ l( e( l( D; g
9
: Q/ P; M& z7 A f k
10
& Z* w' V, }4 Z4 q# ]/ z, `; I
11
0 t$ T+ `1 ~% u
12
% v6 j6 {' l/ Z8 L" }4 g/ n4 k+ X
13
# K! U2 R4 {# @6 M3 U' {# G
14
! T& s6 V; @1 O, |$ Y' {
# 使用上面定义好的类进行画图
( F3 N5 `/ x; M% P
fig = plt.figure(figsize = (20, 12))
; [$ U' w; ?) r
columns = 4
$ r# Y' w7 }2 f; v' T$ h. K
rows = 2
% J0 x0 w9 }4 A( p
0 R& c+ {; f8 B+ P6 T+ ?/ B
# iter迭代器
' e3 Q' f# e& r& B. n% C
# 随便找一个Batch数据进行展示
h9 y/ ?2 C9 }# F: o2 `* L
dataiter = iter(dataloaders['valid'])
3 y4 u5 ]+ j) Y8 I3 D
inputs, classes = dataiter.next()
2 C) c7 _* l6 {' }% \- R3 ^
8 F4 {. _, S1 v
for idx in range(columns * rows):
, c4 O, J' Q0 U: T% F: }. c( n
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
$ m: u4 t, I9 [/ p
# 利用json文件将其对应花的类型打印在图片中
6 G7 J* I' r! j6 y5 C' q i
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
2 C9 S; d$ u3 }" e
plt.imshow(im_convert(inputs[idx]))
) n' a5 ~( D/ o3 F7 s- v0 q
plt.show()
; L) Y8 r/ G O4 F" a
9 n" ~) F+ e5 |
1
* Y: H7 ?0 a% d8 T, i% W' g
2
r% ]4 I) q7 A9 Z1 }! {
3
/ N) A) Z6 w% ?+ v, F& P3 a
4
$ G0 x9 g$ f. _. j! w
5
/ a6 P4 O5 v" p$ H2 G0 i. p# x
6
0 I; m$ M3 @" @7 y: F# C
7
# [5 i0 E5 O1 i$ n
8
, J- q0 N2 S) f1 K& |$ j
9
0 a# M% p/ S9 y2 U% o
10
. D5 d5 D1 O/ G/ B9 Y! W
11
8 S5 T( k/ H7 T& z; z/ v0 }
12
0 z p/ {4 F& Q. d; S$ \
13
6 }" |& B! T5 i' h- V4 |; I
14
3 S% |$ i& [; D. N( Q
15
. P; g' b3 [) ]8 l
16
: [1 ?+ W& Z0 c& W0 ^2 |! i, h
3 G' g0 x0 W1 n1 z9 e
$ b& b" J, B- T7 n' F& @
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
+ h& \! {( f8 Z: \; r& Z; c
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
7 C( u* m4 `9 U* c# v$ u+ w$ d
# 主要的图像识别用resnet来做
) w+ U p1 |. u7 r0 T* q
# 是否用人家训练好的特征
( b3 L. l3 _. F5 H; |& R9 q
feature_extract = True
" n7 M- [* P2 M4 b# S" m! S
1
. t/ E# A: Y( k b
2
, Q7 Z$ _( v% K. u* |
3
; X: q/ d! T! h8 Z2 w/ S
4
7 a3 l* U0 M' m3 i6 P; r
# 是否用GPU进行训练
: d% q% x4 m( m6 S7 U6 P& l1 _: o
train_on_gpu = torch.cuda.is_available()
$ W1 i6 n. r: J. @
2 c& H, @$ c2 Z: Y O& S. v8 I
if not train_on_gpu:
' ~. M9 A6 W( {, ~" J
print('CUDA is not available. Training on CPU ...')
/ g9 C* C! M! r6 i: K0 g
else:
3 H n+ v: [7 z$ f3 k* C f0 ]4 ^
print('CUDA is available! Training on GPU ...')
% f' k: o( D. `/ U7 V7 q
+ r- C( b8 \$ I5 |3 x* `
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
* I- @" M1 b2 [7 P, b( J( R5 F
1
! L7 r4 x; ]$ Z& e$ I, r" P
2
3 F% o5 o) T, R
3
/ \, v( `1 ?; D% O3 M8 m+ E
4
7 S) l+ ~* |% i0 j! H3 E" Y
5
. c, |& K& H5 S, f6 G2 k4 p
6
$ d, x" A3 Z3 j! l+ S
7
& j- v/ a3 U( O+ x
8
/ A9 \! z; h8 X3 \6 D
9
S8 W9 G: i0 {3 T" T0 F4 u8 r
CUDA is not available. Training on CPU ...
3 q8 N2 l" u! @# Q
1
5 G; \* U% ~" ]7 X9 H/ W; \1 t2 s
# 将一些层定义为false,使其不自动更新
5 R! d7 _7 i4 A! g. S- C
def set_parameter_requires_grad(model, feature_extracting):
; A8 P/ j' L5 G _
if feature_extracting:
5 R+ Q0 O6 @) Y9 ~
for param in model.parameters():
$ {* N) R+ t6 ^" ]- j
param.requires_grad = False
+ S9 F7 k( n4 C7 G) X
1
) e6 u2 m& s' Y7 \ \7 N
2
8 S- S# }* N ]0 F9 S
3
3 M& A6 j G( C* S. ?$ f) n. O
4
* `3 d' e3 r9 r* j6 Y T% w
5
7 \% i; u/ i( z" t! m% [
# 打印模型架构告知是怎么一步一步去完成的
3 o# \- ]; c: ^6 O$ v
# 主要是为我们提取特征的
6 N X7 {! t; M; o
5 o1 l4 S. D* y1 e, h7 U* b
model_ft = models.resnet152()
+ [# A/ f7 L) Q i' m4 N u& ~
model_ft
8 w$ t9 X0 K, v2 V {, K$ w/ b
1
7 N1 a7 L; F* [; }3 i4 x
2
, y# r% T5 C# T
3
# X* ~$ r7 Z4 v8 F! O! n" @* U. S, `" z; b
4
* H2 f3 W' m5 O" V+ |
5
2 W3 f, t/ |+ T
ResNet(
1 w. w" M' K6 ^' s; S
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
" G; ^2 m* ~, a# I
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
: X, `( _+ e8 D$ W. l0 D2 E
(relu): ReLU(inplace=True)
% Z% P( D' Z2 S7 F
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
; g: q5 H: E3 i. v
(layer1): Sequential(
- |2 ^5 | Y& c! E8 P2 d+ |
(0): Bottleneck(
: _: U2 G2 G; |: c
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
% ^; v" z# Q1 C+ r" ^" W8 x3 v
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- D/ B* T/ e* p0 t! N# n7 \3 v4 M" y
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
: d. `6 N! _ k [7 Y X% a
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- j, R8 ]& q. o0 S/ K; _
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
7 e* _# {3 A* ]0 p
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
3 A j, U2 S4 ^+ b8 b/ k
(relu): ReLU(inplace=True)
) d7 w# E, k& c
(downsample): Sequential(
1 e( W$ p6 \! I, N. a1 @ n1 A$ x
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
) \2 i% Z; I' m0 G+ [
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
; U3 x/ c) s! ~# P z
)
4 E2 C: J$ T9 `8 Z
)
: |* p P1 s& a8 Q+ J g5 k
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
% k# f2 {5 O9 {' }5 _2 R) N, d
(2): Bottleneck(
6 N( U, L# T$ O1 z! {
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
6 y# A3 L5 L7 P$ n* Z3 ^6 C9 g
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
* k$ L( n) W4 T+ J5 S
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
! Q- @2 R+ C! M& K; U# R/ N
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
! C% y, {6 a5 P$ L' v
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
! \- S: e$ b5 q4 T
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 h" w& U0 a1 l- U8 h
(relu): ReLU(inplace=True)
8 U( [/ _5 B y M6 q
)
9 H/ Q5 R+ M4 ^4 e" }5 P: [
)
& n$ k) |& B9 b. _: }
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
- X( Q! K1 D5 B
(fc): Linear(in_features=2048, out_features=1000, bias=True)
, S8 w- [+ y' K+ m! S
)
" e- Y0 w' O' Y( \3 _1 Y
/ O$ [) q) z) n2 B
1
; P0 l- v+ {1 r. r6 I! ]( p! P
2
! N$ h+ d1 L! ]/ A! C
3
! ^6 R0 Q( b5 j5 @) H8 o
4
/ u1 C- _" M k7 H/ I
5
% B0 x2 [/ c# T) g
6
3 t( x; o( C: ^# ]) X
7
# s: F1 I1 l c7 S! i- V9 m
8
& r, r2 L- k0 H, _! r' ~ A
9
6 P) P& V0 u1 z" y5 t* N' D: n
10
! ^" \4 G5 i- n8 A/ Q& |
11
; c, C$ [& L4 i+ q0 H x
12
, c* p5 B$ o1 e& v5 a E) p
13
0 ]( e1 F; u+ Q" P3 s5 M) z
14
# Y4 F( ^/ G& G& e
15
7 |! H( T. i, r0 R
16
7 o5 g, {# B4 K t3 \
17
! `; k W# b* J6 m/ y, B
18
" D9 K( w- H7 a, n
19
# i& ?$ Z( t# @4 z% Z
20
/ Y z7 {3 ~4 o. k2 R4 y
21
; l- S! S" J5 q% v# v
22
5 q" c* N" X" ?3 M+ `1 M0 W
23
@8 y* q4 J+ q3 D% X4 i
24
' y$ \; p4 ~3 O' ?5 `8 r7 `
25
0 Y% v, r* _6 [9 A
26
, t5 F$ o4 E9 O. ]( z; |/ O
27
6 M3 {& y; g+ W+ O) |# |! g7 X' [
28
" ~! ^/ }2 b- n% [, K' W- `# v; h
29
# j4 K. N: o+ @6 i, q \
30
7 Z9 s" u# j7 `/ g; o
31
8 A) ^" f9 G3 X8 Y0 e
32
2 M6 V3 z; f) x8 I: F" x @ P
33
6 K: B0 p8 T, n; v2 S* P
最后是1000分类,2048输入,分为1000个分类
8 i; [9 D/ c6 f, `7 X E" q, B
而我们需要将我们的任务进行调整,将1000分类改为102输出
: j$ D2 z# X* O
/ W2 ^, _ y. l
6.初始化模型架构
7 s v5 B1 ]$ |9 y- Y
步骤如下:
8 f2 R2 x% z/ F0 z. W$ Z+ m. ^
7 b: F- x- ?' G' v% m) i/ C
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
- b1 t" o. n% [3 z
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
% b. I. {, a0 R$ z" d0 l$ ~3 _
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
* Z/ j# B) W* ]4 L; @
官方文档链接
* Q. C' W( Y7 ]" x# j/ \0 y! ^# v
https://pytorch.org/vision/stable/models.html
( h5 D2 V3 [8 s2 o0 N2 v8 t
% w: @! s. |% ^1 T0 w
# 将他人的模型加载进来
9 U" ?* o+ P3 u2 H& D1 K# b
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
( e2 ]) }8 ~' J6 E& R7 d0 j
# 选择适合的模型,不同的模型初始化参数不同
) l) Y; R" M; {2 \. }& r
model_ft = None
% p: Z5 u5 k4 ?2 o8 f2 H: {
input_size = 0
% h' I' _, k/ `( h1 l
" F! k3 H5 @3 m1 Y( A% V9 |$ d
if model_name == "resnet":
! }4 d8 n, A) g1 `
"""
; K6 j( P: T' N, [# u! b
Resnet152
: Q6 z x* x+ b
"""
& n. A, e' J: f( S: {, H H$ A
) d( y2 E" c6 V7 Q
# 1. 加载与训练网络
( N3 A! j2 z( a+ ^- y( d
model_ft = models.resnet152(pretrained = use_pretrained)
. J& y, |6 r+ ^( X
# 2. 是否将提取特征的模块冻住,只训练FC层
z! j; z8 _2 e
set_parameter_requires_grad(model_ft, feature_extract)
' O+ B8 n( Q% t @* h$ z
# 3. 获得全连接层输入特征
: }5 }4 Y. }* i3 q
num_frts = model_ft.fc.in_features
: Y/ m7 {3 e1 l4 k M
# 4. 重新加载全连接层,设置输出102
: J8 f# w" S, w, M
model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
# `5 O: p1 [, [0 G( e2 |+ V. ~
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
" s4 c6 m' a5 g. v4 ^6 p
input_size = 224
# K2 t5 b6 {) W/ Y) p4 K0 T$ n
9 {* n5 y: x7 }' Y$ `: f
elif model_name == "alexnet":
/ c8 ], }' j4 m4 Y: G: }1 b+ d
"""
' `( c2 n: o/ B4 S) P; S7 @2 Y/ u
Alexnet
}, I. f; T8 ^* ~- N0 v
"""
- G; ]- f. `9 b3 y5 u0 S t5 Y& j
model_ft = models.alexnet(pretrained = use_pretrained)
5 o6 f9 M C6 N
set_parameter_requires_grad(model_ft, feature_extract)
/ j( k7 e& C/ B$ @- f/ `- d
7 o D" b+ p* _* ]% ^+ a/ o
# 将最后一个特征输出替换 序号为【6】的分类器
) L3 n( a" h9 s2 z
num_frts = model_ft.classifier[6].in_features # 获得FC层输入
6 D8 u: t0 M3 r3 R( C+ k* d3 Y# i5 |- S
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
/ q7 `; w; y5 v; R* I) k$ J( k9 U3 P
input_size = 224
1 o; @6 k* X) u! B4 V3 y& z
0 }8 x- y) ~+ ^9 m6 t
elif model_name == "vgg":
x' t& o+ x6 R. W# |* u+ t0 f
"""
; h9 g6 U0 N3 R) m
VGG11_bn
+ L4 D4 U5 Y. { E0 M; ?9 n5 h
"""
2 f# W. Q( K- y" u! h; k8 s
model_ft = models.vgg16(pretrained = use_pretrained)
9 ~$ \5 j7 h4 b w" Y) ]6 Q
set_parameter_requires_grad(model_ft, feature_extract)
. E" m* [6 z! ]$ X! ]$ w, ~9 P
num_frts = model_ft.classifier[6].in_features
8 A% @( Z6 w8 h% u9 p6 h! c# G
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
K. a- P! E% `6 J* I
input_size = 224
: a" Y- Y" k0 l1 N/ X' N
" [4 L4 C% m# a) \5 V, R
elif model_name == "squeezenet":
0 Y6 C7 d1 O% P4 }* H2 ]
"""
7 h( Q+ w& ]* G( l$ u4 k: d
Squeezenet
' \) w" }# u9 X6 R _
"""
: [0 L# b4 Z5 b c8 h& s) `( m
model_ft = models.squeezenet1_0(pretrained = use_pretrained)
& [5 X. O3 b- i
set_parameter_requires_grad(model_ft, feature_extract)
5 w. `2 T, O2 |4 `$ Z: y- f
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
/ v {& |- ?( ~# f" ?& k
model_ft.num_classes = num_classes
* h+ n$ [0 K6 n/ b' y; z- ]; y
input_size = 224
- {1 M U& { W: j
5 j! J/ R3 T) i" V
elif model_name == "densenet":
) m$ A6 y" e$ M
"""
& g; _2 C4 F" X( f% `
Densenet
# y# N" h+ G: t8 Z+ F
"""
' S8 _ m- P9 T
model_ft = models.desenet121(pretrained = use_pretrained)
8 q6 ?( N- x: b! s
set_parameter_requires_grad(model_ft, feature_extract)
1 B2 ]5 t3 e- l+ ]3 N# C: F3 r
num_frts = model_ft.classifier.in_features
* _0 U! R3 |6 N+ J, a9 n5 g
model_ft.classifier = nn.Linear(num_frts, num_classes)
" I, G! _$ P" T+ ] f9 R/ S
input_size = 224
7 W3 b" W& h( x
8 \8 N8 _7 n r# V' r2 Q
elif model_name == "inception":
. I9 e9 f: u1 o+ G% A1 O- m' a( h& e
"""
. j1 x& T5 ^/ E
Inception V3
* {$ S4 I& w6 h9 M# F
"""
1 m" z/ T# V* @3 d: Q; |( P; J
model_ft = models.inception_V(pretrained = use_pretrained)
5 u( x. d4 f3 Y2 ]
set_parameter_requires_grad(model_ft, feature_extract)
7 O- @# o5 R$ P; Z Z$ l
9 D; F. N3 R8 \; O+ z% q4 ^8 `
num_frts = model_ft.AuxLogits.fc.in_features
9 r Z6 K" D2 E4 M% F! Q) s
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
- j2 Z3 O- L3 C
8 ]# }' Z/ M$ k# [/ J
num_frts = model_ft.fc.in_features
F! h: [: ~" a+ z. k& o
model_ft.fc = nn.Linear(num_frts, num_classes)
& u- N2 m2 ]2 [
input_size = 299
6 a4 d( I0 S. u9 Y% l# R
q( A, S5 ^& m; _. U1 J
else:
# N. d ~4 y+ a: l
print("Invalid model name, exiting...")
W7 }2 I( ^; X
exit()
3 z4 r8 n5 \+ l* [$ S
! E* |) E3 L$ d. t
return model_ft, input_size
6 A- Y3 G' k# n- C) ^
* P& {+ E* S' T2 y v; ~0 a
1
, H& t- ^& Q+ R5 @
2
7 j1 A' u w, S5 H6 [
3
, t: }* y. V/ G8 X
4
y) F; q8 g! u, F8 A. L
5
; z! c5 |$ C/ I3 Q( D6 N& ^. J
6
% ]5 Z( g4 h% w' X$ w2 ?
7
- P% R6 Q% ]. N6 u
8
) w* p8 ]8 S+ D4 ~; _
9
. x, H! j# a( m/ q
10
|1 y. S1 X3 i8 t$ a( [6 z
11
4 Y5 R9 F" h+ X6 J4 |
12
) y& t6 I) k" m$ n
13
# C8 p% z) ]& m! K/ y
14
, x) } K; A0 F
15
& y# A8 [3 N% m8 b/ q
16
9 P; Z6 K' Z# h8 D l* X, w2 w
17
7 H9 B* l: S( \, O
18
u' _6 _5 A. f8 c# Q" P# l, o
19
0 ]8 G, Q5 p, o6 f7 I& V
20
0 d1 t" X$ u0 K- |, }
21
# \' _3 e9 o: d" q- B( m
22
# C' j+ ?5 j9 k, D+ J
23
( o) z+ x# T3 W- P, E
24
9 J1 a! ^7 k, }8 Y5 }
25
7 T. Y( ~3 U6 C z
26
3 m& Y1 T3 a8 _
27
: |2 v) C. x( B) w3 a/ I
28
5 N' W7 Y+ C1 M5 `: \# u$ y3 C
29
! W. j5 B0 m; b- C8 M4 C3 @
30
" D" F+ u, P( I ]
31
* G; C+ z6 ]- u7 V' P
32
$ p. k5 I6 r! A8 v& s4 J7 E
33
9 t+ u/ c! j1 m3 ] b7 a w
34
9 p7 w3 H" P+ C( m3 r
35
! _9 ]" `) D. g9 ^
36
o# {( A. i2 z- x) H$ ]
37
4 v- L% @3 y5 m `4 h& u
38
1 E( X' }9 Y0 s1 S
39
. S* w2 ?9 m% a# B1 |+ b1 T% U" U
40
3 b. M: {, t$ r
41
8 X8 t5 l W: U1 B( V7 ~
42
. W- N7 D# ?' u
43
8 _* Q7 e" C8 j; W
44
( U2 w+ I$ S( u1 L5 ]( G
45
+ b8 m2 W4 h* f: x" C; R
46
+ @8 _' X! n0 z. {! Z( f# J
47
4 ?. n; W5 x. z) H' |, s3 B; x
48
* z4 I R# c1 |) I
49
/ R/ u! |5 p" [9 ^+ _% X! d
50
0 H/ B2 u) F( S w- }; c' v$ y$ X' P
51
2 I1 X- _! ^! Y1 ]
52
: Z( l1 H" h: e, D; m, J z
53
; ]4 Z; @; l1 b6 b+ x! D/ l; x
54
* O& M4 s, C" m& D: W
55
" h' S3 G% }( X7 }3 J" h: l
56
- Z. q# P/ U/ x( x8 c( U, J
57
6 d6 W! T) ~5 D: [4 `4 x! i
58
V# v' B; @, {6 c* B
59
& Y; v1 s, V5 x4 {. f
60
0 ^3 ^2 J: _# ] I6 s" f
61
& ?5 \8 f" Y2 s3 Z# O
62
% z, Y9 T, f1 s# k1 O
63
6 n/ u, J% _/ ~2 c0 c$ C% o/ x
64
3 f8 L2 f$ {: ?4 n# q% X
65
! I, j! v0 P( M4 C' r
66
2 }8 w0 Q5 _ U/ J2 Y' Y% R% X
67
8 O( R' _& A" V' d
68
- ?9 G; u1 J) |5 g( I3 a! T
69
2 e: Y; o6 z1 |; F2 n
70
3 W6 @/ L/ a2 u0 _2 A1 j
71
" B( w2 K2 V' |5 b
72
9 v& e* y5 }$ J
73
9 [1 {0 S0 K& r' [
74
+ _% p& Y% c7 X |2 { a" g
75
% J. e5 o9 E2 u) |0 _0 n
76
2 I7 g: N& | C6 M3 n6 r: ~& L! O
77
- {: z2 `9 E; h
78
6 q5 M& W7 z4 Y
79
0 }6 t4 _- R/ G) b- E
80
/ J; r) } I4 b) `5 L: h
81
& C. M0 u5 v! |% _3 b
82
2 }+ e5 q$ D: P- s: L* e( y0 ?, E9 h/ `
83
+ X* K- A6 R4 w" S( P/ u* m
7. 设置需要训练的参数
7 X6 b1 ]5 q7 L. N/ W) E( ?0 m
# 设置模型名字、输出分类数
- U6 L/ x2 c; Z' {/ m6 y/ B
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
+ V' E e- k E( v; o
9 t8 O% m5 r; K, N/ O" S& F7 m& |
# GPU 计算
6 t7 ]2 K9 p3 Z; Q" O+ v/ I/ w
model_ft = model_ft.to(device)
9 o- Z# V1 P9 z5 F* U
' {5 A7 G7 j& _# q2 @: f
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
7 W) b3 A q! G3 F0 c8 N( Z' E
filename = 'checkpoint.pth'
. w' l& \ ^% ?% x
: K% l2 h; z5 W( {8 Q2 f
# 是否训练所有层
+ c6 e# [, d7 l! {0 O6 F2 {
params_to_update = model_ft.parameters()
* g; m8 G/ Z E4 J
# 打印出需要训练的层
" u) ?0 e, u) I8 t% k- R
print("Params to learn:")
+ u4 @6 X6 ^# ~
if feature_extract:
" ^, k4 o& W A2 `, q6 V( ~$ b2 u
params_to_update = []
- }6 d$ Q7 i( H- y$ ?' Q1 l8 D1 @
for name, param in model_ft.named_parameters():
0 l9 a- z, W: {% c. x4 u
if param.requires_grad == True:
B, m8 Z2 v- u
params_to_update.append(param)
# i# P7 p' b2 |; P$ F! C9 d/ I6 `
print("\t", name)
5 H% b3 j+ e+ v
else:
# U/ o7 q1 b8 Y* h# w+ \9 Q0 ~
for name, param in model_ft.named_parameters():
]0 B7 h2 Q/ D" s1 N
if param.requires_grad ==True:
, j: E4 ]$ ^. Y# O" m% U2 M
print("\t", name)
# ~) k% S4 ^ R' Q M
/ I+ ~* P- B4 @3 g& F( w4 L
1
- u7 ^5 E3 b0 q/ y3 S- m# z% W
2
1 m2 c2 ?) F* b; |3 h8 R# g
3
2 |% I1 ^6 T/ \
4
# u: y7 Q5 c# _* [, D$ v e' t0 V
5
, _/ W5 H2 f: z. t
6
. d6 m6 V9 ]- U0 R
7
( Q# h1 ]) d( t7 J# q
8
7 h: H9 i. d3 \3 G n* d
9
9 _ Y8 p( [, Q8 M
10
, q1 F; {9 V5 v0 j& M5 M
11
3 \# Z$ v/ N7 }$ Z: B3 s
12
/ p; k) L0 B* c) ^+ C
13
5 D5 X* A9 Q% w3 ^4 @& l: Z# l" ?
14
5 u+ O8 O5 {5 b J. p5 h
15
! [. B/ h2 ^; \2 v2 ?9 c1 w. ~7 _
16
# }! b+ B. \& m3 s
17
" v+ {6 D/ U3 Z q
18
4 @/ G; r+ T1 r. v" V$ C5 t0 n
19
- ?- x! g# I$ ]. K' O
20
7 |4 E0 o2 ^+ `2 G E3 j- U& J, J' z4 W
21
+ ~! p h; I2 }- i
22
8 ] C# b, V: H5 |9 ^: W; ~$ t9 k
23
# z" K4 }5 E- @5 G" m& f' {
Params to learn:
0 |- n; Y3 H* S8 @# n: u9 u7 n
fc.0.weight
; A. `, V3 C" p! l, M& n4 u2 d6 ^
fc.0.bias
6 r7 A& {1 T j% ~) m" \1 ~
1
0 s* p3 B& ~7 f3 t! r" F% U' O
2
$ W: H4 l" ]- }+ J+ {
3
- F: w% s1 r7 ^" I8 ~
7. 训练与预测
# I6 @ [4 L: l4 H$ |: B6 ~
7.1 优化器设置
; f4 M1 C( e: T! f3 e x
# 优化器设置
! _* ~/ u; F5 U& X
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
* y8 j: Q0 R/ @7 A% l
# 学习率衰减策略
9 l3 `$ I1 F. C/ \/ A g# x
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
0 s" j3 t, ]9 p1 ?. p0 [- U
# 学习率每7个epoch衰减为原来的1/10
: g z+ J7 n& \4 d
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
" F2 ]4 b2 D- w/ q9 @; i
+ L P c# d% \1 p7 q, b4 b- x
criterion = nn.NLLLoss()
* h8 V9 G I5 q: R i
1
* Q1 K8 A& w2 |+ D- A! m
2
* r/ J7 \% G7 D+ o
3
* \) ]$ P t9 P+ S5 w
4
- L- S m% E) }9 A
5
* e9 E7 P5 b, w7 ~1 `
6
8 u6 f. A" u4 t4 N! r, U4 T! ?
7
9 q! G4 t: C1 P) ]( n, v; r
8
' u) M8 M6 P0 Q. B0 T C/ H
# 定义训练函数
: e3 Z) k8 v4 w- p2 w7 C1 E
#is_inception:要不要用其他的网络
/ f1 Z6 a% t% a6 q' D8 u' L
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
' O- k6 L. s+ z7 M9 @% I
since = time.time()
9 l) k& k- H% L4 Y
#保存最好的准确率
& W: Y0 T0 c/ O4 O3 d; j2 F' E
best_acc = 0
' ?- v" P1 N( E: N# T; v
"""
9 }; b! C, Y& v" ~- ~
checkpoint = torch.load(filename)
* u* ?" s! {" ^$ I8 K2 S( U" [
best_acc = checkpoint['best_acc']
+ X R& C( m7 Z* {0 c5 f
model.load_state_dict(checkpoint['state_dict'])
! j- f9 ]& S8 ?2 I8 X. _
optimizer.load_state_dict(checkpoint['optimizer'])
0 [) a& u" m0 k+ I5 v% q
model.class_to_idx = checkpoint['mapping']
7 x' q% d1 U& r
"""
9 N* n- N/ ] W. Y0 w/ r3 A
#指定用GPU还是CPU
% F8 e7 Y5 P9 J2 A$ s' |
model.to(device)
! P( h0 ]' t3 g* u5 k
#下面是为展示做的
4 H# A W7 P p3 x' I! X
val_acc_history = []
. x* h0 A/ S/ t8 ~
train_acc_history = []
* H$ q& E* {5 T! ~; n, d
train_losses = []
0 B' P6 l" Z" l5 F
valid_losses = []
* n- \. F; s+ I% d! u6 U7 @
LRs = [optimizer.param_groups[0]['lr']]
' y& M' X6 M3 _" N& E
#最好的一次存下来
) g# r3 _; d- g2 j1 i
best_model_wts = copy.deepcopy(model.state_dict())
" U5 y; A5 o* B5 s
* l* Y6 ]/ I$ Y* x4 k) H# F% l
for epoch in range(num_epochs):
3 M) Q: i Y# l& o: ]& d
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
) S9 r6 h7 K% f4 D) B2 K
print('-' * 10)
1 R. o7 ?% q+ N# }
2 {8 x4 T) S5 J$ B7 n
# 训练和验证
2 |, o3 H8 C: _8 h1 }; Q5 l8 V
for phase in ['train', 'valid']:
/ r& \5 d% }8 H$ b/ |
if phase == 'train':
* f3 x* e8 P: C4 H& @0 @& M
model.train() # 训练
2 F* Q/ I' P' a; D% i$ A3 o
else:
6 t c o: n+ p+ ^/ P% O
model.eval() # 验证
* Z. v9 o8 v! w( M: Y6 l4 @: c
9 z) I3 q3 U) y7 }1 D/ }/ T1 l! X
running_loss = 0.0
0 T' s* W% D9 M7 h( b
running_corrects = 0
) J! h: V Y' T& K& A1 {
& o. ]1 Y; e9 i: k2 ], r# Q
# 把数据都取个遍
- V* G$ ~3 E" W
for inputs, labels in dataloaders[phase]:
1 B2 p' M2 x) m1 |- X6 v
#下面是将inputs,labels传到GPU
+ m/ k6 x+ l# ^ `
inputs = inputs.to(device)
! m. D3 q8 `1 w! V1 J
labels = labels.to(device)
1 Y5 @% ~8 T3 x ]
; o1 A( {; N# G/ V# y! ^
# 清零
% z/ Z. ]! K+ I
optimizer.zero_grad()
3 e) X2 p; t0 _# ~
# 只有训练的时候计算和更新梯度
$ t, S# Q0 R5 ]2 N: \) b" v0 e- B/ t
with torch.set_grad_enabled(phase == 'train'):
# ^0 ]: t2 ^7 ]
#if这面不需要计算,可忽略
# ~7 t- n P: L' F- D! @* X
if is_inception and phase == 'train':
# ^; B5 c; [0 ?0 z9 r3 x
outputs, aux_outputs = model(inputs)
6 d! }! @, O0 w2 m
loss1 = criterion(outputs, labels)
# c' p5 B- ]! ]5 a$ Q6 d
loss2 = criterion(aux_outputs, labels)
3 ?7 ^0 g9 a w$ _* @
loss = loss1 + 0.4*loss2
6 m! x; I1 Y r% E+ ]0 n4 s
else:#resnet执行的是这里
7 [; p6 C# C5 a& X6 I; w
outputs = model(inputs)
1 F' s3 {* D$ p! U/ C n
loss = criterion(outputs, labels)
( _3 j; G) _! l0 ^4 Y
2 \1 e8 |) A' Q$ F( d
#概率最大的返回preds
. _, U# h* L* z( W% r4 I2 T( O: C4 X
_, preds = torch.max(outputs, 1)
! q( l7 a. a6 @
) A" z% g/ s" ^9 B% a" T
# 训练阶段更新权重
0 c8 ]8 d) Y+ P$ ?" f
if phase == 'train':
: d3 l* M. S: [! T. b1 |
loss.backward()
: r" T. Q p; @
optimizer.step()
0 v) ^; T# U1 }' V) ]
6 m" l: ]4 f" P- v
# 计算损失
: z6 T1 f9 V7 s T% t: j9 V
running_loss += loss.item() * inputs.size(0)
9 y( \5 V1 ~2 W0 a
running_corrects += torch.sum(preds == labels.data)
3 b& j2 V+ z) e0 v6 J8 U3 X8 r
/ k2 Y' @$ Q2 |$ e: Z
#打印操作
) Y6 d9 _ Z- R: G: ^6 T7 ^
epoch_loss = running_loss / len(dataloaders[phase].dataset)
/ z/ Y+ c7 n0 ^, o5 i5 k+ K
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
- K% j X& J0 R0 K- }
8 T, f/ Z$ E! k% Q2 s- d" j
3 h9 u7 a$ d$ ^2 R+ B. h8 M; D* n
time_elapsed = time.time() - since
) k& B8 z' F- H; a
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
7 V+ M. W% m; h3 P! B7 a0 v
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
; t3 R( e" h1 \' {% z6 ]
0 Q6 H. Y8 c2 c
: \3 c" i, `: p
# 得到最好那次的模型
5 J2 V- @9 X9 ?- J6 `/ d
if phase == 'valid' and epoch_acc > best_acc:
% Q- F7 D; s% W- e6 n3 Y1 a
best_acc = epoch_acc
4 A. z0 e9 r2 f. L8 m8 q( R% h
#模型保存
6 \$ H8 {) _1 X! s! t
best_model_wts = copy.deepcopy(model.state_dict())
) ]7 ]' h3 @- K/ t. c
state = {
# M) k0 C" l1 X( K5 h# c# l! t
#tate_dict变量存放训练过程中需要学习的权重和偏执系数
) {1 C \! v" Y( Y
'state_dict': model.state_dict(),
5 z* ~# l3 h- _$ j8 O
'best_acc': best_acc,
" ^5 D! O& _- S+ q' f* a* ~
'optimizer' : optimizer.state_dict(),
1 s* y0 u" ~+ C$ I
}
; }- F+ k3 X* P1 L* R
torch.save(state, filename)
7 F# C7 B) e% I
if phase == 'valid':
& w! V o2 c0 I2 {9 r* i
val_acc_history.append(epoch_acc)
! b2 _1 w/ J4 ?$ z3 r- K6 A" _* h
valid_losses.append(epoch_loss)
2 k4 \) X, X0 X5 b Z5 ~ c+ M4 q
scheduler.step(epoch_loss)
% h" S: D X0 C/ E/ ~
if phase == 'train':
) X3 ?& x5 S, L! u- M; s9 Y# ]
train_acc_history.append(epoch_acc)
/ n& L' I6 A9 G+ f+ W
train_losses.append(epoch_loss)
5 {. K# I8 x# i3 Q% b" `
, o# N5 t9 j9 C4 z
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
: ~: B+ M+ u: F9 k
LRs.append(optimizer.param_groups[0]['lr'])
+ p1 h! T$ L H% \! D
print()
+ y3 A# @; p5 k/ R
- r8 h9 x$ w" F2 g
time_elapsed = time.time() - since
% H6 {2 T5 [* E, p% |7 |+ T
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
6 t/ T$ j+ m6 x
print('Best val Acc: {:4f}'.format(best_acc))
( q. `3 ~ [; p# l0 I4 Y/ b5 g( H" n
+ w% p4 z; `3 P- M$ C5 V
# 保存训练完后用最好的一次当做模型最终的结果
9 H' c9 m4 }0 `. _
model.load_state_dict(best_model_wts)
/ ?) G1 d) u4 m9 n& p3 s+ l" P$ }2 S
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
# O7 E7 d% N" M
* L; [4 [$ q* c v3 k1 `
" k% x! \' k+ g. h; z" e$ V5 {/ U2 ^# s
1
& c5 l5 n5 B3 T; b: g( |2 d( ]$ G y
2
6 ?1 b& `3 Q4 L8 x
3
4 X5 U) R/ |* H1 ]7 R0 N) H
4
# ]$ ~3 f3 X! [. }
5
, `2 {3 p' j, v' P! w/ M: o! J! \
6
# R, U" `1 }8 C. w( b3 Z: {
7
! Y2 r9 u/ g/ U
8
: f. n) g& n ^* _* O2 M: Y
9
9 N x$ _8 `* F4 y0 N
10
2 Y6 m5 |1 w8 n
11
1 ]( ~5 p& j4 v+ V6 _1 }
12
, V3 J( D3 n2 ]* T& i4 ~- x
13
; v: D1 A, w. d
14
, l, h2 E5 t+ n; [9 c# M) d
15
! ^$ ]% t; i8 a& r, m1 ~3 f
16
0 f; l" q% J1 I( \
17
9 B; c0 e! f; h, w
18
9 H d! L' Q5 R7 `8 R1 K# {6 \& R
19
" T$ ?; H! |1 ?2 j
20
9 ^9 _8 b3 S6 g) J" f
21
5 x! S5 W$ p# x/ r" c& x
22
$ C4 V1 Q( u3 a5 c# j
23
# t: Y6 U1 G- Q+ f0 O2 F& T
24
) V/ u- v" ]4 e$ l+ h- o
25
& w! w1 D6 y7 a& ~, L
26
) z9 ?3 [0 s6 z' u3 h
27
: H! D+ ^- G8 O
28
* ]- \5 `7 H/ V) g
29
3 T+ _/ @) ^" ]7 f' H
30
+ X! w: v+ C! W9 ]! Q; o( [. |4 h
31
9 \" {7 P* S. \+ z% j$ j! o
32
; i0 ~* e, ^1 r# z
33
* Q- s, I- M1 [ W9 q) j/ e
34
6 n0 @! X7 L. T8 y0 z6 o% Q
35
5 }# ~; X" z" l# V. P8 l
36
6 b. d8 C! O1 b/ n. T$ h+ }% x
37
" G4 Z& m+ ~; |) c8 W1 O
38
# `$ |+ Z0 n: q+ x
39
, [' Y& p, Y; W( O. O/ V" C/ X
40
( O# z! a6 j7 P4 v
41
% V8 Q6 P8 R6 z0 v) q/ g7 w
42
! W3 L: J3 K3 w, q5 m
43
+ b6 m X8 S/ c# b, Y
44
V) ?. s3 z' u$ J# C
45
4 ?- o- j" l) x5 z
46
8 K# |5 ^. K8 U& P* m3 j
47
& a6 O& Y; j+ p: W* C
48
4 X+ q$ b+ C; _
49
: s4 Q! _1 V; {: b2 `# \/ p6 O4 Q+ [
50
: B+ [9 L9 U# P3 D
51
`3 V9 ]5 l; U* W: i
52
$ f6 J0 |, j* [1 I) ]" B
53
. v8 Q+ J. |( C B' ]' O x
54
# u* T7 W0 p4 S4 a
55
0 x* }9 v4 U; u# Y+ |+ R6 J
56
9 T* W; G- U2 D+ v5 G5 E& M- z6 g
57
U2 E' M) {! K; V' X
58
/ T* S$ \/ t! ~9 Y; j
59
4 i6 d7 N0 S$ a, O
60
- X0 [3 m/ g7 O( k/ Z8 O( f
61
) E( {$ M, s8 }- g+ a
62
' ~& F9 Q* {% H% k% F8 L
63
4 y) n1 u, R! b: b
64
) j W+ ~& Y7 Y1 i9 N3 \6 J, Z |8 j+ Q
65
) |* F# [" b( x' c, b; p" e
66
8 m2 d' U7 _* |) }, n
67
1 c/ V3 c( |% c& J/ b X8 m% r
68
* h/ `2 M$ _/ B% n
69
# u% H2 s6 o" Z o$ o1 E' K
70
1 K" W8 y r: d$ V1 ^+ b& O
71
+ K7 B& r( s" Y4 A- P/ w
72
) K6 [4 i3 M. e
73
5 h, w0 b5 j/ b" w1 f3 l
74
* ?( \% c7 p( I) z8 x2 P
75
7 G+ { }1 o$ z5 { O+ W L2 a
76
2 l8 Z" a+ `* E8 G8 u( {
77
5 e; l, h( w" ?1 j9 m7 h! `$ ^8 `' L
78
8 v5 {& Q3 B/ ?9 s
79
; I9 w' l, f- k4 s
80
; b) M) m- ]* a6 x
81
; \$ g1 D& s1 ]* M3 K7 G1 ?% _- E
82
4 L5 P: } w7 I0 g" f) r1 y3 C& C
83
' ~+ A/ z1 v- p) ^, Q9 l/ G
84
% e1 |* R( M, V7 j7 T3 G s
85
0 x& _; x" z/ R' D" F" Z
86
{& Y# _4 v9 X2 x$ w* P6 N
87
5 ~ P9 @9 u3 w% d+ D; b
88
( k% q0 v$ }8 }
89
) j5 ] M" `- y# `; Z
90
& A% \# |/ E% v7 u# o p
91
, s6 E' u+ p5 j0 X( D' }* H4 s
92
% { {% Z2 g6 q+ T( t$ r7 P
93
2 v% v* A3 P' n) w7 [" W4 ?7 @
94
8 O3 {# P" a* A
95
+ M2 }# U" V2 Z7 K
96
, S, z$ K; B* Y$ z+ j; b
97
( |. d5 h5 R; G
98
5 X9 s5 K Q" `' q0 L7 K% a) L
99
- a: u" N- H+ s% L) B
100
# {6 F# q5 F, ^' a
101
: S2 I* W& P1 \/ t4 h
102
8 s ?8 E) }- \
103
% @9 C- B! r" |- Z# D$ i) c
104
& [: H: L9 `$ f. Z, D' P+ a
105
6 d0 b7 K% O0 Y5 b4 q
106
& R1 u. k/ r: @4 q' s' @# |
107
; \! @; F* k# s' d; e0 v
108
+ k% l) E: @( u! F. I3 ^) G- k
109
/ S' o; b- h3 S9 I- |
110
1 h [3 H/ V$ K0 s! h5 m
111
4 |% D0 _$ j, t
112
+ w. v+ ~3 p( U- R+ j1 \
7.2 开始训练模型
' \. [- e) v" J# q% H# g
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
5 f2 {- v) F& C4 k0 [ f: K
" L, l. z, ] h
#若太慢,把epoch调低,迭代50次可能好些
, N2 q% `: ~$ o% }9 H
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
. J6 m# X$ C* q/ c3 ~7 J
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
1 B4 l( M# E# e1 D; `
8 A" ^$ Y- y* P# n7 {" |0 k
1
7 {% T' Q% y4 k6 L6 y$ w& L
2
4 A5 U$ J8 r) L. {7 G) D7 F3 t
3
. _4 ?3 A4 [2 P0 B5 i5 h
4
+ c5 @- l0 m0 {8 d( H) i4 F, o0 x7 }
Epoch 0/4
% I7 l: h" L8 S
----------
- @9 Y& t+ W& ]& k/ ]- I
Time elapsed 29m 41s
5 A1 Y$ J/ \5 `
train Loss: 10.4774 Acc: 0.3147
+ `$ a- T& q$ r4 [. q; q, r
Time elapsed 32m 54s
, P0 n, t3 l+ o3 O6 ?, l
valid Loss: 8.2902 Acc: 0.4719
3 i( l5 ^ Q2 Z+ l
Optimizer learning rate : 0.0010000
; G2 j h* C% q6 w/ `
3 a- U, x3 l8 k1 q" `; O$ H) @6 M
Epoch 1/4
]( o8 c# Y/ P: l2 b$ x, D
----------
6 I7 f7 r- o8 J, C
Time elapsed 60m 11s
2 a. k4 A: W- Y
train Loss: 2.3126 Acc: 0.7053
; ?. g# ]7 u* l) ?1 R
Time elapsed 63m 16s
' E: k/ `/ {) h0 u! ^3 _
valid Loss: 3.2325 Acc: 0.6626
2 J: a9 \0 Z( w3 ^+ n T7 E8 G
Optimizer learning rate : 0.0100000
/ ?* b% U) t& p+ B9 W% z
& l7 Y, Q2 [; R0 P
Epoch 2/4
. F0 t$ l" {. N# W- m+ C4 e
----------
1 I9 H8 R& M% Z2 P
Time elapsed 90m 58s
( w- V. l! P3 |/ G
train Loss: 9.9720 Acc: 0.4734
! \& T: `# I) Y, l1 Y" p3 i* i
Time elapsed 94m 4s
; ~" b. r2 @$ O. V1 z i
valid Loss: 14.0426 Acc: 0.4413
! g- `- H( z- `2 Z- l
Optimizer learning rate : 0.0001000
$ ^% e' n3 y; |+ c( o
: ]: G( a6 z1 s5 O) }. @- w. Z2 `
Epoch 3/4
! A ^. _& d& I& u( b4 H7 ~
----------
! h& \( B7 m0 }. v: e v
Time elapsed 132m 49s
3 C% z2 \1 s4 c P5 q
train Loss: 5.4290 Acc: 0.6548
8 `! ~! @4 A. V& g
Time elapsed 138m 49s
' D- {/ `& ^3 Y
valid Loss: 6.4208 Acc: 0.6027
/ |& ?3 ]+ _" x" ? X
Optimizer learning rate : 0.0100000
7 S4 M* U: c/ m. U2 d
% X r- z( O# w" R8 u L5 N
Epoch 4/4
1 v- M5 P- h4 p! p. q; P+ f* Y
----------
1 J3 g2 {6 y: p: r
Time elapsed 195m 56s
& P M/ {: f: Z2 D% k4 c; y
train Loss: 8.8911 Acc: 0.5519
# |# N- V. ^9 {: m0 I0 J% A
Time elapsed 199m 16s
8 Z* d( w+ B/ @. m6 ]6 a$ x" `
valid Loss: 13.2221 Acc: 0.4914
" q* Y* @6 \; y0 n# G; [0 i# L0 D
Optimizer learning rate : 0.0010000
# U7 u) F, }) @0 k9 T1 j4 F
8 f: k9 g5 h, _+ G( B. U: Z/ Z7 ^
Training complete in 199m 16s
6 f6 u& i( P* J0 Z& _! W( Q' z C7 b
Best val Acc: 0.662592
2 U" G) U9 r: c4 |$ E
- f9 a& T$ B6 Y8 \
1
. z6 g& p% T7 U+ I5 \
2
9 J8 `! b& H9 w+ F. _
3
/ L- r. c1 |' n1 i$ W2 t( c9 k
4
( I2 q& }9 S4 k" I! Y
5
2 I# O4 b6 c) T9 X! t3 f/ d5 U
6
" X+ A' ]( l) o# g; [0 `" P3 L3 H
7
* ~ x4 l9 u5 J! c& I7 n/ {
8
+ ^- l5 c \" x% |4 ?
9
9 P5 v* F b# m- i V4 ]# o! t/ k
10
. Y1 ?) X' U% b# s6 t9 t
11
& M, B9 P" n) q& R! Z; B
12
" U) m# ]0 x) W% z7 B( W
13
' O( o! q6 p8 n0 W6 F5 a
14
0 {- I2 V$ K, {
15
& q; u) r4 n9 m, g Y8 x
16
0 X: C- ]( c1 ?
17
6 y! t1 n6 D" j& [" c' t* j( l. I
18
- P4 H: q9 U+ `
19
6 W6 C8 Z( p& Z [! ]7 D8 H$ u. m1 P
20
) o, A, G0 ^ f2 T! r& M4 l
21
: H) ?3 d' M& h1 Y& B [
22
( R0 i: n) }' F
23
, M/ K3 K0 z/ J* Y: q2 B2 o
24
& V" `& [: x R2 E+ l6 L
25
5 a; o# G' K4 r& H) ]# l- w4 ]
26
* `& D) ?" W0 _
27
! y# y% W4 F. |4 c8 s' ~
28
6 J) h# Z! Z! F+ _9 L3 _
29
) O4 e4 d' |$ b6 H: ^4 y
30
, \, f5 n3 S* A4 J6 L# R2 W- d7 i
31
1 r- m0 v. p/ X; p% S" n0 P9 Y
32
6 {! Z1 x2 D3 C( q1 G/ x
33
( A$ F3 Y8 I& G9 C/ J3 H/ |
34
a. ~. |5 e2 w1 k, s
35
/ J$ T( j: o8 l
36
3 c& z m" ]/ Y
37
. ]6 V' k7 Y3 }4 ^9 M! r; d
38
0 K* m5 u: ~& r7 e
39
" p1 V8 W m8 L; I- m, K
40
7 i/ V2 R6 x( b3 J
41
/ q7 p: j) I" I- G6 a# L! p Z8 I
42
: g8 D* s/ m/ X0 I: h" V2 l7 x
7.3 训练所有层
0 ~8 S$ F8 q' f2 J
# 将全部网络解锁进行训练
3 a* }/ R$ D8 [7 J, Q( R- [
for param in model_ft.parameters():
- N' \1 {+ Z4 }$ S* o8 T7 W
param.requires_grad = True
/ q: C: E; V, j) I" n) x; z( C$ Y
" `* I% S! d, p/ h: Z8 a1 A
# 再继续训练所有的参数,学习率调小一点\
4 e- X( F- O' q% I0 i
optimizer = optim.Adam(params_to_update, lr = 1e-4)
% G( _) R# i% T4 W; H, w
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
- p6 A$ r8 i) y6 W5 ^( o
' i8 n4 F, a0 S! I( F& ~( S* i
# 损失函数
( A; i: v0 d2 K, O
criterion = nn.NLLLoss()
* i( c, T( d3 V. ^% c
1
/ s* B6 C+ b6 D+ Z; I2 b
2
+ V* L0 j+ R' S0 N" b9 \) N
3
5 o# f3 D! L$ ]1 t# k
4
# E6 v( Y8 V9 _! @- ?- k0 J
5
4 a0 |! D0 i( J1 l; A0 H2 M0 p @6 P
6
8 H$ }! X6 ^( B6 i9 S v
7
T S2 k7 k+ D! B: D- {; n9 _
8
1 b, s! S; B3 x5 g% @- }0 \
9
, P! X& ]! f$ H! v( R
10
# g" c0 A* x# M& B0 R; ^; T
# 加载保存的参数
$ Z2 d0 Z( J5 ^; X" t& l% ?
# 并在原有的模型基础上继续训练
# n8 O3 b% c1 D( f
# 下面保存的是刚刚训练效果较好的路径
9 w" j/ R* d' p1 h4 ^& i
checkpoint = torch.load(filename)
% k. Q$ |3 u/ R& I, x. Q
best_acc = checkpoint['best_acc']
, I& s& E! |: y
model_ft.load_state_dict(checkpoint['state_dict'])
2 P4 b4 @4 M; o# P* Z+ K' M ^
optimizer.load_state_dict(checkpoint['optimizer'])
" W$ z: q, a$ a Q
1
c% [7 X& o% Q- E$ G H
2
1 N' \' {9 {' s G, ~; m3 i' A
3
" N5 ~% H7 P: |' c# D
4
/ n! | N+ g' s' ?+ \
5
9 R! T' S0 N+ I+ J q! |$ i, l) X
6
9 ~8 n2 z3 \, b; D
7
% F# f! |* `0 q7 A
开始训练
9 s1 Q' N& V- S i
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
0 r2 I8 d* Z: M8 V
' M7 M! z4 N6 O3 @
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
6 C4 T$ s$ f) r# v! @8 N% V1 H, V) ~
1
* i: [9 z- z. g% ]* c
Epoch 0/1
0 }* D& i* _$ m! j B+ Q$ f2 Y
----------
- g7 q V' o: n+ D D
Time elapsed 35m 22s
, B% c! d% f1 U. }7 V7 N. A* ?
train Loss: 1.7636 Acc: 0.7346
( C; ]2 n1 A5 i" ?" Z! O" U
Time elapsed 38m 42s
9 d" T% N: v- e9 q+ l, @
valid Loss: 3.6377 Acc: 0.6455
, e/ j% x# R2 P! n, I
Optimizer learning rate : 0.0010000
5 j' _! U% ~+ S6 G) b( \# q3 U8 V
2 G) P- a# G6 w3 A8 z1 Z
Epoch 1/1
/ y- A" ]/ K, R- k0 b6 F3 M
----------
M2 M8 z0 W2 @% H/ G# l2 g: r
Time elapsed 82m 59s
! k8 D1 {0 w! q8 K
train Loss: 1.7543 Acc: 0.7340
E+ v8 i8 d" Z
Time elapsed 86m 11s
, v2 |5 i' P& ^% x' \ z2 z
valid Loss: 3.8275 Acc: 0.6137
. P9 a8 z8 ^+ a. H* J" N! r ~4 c
Optimizer learning rate : 0.0010000
& @! z) k1 c8 I3 X% N
' ^" f& Q9 A" T/ ], X
Training complete in 86m 11s
* l# p' w+ V2 c4 U# Q$ G) h
Best val Acc: 0.645477
V5 P2 E ]5 m* V, J( ]
8 f* ^; o w& |$ t& o6 S5 J2 h% r0 H
1
0 j6 P% E! s8 h4 G
2
: G- r( U- z# S" d: H f
3
1 _) `+ [0 y) Q6 A. e: h
4
5 C, V7 G6 w) Y% ]% @9 B
5
$ v$ n, ?, W" G* @% _
6
5 b! Y! r. p8 D4 I4 N
7
' o/ h0 }3 \8 R4 S! S
8
7 ?( u: \3 ]6 R: n( s
9
6 ?# T% [8 C5 s
10
$ j: A: S, ], w9 C
11
; c1 w% X" _" g* w
12
/ \* |$ `$ n1 Y+ i) a' w- U) X
13
$ F% g. f1 O3 i
14
; |7 ?. L/ f5 K9 J4 a
15
. K1 T' U7 b! J
16
% P( O1 ^$ n) d* E8 K
17
" p: z3 |4 D( I3 H
18
/ V* |& {$ y9 f9 j# A' Z
8. 加载已经训练的模型
$ u1 O! I4 L6 I1 R" {( P
相当于做一次简单的前向传播(逻辑推理),不用更新参数
" J. ^; j% O ]; ?0 N9 T( I
$ J' R% o7 l6 w' S, x1 P! W( p
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
! z( v( ]! ^8 A# }
$ K. ^3 n! X }: k0 v
# GPU 模式
D' P" j8 f2 g8 }% B4 ?3 H
model_ft = model_ft.to(device) # 扔到GPU中
+ E ^: w) T. [4 v* f0 l( @2 X( n
# A; Z1 |+ S2 u& |/ s
# 保存文件的名字
5 }1 m. E7 _, m# F! o8 G# t
filename='checkpoint.pth'
2 ]; V' \) {' d3 T2 a: y
: V6 [% _; f- }; ?9 ?
# 加载模型
P. o7 x! j/ f1 {
checkpoint = torch.load(filename)
% ~8 _; B, k3 D2 \/ [: h
best_acc = checkpoint['best_acc']
, Z+ @) c3 J! h0 T3 \& B/ \
model_ft.load_state_dict(checkpoint['state_dict'])
: f% \# {$ Y! i1 P& v+ X7 B. Z7 ]
1
+ D- [# c) k* D: x
2
6 n; m; o5 K( ~. R& O7 `9 I
3
5 s4 u# Q( \/ l
4
) e7 O, v: O4 y z
5
. R6 [% s5 @: y8 f
6
1 f* `: c' o1 \. C
7
- W' `( H' l6 ^8 H8 y8 M, R
8
- s& Q/ ~8 C0 }# y/ B1 }% l
9
/ V4 d3 R% N9 C0 w% T" m8 c
10
3 v- o4 ]. p3 j) H1 [) E
11
1 Y$ _) j/ h* f: ~5 x
12
8 A+ x. _! @9 y* G
<All keys matched successfully>
- Y, Z* q' u6 ]# z9 ~2 @
1
7 Z- v( l- N( N; Z7 w( Z C
def process_image(image_path):
) H1 ^+ r! Y# ]$ V8 x+ |
# 读取测试集数据
& R# F& ~/ ~! Y H; n) h
img = Image.open(image_path)
- a3 g9 s1 t& L1 k; m
# Resize, thumbnail方法只能进行比例缩小,所以进行判断
; e; d( X5 k9 b/ V9 C) g& [
# 与Resize不同
% V# N( H" J$ k) I9 O
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
1 S, u& E' y% D/ m7 l7 V0 d
# 而且对象调用方法会直接改变其大小,返回None
/ U) O$ C% b% m" y; P( o
if img.size[0] > img.size[1]:
! W, p4 G' J3 j, Z" S2 Z& x$ s; Y+ `
img.thumbnail((10000, 256))
# q7 j! z1 y) ^6 A, c- V' W
else:
: U2 C* @7 P8 {, V- k( d" |/ v
img.thumbnail((256, 10000))
/ _2 Y- X* [0 k+ { i
- j, e! ]# t; z! X4 Y6 J5 ^
# crop操作, 将图像再次裁剪为 224 * 224
: ?+ W! I: v, m& o* P
left_margin = (img.width - 224) / 2 # 取中间的部分
2 P9 W4 s& V; Y% Y0 `' l
bottom_margin = (img.height - 224) / 2
. q6 t; ~+ U V: A8 f4 _8 ^
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
$ r; s1 w2 A+ D
top_margin = bottom_margin + 224
2 x+ b! x$ ], c" V; G; w, m- t2 V8 B
) N( p# e; S# j" p; v
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
' u! ~: b2 ~/ m9 D# v1 R. N
! J; X- u$ j. E) o# ]
# 相同预处理的方法
4 L( Q* s' Y/ y3 M6 a6 h
# 归一化
& r( J. H1 d% W; D' |) G1 g3 X
img = np.array(img) / 255
6 ]7 G" \, u% g f
mean = np.array([0.485, 0.456, 0.406])
9 e h0 k! K3 B4 b4 I
std = np.array([0.229, 0.224, 0.225])
% _# Q) \" V: ~/ O* ^( ?
img = (img - mean) / std
% I( ~) o% P7 o+ B
( ^- b/ A# S3 p, C6 d( R* ?
# 注意颜色通道和位置
: u: d, @( v; i1 Q
img = img.transpose((2, 0, 1))
9 Q* r9 M2 b+ ?" E8 b% Y" e! Q
# O" L9 F4 Q/ Q, z( }% j% e8 V) j
return img
4 C% X7 r5 Q/ _6 w/ [3 m
1 {9 d5 d0 T$ S; x7 c$ _3 X
def imshow(image, ax = None, title = None):
; _4 ^ j. v7 H: N2 E) i
"""展示数据"""
. C! C b4 j0 n
if ax is None:
9 M; M3 p$ ^# P4 A( D1 y0 J
fig, ax = plt.subplots()
- r$ a, U5 s# f! O
7 X6 h/ x* s1 H% L
# 颜色通道进行还原
# [$ V, L0 M( v' L8 _/ g
image = np.array(image).transpose((1, 2, 0))
?; D! P3 V. S
% k9 a: F) o/ W" n6 R! v# E
# 预处理还原
& b+ m; B0 |9 @4 X0 z2 O
mean = np.array([0.485, 0.456, 0.406])
. H5 x$ M1 H- Y5 d( T
std = np.array([0.229, 0.224, 0.225])
- Y* F. z/ i1 j- z+ k( C' F
image = std * image + mean
W% A6 R+ u! J! i# b7 |2 i
image = np.clip(image, 0, 1)
, d! ], n2 G1 f2 |3 m) l o' I
' Q ]* p5 `: d
ax.imshow(image)
$ U0 E N( Y! k1 c- g! T2 _; y( m
ax.set_title(title)
" n% ~9 w1 _7 j3 V* n
( d; G+ O+ U6 @8 r
return ax
, A. M4 P( j8 H. l I
% B: Y. R6 v; l# ]: n
image_path = r'./flower_data/valid/3/image_06621.jpg'
! V7 a# Y' M4 B* B2 L' ^# E% k
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
6 s! d6 k, b. o
imshow(img)
r4 y* N! h" ^ ~- I# e8 e5 @
: c% J1 x G' T5 b
1
; K! @: ?' K# H% K: F
2
- t$ C$ B" k' u7 o9 O" f
3
0 ], C" {0 z9 z
4
2 \1 J+ I" [' L
5
9 W W- e. F9 e% ] w3 N
6
* s P: _( @* o. V* F9 T; w
7
, O' w- U! }: ?
8
6 ?7 B0 {! b! E/ q/ @! _
9
' g P' s/ r% d% P, \) b. m; j8 L
10
5 @ r2 y6 ]) M K w
11
9 z4 |. T2 I5 X9 v
12
! c U( N- x& f6 M9 g, c
13
7 J) s, M2 v( t, r$ r2 N8 o3 R) I
14
" ~ C5 u& ?2 T$ e4 S: Q
15
4 D# x# j3 a0 m5 V. D# w
16
L; a# ?7 Z6 V# g# F4 E
17
1 K' p; d z9 P8 P! j: |
18
2 P' H: k6 X( K$ S& G
19
; _$ C3 C# f7 B. c5 a. x
20
; C- ^& f9 m ~6 v$ a. _
21
9 s8 H9 Q/ n0 f! j; f
22
; X7 ^) p1 M" ]
23
% v j5 F0 B3 h
24
5 [( H$ |2 E* q3 [+ f9 ^+ N
25
7 n" J, ~, [/ @% @
26
- h# T) O; h* h! B9 K% Q* D0 R; V
27
# L) v3 B1 |/ o' M! F0 O
28
: v y) l( ?1 i9 f2 T {( g( e
29
( M1 _: F% h% D2 S9 u' V
30
, H! Z+ R: d$ S0 L4 V9 J3 V
31
% [- }2 q( C; h2 _2 {2 i
32
; h% H- c. f% w6 y1 i5 t
33
8 V9 L7 f' \: C( p* i% I, m6 `
34
+ e# l' Z2 ?( `" J. o& I, n+ T8 q
35
# @* o& W4 d9 S- ^( I; c8 \* \
36
) i6 a% m0 o5 J- s `& t% v
37
% c. i$ }0 V4 k. ]* t7 o
38
: U* \! i2 C7 [
39
l g' y8 x+ f Y2 _2 T
40
/ ?0 B" W8 F3 }+ q/ ]2 o' ^( b O# w
41
. ?) R7 \0 w0 y# x, e! R- z) z3 e2 `3 Y
42
3 m0 [$ h: a) E( i. g. g% z
43
. w) }( B' F/ p4 K1 h
44
$ B8 ~& q8 r# M* z* b4 r' c& w
45
0 {! S9 V" H+ b' G. m! Y
46
4 r& N1 M! o( x' k! D
47
% Z2 v( Y H% j8 o- J
48
! B& z v1 J' K4 [7 Y
49
) h( x' ^% C: [3 f! r
50
- r* H* M; K" k: C' a3 |2 g
51
$ P. T+ [7 _9 X7 O7 N' q! i' v" B
52
) D5 ? A) ^! F; B8 l& a
53
9 n+ v1 I: N; t- a# ?0 B, U. @
54
% o3 w# j. R. C5 U' u' r8 I
<AxesSubplot:>
u- F! @' W; v7 K! ]
1
$ p$ S! i, O/ M: i
T* O2 n0 Q i: t9 R+ x9 ?' X
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
; c% N: R1 w0 T7 u$ |
K, V! Y9 }/ q9 ~, u
img.shape
% j% V6 ]6 }2 |! ^$ ~8 m1 \% [% `
1
/ U- v& V. N; ?" S6 Y* a2 [
(3, 224, 224)
& u; d# ]3 M9 y+ B
1
" V; K* T* l' g! t
证明了通道提前了,而且大小没改变
# j2 R7 w$ C+ d* c# }% ^
' q% M8 u) i @6 |
9. 推理
N, W( U% y! ]) d- \' t9 U
img.shape
9 a$ H' c; J4 ?/ t6 j- R4 R
+ {9 n. e7 p3 }( L1 H# u: g6 x! {
# 得到一个batch的测试数据
9 |) b! ?4 u, M- l+ F
dataiter = iter(dataloaders['valid'])
0 R& }# P4 h: @) o) ]# k
images, labels = dataiter.next()
% g+ ^5 h+ M. X8 H- s! h
4 F4 c7 ?5 t3 g5 r. |* f8 i2 g% G
model_ft.eval()
5 ]+ ~/ x" b7 D. S' n$ D) H% m1 ~
# R4 q' v# I5 \/ _5 B0 C
if train_on_gpu:
% E" _' I/ B3 W5 S
# 前向传播跑一次会得到output
1 y& f- P# t- q4 s; c7 @# E4 N
output = model_ft(images.cuda())
& S/ J+ W8 r* @) F* q/ H5 N( c
else:
- U z, f+ Q0 Q8 ~
output = model_ft(images)
. J8 ~# _% H# ~. X
% n9 E8 O. v- D
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
& M1 B9 X9 S" h/ Z
output.shape
: \- ~3 x0 j5 D: r
, K! F1 u+ @) P: B1 I; `+ e n: E/ X
1
& A ~& V8 |& q: [' X
2
" x8 U7 r& y+ L2 E5 f, |
3
. r4 B) s9 U" P- y8 d7 I
4
+ X0 i3 o# v8 @3 {! v; M4 _
5
3 s% @- u- r/ T
6
) R. p* q) @' I! H6 Z/ y& ?
7
) K: g! d& e/ [' U
8
4 |9 j* g: d8 y/ c0 q% S
9
5 o- H7 I) }- i j
10
& S* _0 H5 v- {+ j
11
$ |: \" [/ y# c
12
. Z# Q( C4 l D, |" K
13
1 h% \) {/ N) }' F! @, \
14
8 n9 s% z0 e/ b& \% {0 d
15
6 c# M. ]: M& g( l+ a% `
16
4 E' A! ~/ k: p: l. x( O
torch.Size([8, 102])
6 M/ }5 L: ]; W) U" j: @
1
' u9 i2 p8 W3 t& v
9.1 计算得到最大概率
& z. L+ ~5 Z) n- r/ h/ w. k: @ y
_, preds_tensor = torch.max(output, 1)
# w- e; O# M4 ]4 m, J
5 F2 z1 c8 [3 L+ g; M$ m) p- }
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
" U* y! w. T9 ?; K, t$ ?9 s
1
+ Y* e" h5 c9 H; d- i
2
. w2 m% m# L) F) x2 I7 g
3
1 Q8 g; v! K3 ?0 R! e* L- }
9.2 展示预测结果
+ ~# ?3 W4 i4 f: i# k3 Q" H
fig = plt.figure(figsize = (20, 20))
9 F1 H# d0 t' N5 X- c- J
columns = 4
% G2 ^7 o8 g, I
rows = 2
5 }. K% _% m1 O* A
; f" t% g! n4 Z- ~- L6 u6 Q! A6 O
for idx in range(columns * rows):
4 s' I9 {4 H; q
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
: C0 n; j) D, a0 [
plt.imshow(im_convert(images[idx]))
% H3 v/ n y5 a" A0 e5 X, q3 J
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
1 D0 {' g3 m$ [+ B
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
+ P! ~6 i; ?: {& j" Z, M
plt.show()
& q# n! J. j' O# w; X$ g7 b& }
# 绿色的表示预测是对的,红色表示预测错了
" `/ Y5 j7 t& j- r2 u( R! ]9 R5 X
1
% @" F8 H" u# F+ e& D
2
7 s- L/ q! K/ p0 c9 p
3
% n% r/ n: J4 u2 H5 \/ Q% H
4
, N \0 x. @# ` d7 d; u
5
$ H/ A; `+ C' [
6
$ [* k9 w# B: D
7
7 _9 g o5 ]& S4 l0 m+ ^* D
8
+ X: y; `2 G; M+ `( q+ C
9
, |3 p1 G7 x) W5 t
10
# g5 `$ M" n+ c
11
8 n ?: d3 R! ?) K8 Y: O: O' z# k1 i
! d* l J D9 _4 W0 m
& f) z, @, t, {& g/ S. f# A9 C
* p$ F9 Q4 c, w9 p! ]% i" e
————————————————
3 B) O" [6 z/ J* T# F$ W# j
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
5 X4 D8 S' @% i; p
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
0 U- K8 q* ]5 z' y0 b) T6 _
2 O# h. K4 z E' ^! f, D+ j3 ?
- u/ S; Z# w, @/ y* x/ r/ g8 _
欢迎光临 数学建模社区-数学中国 (http://www.madio.net/)
Powered by Discuz! X2.5