数学建模社区-数学中国
标题:
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
[打印本页]
作者:
杨利霞
时间:
2022-9-8 10:41
标题:
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
1 ?% A) t6 L6 W5 b4 ~& F
+ W( r6 {9 O5 }3 k* U [: P
文章目录
! u& L3 T: V4 D) R1 N( O5 O; F
卷积网络实战 对花进行分类
1 Z& |% R; j9 Z
数据预处理部分
+ Q3 c* g2 W+ K8 G" s8 ~) @
网络模块设置
0 P+ ^/ q% j& w* H+ L- P
网络模型的保存与测试
5 Q3 w( x" y/ K5 T/ o
数据下载:
: Q4 i; [6 }0 U; U0 o5 b' g
1. 导入工具包
- m+ S9 Y! I% c! a
2. 数据预处理与操作
' U, y: S( i( }. m) K
3. 制作好数据源
: U0 S( r @% h7 ^5 j3 m
读取标签对应的实际名字
' z8 e+ T. r( J, |5 @0 T" ^
4.展示一下数据
9 f1 n9 t: i, O
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
$ A6 x0 `- L; D8 m! Q. Q3 D
6.初始化模型架构
( N4 d' Y, @, s+ n
7. 设置需要训练的参数
6 ]$ ]" n- m6 H0 K4 b1 v
7. 训练与预测
( [, x3 T* U# f2 U) W( |
7.1 优化器设置
' g! R) b3 s' Z4 z# V( ~- B
7.2 开始训练模型
, R9 V$ g6 ^: v
7.3 训练所有层
) [8 i- g1 S1 x% O! r
开始训练
$ }: D- k3 n% I2 Y! [4 \- R
8. 加载已经训练的模型
9 J5 R& N" C- h4 U/ Q2 g
9. 推理
. W" T* x, j |8 E) G0 r8 V
9.1 计算得到最大概率
g2 I" t6 g) g
9.2 展示预测结果
. d1 Z) i6 W5 {% c: h$ p7 C
写在最后
. h9 `; J) v# \9 _
卷积网络实战 对花进行分类
) g* I% x( V" ]
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
& ^4 v5 y9 H% s
4 O; m- C9 U7 ?' l4 ?& T3 {: Z. p
在文件夹中有102种花,我们主要要对这些花进行分类任务
' Q) A' b2 O' P% j! n3 ?6 H
文件夹结构
! e, j$ W3 l6 z! e: j
) E! ]1 i+ c. n p" V
flower_data
8 S O$ S$ @- l1 \4 [: |* e
- d: k7 O$ w1 B
train
# L( L+ H6 ]! L( F( I
' }( J# L9 H" j1 X0 `, J0 ` J' w
1(类别)
: Q. {3 I# Z5 e r, [( y
2
3 |+ u" s/ B O" U8 S8 X4 [- n
xxx.png / xxx.jpg
1 ^, l" L5 c% [3 @
valid
) f% R, h9 f* {; [
: i G3 N! b4 x3 f, ?; y( N" I
主要分为以下几个大模块
Y; C/ a) ~) o+ k- ]5 E
5 Y' c6 _/ H+ @, i- K
数据预处理部分
- J' |: g2 ~7 r
数据增强
$ b- h; }# I. U0 k2 a: u
数据预处理
" P: O0 i8 _3 R- n
网络模块设置
p0 N6 u: ?# h% G% Y R: {
加载预训练模型,直接调用torchVision的经典网络架构
, f4 \' L- Q6 ]* [/ i* Y2 S4 r
因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
: n1 J+ M& z" d/ ]: G6 |
网络模型的保存与测试
2 t* \* X. Z$ i& z! T
模型保存可以带有选择性
1 H' D" R" l( D" ?- s. @7 b
数据下载:
' W* [& D7 Y% }& z3 T
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
) n) }- L; _, }$ ~
7 P8 R0 E& s$ p3 l) ]+ N
改一下文件名,然后将它放到同一根目录就可以了
2 {" u2 j/ H. }. e
k$ \0 p7 |7 y! u' [
下面是我的数据根目录
. D* m5 |4 a+ [& k( P
# s" C7 n: J( R$ d0 e5 F& {! }
$ Q* v6 T7 ?9 X" R
1. 导入工具包
, g5 `- p; f. j" d: O' d
import os
& i! X" S6 c2 {$ e
import matplotlib.pyplot as plt
3 G0 {2 G$ N" P
# 内嵌入绘图简去show的句柄
/ m& s* {8 _& m- j- Q
%matplotlib inline
8 ^! D' R- U. p; R7 e h6 i
import numpy as np
/ b8 `2 [. C8 P# e$ d @% q; W
import torch
; f4 O, u8 G$ V
from torch import nn
7 z% K# h8 G4 W; [) x
/ y+ u% g5 S, H' ?+ _: V
import torch.optim as optim
1 B) d2 {" ^- `2 _" @( c: z
import torchvision
; I/ _/ s/ l0 Y I* p5 B
from torchvision import transforms, models, datasets
/ x7 b$ b+ B( |$ y# G O4 m! o9 B" T; v2 H
& u# U- V8 [# x5 n: @' ?( B: K* Y
import imageio
1 ~1 ]( t6 B+ T/ \9 O- M" H
import time
, L: q3 S/ ]+ h+ y+ p% P
import warnings
! {' J& s# S4 j8 [. a' t3 x
import random
3 _* J8 L/ \7 I6 p, s8 ?& m" @
import sys
' s9 Q7 G0 p* l9 K5 c) m. p
import copy
; F+ D" W* _* \) d! F% Z/ i+ V
import json
; t F* ~2 i% y( N j9 Z
from PIL import Image
3 J+ h/ s7 Q Y- P* q- w- w: x
2 F+ H: H' B! r* l$ p
6 \4 D6 T0 T$ @& W7 j$ [) E* v
1
, t& p! J- G m* U9 O* J
2
& w# U1 }: t( Z! P) j: d" P& U
3
5 h5 Z I1 P( b3 k" B4 e) }! p
4
; f$ U+ V; G6 r0 m3 P
5
2 K$ T4 N8 h0 p2 f
6
( f* x( n L6 P
7
8 H0 B. i4 }4 C ]
8
- Q- ?! K) o: O4 N! k
9
3 K6 b7 ]. i8 Q) Y0 `
10
2 A+ a2 j. L% N t2 Q0 E; \
11
2 S. h1 P! d6 B3 X9 g
12
# t, {5 {. @9 T$ }
13
7 D- c9 U8 L& s0 }* g5 P6 c5 o
14
4 }* R, I+ g( A4 M0 I1 T
15
6 K' Z) Q1 k5 {/ q4 J3 ?
16
! D5 {- A' m0 }- r9 n9 K, z8 C' j
17
. V) s/ C8 v' \' \6 W. u, {" f
18
$ K3 j0 z% r( v7 |4 n1 k" p+ K/ X
19
( I: @$ p# \9 D' S/ Z$ ~4 x# U
20
7 z0 E) `2 F7 G0 ~0 V! t; w! I
21
) N) s/ z5 ~& @, v H5 i$ R5 \
2. 数据预处理与操作
3 d8 u6 Y& s' T9 [/ d9 y
#路径设置
* m! S7 T9 A, P2 ?" _
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录
; P* }! s, L" C
train_dir = data_dir + '/train'
/ t- Z$ B/ ^' e* \3 W% @9 L, ?
valid_dir = data_dir + '/valid'
& F j }7 l% f7 F1 X& L5 I
1
N$ }8 \9 g% Z) q5 F5 }
2
# ~. O* o# W* F: t0 V( j
3
# M5 m( |9 L2 B4 L5 J% `
4
+ |# F) D% r3 F0 g6 ~: J3 D! h
python目录点杠的组合与区别
( X" g# G+ |! q# p b* p, f7 J8 f
注: 里面注明了点杠和斜杠的操作
8 t5 a1 {' f+ o- X5 B/ g H/ w
8 ]4 h4 H( q0 B0 ~$ I1 }
3. 制作好数据源
# ?$ o. j' L, G2 T9 w
data_transforms中制定了所有图像预处理的操作
) i5 o) r& A# `
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
1 q5 X' G9 W; O: K9 Z% _" f
data_transforms = {
9 Q. @2 l, W5 C4 W
# 分成两部分,一部分是训练
5 j, Z$ s! n) g/ z4 z
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间
: P& n, b1 K& G2 @; u/ d+ c
transforms.CenterCrop(224), # 从中心处开始裁剪
9 [, Y9 Z; |" d$ ~+ ]
# 以某个随机的概率决定是否翻转 55开
. G9 e6 `5 G% c9 p/ G5 d, _/ E
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
( a! {; X+ l# ~* {- C- B/ ~
transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转
" H! P& U5 @" C) F( ]# s
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
2 |: O6 x' y9 a5 c( Q; Q6 |" P5 e5 N
transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
( E- H" a5 C8 Q; y
transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
- C$ N, ] a+ G9 \
# 灰度图转换以后也是三个通道,但是只是RGB是一样的
! u: d9 C7 Z. p! A0 `4 H3 l1 ~; s
transforms.ToTensor(),
9 C+ P. r+ ^2 ~, v! f; X4 z: E. W
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
* S& @% _; _. x& B7 T) V
]),
9 A) @% k: a8 @8 p! @
# resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
/ `# w; s: T8 e: H3 @8 t8 c# ~; P
'valid': transforms.Compose([transforms.Resize(256),
# _, O) {% c. Y1 r* m8 Z
transforms.CenterCrop(224),
$ P) }: @3 S* w6 c @. b
transforms.ToTensor(),
+ g. Z7 X" F4 u: a( A
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
+ ?0 D! D) }2 V, w4 D, g
]),
5 |! q& J6 d- H. ~2 D+ \
}
5 C# L0 v0 i4 e( b- L
4 k- y4 Z, O: v
1
% E; @ ]3 P/ A8 m2 E; }# Y2 \
2
4 S2 N } L* x7 _( k
3
) H }& |* ]1 I0 v5 O
4
: M2 \' n) D9 g5 W) v; \' x$ ]; h
5
; H2 ]& n$ x5 m& R7 b* o. `* m# q
6
, C% W5 r/ Q+ i+ \+ v
7
7 N# E* t, Q+ H/ r' r
8
0 G( x7 _" t% ~2 n3 L; \6 }/ m
9
- O( R% _$ Y$ d% ?9 Y8 V
10
# V! h* l4 d& S K C0 B y% {
11
! [& @. z5 H$ V- o1 g
12
- m; J( _: R2 p4 Z7 i
13
# ~, m, @6 `- S! G% r' @- W
14
+ S! _/ x; J5 i- R! ~
15
' }0 Y1 T3 h c2 T9 t; Q
16
$ ?% V4 v' ]& K6 V Y1 ^
17
' w' a0 [4 z) u. Q9 ?1 b
18
% h' x3 D& I# \: e$ V3 W
19
) {! J: S8 Z9 {8 Y; W d- u: ?
20
# k' {( E$ p0 C
21
. y+ Z- @% \5 ?$ k2 t9 R
batch_size = 8
|% j, }3 R- A
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
# W9 j1 o& P6 j3 `* C) p2 J h
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
1 s& d; ~7 a- J; j! e, U0 B
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
) I, [$ W: {* Z* ?6 [9 h2 ^: o- X+ O
class_names = image_datasets['train'].classes
' Q1 ^# X+ h! d' N4 H) K3 z5 W; ?
8 A# Z3 W! f8 G" n* e
#查看数据集合
- f( k" z# q4 p0 y" Z6 }4 l( u3 E
image_datasets
, Z |( }, q& Z3 Q
. g' F% k v, X+ k
1
3 ~: [) e5 X* {8 E0 Q
2
1 C8 p) Q+ V3 E! T0 k, ~; y
3
1 b/ q/ [+ i& J. T9 v
4
Y3 r }& i& x$ A! p8 q
5
1 M! j" L3 j; w) z' m* i, d
6
: B# w. K* W% g1 w. ?2 }
7
J0 u! L* y* K: F9 H
8
9 Q" E8 i {8 o8 Y& ~+ Z* l
9
4 \. L+ [& y" y5 s) W
{'train': Dataset ImageFolder
: N7 I. m T% C3 z' ^4 F1 \8 \/ W
Number of datapoints: 6552
0 J; ]/ w' \) Z6 N/ D1 P
Root location: ./flower_data/train
. e0 a- \! I8 J8 \, x0 a
StandardTransform
8 H- \! `& o" t. i) ]1 O7 S. R
Transform: Compose(
% \3 D7 \0 Q* H) [) b
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
- T; T+ H5 n+ f+ M" ?
CenterCrop(size=(224, 224))
0 q1 ~& W6 M: A4 B7 q. y; M
RandomHorizontalFlip(p=0.5)
! N6 r- a2 C0 M/ g! r& ]3 J' u
RandomVerticalFlip(p=0.5)
& ?& Y9 O3 x6 V2 w. A, ]( H
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
1 e/ s7 i$ @1 A4 v- V% _$ B
RandomGrayscale(p=0.025)
$ ?0 \( A7 f; j1 j5 I& [& K2 O
ToTensor()
3 g& @' T* h1 `& W( j8 u- {( Z
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ c1 x* `7 H, x
),
, ?$ g# D# [2 ] R, K
'valid': Dataset ImageFolder
2 N; e( h+ X% R1 r9 i7 W0 q) r
Number of datapoints: 818
9 c7 Q: P; j6 H
Root location: ./flower_data/valid
" P) w& Q& F5 Q4 y& u% ^! q
StandardTransform
# e. m- V* i& |& r$ ]7 }5 A _" g
Transform: Compose(
" w! C/ z5 e+ f& n( y! V
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
# P3 y. z! b3 g6 @( C: _
CenterCrop(size=(224, 224))
6 S3 Z+ d4 n& A7 Q+ D
ToTensor()
" Y% i. P* t$ t
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1 j) y( d" I w
)}
& d, T) t% r8 F8 ^
8 M: t' E: o- Q
1
& q2 ^+ r9 E2 Y
2
: ~$ }4 y! M: `8 a- V
3
) ?# t% x$ G& V5 D9 e
4
8 c2 D: A" m! W. m/ x
5
# {7 t. \5 O: v/ W N# }
6
; z3 s: L% a, C5 S2 T/ e4 x* U
7
0 j% s) ?2 z9 \; w7 R* c
8
$ @8 j6 G0 @0 h" R2 s h" L# B
9
' n$ x5 [" M; R# y8 Z& l! F
10
' j* T" f2 t8 ^: W. S
11
6 o' e* H3 ]; e
12
7 X3 Q, O i+ U
13
( n! K7 ?3 \+ Y
14
- p0 d0 O; n' `! C# p* t
15
7 F& y- Y/ @8 u" v
16
: l, z, j0 @, {6 ]% J( _0 r1 b
17
& K, W4 q2 ^3 l" v: n4 h' T
18
' B; c4 {0 d9 Q
19
' J# @$ G2 I8 H
20
" a7 B- L6 x4 T v
21
. p! R! I! y5 o1 s# a$ V# C
22
6 |7 Y# K4 y" i: s. @* Q
23
$ J. P, A) p- ~8 ^7 V5 {& ]; K# j
24
3 }& \& G# Q: f) D9 S' a% {( D% z
# 验证一下数据是否已经被处理完毕
0 x+ v: S. } u# W% [
dataloaders
, S) W* c0 B# h0 c
1
+ x1 c- T* A' y9 E Q6 e" c
2
* J7 P1 R3 C* J: [1 T) `1 g+ r
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
( k# x* w9 x- ~. ?6 S! P
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
4 r- p9 u5 ^- B1 ~
1
( Z' i! g O( S' p+ J/ c% @
2
$ @9 i+ A2 L: w+ v4 M3 J1 Q
dataset_sizes
2 o: D; n# g( A, _! i+ M
1
) b* M! ~' s0 H% M' a
{'train': 6552, 'valid': 818}
: J9 L; w8 e; L0 C1 b, T
1
$ L3 k! R, [2 H# d# f' l9 E
读取标签对应的实际名字
8 s/ ^. U% h; p) C, x5 _4 n4 d
使用同一目录下的json文件,反向映射出花对应的名字
# j( o: T# T. E/ U. e
9 q& p( N0 Y* |9 \+ o) ?
with open('./flower_data/cat_to_name.json', 'r') as f:
1 z# [+ I" x2 Y* L) [# n0 O! e
cat_to_name = json.load(f)
1 ?* Y( R7 [8 Y
1
+ _$ Z) d/ B t
2
2 ]/ X- ]* m4 v9 } R4 j/ F
cat_to_name
: J& @1 s. T1 d0 u. @, k
1
4 A3 U0 F' J+ r5 r7 x
{'21': 'fire lily',
7 g- K1 V4 l2 I* F/ {
'3': 'canterbury bells',
* Z) W* k& [2 Q- ?, X) E2 k
'45': 'bolero deep blue',
# F9 S, B7 U+ [. R/ l* ^
'1': 'pink primrose',
# T2 J2 N; {. k* e% T' A
'34': 'mexican aster',
* x) K: K, ?; _
'27': 'prince of wales feathers',
" [/ w: ]( g+ j$ k0 t) U( W$ U
'7': 'moon orchid',
2 [( Q8 h6 y% V/ }& [4 A( d# k
'16': 'globe-flower',
7 Q" S. c4 N4 p
'25': 'grape hyacinth',
3 }: D+ [7 s( M" W& b* f
'26': 'corn poppy',
/ O- ~( C0 o- s8 O5 V. ]) k
'79': 'toad lily',
" B1 i( Z. \+ }9 u( d$ X
'39': 'siam tulip',
3 S" U( [# |3 K# L. p% E* H
'24': 'red ginger',
5 Y8 |+ P, W: W3 _; W/ W8 ~: j
'67': 'spring crocus',
1 N7 E9 }8 x: J, `) P' r. b3 r. ^
'35': 'alpine sea holly',
7 a8 X' C8 [7 _$ ?$ [
'32': 'garden phlox',
" x/ R j6 r# H3 d
'10': 'globe thistle',
% t& x, M- s" |( N1 T, q
'6': 'tiger lily',
1 a1 A3 v' K a) K
'93': 'ball moss',
+ w; Z) E3 v- T* E2 K6 R) Y
'33': 'love in the mist',
" J* ?+ c8 L6 [7 C
'9': 'monkshood',
- g" p6 e- Y9 F; a! u) V
'102': 'blackberry lily',
; ]; @! j" e7 i! C2 r
'14': 'spear thistle',
* c" |0 l1 P& Z& |* k8 _2 H8 t+ g* {
'19': 'balloon flower',
: p' e: Z# M1 A" M; b* V8 m9 f b
'100': 'blanket flower',
6 ^7 f( F+ s( A! K( _* @' Q( l t
'13': 'king protea',
8 L* I! x/ G0 o5 Z
'49': 'oxeye daisy',
# b; n& ]7 L: j% X4 x) M# [
'15': 'yellow iris',
' f8 g; ?* h/ C- `% U
'61': 'cautleya spicata',
5 [ G1 w4 B5 f* b
'31': 'carnation',
) ]1 \$ r( S" l9 G. R: H& m$ @
'64': 'silverbush',
# e% |4 ]" b2 M5 y5 Y
'68': 'bearded iris',
, ^+ P2 R' k4 [- H5 i
'63': 'black-eyed susan',
9 N7 b- [& c1 F% Z7 P. F
'69': 'windflower',
; m0 p; T6 B1 \
'62': 'japanese anemone',
9 N- s4 V; ]$ r4 x8 C2 f3 S0 s, @
'20': 'giant white arum lily',
& }) C: T, a! @# }. @
'38': 'great masterwort',
1 R( @" v- H3 Z5 i& x
'4': 'sweet pea',
6 A2 k6 v/ ^" @& G8 i7 P8 Q
'86': 'tree mallow',
r* u _7 C+ s- e9 a& c3 b
'101': 'trumpet creeper',
H Z+ H8 Q4 {3 u; O1 f5 |- q# d
'42': 'daffodil',
5 \8 X0 G$ G$ `! U# W7 t
'22': 'pincushion flower',
: a( V4 [& w! D) ~! y
'2': 'hard-leaved pocket orchid',
; \4 W6 n) f9 |7 y
'54': 'sunflower',
) L4 ^1 I$ X. i$ c
'66': 'osteospermum',
) G6 {7 i @. C a, B8 N1 G
'70': 'tree poppy',
7 d2 s! Y5 {" S
'85': 'desert-rose',
# `0 d9 G* p" @& I
'99': 'bromelia',
! N" i i. T! p2 K, @4 `4 F
'87': 'magnolia',
" q e* k& z: P
'5': 'english marigold',
- G& ^% R/ i( B# K- D+ U7 [; i
'92': 'bee balm',
! c) F `0 w1 O \) }2 c
'28': 'stemless gentian',
* ]3 k2 ` W3 ]( j/ r
'97': 'mallow',
" [5 l8 o2 l6 X0 y" E
'57': 'gaura',
2 ` s7 o% x# f7 J
'40': 'lenten rose',
w1 c- `* j8 V$ n! T. g
'47': 'marigold',
7 Z- n; M0 d' a& E3 G2 |7 K. m
'59': 'orange dahlia',
( ^1 M1 ?1 y; s6 h8 z3 ~1 V
'48': 'buttercup',
7 O v9 d# A+ J2 f6 S
'55': 'pelargonium',
% z" ^& f! K' a5 v
'36': 'ruby-lipped cattleya',
# H/ k1 ~9 R; r# E6 J5 F2 i8 o
'91': 'hippeastrum',
" e# i4 O- O2 X& Z8 y* a
'29': 'artichoke',
. c2 R. \( ]! h
'71': 'gazania',
! u3 ~9 W& r# r( I. G, _/ D/ d: v8 ~
'90': 'canna lily',
9 \1 j. E: }5 v% i- C
'18': 'peruvian lily',
4 N( B6 h. x0 r, S' X x
'98': 'mexican petunia',
9 y M Y% d' z! w) u
'8': 'bird of paradise',
+ v0 \: a5 X6 k: K7 E& \
'30': 'sweet william',
6 s% z* ?( l) R; _1 F0 a
'17': 'purple coneflower',
& M$ Y% _0 }6 Z, O; [4 F
'52': 'wild pansy',
$ T. Z, ]4 {" j* a& H. d8 ?
'84': 'columbine',
) c E8 n9 S$ K7 k
'12': "colt's foot",
, J: P* [+ e+ a, l* q% K9 @
'11': 'snapdragon',
5 m. j6 ~# A" ]* |% x
'96': 'camellia',
) P( ^; _- M! C* K# }. j1 @
'23': 'fritillary',
. a' m! A# c: Y& G
'50': 'common dandelion',
8 D* {" n t7 u! E
'44': 'poinsettia',
. r& m$ ~" z5 j
'53': 'primula',
! }* q5 J B1 ]2 j' d( C" g. p! U
'72': 'azalea',
9 C9 [+ u4 q7 I% i$ \8 `0 {5 [
'65': 'californian poppy',
) E1 _7 [& O% }6 F5 h& d& h
'80': 'anthurium',
+ U8 l G. j5 v! |+ S
'76': 'morning glory',
$ G+ s+ ?1 [& S0 n: I4 E9 J
'37': 'cape flower',
: x# i: h+ f* k- j& W
'56': 'bishop of llandaff',
2 L" O8 ?5 @, U0 P* h5 I$ \4 l# a" ]9 _
'60': 'pink-yellow dahlia',
; t0 M, c, u) C, _1 U3 a
'82': 'clematis',
6 l6 n6 F _$ Q3 ?, p
'58': 'geranium',
+ h- R7 f, W# k% r- D M D
'75': 'thorn apple',
a( ?# G: S6 G p6 m2 m& u; J' Z
'41': 'barbeton daisy',
7 x# e; F! a5 l2 l# Y7 @: w/ I
'95': 'bougainvillea',
% q; H3 ?; ^' S. O2 c$ d
'43': 'sword lily',
N! [0 `6 Z' X. K8 B6 O
'83': 'hibiscus',
9 ^) H* ]1 J1 s3 C( v
'78': 'lotus lotus',
# D! |5 j! ]) L9 w' ?+ Z
'88': 'cyclamen',
0 M/ P' D& a3 K
'94': 'foxglove',
7 Y( s- ^# @/ ]& S* l' N& D& g
'81': 'frangipani',
* ~, n. L+ K' U6 {& W6 |: S, V( n0 Z
'74': 'rose',
. P' l9 l6 z' @( ^: m! Y
'89': 'watercress',
K; s4 g; H& f( Q
'73': 'water lily',
8 y1 g+ ? _. I: W I
'46': 'wallflower',
* H6 f6 o3 B/ I
'77': 'passion flower',
7 G# U" S3 m$ W; Q2 B
'51': 'petunia'}
5 Y% b9 x: `) Q4 a0 D
6 E$ o7 K* _% t9 J
1
* E q+ n) e% l
2
0 S7 Y6 V; Y& o6 j, |7 q0 F
3
) z6 n: ^2 a h( |) J. }3 {
4
$ M0 H; Z8 n) f( C/ s( A
5
/ B8 u* W7 v& {8 j% }3 _! W6 n
6
/ C# Z3 ?0 n8 k0 x0 w& S: m9 O$ p
7
6 m2 }. K+ Q2 a! I6 S
8
, O8 d& x" t" U n& w5 ?* r# N O
9
0 j& M0 P* k' [( ?( _
10
; q j, e9 N# i& k5 G" A
11
$ R* l" w8 l0 p% m9 R. o. X9 M
12
4 f& \3 `/ a1 B% k/ `, a8 Y7 a* p) A
13
5 m- V2 ~" X7 d; Z8 I( J5 e3 n" R
14
. C" f6 ^# v: k* n; M8 z
15
) F2 P: X; D9 K/ J
16
+ S# ~. y# b' D2 a! f& y5 _
17
8 x$ a* k* K; E+ c' T" B) v; ], q
18
& o( d; l( q# g3 d/ Q9 H; J
19
) n0 v: K) P0 E; n
20
9 m+ Q$ V2 Z) d/ h B. S: o
21
# Y/ }' @ l4 l
22
# z9 [- v/ d1 E/ f* h5 q! Q
23
7 h$ a) ?' N+ k! p$ ]# f
24
4 I4 N2 Y: J! q# S3 q3 I
25
2 o% Z; u+ S; c" y, G# p* w1 F
26
/ v0 R$ X6 Q: }8 A' ~* G
27
/ ]) p( W, f$ ^7 \6 e, X
28
0 x5 _- e* J" u
29
( h _. p2 Z# u
30
' V+ U1 |4 m5 q# u+ m7 x
31
% S8 [: c5 S. ^7 W! `
32
6 S/ K: N* @: y$ a5 }/ V" G/ B
33
( B6 M& b7 p; d, J6 N/ y* h& J! G
34
: }: R2 I, ?& ~% o5 T1 @
35
) q0 s: w+ V& Y
36
, [& z+ c8 t3 a) Z
37
& Y" X; A) v, T l& _3 j
38
+ {6 D; U6 T) ]6 C0 ?
39
* l2 x% C/ F& U. A8 V8 ^/ g( }* L9 o. ]
40
. R& ^/ e7 k' w4 N; Y0 Z
41
7 L/ x- P Z- J! S( n
42
8 @$ x0 D0 Y! p
43
2 K& l/ D/ b, h: n: j+ V" m; H
44
: L5 f& Q+ i# o: ^. Y+ @
45
M7 ]1 C/ K) `) ^1 ?6 O" v2 R
46
V- }# Q" x1 q* h$ j2 C1 c
47
4 I! V. ?! T- V5 m3 ^
48
. E2 W5 z# \9 b4 T! Z
49
8 F' ~5 G" p% S9 L
50
- ~( R# L, m ~: T1 w5 N
51
6 h+ U. \+ U: t& `' x$ Z! m F
52
y1 G2 Z' G. j, w( w, L
53
8 t4 K& _+ ?1 s7 E- A
54
, B6 F3 H/ h* J
55
$ ?7 }6 T# S( x# ^5 o3 ?, h
56
5 _# u' H7 _6 ]$ M) e& Y. |
57
* k$ d( I Q3 ?0 ~7 D1 k* L
58
9 b- c# s& a7 w+ I2 u: l4 b8 G# R" w
59
- g- X* x) Z$ ^# ?: z+ G1 [4 ]
60
9 S: O) d/ m2 }
61
6 p# N0 F# E4 O( @3 r4 j
62
6 r! Q! {5 E: R: S/ l3 B
63
" C4 g1 n# A6 k; w4 @" K
64
. i1 x6 o. I) v8 h# N
65
( U4 o0 H7 O- g2 o, x+ E/ ~5 T
66
3 I0 z4 `- ?3 d, A, s
67
5 g; z. ]8 n! P8 ?9 G j1 t D
68
0 }( c4 C" H6 [: r% s0 c
69
: m' x" j/ D7 Z0 g+ C; W
70
/ g7 f+ b& \+ Q- ?# w1 x5 B3 d- |
71
7 p R9 o1 c. {
72
- w) P1 U& m. x8 s
73
* X, R R$ [9 f. ?0 ]7 M
74
. O, R$ a( m, T" M$ {
75
& s3 S" |9 m- [1 d$ u
76
" N- }4 x/ f) ?# O, ]
77
% q: \; X: i; D* P
78
% G) o' @2 ]. C" c7 Q0 _, ]
79
" W) p; H, ], @ n. M5 n# ^
80
& {+ T( t8 q' i' G. K4 C4 y
81
: p* g( e% `7 y) f3 M! A
82
& z0 }' k6 W" \+ S
83
* v7 @1 J4 H' m! u( b5 k! v K8 x" Q) ~
84
/ e7 K# j, f" X" A( E- x( L3 d6 n& f
85
. h/ _9 E0 M) ]% M4 @
86
8 g1 j- I" g0 Q
87
3 G! o" B' _3 |
88
8 d0 ^& @4 M2 E _: P
89
W' W6 ]' o& j- y D
90
& I( T3 V6 o6 H
91
. ~; t2 i* d% v6 z; t
92
/ h% E/ L6 H- m7 Y0 p
93
) G! s: x% {5 o6 O5 N! F
94
5 V5 R6 T$ J* v/ Y6 v! L2 E
95
+ X! e7 |/ C4 L
96
( q' e O$ H8 _& ?3 R
97
5 C5 t) M) [# _5 w- _
98
5 M- H5 ~' b, o# Y' X
99
3 X6 q0 ~# y( j% t6 Z, G- f( E* ~
100
: T# w' v1 ~0 v7 b- d9 p2 p8 n" }
101
9 S1 a7 Z# z" j9 `6 b% p
102
9 N: ^) [8 ~1 A" G4 m) S5 J# ^3 T
4.展示一下数据
; I1 N- @; ?& I# [: S z# O0 x
def im_convert(tensor):
2 m" k r9 b5 e( `; R6 S% m
"""数据展示"""
1 K2 K. B5 r. q9 I
image = tensor.to("cpu").clone().detach()
7 R5 z+ d! I: M' v+ d' `
image = image.numpy().squeeze()
& D9 r4 A( o& r6 X7 ` y2 u
# 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
! @' D S/ Q% g4 ~: T, k' F- j" x
# transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
8 j; c& k, l$ O1 @0 N; `8 A
image = image.transpose(1, 2, 0)
" e: t7 S1 C6 D# b; V
# 反正则化(反标准化)
$ ~% T& v" Q C8 {) N+ g: ~
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
+ h* G: M' R9 r. n7 {0 d6 Y( E
, y7 }' N. P5 m" @7 R% S
# 将图像中小于0 的都换成0,大于的都变成1
f$ H; N: n* e! x0 I
image = image.clip(0, 1)
! |% j( l$ w: G$ o3 g7 j
- f. Z( |3 _3 M# L. z P. B- N" [
return image
. {4 `" }; c# _6 g
1
: a: P1 T/ |1 [2 v- z4 j: S
2
! A$ \3 m' V! J: P
3
. x6 U! X a' K# [+ L d
4
$ K& _" q0 @) r/ w9 x1 N% J7 s
5
" n5 @ `8 y4 c3 @# U) B( S
6
' c8 ~# [% a6 O) z9 G
7
" i+ b' |/ P. q. _! w
8
6 d2 h" `- s: d
9
) p% l7 x2 j. S' r6 z
10
& F7 g$ A' u( |) S( S# ]
11
; \1 J& m) ?* l
12
& W; z% n) f7 Q; V( ~( q, d
13
7 u& P, ?+ Y2 X* i& a
14
A; J2 T' W% H. O
# 使用上面定义好的类进行画图
7 g. X- g% r1 @2 O
fig = plt.figure(figsize = (20, 12))
n; k9 ]( k4 a
columns = 4
' M5 |. I e! _2 w3 { e$ a5 x5 _
rows = 2
& |* N. u: O* @ C, {9 _
3 s1 A! R: M. u9 g6 Y
# iter迭代器
, i) }5 o8 y7 _+ J4 [! p! {7 g
# 随便找一个Batch数据进行展示
. M, z s! N' \9 D
dataiter = iter(dataloaders['valid'])
/ S5 v$ v& k. C5 H1 A0 ~
inputs, classes = dataiter.next()
7 J1 l9 `8 B$ @0 z! q1 `
' A) H/ S3 i- |2 t
for idx in range(columns * rows):
* a1 O; R% A N0 W; a2 _
ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
. h! s: A& `" V/ T: R% Q
# 利用json文件将其对应花的类型打印在图片中
/ R) R! N, g* R& w* b; K
ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
2 i! Q- Q1 q& g% ~5 v
plt.imshow(im_convert(inputs[idx]))
" c, G% d4 u/ B% I7 a2 \( R
plt.show()
( ` M% U/ M& X- c) E7 D) s
, l% w8 T% h. ~4 P0 @5 W
1
' y6 K! ?" D' v c/ F' M) ?' q5 ^
2
6 ?$ a! Q1 \$ x# y; x3 S4 I$ ]9 w
3
/ L; O8 [5 s2 Q; j
4
! X3 v/ F! V2 ^: D% z7 I
5
6 J" q- [$ D. }# E d- A
6
6 ?/ ?0 F" h" I/ N3 b; c
7
) ] W# A5 Z+ w
8
5 l& M, j5 G8 Z) K
9
7 B. b* \9 p* d$ F' [( x& K
10
% @$ J2 k% d5 t( C# m
11
5 |0 n. W& }0 C' Z2 L, `8 S
12
$ _3 @$ J. ^+ i+ {- b' f2 ^$ w
13
/ ^6 d7 Q2 z1 z: l
14
; l: `. t8 ^( X& j- ?; x0 x& q1 j
15
( O( ]" x/ x9 O3 l( ^+ D( |! i
16
/ y' I. }) B1 U% d% ?7 @
" @5 w" m. X N9 v( r8 i5 |
1 t: ~% s; f7 U4 z4 R8 }
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
' t4 B+ U9 F% i& f) X( H
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
4 T: X) h5 ~. o5 B. |
# 主要的图像识别用resnet来做
- h" O7 G) S8 c1 E& M
# 是否用人家训练好的特征
& p/ \# ~4 f1 J
feature_extract = True
3 z r6 B& Y% m; b/ k9 u! ]) H
1
! _8 E3 E9 l; i) w- L9 i! H9 Q
2
7 x; K9 G- l% M8 l
3
, d6 t$ Z/ m' j# m
4
( u; x9 w5 W' P7 L
# 是否用GPU进行训练
- i c9 L/ ]9 w S2 M w0 q# f
train_on_gpu = torch.cuda.is_available()
" _4 E: v2 {" ?9 b
* I6 ]$ U& m, ?. C
if not train_on_gpu:
4 k4 S1 G* B" S
print('CUDA is not available. Training on CPU ...')
2 @; Q, G6 _, s' c0 X `, g4 e
else:
; f) S; m6 w+ k" L
print('CUDA is available! Training on GPU ...')
, H/ v3 h& ]) C- @
# p1 R* H9 @& } r! i3 T6 B$ P3 W1 ^
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
% \! |% m7 R1 d: V
1
, H/ t! G8 X% Z9 A/ c
2
: u1 d, {7 O J- }/ F+ W+ j
3
J# s D% j4 U, `8 J) @2 c. U# B3 S
4
( Q4 s9 k4 n& h( T. N
5
$ X y1 U) Y' X
6
/ v8 p. Y' P! ^1 ]& }' t# _
7
X* _1 ?& V8 T$ \. x
8
2 Q% h. R2 w1 k4 V; p1 d ~
9
+ \2 Y! b/ d8 U
CUDA is not available. Training on CPU ...
0 V; G. x+ C& P, o5 o L& h
1
' z7 B: ?0 Q6 j- I/ |! x& b
# 将一些层定义为false,使其不自动更新
/ Z. T7 @1 v3 X1 d
def set_parameter_requires_grad(model, feature_extracting):
1 E& `& O, B+ {& j) f+ D2 }# G8 w
if feature_extracting:
% {# ?: r! |4 v/ i4 j V
for param in model.parameters():
2 |, D: j, C6 }8 Y% d$ l
param.requires_grad = False
3 y+ @# m: b c8 I) z' l3 R
1
- P4 [7 b1 ]3 W2 c5 b2 E
2
( F [4 j3 f2 p/ _5 Z: V: c
3
N, @0 \' k% |$ P. H
4
: t; C( Y; z4 g- k
5
$ {1 i- h0 o E, ~: z0 _) m
# 打印模型架构告知是怎么一步一步去完成的
2 b9 p5 Q- ?, x: \5 S) x0 _
# 主要是为我们提取特征的
1 F( V, L- U m
# G1 R3 p; P0 Z! p" g
model_ft = models.resnet152()
5 }/ _" j A0 z3 w* p
model_ft
J" y0 ^5 O7 h6 h7 o* X6 W) O5 b
1
) r- W, F0 f9 s. e5 R) D- q
2
, v: P+ I. g: Q
3
) v2 p- N7 b. X0 w
4
3 m0 @: u1 ~8 _( P$ a
5
+ o# }" {% O" _! W
ResNet(
' [: }4 w c/ p* t
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
( E) c w4 p; F9 g! G
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
) a& e; V+ A" F
(relu): ReLU(inplace=True)
/ G/ S T! f. Q/ R( j0 Y: K$ [
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
* X4 @/ R( E5 I6 S+ A2 F& N
(layer1): Sequential(
/ d' ~ B: ^5 @. m0 v- h; y
(0): Bottleneck(
; Y# C9 v! p+ n" G/ C& P* z4 |
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# L0 O4 i: M- ]9 Z
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
( a& d- A5 }" s( Q$ T
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
+ u% W. L; W0 N$ E- C2 t1 r" C* Q
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
/ z; e1 W+ z8 x' y. c7 [
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
5 H. a5 L( d8 w$ J p; }3 A" K
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- |" A ~; V: @# _* ?+ C* S
(relu): ReLU(inplace=True)
) D3 K9 u$ I& _: v# g% M# }) V: ^
(downsample): Sequential(
" f7 L8 A; X$ [6 W( r7 f. @
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
4 V: A; K7 m( n+ o, u- G3 @
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
& X( Z# x( ]+ U- ?+ E
)
' U4 U. f! c2 G8 Y$ |
)
# G& r* |& C6 U; V
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
6 B# B' ]( r' t) I$ T6 I2 G3 _
(2): Bottleneck(
: Q; |+ s9 V4 x8 x
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
L6 m4 K; |; i9 I# L
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# V. V; |! ?9 ^
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
- H) q) t6 D9 H+ _) |- F
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 I: b0 b W. Z
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
" F+ D* T% `1 c# n, \2 U! s. y' f
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
/ M- z( O; T4 m, G A3 m Z; Z
(relu): ReLU(inplace=True)
( o5 D- [$ a' S: s6 _
)
! y: d) z. g/ Y
)
( M, S; |; k6 O& k
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
% n0 A# o9 }8 N X1 A
(fc): Linear(in_features=2048, out_features=1000, bias=True)
: H6 w! ]3 |: x9 I: }
)
% D3 ]& j' k4 {0 t" J' i1 x* I
, B: Q0 Z7 f& ^
1
8 Q1 H1 K2 S8 ]/ J4 E- @% N
2
2 u; o; b% m1 b
3
8 {( p2 f4 [9 w
4
7 d4 a& U! ^( Z
5
. e) r+ g- a+ c7 b
6
! |- e( f- n' u) o2 {( Y
7
G" ^7 h/ z2 ~5 O1 }# q
8
9 R# z+ v* w" m6 O5 M# [( H
9
& B" Y* b2 U- K$ C2 l
10
0 Z+ N7 F1 O) Z
11
8 E3 k5 c- A- C
12
- ^/ F1 q$ v& L
13
' j. _$ S# b0 d/ ~, W1 ^4 d
14
0 t& V, D1 P8 n1 I4 | `( {
15
~' j6 w" J0 X+ h1 n! I# H
16
, v& b8 i1 J* q
17
& E: n" p' Z: x: @% X4 ?
18
6 [+ Q) [: i1 b" A6 i+ S
19
2 v( C, t( H; s) R4 G8 a
20
0 b6 q! `; j, u
21
0 {- b( x% \& j! R
22
0 y. \9 Q9 M9 r6 H/ r
23
8 z8 {9 Z6 W6 q
24
; T7 N5 Q: J z2 l
25
7 c) }0 Z3 W# J, k/ Z8 E' P
26
! I2 H9 S/ s4 i7 d7 Y* V
27
4 y% F8 `; h* [5 {. h" K, G& ~- c/ ?
28
. r/ s. g/ B% T) z X( V
29
* j. Y3 f7 E6 e2 y8 ~
30
$ O2 r- P* t9 a: z/ V/ a
31
, A# B% d6 P E
32
% h& s# V- V- s( [$ `! Z
33
% d% T6 \5 a$ R4 R- s5 p
最后是1000分类,2048输入,分为1000个分类
I9 l. ^8 D/ ]! J% g
而我们需要将我们的任务进行调整,将1000分类改为102输出
7 }/ l( L5 B7 |0 M3 k2 K
" C4 {5 j% [4 F% ]: K- [6 {' |
6.初始化模型架构
0 r* O. v4 g' k2 J1 b( `: |& _8 V
步骤如下:
: U4 k' T6 `5 [0 Y3 h7 \; w
8 Y2 E5 R- w" f8 d; N& M' @" k4 V
将训练好的模型拿过来,并pre_train = True 得到他人的权重参数
$ h5 R5 U$ N2 n- |
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
3 ]& Z: R9 @& p
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
1 B" z7 E+ d7 w* D3 D
官方文档链接
. t; A# x' A& b
https://pytorch.org/vision/stable/models.html
0 c" q' G! O2 \' C) h
: c6 f0 [9 G9 h8 c5 j4 T# ~) p
# 将他人的模型加载进来
, G8 h" X. ?& M$ S4 Z& r$ B% T g5 ?
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
2 N; y2 ^$ T1 W
# 选择适合的模型,不同的模型初始化参数不同
' v9 c8 S, \2 z$ l5 m+ b
model_ft = None
, I* n( M5 N# ?7 H
input_size = 0
, y# W1 Y9 G) |: k$ C) @. w |
* G% t5 L$ h. ^$ T
if model_name == "resnet":
+ w: _8 V. n3 w Y# L
"""
2 |8 i/ U, B* J2 E# ^
Resnet152
8 d5 }! |3 Q/ V3 H" z, |9 V
"""
8 c _# p& [3 D( Z$ S% K
! ?3 Q( b6 X! b1 ]5 C
# 1. 加载与训练网络
4 |+ }8 \; a1 i( q
model_ft = models.resnet152(pretrained = use_pretrained)
; u+ z1 z1 D- _# G8 Z2 E
# 2. 是否将提取特征的模块冻住,只训练FC层
. x( r5 [2 X* H$ D) y
set_parameter_requires_grad(model_ft, feature_extract)
6 Q$ Z l ~- C2 a" k
# 3. 获得全连接层输入特征
8 c" k/ E, d& \1 m! R$ Z) c8 \' H
num_frts = model_ft.fc.in_features
6 H$ } V! p! ?" p/ C" l& }
# 4. 重新加载全连接层,设置输出102
( E* Y9 r% k* g
model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),
# B& ]6 e. M4 \. H
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
$ B+ r8 z. c# y- {% d5 X
input_size = 224
4 W2 L2 Q( O% O) t- A9 _+ }
; _* k: j1 Z+ l3 R! [! @. c8 H' y
elif model_name == "alexnet":
9 |% p- s! F0 _* s! w. D% {9 V
"""
) w! |3 D% H O$ ~9 o
Alexnet
% r B. c" \ P( M* S
"""
4 c: Z, ~0 `+ }5 K% a0 r
model_ft = models.alexnet(pretrained = use_pretrained)
# h \) q" a/ B6 |* n
set_parameter_requires_grad(model_ft, feature_extract)
V% X5 r. C5 }; y
' C. y, O8 t" o$ z8 U6 O
# 将最后一个特征输出替换 序号为【6】的分类器
) I& {+ y9 H7 S2 j t6 Q+ v
num_frts = model_ft.classifier[6].in_features # 获得FC层输入
. F8 m. h, q4 c1 q( R" l/ H
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
0 j# b. s$ R( ~" [) r
input_size = 224
3 `1 U% d0 ?! `
( `' A$ }9 Y. l. ~& @/ c& c
elif model_name == "vgg":
4 T; h9 w2 A) C6 N9 w7 ^( r3 M+ m( P
"""
B& M7 x* N/ [4 B8 s+ F
VGG11_bn
2 v& a: {4 w- u1 d2 _2 k
"""
2 R. t* ]# R7 X8 M! E/ p2 q W0 _
model_ft = models.vgg16(pretrained = use_pretrained)
u1 D$ i+ a8 F
set_parameter_requires_grad(model_ft, feature_extract)
0 j" B' K* R. W9 l& F+ {
num_frts = model_ft.classifier[6].in_features
9 D6 e) |7 _1 f% n- \2 ^( \
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
/ h8 I/ ~) Q+ j+ c% h1 O! x: k
input_size = 224
& x- Q1 R0 e1 }6 H+ r/ }) ?4 Y1 H
3 |9 i: S) L! Q% H1 c( ]
elif model_name == "squeezenet":
. Q& n! L) P1 c7 ?8 t/ R
"""
# B+ ^7 q; h$ e d5 R" w4 [
Squeezenet
( U" I$ V, l) X9 B3 d$ t2 f. b: |
"""
' @3 m1 n w I) X8 t. S
model_ft = models.squeezenet1_0(pretrained = use_pretrained)
& r3 i# q4 v& Z; f. |; T/ o
set_parameter_requires_grad(model_ft, feature_extract)
; T9 M: j& ^4 o; V$ t
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
5 g9 K$ h9 u0 _6 A9 n+ S/ E4 l
model_ft.num_classes = num_classes
* v9 J: G4 U( o3 \+ ~; N9 F1 @! p! g
input_size = 224
7 K; u% M; |) I
0 w" j; j$ |; S+ d; M
elif model_name == "densenet":
V: M6 N# z; J: B' ]' @' L! o: _( ?" K
"""
: w4 _& X7 V3 B. g- W* U4 \1 A9 f
Densenet
1 o7 `. Z# k( `( o+ ]/ r+ }
"""
6 y, _$ w' Q# h5 N
model_ft = models.desenet121(pretrained = use_pretrained)
1 X+ H% L2 g& C) R* S; _
set_parameter_requires_grad(model_ft, feature_extract)
' F6 Q2 P. O1 C& F# N0 R
num_frts = model_ft.classifier.in_features
+ Q' Q- M+ e1 P& ~7 |- w( ?- v
model_ft.classifier = nn.Linear(num_frts, num_classes)
7 K# r. _2 u4 X) p1 B2 P
input_size = 224
& b) s/ ^+ W8 V9 A& H! }( E
. M1 T) V. W' |2 U; n1 \6 R+ `
elif model_name == "inception":
0 i' H: s/ c H; c D) W
"""
/ R7 A! E$ H4 b% C) w
Inception V3
% `" @8 u2 I' P1 c, K. x
"""
/ J2 }/ E+ }9 S- N ]+ W3 Q6 A. |1 Q
model_ft = models.inception_V(pretrained = use_pretrained)
|8 |) E5 o k5 J- I
set_parameter_requires_grad(model_ft, feature_extract)
+ D P/ t7 T6 o* o0 T
3 ]- y" j) o& l9 s$ Q6 @' Y
num_frts = model_ft.AuxLogits.fc.in_features
! v$ l z: B/ R- q$ J
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
. u9 ?6 C6 A! p) v; n+ @0 N% U( s
2 ^. Z% l! B! J M1 |
num_frts = model_ft.fc.in_features
P9 K4 a0 r; p$ k7 b
model_ft.fc = nn.Linear(num_frts, num_classes)
) z- R, {" t( f% b) w* ^0 T: [
input_size = 299
3 G3 O, f' [: z( Q# _
% ]) W2 x4 K4 ^* Y3 t
else:
! ?1 K& T5 M3 y2 |4 y
print("Invalid model name, exiting...")
+ y' O0 n' ^( f7 R9 d# v* B
exit()
" i5 p3 I' o+ Y- {7 L# X
5 A$ l7 F$ Z j* X. i2 k
return model_ft, input_size
8 M* ?( y, o, {$ N; k% d4 `$ p
- \5 J9 I( W; L
1
# h6 g8 g# h8 }+ ]2 Y/ S* s! A
2
9 l; U( D, y9 u# h3 s
3
% X' w8 i9 h; x* F$ \( x0 p
4
: q/ E0 Z5 l' g
5
4 F+ K _9 u0 q3 H
6
( J4 k) I7 p$ \$ A+ c
7
: P2 N' p+ e5 y% P& o: `* _2 j
8
" z$ C5 @* M' c' ~
9
$ _3 y# |9 E% X" R; h5 Z. S
10
8 @& q, X8 ?5 b; [5 w! ?7 j! f
11
, x6 D7 T4 ?& k7 t% H# @7 i
12
" c {/ M" ] E P
13
4 e- i* p8 e N
14
$ L' l* \2 n& Y* h0 h6 R! `+ h
15
* I* i) F$ f1 ~
16
?0 `8 q; {% Y R7 v; F: U
17
6 A6 ^" j+ y _- s" ~- m7 G9 B- Z( M& ~
18
, F' [7 D* w0 t p3 J5 {7 @, J+ ?
19
* V) ^+ e/ E7 y1 Y4 ?
20
- h* P2 S8 R' f" l% ~. Y5 s
21
" H0 J4 k8 |& _3 J3 X: f& M
22
# _! J4 O4 _. I/ T* K& H
23
1 ~) V" y$ d. B! U1 h
24
- H' j8 L% h) p; \; O6 Y) t' u
25
) k& h0 u" @2 a7 Z* W4 k
26
" i9 T8 z0 L+ o
27
. w1 S, t" h: S, s' G% n4 E+ L9 t
28
4 g) W* n' S! ^
29
4 t8 _$ ~3 G3 W* ?5 ~: [5 ]
30
9 o1 e6 a. ?2 Z+ x
31
9 ]/ w. ^/ r! t3 ` @) ?1 i; f5 I
32
7 H+ k; c9 U+ L& Q' [
33
7 H8 j6 W" r% C4 z. `( W& { _; E
34
' c2 _5 R( Y0 e7 _
35
" R' s2 ~/ C/ d0 S3 J
36
p: X6 U) j3 L9 H
37
w! Q, F& n& n
38
7 i$ ?$ O6 ?* W$ ]3 r0 l
39
' { T" t ^9 B* x+ p* n9 h! p
40
1 \+ B2 b* o6 y* c. f3 q% `: n
41
K& W6 B% P+ a$ b' |
42
0 q2 @! |+ ]9 o* Q" m
43
. t+ q, p' u" a4 }, x
44
% a& E, H" Y( V- s9 X1 o4 H
45
6 \( U0 H1 l6 N/ G F) G; d1 b: d7 I
46
2 w& Y, X, d. M: @$ `& _
47
& q/ a1 @: z. J1 Y! A% u% p
48
9 O O" p" {% b" Z
49
3 F& A. c1 [9 a2 b
50
. q% c# ?; J2 P
51
5 c- }1 f' Y1 b
52
1 W5 Z4 v Z3 J2 n
53
1 E/ r+ [5 u: N0 p3 C5 k
54
" F4 j! E7 W5 y5 f$ P/ O
55
& S( A4 o1 j, U' ^
56
3 M3 B+ Y2 m% v, K: _9 L9 f" B
57
9 O( }* \$ P5 l
58
6 x( C. r) Y) z
59
# k8 o7 R% L; T. E, s' Z/ y
60
9 L; h$ e% u; w3 B V
61
8 w; [; c7 G9 v+ k/ Y
62
6 i2 L' i( W) P; [+ L8 F" Q
63
: B! \* s1 y0 [9 s P
64
( G4 h* B5 Q& c5 q
65
, I4 f5 ?$ @' f. E
66
- Q" ~9 E) K& S( n
67
7 `2 A' U z3 q" {0 r6 A
68
& ?- {0 f D! z( X
69
. d, L3 O: V, |/ m3 \7 U a, @
70
0 [1 _0 e/ J; `9 X% k
71
$ t( \# c! T0 R2 x! G, W0 R
72
2 }) o; I, u$ s/ T) T+ i7 b4 y0 m6 U
73
% Y* \9 N+ M$ O2 S ~$ O+ \4 O
74
+ ]) @0 H$ v/ |- ]+ }3 S
75
2 r4 K: o& z$ x
76
! @. ?# d1 b% C* B+ F, X5 {) A/ m9 B3 z
77
" R, j0 [4 L+ I( g0 S+ ?. N1 f
78
+ @; D7 M% J/ m3 M0 I, J2 `" c2 h7 {
79
4 m; ~+ \# S) E7 q& o9 M
80
5 M# ^+ _$ C+ n N% V
81
' u X( S8 E1 `# l. a A8 @
82
3 o) H% i3 e( d2 p* q4 P( }
83
3 c. f2 y' A0 w; I- k2 }
7. 设置需要训练的参数
8 h1 j" D. w; R7 h c) y. \
# 设置模型名字、输出分类数
) d3 J7 T) ~6 ~1 l$ N
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
& m0 [- V- _5 c, J% t) J. c
. P7 P5 R6 Q v C
# GPU 计算
5 P/ v0 U9 i8 ~2 N" t* _, N& p( i
model_ft = model_ft.to(device)
1 E: {% ?1 s' T+ i
8 {. B5 S; T; ?; q$ I0 ^" B$ f4 F
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
! H+ J. z* Q* P: `
filename = 'checkpoint.pth'
- M! G. [' q' @9 Q/ w4 r2 m7 @
5 d. A; a: v# I- N2 U' S
# 是否训练所有层
/ r% ` L& Y% |0 g/ J
params_to_update = model_ft.parameters()
7 ^0 {* ~: T. [, Q
# 打印出需要训练的层
, @7 F; O: ~3 [6 V* `
print("Params to learn:")
& _! n9 p$ [- m" z v- ?
if feature_extract:
; J& ~! u7 p9 j
params_to_update = []
' n2 m) C5 |6 K, Y& {) i
for name, param in model_ft.named_parameters():
+ C$ b5 j) o5 x$ I1 C! u7 M8 w
if param.requires_grad == True:
) b. W+ s$ M1 n" a) y$ |5 C+ `/ q
params_to_update.append(param)
# a/ F3 u: E3 p/ `1 }2 M
print("\t", name)
3 w+ B7 ]! V5 W5 m6 b% {
else:
( x& n) i% P1 [4 m4 W7 o+ Y
for name, param in model_ft.named_parameters():
$ _1 A/ v& ?, n
if param.requires_grad ==True:
+ I/ I. P, t4 @+ \1 J; ~
print("\t", name)
$ t3 `. r/ {5 ~1 { n" w$ ~
; F5 n% h& x( b, @
1
( `9 c( { X7 E$ ~
2
+ p- ] [8 e; p* F: y Q
3
1 u/ j5 e# q! U
4
7 ?$ _3 }% S7 M0 ?- ~. N/ ^' K, j
5
# C- Q4 ?$ U2 t7 s! `0 @1 `" a
6
' s+ w0 w* l N! h" J3 g! w$ K- Q$ u% F
7
3 g. I* x; u6 }. C7 Q; |
8
; s+ y( a3 c- P0 Y6 R" P3 F7 u+ M) d
9
% A( k8 H/ k1 F
10
1 O3 r1 d% \( V. S e# z' m/ V
11
+ l8 `; B* A3 x0 q/ Z: s: r: m
12
- ]: M2 m! b8 Y" M6 t
13
5 D+ ?0 G+ x" }
14
0 t* U' x" e; N% e z8 e
15
% z$ R D2 h8 s
16
7 A% E/ W. ]4 s; v% F' f
17
3 D/ D) n3 |$ k( ~6 n3 Q1 m0 S
18
8 p N3 V9 _2 H/ L& `& O8 T
19
+ P7 m, x* `) c, e7 g/ m5 G# S4 w3 `: r
20
' I. l: [3 u; X9 R# P$ A, S
21
! y( N$ w! `/ n# ^' Y3 E8 \ c" q
22
6 M" l+ l( W# e
23
8 X. l( w( L0 o% f( W9 Z
Params to learn:
4 w! D) g. j: `+ {9 q
fc.0.weight
4 D# {" C9 T1 j. H! x
fc.0.bias
r! K# d2 i% B6 p5 c9 F9 x
1
: u* k& J6 R3 v
2
. i6 W+ h( \: M0 P" V
3
; v8 x3 F" `2 Z' k8 q4 x
7. 训练与预测
( h: e9 ~# W! ]. x
7.1 优化器设置
7 ~: J5 w$ O! V' @9 V
# 优化器设置
0 j# V9 K- g. u$ l
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
) B7 F5 |. s, N2 s+ K7 B
# 学习率衰减策略
7 D2 V0 ]9 K- X
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
) w2 Q f; I0 y6 h- P- u
# 学习率每7个epoch衰减为原来的1/10
9 u* m& q/ F6 L; W: a k
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
) f$ f! G/ E" m. R
2 n, {' j8 X, O* U" P& w
criterion = nn.NLLLoss()
2 n" C# t6 O4 t1 t5 ?2 F. f% L
1
$ j! {7 [0 m/ U) a* S
2
" y) |% B3 \: s9 a8 p
3
t( i0 }9 |; f5 N- p6 T
4
5 G/ }5 n% G: v+ I) g
5
& g" w9 U4 \5 z6 d7 v& I! |! H
6
, G9 s4 d* d2 i8 q5 l: s
7
5 _* s- Y2 J% C0 ^# p8 U/ k
8
2 A4 m2 D+ s8 O, b, X
# 定义训练函数
% |( {2 ]+ Y) j( b: D( [& f
#is_inception:要不要用其他的网络
* ]: Q5 l% N9 I3 W. E. d
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
$ O* c2 z% r3 f
since = time.time()
1 B: Y$ X$ ?4 c* w
#保存最好的准确率
% t* L4 A. U+ C! w. }
best_acc = 0
4 [! h: R/ \3 G" y0 E) F, G( n0 E
"""
' l) h; K( I/ _: y3 A9 G) y
checkpoint = torch.load(filename)
( i1 Z6 z6 d7 w; N2 r' \
best_acc = checkpoint['best_acc']
( I6 k# t7 E/ a# S( r" s) x
model.load_state_dict(checkpoint['state_dict'])
2 t0 o" d* D# O) o) [
optimizer.load_state_dict(checkpoint['optimizer'])
% ~: [/ {/ ]2 u; [: z8 }/ j+ T8 X# ^
model.class_to_idx = checkpoint['mapping']
4 N) _8 z5 n7 X) m/ d/ S7 `
"""
! K. C' ^2 j! W/ ?
#指定用GPU还是CPU
: E& g) { Y+ J5 I o; E' }
model.to(device)
9 D8 S. S6 u* |" c
#下面是为展示做的
5 g2 s, F: l! T* o2 d' v
val_acc_history = []
* u$ B r+ ]$ E7 m( V1 a) `6 S
train_acc_history = []
6 P1 I& k) P, L7 ?9 v# s/ Y4 z5 Y
train_losses = []
. p- o9 `. v) b" p0 y$ g
valid_losses = []
0 @4 t' g. }/ M' N/ C8 c4 [2 V
LRs = [optimizer.param_groups[0]['lr']]
" Y, X/ ?4 `0 K0 j
#最好的一次存下来
7 v- B9 _" K# [
best_model_wts = copy.deepcopy(model.state_dict())
5 @0 V' {2 k$ K# ]6 t0 y8 K; k
, g9 S; q1 T9 s5 ~; m' @4 h, O A, h
for epoch in range(num_epochs):
% G6 _" O) Y" J4 m! U6 i
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
3 b% R8 |' h2 U5 S3 Y
print('-' * 10)
" e+ J# O" R# m* X) p# {" Y- Q
7 u/ f" j/ N: o" K/ m$ e
# 训练和验证
. ]. H* K: S2 n) ]
for phase in ['train', 'valid']:
4 N" q0 n" \3 x0 q# b( r
if phase == 'train':
5 U) K" {" e) U0 x. U
model.train() # 训练
/ i% n. I& ^% _, D# D3 D
else:
2 s' F' P/ J/ O$ Z
model.eval() # 验证
$ a! p' @2 I) ?' G9 {. Z
. k7 ^8 Q2 ~3 g; y, R
running_loss = 0.0
\5 v) `$ ]( @
running_corrects = 0
0 Z6 ?: S8 g7 E" ^
5 T1 ^) E C9 d
# 把数据都取个遍
! w; Q6 [0 `2 F% L
for inputs, labels in dataloaders[phase]:
I6 [ P1 u- L3 F* E5 Y. n1 M5 _
#下面是将inputs,labels传到GPU
8 M$ ]4 j& P4 H" v/ c2 d7 G3 n
inputs = inputs.to(device)
) U/ j7 {" h* v* J' [8 @) A5 N
labels = labels.to(device)
% t% g, B: o1 g0 ]# |" e) I
0 S5 ~4 u7 u& x7 r6 z: s3 {
# 清零
G/ t+ U! d" {9 Z7 j
optimizer.zero_grad()
& Q1 d3 p7 b2 L+ G# Y" V: l
# 只有训练的时候计算和更新梯度
. h8 R7 A8 W- z4 n9 ^
with torch.set_grad_enabled(phase == 'train'):
! A- v1 j$ ^8 ` S: Y P9 K
#if这面不需要计算,可忽略
0 q4 e( E( R/ x; C y" f6 K U3 b' ~
if is_inception and phase == 'train':
; a; O4 i8 [9 i2 m! N( ]3 i/ n
outputs, aux_outputs = model(inputs)
5 \+ @+ c/ k" O/ E! X) w: L
loss1 = criterion(outputs, labels)
0 J i2 j% r0 `6 k' i/ h
loss2 = criterion(aux_outputs, labels)
, J8 R/ T1 T: J; ?% A
loss = loss1 + 0.4*loss2
& `* l8 [$ I' ~2 ^- H- @* \# Y
else:#resnet执行的是这里
" @! M+ C, C, ]" u
outputs = model(inputs)
# u$ S) f" C! {, I/ N9 B' z3 `
loss = criterion(outputs, labels)
7 e% A/ ] q5 i( x8 @, |- C( G
$ z% |5 V0 f/ `7 T
#概率最大的返回preds
% |1 v+ E+ n) L( P5 }
_, preds = torch.max(outputs, 1)
1 I( O& p" [/ G0 l5 \5 U& U
0 ~$ w2 s3 W/ c# i1 p
# 训练阶段更新权重
" h5 P% E, G2 r& v7 J
if phase == 'train':
! b5 j/ o2 e9 k* d, B
loss.backward()
6 D1 b. ~) D: _: ~8 b
optimizer.step()
6 `- y0 v7 k& b" p. d1 p& ~
9 M* i$ A4 N- s {" V" q
# 计算损失
0 a# s8 N. h) i6 h1 V$ t
running_loss += loss.item() * inputs.size(0)
. {% d4 {, F: {5 B8 d/ Z
running_corrects += torch.sum(preds == labels.data)
$ \# e- Z9 ^7 ~* a
: K9 o. |9 E% ?( C# C9 u9 y. E
#打印操作
* b! }; w, g' S& i& E
epoch_loss = running_loss / len(dataloaders[phase].dataset)
; \7 C3 h$ `: ^' G3 K: N5 y
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
$ w& T8 c2 T7 _5 Y. b
* ?& K1 q- e3 @! X$ m
2 f8 ]3 L/ E/ j' f% x% A: U" |+ b
time_elapsed = time.time() - since
+ ?) u+ Z2 G/ D( `3 H
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
! L1 w$ n1 c5 O4 h9 D
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
% @3 F3 Y, \$ R
* [, m! Q5 g' Z4 v% f
& K# k& Q% |3 N a; \" j3 b; x- H
# 得到最好那次的模型
1 t/ M0 n, J' C! d. W1 G
if phase == 'valid' and epoch_acc > best_acc:
$ H1 F2 p" D# s+ g/ W
best_acc = epoch_acc
7 H0 E+ W9 L$ |$ a4 {* X. T
#模型保存
, i- o( a& U/ O; J
best_model_wts = copy.deepcopy(model.state_dict())
, [& @/ X9 A2 |7 D# j0 |
state = {
# x, C n1 \) g7 C) ?( [5 Z% `
#tate_dict变量存放训练过程中需要学习的权重和偏执系数
/ l. ]% x7 l8 F$ D6 ?1 ]
'state_dict': model.state_dict(),
: t w3 {; g v0 z4 U5 I' {
'best_acc': best_acc,
3 u& n( Y- K, n" V" r/ a$ f, x% E
'optimizer' : optimizer.state_dict(),
. k& Y' v2 M, f4 f; X) J
}
$ o/ C& C0 K9 p- y' M% }: ~7 ]* V$ a+ a
torch.save(state, filename)
- f: D/ h2 T0 N& I1 a7 {
if phase == 'valid':
0 K' `3 V& Q+ A6 ^
val_acc_history.append(epoch_acc)
0 B% ]: v& Q/ R+ P2 f6 h0 D6 g
valid_losses.append(epoch_loss)
8 u5 J7 u' h7 r# P- r1 [! b+ \
scheduler.step(epoch_loss)
; e2 d: j6 R+ g9 X9 K+ w d* K5 j
if phase == 'train':
- B: W9 e/ L. B" J) ^
train_acc_history.append(epoch_acc)
0 a9 D" |5 y2 Y1 F* ?
train_losses.append(epoch_loss)
0 [7 \9 g- l! U4 [9 T
' ~8 H- t; }; X- R# k. d" Y1 p' t9 @
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
; l0 X4 I. z' F- ^1 {
LRs.append(optimizer.param_groups[0]['lr'])
Q# X& t! J9 h* P! {( I
print()
Y& o4 \& { B Q! Q# B
' i! ~, u: p9 b( \
time_elapsed = time.time() - since
' i7 b' c) ?4 G s0 F, r
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
) q! j9 o' Y) x
print('Best val Acc: {:4f}'.format(best_acc))
4 f/ p4 G9 o9 l& Z1 d8 @ e& @ |
7 b& L/ c2 W/ G6 R- g
# 保存训练完后用最好的一次当做模型最终的结果
3 E$ R `5 H' H* T& g
model.load_state_dict(best_model_wts)
& G. q# |$ L% h0 ]
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
( \' M- k j. g6 Z6 {6 n8 ]
6 `( j0 O/ \8 @5 l
, b7 A- j- N! H: L4 h) K
1
( N& F7 k& w* g1 Y. O1 r* ]; G+ ?
2
* R; `. i0 |+ v6 x+ a
3
7 y: e% \& D. s7 z3 S/ t4 K. `
4
2 o& x4 O4 s7 J! H
5
0 m. v7 f2 o/ d" z0 T0 t3 E
6
, P" C- T) y L9 d2 l+ `
7
- \+ ]# R1 @; P7 g
8
$ r: K$ J! o2 O8 v% u
9
& `% v; J. H; g- \7 @
10
5 ?0 T; V! j/ w4 P7 Q
11
, Z2 l1 z: p" O
12
; H$ t- r4 E, C9 ^
13
: B3 r4 D* n& p% H( _
14
; d: z( p6 o" z5 X7 A) R6 h2 h0 X1 I) I
15
e8 w) ?9 C7 l& I( \! z# F
16
( ]/ R5 |9 [% j( G& `1 Z& \
17
. L, z# w; F# K3 }7 _" j& O/ m$ k
18
! ]6 }7 @; G! o1 O6 v
19
8 x. c, s+ D9 D# b' h- R
20
# Z, d5 S" j4 t6 ]* P7 Z- S
21
3 f1 K" B( u' |# X9 l7 L3 N& c
22
# Y$ O8 p0 w9 l+ h' v; O/ P6 I6 z
23
' v5 D/ B1 l/ ` i$ }2 M( m
24
' Q# ~4 b2 `4 }2 h$ y
25
3 c* _2 _8 d" K+ r5 i6 o
26
, ]2 E/ v2 f! f4 K; j
27
. P) j4 R7 I. |$ o9 X/ c
28
& X/ U, B2 m+ Z; I
29
0 _) r6 I6 Q% U' C! p4 ]2 Q, ]
30
1 o7 r% H. X4 J: Y" |! s
31
5 i3 U, i' \5 R- O$ m! T) O8 ]# I
32
' l9 b9 @- H$ P+ ^
33
) |' g$ x& p$ V
34
( G! n: V" h Z v
35
, _# u4 ?& @6 Q
36
; T4 ]) v7 W" m3 y$ Z
37
7 m' F) W4 b. j# M* c: W) w- h; B
38
6 L* l7 J$ ~! g3 H" C
39
1 X# M! s- B+ c& z/ ^' U, k
40
7 f3 {+ A# G1 c# J
41
6 {4 c2 B5 o9 i }4 J
42
# s c/ p2 H* `* E1 ^$ S
43
( Q# G7 ?0 R7 i. P% z7 b
44
+ J9 o7 o& h4 [; n: {
45
, P( D5 J& t3 O
46
, B) g: ~) @: r: }/ {+ C! s
47
0 U9 E% D3 K) b: n/ b4 L! f( F
48
1 ^( A I8 R, o0 Z+ _) V- e
49
$ l( {1 p3 ]8 g/ |' B# {
50
" ?1 Z& W3 ~% O; h
51
. Y$ u G" \$ c0 x3 R! B4 D
52
8 [ J" Y6 S( g, g( ]" K7 f& w
53
. m7 {/ x; s9 g" H1 S
54
~" P9 `* }# R: R0 ^: M
55
- l6 M* h( `" e; k; A+ ?& a, X
56
/ W! @+ @- E8 a7 P# W
57
2 K2 s$ h/ d" x2 R
58
. K9 f0 P6 G. h
59
; U' Y5 |% f/ e( c# A( j
60
0 b( @, I) X; y9 Q2 S+ o
61
0 I. a2 x/ j5 ?5 D
62
, T r* r0 d1 i; P! M) @- \
63
2 K4 Z% O1 q4 ? V5 j
64
) x9 V1 u5 c0 T% }
65
9 f4 f! p8 t( ?5 F- `
66
" B+ G' b% V; b% I+ _
67
4 Q' F# z4 d. ?( J/ K/ Q# d
68
- w2 f2 X) i3 Y: R( @# ]6 K6 ~& \
69
D' I; |9 S* [' X2 C
70
7 h) F. n' j9 i7 c- O
71
' W; l0 ]4 l* p* |4 f4 A% l
72
! j' C( u3 |, x
73
. F& I {8 o2 q/ V k: r
74
4 y; O5 D2 W% j/ C3 t
75
: L4 D. b" t9 W$ t% Y( i `
76
7 c( W2 d5 \: D8 f$ w+ T: m
77
2 l$ ?8 G& ]5 a+ D2 |
78
\. K; U5 ?. [4 Q- w. V# v
79
- Z! f) w$ d i8 y1 W3 g' U
80
- I6 y1 U A( R0 S2 ]0 y) E( o( K$ _
81
% I( P8 y' j( S+ M. H
82
! j; h' b0 n2 M6 T) p5 e" s
83
9 |8 Q- M r! E8 |4 d
84
* W- ?, U& c. R$ y' f: U, ?
85
3 w' o1 D8 d/ H$ x3 W! O
86
6 A1 i( L( B' Z
87
+ `+ ~* }7 a0 K
88
5 a0 n( [( [' q+ |8 |# \8 G
89
, ?) f# ]% V( s% x3 t# }
90
( S3 {6 ^! \ K q
91
/ m5 ^; [, w8 e2 D; f& k* e
92
) M1 y# E0 X+ _& M. j
93
- p- Q: S3 f* P* E
94
6 `3 Y0 p: s1 ]2 d: G! k5 A9 h1 Y
95
) ^& l- P0 Z" |1 ]! ]9 b
96
$ a7 U0 I: L2 C8 H
97
$ N9 T3 a# n* b0 t" R
98
% E- L* L, ]4 X( V9 T, |, E5 T& O
99
/ m# U# h6 m' Q( F9 ^6 t# D
100
2 l9 w [; p- ?7 u& A6 Y. Q
101
5 g7 p! s$ `: Q1 N3 k! b
102
' s( G2 W. Z* e
103
3 k/ b! t% X# }, Z
104
5 o6 ^) h1 N7 j3 @' v! b7 G% b- P* z1 E
105
. t B5 \( Y4 |, f8 v+ v
106
0 F8 F* }( l' H) R, O7 h
107
1 r7 F: `& E! u: Y4 {
108
8 n( H. B3 _ C- A
109
5 V- j, S+ c+ F+ p
110
" o) j' ]6 l$ O6 m) l
111
! Q( F e. ~9 H) i7 f3 ~; N
112
1 h i" u- N) n+ _$ \
7.2 开始训练模型
, Y1 n' k2 Z! Y" {: f
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
/ Q& j! y9 D% D( m9 a
( M( w" F" Z9 E7 j. {
#若太慢,把epoch调低,迭代50次可能好些
+ v( r& i( `4 ~ |5 |& e* W
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
1 Y5 }! K7 X C0 q5 x$ g: c
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
2 C* x4 a9 I& f8 ^) X& _ F
& P d; p. F% x
1
- X8 F' Q/ C) G: X' d1 o
2
" X# n- t h* I( Z$ j
3
4 x5 ?; C8 e# M1 d9 s$ a* O
4
" v0 |' _& f( A9 n* i. }$ A, Z) p
Epoch 0/4
; O3 ?9 {; \3 G
----------
4 j( t# u/ r9 `3 C
Time elapsed 29m 41s
0 Q4 v; J$ B4 y, H a" m! Z$ G
train Loss: 10.4774 Acc: 0.3147
" `% a) h9 v4 V& c* Y. m, ?
Time elapsed 32m 54s
! d; T9 J) K% ~& d
valid Loss: 8.2902 Acc: 0.4719
3 }) ]$ B7 I) f- `% N- i" F
Optimizer learning rate : 0.0010000
1 {) l# Y- \2 q) S$ \: A
% J0 K! Q3 E- \
Epoch 1/4
1 {! |6 v. v9 P, L) @/ ~4 Z
----------
2 ?+ ?4 F7 K# C
Time elapsed 60m 11s
; w, d4 ?( N @% T: P6 P
train Loss: 2.3126 Acc: 0.7053
9 z1 ] o7 m) d
Time elapsed 63m 16s
: ]7 A1 a2 v3 L! G1 O/ G
valid Loss: 3.2325 Acc: 0.6626
* G- v d# j. G! j& t# p% W5 ~
Optimizer learning rate : 0.0100000
9 K3 O4 X/ \) i! Q6 q- [" q4 x
6 ^) \2 q2 R1 T) d) ~
Epoch 2/4
# U& q* s$ b% _1 }5 ]: a3 M7 ^
----------
$ p' O$ r- L% a$ s) m3 R
Time elapsed 90m 58s
; h7 ^ x" V! F5 [# u$ i
train Loss: 9.9720 Acc: 0.4734
/ G$ F, }6 H7 t% ?, J, Z
Time elapsed 94m 4s
6 Q; P8 Z8 y8 p) p% g
valid Loss: 14.0426 Acc: 0.4413
8 V' x/ [$ {9 @
Optimizer learning rate : 0.0001000
I8 N, O- a+ n: V- `
4 g6 @2 f: Q k, J3 u9 W# W
Epoch 3/4
1 x0 d, i& O' j) |# o7 f x3 j
----------
: l) }3 ]6 I& }3 W: K" e( h
Time elapsed 132m 49s
' y9 g6 P, Q+ D" H/ ~0 ?: K: J% _
train Loss: 5.4290 Acc: 0.6548
3 y8 G1 M7 n( k6 \9 e, N) R1 ?
Time elapsed 138m 49s
* ^4 o* h; x' L8 a
valid Loss: 6.4208 Acc: 0.6027
# k% ]' p K$ U' K) x
Optimizer learning rate : 0.0100000
2 \4 u7 O- a6 o3 h* M; Z
7 {- c" e/ w! q. g; j
Epoch 4/4
9 u+ l2 S4 V7 i5 E0 u
----------
5 F8 W$ }) g5 B! ~& I/ {
Time elapsed 195m 56s
; X6 {$ _5 \2 u4 Q5 T
train Loss: 8.8911 Acc: 0.5519
5 g; f F F* R: x) Y8 ]( J+ D
Time elapsed 199m 16s
& u, I R# L/ K6 u
valid Loss: 13.2221 Acc: 0.4914
$ O1 G \7 d! k8 i
Optimizer learning rate : 0.0010000
% a ]& _; g7 x/ E# s4 |
! P5 I0 g: E+ I: j* Y6 _- G+ I
Training complete in 199m 16s
8 h+ X3 X3 `: Q4 P; D3 M
Best val Acc: 0.662592
1 V8 s7 g6 Z8 [) ?
2 l2 n! N# L3 j5 q3 ~7 O
1
8 ? V, p# y$ S J1 Y4 G
2
$ n6 J. s) N$ X+ V
3
g3 B& K- G6 y x! ~
4
& f% ?1 O0 Y4 t
5
+ |# P4 e& D1 w/ M! F9 k& _) ?
6
, e3 }. I6 s' @; c4 G
7
5 I) r4 b- M, H7 R7 x' z# Y
8
$ i9 `/ H% X l% u3 t2 `8 u( r
9
u: j! H$ G" E" _! h
10
; K/ \( Z7 G1 r* r
11
' I4 v$ \+ A- J
12
9 d; j% E0 Z9 c2 ~
13
% X: d0 h1 R2 }" [0 W6 R8 O8 M" E
14
* r: [: Q2 r4 [1 a2 W
15
A6 R* I" j+ i! s8 F E
16
9 p8 S8 f6 G' `# g( v' v
17
. k( `! g& G. D& E
18
/ t2 L0 A t! J) f7 _: d
19
: X8 x0 s7 u4 a8 u0 j ~
20
! m% ~' w% J( v) h: U6 V0 i
21
' A% K+ m. d; }/ H7 K" O
22
/ P/ L4 [+ w u3 s) w6 p
23
6 U' A1 o- X6 b
24
7 N6 H3 k C& ?& I: S" c
25
0 b9 l. t9 n* E# m& n
26
9 C5 x0 L( c' }& [9 d* ]
27
& q+ B2 x' q$ y
28
8 l3 S H+ k* y" h' o( S' K2 u! \
29
* N+ A9 b; w4 ~8 q: d
30
( ~# F$ p, T' n7 s R. d9 `
31
0 n) F4 }5 g* c+ b5 N/ e
32
' i, W+ ~$ q8 i& m* T. a
33
8 k8 b1 s! q) C ]% f
34
8 W- s3 \& m! d# i5 V
35
5 x) A& P& {5 l$ y2 l' J
36
1 [4 i4 j1 e! r6 T& E/ n& T
37
8 V$ m8 w$ n8 s& x/ P- C( Y
38
. R; K4 k0 w" c' V# y' p) b
39
) d( J3 s- S, b1 g$ ]! p; Q6 d
40
% _ u* \: g4 t, h! d
41
* T$ K9 `. O3 A' u! P$ b, d
42
' n8 d$ C$ [* q# C( |. g
7.3 训练所有层
$ \# x E( t% h( w0 }1 |1 e
# 将全部网络解锁进行训练
8 [6 x" l- X8 `& X
for param in model_ft.parameters():
" k5 k) E. P& A P+ R
param.requires_grad = True
2 g+ _1 U8 m4 Z0 V0 {! U
7 g9 o) v( l5 l8 e5 t
# 再继续训练所有的参数,学习率调小一点\
( W; z- @5 R( l
optimizer = optim.Adam(params_to_update, lr = 1e-4)
# ^( |; P9 }+ D, q, p- S2 r
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
3 ^/ M3 E/ v$ W% N
! g) v2 C$ [8 t0 U$ u) r* Z
# 损失函数
, v! b" W, b+ |$ W1 G3 ~& M5 E' `
criterion = nn.NLLLoss()
% q1 r1 a$ ~ p+ |1 e, ~
1
9 }( T% o! K; N
2
- D0 L+ \8 r# u9 X `0 J
3
% I' T8 l5 C+ h/ ^* Z/ @
4
4 Y: g" k9 C% ^
5
! {: v6 O) e5 X, x
6
{, u% }; G8 I4 X
7
, O5 Z$ {1 d1 b9 K$ K3 n. D8 V1 w
8
' {* G( {. s4 Q. P6 O8 e
9
% M$ Q W0 J. r% W3 n4 I
10
- F, M( M8 S' T+ Q8 u5 Q
# 加载保存的参数
H8 O ~5 ~1 e4 U- m u
# 并在原有的模型基础上继续训练
& m# e9 \# J7 }& E5 G6 U
# 下面保存的是刚刚训练效果较好的路径
4 k( m# `; E( A1 r7 ~ [
checkpoint = torch.load(filename)
4 X% Z2 c3 a/ j Y. _# H
best_acc = checkpoint['best_acc']
w! h8 s, R$ }$ }, s: G0 L
model_ft.load_state_dict(checkpoint['state_dict'])
, x2 x$ d9 d( o0 q) p
optimizer.load_state_dict(checkpoint['optimizer'])
+ v" i0 @3 f. [
1
0 E. f1 M6 C2 l/ |# z
2
; v' t- P3 @& P. J- m9 F, b6 |* c
3
+ M" }/ {/ L" O! g# d* J6 t
4
5 A' Z* A8 W. j$ W+ c5 A
5
& @) \7 y5 m- [7 U
6
" d0 o' g4 Y5 K1 x: r
7
; o1 c- W0 d- H
开始训练
# Z, d1 n1 C' h. J* p7 z8 A' b* K9 B
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
- \4 `3 g" K4 E* L
& X8 G% \; e9 r2 n2 [" n
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
' H. V; X2 F+ ~
1
& ^; t: \+ X$ D! ]
Epoch 0/1
7 n' p( k4 g2 L2 [
----------
) T) l# T3 U2 W O' z! @$ k6 i/ {
Time elapsed 35m 22s
' M8 _) T) @! L1 f! p
train Loss: 1.7636 Acc: 0.7346
. I( l3 U# `8 j2 v2 Z4 d
Time elapsed 38m 42s
6 w R+ D0 e/ G
valid Loss: 3.6377 Acc: 0.6455
5 R" a0 Z9 g" z4 N
Optimizer learning rate : 0.0010000
% M! s& @' g5 @+ ]' w9 r' L
* A6 j8 V$ Z( w7 u# X) B$ ?0 k' T
Epoch 1/1
7 R* z7 ~& e6 O: D; k' E- Y/ E
----------
. z# e7 h6 Q/ u; @3 J h
Time elapsed 82m 59s
2 U) o2 b& W m
train Loss: 1.7543 Acc: 0.7340
, L% ^5 m6 L4 l s$ Z! y" ^0 |
Time elapsed 86m 11s
! [* f' w. A( \0 |4 @
valid Loss: 3.8275 Acc: 0.6137
5 ]$ m! Y/ ?( Z9 z; Y1 L
Optimizer learning rate : 0.0010000
1 K8 T( p* I2 w* G
6 s* G7 n# E+ A. T# P
Training complete in 86m 11s
# r S. |+ S8 D8 q0 X0 T" [
Best val Acc: 0.645477
+ M$ Y; J0 o% E1 [
; A+ N( i" m# ?/ I2 n, H
1
) [8 S% N2 q1 _4 T
2
) Z' K2 ?2 d4 y! E( m
3
# k) \2 L: P/ r) t0 q
4
# |' t( G& M- g& @* W' W) d- }
5
/ x5 u5 a$ ]$ t) b
6
5 @% z* W* \; [6 y0 k4 ^9 } r
7
1 C" z- {3 j+ Z; |% X0 j
8
( t# `. ]1 K/ e. Y7 C! a$ y0 @, q
9
7 t1 R9 T# h6 U2 b+ V/ \
10
1 [) G1 J+ g' l) J; Z. t7 u8 L' I
11
( I9 a2 t5 ?+ Q/ D" }: {
12
4 h: C* ]+ H; x) n) E+ L# Z
13
: {! x: X! P# u6 |
14
4 X7 K4 E5 M3 b; o
15
1 U$ s' J' }, E! t0 m
16
, H2 O7 p6 p/ c; M
17
* j) p" r& E/ P. I1 s, ]
18
! ^, M/ R; f( m+ `
8. 加载已经训练的模型
4 K# h7 ^: P; A3 [+ ?! S* E1 U
相当于做一次简单的前向传播(逻辑推理),不用更新参数
6 D2 X% ~6 S. Y1 f4 N% g
( e* M B) _% U5 \
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
4 X2 O- u; H2 Z, C* u& H* h2 ]! C( }
9 ^$ P0 J' Q% g& Z! O& _
# GPU 模式
! I% I9 S6 J! L8 O7 Z% N2 M
model_ft = model_ft.to(device) # 扔到GPU中
2 C8 I2 s. E+ p9 B3 V) a% t2 q! q
( G5 @/ x. d2 d# g4 C+ F7 i
# 保存文件的名字
1 K4 j4 n c2 a0 O& `
filename='checkpoint.pth'
- F r- }1 \+ N' S9 T- {
" H- p6 l/ G7 A# j; |! x
# 加载模型
6 @. F$ {1 \% E8 N' w. ^
checkpoint = torch.load(filename)
- e. x: b! V1 j/ P# H
best_acc = checkpoint['best_acc']
7 g) t" [" c5 n+ ]0 Z
model_ft.load_state_dict(checkpoint['state_dict'])
9 d/ I. K* \' ~& }' _( L+ H+ S* P
1
J6 S& l. a& k+ R8 f, U1 | |6 n
2
$ g9 S# \! ]' f* k% L. B) u H
3
6 x0 `& {4 B9 P; |' ^' i
4
& Q b& x( t$ ~6 d* v# F
5
- k; a2 {9 e/ {+ A$ z' V# ]% I0 J
6
5 J( N* V# j, q* h
7
7 o3 n/ z1 W+ Y/ s: U9 g& o
8
! W' I( x* D0 c- W
9
& V. x0 _8 O2 r0 x7 e0 k2 z: Q! F
10
. G; t5 X# |% R- y1 D; B" Z
11
: u. G' p0 K* T2 p2 q
12
! @" H) |% L( u K
<All keys matched successfully>
) o2 n1 z' c4 P7 a
1
7 }. v% r: ]. m* q5 I
def process_image(image_path):
/ q9 X* G G# C/ }: n
# 读取测试集数据
0 u: y. H$ g" L$ l/ u; h
img = Image.open(image_path)
9 B( \/ }% i! Z Q; m p9 f
# Resize, thumbnail方法只能进行比例缩小,所以进行判断
, B5 I: ] p! m$ w' y( Q0 {, {% K
# 与Resize不同
. R: P7 u T8 @
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
" A3 U5 n" L1 a$ T
# 而且对象调用方法会直接改变其大小,返回None
. V9 e' I9 c) g: b' o" U
if img.size[0] > img.size[1]:
1 G* q7 r; A6 U! K' [
img.thumbnail((10000, 256))
1 Y& X! O' A- ]0 J# H/ z
else:
f3 s& N/ V. |
img.thumbnail((256, 10000))
: p2 v0 q! f$ i# x+ r3 g
( j* B( Q; d8 P2 r8 [8 D J, u9 d, K
# crop操作, 将图像再次裁剪为 224 * 224
1 k, _0 t! }8 K: l% Q
left_margin = (img.width - 224) / 2 # 取中间的部分
0 l3 u2 m: d3 }: R' a
bottom_margin = (img.height - 224) / 2
/ a! ]/ j* V) w
right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
5 s% U% F- z8 V* \! h+ \
top_margin = bottom_margin + 224
( I; K* q4 B; V; b, W! Y
( G+ J8 ?5 f; j) I
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
+ Y7 |4 T v( I9 p% V* V+ G
3 K H- l" P% u7 `9 O L( }# j0 W
# 相同预处理的方法
) r; K A- B/ B% C5 x, n( }
# 归一化
% Q. t4 J- Z; v2 v6 x# Y
img = np.array(img) / 255
# v% H6 T; L! n+ Z8 E5 I( t
mean = np.array([0.485, 0.456, 0.406])
- V) [3 u7 f/ j* P& Y% k" _7 \5 p
std = np.array([0.229, 0.224, 0.225])
) W8 e( }7 v, E7 S$ v
img = (img - mean) / std
3 R5 F) B0 s/ @/ F
2 ~( Z* B2 K6 y+ r( N
# 注意颜色通道和位置
- @; K1 m F& @: i
img = img.transpose((2, 0, 1))
! R \) n$ |. E' I; A
$ y: p8 h) Y) Y+ G. V% h
return img
3 ?! d3 d- _4 r0 Y
! G' N$ J1 r; @! }. T6 x* }- P6 }
def imshow(image, ax = None, title = None):
. c2 t+ q. I! l* `5 P
"""展示数据"""
( ~, e$ z+ t |6 e" M
if ax is None:
3 m( a( N# f8 h& q; u* G
fig, ax = plt.subplots()
8 g/ x* ~+ X j! R# S# A+ l. m9 j
/ @- q5 t7 I7 o/ C
# 颜色通道进行还原
5 X' B) v! d1 {. |& ^9 Q
image = np.array(image).transpose((1, 2, 0))
! j6 w* r3 K+ Y: c! N+ F* [: k
& [$ I4 U/ L5 f9 T' z
# 预处理还原
: u5 k$ R; U$ l$ d' H$ \
mean = np.array([0.485, 0.456, 0.406])
7 `5 z6 V9 l" z# D& q2 [
std = np.array([0.229, 0.224, 0.225])
7 e0 ?2 j$ Q. T( t8 O0 j
image = std * image + mean
2 S" @" _7 H& A9 _5 M6 J1 @7 r
image = np.clip(image, 0, 1)
& O2 }( i% S9 }/ g7 `5 I+ w% \0 T
7 h C8 G, F0 U5 b$ m3 m6 T$ W
ax.imshow(image)
: N& @0 X! ~6 n
ax.set_title(title)
/ x' e7 n1 M. H: U, t
0 ^3 d5 p4 Y9 a h6 l" w
return ax
8 ^0 f+ g' e3 A2 V5 ~& i6 J
3 T! t& F9 n( X. Q* q4 l, P i( F( c
image_path = r'./flower_data/valid/3/image_06621.jpg'
2 j4 B x; g; y# L# `% |) O
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
! t- I/ z+ T; O8 E( U
imshow(img)
! H1 m# z! ?$ {) z x" E5 G
7 i, u0 b: Z; F9 T9 z" {1 y
1
7 U4 P: m; s1 q; Q$ c3 h+ v
2
6 i" Z5 l, ?- m
3
8 { x& O9 \& b" ~+ H
4
# t: t. N! J6 E% o. T, b
5
2 E& ?: D9 A0 b, P4 {7 g
6
g+ [4 G* |; h4 m F& V D, y6 D- o
7
; Z8 I, P: e1 R j, \) A( d) o3 Y9 M
8
6 `$ G1 {2 h9 G8 T9 P4 G
9
$ R( |7 K& O/ M. M
10
8 V5 F( k) `# W( ^2 F6 b+ f
11
, A, y( ~1 N8 [1 l' h! P& v
12
7 k3 W, U! D* q7 U5 K: y+ F
13
4 J4 v `+ W/ P% i6 d0 w- F7 n1 s
14
- k, ]; `& Y- Y; q
15
9 B Q- ? I! Q/ V" A3 z# V3 }
16
6 ^7 a6 Q4 w9 s( _* m, A6 }
17
- e2 d+ X) E" O8 `/ w, d6 Q
18
9 g+ `+ y6 H5 L5 [7 X& S# ?( u _
19
& D" |' `3 C6 L) l0 J( Q
20
& G: \7 ~7 {( u3 n
21
9 a3 @2 Y; b& x, Q3 ]; ^
22
+ C5 w% J/ u( b4 Z# f
23
* G9 w6 H. P) B8 X4 l
24
# i& l8 m3 V. s( ~
25
: B6 I" r4 x& s- f+ Y' |
26
3 m* ~( e" c& q# s
27
: V- O/ Q$ N; p7 {9 l4 {4 b O6 k
28
) r0 m6 M8 {. d3 Y# |1 r
29
- m% @% L' Q+ U* s: @2 F& q
30
7 u* c6 l. W- V. V- s
31
- ^, y. Y9 l% d) i2 T* A N
32
0 r D3 w( [9 H9 @8 n; g
33
0 ^1 }. n! s" h* l
34
3 h. w1 R v: {( ]6 U- l
35
. f( R8 t# c8 ^2 X0 y7 l
36
" V" `6 H9 p* T
37
: \5 C3 u* }7 F& b" L
38
8 \1 @2 Q4 Q& l0 i
39
. V$ d+ ?7 E$ ?# g' N2 d) y
40
5 w- G) b d& y% `7 p6 b
41
0 @, I% p5 E( z! z
42
. [2 d# ^& @* f9 C' c
43
5 q9 Y5 W: `; @2 R9 O
44
$ y j3 e' r5 g' z6 z
45
" Q" p2 l; Y# [
46
/ j' D; n( n6 s' e; P
47
/ ^! b( }9 D" I! U7 E
48
7 U* O! c0 t* X3 L3 _! x
49
$ Y, A# u! ~3 n
50
6 `2 w9 U' K" n9 K5 B; z
51
/ S% s- C; {, v: {* I( q
52
5 }! J) J6 k' q; b: r
53
8 @5 g) z9 J- o3 n5 H- B5 ~
54
8 q$ H/ B* ^; \( m
<AxesSubplot:>
" f1 _$ r% A5 k: y
1
2 g. L$ L1 s1 U" ?, Z
# V+ @8 i' C- a! V1 |8 H
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确
# {5 @9 _. }8 i0 K' K, H: _) x
W6 U& S) j' i5 l/ w$ t1 |$ S3 K
img.shape
3 S! n) @5 V. E
1
9 R" W+ F0 [6 T' ?, m A, |( }! Q( j$ ~
(3, 224, 224)
3 o8 G0 V7 S! Y
1
4 B4 `) {; `) H: l' p: I8 Z
证明了通道提前了,而且大小没改变
) |5 O) u9 x6 N0 N- G0 \
0 v) m7 q# s" I2 @! U
9. 推理
: V: z- V: x8 Z8 s8 V& i5 ^; L
img.shape
# Y4 G7 B- h: V5 l# b! v( v6 C
& D6 q6 s t) b9 M$ W o+ s
# 得到一个batch的测试数据
. i. u: T5 f Y/ _+ O3 ^" X. u
dataiter = iter(dataloaders['valid'])
7 B( s( n! l# R2 T. c
images, labels = dataiter.next()
" J' {7 t/ T0 |. n$ F6 m' U$ v/ O
2 Z1 k* R4 l# a' s @
model_ft.eval()
1 y2 r( b; \+ u
# O0 w2 U$ A3 x# N" G* ]1 A
if train_on_gpu:
0 `7 h* W; \1 h* m# n% t5 W7 C
# 前向传播跑一次会得到output
6 d8 y( V; p& b/ W5 ]/ G/ o
output = model_ft(images.cuda())
0 s$ X) d$ x' Y- h4 w4 E
else:
# D* ?- q# V( [. h6 z6 a
output = model_ft(images)
* \5 k0 E) i' O2 K3 f- G2 ^) B9 K. m
v' x( B1 e7 B9 @( {4 z" w
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
( t1 G7 Q* q4 `+ t
output.shape
$ x& W3 D: \6 B% z; e
1 c8 [$ x/ f5 L% }5 V* y- ^
1
# @* s# M1 V, A% S' Y) [' o, N
2
+ Y8 ~$ @6 A# o
3
0 U1 d; m( g2 U2 Y- D9 c
4
9 @; j, c* n/ s* b7 Z* i
5
6 Z" n; O4 J8 d3 h- f$ R9 r& y
6
( B% D, e' Y; m2 X/ y; B. r- ^* g
7
& ~/ a6 n- f0 b3 Y* ]; n
8
' j# S/ a' [6 X) A, F
9
8 L+ ]1 I6 e: O; k* w5 P b3 M
10
# ]8 j0 M; P+ u4 N
11
& ?! P" p2 A" g' Y8 x4 N5 N6 S% n
12
/ e& [ u& I/ D t- |( m* R6 Q
13
& `8 X0 m7 N( i E$ _1 e5 J
14
! w- L- u' j: Z$ Z8 Q# e# o
15
' y w: v6 |! C, M9 @
16
9 p; C+ {+ q4 H P6 w6 g
torch.Size([8, 102])
0 Q- |4 |! E) f9 v w h
1
9 B+ v0 i; K% s+ k
9.1 计算得到最大概率
6 N1 L/ J" @0 k2 r/ L1 X# f1 l
_, preds_tensor = torch.max(output, 1)
H% Z4 x/ `' e5 B% c- A
, K, ]1 p( Z% @# j2 ]& M. U
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量
" a/ j; |2 j. W6 S# i5 l
1
2 O3 c p8 D4 N- \! @/ W. v- p& G
2
5 Z' Z% |9 R2 b$ ~) b! ^
3
" r6 O" d7 x4 \( |+ M$ k6 F Q0 X
9.2 展示预测结果
0 o9 r2 c0 L- L1 g: q( b* d. r
fig = plt.figure(figsize = (20, 20))
- X8 j* }3 f8 n% {. J- C1 R# w c
columns = 4
9 L4 a) l& @1 N- y
rows = 2
( d- l/ }/ @0 q& @
, `; I- d- q4 [+ \
for idx in range(columns * rows):
( Q6 J/ F% h8 v0 H, W2 ^% C
ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
$ r$ o- N# w" L+ F8 a/ x0 D; U
plt.imshow(im_convert(images[idx]))
$ f2 o V: }: ]1 N8 e
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
) l; ]" p6 `1 _- e, K7 V4 P b
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
+ R4 @' H/ \& u F
plt.show()
8 [( \; R& ^. u+ R$ f* l
# 绿色的表示预测是对的,红色表示预测错了
" {4 a- h( @! I
1
% `/ A ^, R! N+ j+ o5 a7 u* h
2
( J$ L+ \6 K5 r8 C0 m2 i5 ^
3
r. k- I6 o4 e) C S \; n
4
- m* f: E8 l0 S5 n$ s
5
$ D: Z1 K& m4 V2 T% I0 R0 w
6
) ^& g+ }% q% R6 B. e/ h
7
3 `# ~) @: {& K- k+ y- N
8
' S1 t# u# n7 K) u# x
9
- x/ N% s6 U& m
10
* w: F9 e3 T1 q! `( X7 N- j
11
+ k5 z$ R+ s" v- u
, b: h+ g, F: n+ F# O2 c
: y+ ~) {5 \% O( x. _ K$ t
5 h: _7 \! H- o( r l- t7 C- I
————————————————
6 F5 r0 l V2 O7 m5 T: k
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
& @5 T Y1 ], [) J5 {
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
6 f" N9 h" \9 K$ l
2 h i6 t2 p" A- \" L6 O
( T; B3 q( m8 ^/ p) Y
欢迎光临 数学建模社区-数学中国 (http://www.madio.net/)
Powered by Discuz! X2.5