在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 564701 点 威望 12 点 阅读权限 255 积分 174633 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 3 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例 ; K! l. R7 U$ f8 d
& y+ I7 S) E+ R; p
文章目录
* N( T3 m ]" R 卷积网络实战 对花进行分类- V; I0 @0 W, T' t" i' D( h3 q
数据预处理部分
1 U) O% h* g( f+ k 网络模块设置) Q s$ U* E$ ?2 |0 n
网络模型的保存与测试$ D* r7 v& G" \* z; s8 X
数据下载:$ G; [+ S$ l* S3 Y! P9 ~. K8 l4 `
1. 导入工具包2 |% z7 ^3 n/ V' R
2. 数据预处理与操作( s' m( t$ q8 h* b, w
3. 制作好数据源
8 ^* Z5 j! O1 `0 u3 [( b0 \8 K 读取标签对应的实际名字
, ^8 M# |. ^) [+ q- k2 p8 f 4.展示一下数据
* J! T d+ _4 Y3 x7 l) L; G 5. 加载models提供的模型,并直接用训练好的权重做初始化参数
, W: e9 H8 j; [ 6.初始化模型架构5 Z( j/ a& q+ L; m% V p
7. 设置需要训练的参数
5 I2 j3 G4 K+ h0 }4 w 7. 训练与预测" \( X: K" A- a0 I: ]7 m' U3 k6 u
7.1 优化器设置
# X' ^9 `( R3 n" |" ~& n 7.2 开始训练模型
) H& H; |% t9 I 7.3 训练所有层
* h$ P. k. U R9 U$ A; Y 开始训练
* a& h- u/ s8 X. n3 a2 f 8. 加载已经训练的模型
# N# e( l8 U; S. b 9. 推理
; m s5 z5 d) s' |+ k 9.1 计算得到最大概率8 f9 O, H* N3 h: Z" K0 o1 T
9.2 展示预测结果
1 O6 {: a& [! R( }# Y# O" v 写在最后
1 `' I/ k! Y; T2 C 卷积网络实战 对花进行分类
* U2 l3 y# ]* _% h* ~# P 本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
) e& { ]. y6 B2 x4 r' \' P
3 b0 a! r0 D9 `' M! N3 I 在文件夹中有102种花,我们主要要对这些花进行分类任务1 Q+ a$ M" k/ M- B9 ~. W
文件夹结构
: [( }( a7 {7 F5 {4 S2 j4 t
% ~. n* q" A1 d. R$ \( G9 b1 l flower_data
6 z' p3 s3 x, c0 F2 \
+ z+ X& z5 [* @5 z$ I/ ?8 ] train
$ w% x& ^5 J; ^4 d4 j' ?
$ Z; R* R8 e, L 1(类别)
( `2 J+ [* z( p+ K 2
1 C- ?3 A/ [$ u xxx.png / xxx.jpg
6 J! _) S E; `4 n7 Q: d, n valid9 _* R& R* {! x
: h% u5 n0 ^6 B 主要分为以下几个大模块
: ^. W; Z+ l1 y% e# X1 ]9 ]
3 I6 d8 Q; A$ H V/ | 数据预处理部分
, n2 ^8 \% `9 [# C1 s3 w 数据增强' o/ Q4 q! ^1 }8 i
数据预处理2 u7 h1 R6 W6 ?7 r9 ~! _
网络模块设置
' u% ^& I4 ]& T1 s7 P' E 加载预训练模型,直接调用torchVision的经典网络架构
& F7 I9 H4 l: M/ e 因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务8 B) l0 @" l) }$ N+ S% ^) K
网络模型的保存与测试
% S0 w7 Z3 I5 n7 A 模型保存可以带有选择性
/ H X9 j4 R$ W: {3 J" y6 K# j 数据下载:7 e) N0 }+ y% B; R. u6 I
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
. O5 V7 ?. x/ C+ i% c 8 J0 q4 D: P1 R/ }9 |2 r% }) ?9 ]
改一下文件名,然后将它放到同一根目录就可以了; H. [0 x# T5 g, ~
0 q; y. f/ w/ s+ G6 z# F% Z& @) W
下面是我的数据根目录
# M8 A: [3 e9 ^8 g& C ) _( T1 r) l9 }9 X0 J- b0 l
0 r8 g! f3 A; f+ u& v: m
1. 导入工具包
& r* k: m6 j6 ^, B' r' v6 k* q2 s import os
# w" l; r$ @& Q0 V b! s) ` import matplotlib.pyplot as plt; p) X; s; a: ]2 q5 l
# 内嵌入绘图简去show的句柄
0 i q- r7 a1 k* `/ w$ A! F9 o( O1 \ %matplotlib inline
) D ~& o8 w) t) X" U import numpy as np
/ ~/ x- q5 H B$ A2 A import torch
; E+ F& J v9 o/ d" ~9 O& m from torch import nn
5 U7 n6 ~0 V5 d$ g! k( N
, y9 v- s8 x; Z% j% n9 H" Q import torch.optim as optim
7 K9 [2 P l8 U# n. _ import torchvision
* r* C+ e" q Z! Z/ X& M from torchvision import transforms, models, datasets
/ P" _( W9 `' K# A! E , }* s7 Y; @3 w3 ?+ J
import imageio( m0 Z6 l1 W" j/ x: l
import time
C% T8 [* k1 V; g! f9 o import warnings
: o& f5 y3 i/ D x( \* G# ]& Q import random2 U. Y$ E0 Z9 l& Q8 k. _
import sys1 Z: h/ B& i/ l
import copy i! {4 v) _7 |$ z; e$ y9 ^0 k
import json( Q3 K3 _! y% v
from PIL import Image' m" [+ }4 N( e$ x; f
# V: M8 W8 b- h( I6 x
/ ? Y( t. B, b) M4 v+ X# P4 W 1
7 C1 i6 @5 [! l+ d" o) P# Z 2
: u8 O$ B( c0 [, A% t" M8 u" q) l" y 3, w2 Z( z8 H8 {- b- o6 b8 {
46 l/ h: a5 s$ Q7 \
5: O& N1 F- V# o: C
6
4 M0 z% U; w8 L: v2 ]7 T# B$ T! B 7* i- ?3 N- X( B2 D; W6 q. E
81 N) ^/ ^9 O, b4 v& P( Y
94 l/ m$ A1 x! @- \8 p3 |* G
10
( G0 Z2 i* g6 ^: c; X 11
( T: \7 l& V5 e6 m4 t" R 12" g& z$ L- Z# S* S+ _
13% t7 u( j1 G2 b( J j
14' L4 i8 `/ T9 V( t, u; o
15
$ s0 p! ^ J, H 16: a: o+ o3 M% d. f
17
6 j9 I8 d4 E- M# V$ L7 A 18: h" n' \1 K$ o$ R* j
19! i/ [" r N( b. e. |
200 \8 Z1 m7 V1 o% P7 G
21
' L4 i3 |: _/ E* _ Y 2. 数据预处理与操作. h* _# q- f1 ~9 T
#路径设置
2 J- k& Z& ~+ Y" y# j( m1 n data_dir = './flower_data/' # 当前文件夹下的flowerdata目录' y1 j" M; }7 \5 u
train_dir = data_dir + '/train'
+ \' n$ L g1 ?2 |1 v! B valid_dir = data_dir + '/valid'2 q' J2 }7 v4 l( i* D( U2 [
17 n) a( I; k( `/ [
2
* k n- ]3 ?& ] 3% G: k) Y% V* q' \( x0 U
4
3 }. W+ U! j7 y python目录点杠的组合与区别
5 D4 E2 f6 B! w" v) @) y X$ g 注: 里面注明了点杠和斜杠的操作
# {+ a9 O3 { U' N. d7 ~
, v9 R4 g, m$ v: j0 L 3. 制作好数据源
$ u2 a2 U; a& E data_transforms中制定了所有图像预处理的操作6 p9 n3 J! z: L: G
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
1 r6 x+ l& \0 @& U O& \. i data_transforms = {
0 m: D/ Y4 u! d a7 d7 f6 K: P # 分成两部分,一部分是训练$ C8 A6 g: _7 b' V- D" g2 l( p# f
'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间/ D& H' i z+ O( B' S$ ]
transforms.CenterCrop(224), # 从中心处开始裁剪7 |: y. j/ s; T- M% a% [0 |* a
# 以某个随机的概率决定是否翻转 55开1 O& t: A0 u" ?% D v3 A% m: c" A( @& V" k8 w
transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
; w% P# w2 P6 d( }( Q2 ^ transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转3 u: } {/ z! V( y
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
# E: _8 N# b0 k transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
; k, K _. Q+ j5 J$ i7 C$ C transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
6 l M5 A+ \4 j2 ^ # 灰度图转换以后也是三个通道,但是只是RGB是一样的+ o+ M" D3 v5 L6 e4 N; ^# V0 c
transforms.ToTensor(),
@) |5 q# _* a7 _3 ?4 r/ O transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差
: n% k; b! z- M/ F0 Y7 P" _6 D ]),
. n1 F0 B$ O1 j3 I( v. X% ]1 L # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
7 p+ m8 X, h" l. l5 {4 j 'valid': transforms.Compose([transforms.Resize(256),9 k6 j- q- k! t( A
transforms.CenterCrop(224),
; v- ^6 `5 {; u0 u& g9 k; _" Z transforms.ToTensor(),
4 O3 j& [ q+ a transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
; o8 I% ~; @3 t* \ ]),
. T0 H! O" T A- f% N/ s }
- Z- T7 g# N: m* E. g7 J( i
2 y. n: ~! x O" L 1
3 W' ]) x6 r w9 i& Z3 v 2
( d+ g( e! t0 D 3
' V& X0 n5 z, b' [9 E( C4 m5 S 4
O! u; }9 g" F* N# ]4 U. j 5$ ^6 h* C8 W: n0 R7 H: q6 m* B+ G; O& X
6
- d0 f* D2 P1 a' o' O7 ] 7% {( m5 n1 I9 z$ v# \. [
8
' C4 T) L4 Z) i) ^% F, b 9
/ l% i/ P8 F; k# i, | 10, j6 v/ y: h2 t$ u: K
11
* j# y# Y5 M- ^( |* m% Z1 `6 @9 } 12
# w7 i/ c3 h: F% F, C 13& G: O+ D: i+ @( k( D. @/ l
14
# a$ v# p+ Z" z0 M1 k2 l/ j F- ? 15
1 t: P$ `4 a( g4 A' [7 t 16
) e3 A2 u E9 Y4 ^( T+ } 17
6 t& @9 \* }; e& g. E 180 [' i2 P2 B" B
19
, B# O& w; F6 r& j; a8 \& k0 V! X 20
& H+ s0 X+ ?# X6 ` 21, x t$ t3 K% r, R/ k9 D( Z
batch_size = 8
! w: W+ O: V6 K image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}
: H" f z# o# [& [2 e5 u' i dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
2 q' M. s$ F, E5 m6 I dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} ( B! y, G* ^$ d; |
class_names = image_datasets['train'].classes0 x5 W ^% S4 a( g4 ?# b, O0 r5 H
7 s2 f% N8 ?" C% s9 S #查看数据集合' V" n: o, k& Y t% m. M! r! g
image_datasets
. U6 n$ g. Y* s b* ` . F q# I; f3 V {) @% c. G
18 W+ {/ n4 `% _: }
2
8 p. _( M) D7 o1 w% `2 k0 I8 J 3
0 }! m* V0 V* }* ?" U 4
" j, d0 v( [, | 5
. U9 M, n/ A! A5 h# O8 ] 6! X$ g/ d+ I/ A7 Q
7# j; r/ g! [, E' ?: ]" C6 B
84 V: M! B! c" i! z. b" Y# l1 x; [
9
* f* B. s/ q2 e" D+ R' O" C8 r' M {'train': Dataset ImageFolder5 O9 L6 M9 n4 U G7 ^$ R
Number of datapoints: 65520 I/ ]1 v T! W Q! L
Root location: ./flower_data/train+ V0 k( d5 o( g, |" m N; v1 p
StandardTransform, K, r0 M2 e* e K n) P0 X
Transform: Compose(+ `$ A3 T, A! M) _0 ^
RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)3 b- u, B$ _! A" j4 N" G# k
CenterCrop(size=(224, 224))) W; R. Q+ ] d' Z0 @
RandomHorizontalFlip(p=0.5)
! P7 m& o7 _. P3 E- L0 b7 W1 f RandomVerticalFlip(p=0.5)0 K% m6 A% [7 f/ o# |3 h" [" n7 r0 {
ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
$ ?4 x% B" g+ x- e5 n RandomGrayscale(p=0.025)4 x7 \% m9 m+ h5 \
ToTensor()
! Z& O; b/ |! s4 U- S. } C Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) s s( E9 M% K' T2 _9 z- h! p
),; m. E$ R! U X" ^
'valid': Dataset ImageFolder
o8 r+ W( W, _# d2 @& ]& } Number of datapoints: 818 U1 k9 s( f" p8 g/ ~3 W, {( u
Root location: ./flower_data/valid/ m# U2 F; a" w; F- x$ r
StandardTransform' x; l/ \: m( F, e
Transform: Compose() N4 ?- O' z) a# x |0 X) f
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)0 Y k$ d3 u" d* F6 R
CenterCrop(size=(224, 224))5 Z( F7 V; c8 J* Q5 e
ToTensor()
; z4 }& I; [3 c0 G, V Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
2 B+ Y. f* L6 g- W, T' c, _ )}
5 L3 Y+ V) h0 e4 I @! z; W
v# a7 F! A: M; l; M+ p# T6 G 1; ?! k9 f B/ w/ y' S3 e6 F( Z
2- l4 a7 I1 C6 N
3
( o8 p8 |, d5 g$ ^- S1 W 4
( R! W W# T0 r6 {$ N 5 B( c; \. f! z
6* x3 M7 L% B) @! X
7
- P' J r' |5 }: I& q 8
4 }6 B: ?' }6 P/ D 9
1 X6 o4 Y, Z4 B/ ^2 u3 ? 10
$ m. T. w+ m& Q+ t 11
+ y; K7 b+ t$ ^! \2 @: _ 12, w% R1 ^4 d5 v
13
9 x/ J1 v! C/ @3 d 14
. l$ B' r0 A4 O& `( n" n5 ~; b7 Y 15
1 n6 }" }6 _2 |4 \7 M 16
) n/ b w- |: e2 b7 |& S 17; R; Z( ]& ?+ [. _* c
18
% S; f( z2 d# M 19
' ^9 J% Q# t" X. B! d 20
. [: [# X: N5 @4 u+ m* B z 21
) n7 g* d C0 k6 M9 \ @ 22
' L- W' n( L. C: c! I8 W" g 23
% `2 d" I$ Q5 y( C 24! M0 `8 p2 x& Z0 h+ Y
# 验证一下数据是否已经被处理完毕
. Y3 N5 N% j' I; h dataloaders
+ |! Y2 \% z$ A7 x. i6 x 1
) T9 h8 t8 ^9 d* x& p$ X3 }/ q 2% T. u0 T) {( v/ N4 X% f. m# a; k
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
8 W, {6 ]0 {. E- m% @( j 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
. G- U$ p) k, B( N* j 1( i4 a6 E. N* N0 K8 x: z, `7 _- r
28 s8 W/ Q9 e* k% b% u0 X' ~
dataset_sizes
0 a F: p& K/ g9 w, U3 E 1
- Y2 W- N8 n, a: O- z {'train': 6552, 'valid': 818}' @: U% w5 I) K5 B8 i( h" V; W
1. M0 `5 p" s: M+ e: |/ H- F
读取标签对应的实际名字
5 j! }& p9 D2 p 使用同一目录下的json文件,反向映射出花对应的名字: l1 ?% P& Z* Q2 e# \
( s- M, d( Z5 r9 A9 \# [
with open('./flower_data/cat_to_name.json', 'r') as f: E8 K; O' i* [: L5 D) t
cat_to_name = json.load(f)
6 [2 J. A# @9 Z; l. r6 s 1
$ H, q( M2 s6 n! u! u9 ] 2
; ^9 i. h! N4 B; W. O cat_to_name, @' M ]# T! b2 D" u. g; X
1
/ C5 B$ W* Z9 S ? {'21': 'fire lily',) Y8 n5 e7 z7 n0 P2 }
'3': 'canterbury bells',7 j, @3 N2 W6 F/ `: p* B/ k, H
'45': 'bolero deep blue',
9 N/ h( \1 t S' G% e '1': 'pink primrose',1 l: n; h( {. H3 d: g7 S9 A2 }. o
'34': 'mexican aster',
$ a% S& h4 l. u1 Z% L# M; n; I '27': 'prince of wales feathers',0 q! U: I: l) S2 L/ N
'7': 'moon orchid',
# W/ i" {7 m- o. \, ^9 q '16': 'globe-flower',
( A- a) B) q2 e ~) T1 ?" E0 B '25': 'grape hyacinth',$ B) L4 v/ k4 C5 |! T2 V; Q( _
'26': 'corn poppy',: Y" y' x/ D( H0 V; [8 X; c0 d* r
'79': 'toad lily',7 `% L& ?6 Q& L' _( {, P
'39': 'siam tulip',
$ W( n* ?8 R; G& z7 i3 q6 t4 H- ? '24': 'red ginger',
) l C |3 Q' q3 _- Y5 t '67': 'spring crocus',
4 R: `* V; F8 B8 h+ y '35': 'alpine sea holly',5 g7 x% K% H1 m6 P: `
'32': 'garden phlox',
- l/ @7 O5 d: n+ Q9 N# G$ E; J '10': 'globe thistle',; o! s; \9 b) y
'6': 'tiger lily',! v7 {6 H( j! ^8 p
'93': 'ball moss',
0 A/ K: N. k& F. s; }( Q '33': 'love in the mist'," h1 w; C' {/ }, J7 \: d0 N
'9': 'monkshood',
& A- x4 i- z+ h. e6 g '102': 'blackberry lily',
, J: R, n3 X8 p6 P8 n' s. K/ B '14': 'spear thistle',
0 r z0 k, m' t9 J( ~3 e+ E '19': 'balloon flower',: ?' K! B. l; n6 s; D6 o
'100': 'blanket flower',
3 V6 [+ \/ m9 e$ _3 x '13': 'king protea',
+ Y8 H1 \' z' G% p- b) | '49': 'oxeye daisy',( G) ]5 j: n) r. q/ t
'15': 'yellow iris',% {! S( e4 A( f, J. {
'61': 'cautleya spicata',
) |8 j {" f8 b '31': 'carnation'," g2 g9 j) t8 C
'64': 'silverbush',' G* h& r4 o1 k
'68': 'bearded iris',- T% L" m; \# I+ y: X7 `7 y
'63': 'black-eyed susan',& T2 o+ c3 g7 c. p& k2 s
'69': 'windflower',
( z: l& R! [- y' z- P7 f& | '62': 'japanese anemone',
5 l: g1 Z# I- W5 {# _9 y '20': 'giant white arum lily',
4 r- P( S* l& X' B% J) V ?6 {+ Y '38': 'great masterwort',4 d! B4 ?- T* _; l. r I+ P
'4': 'sweet pea',9 c6 T" J& r0 o) t0 |
'86': 'tree mallow',
2 O. S7 ~& M1 z K. m' @ '101': 'trumpet creeper',
2 g# g$ B' O0 r2 ~- P0 @; A. C3 }, z '42': 'daffodil',
! w* X) K7 s" V( g% L0 M '22': 'pincushion flower',
. x$ T i/ ~; t* F '2': 'hard-leaved pocket orchid',
# R2 M0 L+ s3 P; R6 d '54': 'sunflower',3 t2 u' M( U- d. y0 E9 z
'66': 'osteospermum',
! D; s8 T6 P7 C" P/ k! R '70': 'tree poppy',5 b3 j2 g% u" j0 s/ N6 C( w. |
'85': 'desert-rose',
4 b6 _2 ^0 J, A% A. b Z- R '99': 'bromelia',8 P u R$ E0 Q( ~+ F) y
'87': 'magnolia',
3 v8 q; {" p9 {, Q6 u6 y u; u6 t '5': 'english marigold',
/ ?2 B# N" b$ j '92': 'bee balm',
5 l {$ [* @, U* [7 {+ H '28': 'stemless gentian',
" z6 H! S! L8 `2 ^% Q) P: M# l '97': 'mallow',9 j- x+ ~9 Q; X' m; G
'57': 'gaura',
! o- n$ v; ~- x- R '40': 'lenten rose',3 h! z) q1 b' U; `+ ]- Y7 |$ T4 g
'47': 'marigold',+ N: d5 ?; C3 e9 c& P0 P
'59': 'orange dahlia',
- b' g$ W3 F0 O$ T& ?; N) ~: x: Z '48': 'buttercup',
8 v' \7 `% @% M6 r '55': 'pelargonium',
( n5 {0 K- M/ f* f$ W. I '36': 'ruby-lipped cattleya',5 B) [1 u8 P0 _+ ~+ e5 G
'91': 'hippeastrum',
s, L1 o) |" b( `' B! G- l1 H '29': 'artichoke',2 t3 h$ }0 \! |6 h0 e
'71': 'gazania',
( K% k# E+ q3 F! o3 L# R '90': 'canna lily',0 T" V+ N# O' h# A4 o* O
'18': 'peruvian lily',5 d: d* ], \6 E2 g$ \" }- k! u5 M( F
'98': 'mexican petunia',
' ^$ a; f2 k* \' M' ]" S* ~* F) N- ^ '8': 'bird of paradise',% h1 p7 [# a0 F5 o) [
'30': 'sweet william',
- G9 x) O6 c" Q% K6 i+ g/ ] '17': 'purple coneflower',
: _! M* a6 L3 n* e( Y, `( l! @ '52': 'wild pansy',* e9 y3 Z, a. R& j: a, _
'84': 'columbine',
+ q" V/ e1 A2 e* n '12': "colt's foot",; q2 ^3 g+ j$ n
'11': 'snapdragon',3 |4 s; l7 w/ o3 z' b- f
'96': 'camellia',, m+ b; x, v" z+ h2 k
'23': 'fritillary',% v# U6 m- Z9 M' x/ e
'50': 'common dandelion',1 o0 J0 L: y6 o. I
'44': 'poinsettia',
& V, K/ r8 g( [( [$ @' F '53': 'primula',
4 y5 {( u: I4 u '72': 'azalea',
$ x* X# u. L& t3 h1 G '65': 'californian poppy',% V4 P' l- d6 M% E0 V
'80': 'anthurium',
3 l* _* Z: c+ W6 H' T '76': 'morning glory',# \/ C# F! U2 A
'37': 'cape flower',8 n( {: o4 Y: L7 |
'56': 'bishop of llandaff',
! f# R, y, l0 I6 {) M '60': 'pink-yellow dahlia',
% z' d/ N5 [5 b1 C# ^ '82': 'clematis',
- V6 D! Y3 l, U& b '58': 'geranium'," ]/ P9 x$ N, V
'75': 'thorn apple',# U5 M' ]( S% w% |! R% M( E
'41': 'barbeton daisy',. j- y$ I- R% f+ u
'95': 'bougainvillea',
) i# N2 _( l1 G$ o+ m0 L/ }2 J" z8 J '43': 'sword lily',
' O; B' m8 r0 d* |; J '83': 'hibiscus',
3 }5 V- v/ A5 I& [* \. L) B! x7 @# x '78': 'lotus lotus',5 W- l: a* o6 x) N0 Q) S: Q3 J. X- r
'88': 'cyclamen',
- T0 G5 v3 p( s8 I '94': 'foxglove',
, e; |/ A6 u/ T" ^5 ~7 P& O '81': 'frangipani',
) S# F- e W0 |* G V2 n% u' Z+ g '74': 'rose',: {! H7 b1 O7 g5 O; ^0 f( E$ ^- k; |: C1 R
'89': 'watercress',0 N/ m# \8 q3 u% C. u* m. l$ n7 H
'73': 'water lily',, X! F( N) i2 U
'46': 'wallflower',
- J7 i/ s! o6 ~$ `) {0 S '77': 'passion flower',4 \6 p& r- s1 H5 m& |! m8 U2 ~
'51': 'petunia'}
O# M% r1 G& H/ w% r$ z- w2 N0 Q + s T1 {; v: d6 |# |
1
4 Q) c m& e1 b5 i5 ]2 g 23 K: P$ ?) x1 H: A4 s* B( J( ]
3
7 s$ h3 @6 |, p9 ] 4" k- [0 G& X4 R$ K# Y! `
55 w/ F% u2 L$ u' t, U' X: P
6
; n) h9 o: F5 O& Y- ^ 73 ^$ Y7 D3 M- J
8
* b. k( d6 A0 a7 u: | ?8 a 9# V- p: Q3 U' ?' |% ]* f2 G8 A
10
% t: k5 @. o0 s, L9 e' ?3 R 11
2 ~, |4 R5 M3 }0 ~4 K 12
$ g, J' C) N3 T7 `% s8 m/ c1 c( t 13) ]1 V% p9 X$ K( ~2 ]: s
142 b8 v& Y. i6 T: J( W x6 ~
15
- Z3 u( c' {2 l/ W) m% {& c 168 q* h$ p2 l+ e1 K
17
, J+ N" j B6 W. ~! O2 m& h s 18
; b1 E+ \# u* e9 S" ` 19
$ R4 [" w! t, e 203 P4 |: M' c( t' v. q/ I/ t- S
21
; @, k- z, f. I) U 22
6 y( \% O- _+ r- L& W 23
/ G) t5 Z1 a. U; ~, C# u7 [7 M 24! O9 d7 D$ R5 I% J9 B; Q, S
25: b, h; S" V, f+ N. y8 H$ n& i
26
, E$ c; R+ h* ^7 t' B 271 k! e1 n, G, J( s+ I
28
& }" v( Z; R, O8 i5 o 290 p7 V8 j$ l3 Z7 B. ?5 j0 e& i+ a
30& y. J3 U; Q$ `/ z; h; I/ `
31
5 G0 v. [0 }7 w7 I0 u1 t; l/ l 324 |% V$ M- n- ~- m3 y- R
33& s/ F- N& p; p% r6 H1 |& a
340 |& Q" y o0 d
35
. x8 \1 Q+ Z. h3 F) }$ S 36
- F1 t9 i$ J: V& W 37
( B- [6 l/ d3 I. C+ j+ H' r8 Q 385 I$ R$ y3 L9 s+ @" ?* N; l
39
/ a# B4 e' Q+ w" j& I 408 G& s% x) J) f# E+ r' E6 y
41' V% L6 j* X" G- B. x
42
" k; t! n0 I! E& v6 z 43
. H' I( k# B* G S! O 44
- O- D2 n R6 n$ F" s x/ _+ G 45- l/ L& V; _- L8 s1 J
46
8 J- U1 ?- S" t 47
& B$ u/ J( t3 R0 c, z 48! P1 _! X7 `* `5 l/ w
494 s' C) V: u8 b" X+ B/ z3 q
50
8 C! F; N5 c4 k( n4 e 51( O% |, H5 L: e& K$ l6 c, N
52
' B' U/ C; } N: e0 E 53
. y% h7 n- _, Z. [) W 54
+ z6 z% w( g3 n- J1 F1 a( u! m* r( p8 G 55, X* V( b+ `$ c _+ J
56
1 f7 t, j# l9 C, k 57" f- i2 X, i' }
58
& x: n2 y( b! O' e& V+ y 59
6 {0 ?+ N1 v ^' A6 V 60
9 q1 X4 D$ H: n9 z6 ^ 612 v0 J' c# a1 x+ V8 y8 d4 ^5 C
62
" |( A$ C$ M4 Z, T7 O) s 63
6 S J3 H1 q" k 64
/ R# N( S4 j j5 b- z$ P 65
. D4 e5 B s. N 664 _% j3 V0 X6 k
67
/ }! D& K& d- ^% [8 v% M 68# k4 p) y+ c: U% t6 W7 r
69+ K8 b8 S7 A7 c5 F
70
" D. r; E. A. }- e8 T+ u: d* j 71
6 Q9 }6 {: A( O* m- T2 C 72
7 ]! F5 G; w4 i% ?8 o9 z$ c+ S 73
( i5 J C( ?4 x* D2 ]# J6 l% a 740 c3 @7 v; ?; j( I
750 D! |' N' R+ \
76
* u8 j0 ?6 m( `% h# v& y8 [ 77
3 s# C( J2 u3 Z 78# a& s [+ ^7 n
79- K: ?0 f4 q* J
80
! }; ]8 H9 [5 G1 b' o6 \ 819 |! B# _, I9 b: T9 _
824 ^. Z1 g$ f- [/ n- Z {; B, e* B
83* E' T7 `# q) N
84
2 H+ F& K: I; b/ A0 p9 r& ?+ m9 a: D 85
. {( ?0 ^4 g$ n/ o8 A M3 Y! Y 86+ ?7 g( }% G8 k8 Z _0 [
87) T+ E8 a9 n. t
88
* N- E3 j1 h1 ^# S+ @! V$ j 89& w! S4 Q2 j6 s
901 p. u' g. E+ b- e6 S
91
( ~6 ^( {+ ~1 g2 `! u3 g! ]1 p* [ 92
* Q4 o1 e% Y3 Z 93
0 @) h. J. D6 g+ `7 y4 z: l 94
) O4 e/ h& q" d3 H0 S) ` 95
$ b& S7 E8 ^# x 96: p7 D* U* l) c' m0 \ |& _( D
97
* o, u) M3 l2 j. |5 b 98
1 r T0 o* \8 [! T1 P- K, T 99% z/ `( U* p8 p* O; b: A
100+ ]9 Q* w# c' r5 l( j& t
101. z Q( |2 B4 u7 ?5 @
102
9 m6 ^2 G" c- i9 ? 4.展示一下数据
R% _# c8 h B8 { def im_convert(tensor):
! W! ?1 k3 R% } """数据展示"""
& i/ L9 Q% {: q image = tensor.to("cpu").clone().detach()
% U0 ~. n# s; o+ x* n) N2 D image = image.numpy().squeeze()
8 h$ ^! ^7 P! s$ m$ v, i # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
* f6 n. v f% J3 }2 O1 | # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)' _2 H9 R5 e- j/ j, }. C
image = image.transpose(1, 2, 0)* G2 u0 X7 B( b1 E! K* ?
# 反正则化(反标准化)
8 d" F5 \4 N" { image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))- H/ `' j# s( \
2 {. n) H: v5 _& C6 g
# 将图像中小于0 的都换成0,大于的都变成19 Z0 Z; b" n" c
image = image.clip(0, 1)
! O1 K) x% R' b3 z. C / X5 i% Y9 o h. \8 B S9 r
return image
# K- w5 W0 J& l5 X& @: }. s& o1 G 1. e; ~; }. j$ m. t g% C
2
2 e2 O1 |6 E' d3 [/ r R$ j 3
: v" z% v; G4 [/ Y 4) q, F9 ~/ |9 d7 s6 N0 x
5' |: x; Y' m' X, k9 l0 X6 c/ ?; O
6
- r- l7 K# d# u' W( } 7* J+ z0 e2 a; f! Q" q
8
* t# T9 O" Y- O 94 {7 n' d/ z" \3 P( d# n4 |- q. N
10. C/ ~: p$ Q' ~# I
11) G# p) @9 A/ x0 A. y! w+ m$ m- T
12
4 T# M+ X6 g3 t/ q2 F" }- f$ D% e: [ 13' w p; `3 u4 H
14
/ F; w8 o! r% X& J& V0 w # 使用上面定义好的类进行画图, |7 u/ j; t6 e$ f8 T4 A( v
fig = plt.figure(figsize = (20, 12))
/ ~. x2 x2 y8 G3 F' c4 _# ^5 L8 M columns = 4( O0 k- U' p! _8 g/ P
rows = 22 u4 l6 g' Z+ _5 V
8 |7 |- t* F2 P }& @1 O3 | # iter迭代器3 f: p5 {! V5 b- Q+ Z3 Q
# 随便找一个Batch数据进行展示
) |# ~; X: c; L* m0 F- m# m dataiter = iter(dataloaders['valid'])
2 ~0 s4 F. z+ N g: O+ C* @; U inputs, classes = dataiter.next()
$ f s0 `" R1 n4 w$ k1 N* P: Q
! T& r* l; }% h7 S! H8 f" b5 i for idx in range(columns * rows):
0 o5 K/ ~- A7 }5 }$ e |0 U ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
( B% Q" M# W& o8 ^; N # 利用json文件将其对应花的类型打印在图片中
: X5 v" t/ Q! s- h( q: q+ ~ ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])# f W+ I3 l$ h
plt.imshow(im_convert(inputs[idx]))
/ i' A, N ~* r plt.show(), }3 y0 Y$ Y: u0 N( S. d
4 {( r W$ |+ v 1) H. O3 N5 F5 F! n9 S: d
21 w6 I# h0 F8 L, E7 H5 a% E4 D O
3$ K, D4 H( V, J8 \% m; o
4
: _- Q: s) C- y 5
+ U, V! |: Y& T7 U- M 6
' @# W k) M1 U! ^6 K 7
! H; z" k9 Q. D; G 8
; c2 L- b# ~. T: B 9
; s% v8 x3 h+ O6 O; } 10
, r( [; o; a. H1 }# `# ]% ? 11
5 v8 q* F# G; [! n2 O/ N! O 128 o7 s8 S: C9 y2 |" i+ M; j2 e
13* B$ B n+ f( `) d% V
143 b2 [7 t! w6 N; O5 J% l$ q
15- ?: M2 G& i( v& E K
16
% Z2 l' m1 _. Y4 t* U, ]) U * Y/ W8 B9 W+ b
8 c* Z7 d W; o
5. 加载models提供的模型,并直接用训练好的权重做初始化参数
) R7 i8 ^+ y: u model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
7 H! s5 X. q/ `7 K! g, B # 主要的图像识别用resnet来做- y5 N$ Z3 [8 f' F' k d* o( _
# 是否用人家训练好的特征; B3 x. ?# M# |0 s- E
feature_extract = True
5 q' ]5 T7 k$ U7 _" c 1( o' ^8 y+ P" {% P8 t' u
2
( _; Y5 y: k: l2 J 3
/ q" S( C! d a/ b 4
) u N) p4 e0 E3 l* X4 m" M0 f # 是否用GPU进行训练
~8 x a3 l/ C7 o3 q train_on_gpu = torch.cuda.is_available()
9 o" J% D& F5 c 1 o6 P5 \ S& S
if not train_on_gpu:6 R7 T; |& I- W6 a/ H
print('CUDA is not available. Training on CPU ...')# C5 f- t) I7 L% X
else:* {! m9 K' a2 A9 B
print('CUDA is available! Training on GPU ...')
) @' h! u" ^. O+ |3 R4 |
/ m* b1 V$ I, _: U/ \& p) p2 v device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')7 C" r/ |. \/ G' q3 z
14 n* t1 b5 ]! G/ ^! x& L" I
2( I9 Q$ j, [. ~2 B3 ]
3
+ s6 }" e5 }, t: B 4$ C3 y0 N1 d" \2 s
5- {: l% B( V* W8 v
6/ G9 ^- v/ I5 [0 _
7! L# j0 J& w+ m3 h
8; X4 e/ G+ `/ l# r9 I9 {$ I
9: N" v T; p: y ~6 s
CUDA is not available. Training on CPU ...
4 a# r- l Y4 k9 ]' {* Y5 J0 T, \; w 1
; j Y; V/ A8 U4 k7 ^+ S! R # 将一些层定义为false,使其不自动更新, N0 J0 X( t8 n, R" a
def set_parameter_requires_grad(model, feature_extracting):8 s0 V9 ~ B$ I
if feature_extracting:) m' `" }8 W! X0 | b
for param in model.parameters():
% @4 m1 \- [3 R( x" n" d param.requires_grad = False
0 X4 H6 H: E- w 1( C4 M' u4 h0 s& X4 E' N6 D
2* c7 B8 I1 h v9 u2 B$ n
3
- G( S( N. B/ [, H 4
1 H' _) `4 h$ {, a5 m 56 N9 m. c9 u+ f3 A& {) s* W$ V
# 打印模型架构告知是怎么一步一步去完成的
: ~% m3 S/ ?4 o" |7 [% `% t # 主要是为我们提取特征的
0 ?7 d$ l8 x7 a! m - X9 w0 Z. m' |8 ]6 d8 c4 W
model_ft = models.resnet152()( G8 h+ f, z- n. a& `1 |+ p
model_ft
8 \2 O; g# `4 ]' b% ? 18 w( P8 D! T, }& o4 G
2# ~$ u5 z, H2 A" o6 m! {5 S
3
t. J2 z9 F+ w, y; ^ 4& ^" {/ D* X/ q% O
5, h/ I$ _ A- u& V) n7 b- p
ResNet(( E& T/ l0 I. o
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)* f5 i8 X1 ^3 t8 B( A
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
) W4 I- ~ L0 ]- H& G (relu): ReLU(inplace=True), z3 x e+ x( U) }+ J3 G
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
5 T; S* h8 M1 \) I( g; \ (layer1): Sequential(
, T( @& m+ i$ f( J# g1 \3 ^ (0): Bottleneck(
: I9 a, i1 Y7 z- \' U/ i Y (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
! x, U0 h6 S# z6 W) N" U) p (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
+ f5 R2 b! r" X3 A; V7 i (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)% J, `$ \& u5 o3 G& [, _7 x
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
' M Q6 r" A. O9 s4 f( ~2 O X (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
, W) ?7 G( q* @) I" i; d9 s (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)9 [' N& T2 R1 ]3 q% Y# b
(relu): ReLU(inplace=True): d* B# }: ~' F( x0 r8 }
(downsample): Sequential(, y) Z h3 u7 a% t- I1 ^
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)' Q5 T) ?3 A3 b
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)# E4 y4 q4 O; x1 D& j5 m
)
8 W. c1 ]7 J: p/ B! J2 @6 |1 [, ? )
- r, O, u2 G6 i0 s7 ]- J, ~6 z; R$ G5 O 中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
y' n2 T& |4 I: }/ ?" V6 z (2): Bottleneck(# {4 h0 s) F% W, `) E
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
$ P$ S8 m8 ]7 B+ J: m6 V/ R (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
8 y) b0 P; I7 u# g, N) { (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
4 ~7 ^9 h5 c2 o: X7 i0 I (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
U! H4 K2 r: Y. g s0 b (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
1 e3 ]' j: o6 R4 O8 U (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
. W: Q3 K' C1 X) p U: X (relu): ReLU(inplace=True)
' u& X# N- `! o' c$ k: ^, O( ~6 m )
9 u8 X* x) K( l. Z# M) D/ @& {$ e )- X; q9 N& w( O1 Z8 M8 G+ W c
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) T f- ~' ^: W5 \
(fc): Linear(in_features=2048, out_features=1000, bias=True)
5 W- N7 a& k4 f2 A* m )- w3 M1 c- z0 h7 t: m" {' m- t4 E
. y% y9 m2 |1 F2 g) U 1* N' C) p' `& D+ d
2
! m4 n5 D( ~+ o/ B K( q2 _ 3
4 L+ [+ P6 a2 b2 j5 i* S/ b/ \2 l 40 D& G. @8 _* F* r& N, N
57 f' O8 X# T8 o) Y. L( _) t# X
6
- b9 j, q& b5 T 7
" L) X: l; X4 ]% {/ l" z 8
0 _ [ t. b. o5 H* [1 l 9+ W% x4 D5 u) I0 c/ T
10
8 }+ n) X7 s: c4 U6 F 11
; Q* |) G6 Z& } 12
/ ?0 a( G2 N9 y. f& \3 H 13- {# D+ T6 p ^, x" G5 L; m; L
14
9 {4 N7 j9 ^- f" P! ], Z& A# x 15
9 W' A- i& x$ P) b4 Q# t 16! ?% k5 P7 L9 Y5 d) {$ j4 U+ [
17
/ A% `6 S9 T4 _9 ^ 189 F# B) b! K. h. R
19& Q1 M$ B& U% o2 G! Z
20
+ R7 R2 v* a5 s. v u 212 D3 Z1 _; }# k* G7 T+ O
227 w: c/ `3 |- O+ b) n
23
^: a6 V2 {- n 24
0 ]% @" _ I9 Z& b8 X1 S 25- X& g' ?: m% ]% Q4 d
26; Q5 Y3 `7 p2 r- C
27
; k' C3 ~5 z3 f" j) T" j$ ~2 l( e 28
9 z+ j7 `8 X2 s, C6 A9 Z 293 \: b9 u1 @& {/ E! u" i
305 H1 c: ?' G8 K* T. `# f
31
: I' X4 O1 _* ]7 \) b$ g 32
0 U6 B0 O" q" p 33! [; n. t8 W' I, `% E7 z
最后是1000分类,2048输入,分为1000个分类- C c. t" i! ]& G+ c! U) s
而我们需要将我们的任务进行调整,将1000分类改为102输出' d8 [5 i# m$ ^5 a( t( s9 u
3 Z$ I) a& i* n 6.初始化模型架构
1 v; ^& ~9 {% Z5 g3 O: @# f1 b 步骤如下:
% t$ a5 j* [) M- n& {# V7 M8 d9 Q
8 s; z5 U" i3 v0 a, [- U/ T0 R. C1 I' O 将训练好的模型拿过来,并pre_train = True 得到他人的权重参数3 Y- F- j# X; B$ m9 W1 t
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False): y, O7 K- X! C2 K O8 B$ [
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
- L5 h( ~& c. C- n, K Z% ] 官方文档链接% F1 o" F& o( }$ w7 [
https://pytorch.org/vision/stable/models.html/ \2 @4 b. [' E/ r& L
/ p# N( O$ U. V # 将他人的模型加载进来6 ~8 J+ F! I! ?0 \9 @
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):2 U* X& A! x- _3 r6 o
# 选择适合的模型,不同的模型初始化参数不同" ^0 c1 |( x& \+ ~) |: v) ^0 n* |! n
model_ft = None
% z$ m1 {( t* Q9 | input_size = 0/ b+ T; o& ^- Y$ X+ B' Y: V, O- ~: h7 W
* [9 `$ D- A% ~% N3 i if model_name == "resnet":3 L' X, S# }8 S4 Y1 L! P" s+ B
"""( C; a, S9 L, \, W5 a, B6 E7 Q. d
Resnet152- u- Z0 O! t8 t. T
"""$ t- T5 ?& O5 E' X" ^6 a
4 P+ Q2 {' d+ M5 K* ]* ]& h* J # 1. 加载与训练网络: S5 O. i- K8 d& ]
model_ft = models.resnet152(pretrained = use_pretrained)* {. s" P$ N) Z
# 2. 是否将提取特征的模块冻住,只训练FC层
5 T& h/ |$ m! |: I7 ~8 y3 d: T set_parameter_requires_grad(model_ft, feature_extract)
( n4 F8 g( O ~2 `' ^! z # 3. 获得全连接层输入特征5 C% l5 p1 i d2 {" i0 }# ]' ^. T- G
num_frts = model_ft.fc.in_features
$ Q a# |2 Q4 F # 4. 重新加载全连接层,设置输出102
% Z1 R' b9 x( p7 M. Y, z$ k5 w# v model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),7 u+ X `8 M" C$ J
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
& Y# b3 T0 l! G& w6 x* @1 l" B# k input_size = 224, W8 [$ C. Y% S- Y. _
6 T# d0 k* [0 {- Z C elif model_name == "alexnet":" E1 x! K; T! s) C
"""4 f* y2 G1 E7 c1 J! I; g+ n
Alexnet6 h8 c; U9 @6 i# j; i' X3 Y
"""! l! `) r0 w& N4 ?" U% h: S* ~0 \' }
model_ft = models.alexnet(pretrained = use_pretrained)% L. ~% Q- q% E# E
set_parameter_requires_grad(model_ft, feature_extract)
) L K$ @9 i( f4 ?, j) V + |7 W e7 Z" q: x4 @
# 将最后一个特征输出替换 序号为【6】的分类器
( Z! k7 E8 Q. O9 y5 Z H. V num_frts = model_ft.classifier[6].in_features # 获得FC层输入
5 z9 T5 [# c7 @; o/ m0 V: R6 J model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
! ?# d2 d0 g2 C, n input_size = 224
# S; v3 D# [; K& y6 r # y) W" t# U' [+ I
elif model_name == "vgg":
5 }( h5 L2 A" l5 ~7 c """
! a& f8 i( g: V# H2 n" X: t VGG11_bn8 ~! u. z9 e% h/ o2 g; z7 {1 z! s) q
"""! e# O" `6 @& U% I5 ^/ U J
model_ft = models.vgg16(pretrained = use_pretrained)! Y# c" J/ v$ P% E5 S4 N% y3 ]' ~
set_parameter_requires_grad(model_ft, feature_extract)
# y0 I; S s0 r! [5 Y- X. d num_frts = model_ft.classifier[6].in_features6 s) Z- b5 `* I! A9 Q* q
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
/ r* B/ o0 D3 i input_size = 224* S2 S5 Z% P8 _4 G; H: Q6 _, N
* U- Y; x7 `8 t( W
elif model_name == "squeezenet":4 ~# C8 j+ G4 Y# H' @# o
""") T8 y8 E+ g& G
Squeezenet3 ? N; I: a8 A! M( l6 y
"""
$ ^ |1 j R" i5 `, J model_ft = models.squeezenet1_0(pretrained = use_pretrained)
' \) W7 J8 A! z7 j4 `, \ set_parameter_requires_grad(model_ft, feature_extract)4 A5 V2 O; s: `8 |
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))
" L4 L5 V5 N' x, H- z. X# o9 V model_ft.num_classes = num_classes# E+ v5 O( B8 s
input_size = 2244 \& H8 Q9 i, ?- [9 e! h8 e/ K
5 |3 b/ P. J' I elif model_name == "densenet":
$ T1 |3 {2 \. |3 ] """
9 D( N- O$ _& ^8 d Densenet
' L- ^0 i2 t5 u5 C; M """
8 `% Q7 [9 z% ? model_ft = models.desenet121(pretrained = use_pretrained)
2 E) v: o, a/ n5 C set_parameter_requires_grad(model_ft, feature_extract)9 F, w% `! m, K( h" b; c. Q% s
num_frts = model_ft.classifier.in_features
I; N( z) r) `2 \ Y) c' [2 w model_ft.classifier = nn.Linear(num_frts, num_classes)+ ^* x3 D, ?0 j% e
input_size = 224
- O9 o' U. z0 ]( }( P$ i M( }
' m, |3 t: g0 B, D9 }9 J elif model_name == "inception":
. U4 z8 }( `/ p v/ y; h1 q- | """
' Z' ?0 s" D' }+ s4 o7 L% L8 N Inception V3& k& W6 T+ o$ x; k) O
"""# N8 l3 d) t4 r( y2 W% L: n
model_ft = models.inception_V(pretrained = use_pretrained)
7 m" J' k( `9 i! Q) [) e set_parameter_requires_grad(model_ft, feature_extract)) j2 F e" o/ h/ {0 ~7 A5 h: N
- a* g' {! K! C( k, o3 k
num_frts = model_ft.AuxLogits.fc.in_features6 o( }6 }6 D. {# _+ |3 q
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes): U; i3 H/ ^, a6 }/ O% g
! A8 L# D2 }4 C" _3 U) C num_frts = model_ft.fc.in_features
/ R4 ~- W; K, u model_ft.fc = nn.Linear(num_frts, num_classes)
" G% w' C# T6 q) a1 A input_size = 299
! ^7 L/ F, K, l
$ I3 t( r& q$ f; ^) I2 u else:# C+ V4 `: G* U* {: j& x0 r
print("Invalid model name, exiting...")6 j2 P6 v* H5 c$ N& ?% n2 C: `
exit()1 Y+ C6 E! B# I# R3 s& E' L; T1 f
8 c& {. N: q; f7 w( g4 b* a" m
return model_ft, input_size
, O7 ^2 W4 F+ c: R* s + d' h' o8 B! f! K; U
1
3 |' @9 e6 ?1 {) ? G4 w 2) H1 w) R: [. j5 N
3* z$ `4 _* H8 V2 c
4
& V3 R, c' y' ~$ _ 5( R/ k! n0 ~" j& v1 @
6! _* a1 E: _) J; W
7
# \+ F; m I# `, U 8
9 O9 O! G; ` {5 x$ b 9
! J& L4 i9 A) [/ n: v0 ~ 10
9 D$ i" U- Z1 [: V, {/ | 117 \1 T4 c& d' F% C
12
2 B- V# U. i7 g* m 13; r+ [# {' m" M5 Z* P& ~
14- U- J O1 T- [9 C! P# p% m
15
7 @& K9 A% ^1 g5 O, B 16; M; Z% R( V4 S$ _& g- ?
17
7 m) A& Q2 S1 I+ |( F+ E; g: L. E 18; A- Y4 {. C W, g+ F0 X: J
19
6 M2 _0 s6 y; A5 U; p 20( \) n9 o& q5 c, G p
21
7 L0 a% N" C$ i 222 I$ \/ d; K& h$ _
23
: T8 c3 S5 ^5 t) |0 K5 U# { 243 Y: ~" D4 N4 S) ^0 B/ z
25
& P- l+ K8 \: I- ^3 B, n 26( g# _8 Z! U: y: b2 A
27( r/ P1 W9 `, @' j& F5 ?! }- S( w
284 }/ k1 @8 R, D" p/ o
29: K: s& X3 [; _' U+ {+ P$ z
30
) v2 t. |' y& m ~# t3 g 318 n" q. Y9 A5 Z9 ~$ F0 J: J
32/ s6 }2 V' B( {; L- p. }! J6 s
33
" i; |" H6 V$ I: U' n 34& m1 u& p( G3 Y/ ^' _; ^
35
. g2 S3 z( |& N M+ R O 36% s; m" o( s# ~
37) M; F! W0 `# X) G: Q ^/ T
38' e/ v- e3 K* A. d+ e3 h/ _
39! g$ N4 q+ x. Q( c1 q
40
) R R) Y% |% c 41/ u1 {% F$ [' _0 O
42; o3 G3 G. }3 o/ l, b. H
43- s' A" L/ R# h6 b
440 W: [, `9 ]; I/ I. }7 j5 K
45+ e2 Z' p, e( y- Y- f0 x5 b
46
& K1 f: v" W6 e6 n6 D d 47
6 Z0 l+ f; L3 f0 e 48
; ?+ e* S$ G4 I( S 49* F( J* g! U/ K' L/ l7 R) ?/ Z1 m
50
* X; D% {9 C+ f4 k 511 H4 | p! S1 J2 _7 E1 H
52
; M' I! T5 j3 d$ v6 p 53
7 m" h4 i! B5 p$ r! z/ y 54$ ?* }8 D/ L+ H5 h/ n. a$ w3 r
55
) C7 y( s8 Z7 D% L$ d5 a 56" `% \* _9 S: Z) O" |" ^" s
57
. j+ K% C# a, A 58
9 ^2 w6 F% m8 x) p [6 ` 591 S; o; s# o" ^5 I
60
% j$ A' J1 T) S/ Y6 H& } 61) A/ X/ C$ t/ v- M
62
2 ~& t0 o8 |7 k 633 w1 B0 q5 _! ~
64
1 O, {# H! n, k/ @6 b7 r" d 652 G/ v0 G4 K( W+ ]0 J* u7 g
66, V% y. r0 f* {: ^2 T
67: p. J7 t' E' Y n/ ^
68
' }7 t* t9 U# |( R5 R- d) x 695 u. l a/ }1 Z( c
70
' X* l5 R _ u9 G) F 71
6 k+ Z& x! R+ O6 ?# P0 n 72
9 e' U# l+ ^6 v 73
N+ _$ {7 t8 Z/ V: G 74
+ s2 T( N+ a5 ^! l) ]) M+ L 75
& {6 n1 n5 o0 V% Y5 l. Y. \ 76
1 L: ]1 `5 T. ~ 77
+ ~* N9 I3 V- a& ~ 78( G8 T0 l5 q+ S; U+ I o
79
2 v( |+ F0 ^, F9 ]; e 80
8 |5 c) I; |7 T# \% Z 81
3 E1 v3 q7 N" O/ C- c+ s 825 r L/ V% f) K4 J" E
83& [* X3 U9 e- f, Y
7. 设置需要训练的参数# H# Q7 i6 K/ L& s/ U' y
# 设置模型名字、输出分类数$ F' j$ j: R) G: ~8 v! j. m
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)8 V$ k6 E+ g/ b& g# z
9 B1 E4 x- Z9 W! P! e
# GPU 计算 w4 ~% V# P+ J" `6 O7 W# s
model_ft = model_ft.to(device)
1 X# h4 o9 D- y2 e$ X4 e / i- g/ E8 g: V
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
0 `9 q1 Q& h. P0 z3 V filename = 'checkpoint.pth': _2 J( ^# ^) Z _% Q
: x6 B' V k: ]) y4 y
# 是否训练所有层0 C8 y5 g2 @) J2 V! C$ p
params_to_update = model_ft.parameters()8 @- ]/ J. r( ~ N
# 打印出需要训练的层
9 K- E* b+ q0 x print("Params to learn:")8 D: O3 S$ }) b& C& }4 i
if feature_extract:$ ?+ b/ I) G: U/ V: L
params_to_update = []; i7 r1 V( N" V. o; c2 |0 @& j
for name, param in model_ft.named_parameters():5 O; p" _2 q- I7 h9 m, Q( B
if param.requires_grad == True:
' C* I/ M3 `; E1 K! x# O params_to_update.append(param)9 C$ M! O- j( v* `4 M, }/ y3 m- b
print("\t", name)
6 X& n+ Z# w( j' c" [/ } else:+ ^- b# W" b( o8 Z, ?
for name, param in model_ft.named_parameters():
9 L' D1 v+ @7 [" U I1 {8 b X if param.requires_grad ==True:; a- H4 ]& \$ f5 ~
print("\t", name), v& j7 L: [5 i& \' D
}& a, y+ i. e2 ]
1; b9 i3 v6 N: H8 U# ^( D
2
; W ^) r* e8 w5 ~# e* t" r 3
1 n/ }" z& V" X: A 4( Z- h7 V& v [9 a
5& T( M' \6 j. a# N6 j/ x1 H O( ~
6
0 C/ s% i) r7 I2 w1 _/ u8 P 7
5 n$ G( N" v/ r) D7 e 8
+ E4 J, L1 q* E4 k2 {/ L 9
5 Z) Q. O' s+ \; w+ D7 p, n 10
! G' O4 C/ e8 K 11" W9 J4 X; u9 }, x
12
# D6 R1 ^6 D7 L& D8 { ^6 U+ U6 _8 o 13! U, x) y& P8 D
14
" B+ U& O* |) V- t' l$ E) z8 { 15" U- x6 J7 ^2 D9 ] A- n F% @
16. V2 M/ n, ~5 d" u' _1 ]
17
/ j/ U' x" L" _4 s5 I 18
9 v+ X$ J+ d T) n5 ]/ X$ h 19
, _( r# j9 K$ H 20
1 I! z/ E8 C& I+ U J$ q 21
5 {- a2 e7 N+ h$ U 22
/ g+ W1 J4 O% q" |9 k* s, S. E/ n 23
) a! |5 I k+ l8 ^ Params to learn:
! B9 l/ j" y# J+ Y1 N fc.0.weight
$ h- s1 L# t8 g# X fc.0.bias
( y0 _* c; }- }5 r' o7 J 1" y( C c2 u3 C% n ~
2! n. g& K) ^/ S. z: |/ t6 B: V
3; J3 J6 t: a S( q
7. 训练与预测
4 G+ q Z |/ l( H( i4 n A 7.1 优化器设置
- g9 L2 U/ ?% k @) X( l m# g4 w5 A # 优化器设置
4 s0 b/ d; z, q5 @ optimizer_ft = optim.Adam(params_to_update, lr = 1e-2)
7 f: I e# R, C. v# P # 学习率衰减策略% M% t1 y( ~# Y! L1 H8 A
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
: H' u# L8 ^: T # 学习率每7个epoch衰减为原来的1/10
, p- f+ Y1 m2 _8 O* ^ # 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
% @; w7 q) B/ g- E8 G. T
' Y' V4 @$ j% [; W- ? criterion = nn.NLLLoss()
( E8 Z1 {; D: s3 v" i; [. o" ? 12 ]9 e& A n5 w! L$ }
2
5 X+ ?8 O& \( Q4 [, o. ~; n 3/ d2 L+ C1 }/ [* w! O
49 [, Z% F. I; T( V
5, Y# j% g% E Y4 F% t @5 n& A8 r& {
6
% z: Z C# ^2 u7 G 7% X# g1 e& E5 v. m( Y3 Q; V& K
8/ K8 P k: r; s
# 定义训练函数/ m) |3 h+ S5 e3 r* N2 I0 `
#is_inception:要不要用其他的网络0 z; L9 _' X! G2 l! v8 n' u
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
3 }' D# L6 R, h8 `+ z6 ]% z1 _ since = time.time()
$ d+ r' a4 K. B* K% d! M #保存最好的准确率! l' O: G; O% ~& |2 u0 r
best_acc = 0
0 V( h( p2 u b6 f" s( ~# v """
k( z7 `) Z+ K3 S9 k, r$ L checkpoint = torch.load(filename)9 ~& k: L4 v; ?: W. `; V
best_acc = checkpoint['best_acc']
3 N+ [+ n) c5 g: \7 T8 H model.load_state_dict(checkpoint['state_dict'])6 Y' d) B% t+ h, R: l. c% J8 q& u
optimizer.load_state_dict(checkpoint['optimizer'])7 [0 u* d z; t F i
model.class_to_idx = checkpoint['mapping']; ]$ v, M3 ?1 l$ y1 S1 i/ h( ^
"""' E( l. q7 ?$ y" ^8 @
#指定用GPU还是CPU6 A1 t1 A* w% [, v- P3 ]9 Z4 {
model.to(device)
" \' q! M/ c9 p$ A% j- g) n: { #下面是为展示做的
/ X' C9 b0 p7 z i val_acc_history = []8 w; t6 v" j9 O9 I7 S# w+ ~- j* Q& u1 c; w
train_acc_history = []
) a: W, [% d/ U7 h& s0 G9 n+ C. I train_losses = []
4 u g- D6 O# C. F valid_losses = []
0 v5 H) y: k& _, p5 `+ p6 n! f LRs = [optimizer.param_groups[0]['lr']]
. a! J" K% Q0 r' g4 H #最好的一次存下来: Q! S# k% |' g0 m" j9 q
best_model_wts = copy.deepcopy(model.state_dict())
+ w6 m( w8 e# A9 ?# A5 [
+ A. b4 o, ]0 m5 j8 k for epoch in range(num_epochs):. `0 f, j/ W3 @1 \+ p% A
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
4 w5 @; E# }! ` J print('-' * 10)
0 L7 o3 F7 X0 M+ q5 _$ e / Y# n F* e& @. }0 h
# 训练和验证5 C+ O7 g. ~1 R3 P
for phase in ['train', 'valid']:' @7 k1 x+ ^, b7 h
if phase == 'train':3 [0 F% Y6 s8 a7 g3 ^- m
model.train() # 训练
8 @* _% Z* ?! T3 P else:
; i* y! @& A- K model.eval() # 验证/ j5 }) h9 k4 e7 k x- x5 i
" W) J/ {$ [! t: O1 v6 J( D running_loss = 0.0
- a/ n7 {+ l# ~: Q6 s' K! i: {# r running_corrects = 0
- c* H" l6 s* U3 g! o
- b- U8 W$ Y" ^; N1 f$ q, F # 把数据都取个遍
1 S* N1 L- `# S2 w8 p for inputs, labels in dataloaders[phase]:
" s5 O- M/ P" S j #下面是将inputs,labels传到GPU! w$ j- v# d* h z
inputs = inputs.to(device)' u. {; J6 @+ f! w
labels = labels.to(device)
6 e; {: [, l/ P$ z5 n, p
( A c1 g! C+ U0 _9 k: J# x # 清零
' b! ]6 ^; ?+ B' T$ f O8 w optimizer.zero_grad()
( }. G1 Q1 ]# S. J7 u3 y7 _ # 只有训练的时候计算和更新梯度
7 F' m8 v" m& u# F+ M- J6 s& B! f with torch.set_grad_enabled(phase == 'train'):( x D5 V, w% [7 N/ N6 z+ |: P7 ]
#if这面不需要计算,可忽略: e, F; I, T# M" s% u
if is_inception and phase == 'train':
2 l1 F/ }5 l9 h7 `$ Y; V outputs, aux_outputs = model(inputs)
) B& }1 E1 B; _9 x3 u% Y D" ] loss1 = criterion(outputs, labels)
4 q* ?; y. q# Z7 C. o* C- S loss2 = criterion(aux_outputs, labels)
9 l. [7 t& ]! T# G; X loss = loss1 + 0.4*loss2
: K' g) G' }% d- a$ ? else:#resnet执行的是这里
2 f4 K7 j* Q- A/ G outputs = model(inputs)- y! [4 p6 ^1 F {0 \) E( X1 V
loss = criterion(outputs, labels)" \, {; F/ a( P7 C+ h
; _: u+ ?7 i c7 E
#概率最大的返回preds4 J) P5 m, l/ b- X- ^
_, preds = torch.max(outputs, 1)4 J2 ]9 b# F5 `0 S' |; T4 X9 p, K
; ]7 }* c+ i* k o9 O: F. l4 ~
# 训练阶段更新权重
2 z g2 Q; R5 o* }6 O if phase == 'train':
# k4 Y0 x' l/ z. }6 Q loss.backward()" ?0 b( u/ x- E d
optimizer.step()2 w u9 C) U8 ~" ?$ U
+ c3 ]' D7 v1 \7 l6 c& D
# 计算损失
4 I: x; X; u; y running_loss += loss.item() * inputs.size(0)
Z2 f9 y6 U1 x0 E1 e running_corrects += torch.sum(preds == labels.data)
' k' L W* Q9 j5 C) r. c- M % X3 L: q) S3 Z v" v
#打印操作
: X( d7 f. F) J# p+ Z2 F epoch_loss = running_loss / len(dataloaders[phase].dataset), E) V4 i: o% X- U3 V1 _
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
1 Z: q1 g. d1 ^1 C6 T" T% p 3 H6 K5 }" G- l6 \& v/ f I+ g
2 b. k2 }$ }$ o/ R% k' L
time_elapsed = time.time() - since) J4 M0 m( U; P4 V
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))( j5 V# D9 _# L9 n: ?8 c9 s/ y
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))( l; p5 b1 X, E0 }6 S7 }
( o5 ~- l1 J" j/ @; C2 p& d 8 L" J3 j6 j, j/ D# [) O8 H4 h" B" r' ]: k S
# 得到最好那次的模型
b' K9 q$ h1 e+ f if phase == 'valid' and epoch_acc > best_acc:
3 h1 x6 s. F" Q- e9 o best_acc = epoch_acc
' N* u- \: i9 p$ c #模型保存
7 M" V0 i2 o( y/ i best_model_wts = copy.deepcopy(model.state_dict())
, P* B' ~) M, m2 m state = {
* d$ K3 ]. ~5 F* T G #tate_dict变量存放训练过程中需要学习的权重和偏执系数, @+ R0 R/ l! e8 s0 m/ ]
'state_dict': model.state_dict(),
, X0 u. }5 x4 T0 H% `8 }+ ` 'best_acc': best_acc,
# ~0 v8 A; B/ o% ^0 T# Y* W 'optimizer' : optimizer.state_dict(),; K9 B# r- \ P8 d2 ^; ~) o t
}$ x, `. H, ~+ c* Z( N* u4 g" q
torch.save(state, filename)
% P9 r! x3 B% K% M6 Z if phase == 'valid':
9 \1 x$ { H) O6 E" a3 P2 P val_acc_history.append(epoch_acc): i5 ^: @/ Y" J3 m5 Q
valid_losses.append(epoch_loss)
4 v8 J& w# |# v0 ]/ m+ Q1 k scheduler.step(epoch_loss)6 W* x$ @# e- |
if phase == 'train':+ A$ u. @0 i% l$ Y
train_acc_history.append(epoch_acc)
8 k! {9 r- _6 I1 |2 h+ f+ g" w0 Q& u/ [ train_losses.append(epoch_loss)
- {' }/ a% S9 g
o4 g5 c& k& f8 D print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
4 h8 R. T- P% a3 @ LRs.append(optimizer.param_groups[0]['lr'])
0 S- C& D2 c: z2 w print()
7 T7 Z* @) T# ^ t* u2 B& G
( |) q: ~( N s; { time_elapsed = time.time() - since$ T; e& x; Y6 }
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
S$ v- @- G* x( a print('Best val Acc: {:4f}'.format(best_acc))
n1 T- m* U- ^ : v. a3 ~2 ]+ L) F
# 保存训练完后用最好的一次当做模型最终的结果
7 F3 |0 y u- [1 Y model.load_state_dict(best_model_wts) i, O O# {7 n" H; _3 `
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
+ S1 ]2 M ?" a m5 o. h8 o: [9 L" i
( J' U, g& ~* T% Q/ a3 P( y 1
3 i* Z5 t9 \! Q4 S 2
/ n: ^$ p4 L6 m. r1 A: |" { 3 H; q2 n! u/ n" I. ~
4
0 Y& k2 ]: Z' a& G0 B F2 M 5
& J" ^* q3 y+ t( d5 j2 `! | 6' @* y3 T4 H3 z" M$ a3 D% f8 Z
7
1 P: R% P& v; u 8
, |5 } i+ V; p, x* I/ b( r 9% A5 X" _( L) g& e
10
) X8 Y) i' c. V+ u' i 11
Q5 A3 T7 A- r: P 12
& L/ J8 W" W# c. v, N 13# O% q# J% ]; F: }& i
14
! x9 n7 p! q4 ^ 15
, ^( Z; i2 e m( A 169 t9 ^" O5 G# L; A# u
17
) Y2 Z: p4 w3 F8 f0 r$ Y! B, ? 18
* u! n) J' |% T1 H$ n) \2 s+ } 196 C. E, k1 _1 G3 |9 H0 U7 Y" x' J
20
% z9 i* P/ d* }5 s/ ]4 J 21
% A7 u7 [- a& T 22
) z/ V1 s4 ?4 B ?; T 23
0 }. K! \. i; B9 ^& X0 d$ Q! z 24
( W7 h2 { f$ P/ I. H 25& D! A7 P5 Z6 u/ C# U+ ?/ Q6 z7 M5 ^
26
( g. [# [% n' G8 E8 ^ 27( g. e+ T* x! \5 w8 P$ C2 ^
28" T! f! G+ N* w8 y6 d0 {
29
; c+ D* \! ` q2 ~! i7 m: k 30
( t/ k% |' k w; H8 w5 T 31
. L! h9 n6 P0 d' S 324 P9 Y$ L$ b) C8 Y
33
0 l: K6 J# n; f, ^; ^: B 34
: p. a: Y4 m: N9 L$ } 35( [& V5 |' E" d- T
36
! X& K- @( q# x+ n) ~2 |9 [# d 373 r0 @/ S7 k* e0 C, N1 y2 j0 F
38
& |/ N6 D! K8 M% h; H+ U3 a2 w o 393 z2 r9 N9 T( \
40. t: i; B0 B6 @, T5 M" W0 C
41
2 u( C L1 Z& d' j: E( L* a 42
; y. Z# R5 m9 @" T. U 43
8 L+ i5 o! B& a+ ^8 S 44( p# Y1 r) Q0 ^% Z5 k3 R. R9 d
45. u9 `% r, p% W% W" s/ g
46
. e8 G7 [! d0 |* T 47
& l/ b; t8 q( ?. E9 J! M$ ]* N, ? 48
) u2 B; J7 c. J. s; d1 p 49, ~: s. o7 ?& Z C" a
50
# {# s" `" o0 l/ f. H 51
5 ~9 ^, ~/ f; u7 w7 d 52
5 n% [* g1 K9 N, p# ] 53# K& t" T" O X
546 {, c" P0 @0 w
55
+ h: [* c! A- ~/ Y# { 56
, S; W# ~6 |! |/ z# l- ? 57/ N$ W R4 [( G8 W, P2 R9 B
58# v. r0 I/ y1 m! {3 o, d, M
59, V2 v6 U, h8 H6 o7 t
60
1 @; X: q& a- g: r 61
: R& F6 K7 r0 O I7 T; j* i' U F 62" \! G t+ }) U3 i
63: k/ N, L, l! S0 \: o6 E, d7 G' j# X
64( r6 A6 r' C- L# n) ~0 _8 o
65* p* P" A; K) J+ t
665 z; J( ]& t; X. p
67, O% u: V7 v& O2 V' I
68
+ X# w9 {3 a# N% ] 692 `0 `0 R8 u: Q- F
70/ W& _5 O! h+ H
71
$ e# @* p7 k6 X 72
0 @: A ^2 m) {' ]7 j' O5 L 73, _1 V' D2 ^1 i' _/ c/ a# {
74
9 ?9 K! Q+ d* H( R+ `+ g 75
! b% W3 Y# T. w# O( D% X 763 @ x5 S }5 [9 H Y. O5 z
77: o6 a3 X1 W1 L/ r' v
78
" Z5 _/ x: R+ E2 u% y9 p4 S" n 79, @6 [1 m6 G* h5 L8 A3 C4 t+ v9 c+ K
80
" O9 e8 b. s) c* [5 ?2 c0 V) L 81
2 V1 h0 l0 D% S 824 c+ J+ [/ B) V5 }2 t7 Z7 r
83
( b9 o4 ^; s& g$ f; Z 84
m* R! V) N- J+ o6 b1 X, ], v 85
+ F4 T' Z! h$ A- e( L/ I- q 86
1 D' F; }# }) O) ~3 N* h9 x 87( s$ L( z/ q* M' |# v
88
. Y) L7 R- R2 V+ @, e* O( P 896 @9 ?' I% Z* c2 F, x
90
8 X+ _# q$ U m: c0 I, m) N) F# E 91: L o4 S3 @% c# x
92/ k8 ^, [/ Z# [3 o, H/ M; Z
93
9 B1 u3 U& }0 o0 I6 }" V1 k- d 94
" C+ v" j* m/ j8 \# d3 P 953 _7 }) `5 c6 D- ?) |0 O& K
96
3 q) h _, E$ Y& z 972 ^6 O1 i: Q2 Q# p# ~
98
/ V( R0 Z; ^% O9 T, Z 99
7 q' S* F6 r% J2 u( h- I7 z 1007 ]9 F) E& N1 U& q
101' y8 [2 ^- F; Y3 A5 L( y
102
# w5 y, ^2 `8 g0 q 1034 b9 t. k& u$ n1 \
104) c! q: S. l0 B) X
105: @, u- `) Q" n
106
% I' c# _8 `9 p a; U' q 107/ [1 f5 a) t9 P
108
( r9 J! K9 U6 I; X# {; W 109
; Z2 ]" D5 u) I% n 110! _, a) Q* { r
111
' b) f; r' I$ |& \ 112& f- S' i/ x$ ?1 i/ t+ O# g
7.2 开始训练模型3 P/ t( z2 L9 r( z+ y
我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
1 K+ o0 f! W2 t! b* \0 C ]
% t7 q) J' Z( \1 ? #若太慢,把epoch调低,迭代50次可能好些& E& `" p6 H4 N6 V+ Q
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
! X3 K5 N7 A1 ~& G. l model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))! @; A! Q" q9 X: X1 ?7 @
7 e3 `" g, J& s( m" C+ E* n 1$ Y/ p; V$ ]# J* b" I% J& J
2
) _- K" S p" J% n/ C* L6 X 3
5 D1 q% F6 a" n3 \- n- b5 L/ @. T 4
: r+ ?. U5 b+ K' v3 N2 f c Epoch 0/4! v9 t0 j4 {) j# n$ n# |4 m$ H
----------, b' ^3 b* ?9 x; e( ?9 q" R' C
Time elapsed 29m 41s4 o! ~5 s& ? h! g) S+ V, ~
train Loss: 10.4774 Acc: 0.3147 n: f- j. B! L1 Y* Y: ?
Time elapsed 32m 54s
/ D5 A3 z, k) N: V: ^ valid Loss: 8.2902 Acc: 0.4719
7 ]# H, x9 X0 S0 e9 ^. R Optimizer learning rate : 0.0010000, T/ O$ }4 a& |+ b$ y% E: A
' s3 _9 Y7 \* Y7 A; u0 m
Epoch 1/4
% V) {, B% t$ a# B; L2 w% V ----------
7 [/ y; t3 n; P6 ]# V Time elapsed 60m 11s
2 J2 v1 T. ^; v: Y1 S train Loss: 2.3126 Acc: 0.70530 x/ @) `/ M6 F' }
Time elapsed 63m 16s# l5 W1 H! @" j$ F
valid Loss: 3.2325 Acc: 0.6626
( V0 n0 d0 x9 `; P) O Optimizer learning rate : 0.0100000& p+ l6 m7 [. f5 d& D
4 M& C, B+ T# _
Epoch 2/4! U, c( j9 J" B5 M9 \7 ~' W
----------
# k! n$ t; u2 K) g, S# F8 t Time elapsed 90m 58s3 U6 X5 n: Q1 i P: Z E$ Z, p
train Loss: 9.9720 Acc: 0.4734
{* x; G- P0 X: x' x Time elapsed 94m 4s( A, `5 E: D4 x) K
valid Loss: 14.0426 Acc: 0.4413
! r, M; r& }* a Optimizer learning rate : 0.0001000) K( U! R" }- @3 {# J6 i+ _
! \3 a- A% l. H) T% ~
Epoch 3/42 u- { M3 p# L& j9 _: P, u
----------; C$ a9 o1 f9 ~5 {) H$ a1 K0 `
Time elapsed 132m 49s. i, o+ d" s5 z* P, s) _' B
train Loss: 5.4290 Acc: 0.6548
; G. v( N5 x% {, P8 \1 [: U Time elapsed 138m 49s+ }2 @! n1 F1 I4 h
valid Loss: 6.4208 Acc: 0.6027
) W0 h) D$ m5 T% h Optimizer learning rate : 0.0100000
" Y( s R6 Y! ~0 @3 S( t- w! f
, U$ v' n6 B& w6 B# l p/ S Epoch 4/42 a# k% v0 H* c* F; B
----------
- j9 T$ O2 r! H/ {) r8 u0 D Time elapsed 195m 56s: W& E6 r8 t8 W
train Loss: 8.8911 Acc: 0.5519% V( P9 k% ^* ~, U) x2 S
Time elapsed 199m 16s0 u/ }) h L. s, q( L( b
valid Loss: 13.2221 Acc: 0.4914
0 J0 o7 ~3 U- N& A3 |0 x( q Optimizer learning rate : 0.0010000
/ ~* q/ t" Y1 X" j/ _* R
5 @' J7 v8 s9 m! h. l" e7 Z Training complete in 199m 16s
" T+ A) g# V* Q Best val Acc: 0.662592
& }+ B1 ^1 G% j$ S9 d/ N5 d* K, P ; |$ J0 l6 r1 Z5 U4 L# {
1& w4 Q: o* u3 G* L8 {' U
2$ J( O# R$ Z* e4 o
3
0 b4 _$ X0 `/ Q 4
/ P7 P5 I* \3 H/ {% a 5
3 M/ W" k$ i! N/ w 6
g' G* s. m6 x2 R/ Y 7
9 h6 u2 a6 X# S# [$ W. A 8 U* T4 d1 {0 ^5 ]; O5 h9 v+ [
9& j9 Y1 P. t7 m2 S$ o2 `* d. z
10
/ |6 n8 W: r" j s 11
( ^3 O% a; r; O$ U2 d 12$ d Y4 H6 N6 d6 [
13
$ q6 B N: _) p9 I0 n 14
* ?5 b" K/ ?( p) f5 C) B: K6 v 156 |" l* u0 I: r% I% u
167 O7 c1 n* `8 n% T N1 ]
17
! K- W. m( S# P# M6 |) [ 18
& j0 S( M1 u' C7 O8 V 19
# V6 A; K% u& k X% n 20
, i9 t. r* f+ F- u$ I0 a; ` 215 D) D5 i, B$ @6 B* u
22 s- O+ E1 A2 j( F8 Y
23
. ~* M- \/ F: V) p) Y7 _ 24! S# n, ~, h8 U
25
: ]) e, f. i0 x" Z2 V% l4 r 26
8 o% U5 v* X" b% ^8 j7 I 27( j( J- E% }5 O1 M5 O
28% `2 }( z3 Q" @1 C) S
29
6 M) H! |# A6 Z- E2 n1 t 30
6 F$ P i u. t6 C 31
2 E& Q! M+ E6 w 326 o1 R. L0 _! a# s: \3 T
33
8 n; N0 p p; ?* Q 34
/ T- z1 U) y+ O7 \$ h! g 35+ s1 _2 k: Z4 u1 w( f& L
369 l0 q" U5 `- {
37
* f4 }8 K$ d' a. T2 t. B 38: d( l" W O+ I$ l0 O
39
3 [# u8 k* Q$ {$ U# W) Q; b 407 e2 G$ z: @" I/ X2 T( S b; H
41' m8 g' M+ M; y: e3 q0 J& @
42$ h& I- s* I0 o; L% L" x
7.3 训练所有层. R7 X$ a) r( n1 x* E
# 将全部网络解锁进行训练
# U- N4 Y0 v+ D7 n* ` for param in model_ft.parameters():
% S6 |4 y* N2 {$ ?; W$ E param.requires_grad = True
Y E; }8 h$ Q0 i/ {
5 Z: X. W$ u6 c* f8 e$ c5 Q6 d+ Q # 再继续训练所有的参数,学习率调小一点\
1 ?4 N; l7 ^; l& T" ` optimizer = optim.Adam(params_to_update, lr = 1e-4)" w5 X1 G/ C' M! L0 I
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
" r l! z2 o- z! J% R
* J- h, L6 ^4 q1 e/ X* e # 损失函数
! C% w. R% l* q1 k4 v& h& ] criterion = nn.NLLLoss()
* L5 }/ N u( z, l 1
% S' E) z6 J: ]! s 2
, r, m6 { v6 ^* E' G( S9 a 32 M! u5 S( x( f1 G _" u' d
4# s9 G4 A. _. }8 q0 G2 ?7 [
5& k/ x) W. Q! O! r1 T8 H
6
8 _1 B) s2 C: J! p' z5 p 7
# t4 I* R( }: B3 S- C# ?" z+ Z 84 {1 ~, Q4 {* R, b+ X
9
h1 X2 ? M! K3 H9 N( l9 y 105 m( V' S8 H2 v
# 加载保存的参数
# e1 n" {3 X9 G( k0 N: B # 并在原有的模型基础上继续训练4 p# S8 v5 P. h+ ]! m
# 下面保存的是刚刚训练效果较好的路径% z( \7 ?2 U2 k& G; x
checkpoint = torch.load(filename)
. |' O0 Z `! F) D# W best_acc = checkpoint['best_acc']
" H: O, [ | W) F; r+ r- h* H model_ft.load_state_dict(checkpoint['state_dict'])
- J7 C8 `/ W4 k: X8 M2 b optimizer.load_state_dict(checkpoint['optimizer'])( m( ~! w1 q) @
1
. I. u1 ^! Y! G# `. u o 2 [ R! P. K" I1 K; i
3" U6 ^) g8 c1 I( v" [# ]
4, i$ U: _- U0 P/ t
5
- T% Y" ~! z9 ?, J- H 6
# z) @/ l7 }) d9 `; g$ [ 7
k. G/ H8 K3 l& X) | 开始训练2 T5 b& M' q) Y( x! W
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
; `# H- W y0 }$ T0 k# h$ l : K) Q6 i% P/ ?/ X7 T# R
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception")). A _( Y* C* p& i
1- s7 M1 ~: }" g1 B( }
Epoch 0/1
F/ p, n0 o4 B* G2 M ----------
" Z% m9 O) _: e$ ?% E$ x/ |( E Time elapsed 35m 22s* t$ u7 w! o& q- P: o
train Loss: 1.7636 Acc: 0.73461 v+ f. }, d3 j |. w
Time elapsed 38m 42s
2 t6 a% [4 R2 F3 V: L5 I6 L valid Loss: 3.6377 Acc: 0.6455: Q) C4 q6 c' t# M" z
Optimizer learning rate : 0.00100003 h- C; j2 S ^' A+ b+ L
i6 z( D- B% b* p Epoch 1/1
$ L8 F5 c" D: i, m* G9 W1 B8 | ----------' u- C8 Y5 l& B: F, n, q
Time elapsed 82m 59s
( ^+ Y6 v' q/ z train Loss: 1.7543 Acc: 0.73400 g& r7 n w" x: H! G1 b
Time elapsed 86m 11s
2 I9 ^! Y) `8 ?( y' S9 q valid Loss: 3.8275 Acc: 0.6137& A$ q6 l9 L ~
Optimizer learning rate : 0.0010000
" h0 I8 G9 w2 b% H7 y( K* d1 Y; D 6 i) r# i) G, G* V0 m" S& l! H
Training complete in 86m 11s$ p$ G8 i. D" [: z. m' q
Best val Acc: 0.645477; ^4 C! P# c x) u5 q9 _6 e
7 P1 U9 Z+ T- j2 h 1
* x; y/ W2 A( W! t 2
& A# M: }: ^2 r 3* s _+ |4 P1 H( S! n! N/ p9 A
4
' r5 c7 N! V$ z7 k6 h& Z 5
1 F, c L8 ~4 \0 t 6
3 W5 B/ Y8 @: q4 M$ |) G/ R 7( V2 F: l) V$ d; p; G) Q
8 Y! i J$ V( o2 ~: p
93 v# W2 I" G4 o
10
$ L" s" W5 u# `" e3 o9 z# m1 ?; S 11
. _7 c+ J# G' r$ i; D& d. h7 N0 r 125 q' w7 a8 `9 v& r) t6 S
13
) v0 f# i9 K* m7 ?' Q 14
u( J3 |( Z9 x. o/ f 15
4 M; r; E/ `3 S- `' q4 ~3 k" A 16
. _) I; D0 Q1 m- v* \ 17
6 W) ~/ }2 p; {! k- K" ?5 e# U 181 r1 @% j/ A G: T
8. 加载已经训练的模型
8 i3 h% I7 W4 G/ ~1 F! s4 v 相当于做一次简单的前向传播(逻辑推理),不用更新参数
- l6 ?; f9 a" Q" I3 }2 A I! ]- n/ d * O; b4 G' J* F# i7 r+ y# V
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)* a& d) o7 {+ ~ O C$ Y
) A- M: E; z9 p" w& o' ]
# GPU 模式
# } b- a/ M- @0 y4 H* T3 l/ F% j model_ft = model_ft.to(device) # 扔到GPU中$ Y/ p' b4 _$ e: u" i" a+ z
- k" @1 B Q( Y1 K( O2 S3 A% u
# 保存文件的名字1 r' }1 W2 X& L$ I3 r' ]
filename='checkpoint.pth'/ b' U# f( G( r
+ i3 T* M5 L) _ k* w1 c
# 加载模型
; X* R- Y. D0 V6 i+ G checkpoint = torch.load(filename)
" X, d7 l5 {% K1 h# n. W best_acc = checkpoint['best_acc']
' l. |" P5 V: e/ M; y" k3 g model_ft.load_state_dict(checkpoint['state_dict'])
& m% C9 M* x- \ 1
8 _2 Y d! j% \1 m1 _7 c9 J 2! d: a6 G2 n" K# y4 v' k6 S1 v4 L
3
8 l% P t7 P C/ l6 |, O8 L4 F 4
# N5 P& I4 {+ N2 ? 5
% J4 \# r; L4 p6 F( ?4 l 63 B5 A( T) m ~6 D- \
7" l; u3 q/ n) [* Y1 r/ e! u u
8
) h$ M4 @% A# l% Z7 e; x2 b 9
0 |0 A: Y2 M; E+ l4 }$ c 10/ G7 A$ S; J8 L: I' u
11( F% v' _2 M* v6 I c
12! K( A7 F7 E" k
<All keys matched successfully>& h: D8 f2 g' o* T
1
" X. R6 N: t6 I def process_image(image_path):
6 e- ~. j! G9 { # 读取测试集数据
/ P: l5 p) Z( h# C! M7 U/ e img = Image.open(image_path); ^7 J5 e1 I8 A- p( K7 P/ c7 f3 \
# Resize, thumbnail方法只能进行比例缩小,所以进行判断
( V @. R# q G9 [ # 与Resize不同 |: L$ ]# v" N) i6 p
# resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小6 e% E+ V& X3 M3 Z( A2 G7 B9 K# b0 m
# 而且对象调用方法会直接改变其大小,返回None
~# j3 \! O- @4 o if img.size[0] > img.size[1]:7 J* }1 g& h0 P# R6 E
img.thumbnail((10000, 256))
$ u: T% f O/ _" \/ o# x else:& l3 A' z' }! [! K9 x
img.thumbnail((256, 10000))
. k% G s: ?4 `* B {# H4 G / m) V8 }8 \. W& C" m+ T" J( Q# J
# crop操作, 将图像再次裁剪为 224 * 224- b5 L. j4 a, l0 c, B
left_margin = (img.width - 224) / 2 # 取中间的部分
' q0 i! D$ p4 W8 k+ ?: J bottom_margin = (img.height - 224) / 2
8 h; }1 M; n1 T* j$ _ right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度
8 [+ l$ R7 x8 { f- e top_margin = bottom_margin + 224
0 Z6 s* h/ K& Y0 B- [( }( p 7 k( F |+ X! v' H6 K. G
img = img.crop((left_margin, bottom_margin, right_margin, top_margin))
% j) ^1 s8 J2 J; p* U 0 q" z# _) i4 A! t
# 相同预处理的方法( \" o ?# j! q/ |8 h- o9 z
# 归一化
+ n/ v$ \5 A7 w. Q6 p img = np.array(img) / 255' H4 k2 v: m0 J: e% S
mean = np.array([0.485, 0.456, 0.406])
. T# s% _8 f& ^5 A- H8 j9 N std = np.array([0.229, 0.224, 0.225])3 Z' }( L2 n- \1 @7 @7 {& A0 V
img = (img - mean) / std
9 m1 d* i6 F: }! z
; {# V8 H% U! B e1 E- e3 W4 ? # 注意颜色通道和位置
$ Y0 m8 i* J5 R0 o1 k5 Q1 Y! B1 n img = img.transpose((2, 0, 1))
: b% \# J% L4 L# H( e/ I) u T4 S7 v $ V6 U& h5 f8 ?7 X
return img
: P7 T+ g3 l u% W. t/ b, ~; R _# i7 i7 I j0 I4 \( X
def imshow(image, ax = None, title = None):1 l# S: d& |1 B+ e1 Z" P
"""展示数据""", p h0 ?+ e1 u, Z' O$ C
if ax is None:
6 ?- k4 ~4 w1 g, |) Q fig, ax = plt.subplots()
. }- e+ d/ L( r8 M* _: S, H
% ]$ E( L% K. B7 v- } # 颜色通道进行还原
( d% F& ~6 X) u1 W image = np.array(image).transpose((1, 2, 0)). `( |. q% p. Y, H4 B2 v* E" w% |: ~
9 O: t, H1 m ~3 h% j) P$ o: |: Z
# 预处理还原1 W, U2 X# I8 Y7 i6 F( `
mean = np.array([0.485, 0.456, 0.406])$ N% a: A: L3 B- L4 r% B
std = np.array([0.229, 0.224, 0.225])) _& P9 g ^+ t( j* [1 n, g
image = std * image + mean
) H& C$ X& ]6 a; I3 O; v image = np.clip(image, 0, 1)- S2 N4 C% r8 ^8 c8 b8 ~* b) x
+ A' n& Y1 i) L9 O! [1 T3 y
ax.imshow(image)
+ ]" ]% j' p O, M7 F ax.set_title(title)
$ l; r( O0 r% C! c1 ] Y, q d
* ~, I, v0 @4 F% E6 A" U return ax
1 m$ _, F( k1 f2 [) J- H% }2 h4 p7 S2 M* r
+ J! G9 s9 u1 {. n" g8 j3 p image_path = r'./flower_data/valid/3/image_06621.jpg'6 a1 ~5 _# b5 {! A% K
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理. x; ^9 u/ o! D# H1 B8 G2 g
imshow(img)2 H% B8 }! m7 s1 F- p$ ]
! R. g2 p" g" c5 ]" Q' a' I! ^; t$ B
19 O! B* _0 f, d; ]
2& K* g5 o& A1 ]3 ]# @3 w- `! m* l5 O' M
3
B2 U8 w& c- H* \% r8 ^( [0 m 49 v/ F) V/ i% u3 S: u) {5 ^( Z
57 C1 k1 P9 w* P$ C9 g+ N4 W6 m
6- |. I7 O0 c9 {* R' I+ A
7
" x- A* u/ t3 O% O 8/ v6 i/ j I3 g) C
9" G! x8 D5 W2 c1 P! R! F, [
10
! u: g: \' A# `4 K) Q% K 11
' b' {# G. p6 J$ \( i% B 126 o! w; a5 V: j: W. O3 N
13
/ y6 ^, E8 }5 ^3 Q! i 14# K( j5 u% L+ M) W" ]- L5 l( `
15
2 x b8 d) m. O; q5 w% p0 W 16+ Q( V0 h3 q3 `, q
17; `& o2 U! m9 ^7 R
18
3 q- f; u1 x! Z6 { 19% g% {& T y6 O
20
+ a& w! K+ ^% g4 B; g 21
* i I9 q. T$ ~+ G3 N. X; D& E 224 ~) {5 `/ C$ d" e
23 O8 V& c! ^% G/ p- Y
24
$ p6 A4 d. i& [+ V6 b7 u$ } 25( X: i6 R8 X2 j3 q
26/ r1 p+ b% w$ {. e
276 D0 i; z! u% }) a/ a5 R
284 R/ F R- C* U" n/ d+ ~
29
, y+ a# ~# x) h% w' G6 L! R 309 a* c! P. E% S
311 L8 c4 w0 y$ a. }# L# k
32/ K' c4 U! ~7 {" b$ Y/ j
33
3 s |; @2 L: H1 B8 U* c: z 34. p9 Y1 N3 L0 D* `9 a5 d$ D
350 s; x1 J( t$ B# G
36
- }9 X- f% J/ r5 b! x7 U E2 Q 37) h$ R7 A. o0 y& X/ f
38
' M9 }2 u }: i 39$ k* i8 C0 V- i0 R& m4 I
40
* M% T5 y# I6 Y 419 m5 ^* k1 q ?; ^
42
# |# K9 [3 ^) H8 ^9 N+ h {& v/ L 43. X# ~( t$ _* I3 V3 |, v8 V
44: m! o9 A4 _+ R K' A5 a5 @
45$ [% w4 N4 _8 O$ @& B
46
; ], B9 v2 G6 T- i r 474 I2 N, P' a E$ F/ u( ?1 p( k
48/ f* C: p5 ?- x% C& C( T7 c
49
9 [6 m* V- ^" Q; I: K 50
2 Z0 n+ u8 ^8 Y& w! B 516 M. p& x- V4 g8 V& [
526 s& X; W% y. D, J5 V4 S- E
53
& Q7 h) |" V, }1 V d 54
2 R; E$ D7 Y+ \3 D% c8 s5 Q9 y <AxesSubplot:>5 f! w$ |- _0 U% u, k
1
G3 J# e! K! {4 |+ u
7 o: F1 |1 P+ U# R 上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确& ]! q' c5 |" f# k
+ N. s$ w, S" p) P
img.shape
5 ?( E# }0 b1 h7 b& w( m) z7 P* o 1* H9 E* ^# M9 e% {4 {% ^' y
(3, 224, 224); l3 w4 n# \/ W% y- F
1; k8 V$ s7 }2 d1 S) [! i$ e
证明了通道提前了,而且大小没改变
' N/ G! d+ {& s( r( Q3 d% I" @
0 J" V2 K' Z% t* \- O d9 ?( D 9. 推理. s3 @* y8 t0 Y5 g
img.shape+ j3 O$ |" k7 E+ `
/ d7 o9 L T2 ?7 g # 得到一个batch的测试数据, y, F1 C$ A3 B, Y
dataiter = iter(dataloaders['valid'])
" u' k2 _/ N; w' f' t! u3 u images, labels = dataiter.next()8 X& V8 P% `4 B- ]6 i
' \' T8 Q3 d. `9 q: | model_ft.eval()# q$ @( Q, ?& K" ]+ N
8 S' D) r6 B+ h, A% p7 K3 a+ c1 d
if train_on_gpu:4 ]! Q4 l5 X) l# s* k
# 前向传播跑一次会得到output
. ~' B' [9 A7 s; l+ ~( q output = model_ft(images.cuda())8 T( ^# f1 C, E3 W
else:
. Y$ W/ k _& K( I: K+ F: [8 j output = model_ft(images)% @- {( m% O9 x9 b2 p) L
* X" |0 D% g- J( L) @! g
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值 j6 s$ M! p. Z3 E0 g2 A
output.shape
5 | ?: d/ A: E
, M0 }5 Z4 c" ^9 o5 @ 16 O/ L- a, W; r: Q6 K$ c
22 F j8 n6 Z: e$ E
3
1 l5 d/ C* m& M5 J9 t9 R7 }; c 49 B, Y! Z8 s6 z" f0 J m
5) c2 a# @9 X: ?7 p/ X
6( x- k: \) C9 J
7
# s$ @3 x2 P! t 83 d& @4 w8 M& X% a0 s3 R
9! E2 B/ N4 y& X* j7 w
102 A0 f6 V9 p- m7 j7 l) M9 H3 I6 v S
11
X0 J) m/ p* p3 f6 @ 120 u: g8 \6 K. s# f( \* h
13
1 {( a" l/ ^1 T+ B- V6 |0 j7 l6 f3 a 14
e7 U: g+ I$ o. s# x 155 Y4 [! T# u6 n7 L8 E3 v; l' U
16
0 X( `/ ~7 H7 r0 N6 ? torch.Size([8, 102])
: ?8 ]9 |5 y1 i6 f& M 1
0 h3 L: A7 Q, A8 H7 D2 K) c# ]* ^6 d 9.1 计算得到最大概率# Z. @3 w! q% E" x$ |+ N% X
_, preds_tensor = torch.max(output, 1)
" D: e0 U, W; E2 F8 `' H- _ . {& ~6 m2 u+ b& D& M2 W3 a
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量$ R2 ?. p& @ b6 B. U
1
, [6 K3 v2 _8 P) a! w: l1 u 2
6 x' d% _3 a6 P4 e) z( G$ A+ X: U7 _7 _ 38 |. ]+ q; k) `6 y ^' Z
9.2 展示预测结果% [; l; f. Y& c ^0 A. j8 E
fig = plt.figure(figsize = (20, 20))* w6 R3 z) L/ X) b& p) r
columns = 4
5 p1 r$ |6 S8 b" U rows = 2
5 }3 C; L3 V; s& ^; l! v ! d: `- Y! u# i, a0 G
for idx in range(columns * rows):
9 y1 W% ]# X$ l4 F6 o) B" C, g ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
4 _ W; E% I0 L3 @' L plt.imshow(im_convert(images[idx]))9 P) G& x F P( ~( M: x- r5 ^- S
ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), p9 w& h& ~; d5 C; W
color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
& s! \2 ~5 l$ v. E plt.show()
; g% ^. s0 |% o: B4 F* t # 绿色的表示预测是对的,红色表示预测错了
3 s) w" Z$ C& \# u% d 1& V9 B, F. J* X' z P
2! y% i( Y, W/ ?* \ _
3
' R2 S" Q% N/ ]$ Z+ C5 X F: ? 4( V. l: Z1 d: P8 |) u- c
5
: ?3 u2 `& i( j: K. ~5 w1 \0 X 6' }0 q1 m2 w* D1 X6 M$ `
7* u) \" y. X0 X
8, v) ]5 D5 p2 K( k
9: V. J: r7 {) e; \
10, N* F( j* \0 C+ K8 f
11
( G( ~1 z/ Q1 [ % A/ G1 ^$ p% Q9 U9 ~( t$ N
* C: X2 o w* z
: n1 \& l4 d7 P) n c( ^ ————————————————# x6 v3 u8 J+ X
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
7 T1 `) t7 a4 n1 K* ? 原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
0 ~+ ]2 c& N# s6 P1 ?, f) \ ; r8 p3 C2 S) h$ B, E' x& c
$ L/ U+ B& ]" E. f0 I
zan