9 f; I* @# a( q2 f$ R; v( f0 `( b1 Y! l
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。4 M4 L. |+ k9 ~; x; u' ^
: ?+ l% ]! s* s U% H
% Z( @" u+ N& D+ _* y二、加载数据集 1 ~2 Q/ A- ?# ^( @! r; U; Y; O使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。 ! x9 n' R) p: g/ D' T( ^- h `5 @5 k) i, D1 A$ x5 Z. j* ~' t
1 b& B* E6 n2 k2 F
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()( \+ P/ h/ A; Y/ V F# T5 h. P
9 O# {8 y% T" x# g. Q& _5 f" Ytrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')4 q% `# I# }' y$ Y
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内( \ m7 y, X! s: L( l! E
a1 l1 S. |" f. X
BUFFER_SIZE = 60000 D" t1 s( K" X7 X6 e# ?& \
BATCH_SIZE = 2566 D$ j6 g/ u W0 r
5 Q' J2 N. W/ v. f% J% _# 批量化和打乱数据3 Z4 \3 d9 `) L* ^5 r2 M
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) $ m2 r+ w6 v1 i* D$ h* B# q* U三、创建模型5 |6 G/ K6 B' P; `, {. j' ]; M1 g
主要创建两个模型,一个是生成器,另一个是判别器。 , N& G1 @" f9 j9 E6 [1 ^ % b0 R) g# p# U& O/ @$ }' U" p) s% q0 d* c) K. ^6 C
3.1 生成器' M* ^$ R; W6 Y" F3 ^# _
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。 0 M, l/ S+ K+ D7 D$ K7 R- D8 \9 g! t ?) ?! W
, S' ?6 |7 C }: R8 z3 A
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。 4 d! D* v* n$ S A" N- }5 \ ) Q' r+ X3 J" i8 i% x% @ 6 `7 J z5 K# q1 A/ B8 f后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。 K- y" [1 ?7 }9 n. F- R. D
: i' v: r% L) j! F5 R4 e$ q$ ?0 ?" S7 F $ _3 Z3 D0 ^& F: O- a% C- L/ b5 H9 Idef make_generator_model():5 y$ T% j5 r3 K" a
model = tf.keras.Sequential() , d6 h% A T# \3 z3 z0 W- a" N! Q model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) 7 x+ j% g3 u8 z/ R6 V model.add(layers.BatchNormalization()) 5 H$ `# ~9 F! H* k; k1 S. H model.add(layers.LeakyReLU())4 B) D6 d( O& j; a, C5 T" i6 U
$ u' S- m! r3 v- g
model.add(layers.Reshape((7, 7, 256))) 7 _- e: c B) x/ \9 _3 i assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 8 f3 z7 e6 T7 B I0 F3 A; X 5 Q u- W* H6 { model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) 0 L% C' z ^, Q& N assert model.output_shape == (None, 7, 7, 128), h( ?0 {9 r- f' ]; v: Z, C
model.add(layers.BatchNormalization())' C4 o5 c0 K* l- c8 C; j5 K
model.add(layers.LeakyReLU())" ^5 Q& [* G' M% T
; y( ]4 ~. G, p# b' G" ^% b9 v4 B
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)). f8 T2 h/ h4 M' J
assert model.output_shape == (None, 14, 14, 64)7 j; K g; C) @
model.add(layers.BatchNormalization()) . e9 w! H' M5 {8 w+ a9 U model.add(layers.LeakyReLU()) : p5 Y0 r' ]9 W& h5 d* G : r& r1 J* @6 x S9 B% x model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) & V% s6 K3 a& Z- Q3 R$ H. P assert model.output_shape == (None, 28, 28, 1) + n6 d; M* I! X/ x , m: I" O9 v. q- _, O4 {* `$ n return model8 G7 q) f i0 ]7 F9 O& o
用tf.keras.utils.plot_model( ),看一下模型结构0 \5 g& }2 @& Q2 k, W
8 b$ \; q, P4 G5 b: Q. j* Q0 K! `/ n2 Q8 i/ ]
8 |; ?) c5 A/ E, f 4 \7 q% U8 y2 M! H7 |3 A3 v% p4 a& z7 y/ y2 A
用summary(),看一下模型结构和参数 ) _8 _1 C# L, L. h- h5 g* P+ Z% \2 r" b: e8 D
8 x: R2 V! C0 R5 T
' | W9 Q- A) ^$ z
' A% g0 p% v( J1 X3 D. y! Q& K; J% ? K0 Q8 H* q
1 \) ^8 h9 B- U4 _% s/ b使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。8 I+ f/ q/ R9 I' r5 s0 M) Y
8 Z) c* L! R* C7 v$ h2 o0 @1 f" o! M7 o$ \1 e; b/ p
generator = make_generator_model()3 A- O, w- }4 `0 Y; i
% K0 r$ J/ a) D$ _+ pnoise = tf.random.normal([1, 100])- A- q5 s; C) e. y% {9 ~/ v) Q5 ~
generated_image = generator(noise, training=False)! F8 k9 Z) @" x; i8 x/ C8 G
& |2 y( }( @ d# C* c# k
plt.imshow(generated_image[0, :, :, 0], cmap='gray')/ N/ ~: ?5 `4 o
) s" ?& i5 R$ C9 F7 i }
, s- f- G$ d$ H4 ~4 m
9 x0 U/ c* c% z3 v" V5 I/ u) D5 ]9 C
3.1 判别器) f& [/ x6 M2 c' N. N, [$ p
判别器是基于 CNN卷积神经网络 的图片分类器。/ X8 P( }9 {. W8 @. z3 D P- j
6 b7 k) p9 Q4 Z6 y! i% }, {% U- h 6 \! D, J, F" M+ edef make_discriminator_model(): / l6 S \" P8 C2 z model = tf.keras.Sequential() " G$ ]. `5 Z; e model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', 1 f/ K6 F, G% d& y+ a# E input_shape=[28, 28, 1])) 2 n( J7 J& u; b4 C' ] model.add(layers.LeakyReLU()) 8 `1 b0 R" {2 Q# l7 V2 r7 d/ e model.add(layers.Dropout(0.3)), o' z) B: C2 z7 V l/ S
: }& x3 G+ O+ }1 q
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) 4 y- n; X8 ~" H. S p; U model.add(layers.LeakyReLU())% k. S+ J3 t3 N
model.add(layers.Dropout(0.3))6 {! F% c& I! R" h/ i6 I
( j7 V5 m; m8 j. |6 _" w3 K
model.add(layers.Flatten())9 i E! o( H; Y: Q
model.add(layers.Dense(1)) # W6 Z4 L$ g6 ?! V) G% z+ V* ?0 u * l) h- E3 N5 v3 Q& D6 e. ]) R0 v return model1 Y) ~. f( {& V& E5 g
用tf.keras.utils.plot_model( ),看一下模型结构 ( r3 F/ g, ~- S, H0 E: \" E9 C5 \% g4 W4 [* ^
$ Y6 f$ G- }: N$ w2 K
, P% R4 L1 b6 f2 j. {2 O, Q3 U1 D- E, s7 [4 i+ H. G) i
, E' I7 k+ Z: ~+ ~
]8 B1 Y$ H _2 n. s2 y; \ t
用summary(),看一下模型结构和参数4 x- J1 x. I6 o+ M
; q; P7 v7 ?: C1 S, G. |4 J) l4 x1 y. K
- S) b* u: N- K2 |: w
6 ?, s0 Y% A L
; q* O5 f, z6 g' Z$ E6 [
# m! a# n& m& K H" y. A6 C6 h四、定义损失函数和优化器 ! B; W( v* B3 Y/ f由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。& Z) N* m a4 J& }4 f
7 P) V; F! F ` x2 y/ v$ n+ x ' C, Y6 i5 f1 g首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 # s' u! Y/ K9 W( r 3 n) o \+ g/ H * b+ V( o0 Y$ f# 该方法返回计算交叉熵损失的辅助函数 5 Y% N# O2 X4 o9 ^% b- ocross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True); Q( n4 ?$ b- v3 h
4.1 生成器的损失和优化器1 u7 s. g# s. Q/ I. l6 g5 `
1)生成器损失 & i; S) E) `7 W- `4 v2 F% B- n1 y5 }& Q, v* N, j* ], M6 Q0 ]
* ~1 d' p4 `. C6 c( {
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。- Z( G6 f. |; P7 ]* ~
- O; y8 H2 b6 E% X2 p. F , p7 R8 h {1 _0 X/ k7 |- B3 d这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。( t+ ~/ `6 H$ Q7 M4 s
# q/ n$ s5 H* {' `: P( E" Y, ^+ ~7 F3 p; a3 O* T
def generator_loss(fake_output):/ G. X7 k0 i( n, d, m
return cross_entropy(tf.ones_like(fake_output), fake_output) : ^3 l: h: K* {2)生成器优化器 ! |* d+ M r I * N$ U) w" L$ a7 y& L5 m! w$ O$ X6 J/ ]) C9 Y' D
generator_optimizer = tf.keras.optimizers.Adam(1e-4)3 D0 c+ o8 e/ {' ]0 E/ h
4.2 判别器的损失和优化器 # ~ O6 ]9 i( o" b1)判别器损失 ! B& x2 Q) ~! q. O6 n$ U' | ( j; r) u, N$ p# ~0 [$ x2 f& F# p# L J3 Z# B' ~# O7 J
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。5 E& S- O$ J' F. `5 J
8 b) T/ N6 {: M6 L i+ r; V3 V
7 U/ T7 w/ x/ _5 k2 n* j
def discriminator_loss(real_output, fake_output):& F$ [ Y6 g; }; m
real_loss = cross_entropy(tf.ones_like(real_output), real_output)# ~9 E2 e( S/ q; y) @
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)9 g, f* U% r# g
total_loss = real_loss + fake_loss7 O; g, Y2 W3 E3 w$ |
return total_loss* Q" P& ~) P3 s$ W0 o: `
2)判别器优化器! g. a% o6 _- l# M0 Q
$ B% D. `! V4 B
7 J' [6 s$ c7 |" @7 z! [9 wdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)4 F- L4 p5 m1 v4 g- o! I
五、训练模型 * W4 A0 B0 d# K5.1 保存检查点8 W5 Z( R9 X1 i S+ [# s
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。& x. `8 g# R% J1 i& L1 C$ p
# A/ [6 _1 W( T& }
; n8 [8 h# q v+ `
checkpoint_dir = './training_checkpoints'" o% _- M1 ~5 U2 g
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") ) [0 n# i; T9 X7 e4 A. h0 \, Lcheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,0 \! l& O! C0 p. S) J% G- d
discriminator_optimizer=discriminator_optimizer, 8 ~0 Q# e$ O! @4 ~% i generator=generator, 2 c6 h% B/ G: R5 I* @! N1 i discriminator=discriminator) 0 e; R" `: b. O0 {0 M8 W4 o5.2 定义训练过程 / A. E; _9 p( v8 F6 JEPOCHS = 50# \8 B% G0 h; n @
noise_dim = 100 & k7 e/ [2 J5 j* ]9 Anum_examples_to_generate = 16 " u" } Z1 l6 a 5 ]4 u$ X w7 p) n
: u4 B, D4 P9 p; d' d2 Z% g$ j
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)+ r4 h5 c, N2 o
seed = tf.random.normal([num_examples_to_generate, noise_dim])2 E- r2 s5 ^9 |8 G+ V
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。6 G, u, J- B9 M3 s
* y% W2 _, V% Q5 }. t. v4 X ?/ @
3 T, k( ], u; ~. _8 I判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。 * {) }, t/ H4 Y) I' }% m( | 5 H9 u1 T. M1 Q( a, y " @6 Z# v( b% }! O7 g两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。% m; ~3 C( Q5 t, [8 V4 v
/ t ~ N" g9 I. P2 z! A% x3 J; n5 G0 Y
# 注意 `tf.function` 的使用, g& M5 E+ b1 ?% R
# 该注解使函数被“编译” U9 \. I ?5 C, N! y+ p
@tf.function0 Q! X, g' P+ T! ]2 z
def train_step(images): 0 f: g) s# G, W. m! q# Z6 f' Q noise = tf.random.normal([BATCH_SIZE, noise_dim]) " ~+ k. H6 T3 L . z0 y: s& s- r l+ a( O ^ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:! b4 P" v8 Q" n
generated_images = generator(noise, training=True)% K L# g) g7 u9 C. B u+ N
; {" e% A3 { ?1 c3 l3 [9 S real_output = discriminator(images, training=True), N: l6 i* X6 W6 t8 u# X6 I& T
fake_output = discriminator(generated_images, training=True) 1 w1 J) A2 d; x. d: T 3 b. f% L t# _; F# P: w( v! c
gen_loss = generator_loss(fake_output)% I9 F6 ^# G A. I3 q5 C) C. \! U
disc_loss = discriminator_loss(real_output, fake_output)( p* Z6 B6 g& E. c- f
/ |6 O" E) e q gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 0 @" @" T0 ~- X" X, r0 A7 Z gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)& l8 r- x8 R L Y1 T* L
8 U+ m+ N1 S8 @& Z4 w generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) 3 S/ A" v1 {) S" {" ^ discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) , V3 G& L, g: |, x 7 K/ K. N1 O, Z4 h% ~
def train(dataset, epochs):% P" x; ]4 d; f8 M( @& G
for epoch in range(epochs): + p: m0 _) t- w4 l start = time.time() & O9 Z/ q1 W' b1 K }8 `( ]1 t3 d: ? $ X# ~( G$ P2 S# f' a for image_batch in dataset:3 y' N4 u7 i( z. d( L3 e7 m
train_step(image_batch) : ]. b0 p) T7 D" w' W& m8 O* }1 f1 w 6 r) \. r8 i1 l2 ?; y2 l x5 T, q
# 继续进行时为 GIF 生成图像! l8 ~; m8 m$ D- _' C; q+ C* U; x
display.clear_output(wait=True)! w' n$ d" |+ k% S3 D4 m+ X
generate_and_save_images(generator, : Q" [1 a& h% H0 t( _ epoch + 1,3 L& v2 Z. d; x6 I
seed) & I! E# ]1 W' S: q1 c' W9 Q 6 S! W C" L$ z T+ V
# 每 15 个 epoch 保存一次模型5 y4 J1 s) C. K3 m0 P" W
if (epoch + 1) % 15 == 0: ; q$ S l. K1 J; Q" ` T4 J8 z checkpoint.save(file_prefix = checkpoint_prefix)- w+ t0 o0 J, A8 F5 z
9 A+ z- e4 D. y+ p! w( j' m print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) ( R: L$ h# @8 A; v1 c; u7 G / |5 }- U0 h( N, G4 o; j" [
# 最后一个 epoch 结束后生成图片 . @6 X$ U3 E/ o, d display.clear_output(wait=True) 4 c0 Z5 B/ I4 H6 h* B1 I* g! e5 z generate_and_save_images(generator,7 Q: y# q k/ k! {2 `
epochs, : V! F. @. {6 B9 @, [$ ^; y7 R seed) " Q% j- A0 ?# y( F 8 C2 v1 O3 O; k, w R. f& z" Q# 生成与保存图片 2 o5 F3 V4 c* i/ @def generate_and_save_images(model, epoch, test_input):" e, f& F: h1 y% b) i: @3 ~
# 注意 training` 设定为 False" t, Q4 d, d' F6 j: g$ {
# 因此,所有层都在推理模式下运行(batchnorm)。6 p, ?! z& ]# o" g
predictions = model(test_input, training=False) ) N' [8 D' u9 G4 Z+ ` ; @3 @( g9 q" r3 S9 _3 l6 J
fig = plt.figure(figsize=(4,4)) 3 |, {& p* w. `0 p5 B T 7 |1 V% h$ K% }2 x7 Y n8 v2 j
for i in range(predictions.shape[0]):( A* R( y+ u f( @2 Y8 |# a
plt.subplot(4, 4, i+1) 2 j% h2 E6 j$ a- M3 _ plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') 9 o& W7 a+ a8 [; L% N4 E* N plt.axis('off') 3 e8 A1 `& K5 Y. Z # B2 v4 p' s9 v! u: C* T1 ]; a plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))3 j% g! o, j$ N
plt.show() r1 @+ V+ U: Q- u& h
5.3 训练模型 - R! h. P+ j5 l4 e. z调用上面定义的train()函数,来同时训练生成器和判别器。 9 F5 S1 l' Q, @( u ^* v4 M1 |. S" @ ?. [: f# z2 x$ Y0 h
F) t# q4 l6 [; E
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。1 {7 |5 E4 }0 G9 n5 _1 Z4 Y3 l
+ Z* a. i( c! ]6 _ 7 w" l0 r( d, p$ h9 n%%time+ i1 a" _4 Y4 a6 L3 K# C
train(train_dataset, EPOCHS)2 a: N5 K4 ?7 f9 E. |2 y- S
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。" J* Y) s$ p7 X5 ~: D" H' W
3 G, G* A: w; W* Z% \! V8 ^6 } H! u
2 ^8 N) X% S6 |训练了15轮的效果:4 S( V ]7 \3 r1 w
6 b; D. i- N& l3 K8 p0 d* K
: L: P* y% {- v$ [9 ^0 O( v# a5 _" K# J2 ]: C
4 z9 P7 a, z5 T( Y+ W* L! r