数学建模社区-数学中国

标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例 [打印本页]

作者: 杨利霞    时间: 2022-9-8 10:41
标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
( i9 B% Y& m' _9 y  s
: q9 B0 g+ f+ n% B# _# _$ O! z文章目录
( e0 S, h( v$ P$ `' H* b卷积网络实战 对花进行分类
" _( u4 F: }2 }$ P0 s8 P4 F* N4 q数据预处理部分
/ a$ G& z/ S! V- O# [) g$ R8 T9 h网络模块设置
3 J; Q0 u8 w4 J网络模型的保存与测试9 f# X( q$ m; {/ @2 _
数据下载:( _6 E- i. z: m* e4 j
1. 导入工具包
( l9 k8 z% @, Z; w% M3 I# m2. 数据预处理与操作" f  w4 U' U1 T1 o! S" H) Q4 u( }
3. 制作好数据源- z* Y8 ^, X4 w  X
读取标签对应的实际名字
/ B- I: V# L+ G% E2 \4.展示一下数据
3 o% n! A8 q/ A( `. a* V4 e; B  M- V5. 加载models提供的模型,并直接用训练好的权重做初始化参数
4 x: T2 u0 S  S) [4 ?2 p- f, _. h5 \6.初始化模型架构
! a3 T& R+ O0 v7 @: _7. 设置需要训练的参数
5 {* J# G2 d6 l$ v, Q2 p  B5 |7. 训练与预测
* o# a8 v4 S! S5 ]$ C5 k7.1 优化器设置
2 R+ j. R6 w. T7.2 开始训练模型" M/ G" q) [/ V# I9 Z" f0 U
7.3 训练所有层, W  w' a! P- y" H& j8 }7 s, y
开始训练) j3 @5 T/ d5 D
8. 加载已经训练的模型2 F* F& d! W+ o* [3 i
9. 推理) B; H3 k# M" [6 w; \
9.1 计算得到最大概率& |0 n0 ?% t8 W+ _
9.2 展示预测结果# M6 t) `! |5 v3 E
写在最后
2 ?4 L/ L+ E0 M' ]  T7 `; J卷积网络实战 对花进行分类- G  `9 \2 w0 U) T; I) I! r% E
本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能
4 c% T+ h4 }8 l
1 o) F, P7 n& w5 I8 U在文件夹中有102种花,我们主要要对这些花进行分类任务: G: @  T4 C* H  A5 _- q; i
文件夹结构
1 L3 a  W% b6 K" e/ X* `3 I, Y/ M1 E/ Q3 N2 L# I
flower_data, h$ r  `* N! z8 M7 |! B' x

' h) o0 F7 b# I1 a8 Ztrain+ @# F4 m+ T" N

$ I" N# d8 r9 N' o- K  \6 s- `# L1(类别)
0 F9 t6 ]' K" ^28 h6 X* _1 A* F( o! f2 N% b: G
xxx.png / xxx.jpg
) b: L6 Q2 l* i  e: Tvalid; ]/ N( \  F* S: R4 |# C/ C
1 i# j" {9 @+ ~9 H
主要分为以下几个大模块
* Y' z- \7 K  J) @( L" W6 e# e
, t5 p. [- L/ I+ l3 [数据预处理部分
' w) N* `8 Z4 Q3 o' W数据增强2 a! z: Q: d. i# B8 I
数据预处理- h1 z* j* F: B  W4 {/ q: D
网络模块设置
8 v. O0 u% d; f6 @. l, m6 P加载预训练模型,直接调用torchVision的经典网络架构
! ?) C3 L- e. H因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
" b9 j# N$ h5 V' ^% J0 U( B网络模型的保存与测试) l6 Z# G0 Q0 U! C
模型保存可以带有选择性, J+ N2 [  E0 r! i( v) q# @
数据下载:6 x1 J8 z$ ^# j( \: N$ m9 N  F6 K
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
+ r1 f5 @" g9 }* |2 q4 U2 r+ e+ `( f/ Z* m$ z- p
改一下文件名,然后将它放到同一根目录就可以了- A8 E2 P5 S( @* X

* e4 \, u* n. r9 M9 P下面是我的数据根目录
' |% a. i4 y0 H! z) {+ B) p5 B; q8 N# ~: j# K3 Z
  C( W& g* o7 L: Q& a! [7 ~
1. 导入工具包
3 @7 b8 L/ S  S* h3 o2 z# Rimport os7 E7 F6 ?& }. T7 W& a/ b/ {4 X
import matplotlib.pyplot as plt
( {2 d" Y) s5 t) W4 |# 内嵌入绘图简去show的句柄
/ V5 f) `% n; {2 {& o%matplotlib inline / [2 q: d3 f1 ]  T
import numpy as np
; k5 B( t4 y( V$ |import torch) ?# }' t, p* B/ h8 @3 Q& t
from torch import nn! U8 E7 y1 w) t  ]# H& j

6 _6 z& Z3 v7 u7 k0 A# E! |) {" Mimport torch.optim as optim7 C5 X' J. j5 Y) w8 V( t% T" {
import torchvision+ F: z5 E2 x! M
from torchvision import transforms, models, datasets
6 E9 q# ~. y/ T" e. Z4 U3 F( ]) {, Q, ^4 i  y# K+ Y) Z$ z
import imageio
% W3 Z. ^& n1 S+ K# G6 \7 @import time8 z/ _' ^& l. c+ l
import warnings+ D# `  y# ]3 b) w
import random
5 g+ ]8 i( K1 @5 N; N' q3 dimport sys1 G  b* q" R$ n, g& V0 ~+ c
import copy
; f) D9 X. b- O  |( U8 ^) H6 }" `import json
# }: O5 w+ ^$ r4 O8 f. Nfrom PIL import Image5 l5 ^2 J5 t; L. b3 {8 S" r

' Y3 M' b( B5 u* n8 s* y. t; @, Z% Y$ O- K3 o* x1 `7 W
1$ z! _# U% A5 x% I3 `5 {' H
2% E8 }- r8 F# z# T2 h& u4 v
3
& H0 ^9 T$ R7 g0 C7 T) }4
3 j  f4 k7 F! {/ b5  e" ?$ Q( ?- E( V) k
6
; ^2 c% Y. L, S75 }5 Y& c$ o; a7 I5 C: ^& d
8- g8 `% i1 r+ \% x* H) A3 Z1 r
9  n( r& @( B9 _8 g7 M% v
10
- C# @7 y8 Q9 w& D0 M11" o/ [! E$ b' R/ L) e
12
+ o: i# Q' A6 j1 ~13. K3 E7 B4 t' @9 I: i
14
8 r/ \, z5 D, e6 Q15
. k6 V3 W( C7 O16
" p" a+ y5 e; [9 K172 u& g0 B& t4 I  U0 f
188 D( b1 N0 N+ D; A
19+ _0 m( s: R9 O3 e
20
/ D. [1 z& N  Q21
3 ~5 w. a$ r# N) }3 K2. 数据预处理与操作
# ^1 }4 P  e+ k0 e% B" O#路径设置0 _( R# V& y( S: \+ P
data_dir = './flower_data/' # 当前文件夹下的flowerdata目录' l/ X/ ?( a. m
train_dir = data_dir + '/train'
& U$ F* c7 q6 \, F. @5 `valid_dir = data_dir + '/valid') l8 x! u. M9 e& V' a$ {  ~- S
16 l) e' _& k. A
2( R! o+ f' B( L+ z) \9 L) t
3
4 O& e. d+ [; Z6 q) l4 D1 ]4
( x6 b) z% ]6 p; S- E9 vpython目录点杠的组合与区别
+ g' l: Z0 m7 M注: 里面注明了点杠和斜杠的操作
+ `( ]. |8 j$ n/ E) Y$ m0 p/ H
. Z- w) l& g: m% k- A3. 制作好数据源
5 l: w6 c& O3 N4 \data_transforms中制定了所有图像预处理的操作
9 J! Q8 _7 Y% y8 KImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片, ^7 s! I5 U! a# r/ d  j! a/ N1 n
data_transforms = {, m8 [( ^8 L5 Y0 y
    # 分成两部分,一部分是训练
4 {: E# [. [% q' i/ W    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间8 k8 Q& v; ~9 K1 s0 S- f
                                 transforms.CenterCrop(224), # 从中心处开始裁剪
' ^+ B: @! T. j. y. i                                 # 以某个随机的概率决定是否翻转 55开
" J$ D1 Z7 w" ]5 n: w                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转
  H( g9 c2 K, z* {" Y. e% p' p                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转$ T1 _: L. C) T+ \4 `0 H
                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
$ z7 V7 O6 m* g7 `! f                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),
9 r/ `6 Q: y+ C( q/ G  i                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB6 m0 p& B; G" M, Z5 X4 R
                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的; q. a# U" F8 _# B/ s: f+ ]5 |
                                 transforms.ToTensor(),5 {2 L' f# }" R  P* g6 [0 V* }
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差5 x  [" k2 ?+ g% p- r. o1 X, ]
                                ]),3 t" \& j! n3 s, F
    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化
3 p' C9 r# i/ Y7 E3 o    'valid': transforms.Compose([transforms.Resize(256),
2 s8 t+ Z) e1 T7 Y0 ]) D8 _9 m- V0 n                                 transforms.CenterCrop(224),
8 P# ~7 `( u: ~                                 transforms.ToTensor(),
; B( D, _1 O, m                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
* B0 Q  S4 c# y4 F                                ]),
; b9 D2 d: w% y3 E' C* m5 p}+ X( O) z$ K* H- E3 l

6 l) s- @" X7 x) e) h1) b$ u. P( E/ R
2
/ D0 f* r! a* D7 C$ B  H3  s0 n/ C& D5 w( {6 u
4+ t  P9 \+ h8 y1 g
5
- l/ D; k4 f/ h- s% ?! n% b6- L7 F8 G+ G3 G" U1 C4 W9 B) V9 t
74 |+ J1 W) ?) B9 d' ~+ j% _* ?
8
) F* T! h- l; F7 V9* ^( b  ]2 C! }- @
10
$ f- J0 C- A2 T- \! r11
/ Y# g& E, F1 N. t  q6 S12
* q0 a1 W$ j) j3 N3 U13- g/ P" R+ f2 x: z
144 h. r0 Q1 [5 W! k! {
15& ?4 V, W' H, T* }* _
16* C5 z3 e) `& d* ~, u, }
17
( y3 X5 X, y$ g) l5 ?18
$ [8 x( j& B& J; r! }19
0 h9 W1 M' _& s, E- T20
+ y. r$ }6 @. k; ~) H21- i& [8 }+ B5 }4 l, H! `( L
batch_size = 8, C' Q8 {9 T9 `7 `3 K! m" P1 D
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}. ]+ P- a2 l2 {( C! L) |7 g
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
: L1 ]& D: B# {( E; h( Hdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
+ R% k, _  m$ ^/ Q. Tclass_names = image_datasets['train'].classes+ z0 v7 W% Y* A  p
# u- \  n5 r% p' l
#查看数据集合/ m* B: W2 ]- v* {7 I6 u
image_datasets
% U  v8 h( r: E( U$ V
7 K3 F, P. q7 b9 R& f+ a' I1
) `, v- l0 B( g* ^2 P2
9 R" y- |1 D, H! k) h) r3
- N" F- ^5 X% j* p1 k( [; {# Y4
0 Z& F1 T3 z' I1 ?$ _3 v1 m5
1 ?& D/ ]$ [; j$ w- u6; E& ~6 ~5 |' D) {+ O, m: \5 m+ v
70 j: ~# \! d# w  ^
8
7 h* ^3 [$ ~& L5 ]8 k( z9
0 z+ J6 Q8 O; m* [) V! q" {$ s( W{'train': Dataset ImageFolder
4 N1 W) m3 U+ P# @) ~5 S. o     Number of datapoints: 6552
7 a" d9 G* L+ n0 p" R9 V% z     Root location: ./flower_data/train
$ i. e1 c6 |% y0 J% X1 U6 K+ k     StandardTransform
0 Z" ]; w; s* l( D$ U Transform: Compose(
, \. w) M0 p8 V                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
1 N  k# s# b9 @# _" i9 i                CenterCrop(size=(224, 224))
! W; E& H2 ~* ^. Y  T, J$ B                RandomHorizontalFlip(p=0.5)% u: s& T- i/ q' P0 f: l( V
                RandomVerticalFlip(p=0.5)
; [6 I( R, z% D4 i1 C                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])4 g, R  I1 W, y" K; c
                RandomGrayscale(p=0.025)3 h* o2 ^2 R$ T' B. W- W
                ToTensor()
' ^' F9 C5 K) f                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
5 Y* ?8 V& L; Z. I            ),  b; U4 I$ _% S
'valid': Dataset ImageFolder
9 F( g% l+ U( ?1 t     Number of datapoints: 818
; H5 d( p' l, u" `: U9 A     Root location: ./flower_data/valid
$ s( b" W# r: `! y/ X! C! X- y     StandardTransform
5 P8 _5 D& L) s Transform: Compose(
+ f1 C5 }# [+ o/ x                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
* f/ [1 u" Y1 n$ t' J! U                CenterCrop(size=(224, 224)). G6 W! {" ^, z* k* z  P% O9 g1 x
                ToTensor()8 U5 \/ }. K$ D* d$ |# O
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])4 w" L: @/ P- ~6 `* ~
            )}
