数学建模社区-数学中国

标题: 深度卷积生成对抗网络DCGAN——生成手写数字图片 [打印本页]

作者: 杨利霞    时间: 2021-6-28 11:54
标题: 深度卷积生成对抗网络DCGAN——生成手写数字图片

. y' p2 h+ {0 ]" D# {深度卷积生成对抗网络DCGAN——生成手写数字图片' \8 h" D% K, _' F* c( m3 e$ V
前言, v* O3 K; F  X. U/ Z9 m/ k
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。3 X/ j1 w4 X0 O3 B1 F9 j3 t) I# O4 H
! j; S8 x' N& r
6 u+ P0 Q0 i  l: y
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
! H7 `5 D% }: Z% k  z& u& T7 e# q1 M* ^; l) Q* k
9 j$ j. p  c' B2 D- `4 c& C8 a
# 用于生成 GIF 图片0 q" ^" b& P# C
pip install -q imageio
+ r( t  |# m; e2 ~, M3 o目录
$ W2 D: c$ H$ X1 h* `: q' k) H# @1 X
- u3 x) `$ H# o& D5 f) |
前言
: a7 }5 B4 Q, n4 v( I; G
: ^! Y# ?& i$ s  {  g/ K8 `
' t' Y4 F/ v& s+ W% t3 `: ~' R4 z) t7 c
一、什么是生成对抗网络?* W7 ^- {; F, `) C, G& p$ v6 @3 F

) k; D2 O! K5 R* u/ w
% J# O9 }  e9 ]
二、加载数据集8 y2 V. C1 R3 `  t

" y! Z$ R- B$ k
4 d4 R8 D9 @8 K, Z- }1 {
三、创建模型2 ]6 Z' {& `( h4 T

" d6 k/ |& H* G2 L+ _3 _- A
+ `5 y/ _9 D+ x- h9 @% N
3.1 生成器/ ~! `7 z+ S( ?2 L, S  N
0 m# {! ^" R; M1 }$ ~
. \; R8 {) V! d+ ]# g
3.1 判别器- ]( ]0 i/ b, g
3 v" I) N3 E: e% W

7 v, B! ?4 Y- ]' W1 D, N2 ^  \四、定义损失函数和优化器
' a8 W" g7 s- f8 P2 v8 {% J8 |0 @. h% c0 _3 H. E( w4 v

; {6 U7 X! A7 H7 X4 X4.1 生成器的损失和优化器4 j9 j, ]0 ?, R4 a1 i
. Z& @8 v9 i/ R  h

8 C9 ~2 `7 m5 n- K4.2 判别器的损失和优化器1 c+ U! p1 k& P+ @" F! G9 Q6 p

4 }% d+ `- j" ]5 \

9 Z* N* r3 h5 M五、训练模型( `' {1 T0 G2 t7 v/ q
" }5 d0 P  u; W" p* r

- Q' [& \: p* \; h! K5 {4 r# B5.1 保存检查点
' A4 J2 P6 S, h1 j! p
, J. h5 b7 Z$ O5 y

3 h4 a2 [% U/ M! |8 K5.2 定义训练过程
. L2 D3 }. Y; L, n, d% x3 y5 |
. `1 M4 A. y( ?! c1 F/ s! h: M
# k! ?" J$ o8 P: E# X$ G& c3 ^
5.3 训练模型5 g* z# R* x. X, S

7 ~7 e5 d6 V6 J9 Z  z7 d7 J
; U+ S) Y+ ^+ e$ ], m+ j* j
六、评估模型
. R( a  {( s* @; _" l% g( C
8 S) A3 C7 C% \9 A* n/ o; L4 g

5 ]5 E: ~- g# l, x一、什么是生成对抗网络?6 v6 q8 C; \) c3 F
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。; l" A2 g; d$ ~' o. g( r

- I5 e( x- g5 ]  [2 S
$ F9 ^8 O' Q+ k6 n, W
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。' A8 n/ @9 r. L' v( ^: r) b0 h
4 q9 k6 t7 x) X* c. d( M- p0 g

  X7 l& A/ }5 l! p判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。+ \% i6 N  Q+ d" I5 Y7 F
! ^! ]& t8 h7 \
! }- Y7 {/ ^+ J/ \
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
& Q; u' @% D% {0 x$ {  K# }$ f" B7 |( P* w( d& ~% a+ \
4 d. _- @- Z6 N5 E5 e: f" i
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。/ r8 b7 ~+ B# G3 v7 U

5 D3 {# [+ v, Z

2 X) V' l/ z8 x$ W本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
, A. N+ M# k$ h' c. D4 y4 I, Q( T& C' l2 x

* h! R0 g/ R. u: _7 O# `: p二、加载数据集
) H6 a! m7 k  }+ ^2 r使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。! d5 d( L% m( r3 z" G# S* {+ q: t- D9 P

3 }. m! ~' Z4 z! [& X1 C- z# @
& q* F5 L: ^' V9 H0 ]% k
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()8 y5 q# P/ s0 T
' y$ h. ]; c2 o/ g" ?$ |7 L: w
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')4 F( S5 _/ H: E
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
. F* S) d/ o% N
8 x& ?( O6 b1 Y- F) r  hBUFFER_SIZE = 60000
. a, r( @+ b- X3 n/ F& FBATCH_SIZE = 256) P. S: I; b+ i& S; I* b2 c3 S" v
0 f* o4 S2 p8 e4 c9 i8 k
# 批量化和打乱数据( l: b; |" ]* X  r+ ?6 N
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
4 g( E! {" O, ~, M0 }1 {$ X三、创建模型* K" P+ O* A* n
主要创建两个模型,一个是生成器,另一个是判别器。: F5 j! h7 l+ k. t, w: b# P

" Q3 ?# U# W/ r, j3 M
# f' t' Y9 c# v4 f" _1 _
3.1 生成器
  O; J7 f) X/ E( r7 l* q9 a0 J生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
# N4 q6 ]% H, o& b; ^$ j3 H; [3 ~; S; W3 \
* w6 c9 T1 V! I5 L
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
- x5 y- T- w: ]0 H" \4 {" }* L" L2 X$ u) Q

