数学建模社区-数学中国

标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例 [打印本页]

作者: 杨利霞    时间: 2022-9-8 10:41
标题: 【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例
【深度学习】 图像识别实战 102鲜花分类(flower 102)实战案例9 b7 l* v- n# H: ?# Z- A, @
( J& e2 I0 s; |* Y+ U- [
文章目录0 g" g" P1 Y2 `! ?
卷积网络实战 对花进行分类
& T; n6 t1 H" r8 u, Z& q6 j0 [, K数据预处理部分/ r  I# t- y9 h, P8 l
网络模块设置$ k3 M$ Z4 j' C
网络模型的保存与测试
7 m- Z6 q% F: u$ v* o: _! v$ ~数据下载:) r( b, \5 D/ f3 [
1. 导入工具包
- O" _3 y% n4 g2 [2. 数据预处理与操作# {: O6 ]- h/ \  j
3. 制作好数据源( a; c7 J( e4 R- E, M! @- k
读取标签对应的实际名字
; a, a: _6 L% ~" B; w2 l5 d( C3 ?4 R3 y4.展示一下数据
5 P5 h* C# G5 m5 Z8 I) I, S5. 加载models提供的模型,并直接用训练好的权重做初始化参数8 I' o9 e8 V- E. l" l& J8 u9 |
6.初始化模型架构& w0 O. i5 E- S0 K; f0 T+ w8 Y
7. 设置需要训练的参数
( k% |" |9 U/ p, k: [8 h7. 训练与预测
% c8 {" ~$ O, d+ v+ t; v7.1 优化器设置
' ~$ a- N9 W( l% L  d) p& k0 t, w7.2 开始训练模型
  L0 A2 S# |) l6 I9 \& e% _7.3 训练所有层
0 k6 `  e# g. s" ~开始训练/ B/ e3 T7 x, p; n8 r- V
8. 加载已经训练的模型
$ n( t+ u1 y0 o9. 推理
: D# q# F1 S$ ^, _9.1 计算得到最大概率
" P0 a* m- i& x5 d4 l9.2 展示预测结果
; l# I  b6 P& G0 M. X写在最后6 r) v; y  L  f+ y6 [
卷积网络实战 对花进行分类
) ^: f( d; E  l. x; o0 f本文主要对牛津大学的花卉数据集flower进行分类任务,写了一个具有普适性的神经网络架构(主要采用ResNet进行实现),结合了pytorch的框架中的一些常用操作,预处理、训练、模型保存、模型加载等功能' H( Y# M% k9 B* X) d
. k. F+ b  ?8 ]% z
在文件夹中有102种花,我们主要要对这些花进行分类任务
* ^7 G8 I0 ~+ m* v文件夹结构
& A" H( l) r3 r- U' R& j% \& i( M0 ]% H+ @! H$ d
flower_data
. ~, I7 c- g+ {2 ]6 d! K/ e6 U6 l# X& T; y
train
& h4 Y- Y0 s  O" ?( e1 P" H4 `$ l- }( P, B) ^; |
1(类别)& @# W& o% R  p; l* K
2+ g( ~6 p! q+ X; q
xxx.png / xxx.jpg
( R4 s1 Z) G4 yvalid
, [% k, s* t# _$ b1 f
% l$ V" h2 {$ k$ k/ D6 J' F! F主要分为以下几个大模块
" b+ ?! O* D4 M: z* S. a- W% {8 q
数据预处理部分) C, N% S* z$ p* L
数据增强/ l% Z0 n' R- P
数据预处理% \; b# J8 h4 g/ e: F) ^4 Y: ]8 T5 m
网络模块设置% |- T: s: w% b% ]# I- {3 Y0 D
加载预训练模型,直接调用torchVision的经典网络架构
* F. I3 w  G0 W& J: ~  P, s因为别人的训练任务有可能是1000分类(不一定分类一样),应该将其改为我们自己的任务
4 u6 \) S- I( g; C; @) C网络模型的保存与测试
) _3 {$ g5 X" [5 t8 h1 f模型保存可以带有选择性
& R1 ?: O/ Z7 J/ |: Z$ x数据下载:9 a5 F/ V$ v* H! x" R( M# u8 ^. W
https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset
% x/ W) `/ H- M$ @0 j! `9 l' V% B9 f+ G5 ^3 n% L1 K0 y4 R! ^0 X
改一下文件名,然后将它放到同一根目录就可以了
. O% b( z7 s; {8 h' g7 l, z9 V( f
. b3 O6 ?9 t) _6 F1 [下面是我的数据根目录. C' @4 p6 l) Q' V+ M$ x

, H3 c* \1 M: k# Q1 g# N" M% b% V! ]8 @" n
1. 导入工具包8 B- i9 K7 j+ L1 |: F9 B
import os, `! B. @" Z1 ~1 x
import matplotlib.pyplot as plt7 k2 `  @  T$ }; Q
# 内嵌入绘图简去show的句柄6 n1 b  s8 S. \- P! O8 ?& F3 ^
%matplotlib inline
4 M6 G' \3 j# _import numpy as np9 a6 \$ n1 E+ N  _* [, h
import torch
3 l0 J. {# D$ F% rfrom torch import nn
9 w3 X  E2 J0 z9 M% p
8 P0 C4 r+ ?; z7 oimport torch.optim as optim( i4 m3 x: f' j
import torchvision
- u! g: Z! }, Tfrom torchvision import transforms, models, datasets
/ c4 v# d0 J8 P# {* W; |* A
4 `! s. [4 j7 i5 o6 d4 l. y3 dimport imageio/ R% x0 d2 s) s6 X1 c
import time
" h& |7 C8 L8 M2 i; Wimport warnings
" h! ~! U0 ?  s! Uimport random7 Z5 m: v. x' d+ j
import sys1 W4 }7 L4 a" o% h" c9 R
import copy
# R( a$ K+ N2 C' y( ^import json
/ w, f' e; X$ S4 l' ]2 vfrom PIL import Image9 g" O" I0 s: ?8 u