3 j+ h. R7 r. x# n2 r: P9 H8 R6 R' ]4 z9 Y' i9 i8 z
1/ J( D+ }* I2 f
2
' O1 J5 ~/ D, ^# T3& y* [2 s$ ~4 S6 @6 ?
4, O/ K0 c/ [% A* f
53 g( r/ U( D: h* r& n" o. [, M0 z
6
$ v8 N* J, ]! Q7
8 f: v: w# m9 F. X& h+ {7 W" ~8$ g! \) U6 b9 Z1 ^
9( a$ y7 o* a" g0 N2 }: Z) r# u
10. K5 k. ?% @  u7 x; |
11
( {8 Q0 W" [8 |" ^2 x12
% ]% o4 ^, I1 `# j* L6 M13# m/ n" [+ j# ]( `
145 X/ M% X/ Z1 H4 P) e9 u, M
15
6 V$ T; A, g- v; H. t6 H; z5 c16
1 d. ^4 ~' }8 w178 [3 Q8 d+ ]9 |4 Z
18
! h0 B5 s7 M& Q4 p. I19
/ C% Q  M0 |9 R/ B20+ I' d2 O" n  X
21
1 l2 c" d4 F8 c. I, X5 U# \; M22% b0 J; v6 B5 p
233 C! Q: X) X. z& A; c( X) h
24
: x) B1 k- r+ `% t0 b! m0 N8 h# 验证一下数据是否已经被处理完毕
2 x& J( _9 z9 rdataloaders
% a) U4 d/ {! u1
) Q. x( R! K, p8 v" m" S$ h2: f% E+ R5 o7 \
{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,# e4 |7 E* r5 S( U0 \
'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
/ I' R" T" l% U7 R' Q" O* ]* s1
# y6 [$ @2 y' d2
; @" Y$ V5 ]! X; a) e, udataset_sizes
0 E7 ^- _  L6 a- D9 Q' l1$ O. v4 t' o9 A. y5 f: Q
{'train': 6552, 'valid': 818}
0 i; Y% d" n0 W6 K$ n! \0 }1- M% k1 B/ {+ V$ Z8 I" ~8 ]
读取标签对应的实际名字) P& g/ _1 J; C! W
使用同一目录下的json文件,反向映射出花对应的名字
3 @2 z& X8 u% L& `6 S
7 ?( L9 ^) I8 s0 Y/ R4 Kwith open('./flower_data/cat_to_name.json', 'r') as f:
; O3 n; V. v5 X- Z* r( g2 V6 \    cat_to_name = json.load(f)' |" h  ]8 c3 J# s* s! ^
1
; e2 i& O/ s3 D; y2* M4 \  o  H! U) e4 p/ e6 J
cat_to_name
0 I' s) D7 y' }2 l0 C# ?1
/ M5 M7 ]+ @; E' v1 T5 g' \: @( ^1 b1 F{'21': 'fire lily',
0 r0 h. Y* o) B! g3 ~1 ]2 ]2 V '3': 'canterbury bells',6 i" u. ~; \. l( w% Q
'45': 'bolero deep blue',
( F; |' `$ R) p# T '1': 'pink primrose',
$ w! {8 F8 Q/ A. ~5 |2 x '34': 'mexican aster',: M3 j! }  j+ i) @$ P2 m9 l
'27': 'prince of wales feathers',3 i+ }9 j* z$ e$ ^0 O- x' P% H  i* N6 @
'7': 'moon orchid',4 U2 F2 t+ }  ?9 f) {' X
'16': 'globe-flower',
, w& D8 N/ Z( _' q# W& w% o3 g% N '25': 'grape hyacinth',
) s0 r. L& T7 K/ L3 o6 S5 [& t '26': 'corn poppy',
# R( t; \2 @& t9 c4 \( J '79': 'toad lily',
& s! T; w0 m4 u '39': 'siam tulip',, e+ E  R' s+ j8 n8 `
'24': 'red ginger',- h9 k/ D. W; r
'67': 'spring crocus',
0 H; `6 ~" Z1 U '35': 'alpine sea holly',
* _" l& U- J: d7 e1 }: O+ q+ K! y '32': 'garden phlox',7 F8 H: ]6 i& z
'10': 'globe thistle',
4 C/ |. @# B! \4 C( x '6': 'tiger lily',
* o( l1 W8 z/ A8 h$ a5 b! x '93': 'ball moss',
: `- T7 I9 R5 o8 a1 r" h; R '33': 'love in the mist'," a# e% H, K+ V' S  B$ `% }1 `
'9': 'monkshood',- p7 p6 d5 }+ D
'102': 'blackberry lily',
5 ~% `6 C4 m5 ]3 l! E+ v8 z; Q; O '14': 'spear thistle',& Q* `! R! a8 z8 {, ?1 y
'19': 'balloon flower',: |; y$ X- E; [+ A
'100': 'blanket flower',
9 X: f" Y- q2 j7 V; p7 i '13': 'king protea',
( K# {1 d" y0 q8 \" V '49': 'oxeye daisy',% k  r7 q8 H. B& x$ O8 ?$ `# w
'15': 'yellow iris',
0 ?1 `! E$ [7 {9 H% [) r% g& k '61': 'cautleya spicata',8 |7 L( Y+ f2 n9 l* m
'31': 'carnation',( v2 i2 o# [  p% @3 ?* o+ }
'64': 'silverbush',- T, _3 K  \) s2 _+ o: u
'68': 'bearded iris',2 h, k, S, D" \! p0 ~! y3 t8 j$ T
'63': 'black-eyed susan',
# z; U  l7 Q. v. V '69': 'windflower',$ S0 ^% o8 X. G) x# X
'62': 'japanese anemone',* y' H5 p) W8 k+ N  N8 Z# U
'20': 'giant white arum lily',
9 B) q; @1 Z" O# E# D- V9 q '38': 'great masterwort',+ G. ?) T& m" f% J$ A1 D9 d: Y
'4': 'sweet pea',
) O5 y; P1 e1 j# y2 p, k '86': 'tree mallow',
# b5 h2 D! Y1 p! Z, D8 h/ K7 G '101': 'trumpet creeper'," N* F6 x: @. K1 [  B. e" l, {
'42': 'daffodil',/ s. r* Z; j, |; f/ ]
'22': 'pincushion flower',6 h0 z. @) T7 g' i6 H
'2': 'hard-leaved pocket orchid',
& \% b6 ^& x9 W" g '54': 'sunflower',
+ r& k: ?6 m7 k  f' o' c '66': 'osteospermum',1 A/ V# l" }/ z2 G: i% L
'70': 'tree poppy',
5 M- ~3 W% \5 S' r, J '85': 'desert-rose',
! x, y$ c3 Q8 b1 T, U9 b* n' L '99': 'bromelia',0 r2 \$ o. p6 E2 w4 v
'87': 'magnolia',
9 ]3 m: d# ~8 J: I- H/ d '5': 'english marigold',
3 a+ [6 ~4 h* h/ n5 Y '92': 'bee balm'," w2 `/ V- h7 s  _
'28': 'stemless gentian',
/ A$ F  d8 `% z  j '97': 'mallow',
4 g6 `# v( l# D+ m2 t0 Y4 H. E '57': 'gaura',6 p' k; V/ J5 \
'40': 'lenten rose',
1 q3 }- x. a) H% K+ |( f. k' G '47': 'marigold',
0 p6 ?8 V+ J, v* Q8 B: P '59': 'orange dahlia',
7 N* p% o( a- x8 v: a7 O% ` '48': 'buttercup',9 D3 k9 T- i. |/ N% l
'55': 'pelargonium',
& {6 |  `- l  J* R8 r: u5 T '36': 'ruby-lipped cattleya',
- H. l/ g" I+ z" \' [ '91': 'hippeastrum'," _% `% D; _3 x! M' ?) ^. u
'29': 'artichoke',
* A& q3 v# Y9 u$ Q- Z1 L# E* d3 C* ^ '71': 'gazania',1 B5 M  U0 C  \) x3 J
'90': 'canna lily',
% X1 @6 ]" d7 g '18': 'peruvian lily',
7 W. \' {2 `  ]4 G% I1 i '98': 'mexican petunia',
  ], n0 u6 x# D* H5 S9 Z. [ '8': 'bird of paradise',$ M1 d# _+ {5 s- R
'30': 'sweet william',# C, Y) R$ c  y. r1 p4 ^0 U
'17': 'purple coneflower',: @- B9 G  `2 a1 [' q2 Z2 C5 ]( I
'52': 'wild pansy',, T" S0 C  `0 A! j1 [- ^
'84': 'columbine',9 M7 u* z7 }+ d7 b" s' X; m; u; k
'12': "colt's foot",
% U7 f* G7 f- o# O5 k$ i '11': 'snapdragon',
0 V. m- [9 V/ U  t. `4 E; ~ '96': 'camellia',
( d5 g: R: {9 |% E/ h6 s% ? '23': 'fritillary',& M5 v5 A. G! ^1 o. W) i, d
'50': 'common dandelion',
$ _3 z0 \3 U3 @6 G) M! Q5 M '44': 'poinsettia',
7 [# Q7 x" `5 o9 G' } '53': 'primula',' T- q: ]% _. M" Z- H
'72': 'azalea',
0 Y, a( h" x4 o' t4 m- { '65': 'californian poppy',0 F$ {2 t  ?; c( ~
'80': 'anthurium',$ P4 _; y# g- F% A% \; c% ^* k
'76': 'morning glory',
7 b+ ~2 d) V. a- a; ~- i: ? '37': 'cape flower',
! [+ y# L" }; k: L4 H9 ]; X '56': 'bishop of llandaff',
) R2 w5 E. l/ ?4 u( L '60': 'pink-yellow dahlia',- ?, e, y, u. g
'82': 'clematis',* i2 u' L4 A3 K
'58': 'geranium',5 D) x: v8 L  D! i& ]: [/ b" F
'75': 'thorn apple',
8 ]1 G1 ], e9 A5 F* J1 t '41': 'barbeton daisy',
4 `9 N+ {1 K* w4 n '95': 'bougainvillea',
$ d/ p; S1 K6 d, A '43': 'sword lily',
( v$ G) t, E4 F$ [  |! C '83': 'hibiscus',! L' ~; ?2 K& _
'78': 'lotus lotus',8 j* B' A# M( H! `6 `+ u& H6 O
'88': 'cyclamen',( w% z! E* C' }4 w
'94': 'foxglove',
* l) y1 C3 g! F  f" I '81': 'frangipani',/ J7 g2 v4 ]% [# H/ l) w5 \
'74': 'rose',% |5 R2 @* X& v! F* |  d  @  l, w8 V* v
'89': 'watercress',$ W9 j7 Y; r! o4 K" u
'73': 'water lily',2 }0 |* Z3 n. G9 G9 z5 v
'46': 'wallflower',
# t0 K: z; y/ ^9 o '77': 'passion flower',
; f5 C6 F0 `- n9 m) g% h '51': 'petunia'}
5 t) _8 N! @( H  o9 O  L* c5 f6 d& M7 n" o
1
( w( A) v6 m! D2
" Z3 {4 `2 _0 C3
1 l+ v1 A0 L$ V% }2 L; {4
4 V- G  `$ `* J6 I$ V5& c9 U' `" ^4 s% }: R. [1 R
6
" T$ p6 h9 q1 i: g8 |% m6 N7" t2 r0 L" n" r* z
8
) [. u* @( t( _$ I: r: p. D3 a0 ?4 k0 A9
# _# i/ v" ?. I# q( H  \" ], G10* ?$ X2 s( Q; L. \# ~
11. i9 ~( Z: U7 t1 M1 }& y3 e
12
2 n0 n+ s, D3 G, u- q1 J% i13/ q- x6 y; B( ]# d1 q3 t
14. J/ u8 _) [) n0 v- H9 r5 V% P
15% a6 h0 M, [: x2 m
16
+ W: E5 X: E  ^# {/ C1 `8 j7 R17
2 c8 P, N* Q' p3 C  J6 I18" Q. r" G  _. Z! o( C
19
7 s+ G5 V2 L& m6 G5 W7 A201 }3 v  d$ l+ V: i5 J# x0 ?9 R4 b
21
/ N7 S( o" n7 t  Q, e22
6 T' w% h, U6 m# L8 D% e4 @0 q234 H# y: f+ x0 c& x8 [- \0 b- Q
242 A4 g0 B0 C9 s8 d- T! q3 a
25
! q2 A1 Q4 q1 G" b265 l/ N3 ~# V" F% e, y6 P
27/ v* x/ z$ ~6 P; Y  p
28
, W" t& j) b# N4 f; [# R298 z, k0 _1 q( }0 z0 V/ x
30/ K5 f# e( ~* G& q
31( P* t& a. O) j- o# I
32
4 j0 O  ~7 v! N4 W33: U. R, d* e  u9 R
34! U' o/ f. b; h! [  V* q
35
9 X% [7 e0 q: y  {7 V/ P- g8 K$ D( X36
9 K! r1 ?+ Y, J: y- ~* \9 l37) ~2 N, E3 b7 f: ]0 z# T/ _' u
383 z3 i+ N* d$ G, x
393 r  @1 L5 E8 ~8 u4 r8 G; P
40
! A9 c' p5 Q/ C+ R% Y7 ~5 _41
& B7 D0 c6 j2 M' N8 V1 [6 w42! W( i: o: Q/ m1 f* R# t3 h: W
43$ \3 @; d+ x  ~5 {! S
44
* `4 o* \; h2 D. v% y- z% }455 ~6 V! h, \- A3 y$ K
46
' E  r- H& _, y7 F  T& Y472 N1 j9 M# S# G$ l8 Z% L
48
3 O9 F6 w6 v% }- U9 @" O+ W49* ?8 x5 L, j9 v# V, g$ J
50) B, r- t# c9 `
51" Y" W. e1 U  K& T" g
52- ~. o1 d" q/ y( Z( g8 X) C/ s
53
% v# [" v# z2 ]7 l1 k/ y' Q/ S* A54
3 B6 s* F0 I1 L8 `" Z55
& ?% A( l2 p# e' e/ d56: v" x; C9 i- k( j/ t  Y5 s/ a
579 H1 c1 B# r7 @6 k
58
4 H* K; P. q4 p59
- V! j" I) d& [' f60
! ~& i4 m* _3 V* s7 q5 P% B( @. o61
+ F' K# Q: J( o% h; M62. ?2 x& L. w$ n0 V& Z0 [! |" d
635 b: U, r$ e" J0 a9 ]
648 ~! a) H5 k! a  O4 X
65) H- t6 K" k% D5 P& j/ s# R' V
66) _8 o/ V) i8 \: [
67& q- x1 x0 b( [( D9 u
68
# O; N  o$ j' b1 \! M: k2 l& e69
8 y# v7 j( T( T5 Q9 g70/ V: g& u/ _: c
716 o7 @, S" n/ F8 k  {0 P1 m
72
5 t" I* m- Y7 [' z4 m9 N73
7 C. P! j% F/ h, c0 l: s6 T1 k74
- _3 j4 H, U" m  r9 N9 H750 _% v) _6 R7 Z/ ]
76
7 r. [1 ?& ?9 V' X" k  I77
: i8 r0 M  g1 Z' h) G# c8 W7 V789 f: Q+ k3 X' M; y
79
4 ]5 ]# F% M0 V$ [80; g  x  }- M8 q4 J) p, t
810 w9 z* C' f  Y) e
827 c/ Q9 L/ U% Q3 M8 d
83
9 D: f$ x- B# o1 C+ k! B847 i8 Y* f! F; P, x
85( {" `) e0 r& |# a1 ~
86( j% P  S0 @5 e5 C8 @
87
! g1 x' M  i* s; O88
! Z: Q  G( `$ d  \' a3 R% v2 K89* t" u/ E1 }- k9 n  h- h) x, f
90
' X* _' Z: x+ p4 W* ^' O/ l91+ w+ C) x& w4 G
92
6 R! Q. r2 v. }/ o; d: ?6 l93
" N: f2 j: o5 J94
9 @1 G- O# L) k% V, j957 N" H+ x, S% q& Y2 J5 I
96
- X2 g# f% ^$ o7 [( s97
- }2 F8 ]$ l  O. u7 V; D98/ i% D. [0 W$ @, q) D, j4 Q
994 Q' r# u# H/ S4 N
100( n( a7 ?% d, ~$ M8 y, p2 L& s. ]
101
3 _. F; `# G  G102# U8 O/ W% X& ~* w/ E
4.展示一下数据' O8 {+ o9 j9 O
def im_convert(tensor):
2 K; m3 u8 b9 W( Z3 F6 d    """数据展示"""
5 Z+ d+ [1 d2 H0 Y' P/ c    image = tensor.to("cpu").clone().detach(): k) ]' W+ w0 o7 O, i  E* \% A
    image = image.numpy().squeeze()/ A* c4 j  K/ {
    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图1 o: h9 D, M1 k+ b% Y
    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
2 Y2 P3 g% ?8 W9 Y+ P5 P5 Q    image = image.transpose(1, 2, 0)
9 _! A/ f- [' C4 {- S5 k0 y    # 反正则化(反标准化)' j# ^1 y$ Z0 L
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
  ]+ H4 N2 M5 j  `2 ]1 T* r
9 u, L0 b+ s( ?8 X6 S    # 将图像中小于0 的都换成0,大于的都变成1
0 O1 d1 X+ }9 L) X' t) K    image = image.clip(0, 1)
' c1 ]$ @+ p6 i* l8 ?& y' ]: I+ y9 w' [5 P# ]* Y
    return image
3 ^8 Q; V7 l9 G% A$ p" H3 V. V16 C$ P6 t7 C5 ]4 q5 k# W
2
* e, F1 A' W$ P3
# T4 L3 y$ @. o  j  u# S  K4( W2 [& M. P& U. P4 k  K8 F
5
3 B2 R9 ?# `9 O# }7 i2 M0 p6- Q$ a' R0 _8 y* y- H. h$ v# u; T
7; e% R5 I7 i3 E, i8 D5 Q$ b4 S, m
8
) l$ H$ l( e( l( D; g9: Q/ P; M& z7 A  f  k
10& Z* w' V, }4 Z4 q# ]/ z, `; I
11
0 t$ T+ `1 ~% u12
% v6 j6 {' l/ Z8 L" }4 g/ n4 k+ X13# K! U2 R4 {# @6 M3 U' {# G
14
! T& s6 V; @1 O, |$ Y' {# 使用上面定义好的类进行画图( F3 N5 `/ x; M% P
fig = plt.figure(figsize = (20, 12)); [$ U' w; ?) r
columns = 4
$ r# Y' w7 }2 f; v' T$ h. Krows = 2% J0 x0 w9 }4 A( p
0 R& c+ {; f8 B+ P6 T+ ?/ B
# iter迭代器
' e3 Q' f# e& r& B. n% C# 随便找一个Batch数据进行展示  h9 y/ ?2 C9 }# F: o2 `* L
dataiter = iter(dataloaders['valid'])3 y4 u5 ]+ j) Y8 I3 D
inputs, classes = dataiter.next()
2 C) c7 _* l6 {' }% \- R3 ^
8 F4 {. _, S1 vfor idx in range(columns * rows):
, c4 O, J' Q0 U: T% F: }. c( n    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])$ m: u4 t, I9 [/ p
    # 利用json文件将其对应花的类型打印在图片中
6 G7 J* I' r! j6 y5 C' q  i    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
2 C9 S; d$ u3 }" e    plt.imshow(im_convert(inputs[idx]))
) n' a5 ~( D/ o3 F7 s- v0 qplt.show()
; L) Y8 r/ G  O4 F" a
9 n" ~) F+ e5 |1
* Y: H7 ?0 a% d8 T, i% W' g2
  r% ]4 I) q7 A9 Z1 }! {3/ N) A) Z6 w% ?+ v, F& P3 a
4$ G0 x9 g$ f. _. j! w
5
/ a6 P4 O5 v" p$ H2 G0 i. p# x60 I; m$ M3 @" @7 y: F# C
7# [5 i0 E5 O1 i$ n
8, J- q0 N2 S) f1 K& |$ j
9
0 a# M% p/ S9 y2 U% o10. D5 d5 D1 O/ G/ B9 Y! W
11
8 S5 T( k/ H7 T& z; z/ v0 }120 z  p/ {4 F& Q. d; S$ \
13
6 }" |& B! T5 i' h- V4 |; I14
3 S% |$ i& [; D. N( Q15
. P; g' b3 [) ]8 l16: [1 ?+ W& Z0 c& W0 ^2 |! i, h

3 G' g0 x0 W1 n1 z9 e
$ b& b" J, B- T7 n' F& @5. 加载models提供的模型,并直接用训练好的权重做初始化参数
+ h& \! {( f8 Z: \; r& Z; cmodel_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
7 C( u* m4 `9 U* c# v$ u+ w$ d# 主要的图像识别用resnet来做) w+ U  p1 |. u7 r0 T* q
# 是否用人家训练好的特征
( b3 L. l3 _. F5 H; |& R9 qfeature_extract = True" n7 M- [* P2 M4 b# S" m! S
1. t/ E# A: Y( k  b
2
, Q7 Z$ _( v% K. u* |3; X: q/ d! T! h8 Z2 w/ S
4
7 a3 l* U0 M' m3 i6 P; r# 是否用GPU进行训练
: d% q% x4 m( m6 S7 U6 P& l1 _: otrain_on_gpu = torch.cuda.is_available()$ W1 i6 n. r: J. @

2 c& H, @$ c2 Z: Y  O& S. v8 Iif not train_on_gpu:' ~. M9 A6 W( {, ~" J
    print('CUDA is not available.   Training on CPU ...')
/ g9 C* C! M! r6 i: K0 gelse:3 H  n+ v: [7 z$ f3 k* C  f0 ]4 ^
    print('CUDA is available! Training on GPU ...')
% f' k: o( D. `/ U7 V7 q+ r- C( b8 \$ I5 |3 x* `
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
* I- @" M1 b2 [7 P, b( J( R5 F1
! L7 r4 x; ]$ Z& e$ I, r" P2
3 F% o5 o) T, R3
/ \, v( `1 ?; D% O3 M8 m+ E4
7 S) l+ ~* |% i0 j! H3 E" Y5. c, |& K& H5 S, f6 G2 k4 p
6$ d, x" A3 Z3 j! l+ S
7& j- v/ a3 U( O+ x
8/ A9 \! z; h8 X3 \6 D
9
  S8 W9 G: i0 {3 T" T0 F4 u8 rCUDA is not available.   Training on CPU ...
3 q8 N2 l" u! @# Q15 G; \* U% ~" ]7 X9 H/ W; \1 t2 s
# 将一些层定义为false,使其不自动更新5 R! d7 _7 i4 A! g. S- C
def set_parameter_requires_grad(model, feature_extracting):; A8 P/ j' L5 G  _
    if feature_extracting:5 R+ Q0 O6 @) Y9 ~
        for param in model.parameters():$ {* N) R+ t6 ^" ]- j
            param.requires_grad = False+ S9 F7 k( n4 C7 G) X
1) e6 u2 m& s' Y7 \  \7 N
2
8 S- S# }* N  ]0 F9 S3
3 M& A6 j  G( C* S. ?$ f) n. O4* `3 d' e3 r9 r* j6 Y  T% w
57 \% i; u/ i( z" t! m% [
# 打印模型架构告知是怎么一步一步去完成的
3 o# \- ]; c: ^6 O$ v# 主要是为我们提取特征的6 N  X7 {! t; M; o
5 o1 l4 S. D* y1 e, h7 U* b
model_ft = models.resnet152()+ [# A/ f7 L) Q  i' m4 N  u& ~
model_ft
8 w$ t9 X0 K, v2 V  {, K$ w/ b1
7 N1 a7 L; F* [; }3 i4 x2, y# r% T5 C# T
3# X* ~$ r7 Z4 v8 F! O! n" @* U. S, `" z; b
4* H2 f3 W' m5 O" V+ |
5
2 W3 f, t/ |+ TResNet(
1 w. w" M' K6 ^' s; S  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)" G; ^2 m* ~, a# I
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
: X, `( _+ e8 D$ W. l0 D2 E  (relu): ReLU(inplace=True)% Z% P( D' Z2 S7 F
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False); g: q5 H: E3 i. v
  (layer1): Sequential(
- |2 ^5 |  Y& c! E8 P2 d+ |    (0): Bottleneck(
: _: U2 G2 G; |: c      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)% ^; v" z# Q1 C+ r" ^" W8 x3 v
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- D/ B* T/ e* p0 t! N# n7 \3 v4 M" y
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False): d. `6 N! _  k  [7 Y  X% a
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)- j, R8 ]& q. o0 S/ K; _
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
7 e* _# {3 A* ]0 p      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)3 A  j, U2 S4 ^+ b8 b/ k
      (relu): ReLU(inplace=True)) d7 w# E, k& c
      (downsample): Sequential(1 e( W$ p6 \! I, N. a1 @  n1 A$ x
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)) \2 i% Z; I' m0 G+ [
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
; U3 x/ c) s! ~# P  z      )
4 E2 C: J$ T9 `8 Z    ): |* p  P1 s& a8 Q+ J  g5 k
中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。
% k# f2 {5 O9 {' }5 _2 R) N, d    (2): Bottleneck(6 N( U, L# T$ O1 z! {
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)6 y# A3 L5 L7 P$ n* Z3 ^6 C9 g
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
* k$ L( n) W4 T+ J5 S      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
! Q- @2 R+ C! M& K; U# R/ N      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
! C% y, {6 a5 P$ L' v      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
! \- S: e$ b5 q4 T      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)0 h" w& U0 a1 l- U8 h
      (relu): ReLU(inplace=True)8 U( [/ _5 B  y  M6 q
    )9 H/ Q5 R+ M4 ^4 e" }5 P: [
  )
& n$ k) |& B9 b. _: }  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))- X( Q! K1 D5 B
  (fc): Linear(in_features=2048, out_features=1000, bias=True), S8 w- [+ y' K+ m! S
)
" e- Y0 w' O' Y( \3 _1 Y
/ O$ [) q) z) n2 B1; P0 l- v+ {1 r. r6 I! ]( p! P
2
! N$ h+ d1 L! ]/ A! C3
! ^6 R0 Q( b5 j5 @) H8 o4
/ u1 C- _" M  k7 H/ I5% B0 x2 [/ c# T) g
6
3 t( x; o( C: ^# ]) X7
# s: F1 I1 l  c7 S! i- V9 m8
& r, r2 L- k0 H, _! r' ~  A96 P) P& V0 u1 z" y5 t* N' D: n
10
! ^" \4 G5 i- n8 A/ Q& |11; c, C$ [& L4 i+ q0 H  x
12, c* p5 B$ o1 e& v5 a  E) p
13
0 ]( e1 F; u+ Q" P3 s5 M) z14
# Y4 F( ^/ G& G& e157 |! H( T. i, r0 R
167 o5 g, {# B4 K  t3 \
17! `; k  W# b* J6 m/ y, B
18" D9 K( w- H7 a, n
19# i& ?$ Z( t# @4 z% Z
20
/ Y  z7 {3 ~4 o. k2 R4 y21
; l- S! S" J5 q% v# v225 q" c* N" X" ?3 M+ `1 M0 W
23
  @8 y* q4 J+ q3 D% X4 i24
' y$ \; p4 ~3 O' ?5 `8 r7 `250 Y% v, r* _6 [9 A
26
, t5 F$ o4 E9 O. ]( z; |/ O27
6 M3 {& y; g+ W+ O) |# |! g7 X' [28
" ~! ^/ }2 b- n% [, K' W- `# v; h29
# j4 K. N: o+ @6 i, q  \307 Z9 s" u# j7 `/ g; o
31
8 A) ^" f9 G3 X8 Y0 e32
2 M6 V3 z; f) x8 I: F" x  @  P336 K: B0 p8 T, n; v2 S* P
最后是1000分类,2048输入,分为1000个分类
8 i; [9 D/ c6 f, `7 X  E" q, B而我们需要将我们的任务进行调整,将1000分类改为102输出
: j$ D2 z# X* O
/ W2 ^, _  y. l6.初始化模型架构7 s  v5 B1 ]$ |9 y- Y
步骤如下:
8 f2 R2 x% z/ F0 z. W$ Z+ m. ^
7 b: F- x- ?' G' v% m) i/ C将训练好的模型拿过来,并pre_train = True 得到他人的权重参数- b1 t" o. n% [3 z
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)% b. I. {, a0 R$ z" d0 l$ ~3 _
无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
* Z/ j# B) W* ]4 L; @官方文档链接
* Q. C' W( Y7 ]" x# j/ \0 y! ^# vhttps://pytorch.org/vision/stable/models.html
( h5 D2 V3 [8 s2 o0 N2 v8 t
% w: @! s. |% ^1 T0 w# 将他人的模型加载进来9 U" ?* o+ P3 u2 H& D1 K# b
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
( e2 ]) }8 ~' J6 E& R7 d0 j    # 选择适合的模型,不同的模型初始化参数不同) l) Y; R" M; {2 \. }& r
    model_ft = None% p: Z5 u5 k4 ?2 o8 f2 H: {
    input_size = 0
% h' I' _, k/ `( h1 l
" F! k3 H5 @3 m1 Y( A% V9 |$ d    if model_name == "resnet":! }4 d8 n, A) g1 `
        """
; K6 j( P: T' N, [# u! b        Resnet152: Q6 z  x* x+ b
        """
& n. A, e' J: f( S: {, H  H$ A
) d( y2 E" c6 V7 Q        # 1. 加载与训练网络( N3 A! j2 z( a+ ^- y( d
        model_ft = models.resnet152(pretrained = use_pretrained). J& y, |6 r+ ^( X
        # 2. 是否将提取特征的模块冻住,只训练FC层  z! j; z8 _2 e
        set_parameter_requires_grad(model_ft, feature_extract)
' O+ B8 n( Q% t  @* h$ z        # 3. 获得全连接层输入特征: }5 }4 Y. }* i3 q
        num_frts = model_ft.fc.in_features
: Y/ m7 {3 e1 l4 k  M        # 4. 重新加载全连接层,设置输出102: J8 f# w" S, w, M
        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),# `5 O: p1 [, [0 G( e2 |+ V. ~
                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1
" s4 c6 m' a5 g. v4 ^6 p        input_size = 224
# K2 t5 b6 {) W/ Y) p4 K0 T$ n
9 {* n5 y: x7 }' Y$ `: f    elif model_name == "alexnet":/ c8 ], }' j4 m4 Y: G: }1 b+ d
        """
' `( c2 n: o/ B4 S) P; S7 @2 Y/ u        Alexnet
  }, I. f; T8 ^* ~- N0 v        """
- G; ]- f. `9 b3 y5 u0 S  t5 Y& j        model_ft = models.alexnet(pretrained = use_pretrained)5 o6 f9 M  C6 N
        set_parameter_requires_grad(model_ft, feature_extract)
/ j( k7 e& C/ B$ @- f/ `- d7 o  D" b+ p* _* ]% ^+ a/ o
        # 将最后一个特征输出替换 序号为【6】的分类器) L3 n( a" h9 s2 z
        num_frts = model_ft.classifier[6].in_features # 获得FC层输入
6 D8 u: t0 M3 r3 R( C+ k* d3 Y# i5 |- S        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
/ q7 `; w; y5 v; R* I) k$ J( k9 U3 P        input_size = 2241 o; @6 k* X) u! B4 V3 y& z
0 }8 x- y) ~+ ^9 m6 t
    elif model_name == "vgg":
  x' t& o+ x6 R. W# |* u+ t0 f        """; h9 g6 U0 N3 R) m
        VGG11_bn+ L4 D4 U5 Y. {  E0 M; ?9 n5 h
        """
2 f# W. Q( K- y" u! h; k8 s        model_ft = models.vgg16(pretrained = use_pretrained)9 ~$ \5 j7 h4 b  w" Y) ]6 Q
        set_parameter_requires_grad(model_ft, feature_extract). E" m* [6 z! ]$ X! ]$ w, ~9 P
        num_frts = model_ft.classifier[6].in_features
8 A% @( Z6 w8 h% u9 p6 h! c# G        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
  K. a- P! E% `6 J* I        input_size = 224
: a" Y- Y" k0 l1 N/ X' N" [4 L4 C% m# a) \5 V, R
    elif model_name == "squeezenet":
0 Y6 C7 d1 O% P4 }* H2 ]        """
7 h( Q+ w& ]* G( l$ u4 k: d        Squeezenet' \) w" }# u9 X6 R  _
        """: [0 L# b4 Z5 b  c8 h& s) `( m
        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
& [5 X. O3 b- i        set_parameter_requires_grad(model_ft, feature_extract)5 w. `2 T, O2 |4 `$ Z: y- f
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))/ v  {& |- ?( ~# f" ?& k
        model_ft.num_classes = num_classes
* h+ n$ [0 K6 n/ b' y; z- ]; y        input_size = 224- {1 M  U& {  W: j
5 j! J/ R3 T) i" V
    elif model_name == "densenet":) m$ A6 y" e$ M
        """& g; _2 C4 F" X( f% `
        Densenet# y# N" h+ G: t8 Z+ F
        """' S8 _  m- P9 T
        model_ft = models.desenet121(pretrained = use_pretrained)8 q6 ?( N- x: b! s
        set_parameter_requires_grad(model_ft, feature_extract)1 B2 ]5 t3 e- l+ ]3 N# C: F3 r
        num_frts = model_ft.classifier.in_features
* _0 U! R3 |6 N+ J, a9 n5 g        model_ft.classifier = nn.Linear(num_frts, num_classes)
" I, G! _$ P" T+ ]  f9 R/ S        input_size = 2247 W3 b" W& h( x
8 \8 N8 _7 n  r# V' r2 Q
    elif model_name == "inception":. I9 e9 f: u1 o+ G% A1 O- m' a( h& e
        """
. j1 x& T5 ^/ E        Inception V3
* {$ S4 I& w6 h9 M# F        """
1 m" z/ T# V* @3 d: Q; |( P; J        model_ft = models.inception_V(pretrained = use_pretrained)5 u( x. d4 f3 Y2 ]
        set_parameter_requires_grad(model_ft, feature_extract)
7 O- @# o5 R$ P; Z  Z$ l
9 D; F. N3 R8 \; O+ z% q4 ^8 `        num_frts = model_ft.AuxLogits.fc.in_features
9 r  Z6 K" D2 E4 M% F! Q) s        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)- j2 Z3 O- L3 C
8 ]# }' Z/ M$ k# [/ J
        num_frts = model_ft.fc.in_features
  F! h: [: ~" a+ z. k& o        model_ft.fc = nn.Linear(num_frts, num_classes)& u- N2 m2 ]2 [
        input_size = 2996 a4 d( I0 S. u9 Y% l# R
  q( A, S5 ^& m; _. U1 J
    else:# N. d  ~4 y+ a: l
        print("Invalid model name, exiting...")  W7 }2 I( ^; X
        exit()
3 z4 r8 n5 \+ l* [$ S
! E* |) E3 L$ d. t    return model_ft, input_size6 A- Y3 G' k# n- C) ^

* P& {+ E* S' T2 y  v; ~0 a1, H& t- ^& Q+ R5 @
2
7 j1 A' u  w, S5 H6 [3, t: }* y. V/ G8 X
4  y) F; q8 g! u, F8 A. L
5
; z! c5 |$ C/ I3 Q( D6 N& ^. J6% ]5 Z( g4 h% w' X$ w2 ?
7
- P% R6 Q% ]. N6 u8
) w* p8 ]8 S+ D4 ~; _9. x, H! j# a( m/ q
10
  |1 y. S1 X3 i8 t$ a( [6 z114 Y5 R9 F" h+ X6 J4 |
12
) y& t6 I) k" m$ n13
# C8 p% z) ]& m! K/ y14, x) }  K; A0 F
15& y# A8 [3 N% m8 b/ q
169 P; Z6 K' Z# h8 D  l* X, w2 w
177 H9 B* l: S( \, O
18  u' _6 _5 A. f8 c# Q" P# l, o
19
0 ]8 G, Q5 p, o6 f7 I& V20
0 d1 t" X$ u0 K- |, }21
# \' _3 e9 o: d" q- B( m22
# C' j+ ?5 j9 k, D+ J23( o) z+ x# T3 W- P, E
249 J1 a! ^7 k, }8 Y5 }
25
7 T. Y( ~3 U6 C  z263 m& Y1 T3 a8 _
27: |2 v) C. x( B) w3 a/ I
28
5 N' W7 Y+ C1 M5 `: \# u$ y3 C29! W. j5 B0 m; b- C8 M4 C3 @
30" D" F+ u, P( I  ]
31* G; C+ z6 ]- u7 V' P
32$ p. k5 I6 r! A8 v& s4 J7 E
33
9 t+ u/ c! j1 m3 ]  b7 a  w34
9 p7 w3 H" P+ C( m3 r35
! _9 ]" `) D. g9 ^36
  o# {( A. i2 z- x) H$ ]37
4 v- L% @3 y5 m  `4 h& u381 E( X' }9 Y0 s1 S
39. S* w2 ?9 m% a# B1 |+ b1 T% U" U
40
3 b. M: {, t$ r418 X8 t5 l  W: U1 B( V7 ~
42. W- N7 D# ?' u
438 _* Q7 e" C8 j; W
44( U2 w+ I$ S( u1 L5 ]( G
45+ b8 m2 W4 h* f: x" C; R
46+ @8 _' X! n0 z. {! Z( f# J
474 ?. n; W5 x. z) H' |, s3 B; x
48* z4 I  R# c1 |) I
49/ R/ u! |5 p" [9 ^+ _% X! d
50
0 H/ B2 u) F( S  w- }; c' v$ y$ X' P51
2 I1 X- _! ^! Y1 ]52: Z( l1 H" h: e, D; m, J  z
53
; ]4 Z; @; l1 b6 b+ x! D/ l; x54
* O& M4 s, C" m& D: W55
" h' S3 G% }( X7 }3 J" h: l56- Z. q# P/ U/ x( x8 c( U, J
57
6 d6 W! T) ~5 D: [4 `4 x! i58  V# v' B; @, {6 c* B
59& Y; v1 s, V5 x4 {. f
60
0 ^3 ^2 J: _# ]  I6 s" f61& ?5 \8 f" Y2 s3 Z# O
62% z, Y9 T, f1 s# k1 O
636 n/ u, J% _/ ~2 c0 c$ C% o/ x
64
3 f8 L2 f$ {: ?4 n# q% X65! I, j! v0 P( M4 C' r
662 }8 w0 Q5 _  U/ J2 Y' Y% R% X
67
8 O( R' _& A" V' d68
- ?9 G; u1 J) |5 g( I3 a! T69
2 e: Y; o6 z1 |; F2 n70
3 W6 @/ L/ a2 u0 _2 A1 j71" B( w2 K2 V' |5 b
729 v& e* y5 }$ J
73
9 [1 {0 S0 K& r' [74
+ _% p& Y% c7 X  |2 {  a" g75
% J. e5 o9 E2 u) |0 _0 n762 I7 g: N& |  C6 M3 n6 r: ~& L! O
77- {: z2 `9 E; h
78
6 q5 M& W7 z4 Y79
0 }6 t4 _- R/ G) b- E80
/ J; r) }  I4 b) `5 L: h81& C. M0 u5 v! |% _3 b
822 }+ e5 q$ D: P- s: L* e( y0 ?, E9 h/ `
83
+ X* K- A6 R4 w" S( P/ u* m7. 设置需要训练的参数7 X6 b1 ]5 q7 L. N/ W) E( ?0 m
# 设置模型名字、输出分类数- U6 L/ x2 c; Z' {/ m6 y/ B
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
+ V' E  e- k  E( v; o9 t8 O% m5 r; K, N/ O" S& F7 m& |
# GPU 计算6 t7 ]2 K9 p3 Z; Q" O+ v/ I/ w
model_ft = model_ft.to(device)9 o- Z# V1 P9 z5 F* U
' {5 A7 G7 j& _# q2 @: f
# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取
7 W) b3 A  q! G3 F0 c8 N( Z' Efilename = 'checkpoint.pth'
. w' l& \  ^% ?% x
: K% l2 h; z5 W( {8 Q2 f# 是否训练所有层+ c6 e# [, d7 l! {0 O6 F2 {
params_to_update = model_ft.parameters()* g; m8 G/ Z  E4 J
# 打印出需要训练的层
" u) ?0 e, u) I8 t% k- Rprint("Params to learn:")
+ u4 @6 X6 ^# ~if feature_extract:" ^, k4 o& W  A2 `, q6 V( ~$ b2 u
    params_to_update = []
- }6 d$ Q7 i( H- y$ ?' Q1 l8 D1 @    for name, param in model_ft.named_parameters():0 l9 a- z, W: {% c. x4 u
        if param.requires_grad == True:
  B, m8 Z2 v- u            params_to_update.append(param)
# i# P7 p' b2 |; P$ F! C9 d/ I6 `            print("\t", name)
5 H% b3 j+ e+ velse:# U/ o7 q1 b8 Y* h# w+ \9 Q0 ~
    for name, param in model_ft.named_parameters():
  ]0 B7 h2 Q/ D" s1 N        if param.requires_grad ==True:
, j: E4 ]$ ^. Y# O" m% U2 M            print("\t", name)# ~) k% S4 ^  R' Q  M

/ I+ ~* P- B4 @3 g& F( w4 L1- u7 ^5 E3 b0 q/ y3 S- m# z% W
2
1 m2 c2 ?) F* b; |3 h8 R# g32 |% I1 ^6 T/ \
4# u: y7 Q5 c# _* [, D$ v  e' t0 V
5
, _/ W5 H2 f: z. t6. d6 m6 V9 ]- U0 R
7
( Q# h1 ]) d( t7 J# q87 h: H9 i. d3 \3 G  n* d
9
9 _  Y8 p( [, Q8 M10, q1 F; {9 V5 v0 j& M5 M
113 \# Z$ v/ N7 }$ Z: B3 s
12
/ p; k) L0 B* c) ^+ C13
5 D5 X* A9 Q% w3 ^4 @& l: Z# l" ?14
5 u+ O8 O5 {5 b  J. p5 h15
! [. B/ h2 ^; \2 v2 ?9 c1 w. ~7 _16# }! b+ B. \& m3 s
17" v+ {6 D/ U3 Z  q
184 @/ G; r+ T1 r. v" V$ C5 t0 n
19
- ?- x! g# I$ ]. K' O207 |4 E0 o2 ^+ `2 G  E3 j- U& J, J' z4 W
21
+ ~! p  h; I2 }- i228 ]  C# b, V: H5 |9 ^: W; ~$ t9 k
23
# z" K4 }5 E- @5 G" m& f' {Params to learn:
0 |- n; Y3 H* S8 @# n: u9 u7 n         fc.0.weight; A. `, V3 C" p! l, M& n4 u2 d6 ^
         fc.0.bias6 r7 A& {1 T  j% ~) m" \1 ~
1
0 s* p3 B& ~7 f3 t! r" F% U' O2
$ W: H4 l" ]- }+ J+ {3
- F: w% s1 r7 ^" I8 ~7. 训练与预测
# I6 @  [4 L: l4 H$ |: B6 ~7.1 优化器设置; f4 M1 C( e: T! f3 e  x
# 优化器设置
! _* ~/ u; F5 U& Xoptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)
* y8 j: Q0 R/ @7 A% l# 学习率衰减策略
9 l3 `$ I1 F. C/ \/ A  g# xscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)0 s" j3 t, ]9 p1 ?. p0 [- U
# 学习率每7个epoch衰减为原来的1/10: g  z+ J7 n& \4 d
# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
" F2 ]4 b2 D- w/ q9 @; i
+ L  P  c# d% \1 p7 q, b4 b- xcriterion = nn.NLLLoss()* h8 V9 G  I5 q: R  i
1* Q1 K8 A& w2 |+ D- A! m
2* r/ J7 \% G7 D+ o
3
* \) ]$ P  t9 P+ S5 w4- L- S  m% E) }9 A
5* e9 E7 P5 b, w7 ~1 `
68 u6 f. A" u4 t4 N! r, U4 T! ?
79 q! G4 t: C1 P) ]( n, v; r
8
' u) M8 M6 P0 Q. B0 T  C/ H# 定义训练函数
: e3 Z) k8 v4 w- p2 w7 C1 E#is_inception:要不要用其他的网络/ f1 Z6 a% t% a6 q' D8 u' L
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):
' O- k6 L. s+ z7 M9 @% I    since = time.time()
9 l) k& k- H% L4 Y    #保存最好的准确率
& W: Y0 T0 c/ O4 O3 d; j2 F' E    best_acc = 0' ?- v" P1 N( E: N# T; v
    """9 }; b! C, Y& v" ~- ~
    checkpoint = torch.load(filename)
* u* ?" s! {" ^$ I8 K2 S( U" [    best_acc = checkpoint['best_acc']
+ X  R& C( m7 Z* {0 c5 f    model.load_state_dict(checkpoint['state_dict'])
! j- f9 ]& S8 ?2 I8 X. _    optimizer.load_state_dict(checkpoint['optimizer'])0 [) a& u" m0 k+ I5 v% q
    model.class_to_idx = checkpoint['mapping']7 x' q% d1 U& r
    """
9 N* n- N/ ]  W. Y0 w/ r3 A    #指定用GPU还是CPU
% F8 e7 Y5 P9 J2 A$ s' |    model.to(device)
! P( h0 ]' t3 g* u5 k    #下面是为展示做的4 H# A  W7 P  p3 x' I! X
    val_acc_history = []