- S' C8 [) }, q" A后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
; S, ]+ X% V. o' B: y) s6 b- w* U& _! K9 I+ o

( d" ~) C5 L, M0 x: H" Bdef make_generator_model():
, @7 J3 G7 d5 n' ^' A  H/ Q    model = tf.keras.Sequential()7 Q  V7 ?+ }4 l5 ^* q) T/ `
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
+ O7 l$ m4 b1 ]    model.add(layers.BatchNormalization())
* S  r2 h2 N7 m# g    model.add(layers.LeakyReLU())
0 s0 S. C7 @/ T+ y1 } / }# w& K% [4 m' o7 h
    model.add(layers.Reshape((7, 7, 256)))6 U7 Z) M+ e6 `4 W+ ~' H
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制6 w* o# b9 ?7 @3 Y" v+ O

8 i& Z) W% u  J- H. I: A  a; l    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))( L8 e0 h" z- m4 t" _% @) u
    assert model.output_shape == (None, 7, 7, 128)
9 }1 {) ^( x/ |: t3 O1 A3 O    model.add(layers.BatchNormalization())
( J( n; C$ {, @    model.add(layers.LeakyReLU())$ h% B3 H) s2 E7 X& F/ V6 Q
5 m* V# T" u" f  e
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)); ^' M6 j" @8 ~
    assert model.output_shape == (None, 14, 14, 64)! T7 u3 J" v. T, x6 G
    model.add(layers.BatchNormalization())
8 c; |, F- @' C0 ^% E    model.add(layers.LeakyReLU())
3 ?( ?2 Z. S  v1 c7 a) \ 6 v% N/ R5 |6 n1 Z' F) K
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))3 C+ O8 c3 q2 e3 A
    assert model.output_shape == (None, 28, 28, 1)
- E' }- |. B: q2 E' C # |" A$ W. m0 p9 e+ ]% c4 q
    return model
" {3 B$ T6 H6 y! S$ h+ ~用tf.keras.utils.plot_model( ),看一下模型结构
6 o) S6 N) Q: r, X( h) K6 V% X3 Y
9 e& e4 f$ z3 d9 j6 _6 c; L9 L
1 D2 b! c3 H3 Q) P3 A* }: k, C
2 q- j5 D! z$ a! k# J

+ W9 w  l2 M# y  Y: C% {