! Q1 ^! Y# S" g( q: U; Z* C6 \# `4 r
1
4 F2 S* e% j) c, B) l26 m+ r- Q9 E& N+ t4 g3 {+ r
3
) [3 O3 E+ S/ Q: b$ b# h% l42 u/ h* z9 X* l3 \0 A
5
8 o2 c, f6 r, y. X' L. o6
7 o) }. I/ Q) P( l8 X7
9 g2 Q, C$ U, M  u  g1 C5 ^8% X! O6 F" T8 F9 v2 L
9
( E. i' ?: _  x% o% U10
: B+ N- F+ B6 f6 z2 C! f) @, r11
. O! A1 E( `9 t4 t' s' d$ q129 R% m" N/ O1 `4 K6 x3 o' G+ G. T3 [
13
% L& Q9 m  s1 A8 Z144 E* f" Y& T, K/ B
15
* x$ B2 @! x% x+ V# a: {; d16
6 i2 Y% w4 [* I7 Q17* z8 K& s& A" o
18
9 N! h: V+ m9 T) n& c) j) e19* ~- a9 z1 b0 h& q3 G
208 ~4 ?$ R0 ]5 j
21$ \; ?3 Q8 Y  i% S9 c: ^3 w; @
2. 数据预处理与操作$ E; R8 l$ e! s% U! M" V! _
#路径设置
, q, l! Y2 [+ u! E) q% `data_dir = './flower_data/' # 当前文件夹下的flowerdata目录/ }4 c. A* R- W$ s7 Y
train_dir = data_dir + '/train'
1 N# d9 E% |5 ~% M" gvalid_dir = data_dir + '/valid'
( j4 J, Q2 i$ ?/ g7 j1
; k& x: I* }% ?, f: ?" z2
) H/ @. Q! u6 q. T) p3. t. S1 B. h! e* z9 a/ w
4
# t! i% [0 T. {& i* u# e! i; `python目录点杠的组合与区别
9 ]% Y( e4 i8 O9 M  [/ p注: 里面注明了点杠和斜杠的操作5 _; j$ t2 S; s  G; J% n
: X: I/ `) v+ g; j
3. 制作好数据源
$ V. G0 O; A3 qdata_transforms中制定了所有图像预处理的操作4 C5 s" x! ^: A8 z. `+ q
ImageFolder假设所有文件按文件夹保存好,每个文件夹下存储同一类图片
* f9 |. [! l! O, K- L$ o" bdata_transforms = {
: l/ ?6 D' P  [' A8 x% d    # 分成两部分,一部分是训练% W3 F  Y; |* I% ?# o
    'train': transforms.Compose([transforms.RandomRotation(45), # 随机旋转 -45度到45度之间; s4 O) u* }3 P) ]  n
                                 transforms.CenterCrop(224), # 从中心处开始裁剪  y: @- d% B/ [0 Q* v' b
                                 # 以某个随机的概率决定是否翻转 55开7 }) ]0 ~% i1 E2 G
                                 transforms.RandomHorizontalFlip(p = 0.5), # 随机水平翻转) X0 ^4 k" q4 L; c% f+ L
                                 transforms.RandomVerticalFlip(p = 0.5), # 随机垂直翻转; k7 n  j  [! C, l$ c, m
                                 # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相: u, r2 t& J2 {# ?: t, X5 J
                                 transforms.ColorJitter(brightness = 0.2, contrast = 0.1, saturation = 0.1, hue = 0.1),8 r0 o2 M# ^. E: Q* M- F. i
                                 transforms.RandomGrayscale(p = 0.025), # 概率转换为灰度图,三通道RGB
0 M* L0 B  n9 h( c$ r                                 # 灰度图转换以后也是三个通道,但是只是RGB是一样的
0 F& T8 P. f4 S2 @                                 transforms.ToTensor(),
. {4 r' j& b6 b0 q: J                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值,标准差4 G. M& Y; Z/ y1 R2 t+ V
                                ]),
, h0 X4 L2 O# w1 `% }7 ?5 l" s    # resize成256 * 256 再选取 中心 224 * 224,然后转化为向量,最后正则化* \, Q0 o! u( |% C4 t
    'valid': transforms.Compose([transforms.Resize(256),
' O- t! O8 u) C# K+ d                                 transforms.CenterCrop(224),# z; X) I2 _( F* b$ M; D
                                 transforms.ToTensor(),
- [5 h" x- j! s8 H! F+ ^                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 均值和标准差和训练集相同
4 h6 O9 ?: V) H$ Y8 L6 f  D# {                                ]),' p6 V6 D) U4 n) V7 J! x
}
( k' U9 Y' @/ {/ a8 a: d
9 C  B$ H3 B  O- W- M% Z1
+ @0 b$ u% e% d6 s2
9 u9 a, v) w+ U8 `& e3
" ~% u4 _3 u7 u9 A9 [( d) O7 t44 Q9 t, l0 D3 ], M
5& m8 `- ~4 O/ |; o) O5 y! f
6
* A: ~8 @/ y) k3 Q* F5 ]7
% t; k- n+ `2 h! u6 U8
( I' n4 m5 {3 y! u9% @% N' I& Y( a( z; G
10
) s& N, r/ L, _) ]8 b) h2 P& u11& K: W( L' Z5 j* K
12
7 x( A+ i3 e. P, L; z13
4 @) c% E$ L8 B% N% v! Q' T14
% D. v) C  Y# F) c' N15
: t+ X) A3 I  D( g16
( `: \" V" [- e- Z8 \9 u176 n; _4 X& M+ B! h5 |, D6 k% q
184 Y/ t4 V# M$ h$ C4 L4 L1 H& b
199 Q1 o+ v; R/ H9 @3 L8 N
20
. U" H4 d$ F" |& W- Q- e21/ L9 ]$ a" j# {  G7 {8 }& K
batch_size = 8' Y6 k; }' I1 {5 R
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms[x]) for x in ['train', 'valid']}5 D- d  F. P  D9 H8 j- A; Q
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
! d0 H7 L" A# U; }: \, b) |3 L  Vdataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']} + h( ~2 g) s5 U/ d$ U
class_names = image_datasets['train'].classes. n/ t- b' z2 N% ^: z% ]% |

  x9 T) G- K1 Z" L% d#查看数据集合: u/ Y. j# l- ^) f+ u) j
image_datasets7 W; s9 w9 k/ @" ?

) \) T8 A# [* ?2 n' v( V6 k) I# ~1
! K% J1 c" V# O; R& o2# ^: a, K& x6 ]9 m
3" U/ W0 Z- V( f' t) g
4
, l# |; j1 Y/ }# Q5 B7 X5
4 c! V0 t# Z  f- n, |6# h1 G8 B% }- _: y# X/ b
7$ z& t) g/ c7 z7 [9 W
8
$ D3 M' x* J- R' B9
" n  [. w! y/ A% n{'train': Dataset ImageFolder' j3 p- }2 \# Q: y0 h  l! [
     Number of datapoints: 6552
- {" p* Q4 A) J: r/ Q     Root location: ./flower_data/train3 z- [# h: [; D; X8 }6 k" J
     StandardTransform
6 b3 s2 b9 V' {3 y. s% ]+ O Transform: Compose(
: `: W2 `6 i7 N8 L; X; B                RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
* Z; v  N; A, |+ `2 U3 u; g                CenterCrop(size=(224, 224))
' x! C3 D  l6 Z                RandomHorizontalFlip(p=0.5): M, d; K4 \! Y  Q
                RandomVerticalFlip(p=0.5)
5 z5 J; d% @1 n( O; O                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
3 l! ^8 ]4 b5 V# }                RandomGrayscale(p=0.025)* @* F2 \3 M) k- l3 R/ t+ q7 S
                ToTensor()
# \* u# Z/ @0 d, ]                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  l# N1 y7 a6 `            ),1 u* f; w. j4 O5 c) Q
'valid': Dataset ImageFolder
, S5 G1 r3 g8 o, z- N8 a     Number of datapoints: 818
! M7 S8 U+ p" R0 E3 L8 y7 u! M9 O     Root location: ./flower_data/valid
: p+ }; f! V8 B7 w; P     StandardTransform
, B1 u. ~5 _. H# E Transform: Compose(
0 e, x8 M5 E5 _: u1 F! ?) T                Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)  a% O8 [* [' W
                CenterCrop(size=(224, 224))
6 h. I: s1 V5 Y                ToTensor()
( f& Y! a- J, V5 ?                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) d# F3 ]: x' K' }4 V" L' V
            )}& G( i$ S- p/ h, O# ?
0 @& b! v( W' D) W
1; y7 u9 i. r( u& ]( G) [9 o0 n
28 C) B5 u& y+ o  P
3) V( v) M- o0 t9 R
45 |4 d+ x& |) f! M) d
5" F" r. b. |2 F+ J9 w6 t0 T
60 M, |0 w2 I" g! Y3 q
7
# I1 V0 I) d2 W* f8
( r0 T; @- b8 G: _9) N4 _: T) O; B+ U
10% {. S; q! c; g* k2 X1 z
11
3 p, B- [9 [  [( C3 e126 m" r+ L9 ?; C* A7 V9 z4 z
13
6 M% @7 B' e" m' v! i14- u; l* p9 N8 g' `
15
9 {2 R. s. v" B8 S4 w9 M% X1 b167 Q; @; g& l0 O
17& k) c$ u1 s/ o
18
8 k/ w, i& B, q7 M19
* o8 N) \% c& J+ t5 e4 V; M. A205 R7 V( a! A+ H6 S
21, ?- W' x0 K! r! ]3 c
22
) ?* j( J: X- \! O23
  x! }. M) P2 e' B24
( F; J- K" x4 L) T  r# 验证一下数据是否已经被处理完毕
9 W3 w9 E' q- L# Xdataloaders
: \$ S: M- ?/ e1
+ k3 R6 `1 T- `' `2
% E" l4 R6 B3 r! _: e{'train': <torch.utils.data.dataloader.DataLoader at 0x2796a9c0940>,
. K- p* H' s& e) C4 l: B 'valid': <torch.utils.data.dataloader.DataLoader at 0x2796aaca6d8>}
( N: F4 U: P7 }6 X5 {1
! U' p/ }/ T8 }, D5 G  X2+ W" }' \& t2 K) @1 ]0 u
dataset_sizes
+ e0 l& `; `# O1# P  a: R6 G6 L/ x2 B" [
{'train': 6552, 'valid': 818}& l9 P8 I2 A7 O: @, t, Y
10 J# H/ d3 F) v4 d8 I
读取标签对应的实际名字) O5 Y8 \3 W6 \! f
使用同一目录下的json文件,反向映射出花对应的名字( s5 S( z+ O8 o* q. P

; M8 ^) b3 w2 G5 i8 W& Y  S- x, Fwith open('./flower_data/cat_to_name.json', 'r') as f:& v0 f3 A( k* ]: D4 e
    cat_to_name = json.load(f)
( Z4 M2 H' |- F. A* U* P1- Z( G' A6 r- x+ P) f
2. W$ E& o2 A, w/ f% h* ^8 g8 y
cat_to_name
, b4 v- {9 q" e/ C& N1
, }* X4 ?7 _; K/ W* o! b9 L/ b7 B{'21': 'fire lily',
0 r+ S, R! b2 E: Y '3': 'canterbury bells'," X( K. x/ M) q/ M
'45': 'bolero deep blue',
0 e, i% X2 s1 ]/ W6 l& O '1': 'pink primrose',
* ~7 R$ R0 ?) P '34': 'mexican aster',
: ~( x5 ~5 [( _ '27': 'prince of wales feathers',* z1 x* f! c' b0 l
'7': 'moon orchid',
, E4 g$ A$ F1 j  Y" ~" z; n '16': 'globe-flower',
+ Q5 K9 P" I' d '25': 'grape hyacinth',
0 J0 U! B; p  e% M* n% @* R* L '26': 'corn poppy',
+ I7 l$ ~! `* A" ~# x" r '79': 'toad lily',$ q2 @. H  G3 d! q
'39': 'siam tulip',6 n/ K3 B- ^2 ?
'24': 'red ginger',: S; c& U/ y" b
'67': 'spring crocus',
& W* g$ d2 i" [* e, u6 K# w' | '35': 'alpine sea holly',& l, M5 _$ q9 D: Q( d9 j( U
'32': 'garden phlox',
" {, e/ ~1 W; o8 U3 P '10': 'globe thistle',
+ H  E! h6 u) q1 |3 `4 ` '6': 'tiger lily',
. Q5 Q/ _. a/ O# b9 e0 u4 O; C" } '93': 'ball moss',( Y7 D/ Y! G: {; A* T
'33': 'love in the mist',
0 c' f' [/ G$ d* I, `6 r '9': 'monkshood',: D* B) m; Q& p1 Y3 W* ?+ l
'102': 'blackberry lily',- y$ K. J: [' E) o, G7 k" o, P
'14': 'spear thistle',
. c1 f4 r4 O$ s7 V2 S4 | '19': 'balloon flower',
( T5 e$ I7 B' }; X7 g; l '100': 'blanket flower',
% d4 \8 c) x$ ]+ x/ J0 a5 v7 }8 T '13': 'king protea',% F/ F% i" O1 j6 b. {
'49': 'oxeye daisy',
/ m, R; N& Y9 u. P1 }! A* h '15': 'yellow iris',: S3 l2 s1 v: p0 \
'61': 'cautleya spicata',  D! L6 l1 y# b& ?9 H3 M
'31': 'carnation',. V1 Q$ C) R/ ]3 I/ D
'64': 'silverbush',
! i1 W. m" o; m  w. |- N# E; o( q+ ? '68': 'bearded iris',
. A2 H$ y* ?4 a9 g& b1 X* l* W& v '63': 'black-eyed susan',
( r# W# d: v% [  R6 ^ '69': 'windflower',
6 D" M) }; M' m5 m1 w+ x# R '62': 'japanese anemone'," g6 Q2 x: _) W; U3 J- y0 F
'20': 'giant white arum lily',2 {# ^9 g. `- C3 u
'38': 'great masterwort',
# S; O- Z( I. e9 e '4': 'sweet pea',
5 ?$ q7 x, M- X2 q' _ '86': 'tree mallow',
, F+ b4 a! g9 k6 u '101': 'trumpet creeper',
# b; m* j( L$ ~4 E7 e, I '42': 'daffodil',
# H& a. F+ g: j& m) j0 w& C '22': 'pincushion flower',& n; W( u5 S- q! a9 T  J
'2': 'hard-leaved pocket orchid',
2 Q$ Y2 `. @1 y3 {0 Y$ k; B9 [3 ? '54': 'sunflower',
6 P5 @; i4 M) I: x '66': 'osteospermum',
2 l8 [" A% P- Z, I '70': 'tree poppy',8 z" ?  f) Q: k9 e+ ^
'85': 'desert-rose',
6 }/ A. l) o* S2 R  u7 J2 C '99': 'bromelia',5 n8 ~+ u) G- E% ?1 P8 Q/ h; }
'87': 'magnolia',
0 U5 w6 ]4 {9 _% Z, r( z" Z '5': 'english marigold',& c9 a) t# `" _* ]. B& v. g/ V
'92': 'bee balm',
% J" j! J- u* a/ H '28': 'stemless gentian',
+ n5 f# e9 ?6 j '97': 'mallow',; m1 r* V' [3 _. G0 _# |
'57': 'gaura',. C/ W5 M, D) G* P( R8 p
'40': 'lenten rose',
, @8 f/ T& B' o9 {( Q" Q1 q: X '47': 'marigold'," h/ J; O3 ?3 ^- _& q9 |2 S* i9 ]% ]
'59': 'orange dahlia',
* C- V2 \; i$ |; l '48': 'buttercup',8 ~6 z* i3 Q# B0 A
'55': 'pelargonium',+ E" C& e7 q. Q* h
'36': 'ruby-lipped cattleya',0 \, h9 J- r3 s( z& ?1 V+ J8 {
'91': 'hippeastrum',3 C" B: E' i/ u1 w$ ?
'29': 'artichoke',* h( c: c: @8 h
'71': 'gazania',1 i; [- M$ v' R( f: [! ^& k5 e6 [2 L
'90': 'canna lily',3 g+ K% a7 u' }4 Q0 ]& S
'18': 'peruvian lily',0 x7 E. C, C! O7 }2 X! U. [0 U$ W
'98': 'mexican petunia',& S% p  c7 v+ \& h9 o
'8': 'bird of paradise',
7 C( |# B1 k  _) ~9 N9 d( k '30': 'sweet william',% G6 n- R- g% o$ P
'17': 'purple coneflower',0 R! n: @# ~+ \
'52': 'wild pansy',
  q8 j! u# B7 p5 R# t '84': 'columbine',/ {; I- ^& S2 m$ V4 ~' f/ k" o9 e
'12': "colt's foot",
1 J( F: j2 a& H+ s& t/ U, z& G8 O '11': 'snapdragon',2 Z+ Q1 a8 C* k5 B# a/ N
'96': 'camellia',9 O( p3 R/ x1 k/ w- |2 v
'23': 'fritillary',9 j. \: l* ^" s0 J! V
'50': 'common dandelion',
) I* m, F* M* y$ B9 e. d$ ] '44': 'poinsettia',4 p4 O" @( N) L( J! D, @
'53': 'primula',4 U' j8 Y* W5 \1 q* b
'72': 'azalea',+ z, |! H' a' I+ S  S. A& Z0 W( {
'65': 'californian poppy',
$ ~. l) a- n$ [% S  R+ O' c '80': 'anthurium',3 Y# h. t0 v) y) A+ ~" Z6 o8 M
'76': 'morning glory',
4 z7 ]5 ^+ V; @) T1 |8 ] '37': 'cape flower',
7 ]7 M3 ^2 T% n3 Y# c '56': 'bishop of llandaff',7 m. d; A8 j; J8 |% Z+ O* K7 h
'60': 'pink-yellow dahlia',+ {3 f0 {" k. Q2 M& i
'82': 'clematis',: t8 d! f' p4 p  P6 l: {
'58': 'geranium',+ R$ C) [2 q9 j- Q: o0 l% ^
'75': 'thorn apple',
; O" n) s: h' B '41': 'barbeton daisy',% M! k7 y5 s& K% E3 q
'95': 'bougainvillea',
4 X& U7 K' y* G$ ~% \8 V '43': 'sword lily',& Z# D: \2 l' L+ o5 p0 c; X
'83': 'hibiscus',
5 t" W- R; _& c1 z" a" q- y '78': 'lotus lotus',
" b) c9 d: T7 x5 `8 u '88': 'cyclamen',
0 S7 I2 w4 X0 Y4 B2 _ '94': 'foxglove',
0 `" T, q# }7 f2 @, h '81': 'frangipani',9 C7 L* V2 ^6 i4 O5 ?2 k
'74': 'rose',
" s' V) M( e# s; X8 [1 s0 S  R '89': 'watercress',
5 V8 N4 z7 ]$ B8 ?; ~. c1 f* l  Y '73': 'water lily',! b5 H9 P! W" ^2 j
'46': 'wallflower',
1 r  {" @: x2 `$ K. V5 [ '77': 'passion flower',
. f+ @0 Q+ c, ]/ x1 u '51': 'petunia'}
7 m, S" m. Q7 L8 Y0 \6 l1 Y% y8 N' y6 `; [
12 x( I2 z) p5 K3 R
2% j& Y- r4 y: x+ r3 W
3
; x( k0 p; N- i) z" H( |; U# ]41 v; F; G0 X" q: K9 t, J" f  |0 K
5! i5 H  e  W" k; S% \/ _  [
6
% ?1 e2 D. d; M5 z% Z' H  E78 Z1 I" D8 M" i! {2 _
8. U5 E4 S( U) W& f2 \( e% W6 ?
9
" r* F* x4 @  W1 b10( O4 O' K7 Q# g. F+ r
119 V' z4 }: ~. ?& A6 ?) G* ^8 u
12
* N3 }& ~3 g. c0 ^- b13
' v+ T7 M! i. u, v8 Z14
( u8 m6 R  f  }+ F15
) `3 n9 M1 h6 B16# J: p2 `3 H: w
17! Q' k7 B, b& r1 `
18+ s* V* a: N/ ], o* I9 d2 x. e3 ?
19
1 I% d0 _4 f6 z' ?! S8 r20
1 H! ^1 Q7 ~  e21% c# f7 ^4 ^1 p" R2 R( @7 Y' P
22' x, e. c' w' v+ x/ j
23
3 ~, @. B0 F: ]/ \  `) F2 g, }24& c5 G, x5 [; K+ N
25
' z" Q- D3 e" x26# L6 s# f7 F1 l% k8 Z  X6 g$ t
27( U& F9 B  ]  }: l* r' R
28- b0 [" v- K$ C1 W' D
29; p. w8 _; B1 `) D1 q
30
9 ?5 n: ^$ m$ |! `5 `313 d% s, S$ S- r3 {4 }& _) C! j: S
325 X/ k0 N6 Z  \0 K" l7 f. K" o- Z. L# n) [
33
* r) v; V& @6 ^% F9 ~34) f  N) i2 n! y! u# t
35
* \! w% U8 u2 ~4 l( V/ z36
8 s% N4 k6 z, z5 c& w378 u, f: o' [/ c6 i6 ?
38: i( u( Z1 A% M( G, U7 v
39  _, N5 {- j$ D2 M3 W& U
40+ z. N4 t+ S. L, j* G# L
41
; B( s( Z; J6 r' D, H42
- i. g7 w0 w" L43
- ]- l9 F. m4 Y1 D! s* \" F44
# d$ b# G/ i/ b& n# d4 L45
" z! j  j0 m7 r4 p46: a1 K9 F; S1 z3 Y/ r: V, @
47' s1 C9 S  x* V- b! o7 Y
48$ x3 G" b" G8 S9 L
49' B) f- x& W7 _3 b; h
50% n; X5 L) N$ ?* Y+ D
51
- w: j: ]& j% Z( _7 w8 J3 S2 l' ?522 p0 P- t: X+ Q* X
53
! @  O. [2 q6 M) P7 }9 j54! }$ ]9 i, d1 D% \3 h# k
557 l5 e8 a! |! v% s) J
56( n8 i* [' F: C& C! E5 G( p4 E% k5 z
57
. {/ D; o3 b1 o/ P. q- P2 R58& c/ ~1 c" I& u4 N: \/ M
59& I( H  m$ Y# \4 i" a/ @: B
60
' m7 Z9 N& t. S1 j61
% B; M& p& U# B) H- \' d626 [3 d8 f5 b; k* W4 H4 _
63
% s! [9 U' m" p% u) ~. t% {6 o64' t- n) y( S& G, r
65
% D; q: C& m1 ]. e+ [' n66
% p0 V3 {$ D) W670 [% F' m' X+ Y( w
68
- E; y% B  T: I69# E; A" O# y$ N, U
700 k: a  \( Q; r- L
71
- ?+ f# h5 I' M1 E) j) {2 J& a72
/ i2 l* Z  J  l- k6 E73
& M$ j5 \% i& M- J3 U74
7 S/ A: m, T, K% a75
* S( x6 j7 F( q$ X. U" \3 d5 s/ S4 c3 ]6 L76
% j: O/ O" b3 P8 U( J& c77# Q% P4 q8 O1 W( `3 y3 w8 S3 B
78
0 c1 L* ]1 D6 N4 ?2 Q/ N5 m79
, Z  H8 j; f0 a, _( Y0 i  g80
% C) ^5 |0 N2 y0 T81  k& D9 H( o/ K+ Z" z% C
82
4 S4 m3 Q9 R- r" E834 D* J, r3 D' M6 C2 Q
84' U7 w0 p- z" O! [& U
85' ?; Z) v6 ~  ^; L/ O0 D% u0 E
86
. ]0 i: m/ |8 |* e' x$ N" ^4 c! h87
1 x0 i& v( A- C' g; y0 ^88( R4 ^* l; u8 t3 h  k
89; ~% \. \& @: D3 \- E7 L2 y9 {+ D* w
904 N% N7 S+ I% A8 b3 q2 D/ K/ w9 @
913 H8 D# b( _" l  f' c% q. Y
92
5 h5 w( Z. w7 h9 ^93: \6 E# k3 J& V
944 U  P* e, q0 S% |# b, A
95- \& |# B/ e4 m8 C, S4 W, A8 c
96* x# d  e% J" w! `! J- l9 k) d
97
! {; b. u+ D6 P+ P) i# m98
8 I- l8 b8 W/ n3 B/ Q3 F99$ i2 `3 Z8 O) v+ M; U" E7 Y
100
4 @8 u: |8 ]" K101
" L( E+ L( o. q; S102
% I3 V2 K2 E5 l% |4.展示一下数据
2 R; z( y6 R+ A2 Q; ]+ fdef im_convert(tensor):
7 ~$ t$ H5 F6 F3 d( Q" b/ g    """数据展示"""
) H) I. ?" J9 G1 s8 E    image = tensor.to("cpu").clone().detach()! y7 v/ @% `( i8 K3 M, j
    image = image.numpy().squeeze()6 x7 k- q( n6 x3 }0 p; d+ o
    # 下面将图像还原,使用squeeze,将函数标识的向量转换为1维度的向量,便于绘图
4 G8 J: m. S* e% `    # transpose是调换位置,之前是换成了(c, h, w),需要重新还原为(h, w, c)
2 `  O$ @' u7 S- k$ A( T9 D9 _% v) J! c- R5 H    image = image.transpose(1, 2, 0)
3 w. ]% j" C+ H3 J    # 反正则化(反标准化)
; c" C/ I5 f5 z    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
# ~7 S8 V9 }9 M# {3 D' m
* D1 @0 ?  L9 V% L    # 将图像中小于0 的都换成0,大于的都变成1
: u1 ~: k( b/ Q' k    image = image.clip(0, 1)
4 l: T5 Z/ |8 f/ O3 |
' L6 w1 L9 ~! x, A& M  c4 w, c    return image5 q3 o8 J6 ^/ O/ D, y0 g0 z
1
! B8 i! H( D" B* @" U" k4 I2
" i" J* q8 t( Y4 o" X* T3
  Q! c8 E. [  b8 z; Q. m" n  o4
! q- U1 H" l* {% {9 A" g9 j5
4 N  S5 g6 O, l3 @8 j. l6+ ~  ~4 |0 B8 t, h% D; Z
7
$ W% [, D3 b9 o. i+ F1 |8
7 T& S' h/ g! ?/ a( M9 j9
  G& e. y+ D0 K* P6 G+ N1 H103 I& D, {9 T# O; Y$ X
11: ~1 A, Q: w0 u6 r7 w; n
12! \( ]1 G9 V) M8 o5 X/ w
13# E- E- z- w  A6 E2 m
14
# \3 O) j7 s$ x5 c& \+ w# 使用上面定义好的类进行画图# J0 p9 j% W: {: t! J
fig = plt.figure(figsize = (20, 12))
5 |% a4 {; ?; vcolumns = 4, Y3 M5 V  m, G  v$ R' I
rows = 2) Q7 B  b! y' r

6 Z4 u& S- S$ d. u# {+ J8 ~# iter迭代器8 U8 S9 G3 Y1 u
# 随便找一个Batch数据进行展示
( R9 }! t5 u6 c9 A. w2 wdataiter = iter(dataloaders['valid'])7 J; l8 \5 u& s: Z4 `
inputs, classes = dataiter.next()! J4 I$ T! M; V4 e
  i1 s/ E1 D* {  \/ p
for idx in range(columns * rows):
6 L( t% e5 u% `! W8 C( F1 x2 s' P; N    ax = fig.add_subplot(rows, columns, idx + 1, xticks = [], yticks = [])
) p8 a3 T: q6 f6 I8 a    # 利用json文件将其对应花的类型打印在图片中
8 ]) r/ }( v% n8 |) P+ b    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])4 V0 ~. A, @) C% @4 a4 L
    plt.imshow(im_convert(inputs[idx]))
0 k' ]: k7 W8 c8 f% B: gplt.show()1 }0 @3 o8 ?1 h0 b- A: J. e: _

( ~- x1 ~3 U) v/ c& z2 T& D1. O* ~+ H/ I- b9 Z' X
2% O5 ]) Y$ X% p% Y
3+ U4 r1 f6 g( h. F  R
4! e( f: b7 U0 o+ x) v9 ?
5  g9 W4 ^5 z) M6 m
6" C+ }3 T" C7 t- N
7
9 Q" ^5 B7 @, ^85 c7 ?: }# A% L! m9 @
9" D: [: ?( T1 M& k/ T  A
109 t/ u: }2 O" ]4 Y3 d
11* O" I" F+ Q& U1 }4 `4 t$ g/ |
12, \# z+ _% ^7 W7 n! b
13
/ O7 a) o8 n! \  ^9 m$ ]7 }9 s4 ]% p0 K144 k5 y  W2 _" L0 l
15
  g) R. I1 s0 k) l: K16
% p1 ?; O; n' @& @! W& }- c- V# |8 k
8 f0 v( n! n) N2 g2 w
5. 加载models提供的模型,并直接用训练好的权重做初始化参数9 O/ b) r4 G( M# ?  q& G/ U! K+ p: ?
model_name = 'resnet' # 可选的模型比较多['resnet', 'alexnet', 'vgg', 'squeezenet', 'densent', 'inception']
& f% v! E: G  D4 G+ U7 j* e' M# 主要的图像识别用resnet来做
- W1 \4 u3 `4 R  T& `: _0 }# 是否用人家训练好的特征1 D9 X  i+ \* L
feature_extract = True
- g: \- {, l/ m/ O0 F. a1; J6 q) T' N! k
26 C* A! {+ d/ s4 C* Q' g
3
  W  B7 y; D4 W3 E4
2 v  y7 C/ T; J) B( s' ~7 W# 是否用GPU进行训练2 \- V: N3 P2 ]6 L& Z, w0 w, h' c
train_on_gpu = torch.cuda.is_available()8 o$ |1 S& Y, o. [* U( S+ r) }
. q. i. E* _8 R1 P
if not train_on_gpu:, o# _, R/ A) p0 m8 y
    print('CUDA is not available.   Training on CPU ...')
$ a0 v. x9 n1 W9 y1 ~) Qelse:4 g" f( @/ p# ~, v1 Y( p$ o
    print('CUDA is available! Training on GPU ...')2 _( C$ H& G8 c4 G' G% K' d8 c
5 k! [! m+ d( r  v
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')8 Z- n2 v* }' j8 k/ d/ P: h
1. |4 [! v3 o7 s; Y% O
2
* u$ N( q3 s; S2 u# D) A3
. Y3 z) O+ X5 l- n7 V. \/ @4
# }5 w3 a) q5 e8 i# Z0 W' e5! {) n; B- \: w7 Y
6, Y2 r3 t7 a; H: i& y( h) g
7+ c2 d9 j- l; [
8
- e/ \5 E4 t% ?2 F9
  [  m* s' ]. J, b2 g5 FCUDA is not available.   Training on CPU ...: z4 w2 x' i$ D: S  y
1- S* M# S6 t- z  H0 S( n0 ]9 r6 u
# 将一些层定义为false,使其不自动更新
7 m# z9 l0 r6 y$ E% W3 ?1 Vdef set_parameter_requires_grad(model, feature_extracting):: T0 g4 X8 X1 O3 g8 _
    if feature_extracting:
0 p! x3 k' ?: A8 U$ }% g5 [        for param in model.parameters():
7 `5 n+ F" M# x8 S: S$ d            param.requires_grad = False' v# k' D$ D9 D9 O8 z9 s  K
1
) Y" p* h7 f: l2; n9 Y% H3 D$ Q5 l6 q- C
36 l! Y' [8 [. v: b( C* y9 d) p
4/ |5 n/ V) h' v
5! z% g0 N1 t6 o3 G6 U
# 打印模型架构告知是怎么一步一步去完成的: _3 C, ^/ d/ y
# 主要是为我们提取特征的9 g8 Z) H( u+ t7 z( n7 _7 V3 k
' s7 K4 A# \, o7 x
model_ft = models.resnet152()
9 z  r" c1 d6 ?" v8 X/ K/ E  Tmodel_ft4 Z# [! G0 P  g( l2 W6 k( h) o  K  x
1
. w4 `+ p$ E9 ]# U: i6 l2; c2 w) D& y0 O, f4 o7 `+ ^0 D
3
" [% i+ _! z3 l7 e) Q6 l4$ R" M& p/ p; f- z: y  u2 o
5
0 W( n0 K' X0 z, D' ~ResNet(5 w9 O6 s6 F/ q
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
7 U  Q" n3 n& d, M. u* Y( K8 m  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
7 P( y, B- z0 Q8 Y9 J  (relu): ReLU(inplace=True)' L: h2 h+ r3 |1 x$ k0 ?; X% J( W
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
1 D2 q6 u9 n- `& l  (layer1): Sequential(
0 @6 m3 o5 n/ }  n8 a2 J0 F    (0): Bottleneck(2 r" w/ A. Y* G) N5 O; a1 d" a
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)4 G) u/ a5 [2 @: Q. g+ @. T
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)3 {1 L- V! G3 W& ?6 m% F- c, K
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
4 L. V; L7 M. ^4 P      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)# `% v1 s  s7 ]: s
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)8 p0 ^' a1 K3 k# v  h8 t
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): a3 `, [: [5 ?# \) _, E) Z, D$ M. P
      (relu): ReLU(inplace=True)
7 ~( R( H5 i  u" e      (downsample): Sequential(! h" J3 j0 D, T$ ?# c
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
$ d% J" D% T3 O& c/ o        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)0 g0 c' x8 p  ?; [
      )
  H3 s: {8 t, G. Z    )
& v9 Z- v8 X3 ?( Q  C) @5 g$ l中间还有很多输出结果,我们着重看模型架构的两个层级就完了,缩略。。。5 P2 X; Q% b- E7 P5 |! \1 W
    (2): Bottleneck(* w" c6 O( V1 E# p8 g/ l6 a# Z
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)# @% t7 V5 ^, e3 M* y
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)# S: Z% A& q# e0 W1 ~& T( d
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)  N+ s/ i. i* c5 n; W
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
! \, v+ ]$ I# M9 X+ e* p      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False); w9 v% t6 d/ ~& T/ f
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)* y% R+ J! r$ P4 N& |8 p$ n
      (relu): ReLU(inplace=True)9 b1 _) b1 \3 |: ]; B
    )7 B; B3 \" Z' Y/ {
  )
# S0 }3 v9 `4 Z. ]+ @% Z  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
5 b9 ~3 T2 {0 y  _# w  (fc): Linear(in_features=2048, out_features=1000, bias=True)! N" n/ \) N7 S3 u6 a# \8 B
)8 v$ q/ I4 C4 [' ~3 X- n0 `" Y
) j& _3 {$ A. N6 k" E
1
' r/ s8 S$ w0 o# k* g2
5 |2 `1 h+ x1 U) k9 I( H$ f# m36 {7 |  U6 n2 C* D; I' ~
4" B5 L0 P" y! ]
5; p. ?1 x; K7 A: r
6
- M; I0 C2 G3 S+ [! X: H3 R  @0 g" X- }7& h' G# h6 a* x) q
8
- S1 ?, g4 \3 e) }1 y9; g5 K0 J3 \* {' R* d/ v
10
0 x/ j: ]2 `( i11
- O, Y2 s8 z( t& M) r7 \12
- Z( V4 _( u, U13. j! {! O3 {7 n
14
. i8 p# I, \) P# o% m& x15
+ ~7 i1 Z; S- @. F  g$ j16
5 K, _: ^8 I. r7 W( V17& A+ r- @  k" S& }0 @3 j# j* ?
18
$ s' y) y9 B: S: T6 f' X0 `19$ M5 F' z- f' `  G0 O8 f$ @+ H3 B
20( w6 P8 A, a% [% }
21! }; x, z* k; v2 h
22  d( {" `: T2 K6 G2 _$ [
235 [6 s. A' _$ A# n: s$ J! U% Y3 @8 ]
24/ ^  a2 i, y) }& w7 m' q( o
25# u. U" g; c8 w8 M2 m5 |
26
! u, a7 F+ Q+ F. x8 G. M7 d27
; ~# q4 l) w  \' }, j  N28
6 h+ F8 b& M  g( Q29
; u2 {1 w9 Q2 N. t5 }7 d+ G30
8 `$ g$ L. \3 w) j+ L. |+ R319 s7 [2 H+ X' {' v  V8 d; H
32
" P: r$ F( |  c339 q- a: U+ F0 C+ u/ W% z
最后是1000分类,2048输入,分为1000个分类
  W% _) t: p( k( A- m: A而我们需要将我们的任务进行调整,将1000分类改为102输出
; k/ }! W0 I- ]2 q
# R- f& @$ E/ P' w' a5 u6.初始化模型架构% x$ _1 {* F5 N0 I9 w* j
步骤如下:0 @2 R  Y# K4 [$ z; `; X, O3 h% B

/ R: O4 ^0 A: I$ {8 k. z将训练好的模型拿过来,并pre_train = True 得到他人的权重参数5 r( L' k& e0 u$ {
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False)
+ b* ~1 A! |# f# `无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数
7 F* }/ d' t) t; S! ?" @) @官方文档链接( w( q+ s. P7 F; x2 o' u, b5 }
https://pytorch.org/vision/stable/models.html
' j6 D' b3 B- |7 q  o! w8 o
0 F% y/ j& Q* N& D4 ]3 \# 将他人的模型加载进来+ H1 ?: W; K( H- C. k
def initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):
) N6 f! P9 ?0 Z& y: d    # 选择适合的模型,不同的模型初始化参数不同
  E3 ]$ J: {5 p    model_ft = None; k  s4 b9 X4 [# N% ?1 b$ L
    input_size = 0
5 i  n: N. \& c$ _- {5 I3 I6 T) P9 U. C, W7 E) d
    if model_name == "resnet":
9 m0 z' t( w7 j! d5 ]* z        """
) Z& I$ Q7 ^% T4 l0 c        Resnet152
  D2 Y5 o9 @6 D; R) O8 F% L7 B$ `        """( l$ `3 W+ C0 l- v
0 [4 k% n: ]6 P6 N8 h
        # 1. 加载与训练网络
" {3 ^6 E) T% c7 U        model_ft = models.resnet152(pretrained = use_pretrained)3 R9 z$ F! b8 G9 V+ Q
        # 2. 是否将提取特征的模块冻住,只训练FC层9 u& u8 Q; \. P% k
        set_parameter_requires_grad(model_ft, feature_extract); F0 o! [* l2 ~4 W/ z" X. m) z
        # 3. 获得全连接层输入特征
6 F: h* ~5 V2 _' l! a        num_frts = model_ft.fc.in_features# z; }/ W; l: z2 W
        # 4. 重新加载全连接层,设置输出102
" j& J! `1 d8 p& [        model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),# J5 S. d' I! ~% o: Q
                                   nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为11 i- ^2 o3 w& P& ~4 V2 U7 x
        input_size = 224* D: c3 p: z5 n- g, S1 f6 h7 H
* ]* w! v6 S) C* p! g% ~% o
    elif model_name == "alexnet":
  y, m; N4 `# M+ H6 n        """: ^" T0 b. l6 G0 t
        Alexnet# N9 c; E" m+ F7 {: v
        """+ Y8 y8 o/ K7 @9 w1 T" h) O
        model_ft = models.alexnet(pretrained = use_pretrained)
* X1 g$ Q' t7 _  X        set_parameter_requires_grad(model_ft, feature_extract)
# v$ I' C4 f- x) Q1 h& S" |" ^' F8 g8 y9 ~
        # 将最后一个特征输出替换 序号为【6】的分类器$ z9 q7 h- v: s$ A8 l
        num_frts = model_ft.classifier[6].in_features # 获得FC层输入9 L9 ^& [+ S, I% P0 I
        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
! y4 p: l! b$ V1 ~; n( X5 s4 A        input_size = 2243 }" `, m8 H! R: j/ l" F. M
8 w2 u+ w9 ?+ E. @4 X9 u4 ]
    elif model_name == "vgg":
% V: @! n1 j6 _' A$ b        """0 e* p8 v/ D, b, W' r
        VGG11_bn0 L$ a& n  K. B  F! K4 }
        """; A/ B/ E7 ]* M# @- f4 i
        model_ft = models.vgg16(pretrained = use_pretrained)
& U- @" U) ~- s" G- I8 x        set_parameter_requires_grad(model_ft, feature_extract)
8 u0 |0 A% e; w8 V) L' M  u        num_frts = model_ft.classifier[6].in_features
9 N, |' J, s& _; X, `" W        model_ft.classifier[6] = nn.Linear(num_frts, num_classes)
" L* K# r" b- X  i$ L4 B3 |3 O        input_size = 224
% u& [+ h( w, c; [  d3 I: M% I
8 L: Z3 F6 M  E5 `6 x8 r    elif model_name == "squeezenet":5 D( G( }& `% S! R3 x6 n* ^, o6 f
        """
- p1 L9 B) J) j        Squeezenet; Y& @; U* O2 F1 ~) j# i" Q/ `
        """
2 C- |. L6 L/ t6 T$ ?8 n        model_ft = models.squeezenet1_0(pretrained = use_pretrained)
& i6 C& T" N: G) f8 D3 ?        set_parameter_requires_grad(model_ft, feature_extract)/ q# a- d1 G4 r0 K# t  d) ~* k
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1))& E4 p# c* G! Q. b$ c* Y/ y
        model_ft.num_classes = num_classes( r, r/ @$ W8 p: t9 z* H9 J
        input_size = 2249 ^( v% {6 z' z! _" d7 ^( Q7 L& J+ {
! Y5 F. O6 J, Y
    elif model_name == "densenet":
" v8 t1 T9 ]# o2 C  G( _        """% W: B4 ?- @& j7 f7 {# T4 n
        Densenet7 G+ C/ M# o5 ~9 ]* P! k& p( W
        """/ |1 Y# H$ M1 [' K
        model_ft = models.desenet121(pretrained = use_pretrained)
9 m+ ?0 T- b. f/ Q" P8 o        set_parameter_requires_grad(model_ft, feature_extract)
# S8 O& ^8 i$ @        num_frts = model_ft.classifier.in_features
1 b; v7 _! i0 h4 w* p5 C        model_ft.classifier = nn.Linear(num_frts, num_classes)
3 p5 g9 f5 h4 X. H        input_size = 224
& u8 D! \: ?' K
+ F; c8 w" D8 g, X  F    elif model_name == "inception":9 a2 A3 ]0 y+ f; o( l, X+ u7 j9 Z
        """- S+ H- B0 l$ `' V7 r) q
        Inception V31 u/ K0 b, S6 j* F3 s
        """1 L+ A, P8 Z+ D9 `
        model_ft = models.inception_V(pretrained = use_pretrained); F$ I* n; R+ \  S4 G: I, a& t
        set_parameter_requires_grad(model_ft, feature_extract)
- _* O3 u* y* ?7 O$ ~8 c4 @
5 a5 X" a" X9 q. \! R        num_frts = model_ft.AuxLogits.fc.in_features
2 v3 B% u9 m) O        model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes)
8 {: C1 k0 H8 y# e6 ~
! _! }$ I; _- _1 `        num_frts = model_ft.fc.in_features7 V; d% R1 ]9 ]
        model_ft.fc = nn.Linear(num_frts, num_classes)0 c0 N3 J/ o+ I* {- y
        input_size = 2995 ^9 B4 R3 `1 J2 D! X6 C: V# o
: z& P: k: I, T# H& g* u6 D. J2 j
    else:
6 m- U( s, X  s& E2 N        print("Invalid model name, exiting...")
; M9 U  K9 s1 j% k5 r" P        exit()
5 {! \+ U# i) m3 M0 P- |  A2 T" q. `' E, |% {0 z7 Z. }" e
    return model_ft, input_size
0 Y$ B8 m; y7 u' d* I* m" }1 p' D) Z
# z3 `: k! I( a" e% }/ v1
+ w: c9 B3 H  i2% i, `1 L$ c5 ^3 D9 T+ k1 u: \
3' d) U% u+ [: S& n4 J* P* `
4' [4 f# D6 Y; h8 R
5, t$ [3 g7 Y. y3 B0 w. q. P
64 H8 R: ^- U+ J7 f& M% ?  q
7
. f5 r0 l1 T, v/ y* k% a2 k8
( I+ G+ {# t0 W/ [0 r) `9
/ y6 d: a5 M+ I2 O10' n3 W  A5 v% W6 |  S7 R
11
/ c  z/ ?2 b$ q( M$ Y125 K) N* q0 ~  w1 I7 a( o) n! E+ w
13
7 Z4 q' K& g8 y' ]" h9 `( F7 U14: V; S. o" j; i+ ?
159 M8 I6 o5 `9 d; U2 J
16' L5 m3 J! h3 B8 `! ]! T/ h6 u
172 g' @& U0 r" B* b  U
18; x7 y7 h. Z. x1 {
19
/ u! d4 ~' L6 R20
% ~' Z' f: ~; j: F/ P" x! [- Y0 P3 G21
" N8 |2 \; g2 r) w; M% c3 V- g22
( y! ?" b8 e0 N! R2 \+ T236 F3 l% [' G5 \) @" f
24
4 j% Z8 [& K: a+ d0 W% U# H( T25. M* t# B, Z& S  G, P; ]  R3 z
26
' s3 ?  B: z+ p. j2 v1 m. ^27/ |9 B! Y1 P  m: f
288 n3 g. M) s9 T9 N
29
0 O  X8 k7 h. v305 c8 V  t9 _; N
316 ]) u  u5 q* [& ]! N# ~- y
322 Y) B- j" e, Q$ @) v; N8 J  M
33, I5 w- Q! n7 l6 v4 W2 e: Y4 I
34
* N! ~& Z' I/ U; K8 K' O' u' H' R35
+ v3 r9 T% @' }; y* h2 e9 I$ W36
, u" }( ]& }% q+ K37, X4 ~! C" i6 w$ p, |6 l$ d
384 U# |7 O5 a) Q1 K9 @$ s
39
5 y. p, d- A8 K. r( D$ F. x40
. ]3 k+ }+ W4 ^41/ r  y6 u7 U3 k. ^2 j) N
421 C3 j& w3 }2 |- u- I; b
43
/ S' y/ V: M5 a9 P. O" C44
$ B" @5 M  @( t4 N5 X9 N) d45
$ y& U8 K8 [# W$ R7 z, @: E46
. C5 i. g" K, H47
" Z0 Y( J3 b# E# G2 W4 _2 C- ~) N: Y48
* j, _  s$ ^/ V" r' ~# h49  E2 Q: C& E1 ]4 r2 B4 R- I: }' I
50
1 ?0 T. Z8 v4 Y0 ?7 {51/ H) q/ G+ m5 X$ r
52
$ Q* y8 K8 P9 R) W  L/ q6 [, w53# n6 W7 S) O, X8 b% V6 B
54
+ ~+ V& {' c* X! M55
9 e. P3 l- y( B; c' W56
$ c$ o: z" E: L5 S: A57
+ t0 d* L" C/ }$ ]4 R9 r/ i: s/ G& |58) @2 `+ R, P' p' h" Y
59
" A* O; G. E7 x# t! v' u# C4 g60
7 q& d# l$ d* R+ L+ X61
* [. U, k; x" j( R& O+ e. x, _62
$ M- D( Q0 o5 ~  U( @0 i63
  I$ K4 `. l! z3 @64
/ J) X  ^% A% ^: W* H( v656 `6 J- e+ Q  w6 |& p8 e7 P
66
: b4 ?" H( p3 y# y& m673 I. M& @# B$ c# f3 o
68
0 Y/ S0 O" N8 U& t696 ?% Y  a7 u" N+ a
70/ e* b* J9 g/ B, |  O
71
# X' v9 X+ `) p# \, K4 t725 }0 e( H: A6 G; H9 n
73. G$ i9 l9 x9 B+ Q
74( Z4 H3 l. @. A3 U0 w$ G9 c6 A
75- m0 M( \% d. T/ f: q- f# A$ M4 g
76
: X% A8 W: G- g3 {, G- s77
: p+ A. l, ?' p0 o* Z5 b78. y2 G1 F; ^/ _1 p
79
9 ]3 w2 g! q  R4 H80# \$ F5 x1 {$ z3 h
81
0 Q/ Q: E+ ^2 ~. ^' d; i& n82
& f% ?% C. O) k( `0 [6 s834 u6 @( J  T& |  O0 K' _- H. A
7. 设置需要训练的参数
$ Q" [; q& u! `* m$ `8 v+ q! q/ G# 设置模型名字、输出分类数
7 f0 `# t3 c3 e; |* Y$ emodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)
$ q) J, M. Y) ^8 t) y- [, u
' C* d7 j4 U& S' O# GPU 计算
5 ]: {: r8 ?' H* @8 _: T& k  {3 rmodel_ft = model_ft.to(device)
  l7 U; n- G4 \. @; M4 B4 I
0 J  T7 E/ d1 d+ T# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取& L, w2 c- ?5 f, M% u" S
filename = 'checkpoint.pth'
1 a* v% F8 c" z6 C4 J" P# u
* }! \# J" L5 I- ~# 是否训练所有层- y- M& E& ^8 _2 x& H5 z7 n
params_to_update = model_ft.parameters(), C% O( [1 R, h: R/ r7 E& a+ X/ `& w
# 打印出需要训练的层2 J0 H$ Y) r: ?! ?: R9 g- G
print("Params to learn:")
7 t# E2 {+ G' r+ oif feature_extract:
1 V- x3 x2 Q% }    params_to_update = []
3 S! }! l/ P- x' J& w    for name, param in model_ft.named_parameters():5 U. N8 [( W" l" b0 V! l
        if param.requires_grad == True:/ p8 c9 L. M; C' c# ~& [7 \
            params_to_update.append(param)
/ T: U  L7 r( d6 `1 b# K            print("\t", name)% d1 k; U6 M0 J9 p
else:# ], b" t: a9 ]  r# S' x+ t: Z
    for name, param in model_ft.named_parameters():
* Y  `- U- r4 N: j        if param.requires_grad ==True:
& Q6 P  V) C4 q% U' f. |/ F            print("\t", name)
/ J! V1 k# g7 W, T$ Z; g
$ L: h& @5 f1 B( [) h$ u" H1, L' x, `8 h( D7 c, [/ h5 f& x& ?
2' E4 c; ]: s# W8 Y) b( k
3
) U1 W7 z* e$ }* j* n+ Q+ A! H& |4+ o9 q1 J$ g- r$ O: r
5: a1 l% c9 n" s* T  C( x
6
# A& P  J2 Q( x# D7
, M# b* a) F# u% C; T8, d& t( s" L, A& s9 O/ S
9
, E8 ~7 h0 m% |& ~1 I10
$ e6 K; Z; M3 a# `5 I* |115 r- Q) J' Y8 D$ ?5 }7 E: h, b3 w0 i
12
! J  p4 u' J. O5 ^$ B% w136 i# F' ~9 N' r/ ?
14+ l4 {. S4 `/ z: a
15
7 k2 s: Q8 R' d1 p+ E163 L- V/ N+ w. t5 f) P# g( @
17
  L- ~$ F4 U( R9 x. B: M18. g/ z: C, s: ?) e- v1 h/ f: q* }0 X1 \
19! b0 |$ |  I( ?+ k! ?
20
% W: ], f! m8 j  f) o21
* h6 m) r4 `, x& Q8 d22  R4 z8 N3 M4 g5 N) }7 Q
235 |) W* O& Z- C' C/ y
Params to learn:7 K. N: Z/ ~+ e6 U  f0 `
         fc.0.weight% i2 |3 J& W/ B/ D- g, c8 ~3 L/ w
         fc.0.bias& t& b* B1 H' X# Y' S5 M
19 r% b/ b0 i& R+ I
2. d; y. |$ w7 _8 g$ N  g
36 M) d" C* R9 i# o  F8 v6 o
7. 训练与预测
3 W( {/ o% r% P7.1 优化器设置
7 r9 @0 J, x3 M: q- `# 优化器设置
3 e- H: b$ o& G9 V1 qoptimizer_ft  = optim.Adam(params_to_update, lr = 1e-2)3 p! w+ h  U6 b, `0 a! _, n
# 学习率衰减策略# m+ H2 v7 L, J' w: C) }
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
" P' O, X" i3 q; \$ b" @2 _3 Z" p# 学习率每7个epoch衰减为原来的1/10
  h" M* P' C7 z4 ~3 [# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算
/ `% x' f. }1 \/ F
1 Q: @* t4 [: F) a0 W: icriterion = nn.NLLLoss()
5 h# A9 s( A- m" n% l9 _. b19 U" f) q; k! F# Q$ V: _/ @
2' L. I4 Y  [1 j7 L! d$ N
3* q8 E( H4 A5 f9 N: A$ g6 F5 L) I
40 P' G3 }2 k# g8 a# {  i8 Q$ @
5
% n+ H- a* X" q' a% d6' X& a& [1 n- D1 t+ B! H
7
0 i4 N; w. I- n) p& w83 U/ y0 |6 i5 |9 O7 M% }" j
# 定义训练函数$ S! G6 q1 Z. a. J7 {
#is_inception:要不要用其他的网络" ?; f* N# ~( i2 n/ q
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename):9 O9 U" q) G% q" Z7 o0 I
    since = time.time()& H6 e* Y8 R# m; G5 X
    #保存最好的准确率
: W0 p, ]( a2 v8 |) G" ~    best_acc = 02 B- T5 w& _& m% v4 M
    """
  \/ o4 f2 `7 x4 F& F4 l( Q    checkpoint = torch.load(filename)& ]9 V3 f0 Z- N
    best_acc = checkpoint['best_acc']8 |# r9 K0 a+ H2 `7 T$ H
    model.load_state_dict(checkpoint['state_dict'])
3 f3 ~  o) F8 s0 r    optimizer.load_state_dict(checkpoint['optimizer'])
* W, n5 d% g" ~. S9 g( u    model.class_to_idx = checkpoint['mapping']
: F  L( d* J) [% M6 S2 K    """2 U- S/ a# [& J, E3 s% x
    #指定用GPU还是CPU
  ^% b7 p6 r# m    model.to(device)( |0 d8 r. N3 j4 j7 L
    #下面是为展示做的! g8 ?' o! o) p) K
    val_acc_history = []
( h- L" [. f! S) N/ A$ `6 u    train_acc_history = []
% e$ U- J' K5 C, E5 Y    train_losses = []4 H; P1 T1 ~" c4 P" I# K
    valid_losses = []' X- w2 u* T% Z3 F
    LRs = [optimizer.param_groups[0]['lr']]
8 k2 J. }# E, B8 \    #最好的一次存下来
; j- ~+ H7 T4 X* m. _* u    best_model_wts = copy.deepcopy(model.state_dict())% }5 J! b' H( V$ x( S

) c- w- k; H3 i5 h: D( B9 O+ m    for epoch in range(num_epochs):3 [# C2 F, g' Y% O- Q, B5 C, f9 Q
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
7 g- [: B( V1 |        print('-' * 10)1 K. s8 G. W+ [% b) b
( d  ?8 b3 _+ r6 C  g' v2 v
        # 训练和验证
: _7 F1 X! E6 S* l  f! E        for phase in ['train', 'valid']:# q8 a6 j7 o/ m# s. q
            if phase == 'train':8 h8 G/ h/ g4 k7 ]! F- O
                model.train()  # 训练
% i; ]+ ^$ w/ i# ~; H7 v            else:2 f" Q4 q9 w% B$ j
                model.eval()   # 验证. x3 C5 q7 a' f7 j) b, U: S& `

2 t; g- _; o; ^0 z: r% f) x. @            running_loss = 0.0
( H' g1 W  p7 |7 o: c8 A) K. I& v( [            running_corrects = 0
/ e, @  `3 {7 Z$ z8 y
9 r+ x' `# o4 M2 r* d  N: k) P$ |            # 把数据都取个遍
: z+ x6 p4 \$ @: D( D; h            for inputs, labels in dataloaders[phase]:
* g' F2 }1 d7 D& I+ Q+ H& f                #下面是将inputs,labels传到GPU
; \! v6 d2 G, v; ?" c5 ~                inputs = inputs.to(device)2 L- V6 N- f: \2 x: I; U0 Y
                labels = labels.to(device)4 K1 m6 c6 O6 o( T! _* b

( d) O. u( b' q+ F                # 清零
5 @+ `' e* x; n: q2 A) h2 E7 Y7 s# s3 N                optimizer.zero_grad()  Z+ J; |9 @' u  M0 e
                # 只有训练的时候计算和更新梯度
6 i( E" x" e5 v1 R                with torch.set_grad_enabled(phase == 'train'):" ~; `1 L) u/ S' F& D% A2 Y
                    #if这面不需要计算,可忽略
, {8 H- U+ f* E" T) J                    if is_inception and phase == 'train':
1 V" d% B0 ?! s+ j6 a3 I. E                        outputs, aux_outputs = model(inputs)
9 ]% c- `! T8 ^# G5 b; W                        loss1 = criterion(outputs, labels)
5 L6 y8 D5 P/ K& U5 K/ ?                        loss2 = criterion(aux_outputs, labels)/ o" f' c0 F1 J) y: B3 }+ K, o
                        loss = loss1 + 0.4*loss2: ~3 J& w) _% w0 L
                    else:#resnet执行的是这里- }' w2 X: Z  J6 z6 ^
                        outputs = model(inputs)
* ]* T% B6 e) e' p" B                        loss = criterion(outputs, labels)
0 ?* _7 G' {0 X) ^1 I/ y2 E9 {
! W/ |4 l% x' ?+ f) l                        #概率最大的返回preds
" y% y5 ]( w4 }/ p; E4 K                    _, preds = torch.max(outputs, 1); b3 r) q6 ], q' x: b, U

; U5 |% n0 J2 f0 x, J* z                    # 训练阶段更新权重
3 i1 `7 r7 P0 r% u$ j% D6 \                    if phase == 'train':% j' |1 I3 h# O+ t7 c5 y
                        loss.backward()8 ^/ C1 j7 ~+ e$ S/ r$ G
                        optimizer.step()
/ d  I: ~1 C" O- m' r6 K% {' N! J% V3 ?7 X& O. \  z
                # 计算损失% H% S8 {/ K: Y- g' R
                running_loss += loss.item() * inputs.size(0)
6 I5 t# V1 F- D# t6 ~0 Q' |                running_corrects += torch.sum(preds == labels.data)
- B: L4 \+ c/ e- J- i: z- t* Z- L6 F% c# u- K. I7 l# d
            #打印操作
# o( ^0 S3 V, U            epoch_loss = running_loss / len(dataloaders[phase].dataset)
/ Y; I& G  V' L' V. K, M            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
; D3 S  J9 j2 h8 G9 o& A
; D1 z/ G8 C5 q. L
) H* C0 y" _) w4 V1 d; [& m            time_elapsed = time.time() - since8 p0 o( V- S* p: \+ A* }$ a: ~! E
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
) q- G* U7 E3 W- {; r$ I* O4 Q            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
, t2 ^% A4 \( h6 D9 x' L/ v; y6 q4 p8 V+ A0 _7 j
1 ]4 F" U/ L& u6 U' M1 w/ i
            # 得到最好那次的模型1 l. x& k8 A* V2 j, a' P9 s# a# E
            if phase == 'valid' and epoch_acc > best_acc:8 a# Q& V4 O$ Z5 d6 }: h/ v
                best_acc = epoch_acc, N# I: e: i8 k: K1 J! v; T( W
                #模型保存
% }( J5 Z; j; x5 g7 k* B  r                best_model_wts = copy.deepcopy(model.state_dict()): w1 o6 E0 c& Z2 F9 Q: h
                state = {
( I! \1 c& J- U9 f                    #tate_dict变量存放训练过程中需要学习的权重和偏执系数9 x3 D( ]0 I7 F: ^% |' j7 e
                  'state_dict': model.state_dict(),
  R9 B; v, A! ?                  'best_acc': best_acc,
' L3 b/ Y/ N* R7 `9 \2 \3 j                  'optimizer' : optimizer.state_dict(),1 h% ~- f( I+ A- a0 g: x; W
                }
. f* {& u! {1 t0 D+ D3 X6 K  ~                torch.save(state, filename)% r8 S$ H8 t. e! T; E! e9 F
            if phase == 'valid':4 @+ ]& R! p4 k& W. n
                val_acc_history.append(epoch_acc)
' h0 y( b' j& Z/ P& u! v) T+ A/ Y                valid_losses.append(epoch_loss)
/ X' U$ M/ b/ ?& ^" i                scheduler.step(epoch_loss), l- {! v0 N( u# J: N
            if phase == 'train':
% G2 F7 I) P2 c2 g                train_acc_history.append(epoch_acc): b4 b; w9 @5 P2 Z
                train_losses.append(epoch_loss)4 E$ C: e, q( V) k6 C( [* v% X
2 j" }4 D( o5 {, k6 f2 w
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))* Q; Z- c# x" m
        LRs.append(optimizer.param_groups[0]['lr'])( H- B- _, `, S# M  [
        print()/ T9 G9 {) [( T9 ~  k7 i' M

/ l! l9 O* ~1 R    time_elapsed = time.time() - since" F% j5 x; c  D! \' G: W
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))' D/ ]0 U% t0 M0 \  Z0 d0 G
    print('Best val Acc: {:4f}'.format(best_acc))' x2 l* @  `1 g: H( ?; w

& L+ s% s( ^% T    # 保存训练完后用最好的一次当做模型最终的结果
8 E/ _2 W6 z+ V4 E2 k4 U  R9 N( B    model.load_state_dict(best_model_wts)! ~# _1 A- O) F- D3 _
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
/ ?/ Y& l0 c- q8 w* _, H& F6 \% k$ u
( k9 z0 X2 }3 `; z: S. M
: E! V/ e. _5 s5 _+ T* k$ ?1
/ L- R: Z" O1 G3 J* G2
  K! ?$ M+ E$ V5 B: C36 d+ f# ~7 R+ D
41 N% V5 v) X, T" w) i9 I( i
5- C' X- [+ Y1 J  W# a6 K
6( x/ T9 z9 M$ i' G( _$ M0 m8 L
7: |& a5 t% `) u
8. F) {" S% W: l
9
: Y; u# L8 @; t0 s% H0 d106 f, T* F$ n0 Q) F# b
114 Z/ Z9 l* O; P6 q2 G
12
) o: s, s! B, Y4 i13
6 ^+ j; x7 w8 Y8 M. w# [4 w( q) Q14
7 k* V/ \9 @& L15
; z: k2 y8 {. h; @+ ]6 T; c16
# i) N9 ?0 c5 |& C4 }17
' N9 u2 |4 g" V9 ?$ M# a5 q! C18
5 e7 H- d: g+ \; p$ s5 w" E197 p, G, T( Y6 S3 E" s  X
20
5 Y7 V: C$ j/ I. s0 G21, }0 w) q9 g/ ~5 O
22
" ?/ C( @) B6 d234 s$ C2 S5 ]; b& U0 \  s
245 x% q# s- a+ \1 Y, ]  ]! ~: P, ^1 ]
25
) p* N  ]! ^5 L& p26
2 `' _0 _, z. S% k27
. z  X# s" S9 r' w% F28
4 F# O% q& Z7 Z% S' Y$ x7 \290 Z( }# ~( D) z# X% f8 C
30
+ Y6 N1 V; Y" J% Z31
" H) }, l0 ]8 G7 ~' g" S% L329 O, s, i- p8 N; E: m/ @1 ~! X, M
33( T6 @+ A! j1 q$ D# m$ f4 a2 Z
34% f8 R; L1 ?8 Q
35* N9 y9 v3 u" \. C! f+ }0 ?
36; f$ C0 \9 d% A; T) F- V' J
37
8 j' u( h1 v: g5 r: l) l38) D- K4 q. f( y! u. }# a4 \3 @" z7 H
398 |6 n9 h- _, c% g, I# u
40
: l- q, K$ L- d415 ^8 J. L2 f& o+ r( U: I
424 ~1 J8 [) z# w# _% |
43
! `" z7 L/ V- g+ ^44
7 r7 F; S9 j4 s8 H45. j1 h* }7 M" L1 F1 l2 w" J
46( e9 u# V$ Q8 h4 C/ y
47- s& e% ^, b3 x6 V
48
. L% k0 D2 K4 Q7 X49
  ~9 b4 V: S: M* ]5 j50
: n5 [. S% s& t: i' o, n51
% C) N9 u& G$ `52
7 N; m' h7 c' e0 |" l3 K: `: R53
$ q# \; Y0 u6 ]( W& t8 R54- a2 d: S6 A! I1 P
55( m  A1 a+ T* s: {$ m& w
56
! ^& x. v. s  d- b9 K7 x57- p& R. N9 V: t% v: J
58' ^# ~4 W2 F/ \; f" n8 L
592 r% Q; d2 a. q5 h9 g
60$ O7 m& V2 h( p8 t2 G% E/ Q
61
2 n& V7 |# y7 _  M/ U4 R, q. ^62
1 _8 h& b, x- O" Q0 g& ?, L63
- T. d$ W+ E; q* w. S! m64
: N; O- O( M0 z2 J' u+ w) [65
5 {1 ?1 G. Y$ B1 U; ]7 ^0 |66
6 k  N8 w' f) u2 q) ~, c7 ]4 B67) @6 y+ q  L3 J' q" @3 {5 ?' }
68: |  Q; Y( S+ c/ v. \4 i- z
69
) Z8 x# h( I; A( |0 F" a  Y/ z70
0 |3 {, v" F, u) ]6 o71; ~7 s$ `! g6 `7 ]% i" e
72) D- O1 d: c  z3 b7 {3 a
73. d% b6 ^# F* E, M! D0 n5 E
74" e7 y1 P/ O$ l4 x
75
9 o0 e8 J' c6 r9 d0 P9 r769 }# C# i6 f) S
77
4 D' }) B# K) E1 k* B/ s  W78; q. j; T! k, Y) Y# t
79) [! a! i. _- E/ \
80
2 f9 ^! n# I5 s0 r" @81
0 p0 [5 K3 E4 r5 M( T5 A82
2 N. J3 E4 c- u$ N) V8 ?  o. [3 i83
: ~6 J3 O" i* H$ A7 u8 g84
( @$ _9 Q" |4 ^7 s0 Z" |852 E; {7 b$ G+ }5 A6 ^& T1 \
862 b, D+ ]2 ?1 ~$ j
873 a4 M' T6 T5 W- d4 @1 L% H
884 e; q7 m8 G% I* S5 G
893 K  T( o, D# |- y; B
90
+ a" G2 l5 q0 ~% @! s8 |! j91
9 K. N, f' J0 m  D! C5 H( h* w92- y& F) p4 b0 H+ T1 M$ F
93& v: r% o1 E9 B3 _8 I
94% ~- k5 w" Z- C6 q. U& ~  m
952 d+ H6 r: R- a: X3 g
96, C' |& T  Z" f9 w
97
$ U# {! b& |! H( n8 M% ?98( j9 X9 ~( B: f. O, v
99
( W) p( x- X7 o' J1 x100( E, X! _* W. e) O
101
- Y" J  m* q+ I( }5 `7 g6 {* j1021 K3 \5 ]; j5 h& K
103  J2 A. Q; q4 X8 ~
104
: X9 x3 i4 a" I2 {9 C5 x5 A105, s$ B: e+ F$ g# Q8 B) C" k
1069 K5 N. ~+ m0 e$ B9 f/ `/ ~
1077 C! V1 G$ `+ T/ y
108
' Q/ Q8 G& a5 c3 |) t4 m% G109
: E* L4 V; F, @+ t$ c110
0 \: Z; h" L, _) ~111" q0 [5 h# }0 ^* N$ r  h
1127 U7 z# G) v5 ~4 V2 S4 q
7.2 开始训练模型
! n, [  E& b' M$ J: H# V( c我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次
+ O1 i: K/ ^. Y7 T' x$ k
6 C4 F: C/ O# a* Y$ n" W#若太慢,把epoch调低,迭代50次可能好些+ e, ~& w' Y/ V4 B) p1 p5 l$ v2 V
#训练时,损失是否下降,准确是否有上升;验证与训练差距大吗?若差距大,就是过拟合; \- D* c/ P" B: U* a9 {
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=5, is_inception=(model_name=="inception"))
3 p+ }: t, B1 u1 t5 Z
9 ]% D) _, M+ ~' F* D& y1
- ^9 b/ E- P" n! k0 G3 b2
& R8 m# E3 E3 H$ j! m$ O6 N2 N3
& ]1 ~4 }; a+ K3 z/ P" a) b4
' C+ n$ ~, A4 D( `- I: M2 v* FEpoch 0/43 x# ^1 d% I1 k& t" v( D
----------/ v) Z- O8 c) N7 u* v8 @
Time elapsed 29m 41s
! u# p5 U* F  ]train Loss: 10.4774 Acc: 0.3147
: U7 h5 X' K/ zTime elapsed 32m 54s- E+ U& V; L' H9 k2 E
valid Loss: 8.2902 Acc: 0.4719
0 N$ J% E8 z) I) }( E, DOptimizer learning rate : 0.0010000
1 h( l; ?. L6 ?) c% T8 b
; Y! t1 Q& j4 D# q) SEpoch 1/4
2 L! }! V) a  U# H& {* n8 c. x----------: w. i  V+ L; g
Time elapsed 60m 11s1 R! J1 E% J! v
train Loss: 2.3126 Acc: 0.7053" m5 n: q' N3 F
Time elapsed 63m 16s) L  m$ [+ v& h; c/ G
valid Loss: 3.2325 Acc: 0.6626& D0 l/ N: l" n1 [9 {- K
Optimizer learning rate : 0.01000006 U1 g  V% U* t/ @% e2 u

, P( D5 v9 U0 F* r5 @4 GEpoch 2/4$ K/ j# j% `6 q
----------. H! H0 _: d2 o( f* J5 K
Time elapsed 90m 58s& Y8 P* a$ _5 O. |
train Loss: 9.9720 Acc: 0.4734& v/ N2 {% O1 A9 D  G
Time elapsed 94m 4s2 [7 D" }9 ]2 Y2 m
valid Loss: 14.0426 Acc: 0.44133 {$ ?4 {+ Q0 `% s& m2 W! @
Optimizer learning rate : 0.00010007 U0 n7 p7 k1 k: t5 S! a
. W5 _0 C: @7 H" X4 `& k
Epoch 3/4
' y7 X" t; C/ _1 O----------; Z5 @) B; @+ Z2 W
Time elapsed 132m 49s
  I3 c. S) ^2 r3 ltrain Loss: 5.4290 Acc: 0.6548
' S1 g) R+ U4 w3 n, STime elapsed 138m 49s# _4 F2 x! P- N& |. p! Q( L
valid Loss: 6.4208 Acc: 0.6027
$ t# S/ v' F9 H" r: gOptimizer learning rate : 0.0100000
$ A& R2 S( R( d% U. |3 P6 H: A2 w+ [+ C! {
Epoch 4/4. W0 @# A8 ~) M, i0 _# I+ V
----------
% e/ S$ S, N; H9 i8 Z7 FTime elapsed 195m 56s
6 N7 X1 j; W4 n3 {( s, ntrain Loss: 8.8911 Acc: 0.5519$ Q- M* J" u. Y5 A) N$ E" E4 ~
Time elapsed 199m 16s
- e" R- J6 z/ M3 C! w$ r6 R2 ^5 j* gvalid Loss: 13.2221 Acc: 0.49142 e7 q8 ]5 f4 Q0 P7 m& K/ j4 X5 _
Optimizer learning rate : 0.0010000' B" y+ N* {: @! S9 W+ p

" P5 ?! k* a5 Y3 r' C, CTraining complete in 199m 16s
. S* [  _0 ]5 N8 BBest val Acc: 0.662592
7 Z9 a4 L0 n9 \9 T6 R' T. K. X) w, Q  Q2 u3 }. N9 t! u9 g
1* k# E+ Y4 D0 A/ s
25 @/ [7 T8 J) X
3# x0 I" `: ^  N/ f  L1 X
4" Z/ }6 i1 ?+ ^4 N' k2 |  Y9 X! ~
5: ~* z" H9 l# C) x. K
6& v; e/ y( y# u5 n
7
. @) k" L2 G; s$ H5 @84 M3 Q5 U  q$ n+ o" [& E
9+ t! v+ Q  p' d/ T
10
+ }, D0 r* R6 V7 ~4 a11
. w1 i7 n; c* u+ b4 E12' p2 I" w. l3 `' n; C; |- M
131 @3 N0 {* c2 w! m
14
, F4 Q2 e  E; E, v; c15
9 s# Z$ z' h* f: Y, e! ]16+ l, `, R3 t, f
17
- n% V2 }4 I! c! V/ ?* @9 L' j18# [- E' R  s- k  L/ A  t1 R' }6 C
190 P" ^7 }5 ^8 e& ]
20
/ c1 K6 K: i9 D6 G" Q21
8 D9 s: B# A: u* Y' j/ P$ j; E22
( V* h2 [( G% h) A! T" Q% S23
, f+ k$ c' z# m/ p247 ~3 N( ?0 g- N4 g6 h+ A5 l
25
9 o+ K+ E& z  ^; M' X26+ X  c/ a; D8 h2 m3 t0 v
276 u( ^5 J* H: V' h) `
28
: N$ Z) y' ^- j, n29$ d: Q3 v3 o0 C+ e2 L/ Y& @; t6 H0 @
30
% U& O2 e9 g, ]6 d  b31
9 M9 P0 K! z: \: [% w: L8 B0 g+ K32) a  R& w+ q% D! \- M! b. Q
33
1 }4 Z! Z; J3 k: R' B- U7 [34% W" c8 E# ?: t7 y4 y+ Z) S7 v0 R
35
5 Y2 C0 l  X: S+ O36
: ~7 |8 @# E; m, L) P37
" i+ a- c& q: {! e2 y$ v38
9 f% a5 F" u- X6 {) U39+ d+ W+ O$ R" s- m1 L9 b! O" U2 h
40$ ~1 q/ {  [! [. V# X+ n* d
41
4 q% ^& c. y+ K; C+ |9 X1 X. W42
0 q; z; `9 D' ~- a$ N7.3 训练所有层
1 B  Z0 ]6 ~; r6 [9 o8 c# 将全部网络解锁进行训练+ a5 o* O+ [& m
for param in model_ft.parameters():: ]  [9 m& ^) A; C; u% e
    param.requires_grad = True
! q7 D1 \9 w3 i) B  [  q8 Z0 y9 L: a0 ?% }
# 再继续训练所有的参数,学习率调小一点\
4 |, w( R; c$ h6 S6 @( z* Aoptimizer = optim.Adam(params_to_update, lr = 1e-4)
; T: ^5 k* f$ }1 }$ R2 v, o% s0 Mscheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)
9 F' M' S2 O1 z  q9 z. C* X4 G5 d: u: E2 T1 A7 Z; Q
# 损失函数
: T% [3 i7 j. c- E) u  Zcriterion = nn.NLLLoss()! G% A  L7 W  M
1
" Y( a- C9 ^" i. X2
4 S! H2 q5 S  N) s! {# S! v* q( j* r3
) E5 r1 M* Y  X% a1 _* @! \2 c+ j4
9 q3 \' z$ N! B1 R; A0 @5 R5
( S2 `/ J  {1 j' L6 i) B( x1 v6! R' l# \4 V" ?- h8 ^, ^' Y( x
73 |5 c' x7 D3 m" L; p" D
8# r) h- O; i8 h: E6 i) K& U
9
3 @# h; o8 Z0 y10$ d1 h0 h) O2 I2 V
# 加载保存的参数
8 E) b# ?: j" M# 并在原有的模型基础上继续训练& f: m& q/ G( W* q; m* ]: p: R: h
# 下面保存的是刚刚训练效果较好的路径
: l' [/ o) Y5 _( Ocheckpoint = torch.load(filename); ]$ t2 f# x7 B1 g% i" K
best_acc = checkpoint['best_acc']
  F3 ]' G3 [; z! ]( c2 s  Umodel_ft.load_state_dict(checkpoint['state_dict'])  T# E( o% e6 W, t
optimizer.load_state_dict(checkpoint['optimizer'])
% N; O/ H( j' \! i3 M, z12 f, d  C, P  Z: M/ ]4 @* h
2  {% ~4 O0 w4 n. K
3
- y4 k& G/ Y, l' H% D$ D4# b$ ^/ K6 W8 t
5
; M. U8 W3 V9 A8 q! ^7 Z7 j5 b6
6 z' z7 g+ C; z76 X* ]' z, G- F8 p) b& ]
开始训练! J3 z& \; y# E# \5 f
注:这里训练时长会变得别慢:我的显卡是1660ti,仅供各位参考) y/ c0 P( h$ n& z

1 g% J7 v4 J5 F9 |9 amodel_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=2, is_inception=(model_name=="inception"))
# c( v' r3 L' o  z4 j( s, c# Y13 v' r4 ~/ X7 P
Epoch 0/1" {8 G6 z9 I- P. Y& A
----------2 e; D- C. G$ S7 E
Time elapsed 35m 22s
% Y9 H+ O, k2 r* Atrain Loss: 1.7636 Acc: 0.7346
1 y* C+ Y. N; f7 E4 GTime elapsed 38m 42s
3 L9 }& k) q' A  T  \7 vvalid Loss: 3.6377 Acc: 0.6455
) j6 I$ o+ X6 a3 ?Optimizer learning rate : 0.0010000: l9 q( ]. U5 G2 b$ |+ Z9 h/ R
+ z0 k7 C, x8 R" H
Epoch 1/18 l$ r0 e& k% U
----------
, Y" |! K, V7 R" n, Z& eTime elapsed 82m 59s
3 }9 {( J* X& h: [7 i/ p8 jtrain Loss: 1.7543 Acc: 0.7340
& ^% O& t4 H8 ?Time elapsed 86m 11s
4 ^& V8 B, n- v: H! ~: |* Kvalid Loss: 3.8275 Acc: 0.6137! Y+ z) H! X& E8 [
Optimizer learning rate : 0.0010000
3 L. r6 l' P  h" c7 `  J/ {3 d1 y+ M5 j$ T7 G2 \; W) ?
Training complete in 86m 11s
' `  Q9 f& B7 C( W$ h8 U  F& ]Best val Acc: 0.6454770 P+ Z: S3 s1 L8 V- L4 c

# w; I1 {2 ~7 o; P) E1
- n% H% j8 k) o% D# A3 y2: y/ O6 S: l. \1 a" _
3
$ h, y, s& M2 j' h6 c& Q) i4  w. k4 L4 Q: v! J/ w
55 R* P5 G$ r3 N9 L: n3 l8 ~
6  \: N2 n4 {( r" e9 ^7 }
75 i* e; X' o+ H7 V3 t* ]2 _- `3 ^% @8 Q* F
8+ J% L+ R4 ~7 o7 f. G( O
9
& t, |2 S- b, }5 d/ Y2 D109 H% i' r- F' S4 f$ W: U
11/ |- g: d& a9 x1 P- o
126 O& F$ Q/ }0 |- h; S' G
13
% A: |5 R+ `/ F; X" [14
* z. }2 ^3 [/ T- q% t# D1 Y' |15
* K+ _8 G0 |+ w16
# X: }& G. o( F: X17
( H/ X3 p5 Y- A, i. `18
& v2 T; k, A7 D9 Q2 J5 K8. 加载已经训练的模型
( Y# v- R0 q  q相当于做一次简单的前向传播(逻辑推理),不用更新参数
1 V; m# Z) C( V( Q9 Z* t) c
$ p0 X2 ~+ H4 s5 k, I, emodel_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
, k$ {. a8 T4 `! h  L3 d  c6 F4 e% p/ ~9 \. a) ^! X
# GPU 模式
$ S7 T; a4 q6 smodel_ft = model_ft.to(device) # 扔到GPU中# v; [* Y: W# B4 M2 z3 v/ _2 j" g
8 v# j- e, q% S8 g: O0 ?( i6 U
# 保存文件的名字
8 n- j6 W, }3 ^filename='checkpoint.pth'- p- J( j" b& B& p' e3 l

) `; h) F/ z+ C$ X/ a$ u# 加载模型
' L* u% t& u, v' ?, H8 }  Z$ tcheckpoint = torch.load(filename)8 v; l& @" o# l4 B& E! ^* m
best_acc = checkpoint['best_acc']
* e4 Z! Q# R: z' Z# omodel_ft.load_state_dict(checkpoint['state_dict'])$ z2 y) V& `+ o+ U: @6 e) n" o# e+ I* o
1$ j& @6 y7 i: H: X3 a, k7 H% \% |
2! H) H2 s* N" b; A9 d, z
3! X. C$ f5 N1 U9 ~
45 B0 L* {9 a) G2 b% a
5
8 [5 P4 b) `! V4 ^) C9 s69 _9 {  O. M/ M; A; X  g
7
6 ?2 {' X% S; v8; V! d" y" _7 H$ Q
92 A% f( m3 A% P" v% Y) s
10
  i+ Y- G& `9 E- M; l11
; \2 \, Y! [: y! @2 R12
, H, h. t0 w1 K& a<All keys matched successfully>7 U% W" \8 k0 p, c0 g4 P7 ], n1 D
1
; N2 o& L4 Y/ t8 Ddef process_image(image_path):6 ?- c9 ?3 n6 A9 H
    # 读取测试集数据
' E& q+ g  C9 g, z# ?9 f1 O! E    img = Image.open(image_path)
# D  F% c) ]) j; I8 Y  f    # Resize, thumbnail方法只能进行比例缩小,所以进行判断
9 l5 M* _* b5 U    # 与Resize不同; v  T5 y& ?# X9 g5 @
    # resize()方法中的size参数直接规定了修改后的大小,而thumbnail()方法按比例缩小
( P1 x$ r7 N: g' D& Z; A  v    # 而且对象调用方法会直接改变其大小,返回None
0 K$ ~3 C) ?# i1 I0 C5 C1 m    if img.size[0] > img.size[1]:
3 @) g+ x- Z9 O- b' n  x; G        img.thumbnail((10000, 256))0 g1 y7 j& l. s! [8 \) F) G
    else:8 X3 h/ p2 V) s' q9 F5 A% T
        img.thumbnail((256, 10000))
" a4 s( K4 `- v1 Z, j" F# e: |6 Z0 l* ~& y2 Z+ ?" e, a9 h
    # crop操作, 将图像再次裁剪为 224 * 2244 p) C  ?* k+ k7 F. N
    left_margin = (img.width - 224) / 2 # 取中间的部分0 t9 c: ]# |+ J6 J5 x2 t0 h
    bottom_margin = (img.height - 224) / 2 & E# {4 D4 S- ~3 J# R0 q. N
    right_margin = left_margin + 224 # 加上图片的长度224,得到全部长度( \7 F% ]& G  P( a4 j
    top_margin = bottom_margin + 224. n# y5 [' w3 T* B9 e

8 d) D! z6 n# x. h$ [' g( o7 l+ s    img = img.crop((left_margin, bottom_margin, right_margin, top_margin)); u- t: q2 `- m& C6 C

( h6 Q8 G$ @( O3 E    # 相同预处理的方法
' q6 T! L& Q' U    # 归一化# E$ t% i, k4 Y0 ]
    img = np.array(img) / 255
: U9 Y" c9 ^% J/ G# n0 \; j! ]    mean = np.array([0.485, 0.456, 0.406])3 o2 G8 H# V/ T
    std = np.array([0.229, 0.224, 0.225])1 H1 O+ z2 V2 l8 A! b: d* p
    img = (img - mean) / std1 B# o5 ~  O: Q
4 ^7 K. _0 g' s5 ^; }9 s$ ~
    # 注意颜色通道和位置
( U: v8 r. [0 N( q3 d    img = img.transpose((2, 0, 1))( g: b( T7 P4 Z
4 @: U: \- D, L( U7 D2 X
    return img! J5 Q) ^: c: N
& C' l. w$ [3 i+ r2 |
def imshow(image, ax = None, title = None):
; [- Y, u- p1 z# R% \    """展示数据"""* W, N, ]& [, {
    if ax is None:
, B% s. ~! o% {8 @1 ^        fig, ax = plt.subplots()
3 N7 e. u1 W# G' t) r( g1 T9 z
  [, ^" B) V' i' M, U& O. J0 o1 ]    # 颜色通道进行还原
# h" e0 I9 t' p  s/ k1 f, d    image = np.array(image).transpose((1, 2, 0))& _8 z+ i% C1 G: H# k0 I  n2 O

  a9 Z8 L6 w( m5 S    # 预处理还原& b- y$ K$ H6 f$ W4 R! o9 {4 o; v) K
    mean = np.array([0.485, 0.456, 0.406])
: l- V9 z1 _" z  N& ~    std = np.array([0.229, 0.224, 0.225])) I' s% z4 S  s& \$ W
    image = std * image + mean
) P" U' ?8 ?9 q% G5 [$ j, ^    image = np.clip(image, 0, 1)
% @" \2 u4 ^8 e: ~' B
+ A2 \' E$ d2 @7 J: u    ax.imshow(image)8 S1 W" z6 p  r, H. x2 U
    ax.set_title(title)
3 T+ E' Z  H4 U; o1 S# _  }2 t+ p( ^- V
    return ax
0 e6 {8 S  Q) l3 @+ z# o% @, h# h+ t: u0 g+ u! n7 U
image_path = r'./flower_data/valid/3/image_06621.jpg'3 i* Y: S" z; P3 z7 d3 @, }
img = process_image(image_path) # 我们可以通过多次使用该函数对图片完成处理
/ f- q% _* n* X' rimshow(img)
8 n6 x$ M8 k  @" s1 X5 w* k" Q( v( Z+ w1 u  a
1
( I5 T4 o+ H: G5 E; l3 T) A8 U" i0 [2, x9 [# ^! T: y/ @
3
- u9 B7 }' S+ z( e4
" t5 F( a# Y; b* j* y53 q  S5 p. {9 [) p; Y( w* ?% J
6: q9 d. a" Z$ J# n$ |" i
7" x/ ]/ I. N6 [: s1 ^) P' B5 ?
80 g0 d" M% F2 S3 d
9: g* J7 S3 r( X8 w& E4 W
100 e5 N: k% l* i( g! \3 h* i- Y
11
  w# v* i1 z- E( P% s) Q12
9 u& I: C3 g% y  ]! G$ e; i6 K: V" M13( ?) C; p+ r. H& v) `5 P
14
* r  v% P& P; `& W% e0 U$ C15! U: X# i% Q5 s3 o
16
2 f0 L9 H8 {3 o* s9 m  L) N3 b17! _9 b$ d5 h2 z( Y' g
18# a! F; L; X$ o, |. G
197 _+ _2 Q4 z. h/ B4 a7 c: B: I8 j% o
206 V; A4 D$ i+ Z) D3 H# u! k
21: J+ w# }* t$ m- d' f
22" W5 h# G0 V* L1 d0 \
23
2 ?9 O& ?1 Y/ Z24
+ n; z. h8 ?0 ?( u/ `2 K% l' V25
0 m8 N) o& L0 ~26. B) r5 `  r) A
27
3 j& F2 _) T% {; Q& m4 H! n28: Q2 N4 q3 c% i* Y! e' d4 B5 n
29
2 P9 [  m: U) i30  g4 w3 M. Z( q% I, a1 y- |
31' [$ S; U: K" g7 \/ v  ?
32. H* y+ Z6 ?$ c3 |  t2 u
33
: f( s! q2 m8 x( z% z) {* j2 _/ `/ u34- {3 E& Z; g/ g9 L5 ~7 G# @- @
358 V0 ~6 j/ S& ?6 ~% o  ~! X
36; e2 ?. o0 f! S4 P( U) ?
37
& D$ C* V: Z# C2 f3 f38: K: w$ l) X& O5 a# [
39
5 g% P5 |; Y6 ^8 h7 x40
! F1 i8 d. \# c, h1 U9 U1 J418 D* c+ s3 m% `0 D& ?1 x
420 t! L( [- ?8 m' }
439 j- F8 @9 G- K
44' H3 k8 b3 y# m/ ~0 S& }
45
% B+ ?0 P* `  f& H0 H46: y8 i2 i1 G2 r/ l9 W6 c( s/ g2 m
47" y6 @) H- U. v4 Y- y
485 D: D6 t8 F* e
49
4 B+ }0 p; p  [+ c+ N. q/ h7 k1 ^50. r) ?7 T# V5 c/ ~+ M  v) q/ {
51
( Y9 b7 H' f" Z* B$ i52+ M  t& c/ u6 ?% o. ^
535 i. ]# |' U( S1 D; v) l$ s1 y
54
2 v( G/ @* a% J3 F: V3 L4 R<AxesSubplot:>
: z# }  M4 m) m/ C6 s1 C# ^8 L17 \8 f7 \  g7 x* z
# W+ j7 G; \5 r# I7 Z% j3 V
上面是我们对测试集图片进行预处理之后的操作,我们使用shape来查看图片大小,预处理函数是否正确$ k: Q' y" a0 Y+ M

. K" h; i/ [# V3 Mimg.shape
, ]. U; H3 H( z* s  S0 @9 w+ V1
( }9 y$ q$ O: N" ?! q1 X7 e(3, 224, 224)4 m6 p# n/ Y: u6 r" X- O* Z$ S8 v
11 J% S3 I: |' i- }- U( K3 I) K
证明了通道提前了,而且大小没改变* T' i# v$ l" C' U

2 }" f8 o: M# [8 j9. 推理
. D* }% I0 ^' m8 _4 b2 ^img.shape
5 J5 f: F5 Q% `: l) L7 K! e; I6 k* d& l/ F% I$ L- H/ U) k
# 得到一个batch的测试数据
1 i/ T+ U  F5 V. j7 p6 Wdataiter = iter(dataloaders['valid'])' |4 v7 f5 o  p" Y! A+ L
images, labels = dataiter.next()4 `4 m! k  Q2 I& u- I% L6 W/ y
8 R" N0 n- V1 K) t
model_ft.eval()
+ }' w8 v/ v! d2 @$ [0 @7 X8 U5 V, O% }5 F. O, X
if train_on_gpu:9 ^# l  B$ Z% t% g& B9 r% F  |
    # 前向传播跑一次会得到output
0 t9 a) H; A; \# d9 G6 f: r# {. J    output = model_ft(images.cuda())' U! C7 _+ M! n6 _9 ~- x
else:
- E: I6 D3 u  t. s    output = model_ft(images)9 [! u8 M" z2 F+ L7 ]( b

3 Y; O/ f5 h9 I$ \$ K" U* o# batch 中有8 个数据,每个数据分为102个结果值, 每个结果是当前的一个概率值" Y* n8 H% k1 G" q) q
output.shape' G1 H) ?; }# o  M+ s" o

9 p0 y& r6 b1 ]3 y/ S9 @1" K, L. D# c4 d& Q% S
23 K) c+ r: _1 C6 E+ s: ~
3* ~! v" x2 j+ H. O
4
! v" d$ }* N2 Y* k57 G. _% p, _- ^6 @" q
66 o- ~! J4 g( N9 K& ?
7
0 S) t& Z7 Q; P% h8 C" ?8
1 t0 d1 t) l5 \9
( {9 K& d. Y/ @0 B. Q6 G10. R4 j9 d4 l$ s; F
11
% c, H2 O+ I% x8 P0 |12
( Y1 s* a/ l! N) Q9 i13
- ?# [* M$ [% J# Z14
. V  l2 ^0 w' l/ |7 e15
4 g& G- \  A' ]; ~16
0 s, ~1 u0 `2 e3 b9 P- [/ p" W5 Gtorch.Size([8, 102])- ^- a0 z6 B4 g1 W0 A) A0 d- O' u
11 w, G: b7 H- r* y" m$ b' f
9.1 计算得到最大概率
; q. E: p6 ~' a4 w4 g_, preds_tensor = torch.max(output, 1)4 `4 l7 q! b$ g. \3 ~

, v7 t( Y0 e7 Epreds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())# 将秩为1的数组转为 1 维张量: R# R% m: w; |7 h  {+ ^
1
/ M8 H/ T0 G0 l/ Q" u24 ~! p5 K- U: ~4 n  T. L+ q. W
3* I0 u9 G% d! d& N7 z
9.2 展示预测结果" f9 ]+ p! {' q, X/ B. _3 u
fig = plt.figure(figsize = (20, 20))% c6 [7 `' D- [. w6 U3 R9 ^
columns = 4
8 |6 x4 R3 e) p4 b1 Q" b! Z1 B7 t% ^rows = 24 _! y, t6 r4 u

# k; D+ {6 {& [- dfor idx in range(columns * rows):  A. e! E; h# e- [
    ax = fig.add_subplot(rows, columns, idx + 1, xticks =[], yticks =[])
3 I  O( X! z8 E/ X. e1 e0 g    plt.imshow(im_convert(images[idx]))  P: H0 q5 q6 v1 H. I
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]), ( P, c; R% k2 E) U3 T* e% e
                color = ("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))  l( z) K: D! C- t& R
plt.show(): f2 V4 T: M1 {$ {7 D; }9 B
# 绿色的表示预测是对的,红色表示预测错了0 H0 j. n0 X' q$ ]: [# b6 u8 ^
1
$ I( Z) o0 p2 S, \5 H6 |2
1 {. l" z! H, i; y* b$ ?* q3) ~4 H5 Y+ x  X" A% H' I: s
4$ l8 b' D- P8 K2 a4 K3 ?' F$ n/ O: v
5- G/ C) v( G7 R
69 A1 X; f- y! r+ ~8 ^' |6 T
7
7 D- i2 }. ^9 A6 B; J2 f8% d- i: k5 i9 k, U4 y1 X
9
% x, n  {: c9 c* i6 \' Q107 @4 B7 t2 J/ x3 y- J
11: `9 i. C/ S" S3 n5 `1 I6 ~

0 E' |2 J# q5 I& a
% R* ~5 T6 F* O% a/ g5 O+ g. O3 Y( T' _1 y& i; B) P/ @8 s
————————————————# a$ Y  g1 z& @$ o2 G* ?
版权声明:本文为CSDN博主「FeverTwice」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。  `8 [; d: D; g
原文链接:https://blog.csdn.net/LeungSr/article/details/126747940: g, S( s1 x! g* g7 I0 f5 \

0 {7 M+ H4 m) R# G0 _; S, H7 C
: W2 ~2 B5 N- l. @) F/ S5 ?




欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5