. x* h0 A/ S/ t8 ~    train_acc_history = []* H$ q& E* {5 T! ~; n, d
    train_losses = []
0 B' P6 l" Z" l5 F    valid_losses = []* n- \. F; s+ I% d! u6 U7 @
    LRs = [optimizer.param_groups[0]['lr']]
' y& M' X6 M3 _" N& E    #最好的一次存下来) g# r3 _; d- g2 j1 i
    best_model_wts = copy.deepcopy(model.state_dict())
" U5 y; A5 o* B5 s
* l* Y6 ]/ I$ Y* x4 k) H# F% l    for epoch in range(num_epochs):
3 M) Q: i  Y# l& o: ]& d        print('Epoch {}/{}'.format(epoch, num_epochs - 1))) S9 r6 h7 K% f4 D) B2 K
        print('-' * 10)
1 R. o7 ?% q+ N# }2 {8 x4 T) S5 J$ B7 n
        # 训练和验证
2 |, o3 H8 C: _8 h1 }; Q5 l8 V        for phase in ['train', 'valid']:/ r& \5 d% }8 H$ b/ |
            if phase == 'train':
* f3 x* e8 P: C4 H& @0 @& M                model.train()  # 训练2 F* Q/ I' P' a; D% i$ A3 o
            else:
6 t  c  o: n+ p+ ^/ P% O                model.eval()   # 验证* Z. v9 o8 v! w( M: Y6 l4 @: c