1 c; Z; C% {4 f5 u! v0 Z0 X& z用summary(),看一下模型结构和参数
# O. W. J' L6 o( B2 [( [- H
) h, X- x, ^+ P/ B
6 ]3 e+ s6 L9 [, \" g/ f! E3 c

* S  ^  z! F/ e5 O+ \6 c
( w/ N2 C& s6 W' L; q$ ]! ?) U
1 O# d: ?  p" y* H3 e
$ r5 S1 }" |8 t7 m% C
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。# e3 ^" ?5 ~4 U8 d  C
5 X5 ?* d$ E( f/ }. l- G5 b

5 n" B1 ?+ a# F8 g6 ^7 Qgenerator = make_generator_model()% v! d7 n6 R! Z
, ^- c4 |9 ^& ^3 y* c$ J+ I; k
noise = tf.random.normal([1, 100])1 s) P: G* x% u! {6 R  x+ k1 _
generated_image = generator(noise, training=False)2 {. y  J! M! L0 W9 d

' h4 h$ d7 Y8 Q+ o0 G! O+ Pplt.imshow(generated_image[0, :, :, 0], cmap='gray')
4 C: `2 _( {4 T0 \7 b& @' C# g, o9 r$ ]) T+ r
& ~1 s: z7 [/ o+ d

9 Y, W: _+ M! B3 }2 M% a: v
" F( K* K% p: Y: i# n
3.1 判别器" d7 O, E! ~- g2 k' x2 |
判别器是基于 CNN卷积神经网络 的图片分类器。' N0 c$ O/ Q( \9 Q6 N7 q

! h' f. q; p- P4 z
  \- Z& W: V2 K3 C9 I6 _* P
def make_discriminator_model():
2 ]( L8 O1 F/ [    model = tf.keras.Sequential()
4 T$ p& s, \6 l& r9 {* h* R    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',. z, {1 _0 k5 F
                                     input_shape=[28, 28, 1]))! s/ `* V- O/ v* d! d, ?/ K: @
    model.add(layers.LeakyReLU())
" k7 f7 G6 M2 B: R, J2 d+ y" z    model.add(layers.Dropout(0.3))
0 F0 l/ S! ~# H0 D4 [7 N. z , z' Z! ]. B9 @0 n2 Y% _
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
. L1 [5 Q, g" T7 n1 T    model.add(layers.LeakyReLU())  c4 s6 R4 W% Y) l# f6 @
    model.add(layers.Dropout(0.3))8 I% S9 C( d; s; d; P

6 `7 o! t5 |# t7 r9 A& X- k2 k    model.add(layers.Flatten())
' F6 X6 d5 b, K    model.add(layers.Dense(1)); T( }% [: t) T3 m1 Y

: t. I' i7 l3 W0 x+ `    return model5 n% p5 G# n; R" ^! B4 m# }
用tf.keras.utils.plot_model( ),看一下模型结构
* u' N; i: D1 x  z) a5 [+ a8 \2 D; r) o' Y

' t' E: O* l) A
* t" I$ }% s% C$ j7 j, c# Q7 j# U
8 H/ u0 v9 |$ Y! C  E, q
2 P7 x, e9 L/ K9 ^) V' F+ r6 t
9 |) W0 T  u9 r/ [
用summary(),看一下模型结构和参数$ S) ^9 z. U( |" {! G& O7 z
. {' o( [& f9 L' i, P
, g- ]* a5 i6 Y5 V% X  E8 f
/ l* I0 k: \& r
8 e$ j/ y6 _2 u& S& U

, q5 S! F7 {$ b5 C

8 q6 H: h( ]& m2 X$ _四、定义损失函数和优化器1 V8 O5 D: S5 G! z
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。3 g2 e6 Z; r% ^7 L
! M4 D+ G/ f- q8 w
) x$ J4 X5 h. z# e
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。5 w2 J# |* G& y, S7 e% I* P% T

& H5 P1 a/ S2 k* S

# |' [+ H5 D) ]7 c& _* j! X, z# 该方法返回计算交叉熵损失的辅助函数, L7 u7 [9 l7 t
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True): g' }) t9 Z0 K+ g. Z
4.1 生成器的损失和优化器
. E7 A: a4 o# D" |1)生成器损失
6 D$ ^' e# M1 T7 K& K, T, y* j1 B1 V6 J) v. q) U

" p+ ]3 x, a0 V7 t3 t生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
) i+ k  N7 R* _/ u, s5 o# `6 R" E4 y3 X6 A
6 v, q& \6 W4 ?' F3 b
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
. t; y: V% J$ A) R0 |. q. a5 ^" L; c/ T% G9 F' O- f

( ?4 \' c$ D9 ]/ [0 idef generator_loss(fake_output):1 Z: c" f( I5 h: a! }9 M
    return cross_entropy(tf.ones_like(fake_output), fake_output)- g$ V) P/ ?9 F7 F# }
2)生成器优化器
4 G; U* m+ d, O2 f2 b- s# G+ H' Y/ ~8 V7 ?3 o1 q3 A

, B/ g) b6 J) c; L8 }* [generator_optimizer = tf.keras.optimizers.Adam(1e-4); Q) M0 x# g: U5 I
4.2 判别器的损失和优化器
- w/ ^$ @/ \  H: B1)判别器损失$ n3 r! `7 I* L$ h! w3 Q/ V8 O5 |+ X& q& a

) I% N3 Q5 q$ F2 Y# h! @+ p
! y6 F) e0 O6 e& s& K9 b2 L% F
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
2 O4 ?7 ^7 u) P9 q: K) |( a- m! Q+ m1 l! T0 O6 l9 W
2 P7 B( ~# `% I' L
def discriminator_loss(real_output, fake_output):
8 F' W0 z) J4 d" S( o7 C    real_loss = cross_entropy(tf.ones_like(real_output), real_output)( M* G4 R9 F; N
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)& O( D. D& ~$ Q7 i2 [
    total_loss = real_loss + fake_loss# N# ]# w. B8 I' y3 b/ Z3 U
    return total_loss
, q5 O! M/ ^. R4 k* Q0 ~; X9 t2)判别器优化器  E5 d( i- G  S' N3 @6 k* p% c
  m* v' A( a0 U" J4 p  ^- @7 o
  J# a4 [/ K/ H) P. {; z6 D
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
! y2 f) a4 a0 I8 L, v3 ^5 l+ |! A五、训练模型+ o7 Z1 l" X! h) O9 `7 f" ?
5.1 保存检查点* ?* b! Q# g2 L
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。6 Y8 s: x7 a3 k. Z% ^0 g& }

; g$ {4 L/ A* i' C  d

- G$ I9 |" _; Z+ \  X) p: P, Echeckpoint_dir = './training_checkpoints'6 m) |# S( q; D; W' I* f
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
2 J: v$ b' v1 ?checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,' N5 Q2 r) m& r$ B
                                 discriminator_optimizer=discriminator_optimizer,, m* m, u1 e3 @! @
                                 generator=generator,
