% X2 s0 E. o4 n$ S) [2 r/ |! S* s首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 5 \( _8 g9 _! L; J0 X' B 2 l, J7 | k1 f2 y5 c 1 J& P* R/ z1 }% [# 该方法返回计算交叉熵损失的辅助函数" A v4 W- Z" ^5 E- J+ T+ j' E
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) ( N' ]; l6 C3 t4.1 生成器的损失和优化器1 W! G3 i7 ]+ }% h, u& x
1)生成器损失: j$ G$ |' J6 j) n8 Q+ i
( [0 r/ t* |0 y9 K4 X. \+ A
/ b" o( N4 }4 O9 g, |0 s
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。 . k2 ?. m2 I& C. X l- l% S! d! `$ E9 c
6 M4 L! T$ ?# A! T4 D/ q7 Z这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。$ c/ |8 S8 A) N% {% q. _* v
$ H+ I; w" e( W, I$ y9 _4 s" E
7 j' q# Q- h+ ?0 Q4 E9 ?* n% S
def generator_loss(fake_output):; |1 E& S: o5 M* I. L
return cross_entropy(tf.ones_like(fake_output), fake_output). ?: F# ^* f! O( q- X0 k$ Z: R
2)生成器优化器 ) w( H1 y3 {3 ]1 ?5 q( {; |5 F6 V$ j) V% U3 t5 ?
; V% n. z% {8 z& B; `, F
generator_optimizer = tf.keras.optimizers.Adam(1e-4)# {, d" t, z0 s1 d- b. p
4.2 判别器的损失和优化器; @1 ^1 i+ u8 X8 w
1)判别器损失 + a8 O4 l- o8 j : G- V& @5 a) d: f) Q: f( Z& r( b, I0 g0 q" p$ E
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。 0 z" S$ h! _* k1 b) G; A) m+ H6 l! C9 q
9 h( ?" p0 K& V G" m1 l+ P
def discriminator_loss(real_output, fake_output): 8 u' I2 r2 z( R& {7 C real_loss = cross_entropy(tf.ones_like(real_output), real_output) ) ~9 c! u) v! l0 t6 r0 C$ r fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 6 e# r9 @( k5 p6 M6 `/ g5 @! s$ H total_loss = real_loss + fake_loss 8 r# b' m" B: d) C2 t/ ~8 h! n return total_loss8 M. P; W+ E! i& u
2)判别器优化器 / ?$ Q( G5 {! L1 f' u. C) Z5 E$ l0 g5 M+ o
# k0 `0 J: t( `: p& T- [
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)( X# R2 _2 q4 \( b* u% D, \* I: d5 a
五、训练模型1 B6 u1 [9 P% v
5.1 保存检查点" S) V4 J6 }- ?0 m
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。 g5 y% E e) ^' g) p0 W& v
$ C: Y3 _: I1 |" K6 ~( Y* w" V & ~ d+ @" T% S2 Y; ]4 P8 ^ zcheckpoint_dir = './training_checkpoints' 5 r2 `8 ?- h9 ~( \) I8 M" |" k! lcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")8 {! }1 h7 _% a9 I; y
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,; x- J: ]/ w6 l6 j2 D3 ^4 R
discriminator_optimizer=discriminator_optimizer, % A1 _! _9 V8 i6 y- _3 l3 C0 i( p generator=generator, O$ P) y4 q' P- c4 \
discriminator=discriminator)1 X8 \ K( @! c# m
5.2 定义训练过程6 }3 N. `! s/ z* H
EPOCHS = 50 ! t5 ?# K8 u0 w0 [7 v) B1 h' W8 Lnoise_dim = 100 ' A+ [# Z3 D0 B: ynum_examples_to_generate = 16 ) a; E: C6 }9 u6 f! s# j 8 l; x Z" K9 D C' m+ M ) _ `3 o# Z% g1 b I T$ J) t# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)- }5 F' j5 M5 F/ `/ u( z
seed = tf.random.normal([num_examples_to_generate, noise_dim])$ {1 ~$ \' u6 @5 X
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。 4 B3 i( ~( x5 f. Z$ U9 J3 d9 c7 u; o1 T3 n0 g" a. `
4 d- I) t' l9 Z+ J1 _0 M
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。* M0 h$ A4 G' @) u, @. d; n$ c
" ^' i h5 e) @- [$ e& d3 W1 d2 x) b
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。 , K+ Z& u* r& c W , _+ H; P4 m8 I2 c. T 6 Y+ g, U/ o, f6 `5 W' I# 注意 `tf.function` 的使用 a1 g, |: x. p; v2 y4 I) U# 该注解使函数被“编译” 1 b6 Z& K3 T) o) T3 A W9 _@tf.function$ I% d) B1 }' b+ P8 |) J6 Z& f
def train_step(images):: m# E: q1 a6 `9 Y+ b
noise = tf.random.normal([BATCH_SIZE, noise_dim])% ]- ]* `/ L- G# T4 h0 F/ g' ~5 X& |: \
/ n4 |) m! O. L0 c8 ^# d. D+ s' O
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:/ K6 m _3 l1 q, N% \# j
generated_images = generator(noise, training=True)( [4 B6 g5 }4 Z" R2 J6 b
+ I: h8 ^* P3 `! M# e- z real_output = discriminator(images, training=True)# M- n) M, \- M0 F
fake_output = discriminator(generated_images, training=True): t# k F4 ^( f$ u
8 h7 ~$ I" v/ N" S% Q7 } \/ r gen_loss = generator_loss(fake_output) 8 L, A$ O, W/ l- [( v8 O, `8 r& { disc_loss = discriminator_loss(real_output, fake_output)! w! l4 m X. Q% S, z
6 m1 w0 h5 N4 G$ F7 j( ], l. V
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 4 A$ P; g+ Z2 |- j gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)& J0 L9 {: X. Q Q& G
0 y6 z' B$ e+ r
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) 8 r* w: U5 ^ E discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))* F6 j! a, y; [
0 ]: J' x' Q$ zdef train(dataset, epochs):! r# d( w8 S5 e: t2 J9 A M
for epoch in range(epochs):! q" z& J6 p& q6 p" S
start = time.time()8 O" a) b6 j2 C% @3 T6 @
8 X) [& Q% @5 w9 b- U. O for image_batch in dataset: 3 f8 ]# Q2 v. Y; r R train_step(image_batch)% C7 j( P. a7 _ q- i; ^
8 F: E$ X! X% Z# o1 ?% }$ | # 继续进行时为 GIF 生成图像 " t: [5 c( d8 E6 q; ~ L' \ display.clear_output(wait=True) - C% P( H: U- E2 O7 ?: S generate_and_save_images(generator, 5 f" H @% l! Y& t! F epoch + 1,( [# B1 d9 N9 Z. F, k2 l9 W
seed)8 x8 D) D. x+ }: @4 Q
2 i3 r! {" a+ l
# 每 15 个 epoch 保存一次模型 $ G8 Y: I6 Y# S2 P1 [+ a if (epoch + 1) % 15 == 0: + Q3 }0 |% O5 a8 p) D checkpoint.save(file_prefix = checkpoint_prefix) ! F g$ u) |1 G8 c" h' v 5 S; ]% c R, z X) F print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))3 U: R6 q) ~ S
V8 h% e& R/ M6 e" T ?2 [) x+ I
# 最后一个 epoch 结束后生成图片 - ^2 O7 K! W4 u, | display.clear_output(wait=True)& g4 `! u7 o0 u) P
generate_and_save_images(generator, " Z: \) s. W) O7 H2 { epochs, 0 }: L3 x2 b1 U$ F. d- s C seed) 4 A! r# _( E6 W7 c, g) R 7 c' Y2 P) @3 g$ _# 生成与保存图片 . H, O, z- }. s. Tdef generate_and_save_images(model, epoch, test_input):' J4 y/ R7 [. f# p
# 注意 training` 设定为 False c; ]7 \7 b C8 ]$ y; V # 因此,所有层都在推理模式下运行(batchnorm)。 ' P; n5 Q( c% d" T0 _2 [ predictions = model(test_input, training=False) 7 W& N7 e9 w v/ r$ L7 Y8 O . F) i+ W2 I G, V
fig = plt.figure(figsize=(4,4)) ' [( V1 z1 U. F2 A% f! J( ] 0 d- U2 h/ n: `3 s9 s' {2 _+ i for i in range(predictions.shape[0]): ( O' W: C4 a$ E4 \. e& a plt.subplot(4, 4, i+1)) W6 S3 S. Z4 c1 m3 [, N+ i
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')2 h7 I8 J$ ]- S9 ~5 p, M( x* I! c8 J
plt.axis('off')- f( R0 v/ [' G0 f! {1 L
: C/ l/ T$ O5 I0 K8 m5 `* `
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))9 R- U" A2 Z* [. o
plt.show()0 W1 C# U- O/ }- k
5.3 训练模型( K- k+ n& H9 f( @$ d
调用上面定义的train()函数,来同时训练生成器和判别器。8 n! ?- y9 j- Y' I# F) X% `
. P. w3 p* u% `% d7 b7 e3 D1 H: n) b( P, D# b6 D0 p$ \
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。" Y- }, N/ X1 _& o, c
# J6 f8 w6 W* p3 x( O, L
% A6 d, T; w6 V; g
%%time* e+ y, U/ @6 V
train(train_dataset, EPOCHS) ) b1 u& I. Y* q1 ?- o2 k( s. o在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。) P4 W, R* C) V q: _4 q/ c9 `
' ^& W4 h% o, Y9 l4 p8 N6 H8 ?2 t
( z1 ~" W9 T- M( }& s' r+ _训练了15轮的效果: 8 `/ ]7 J. S# L) \3 L/ m/ ^0 s( N$ J/ Q, s9 [! Z" @
, Q; O% e) F A8 E Y
% ~* R$ `: x2 [, e& u
1 O, m5 W, y! f! g1 p" J: r ! s5 P' x$ B' C1 D* L 6 y2 o) B& f: t* ]; |. X* {训练了30轮的效果: 0 ^6 C v1 |) n! q. f' p9 B* ? 7 U& v5 Q6 w' H # u1 }- q$ l4 Y" n $ ]! [; d1 b+ p' Z o! p+ K ` h! [. z& V c
- D; @/ e4 S+ n4 D" G4 w. u" x! e" k, ~$ m! p
训练过程:$ Y# O, K6 {/ y6 Y% ^
8 m' k8 }: K; M4 ]) y* k3 A) O
4 h4 M! \; C* e7 I9 v( V' x( ?1 R# O
: L! l# u- p0 r& k
' _9 Y. e" m# x0 c! Z2 x8 U D* U% p J- t) a; i
恢复最新的检查点& X* a& m" b8 n) e! l4 L
, l3 E, X4 f S$ M1 m
! n$ L) h- A r* i. g
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))8 {0 H% R p Q i6 {' |* x; `% c, s
六、评估模型 * I- D# Z* |3 l5 ]5 ^这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。- @( }8 V, b, i$ s- @4 o
" i9 b. H7 ^/ w; y9 T) n- g1 M0 F% K, L
# 使用 epoch 数生成单张图片 ) M' n' q8 Q& b+ `5 K- k7 edef display_image(epoch_no): 3 k9 p& b" y; M( s- [! E return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))* Z6 L$ E \; K7 Z
2 I; D% F! r) C# W7 ~$ G
display_image(EPOCHS) # P3 `5 d) I8 @# n. a3 zanim_file = 'dcgan.gif'9 q) f' @# q9 I8 K" m3 ]" ^6 c
# f+ s4 f4 _1 \1 @$ m
with imageio.get_writer(anim_file, mode='I') as writer: * K3 Y# v# N/ t% l: k% o& x filenames = glob.glob('image*.png') " L& I4 @1 m) v$ v9 @. e filenames = sorted(filenames) ! @. ~" F0 o8 |3 N1 m( o2 P last = -13 w8 V8 R3 b* d4 B; r
for i,filename in enumerate(filenames):$ L7 G: z. ] V! p7 j4 S" G
frame = 2*(i**0.5) + ]+ Z" `- K8 h0 i; p& _ i0 G if round(frame) > round(last):2 {3 U0 C, P) M- f. S" s
last = frame 8 g. G% X5 z; ?3 k! b3 k8 [. u else: % H2 D W5 x2 }0 e" r6 J7 s/ A7 X continue: I9 i0 O+ Z9 p: {3 }0 V( \0 j
image = imageio.imread(filename) $ P6 b4 P- B% l0 i writer.append_data(image)3 p+ C" r) \& s" f/ r& x
image = imageio.imread(filename)4 |7 I- H$ K" Z; @) q8 N$ ?2 s9 |
writer.append_data(image) 6 k* v: V$ {" | v' W3 I' ^# k 9 a% Y5 S7 {2 Q, I# X, p/ Vimport IPython$ F6 J) y' L% v3 b5 g5 _
if IPython.version_info > (6,2,0,''):* k/ u+ ]( m2 r; B2 Z( g t0 j( {, w! h
display.Image(filename=anim_file): `7 _- i+ u% m
) z' F7 [: {* A$ X$ Z0 n% T9 k; _3 W. d5 l* Z
3 e$ s) W6 d E$ x