9 z) I3 q3 U) y7 }1 D/ }/ T1 l! X            running_loss = 0.00 T' s* W% D9 M7 h( b
            running_corrects = 0
) J! h: V  Y' T& K& A1 {& o. ]1 Y; e9 i: k2 ], r# Q
            # 把数据都取个遍- V* G$ ~3 E" W
            for inputs, labels in dataloaders[phase]:1 B2 p' M2 x) m1 |- X6 v
                #下面是将inputs,labels传到GPU
+ m/ k6 x+ l# ^  `                inputs = inputs.to(device)! m. D3 q8 `1 w! V1 J
                labels = labels.to(device)
1 Y5 @% ~8 T3 x  ]
; o1 A( {; N# G/ V# y! ^                # 清零% z/ Z. ]! K+ I
                optimizer.zero_grad()
3 e) X2 p; t0 _# ~                # 只有训练的时候计算和更新梯度
$ t, S# Q0 R5 ]2 N: \) b" v0 e- B/ t                with torch.set_grad_enabled(phase == 'train'):
# ^0 ]: t2 ^7 ]                    #if这面不需要计算,可忽略
# ~7 t- n  P: L' F- D! @* X                    if is_inception and phase == 'train':# ^; B5 c; [0 ?0 z9 r3 x
                        outputs, aux_outputs = model(inputs)
6 d! }! @, O0 w2 m                        loss1 = criterion(outputs, labels)# c' p5 B- ]! ]5 a$ Q6 d
                        loss2 = criterion(aux_outputs, labels)3 ?7 ^0 g9 a  w$ _* @
                        loss = loss1 + 0.4*loss2