- [& X9 l2 |3 F/ S" U9 E4 v                                 discriminator=discriminator)
# r' m" X: z+ O! k/ I$ g5.2 定义训练过程; O- G- c4 R4 g4 c" E) W$ T
EPOCHS = 506 p& Z# n5 ^% |& G5 w
noise_dim = 100
! I6 R1 R1 V% o5 P$ D; _# hnum_examples_to_generate = 16) f& V( J1 O: K8 ^3 |

5 x" C5 e  V; q2 R1 l* ^$ c- J+ i
3 I" l( `+ U# G/ k7 X* }) F9 @. i# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
5 }* ^1 p- {  W* b6 o2 pseed = tf.random.normal([num_examples_to_generate, noise_dim])7 t- Q; y3 @3 x( W- n' @
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。& [  e  W/ X2 _# J7 w

' ^+ l6 S" y8 b/ p3 B1 k# s
1 W! C' j+ S; W- X: V- h' F
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。3 R6 M" U( m% i& s5 H+ d
7 B5 M( F- {8 ?3 c# s5 f7 B: w
/ G0 `. u/ @3 b8 g2 {1 P
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。1 [) K6 E. `4 d+ g" \# C; L

- E) L; g8 _, b0 N, R
$ D9 L2 Y1 D  J# F, ]$ |. d
# 注意 `tf.function` 的使用
! G2 }4 d" q* G% |" Y& |/ L8 j# 该注解使函数被“编译”
* ?/ W+ z9 t, [@tf.function
: u0 X% r9 ~5 |def train_step(images):$ b$ h* M3 }( n2 n3 o
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
6 P2 W& Z5 V4 [6 U4 [2 y' s % o$ u3 p8 n8 B, N3 ]3 q. t
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:$ J4 y! a; u! I+ T* b2 _
      generated_images = generator(noise, training=True)+ U/ e' M5 k% H. X# d* k* l( |) }
; U5 @7 ?3 w  A7 W5 J# c* d3 u
      real_output = discriminator(images, training=True)
" ?8 z5 ]- K' W4 o; ~      fake_output = discriminator(generated_images, training=True)
  M3 b! H! J: P2 B- `1 \* ]! y
' D% e: B7 o$ U( S8 P      gen_loss = generator_loss(fake_output)4 o- E( L) k$ f* a1 G- h
      disc_loss = discriminator_loss(real_output, fake_output)
+ a" j0 g" H4 t
1 c4 ]& E, B# V% b9 ]    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
* w( E2 }& K6 F& H    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
+ Y2 A5 \: K$ K; a" F- p
& Q% A6 u- h& g$ i    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
; e& e( G' F6 k) y* L2 D    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
" S8 b, Q1 k; v: O6 g+ e + F/ C! x0 n" g
def train(dataset, epochs):; ^9 \; h7 F4 U- @6 R
  for epoch in range(epochs):3 Q5 G$ @5 t# r+ {* q0 X& g3 I- F
    start = time.time(): K5 `' i- x; J7 s; R( X
6 ?) r1 e$ G! u! d" S
    for image_batch in dataset:
; t  S  K+ i! P  i! B) s6 B      train_step(image_batch)  F; K2 H0 I# [9 z  i3 P4 t
2 N$ A/ k! Z6 R! b  {& b
    # 继续进行时为 GIF 生成图像; w8 O0 S2 q+ {' o$ c2 j1 W3 i) m
    display.clear_output(wait=True)
8 Q* t+ T. v1 n2 a    generate_and_save_images(generator,! L" n0 p8 P) \* i9 o$ v
                             epoch + 1,7 W+ ^$ A" R. j2 y' V  q+ k
                             seed)
: C! f2 l: Z9 ]0 [# |9 g + t0 ^+ t# F" K2 a( t2 U% _
    # 每 15 个 epoch 保存一次模型0 r, c" {7 F9 E8 o: Z" u  G
    if (epoch + 1) % 15 == 0:# A" ^- }8 v( v6 H$ |
      checkpoint.save(file_prefix = checkpoint_prefix)+ [, K+ p8 Y1 H/ _0 s
. D2 o' I  {! q6 S9 K9 u5 Z
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))3 N1 |% a8 t: ]. \; R( R/ _5 W
# Z5 c' {' G- ^0 U" D
  # 最后一个 epoch 结束后生成图片
6 V) U. `. h8 q- c  display.clear_output(wait=True)! W* s/ v& T3 Q4 h
  generate_and_save_images(generator,. B- V' D1 J5 i. e
                           epochs,
: G+ s9 m, y1 f1 Q+ |: E: K                           seed)
9 L( w, c* ~6 R ; K8 |" Q3 N( f$ c. F  e4 U
# 生成与保存图片& J1 d# {  l# q3 w
def generate_and_save_images(model, epoch, test_input):5 {6 |0 [! d" _" ^" K0 y
  # 注意 training` 设定为 False1 u( [& h+ Q6 X* ^1 {
  # 因此,所有层都在推理模式下运行(batchnorm)。
$ I7 @8 ~0 k% i2 |% b0 c4 q  predictions = model(test_input, training=False)) y9 J" ?2 G: A4 z  V, W7 s4 i1 ^
3 Q* N2 e3 J6 S7 ?  [- k1 p
  fig = plt.figure(figsize=(4,4))5 a% _5 r8 ~9 R; w7 h+ [: N+ T

/ Y- U1 o3 u, y( S) ~  for i in range(predictions.shape[0]):1 x4 u7 B. u* r# p: `) d
      plt.subplot(4, 4, i+1)
" x# D. i$ n& w4 b# Q      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')% ]* k) v* Z) A2 {8 x8 f' J, G. {
      plt.axis('off')2 v1 v- v* b" d2 n" Z. i

/ r; a& t# F2 `2 `$ {4 L8 _9 |  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
+ n! \! G4 P  w0 B* N  plt.show()) s) }# ?' R( P3 N7 m
5.3 训练模型
7 \3 V2 R) ?( T1 e, j调用上面定义的train()函数,来同时训练生成器和判别器。
4 D! k) [# \" a$ s% E, Q3 p2 W; L- {* H; [- c' V$ O9 x" E

1 y- X9 |! v% Z+ w/ B注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
. H2 x) a: O% A! C  F  X! H3 a" M: P8 E
# K6 U2 \1 r& M5 T5 D
%%time
2 f/ y" G! p+ ~6 itrain(train_dataset, EPOCHS)( q$ C& Y( y0 z/ R
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
9 D$ l8 N; S* A+ }
0 q* J; g. g! m- j  R7 u

) A' z) P; h- P4 q  S训练了15轮的效果:
$ ?! E5 e: b% A% W6 V! t! `3 |% O$ w' j
1 {& E- I% N- v& P& W

4 s$ V  O6 ?- w& l

9 S) C, Q5 c, `, f, G( o  u# u9 `7 s  c' G4 b( R+ S# l

. W) H! m' C- N1 N0 z2 j训练了30轮的效果:
& p7 i7 J# c- J1 F
, i" P* e, \$ R# h0 D  O7 k9 g+ M
: [' T5 L* S- F3 B3 g

$ ^+ o) q, f2 Y! e& ]7 k+ |

