% I' A- C+ A. V* L& l9 R1 ~" c4 V' x将训练好的模型拿过来,并pre_train = True 得到他人的权重参数7 E9 z7 D- `0 F) b. y6 d& g
可以自己指定一下要不要把某些层给冻住,要冻住的可以指定(将梯度更新改为False) 3 E) ]( j7 Q% g5 v3 v' X无论是分类任务还是回归任务,还是将最后的FC层改为相应的参数1 a8 u1 a; D7 U) S% e
官方文档链接# A; W3 A8 D7 o, W. `
https://pytorch.org/vision/stable/models.html4 Y) G5 a: u* z( ?5 @
- g! B1 J; G$ q( Y5 d. o' E# 将他人的模型加载进来 & S8 S' k9 a( x9 m3 v3 wdef initialize_model(model_name, num_classes, feature_extract, use_pretrained = True):; i- r1 a- K9 W3 y7 | G: m
# 选择适合的模型,不同的模型初始化参数不同7 J# i" M9 p' a4 C
model_ft = None E1 a+ u: u$ S
input_size = 0 " ^2 @1 [! Q+ W) f" \$ B6 O3 D/ R' X: S% q/ j2 u1 m
if model_name == "resnet": 3 K! @- y6 ~3 ?9 K% C* t """ # f' L% q5 A3 y3 c* O Resnet152 ]% _" e7 g# N" n
""" . O/ I O4 U. D. a! ]' |7 x1 Y8 F- ~0 _
# 1. 加载与训练网络# \ y, K E3 ], x6 W$ u
model_ft = models.resnet152(pretrained = use_pretrained), `6 m! p+ Y2 }. l" k+ V
# 2. 是否将提取特征的模块冻住,只训练FC层 1 u5 U+ Z+ @6 }; C4 G set_parameter_requires_grad(model_ft, feature_extract) , P2 ~! G2 B4 ]/ |( ] j' l) H # 3. 获得全连接层输入特征 . U- Y3 X. y$ R% G$ n8 {/ h; e F+ q num_frts = model_ft.fc.in_features" i* |4 r2 `6 ` C
# 4. 重新加载全连接层,设置输出102 * K: Q: g O& _$ j% z! { model_ft.fc = nn.Sequential(nn.Linear(num_frts, 102),' Y2 E) Y! |. |2 j4 f, K
nn.LogSoftmax(dim = 1)) # 默认dim = 0(对列运算),我们将其改为对行运算,且元素和为1 + b, u7 B( {- s4 `/ ?; U, M input_size = 224 ! ~; U) r2 a7 ^& w+ u6 z" z4 D0 F4 p! v6 N) M: q
elif model_name == "alexnet":8 o& K8 U# u* r7 r
"""; ]/ G/ u1 p% x% z O* B
Alexnet + l8 {+ y, f4 @! _- F+ c """3 W7 d2 [% b4 K
model_ft = models.alexnet(pretrained = use_pretrained); o. _0 a* Y( Y0 W: n
set_parameter_requires_grad(model_ft, feature_extract)3 i) u% d7 B( w9 n- `
/ ]9 h: h4 x& o4 l: K1 y
# 将最后一个特征输出替换 序号为【6】的分类器. f5 ]1 `( V/ Z: u7 |+ f0 d0 t
num_frts = model_ft.classifier[6].in_features # 获得FC层输入 / G0 S4 ^3 y* _- C H H( b model_ft.classifier[6] = nn.Linear(num_frts, num_classes) 2 u; d* M+ L* p, |) b3 Y1 ] input_size = 224 0 G9 ?7 E$ A2 A; A- E 8 ^) D. E: {# E2 v+ b. r" V$ z, Q elif model_name == "vgg": ; t" S) M7 v* o( g; n0 i9 { """; E, f4 L6 }' e# h
VGG11_bn* b- B4 S/ D5 h' R- [8 I4 ]" I* p$ q
""" 4 r& n' y4 Z/ U: n9 e. D8 r7 v& K model_ft = models.vgg16(pretrained = use_pretrained)2 Q& }( Q9 n. _5 O
set_parameter_requires_grad(model_ft, feature_extract)+ L6 q# E/ \" z
num_frts = model_ft.classifier[6].in_features( n. F+ x) E( S( T
model_ft.classifier[6] = nn.Linear(num_frts, num_classes)" A* \8 g: I5 T f \6 e/ \" n
input_size = 224 ]: O( n+ A$ z1 D" ^: a$ b# @$ S3 {! P9 |
elif model_name == "squeezenet":+ w* z/ `8 C8 m) f
""" + k! q$ ~! Q8 P8 p ?( z# W Squeezenet: t; x/ N/ O; |) e; y* ]* M8 j! @
"""( n! `( p, W7 p* z7 g' f9 k) K6 m7 J
model_ft = models.squeezenet1_0(pretrained = use_pretrained)2 `% C+ x' M- B p4 P: c
set_parameter_requires_grad(model_ft, feature_extract)5 W' ]5 R$ w$ H: `7 v/ p
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size = (1, 1), stride = (1, 1)) ' {4 r/ q& {" }" R' H3 [) z model_ft.num_classes = num_classes 9 a) X+ J; W% Q input_size = 224( @3 `5 e2 k+ v, S" h7 ~( c3 q
0 ^5 E$ @2 o W" K1 e* ]" h
elif model_name == "densenet":- z- ?2 G+ O1 C' H M9 U
"""9 ~9 r! t1 Q! P
Densenet7 m& v* |5 P, k, e t+ H6 ]
"""1 q* P) [& D& X n8 L& Y5 F
model_ft = models.desenet121(pretrained = use_pretrained) 0 k8 S* x- I& M2 q) I# j! F set_parameter_requires_grad(model_ft, feature_extract)0 h. J- @* I& {$ b: G# V
num_frts = model_ft.classifier.in_features6 R8 ~! g* c1 R7 W
model_ft.classifier = nn.Linear(num_frts, num_classes) + G" g q# W" ^. i7 ^ input_size = 224# \: }# J0 \& I1 P% t# |; r e
2 ?+ G2 M% b, _- F elif model_name == "inception":& `, a& ^/ l2 }) \
"""! k1 X# `" t }' n+ W
Inception V3 + a Q3 U B- y """# h4 K. z0 e0 L8 F" p! O
model_ft = models.inception_V(pretrained = use_pretrained) " G3 U, N, i; w) S* Y9 s; V; q set_parameter_requires_grad(model_ft, feature_extract) - F$ e0 i4 r8 b) |7 Z' g. m( W5 c/ H7 U
num_frts = model_ft.AuxLogits.fc.in_features9 A# i; o; p; w5 B6 |* r
model_ft.AuxLogits.fc = nn.Linear(num_frts, num_classes) 3 s* g: U+ G- e/ {4 }9 e$ I , w1 n; ^$ L& x2 p num_frts = model_ft.fc.in_features/ g4 t7 [# j$ q ?; ?! ^
model_ft.fc = nn.Linear(num_frts, num_classes) 5 H) w( T% v4 C input_size = 299* ]! s0 } t6 p" w" d
, v- I' ~/ T! W2 T# F1 n" s5 V
else:8 A& O- x' J0 D2 W1 @: |2 H2 v
print("Invalid model name, exiting...") 0 q/ K. k3 E* u0 i; f- n exit()0 E9 T" V) e7 I2 N3 V
. _8 C- j) o. L return model_ft, input_size" o& u( r6 B. U; T' ]$ \ Z4 \, q' E
4 x; t8 f0 [! j" l1 , g- k" d% G) t2$ I, { v- I) g: g" @- [8 N9 K( X
3 1 b8 D+ f; Z, }' W4 3 r3 H Y4 C, ^8 u, o& f5 + Y, H. ?; o' }. c1 E64 N/ D/ f, @- m. [$ Y
7 S! l) o: h8 ^4 \& i% f87 \+ x: X; c1 O+ a
93 ^/ ]* T& M' p8 j @- L9 k9 v
10 C+ z- q8 ]) J% V4 L, S
117 b0 c/ W% t ~% p7 W, K$ d1 r3 n
124 a9 l* T& i6 e0 Y; _4 H
13 / \$ f m; |; E# j9 D7 C14 ! p# E4 y# A/ n% O/ Y. S151 S' M% |0 a2 F' c; Z* ?; N% s) j
16 + _) D5 Y* y* }! }17 + b- H# d" ~/ b: ?3 N18 ( L: g+ ?3 D9 ~/ r0 L$ `19 . {5 M) |( \! \4 [! R# ^1 N20) j) X/ g4 z( O/ ]3 e
214 v1 A& z/ B' t
22 - @: P0 p6 p; _, W, ~3 s239 j: ]: [& A, S9 v- F: x
24" ?! }" `- b, q1 [/ v4 d2 U
25: j" B8 ^0 h4 J( H0 Z1 X
26- j" S9 t. t. W
27! O, B2 v; \# i2 e4 `4 O' [
28 6 \7 f7 f9 _3 r* B29. I6 j# F4 M: z# u
30 - q# l, [% V: C, A; T2 [317 T0 z# R" G0 S# `( C! T5 s
325 P' J' [7 p9 o+ `/ A
33 - o# s! h$ P8 W# n4 v346 I7 u1 M# j+ e( M
35 3 j& y, z9 w, U/ o! k$ w _7 q* X36 ' k/ K8 [, i) A; i! w- Y" _375 B# }- s5 l: p8 d7 w) h
38 2 W$ ]( R# b9 O G2 J& Z8 A" B394 O+ h& ?5 N' r8 x
40 . N" l8 Y6 n# O3 l5 Z415 Z0 v% x; i+ o7 _; L# [; u+ d3 y; H/ f2 ^
42; t: ?- x7 S1 o& p. } t
43: x9 ~7 r- M* E7 V* `
44 " _- y! b# X4 V45 1 b" H+ K. i1 f! r46 " w x6 v) D& S: f, F; @/ h& h* e47 . M# j, Q [% w" ~485 D* ]. l( C+ i; k) i6 X
49 ; Q# X3 T6 D6 Z50 * H; \5 C5 `$ [5 g51- q M9 y! V5 `/ _' {. d
52 + B, z) d; @7 X4 T W5 M% X53! [( F/ j1 O, R
54. s0 G7 D2 X) P
55 ; L8 \+ Q0 L8 @1 A$ X" q0 W/ y! j56 + V6 w. d+ a. x0 c: K57 5 B y- ^+ `% W1 Z- g58) w8 v2 I' f9 E. C
59 ! s6 p+ |. l1 U& I; n" t602 V( Z; R& n5 _( A9 Z
61. T! B. t" O/ ?( L, S( [ J+ _
62 ) c' Z7 H. A% Z. o7 i63 , _; f5 r' F4 \' l64 # M+ i, Z; u: J; t65 ( C" e9 J! r% o' P. h, ~66% t, q7 ~1 N4 C* M
672 ]" h7 q! W7 O j; m" @" S
689 J* u) x3 w# ~8 Y& Y
69 / }2 U H# ?% K! m701 c3 z3 ` z; ~3 S
71 7 |' D, W3 Y8 s4 I) W6 v728 z i& D8 I0 I- T+ V5 I: k4 Z
73$ W) z) b8 ^+ F) @! s: x* y& `. \
74 : c5 \5 P+ \9 N% \0 e9 T8 D: x756 o( @" A+ H5 {1 U& P
765 y: |2 v% |/ x6 R8 G% D( I6 A0 ?- C
777 b0 N' a7 {$ z9 F6 r: d- C5 T
78 5 D1 F0 f2 Z% X# o& ^79 + C. o7 M' i% z1 r" m3 e; v80+ d5 [! g4 P! Q8 X- L
816 [% P0 |" U6 q5 W @( s g
82 ' T/ X* O- y# }# I2 H5 N E83& F# W' M/ V+ N8 @5 l8 [; `: G
7. 设置需要训练的参数 6 |# p: d$ z8 ~# 设置模型名字、输出分类数+ @$ \) ~# k6 Q1 g8 O* R! q
model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained = True)) h/ V5 G8 d, _/ J. w
% Y* N% b! s# y
# GPU 计算 : V# @5 H2 M, k, K# wmodel_ft = model_ft.to(device) & R: q9 V0 O" f; m/ H & i% h* r8 |2 a# 模型保存, checkpoints 保存是已经训练好的模型,以后使用可以直接读取' t/ r ^2 O5 D" p2 }
filename = 'checkpoint.pth' 8 q6 h3 v9 K5 R" F' V/ D 1 f$ ]/ i6 m2 X. ^* }# 是否训练所有层) x0 U+ B7 Q0 Z
params_to_update = model_ft.parameters()( S8 r8 r( a2 |, h) V! ?
# 打印出需要训练的层 * b N. D+ k; C0 @! w6 M# P: {4 n2 T7 cprint("Params to learn:"); j+ i4 n f+ K M& V* _: i+ V
if feature_extract:2 H# E5 W8 i* ^8 Z7 @
params_to_update = [] ( `) t8 B5 O% ~" N' q8 q1 g for name, param in model_ft.named_parameters(): 9 D1 j5 L Y! B+ | if param.requires_grad == True: * y1 L# l. y8 A1 X. q3 H params_to_update.append(param) $ V) X9 i, b: {$ H0 `" m8 {) a print("\t", name) - [0 h, T$ R9 C3 w Ielse: , z1 V; }9 m) ?( I# \ for name, param in model_ft.named_parameters(): 1 d# m" ?4 v8 e' s if param.requires_grad ==True:) W2 S( q8 ^$ A0 s/ v8 V, V
print("\t", name) K9 z5 l$ n# b$ b: D / X, k) M) m' c; C1( {$ R( ?$ ^$ z( k* u7 p
2 6 e. \+ y$ G# Q \3) K# d# X! r: W6 ?: p0 P
4& O) d6 [* K' }; D! s1 s; X" g
5" P' Q; @7 M& G V7 [
6: A5 M3 |8 L+ {" {3 m8 q, p, u
7 1 W0 h1 |3 s7 K+ ^: P8 : {# J: _$ U: [/ H" m9 ( q% s; p- ?0 r* L. S* A102 `+ ^" }, | q% i! x. x
11 : o" S7 M4 \2 w( F1 E: f12 5 r: L% m `. I8 k/ b+ l; ~' J13: b% `/ [' f [* ~( Z
14 ! c6 G( `9 W* T$ l' r$ q6 I! [) m15( H$ Q( p$ b M' J; y
160 E, J7 |3 k' `0 M" I
17 7 l; j- T/ Z- Z18 + r3 c/ o! l: p+ b3 G: u" ^19 4 t; S+ V# c0 y+ W" D3 L- e& {: y ^2 w20 w$ f" F+ f1 ?7 u' K1 Z21 d$ H" k% U ?. C' k' \! J" I22* h0 ]# i) k, ^2 i& c5 o" W$ ]% |
238 Z! E, U1 Z" a# N5 J }
Params to learn: 8 e0 _: T3 y& x+ v, v% b0 H. P" a( ~! H fc.0.weight # n* ?+ o8 _6 p h" y$ _ `$ E fc.0.bias 0 O$ j) Z4 I, Z2 f# Z3 u; l3 j19 @/ e& V: h3 ]$ Y% f' |8 v
2 8 u0 x4 ^$ M- q8 m9 p3, ?1 T7 P. F$ ^) Y3 R% H7 T
7. 训练与预测 . m2 n3 c. p1 \! H7.1 优化器设置4 G1 G# D# m; }4 ]: j% T* I3 B
# 优化器设置1 U! N' m. [/ W3 d; B2 Z0 ?
optimizer_ft = optim.Adam(params_to_update, lr = 1e-2) % H# f B) |$ H. }4 A" F. N9 V# 学习率衰减策略4 V, D9 h; P1 k6 j, p% @' [
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) " E4 {6 S9 o3 `) z4 t/ v1 H R5 \& [# 学习率每7个epoch衰减为原来的1/10 % J4 W1 {9 X S; @/ T! d5 u# 最后一层使用LogSoftmax(), 故不能使用nn.CrossEntropyLoss()来计算 % j' G# Q p4 R- B: u* Z; D; l0 e# G
criterion = nn.NLLLoss() " g, P G$ K7 G' K1 - ~& v7 U3 h$ V: `1 x, l0 {& G28 K! Q' V4 ?8 ^5 }+ v
30 ^- h7 @8 q$ B- O
4; ]" P( d' ^* F- M$ k
5 # S, u6 w% P/ H# t, x2 H5 u$ ]6 8 `3 b0 J6 c. V9 h7- l6 M/ @& |# f+ e3 \# l9 r+ H
8 / Z: }/ T5 J5 Q. s1 P" D# 定义训练函数8 o8 i9 i& C6 F. p0 a" U
#is_inception:要不要用其他的网络/ b& }6 z" y8 ]& W- T
def train_model(model, dataloaders, criterion, optimizer, num_epochs=10, is_inception=False,filename=filename): * D/ n, O# J8 k( `1 a- W* C since = time.time()7 ` x Y. z- z/ g
#保存最好的准确率$ d3 \5 R: q/ U* t
best_acc = 0! |. D* F; P; K) T
""" " i5 w% k1 |; H/ b* C/ w7 r checkpoint = torch.load(filename) E( ]% V% l7 l best_acc = checkpoint['best_acc'] 2 c, A2 q' ^% T% d3 ~ model.load_state_dict(checkpoint['state_dict'])# ?% e1 B0 y: \6 }1 g, k- k
optimizer.load_state_dict(checkpoint['optimizer']) / w% r5 d# o2 {0 t: o model.class_to_idx = checkpoint['mapping'] ! ^1 i5 `7 e k E* n """2 |. b W5 H9 T" k0 n4 g7 g
#指定用GPU还是CPU 9 W1 Q1 O9 [4 _4 l# X model.to(device)" q& X. S. W! G
#下面是为展示做的 + R+ O$ {$ ~+ q val_acc_history = [] 1 p7 |, s2 m+ B3 Q train_acc_history = []& a4 M2 i1 ` P/ x) k
train_losses = []# _5 S9 N- N% h; ?8 m
valid_losses = []) B* x' k% e# `, b8 O* n& D+ u
LRs = [optimizer.param_groups[0]['lr']] 3 K' s# {5 a! x* C2 D% _ #最好的一次存下来4 Q3 V1 h& \; G# K1 x4 n9 y, E
best_model_wts = copy.deepcopy(model.state_dict()) t$ v+ r9 `, O- z) J 9 g( P6 E1 t% n# f. D4 L1 v' S! R for epoch in range(num_epochs): 8 }& t! Q8 f* s print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 5 ^8 p2 h0 j+ w4 ^# O9 ?! \ print('-' * 10). E' s8 V4 W/ I
7 O" E) X. b5 x3 L* n # 训练和验证6 l( v" V$ M- [% F$ P4 v1 f* {
for phase in ['train', 'valid']: 0 u5 C( J1 M; m2 c5 j if phase == 'train': % a3 F) n1 f5 C/ t6 I G( d, ]) _ model.train() # 训练 ! B. p: B$ m0 c! {# f else: $ M2 s9 X* ~ x! t, V) g1 s' K model.eval() # 验证 # D. A( M0 x- N& r. V) H3 v+ W+ ?2 ^* P! f
running_loss = 0.0 # F3 U1 N% r: H3 S; ^ running_corrects = 04 [' L( b( y" L+ n
0 k) O- }+ m+ b, l c% t: @' c0 L # 把数据都取个遍9 y6 D8 S* i0 \& |; e. q: e
for inputs, labels in dataloaders[phase]:! L* Z$ q$ F' N6 R
#下面是将inputs,labels传到GPU * H3 s5 T1 Z. @/ T9 u% V inputs = inputs.to(device) ' I, C: u$ M' J4 A labels = labels.to(device)/ b2 b! i% e: h z$ s
5 B6 r2 i0 u% V( A" o4 P # 清零: D9 C1 r" R5 q- C' _
optimizer.zero_grad(), d# C, v" V @3 b9 A
# 只有训练的时候计算和更新梯度 m ~ i9 Z. f1 D
with torch.set_grad_enabled(phase == 'train'):/ y! [: g. e% v
#if这面不需要计算,可忽略 + O1 h/ E( L( o2 i7 K2 K if is_inception and phase == 'train':3 B+ D) B8 Y! N0 ?' ?2 W" {
outputs, aux_outputs = model(inputs) ' O& z; o: b) T1 G- I" w( {& O loss1 = criterion(outputs, labels) & H1 d8 d, k* ~0 V* d3 i loss2 = criterion(aux_outputs, labels) 0 N8 T5 e& h1 h; k loss = loss1 + 0.4*loss2 & a# K4 G, I9 M! `. h else:#resnet执行的是这里" h; ] d; q% I u
outputs = model(inputs) + W* [% Z) m3 Y4 T0 R loss = criterion(outputs, labels) # W ~, \0 S/ R( w* _ , X5 j* L8 _: B* Q/ C #概率最大的返回preds / \4 }- U* p! X8 x- t+ G5 D _, preds = torch.max(outputs, 1). {- _6 L% y( k8 \
3 s9 J4 s! O: d4 G" o! U
# 训练阶段更新权重 6 X- @, M' b& ] if phase == 'train':+ X b# e- r5 c. m. h: k6 m. d4 E! A
loss.backward()8 e! [$ g2 @1 h5 d, l5 F
optimizer.step() 2 q. f' S. X/ `/ ^7 P+ u$ l5 A0 ?3 {* P* t4 s% x! M @
# 计算损失5 y. t3 m- E2 u' P
running_loss += loss.item() * inputs.size(0)# i1 x" h8 a! y; b2 o) n3 D) Y
running_corrects += torch.sum(preds == labels.data)9 E" n* J# C6 H5 t) j" M! w0 N
1 C' Q& P3 }: R" n9 D1 ~. H #打印操作 - G0 @1 g* H* W9 |1 S" n7 Q. G epoch_loss = running_loss / len(dataloaders[phase].dataset) 7 y" G# o2 S P% e7 H epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) 0 d; S3 O# m' R5 m. P * V) y, Y9 T: ]5 p7 C4 P ( w# S2 K* Z6 e8 P& J- U* e2 o+ r. k- e1 } time_elapsed = time.time() - since Z: I k- b$ E% q1 \ [% ^# [
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) P3 Y! g. @8 |. E! G2 P2 q
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 7 Y+ b/ F& ^3 G* C& x8 { r3 q" }' `9 j% t9 j$ b( |$ g; Q: @6 Q8 q5 |$ f/ h8 l1 w
# 得到最好那次的模型 ) C. G& O0 l0 B0 O3 E! h if phase == 'valid' and epoch_acc > best_acc: * E) N: t4 X, {" p4 x best_acc = epoch_acc1 ^: C0 K; \: n/ X/ r3 t2 q
#模型保存 ( A0 B8 y1 v O& H8 G best_model_wts = copy.deepcopy(model.state_dict()): K, Y4 z0 }: ]+ X
state = {' n" l! D4 z- ?3 w- H
#tate_dict变量存放训练过程中需要学习的权重和偏执系数 B& | Z* m( q+ H5 ?+ F) n9 Z8 j( y+ q 'state_dict': model.state_dict(), " B0 b6 Y: T7 z8 [( W* x 'best_acc': best_acc, & {. P$ k4 c$ k# f# h 'optimizer' : optimizer.state_dict(),3 ?- f( x" i/ P3 c' u1 |
}9 k$ H r. i% [& t( A% y5 e9 R% ?7 e
torch.save(state, filename): U% K; M( }% j; [2 v0 @
if phase == 'valid': H/ C$ W) m& D1 n0 m6 [ val_acc_history.append(epoch_acc)# w. Q( x% V0 e0 J6 Y) L' N
valid_losses.append(epoch_loss)& Z1 K6 I" {* l0 j5 R) w$ M
scheduler.step(epoch_loss), i) U' r. S' D- V
if phase == 'train': 0 x3 V- o' Q5 Z) ?& g) J- N; n train_acc_history.append(epoch_acc)7 C* b5 F8 |6 E7 S9 ~9 u2 R
train_losses.append(epoch_loss)( |4 h' h* Q% R7 R$ E
& a8 R6 _" [8 @; f* J
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))9 X/ q* M8 y6 @$ R; t$ E' t+ p
LRs.append(optimizer.param_groups[0]['lr'])2 V* [! s9 a0 G3 ^4 @
print() ! h3 z1 J& R6 ~ j8 L5 E, n9 ?+ K. w* h
time_elapsed = time.time() - since8 ]: E5 x- U) V# s
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) ; S+ n) N) i4 U7 b print('Best val Acc: {:4f}'.format(best_acc)) 3 ^- H2 m1 E6 B+ G# l# W ' ?/ e- [. x7 H- F- d # 保存训练完后用最好的一次当做模型最终的结果 2 L2 a- I3 t, }. U) `2 L1 x& |+ N: P model.load_state_dict(best_model_wts)( q( q" j. X( W6 M h5 U
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs ' U# \) T" @4 o( n3 q% Z
) _" r9 b! j6 F& |, B
* G( x$ b, X5 w: p1 . ^$ t/ y: B5 H( l B. r2 ! r& ^# d6 a: c31 x; z2 x6 _( R0 U. O7 r
4 ) v" z1 B+ ~) }' ^& A5 % b/ o1 t3 d& F& V8 F! T- [6# l! n' U! y' [- @; @
7* b: O% X' P5 o1 d4 i M( j
8 $ ~/ N# T# B+ |7 A" |9 W9 7 e% ]/ H- ~9 ^' c* I10 : S7 i+ u, s" M- O( s# b11 $ Q$ H. t1 `4 w$ S" q- G& ^ E12 8 T$ `3 n% Z- F6 L1 ?* D138 Y( S5 D$ X4 h. d! B
14 0 ]( R- o2 A3 p# K& B" w15 " t' a& \- l. O164 {& W+ d1 H, h4 O/ N6 m
17 B9 `- u. z' X$ Y2 j- O186 c& e$ p$ ^/ ^. R! w; X4 C
19+ Y/ }; q* U) {# `* @
20 ) v$ [( u2 {) g& `( _1 T21 z6 ?2 D- C0 B: a
22* o1 Q1 D0 b/ V4 m; b" A
23- [! G. M* ]0 T, \
24 % B) d }: y" k) Y5 r _6 o25) O [& [8 e5 L% S- X0 _/ _) k
26$ C+ e$ ^, ~- g' Y, P* K- m
27( U" }8 V. T, l( F) o8 Y/ f, M1 U
28! U3 h' M# R N# T' a3 T0 U9 k
29' n$ g- I) z% X3 P" v
302 E' T* s6 b" x. F/ r( `( N3 z
31 & j3 @7 a4 A) @* @6 ^" }3 O32+ S8 p5 `( F+ g
33 1 T9 I; q) v, K* s34% w8 ^, L/ s' Q& m
35 : Z% L" P8 f# J% M: ?; b3 w/ o36 9 x9 ~/ l1 K& Q. [# W37 ( g3 F) T6 R2 ~, ~( g8 T) v/ M38; |' Y7 n/ x2 s, |! }1 X7 s
39 5 x0 ?. o' w4 X% l! ^; h7 z40 & v8 `5 G( M" e1 B, P41 , c+ H: i1 L$ u: @42 9 s- h6 U# d& h+ i2 m/ ]& E/ T+ U43$ ]" o: v$ Q+ H1 L( ]8 v
44 5 j, r; \+ v- q$ n# O# q' v& q45 ) I: n1 R( g8 D; v46, N* T7 J5 {- t% q
475 J2 U$ q2 U9 d- @$ C; H8 z
48 % A5 d- p6 S7 \5 l9 U497 `. M0 h$ n# p
503 `- s7 e) X0 v5 d7 ]; [' c) ]3 e
51' |6 @" m& E2 T/ e0 |$ c$ [& R" D
52 . g( o7 D/ c+ d. y4 o$ Y53 0 Y C" h9 [" [7 N' P541 z" W, }% k( ^6 }
551 c3 z' Q& F( @6 v
56 ; {7 f& ]7 q) J/ U576 s/ L# c& Y4 [" T- _/ s, H4 v
58 b: b9 |' N6 M" G2 b) L" Y
59 2 I& o/ }& O4 N/ W5 k% S60 # C" v* g; e. F/ L61 7 v$ y# A3 I: A: D$ U+ v8 k62 2 \: N' H v$ S4 v- k63 * [9 K+ u2 _4 @64* k- b/ D+ f; Y s$ l0 w/ G
65 / ]# Y5 I: x+ ?; [! ~2 Y( P66 3 B- }- V( G8 {: w6 Y3 e0 a8 P67 5 p& ?% F" C9 a& v' z! m) _# @68& Q; I' R" j/ A- l7 Y; d
69 2 J) @0 `% H( B2 _3 |4 D- [70 0 |$ _3 g" D, t0 y+ t71) y- S* O3 V% J$ R* ?
72' ]" p/ f5 N# R; s9 q6 x% l
73 ) H9 v! d$ G2 S; G6 n9 _74 ! d3 P1 N8 L& G* }6 d* x75/ ~& k6 n# ?4 {2 U
76' |% N& h, ^$ A2 K4 c& |. R
77 # H3 L% w# P4 y$ V( D; E' A78 " B2 f+ H1 W* t+ V# w& L, p' B79' k$ o. M3 d( @' k1 v
80$ A u5 W5 a5 S3 a z# v
81 . u( k# u4 _2 M% X82; O3 r, ^0 [4 n7 g3 S
837 H; d; H7 `2 S2 e- X
84& a9 z9 b/ P7 c
85 $ Z& N3 @6 Q, P6 i86 6 y$ w R, S! m) J, r6 P+ A87! ^6 Q7 U7 S1 k0 N$ J% D
88 / k6 D: g$ k8 i7 r u r5 [89 J1 l: K& x* F
90 9 k" ` c4 Z5 `9 p$ }91! J: m* ~: e2 D- o+ G
92 ) m% k2 C% ], G+ ?$ d& C' L9 d93 ' M* W& m! v* s+ }94 ' t/ d9 w, y+ G$ X9 T" v# t95 * j+ b9 n$ D8 w7 t/ N( V) j96 # ~- S/ h% z, O, O& ~' f) n! T97 - x. n8 m4 U6 _98 0 f7 Q% }6 u6 c' _# W7 a99 ( h3 P5 m! ^+ c$ D$ G100 # B' p9 ]# b2 g% \101 * c7 O' V8 f9 s102! ?/ V1 A8 [- [$ z7 j& T
1030 w8 s/ z/ x4 G( d8 O; W$ I5 l% }+ a
104 ( X" ]; v5 n" y, k. W2 i105& `. N: s6 L+ @9 l
106 # w4 x& C3 m; S( r7 T5 `& k1076 |) J. s: i% }# S5 K. M* I
108- J7 S/ T; W" v
109& E5 ]2 O2 l9 b
110 3 g5 B6 s% S _111 8 b( A* O$ I3 o( q' P$ T/ E112- e$ _3 L8 V) e4 @/ ^ j' F1 i
7.2 开始训练模型 * s; k! A, \) D' l* v0 p$ Y我这里只训练了4轮(因为训练真的太长了),大家自己玩的时候可以调大训练轮次% k3 `. T9 \+ `