6 m! x; I1 Y  r% E+ ]0 n4 s                    else:#resnet执行的是这里7 [; p6 C# C5 a& X6 I; w
                        outputs = model(inputs)1 F' s3 {* D$ p! U/ C  n
                        loss = criterion(outputs, labels)( _3 j; G) _! l0 ^4 Y

2 \1 e8 |) A' Q$ F( d                        #概率最大的返回preds. _, U# h* L* z( W% r4 I2 T( O: C4 X
                    _, preds = torch.max(outputs, 1)! q( l7 a. a6 @
) A" z% g/ s" ^9 B% a" T
                    # 训练阶段更新权重
0 c8 ]8 d) Y+ P$ ?" f                    if phase == 'train':
: d3 l* M. S: [! T. b1 |                        loss.backward(): r" T. Q  p; @
                        optimizer.step()0 v) ^; T# U1 }' V) ]
6 m" l: ]4 f" P- v
                # 计算损失: z6 T1 f9 V7 s  T% t: j9 V
                running_loss += loss.item() * inputs.size(0)9 y( \5 V1 ~2 W0 a
                running_corrects += torch.sum(preds == labels.data)
3 b& j2 V+ z) e0 v6 J8 U3 X8 r
/ k2 Y' @$ Q2 |$ e: Z            #打印操作) Y6 d9 _  Z- R: G: ^6 T7 ^
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
/ z/ Y+ c7 n0 ^, o5 i5 k+ K            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
- K% j  X& J0 R0 K- }
8 T, f/ Z$ E! k% Q2 s- d" j3 h9 u7 a$ d$ ^2 R+ B. h8 M; D* n
            time_elapsed = time.time() - since
) k& B8 z' F- H; a            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))7 V+ M. W% m; h3 P! B7 a0 v
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)); t3 R( e" h1 \' {% z6 ]
0 Q6 H. Y8 c2 c

: \3 c" i, `: p            # 得到最好那次的模型5 J2 V- @9 X9 ?- J6 `/ d
            if phase == 'valid' and epoch_acc > best_acc:% Q- F7 D; s% W- e6 n3 Y1 a
                best_acc = epoch_acc4 A. z0 e9 r2 f. L8 m8 q( R% h
                #模型保存
6 \$ H8 {) _1 X! s! t                best_model_wts = copy.deepcopy(model.state_dict())) ]7 ]' h3 @- K/ t. c
                state = {
# M) k0 C" l1 X( K5 h# c# l! t                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数
) {1 C  \! v" Y( Y                  'state_dict': model.state_dict(),5 z* ~# l3 h- _$ j8 O
                  'best_acc': best_acc," ^5 D! O& _- S+ q' f* a* ~
                  'optimizer' : optimizer.state_dict(),
1 s* y0 u" ~+ C$ I                }
; }- F+ k3 X* P1 L* R                torch.save(state, filename)7 F# C7 B) e% I
            if phase == 'valid':& w! V  o2 c0 I2 {9 r* i
                val_acc_history.append(epoch_acc)! b2 _1 w/ J4 ?$ z3 r- K6 A" _* h
                valid_losses.append(epoch_loss)
2 k4 \) X, X0 X5 b  Z5 ~  c+ M4 q                scheduler.step(epoch_loss)
% h" S: D  X0 C/ E/ ~            if phase == 'train':) X3 ?& x5 S, L! u- M; s9 Y# ]
                train_acc_history.append(epoch_acc)