# M6 C3 F. G7 q
9 ~2 O" O# S* |) ^/ b
) b9 h4 r8 L$ W! A( f( Q& o
训练过程:
+ Y9 o9 h5 w1 Q9 h9 Z- Z' D' c0 D1 C5 f  ]( X& k6 l, j
9 D/ c1 s5 r9 h+ u
0 ^" @/ z, @2 I9 B, f. u/ U; T3 O
& m- R1 Z% S' \- x9 a

) {' w; L1 w0 g: T, O
# N' }; T( d+ y2 W
恢复最新的检查点9 w4 |4 X1 q. W. v" l6 G! r

! H0 q: S6 R0 Y6 P* V

2 u% |8 O4 m. F1 Y3 m4 ]checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))/ C+ B4 l( Y, I& n( w7 Q7 a+ B
六、评估模型& N( u7 ]+ ]  D3 A
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。0 I$ F6 S9 n5 G- U

* u& v+ ?6 i4 M1 v. u

/ w- O  E4 \0 n3 g# 使用 epoch 数生成单张图片. z/ _$ n) Z, q# I5 V' l  _
def display_image(epoch_no):
5 h% q) f9 s# D, U% ~/ c& Q  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)); _) q4 {+ s& \5 R
* L1 ]: s# L) m# z* o- G
display_image(EPOCHS)
6 t9 i2 `. z2 n4 ]. t+ l" eanim_file = 'dcgan.gif'
- t; u% d, X/ Z) R: O# O  ] 7 J6 {+ D) a" O8 k0 C4 u0 I* U
with imageio.get_writer(anim_file, mode='I') as writer:
; d* ?4 U! h& j$ m4 c  filenames = glob.glob('image*.png')* e9 c( H& }2 G% |" i
  filenames = sorted(filenames)" N( c1 Q5 H3 m* q6 W
  last = -1% }) h  N# E1 R8 P: \
  for i,filename in enumerate(filenames):
3 i. R7 E) Z: `! H- s/ j( i    frame = 2*(i**0.5)
7 I2 M- v: k5 i3 r% u5 \( Z8 e    if round(frame) > round(last):
; S* o) v& ^5 m+ c6 f, _      last = frame
: T) R5 Q# K  B5 \    else:7 q. G' t9 B, x  X% p7 j4 w
      continue
* J  y0 L% ^- B( W# A: \    image = imageio.imread(filename)
5 Z" F6 `* f: G) o$ M. h0 Q7 N    writer.append_data(image)
4 t  J( w' w# s$ O& C, @3 ]& A4 P  image = imageio.imread(filename)
" {8 Y! U, G, a& [  _7 ?. L% g3 d! v  writer.append_data(image)
. l7 H( z- S0 m1 G6 @' N 4 k0 R# L3 r; ~
import IPython" f8 j, q+ G2 Q! ]" E) L
if IPython.version_info > (6,2,0,''):! ]5 B( `, `! e+ u
  display.Image(filename=anim_file)4 I1 u$ u# S3 t9 c# j0 X
8 D/ G2 H! t% O4 _+ T' v$ l. \( v

  I1 E% p5 C$ q# n8 a1 ^: d+ V0 I, L( j; i- u, o
$ I& [) B. Y: l' |2 l
完整代码:
" a% U! d6 n, C7 b  a( T( Q9 `1 \, [. K
% |, \6 H% h( u  w  K5 ?
import tensorflow as tf& V' A4 d; M+ X" b' v1 ]& B% a* ~$ u
import glob
& D# x  z( u# Q, vimport imageio
: K9 ?6 M4 u* X$ ?5 s7 Bimport matplotlib.pyplot as plt# s: n( V5 u# n
import numpy as np) d/ v: N$ o/ x( _3 M2 c4 z
import os
% R8 @- \: s( u  `4 @+ rimport PIL4 B0 }. W/ B, D4 n
from tensorflow.keras import layers7 B1 ^& B/ i5 D8 a
import time
7 E$ X, N/ F' F8 e) K* D $ |! J4 y# \; ]: E) A! Y6 C: n& `
from IPython import display
3 Z  @: Y; U9 K5 O $ Y" c8 b- X5 V  ?
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()! h" t+ b0 J2 K/ m; p$ D
3 N% k$ o' u4 g* s; M
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
" t/ u: K- C" P7 N6 j' mtrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
' `. K& [& u8 _% k $ e1 n: a( r" {( B1 G) V- }
BUFFER_SIZE = 60000% X6 i2 E2 x$ {! \6 U
BATCH_SIZE = 256
, {. a" b! e6 w% G ' ?2 X' R2 C7 I: {; B1 i  q, \
# 批量化和打乱数据9 D, W; n9 u1 C, Z9 A0 x
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)' _' n8 `& }) M& }1 \& b
% h. x1 Y; V$ k1 k' B1 G
# 创建模型--生成器9 Z3 K* F5 A: V% W: o+ Y
def make_generator_model():& X' s( F8 k$ [0 {' N7 M. D% X* A
    model = tf.keras.Sequential()
' g, w' N; D, s; E; A9 _  a  W5 a    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))): g" T. J" }, t$ ?& u! A
    model.add(layers.BatchNormalization())- s" R+ [1 j" n' r& [) \  A
    model.add(layers.LeakyReLU())
1 G0 @4 @% r3 y' {. u4 m% m0 f 8 E- k7 _7 e6 e. `2 I. o
    model.add(layers.Reshape((7, 7, 256)))
9 ?+ I1 \* v" ?- C) W% [    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
2 k3 L( P$ A: A8 \) w1 \
5 P- W) ^! |0 u) l' n    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))# a$ [: f! g5 H% Y- I$ M3 ]* H% ?
    assert model.output_shape == (None, 7, 7, 128)4 F: Y0 I7 `3 k& a$ f, a
    model.add(layers.BatchNormalization()); K3 A7 [: z4 K. F+ ?( a
    model.add(layers.LeakyReLU())6 [6 b; u: s7 y# r* o3 l, ~- F0 V
- _. R" `8 N* ]1 T8 R2 z
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)). X2 c7 [$ w0 U1 O; |
    assert model.output_shape == (None, 14, 14, 64)( I- v. Y" T' D( Z3 x1 x: g
    model.add(layers.BatchNormalization())
. M& c8 g$ L' \+ Z    model.add(layers.LeakyReLU())
1 G; V9 l. J2 I( s  }0 I : n4 N, X5 r7 V  M* D
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))$ ]! f& @; d. l1 }
    assert model.output_shape == (None, 28, 28, 1)4 i3 d* f) Z; U7 c: I( {& n6 k9 H- y
+ O! v7 M- G6 o8 I% g$ f) l9 N( n
    return model
  k, U' J0 M1 D# }
1 q& J  a- T" z# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。) h1 h2 b2 Y; t0 X; p) u! o% q
generator = make_generator_model()8 A! |: M& ~! h4 r+ F

9 T* A( @, C# W5 _% gnoise = tf.random.normal([1, 100])2 t) v/ [( v/ c+ k2 R- E- S
generated_image = generator(noise, training=False)
# O6 J9 H0 p! [, u $ C3 O& L: u- n; ]  n/ r
plt.imshow(generated_image[0, :, :, 0], cmap='gray')+ K9 W) Z: |* ^4 b; {! j7 M, O
tf.keras.utils.plot_model(generator)
" N0 E6 _; J. A  Y3 @8 a0 Y% g , `* Q& I: H' G+ u7 I" P
# 判别器5 e0 D+ [/ E, p7 I* ~
def make_discriminator_model():4 C- ?, }3 y. ~3 f; C
    model = tf.keras.Sequential()+ P/ s; B6 o* [7 r! w
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
% _* S/ J) s& h/ O! P7 N+ R; g, v                                     input_shape=[28, 28, 1]))2 u7 i5 i) ^8 V7 P1 Y$ d: r
    model.add(layers.LeakyReLU())1 q/ P4 z1 Q: W% @. X3 X' a
    model.add(layers.Dropout(0.3))