/ n& L' I6 A9 G+ f+ W                train_losses.append(epoch_loss)
5 {. K# I8 x# i3 Q% b" `
, o# N5 t9 j9 C4 z        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
: ~: B+ M+ u: F9 k        LRs.append(optimizer.param_groups[0]['lr'])+ p1 h! T$ L  H% \! D
        print()+ y3 A# @; p5 k/ R

- r8 h9 x$ w" F2 g    time_elapsed = time.time() - since
% H6 {2 T5 [* E, p% |7 |+ T    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))6 t/ T$ j+ m6 x
    print('Best val Acc: {:4f}'.format(best_acc))( q. `3 ~  [; p# l0 I4 Y/ b5 g( H" n

+ w% p4 z; `3 P- M$ C5 V    # 保存训练完后用最好的一次当做模型最终的结果9 H' c9 m4 }0 `. _
    model.load_state_dict(best_model_wts)/ ?) G1 d) u4 m9 n& p3 s+ l" P$ }2 S
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs # O7 E7 d% N" M

* L; [4 [$ q* c  v3 k1 `" k% x! \' k+ g. h; z" e$ V5 {/ U2 ^# s
1& c5 l5 n5 B3 T; b: g( |2 d( ]$ G  y
26 ?1 b& `3 Q4 L8 x
34 X5 U) R/ |* H1 ]7 R0 N) H
4
# ]$ ~3 f3 X! [. }5
, `2 {3 p' j, v' P! w/ M: o! J! \6# R, U" `1 }8 C. w( b3 Z: {
7! Y2 r9 u/ g/ U
8: f. n) g& n  ^* _* O2 M: Y
9
9 N  x$ _8 `* F4 y0 N102 Y6 m5 |1 w8 n
111 ]( ~5 p& j4 v+ V6 _1 }
12, V3 J( D3 n2 ]* T& i4 ~- x
13
; v: D1 A, w. d14
, l, h2 E5 t+ n; [9 c# M) d15! ^$ ]% t; i8 a& r, m1 ~3 f
16
0 f; l" q% J1 I( \179 B; c0 e! f; h, w
18
9 H  d! L' Q5 R7 `8 R1 K# {6 \& R19
" T$ ?; H! |1 ?2 j20
9 ^9 _8 b3 S6 g) J" f215 x! S5 W$ p# x/ r" c& x
22$ C4 V1 Q( u3 a5 c# j
23# t: Y6 U1 G- Q+ f0 O2 F& T
24
) V/ u- v" ]4 e$ l+ h- o25& w! w1 D6 y7 a& ~, L
26
) z9 ?3 [0 s6 z' u3 h27
: H! D+ ^- G8 O28* ]- \5 `7 H/ V) g
29
3 T+ _/ @) ^" ]7 f' H30
+ X! w: v+ C! W9 ]! Q; o( [. |4 h319 \" {7 P* S. \+ z% j$ j! o
32; i0 ~* e, ^1 r# z
33* Q- s, I- M1 [  W9 q) j/ e
346 n0 @! X7 L. T8 y0 z6 o% Q
35
5 }# ~; X" z" l# V. P8 l366 b. d8 C! O1 b/ n. T$ h+ }% x
37
" G4 Z& m+ ~; |) c8 W1 O38# `$ |+ Z0 n: q+ x
39, [' Y& p, Y; W( O. O/ V" C/ X
40
( O# z! a6 j7 P4 v41% V8 Q6 P8 R6 z0 v) q/ g7 w
42! W3 L: J3 K3 w, q5 m
43
+ b6 m  X8 S/ c# b, Y44  V) ?. s3 z' u$ J# C
45
4 ?- o- j" l) x5 z46
8 K# |5 ^. K8 U& P* m3 j47
& a6 O& Y; j+ p: W* C484 X+ q$ b+ C; _
49
: s4 Q! _1 V; {: b2 `# \/ p6 O4 Q+ [50: B+ [9 L9 U# P3 D
51  `3 V9 ]5 l; U* W: i
52
$ f6 J0 |, j* [1 I) ]" B53. v8 Q+ J. |( C  B' ]' O  x
54
# u* T7 W0 p4 S4 a550 x* }9 v4 U; u# Y+ |+ R6 J
569 T* W; G- U2 D+ v5 G5 E& M- z6 g
57
  U2 E' M) {! K; V' X58/ T* S$ \/ t! ~9 Y; j
59
4 i6 d7 N0 S$ a, O60- X0 [3 m/ g7 O( k/ Z8 O( f
61) E( {$ M, s8 }- g+ a
62
' ~& F9 Q* {% H% k% F8 L63
4 y) n1 u, R! b: b64
) j  W+ ~& Y7 Y1 i9 N3 \6 J, Z  |8 j+ Q65) |* F# [" b( x' c, b; p" e
668 m2 d' U7 _* |) }, n
67
1 c/ V3 c( |% c& J/ b  X8 m% r68
* h/ `2 M$ _/ B% n69# u% H2 s6 o" Z  o$ o1 E' K
70
1 K" W8 y  r: d$ V1 ^+ b& O71+ K7 B& r( s" Y4 A- P/ w
72
) K6 [4 i3 M. e735 h, w0 b5 j/ b" w1 f3 l
74* ?( \% c7 p( I) z8 x2 P
75
7 G+ {  }1 o$ z5 {  O+ W  L2 a76
2 l8 Z" a+ `* E8 G8 u( {775 e; l, h( w" ?1 j9 m7 h! `$ ^8 `' L
78
8 v5 {& Q3 B/ ?9 s79; I9 w' l, f- k4 s
80; b) M) m- ]* a6 x
81; \$ g1 D& s1 ]* M3 K7 G1 ?% _- E
824 L5 P: }  w7 I0 g" f) r1 y3 C& C
83
' ~+ A/ z1 v- p) ^, Q9 l/ G84
% e1 |* R( M, V7 j7 T3 G  s85
0 x& _; x" z/ R' D" F" Z86  {& Y# _4 v9 X2 x$ w* P6 N
87
5 ~  P9 @9 u3 w% d+ D; b88( k% q0 v$ }8 }
89
) j5 ]  M" `- y# `; Z90& A% \# |/ E% v7 u# o  p
91, s6 E' u+ p5 j0 X( D' }* H4 s
92
% {  {% Z2 g6 q+ T( t$ r7 P932 v% v* A3 P' n) w7 [" W4 ?7 @
948 O3 {# P" a* A
95
+ M2 }# U" V2 Z7 K96
, S, z$ K; B* Y$ z+ j; b97
( |. d5 h5 R; G985 X9 s5 K  Q" `' q0 L7 K% a) L
99- a: u" N- H+ s% L) B
100
# {6 F# q5 F, ^' a101
: S2 I* W& P1 \/ t4 h102
8 s  ?8 E) }- \103% @9 C- B! r" |- Z# D$ i) c
104
& [: H: L9 `$ f. Z, D' P+ a105
6 d0 b7 K% O0 Y5 b4 q106
& R1 u. k/ r: @4 q' s' @# |107; \! @; F* k# s' d; e0 v
108
+ k% l) E: @( u! F. I3 ^) G- k109
/ S' o; b- h3 S9 I- |1101 h  [3 H/ V$ K0 s! h5 m
1114 |% D0 _$ j, t
112+ w. v+ ~3 p( U- R+ j1 \
7.2 开始训练模型
' \. [- e) v" J# q% H# g我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
5 f2 {- v) F& C4 k0 [  f: K
" L, l. z, ]  h#若太慢,把epoch调低,迭代50次可能好些, N2 q% `: ~$ o% }9 H
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合
. J6 m# X$ C* q/ c3 ~7 Jmodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))1 B4 l( M# E# e1 D; `
8 A" ^$ Y- y* P# n7 {" |0 k
1
7 {% T' Q% y4 k6 L6 y$ w& L24 A5 U$ J8 r) L. {7 G) D7 F3 t
3
. _4 ?3 A4 [2 P0 B5 i5 h4
+ c5 @- l0 m0 {8 d( H) i4 F, o0 x7 }Epoch 0/4
% I7 l: h" L8 S----------- @9 Y& t+ W& ]& k/ ]- I
Time elapsed 29m 41s
5 A1 Y$ J/ \5 `train Loss: 10.4774 Acc: 0.3147
+ `$ a- T& q$ r4 [. q; q, rTime elapsed 32m 54s, P0 n, t3 l+ o3 O6 ?, l
valid Loss: 8.2902 Acc: 0.4719
3 i( l5 ^  Q2 Z+ lOptimizer learning rate : 0.0010000; G2 j  h* C% q6 w/ `
3 a- U, x3 l8 k1 q" `; O$ H) @6 M
Epoch 1/4
  ]( o8 c# Y/ P: l2 b$ x, D----------
6 I7 f7 r- o8 J, CTime elapsed 60m 11s2 a. k4 A: W- Y
train Loss: 2.3126 Acc: 0.7053; ?. g# ]7 u* l) ?1 R
Time elapsed 63m 16s
' E: k/ `/ {) h0 u! ^3 _valid Loss: 3.2325 Acc: 0.66262 J: a9 \0 Z( w3 ^+ n  T7 E8 G
Optimizer learning rate : 0.0100000/ ?* b% U) t& p+ B9 W% z
& l7 Y, Q2 [; R0 P
Epoch 2/4
. F0 t$ l" {. N# W- m+ C4 e----------1 I9 H8 R& M% Z2 P
Time elapsed 90m 58s( w- V. l! P3 |/ G
train Loss: 9.9720 Acc: 0.4734! \& T: `# I) Y, l1 Y" p3 i* i
Time elapsed 94m 4s; ~" b. r2 @$ O. V1 z  i
valid Loss: 14.0426 Acc: 0.4413! g- `- H( z- `2 Z- l
Optimizer learning rate : 0.0001000$ ^% e' n3 y; |+ c( o
: ]: G( a6 z1 s5 O) }. @- w. Z2 `
Epoch 3/4
! A  ^. _& d& I& u( b4 H7 ~----------
! h& \( B7 m0 }. v: e  vTime elapsed 132m 49s
3 C% z2 \1 s4 c  P5 qtrain Loss: 5.4290 Acc: 0.65488 `! ~! @4 A. V& g
Time elapsed 138m 49s' D- {/ `& ^3 Y
valid Loss: 6.4208 Acc: 0.6027/ |& ?3 ]+ _" x" ?  X
Optimizer learning rate : 0.01000007 S4 M* U: c/ m. U2 d
% X  r- z( O# w" R8 u  L5 N
Epoch 4/4
1 v- M5 P- h4 p! p. q; P+ f* Y----------
1 J3 g2 {6 y: p: rTime elapsed 195m 56s& P  M/ {: f: Z2 D% k4 c; y
train Loss: 8.8911 Acc: 0.5519
# |# N- V. ^9 {: m0 I0 J% ATime elapsed 199m 16s
8 Z* d( w+ B/ @. m6 ]6 a$ x" `valid Loss: 13.2221 Acc: 0.4914
" q* Y* @6 \; y0 n# G; [0 i# L0 DOptimizer learning rate : 0.0010000# U7 u) F, }) @0 k9 T1 j4 F
8 f: k9 g5 h, _+ G( B. U: Z/ Z7 ^
Training complete in 199m 16s
6 f6 u& i( P* J0 Z& _! W( Q' z  C7 bBest val Acc: 0.662592
2 U" G) U9 r: c4 |$ E
- f9 a& T$ B6 Y8 \1. z6 g& p% T7 U+ I5 \
2
9 J8 `! b& H9 w+ F. _3/ L- r. c1 |' n1 i$ W2 t( c9 k
4( I2 q& }9 S4 k" I! Y
5
2 I# O4 b6 c) T9 X! t3 f/ d5 U6" X+ A' ]( l) o# g; [0 `" P3 L3 H
7
* ~  x4 l9 u5 J! c& I7 n/ {8
+ ^- l5 c  \" x% |4 ?9
9 P5 v* F  b# m- i  V4 ]# o! t/ k10
. Y1 ?) X' U% b# s6 t9 t11& M, B9 P" n) q& R! Z; B
12
" U) m# ]0 x) W% z7 B( W13' O( o! q6 p8 n0 W6 F5 a
14
0 {- I2 V$ K, {15& q; u) r4 n9 m, g  Y8 x
16
0 X: C- ]( c1 ?17
6 y! t1 n6 D" j& [" c' t* j( l. I18- P4 H: q9 U+ `
196 W6 C8 Z( p& Z  [! ]7 D8 H$ u. m1 P
20
) o, A, G0 ^  f2 T! r& M4 l21
: H) ?3 d' M& h1 Y& B  [22( R0 i: n) }' F
23
, M/ K3 K0 z/ J* Y: q2 B2 o24
& V" `& [: x  R2 E+ l6 L255 a; o# G' K4 r& H) ]# l- w4 ]
26
* `& D) ?" W0 _27
! y# y% W4 F. |4 c8 s' ~28
6 J) h# Z! Z! F+ _9 L3 _29) O4 e4 d' |$ b6 H: ^4 y
30, \, f5 n3 S* A4 J6 L# R2 W- d7 i
311 r- m0 v. p/ X; p% S" n0 P9 Y
32
6 {! Z1 x2 D3 C( q1 G/ x33
( A$ F3 Y8 I& G9 C/ J3 H/ |34  a. ~. |5 e2 w1 k, s
35
/ J$ T( j: o8 l36
3 c& z  m" ]/ Y37
. ]6 V' k7 Y3 }4 ^9 M! r; d380 K* m5 u: ~& r7 e
39
" p1 V8 W  m8 L; I- m, K40
7 i/ V2 R6 x( b3 J41
/ q7 p: j) I" I- G6 a# L! p  Z8 I42
: g8 D* s/ m/ X0 I: h" V2 l7 x7.3 训练所有层0 ~8 S$ F8 q' f2 J
# 将全部网络解锁进行训练
3 a* }/ R$ D8 [7 J, Q( R- [for param in model_ft.parameters():
- N' \1 {+ Z4 }$ S* o8 T7 W    param.requires_grad = True/ q: C: E; V, j) I" n) x; z( C$ Y
" `* I% S! d, p/ h: Z8 a1 A
# 再继续训练所有的参数,学习率调小一点\
4 e- X( F- O' q% I0 ioptimizer = optim.Adam(params_to_update, lr = 1e-4)
% G( _) R# i% T4 W; H, wscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)- p6 A$ r8 i) y6 W5 ^( o
' i8 n4 F, a0 S! I( F& ~( S* i
# 损失函数
( A; i: v0 d2 K, Ocriterion = nn.NLLLoss()* i( c, T( d3 V. ^% c
1
/ s* B6 C+ b6 D+ Z; I2 b2
+ V* L0 j+ R' S0 N" b9 \) N3
5 o# f3 D! L$ ]1 t# k4
# E6 v( Y8 V9 _! @- ?- k0 J5
4 a0 |! D0 i( J1 l; A0 H2 M0 p  @6 P6
8 H$ }! X6 ^( B6 i9 S  v7  T  S2 k7 k+ D! B: D- {; n9 _
8
1 b, s! S; B3 x5 g% @- }0 \9
, P! X& ]! f$ H! v( R10# g" c0 A* x# M& B0 R; ^; T
# 加载保存的参数
$ Z2 d0 Z( J5 ^; X" t& l% ?# 并在原有的模型基础上继续训练
# n8 O3 b% c1 D( f# 下面保存的是刚刚训练效果较好的路径9 w" j/ R* d' p1 h4 ^& i
checkpoint = torch.load(filename)% k. Q$ |3 u/ R& I, x. Q
best_acc = checkpoint['best_acc']
, I& s& E! |: ymodel_ft.load_state_dict(checkpoint['state_dict'])2 P4 b4 @4 M; o# P* Z+ K' M  ^
optimizer.load_state_dict(checkpoint['optimizer'])" W$ z: q, a$ a  Q
1
  c% [7 X& o% Q- E$ G  H2
1 N' \' {9 {' s  G, ~; m3 i' A3
" N5 ~% H7 P: |' c# D4/ n! |  N+ g' s' ?+ \
59 R! T' S0 N+ I+ J  q! |$ i, l) X
69 ~8 n2 z3 \, b; D
7
% F# f! |* `0 q7 A开始训练
9 s1 Q' N& V- S  i注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考
0 r2 I8 d* Z: M8 V' M7 M! z4 N6 O3 @
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))6 C4 T$ s$ f) r# v! @8 N% V1 H, V) ~
1
* i: [9 z- z. g% ]* cEpoch 0/10 }* D& i* _$ m! j  B+ Q$ f2 Y
----------
- g7 q  V' o: n+ D  DTime elapsed 35m 22s, B% c! d% f1 U. }7 V7 N. A* ?
train Loss: 1.7636 Acc: 0.7346( C; ]2 n1 A5 i" ?" Z! O" U
Time elapsed 38m 42s9 d" T% N: v- e9 q+ l, @
valid Loss: 3.6377 Acc: 0.6455
, e/ j% x# R2 P! n, IOptimizer learning rate : 0.0010000
5 j' _! U% ~+ S6 G) b( \# q3 U8 V2 G) P- a# G6 w3 A8 z1 Z
Epoch 1/1
/ y- A" ]/ K, R- k0 b6 F3 M----------  M2 M8 z0 W2 @% H/ G# l2 g: r
Time elapsed 82m 59s! k8 D1 {0 w! q8 K
train Loss: 1.7543 Acc: 0.7340  E+ v8 i8 d" Z
Time elapsed 86m 11s, v2 |5 i' P& ^% x' \  z2 z
valid Loss: 3.8275 Acc: 0.6137. P9 a8 z8 ^+ a. H* J" N! r  ~4 c
Optimizer learning rate : 0.0010000& @! z) k1 c8 I3 X% N
' ^" f& Q9 A" T/ ], X
Training complete in 86m 11s* l# p' w+ V2 c4 U# Q$ G) h
Best val Acc: 0.645477  V5 P2 E  ]5 m* V, J( ]
8 f* ^; o  w& |$ t& o6 S5 J2 h% r0 H
1
0 j6 P% E! s8 h4 G2
: G- r( U- z# S" d: H  f31 _) `+ [0 y) Q6 A. e: h
45 C, V7 G6 w) Y% ]% @9 B
5
$ v$ n, ?, W" G* @% _6
5 b! Y! r. p8 D4 I4 N7' o/ h0 }3 \8 R4 S! S
87 ?( u: \3 ]6 R: n( s
96 ?# T% [8 C5 s
10
$ j: A: S, ], w9 C11
; c1 w% X" _" g* w12/ \* |$ `$ n1 Y+ i) a' w- U) X
13
$ F% g. f1 O3 i14; |7 ?. L/ f5 K9 J4 a
15. K1 T' U7 b! J
16% P( O1 ^$ n) d* E8 K
17" p: z3 |4 D( I3 H
18
/ V* |& {$ y9 f9 j# A' Z8. 加载已经训练的模型$ u1 O! I4 L6 I1 R" {( P
相当于做一次简单的前向传播(逻辑推理),不用更新参数" J. ^; j% O  ]; ?0 N9 T( I
$ J' R% o7 l6 w' S, x1 P! W( p
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)! z( v( ]! ^8 A# }
$ K. ^3 n! X  }: k0 v
# GPU 模式  D' P" j8 f2 g8 }% B4 ?3 H
model_ft = model_ft.to(device) # 扔到GPU中+ E  ^: w) T. [4 v* f0 l( @2 X( n
# A; Z1 |+ S2 u& |/ s
# 保存文件的名字
5 }1 m. E7 _, m# F! o8 G# tfilename='checkpoint.pth'2 ]; V' \) {' d3 T2 a: y
: V6 [% _; f- }; ?9 ?
# 加载模型  P. o7 x! j/ f1 {
checkpoint = torch.load(filename)
% ~8 _; B, k3 D2 \/ [: hbest_acc = checkpoint['best_acc'], Z+ @) c3 J! h0 T3 \& B/ \
model_ft.load_state_dict(checkpoint['state_dict']): f% \# {$ Y! i1 P& v+ X7 B. Z7 ]
1
+ D- [# c) k* D: x2
6 n; m; o5 K( ~. R& O7 `9 I35 s4 u# Q( \/ l
4) e7 O, v: O4 y  z
5
. R6 [% s5 @: y8 f6
1 f* `: c' o1 \. C7- W' `( H' l6 ^8 H8 y8 M, R
8
- s& Q/ ~8 C0 }# y/ B1 }% l9/ V4 d3 R% N9 C0 w% T" m8 c
103 v- o4 ]. p3 j) H1 [) E
111 Y$ _) j/ h* f: ~5 x
128 A+ x. _! @9 y* G
<All keys matched successfully>
- Y, Z* q' u6 ]# z9 ~2 @17 Z- v( l- N( N; Z7 w( Z  C
def process_image(image_path):
) H1 ^+ r! Y# ]$ V8 x+ |    # 读取测试集数据& R# F& ~/ ~! Y  H; n) h
    img = Image.open(image_path)- a3 g9 s1 t& L1 k; m
    # Resize, thumbnail方法只能进行比例缩小,所以进行判断; e; d( X5 k9 b/ V9 C) g& [
    # 与Resize不同
% V# N( H" J$ k) I9 O    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
1 S, u& E' y% D/ m7 l7 V0 d    # 而且对象调用方法会直接改变其大小,返回None
/ U) O$ C% b% m" y; P( o    if img.size[0] > img.size[1]:
! W, p4 G' J3 j, Z" S2 Z& x$ s; Y+ `        img.thumbnail((10000, 256))# q7 j! z1 y) ^6 A, c- V' W
    else:: U2 C* @7 P8 {, V- k( d" |/ v
        img.thumbnail((256, 10000))/ _2 Y- X* [0 k+ {  i
- j, e! ]# t; z! X4 Y6 J5 ^
    # crop操作, 将图像再次裁剪为 224 * 224: ?+ W! I: v, m& o* P
    left_margin = (img.width - 224) / 2 # 取中间的部分
2 P9 W4 s& V; Y% Y0 `' l    bottom_margin = (img.height - 224) / 2 . q6 t; ~+ U  V: A8 f4 _8 ^
    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度$ r; s1 w2 A+ D
    top_margin = bottom_margin + 2242 x+ b! x$ ], c" V; G; w, m- t2 V8 B

) N( p# e; S# j" p; v    img = img.crop((left_margin, bottom_margin, right_margin, top_margin))' u! ~: b2 ~/ m9 D# v1 R. N

! J; X- u$ j. E) o# ]    # 相同预处理的方法4 L( Q* s' Y/ y3 M6 a6 h
    # 归一化& r( J. H1 d% W; D' |) G1 g3 X
    img = np.array(img) / 2556 ]7 G" \, u% g  f
    mean = np.array([0.485, 0.456, 0.406])9 e  h0 k! K3 B4 b4 I
    std = np.array([0.229, 0.224, 0.225])% _# Q) \" V: ~/ O* ^( ?
    img = (img - mean) / std% I( ~) o% P7 o+ B
( ^- b/ A# S3 p, C6 d( R* ?
    # 注意颜色通道和位置
: u: d, @( v; i1 Q    img = img.transpose((2, 0, 1))9 Q* r9 M2 b+ ?" E8 b% Y" e! Q
# O" L9 F4 Q/ Q, z( }% j% e8 V) j
    return img4 C% X7 r5 Q/ _6 w/ [3 m

1 {9 d5 d0 T$ S; x7 c$ _3 Xdef imshow(image, ax = None, title = None):
; _4 ^  j. v7 H: N2 E) i    """展示数据""". C! C  b4 j0 n
    if ax is None:9 M; M3 p$ ^# P4 A( D1 y0 J
        fig, ax = plt.subplots()
- r$ a, U5 s# f! O7 X6 h/ x* s1 H% L
    # 颜色通道进行还原# [$ V, L0 M( v' L8 _/ g
    image = np.array(image).transpose((1, 2, 0))  ?; D! P3 V. S

% k9 a: F) o/ W" n6 R! v# E    # 预处理还原
& b+ m; B0 |9 @4 X0 z2 O    mean = np.array([0.485, 0.456, 0.406])
. H5 x$ M1 H- Y5 d( T    std = np.array([0.229, 0.224, 0.225])- Y* F. z/ i1 j- z+ k( C' F
    image = std * image + mean  W% A6 R+ u! J! i# b7 |2 i
    image = np.clip(image, 0, 1)
, d! ], n2 G1 f2 |3 m) l  o' I' Q  ]* p5 `: d
    ax.imshow(image)$ U0 E  N( Y! k1 c- g! T2 _; y( m
    ax.set_title(title)
" n% ~9 w1 _7 j3 V* n
( d; G+ O+ U6 @8 r    return ax, A. M4 P( j8 H. l  I

% B: Y. R6 v; l# ]: nimage_path = r'./flower_data/valid/3/image_06621.jpg'
! V7 a# Y' M4 B* B2 L' ^# E% kimg = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理6 s! d6 k, b. o
imshow(img)  r4 y* N! h" ^  ~- I# e8 e5 @

: c% J1 x  G' T5 b1; K! @: ?' K# H% K: F
2- t$ C$ B" k' u7 o9 O" f
3
0 ], C" {0 z9 z42 \1 J+ I" [' L
5
9 W  W- e. F9 e% ]  w3 N6
* s  P: _( @* o. V* F9 T; w7
, O' w- U! }: ?8
6 ?7 B0 {! b! E/ q/ @! _9
' g  P' s/ r% d% P, \) b. m; j8 L105 @  r2 y6 ]) M  K  w
11
9 z4 |. T2 I5 X9 v12
! c  U( N- x& f6 M9 g, c137 J) s, M2 v( t, r$ r2 N8 o3 R) I
14" ~  C5 u& ?2 T$ e4 S: Q
154 D# x# j3 a0 m5 V. D# w
16  L; a# ?7 Z6 V# g# F4 E
171 K' p; d  z9 P8 P! j: |
18
2 P' H: k6 X( K$ S& G19
; _$ C3 C# f7 B. c5 a. x20
; C- ^& f9 m  ~6 v$ a. _219 s8 H9 Q/ n0 f! j; f
22; X7 ^) p1 M" ]
23
% v  j5 F0 B3 h245 [( H$ |2 E* q3 [+ f9 ^+ N
257 n" J, ~, [/ @% @
26
- h# T) O; h* h! B9 K% Q* D0 R; V27
# L) v3 B1 |/ o' M! F0 O28: v  y) l( ?1 i9 f2 T  {( g( e
29
( M1 _: F% h% D2 S9 u' V30, H! Z+ R: d$ S0 L4 V9 J3 V
31
% [- }2 q( C; h2 _2 {2 i32
; h% H- c. f% w6 y1 i5 t338 V9 L7 f' \: C( p* i% I, m6 `
34
+ e# l' Z2 ?( `" J. o& I, n+ T8 q35# @* o& W4 d9 S- ^( I; c8 \* \
36
) i6 a% m0 o5 J- s  `& t% v37
% c. i$ }0 V4 k. ]* t7 o38
: U* \! i2 C7 [39  l  g' y8 x+ f  Y2 _2 T
40/ ?0 B" W8 F3 }+ q/ ]2 o' ^( b  O# w
41. ?) R7 \0 w0 y# x, e! R- z) z3 e2 `3 Y
423 m0 [$ h: a) E( i. g. g% z
43
. w) }( B' F/ p4 K1 h44$ B8 ~& q8 r# M* z* b4 r' c& w
45
0 {! S9 V" H+ b' G. m! Y464 r& N1 M! o( x' k! D
47% Z2 v( Y  H% j8 o- J
48
! B& z  v1 J' K4 [7 Y49) h( x' ^% C: [3 f! r
50- r* H* M; K" k: C' a3 |2 g
51
$ P. T+ [7 _9 X7 O7 N' q! i' v" B52) D5 ?  A) ^! F; B8 l& a
539 n+ v1 I: N; t- a# ?0 B, U. @
54% o3 w# j. R. C5 U' u' r8 I
<AxesSubplot:>
  u- F! @' W; v7 K! ]1$ p$ S! i, O/ M: i

  T* O2 n0 Q  i: t9 R+ x9 ?' X上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确; c% N: R1 w0 T7 u$ |
  K, V! Y9 }/ q9 ~, u
img.shape
% j% V6 ]6 }2 |! ^$ ~8 m1 \% [% `1/ U- v& V. N; ?" S6 Y* a2 [
(3, 224, 224)
& u; d# ]3 M9 y+ B1" V; K* T* l' g! t
证明了通道提前了,而且大小没改变# j2 R7 w$ C+ d* c# }% ^
' q% M8 u) i  @6 |
9. 推理
  N, W( U% y! ]) d- \' t9 Uimg.shape
9 a$ H' c; J4 ?/ t6 j- R4 R+ {9 n. e7 p3 }( L1 H# u: g6 x! {
# 得到一个batch的测试数据9 |) b! ?4 u, M- l+ F
dataiter = iter(dataloaders['valid'])0 R& }# P4 h: @) o) ]# k
images, labels = dataiter.next()
% g+ ^5 h+ M. X8 H- s! h
4 F4 c7 ?5 t3 g5 r. |* f8 i2 g% Gmodel_ft.eval()5 ]+ ~/ x" b7 D. S' n$ D) H% m1 ~

# R4 q' v# I5 \/ _5 B0 Cif train_on_gpu:% E" _' I/ B3 W5 S
    # 前向传播跑一次会得到output1 y& f- P# t- q4 s; c7 @# E4 N
    output = model_ft(images.cuda())
& S/ J+ W8 r* @) F* q/ H5 N( celse:
- U  z, f+ Q0 Q8 ~    output = model_ft(images). J8 ~# _% H# ~. X
% n9 E8 O. v- D
# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值
& M1 B9 X9 S" h/ Zoutput.shape: \- ~3 x0 j5 D: r

, K! F1 u+ @) P: B1 I; `+ e  n: E/ X1
& A  ~& V8 |& q: [' X2" x8 U7 r& y+ L2 E5 f, |
3
. r4 B) s9 U" P- y8 d7 I4+ X0 i3 o# v8 @3 {! v; M4 _
5
3 s% @- u- r/ T6) R. p* q) @' I! H6 Z/ y& ?
7) K: g! d& e/ [' U
84 |9 j* g: d8 y/ c0 q% S
95 o- H7 I) }- i  j
10
& S* _0 H5 v- {+ j11$ |: \" [/ y# c
12
. Z# Q( C4 l  D, |" K131 h% \) {/ N) }' F! @, \
14
8 n9 s% z0 e/ b& \% {0 d15
6 c# M. ]: M& g( l+ a% `16
4 E' A! ~/ k: p: l. x( Otorch.Size([8, 102])6 M/ }5 L: ]; W) U" j: @
1
' u9 i2 p8 W3 t& v9.1 计算得到最大概率& z. L+ ~5 Z) n- r/ h/ w. k: @  y
_, preds_tensor = torch.max(output, 1)
# w- e; O# M4 ]4 m, J
5 F2 z1 c8 [3 L+ g; M$ m) p- }preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量" U* y! w. T9 ?; K, t$ ?9 s
1+ Y* e" h5 c9 H; d- i
2
. w2 m% m# L) F) x2 I7 g3
1 Q8 g; v! K3 ?0 R! e* L- }9.2 展示预测结果
+ ~# ?3 W4 i4 f: i# k3 Q" Hfig = plt.figure(figsize = (20, 20))
9 F1 H# d0 t' N5 X- c- Jcolumns = 4
% G2 ^7 o8 g, Irows = 25 }. K% _% m1 O* A
; f" t% g! n4 Z- ~- L6 u6 Q! A6 O
for idx in range(columns * rows):
4 s' I9 {4 H; q    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
: C0 n; j) D, a0 [    plt.imshow(im_convert(images[idx]))% H3 v/ n  y5 a" A0 e5 X, q3 J
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), 1 D0 {' g3 m$ [+ B
                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
+ P! ~6 i; ?: {& j" Z, Mplt.show()
& q# n! J. j' O# w; X$ g7 b& }# 绿色的表示预测是对的,红色表示预测错了" `/ Y5 j7 t& j- r2 u( R! ]9 R5 X
1
% @" F8 H" u# F+ e& D27 s- L/ q! K/ p0 c9 p
3
% n% r/ n: J4 u2 H5 \/ Q% H4
, N  \0 x. @# `  d7 d; u5
$ H/ A; `+ C' [6
$ [* k9 w# B: D77 _9 g  o5 ]& S4 l0 m+ ^* D
8
+ X: y; `2 G; M+ `( q+ C9
, |3 p1 G7 x) W5 t10
# g5 `$ M" n+ c118 n  ?: d3 R! ?) K8 Y: O: O' z# k1 i
! d* l  J  D9 _4 W0 m

& f) z, @, t, {& g/ S. f# A9 C
* p$ F9 Q4 c, w9 p! ]% i" e————————————————
3 B) O" [6 z/ J* T# F$ W# j版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。5 X4 D8 S' @% i; p
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940
0 U- K8 q* ]5 z' y0 b) T6 _2 O# h. K4 z  E' ^! f, D+ j3 ?
- u/ S; Z# w, @/ y* x/ r/ g8 _





欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5