4 w  E& [- F, X% t( J3 q- [ 7 O: d: V6 ^/ ]/ Z
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
- b  [- v. X3 U: a/ _( V9 S! n" W    model.add(layers.LeakyReLU())
, Q% H8 b0 S6 c# Y- B    model.add(layers.Dropout(0.3)); w9 _, h9 n+ C9 j# C2 u
4 \5 ]# p0 n; i0 n  A
    model.add(layers.Flatten())) _8 s0 B. ]3 J+ u3 E* c/ B
    model.add(layers.Dense(1))- b% W& H7 Z# Y  O) j$ ~/ x: @
; t  Z+ ?+ h: ?
    return model
% G  L1 `0 q6 I/ v3 a1 F; ^7 H ( f" F( E: D& T$ q& y. b) y
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
! k3 e8 V9 d- z/ @discriminator = make_discriminator_model(). [' Q- ?% Q! A% {
decision = discriminator(generated_image)
  |$ N& k( L/ G0 x; a# [print (decision)2 V4 @0 u* W0 a: Q; ]
* P5 n4 D! b) c, a
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。$ n* \5 n% l5 i8 o
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
0 A/ x6 K3 h0 N+ S$ y- F # P- f+ e8 A; r6 k# Z0 m
# 生成器的损失和优化器+ Q8 Y. W0 O, H5 n5 U
def generator_loss(fake_output):
, [# `. m0 A! W& d    return cross_entropy(tf.ones_like(fake_output), fake_output)
, W  ]5 w; g9 L( D8 K/ _generator_optimizer = tf.keras.optimizers.Adam(1e-4)
* X6 C( z( R' T& [$ t9 G5 Q
$ _! i4 W* D- [, {% W# 判别器的损失和优化器
; ?3 I7 G7 x( D1 ?5 G- Kdef discriminator_loss(real_output, fake_output):
! m& o4 [' @  q" [6 E    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
4 g  o( w* k; u5 X7 _; @0 k    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
0 N. t7 I. I7 r- C$ R7 y3 i    total_loss = real_loss + fake_loss
& d  X. g8 |7 @# H    return total_loss" W1 \8 {9 H  s7 e, @1 G- n: ?
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
: A) p- W* Q0 D8 q* c/ P
( a+ s# r7 z5 G8 \# 保存检查点( C" `5 b* j: q( x  N. f/ b# K1 Z
checkpoint_dir = './training_checkpoints'
) d) m- B7 o; Icheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
" p5 ]( T* q! rcheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
" O3 j- e5 l& x; ^3 h                                 discriminator_optimizer=discriminator_optimizer,* j( u9 n- F# c; `$ K# @7 i; z) k
                                 generator=generator,& P7 u4 _3 y: K1 `% q  f
                                 discriminator=discriminator)
: h. @! Y1 ^( K, k+ V# z7 A ' L5 B. G% M! o! m! o
# 定义训练过程
% t# k7 z2 u  ^' D8 tEPOCHS = 50( W$ P1 [: x$ x0 x& P
noise_dim = 100
9 M% D, X5 I- k% P3 wnum_examples_to_generate = 16+ h5 o; L" c) e: H6 |- R* t/ C
6 ?% W$ e; g% ^7 G/ h# _
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
' P+ o" A/ f; aseed = tf.random.normal([num_examples_to_generate, noise_dim])  z% f7 G/ j; m9 B& i

) ~3 A" l, J. w4 u8 R$ N! _: l# 注意 `tf.function` 的使用
* t* W, h( }1 \" c# 该注解使函数被“编译”
2 S2 |6 R) E0 O, ~+ D@tf.function  g) w1 R  Q% a- W
def train_step(images):, m7 [. v9 s7 ~; [
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
* i: ^( A$ _* Y, D' L6 U% Z: S
7 f+ |* h% O, e, L7 y  @( `9 [    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:7 I; t. I7 m0 R1 s# B
      generated_images = generator(noise, training=True)* Z* E/ H3 ]3 o2 ?

0 H' p& d. ?6 {0 r      real_output = discriminator(images, training=True)! _8 b/ e( Z+ v" @
      fake_output = discriminator(generated_images, training=True)
; L- h; O) ?4 ^
) ~1 Z0 c, }/ p; y4 X* |/ v* l# ?: B# h      gen_loss = generator_loss(fake_output)
7 z& `" W* r# v! a      disc_loss = discriminator_loss(real_output, fake_output)+ V' x( g% ^. p4 B" i! N

" j: b( Y( b* J+ l/ e- B' ^8 F* g7 j    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)( ]6 P. d$ L3 L% S
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)( I8 s3 ~" k, t' [5 k

0 g* Y7 w: C2 _+ S( g/ V    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
2 u1 x  x/ l% a) g; u    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
. ?# R  I: j) m 8 K' v/ r1 ]' H: Z) t* r  {2 j
def train(dataset, epochs):
7 S( W' C# ^# z) X8 l/ j  for epoch in range(epochs):. {3 |. z& I) H. X
    start = time.time()
/ v5 Z6 }! s' q + L7 H2 \6 w4 b8 [
    for image_batch in dataset:
) h3 q' @* G3 u" N) R+ `. v      train_step(image_batch)
& G( y1 ]1 V0 j+ b- x) W 1 [2 p5 T$ s  ~, N" y/ K: |
    # 继续进行时为 GIF 生成图像
, o( d% L  D+ g& W6 g$ W. O    display.clear_output(wait=True)$ O* |' }; K3 r9 y, ^4 B8 w  K3 T
    generate_and_save_images(generator,
# c9 g9 n5 F4 S& u0 M% Q  D                             epoch + 1,# q# z, v/ f: Q/ @+ M) k  F, W
                             seed); ]1 b* d8 C% y$ X) m
& @, ?4 c& l; Z9 t0 R) f. N2 n% \+ Q
    # 每 15 个 epoch 保存一次模型' Q% x. {2 P% z' ?
    if (epoch + 1) % 15 == 0:9 p% l; H$ o+ c- B$ e+ y! [9 U
      checkpoint.save(file_prefix = checkpoint_prefix)
8 L9 E+ a1 W: u( W& _
8 F' M# |0 c! F6 S2 Q) }' F    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))8 U! _/ Z4 f, [9 a- [2 v5 W4 Y
5 O  v8 T9 |( B) q' t- W, ^$ @
  # 最后一个 epoch 结束后生成图片
9 Z# F' f6 I% y2 e4 u  display.clear_output(wait=True)! {8 g) d/ s, o7 d. Q" Y
  generate_and_save_images(generator,- r+ V! g. B" w4 x2 M0 j( {1 a
                           epochs,
, N1 e% `! z1 O) m# s, W                           seed)/ L1 J  \' M8 n! x6 p# J

+ N/ G# u5 J/ E& h+ A' Q9 [+ N8 {# 生成与保存图片+ T. C/ Z! x, B2 ]
def generate_and_save_images(model, epoch, test_input):- F  g* \8 u* z# Q: x& O& N
  # 注意 training` 设定为 False
8 U5 W6 e$ n0 X3 O$ e$ X  # 因此,所有层都在推理模式下运行(batchnorm)。
  U9 Q  i6 X9 |' m. i: t% m( e+ [% H  predictions = model(test_input, training=False)4 S8 A7 _/ S  c! S) q) o. R: b
" _; C0 K7 T5 |# Z: G8 p! E
  fig = plt.figure(figsize=(4,4))
1 M, [- C) b) x; y2 l( W2 ] ' T; x/ d9 q" a( ~
  for i in range(predictions.shape[0]):
9 n- R: v) l& v) D8 H- M      plt.subplot(4, 4, i+1). [7 i* u# E- i1 n1 V
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
' K  \; {) y: `$ b1 q      plt.axis('off')( _1 |1 {8 M: Z9 A/ M
1 F  b; i" T0 X3 \" ~- w% Q1 j7 D8 Q8 p
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  _& V. P  b3 ~% v/ f5 f8 ?7 ^8 @; e  plt.show()
, D4 g) S7 o+ y2 [! `* u * \, z! y, N- I: j4 n
# 训练模型
* J1 P/ W4 q. Q9 _8 [train(train_dataset, EPOCHS)# I; s  ~# j* l" T6 j3 R. s
- d' x0 K1 K; F# \" S6 K% [: b! E1 n
# 恢复最新的检查点$ m: R6 E& I! P3 R
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))/ M- H  t" o; v* {2 O, O$ O* C9 b; _" u- [
/ M- W& _9 _8 v3 e3 I
# 评估模型
2 e- k9 P% p" Z2 y' F# 使用 epoch 数生成单张图片
9 ~& n/ E# J# x0 A8 H( c2 I' ~def display_image(epoch_no):
& l1 H1 X+ ^& l: I* Q  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))) {( q3 x+ K7 r- W) X" i& c: k( F
' |/ q7 ]( Q2 o
display_image(EPOCHS)6 @- E6 D# w6 F

8 C5 C* t. B  f( U5 o0 n5 Wanim_file = 'dcgan.gif'
0 q, m7 a$ P  E' n ) }+ n; P& x/ O' k% K7 a; U
with imageio.get_writer(anim_file, mode='I') as writer:  Y5 c( l  b! d; v
  filenames = glob.glob('image*.png')
/ }1 J( S1 G. d* ]2 I- v  filenames = sorted(filenames)7 a: w6 M4 N) n8 T4 p0 ?5 L
  last = -19 \8 y* l/ b& w2 Y5 n9 m: W  J0 @
  for i,filename in enumerate(filenames):
( `* ~) s2 b6 `. ?8 y    frame = 2*(i**0.5)
8 ^) g$ E( I0 L% g1 |0 H) i    if round(frame) > round(last):$ o! E( J$ ^9 ^
      last = frame) V: O+ o4 f4 V+ T; Z
    else:1 T& [+ b: ]+ p9 F
      continue5 Y  j5 p+ r9 Q
    image = imageio.imread(filename)
1 X# |) Q& u6 X/ E) [- F    writer.append_data(image)
0 b6 \/ P0 @. h) A: @& L  image = imageio.imread(filename): K1 ?5 x$ C& c8 P
  writer.append_data(image)
8 h* E2 x. D, ~$ W& I5 Q 9 u1 a$ D0 @4 r) M+ l0 d7 k; Y! b
import IPython" T& c5 X8 ?( D* c
if IPython.version_info > (6,2,0,''):
5 Y% _' Z8 c1 e/ R  display.Image(filename=anim_file)' f& B& _( \) D- Z' Q/ D/ ^, o
参考:https://www.tensorflow.org/tutorials/generative/dcgan9 h- s4 G' y5 W* B7 l# _; Y  s
————————————————8 N  M3 a$ b" A4 U! Z
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
% o9 C- K  X& |$ l9 n原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111  L- z1 n6 z2 `6 Y2 A0 a

0 K  l6 G) G6 n& r$ E' B, H9 w9 b  H1 [$ b- {0 w9 j. U





欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5