数学建模社区-数学中国

标题: 深度卷积生成对抗网络DCGAN——生成手写数字图片 [打印本页]

作者: 杨利霞    时间: 2021-6-28 11:54
标题: 深度卷积生成对抗网络DCGAN——生成手写数字图片
# L& a$ y9 D4 ?: g
深度卷积生成对抗网络DCGAN——生成手写数字图片! b5 B, N3 f* P: u" l( O" [
前言
) M$ R" i  T# M( l5 i$ N) U本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
6 W' @- E4 R! S0 `4 a7 E$ y9 V- U: k8 r" f" D
3 W: t5 H  o0 I$ n. k
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:! l' i1 A! u7 }  t
" I4 ~. o: {4 U( u

0 d7 x$ {- m4 T8 @# 用于生成 GIF 图片
3 O" Z3 q' q- Q* p: e: ?( M& ypip install -q imageio
! k* Z' ?1 s) [, s$ B. I( ^目录2 ]1 l# h5 I! e6 A* p
, k9 j" V7 J4 M

& ^) |: p9 H) R; v/ T前言
2 K3 ~- S1 i* L, h5 E5 n' G9 E! u0 n  t1 C& y* x. J5 c

# \" f- v+ V$ T5 T, ?一、什么是生成对抗网络?5 _+ X0 n' w& F) b  d" M

# H. u7 b$ I: s' J0 y8 j/ _' H

) y. b3 ]4 X, o1 S+ F0 _二、加载数据集
6 I. y, k; l( q) U; y9 m8 h# G7 O* y; {7 h8 e% w

. N5 c$ L' g" v. |. n三、创建模型' j* b) ]! A0 Q3 }

) |( h" ~) d" E
$ u# O" p9 D$ ^4 ?+ c! i
3.1 生成器
# E4 q) p8 R& v" b6 G
$ s! Y2 o0 ?7 I4 B8 I2 S$ s

5 J, L) Q) t+ P6 n5 K3.1 判别器/ Q" S+ P/ `) y0 w, U5 R

1 q9 s) Q- }0 y4 z- U' K" |
: I$ K& ~( T& B- u) t: o# ^
四、定义损失函数和优化器& A0 y' N$ a) j: R8 Y9 E

) m& s& w# \3 M1 K' A

( U  C/ }2 m( ]& ?) ~4.1 生成器的损失和优化器2 o7 g% t& a- }6 j9 o
+ D# a: X: I* x1 j! J  U
9 q% e1 v7 n3 |) v) T1 y& F9 ]
4.2 判别器的损失和优化器
, b8 Y+ F- g5 P7 K$ z. r: X1 Y7 v3 W- P7 A1 W
# @0 h2 g* l2 U
五、训练模型9 \+ H8 B$ t6 W* g
8 b8 P" Z5 S/ a& I6 Q- U, C

+ P* m" }; c! M+ e/ e7 g5.1 保存检查点
5 {+ ?) B4 Z4 n" Q4 A# b  c; Y- O4 Q! `" A
' K0 t' p# B0 z' _# w1 C# C1 p
5.2 定义训练过程
4 D6 ~/ V. u/ i
3 r; k* ]. Y3 O: P" k
! K# _" v( p$ J& |1 O/ T6 b8 T
5.3 训练模型
( S* W; i. Y: K# `
: O& O4 j% B. X. E$ N

6 u# U7 `4 I" @' G0 R4 y- W6 J六、评估模型) z4 f2 M' U; @# c- _- P
0 {; O. b. w! _/ B! y

) u. B$ R. u8 d8 J一、什么是生成对抗网络?
7 q1 ~( U6 e  u1 J& X, J. [生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
% }  Q2 t+ w3 b' z4 o3 g$ t/ M$ I2 W5 \$ Q2 s, t' i/ f
, ~0 b# [3 s  n7 ~% W& A
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
5 X% p  W! I8 u8 Q
: N4 F* U4 K/ G$ ]+ S

, \0 [  _% y% S判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
5 ~; X' Y: i& u* a, d+ x" I* f$ ~/ }7 Y- R4 J; w

, W7 e8 s7 `3 s! _/ {% R* O9 Y训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。; t+ U8 \  r8 r4 [# g6 D
5 n* \7 V$ Z; r9 R

$ q4 D5 [1 e8 F+ o2 V, C: x当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。! i- J6 ]8 |7 ~+ O0 u% P5 _

' f4 c  m) ?8 a: a+ m
+ z2 T9 N0 A3 S& e. J: N5 A( _
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。/ R; k7 J6 T1 r- D) \& U% ~7 N
# s" ?: a! G  B5 {

- b& t& R1 B4 ~! b5 V/ V二、加载数据集2 U% |3 ^9 M" `2 a. w
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。  v$ p8 N$ d6 r; \4 {, y/ a

/ F; J: X! V9 D4 P! L# A7 R3 ]

  z, [8 i% u) [, {1 D8 m& q6 m(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
$ h  p7 \, O7 x / H% c3 ]5 M. g- d1 v
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# |5 j6 D3 ?& P: v( ztrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内4 g" Z5 X. K* Q' A+ J1 X( L$ b/ s  f

+ x( e. R; z8 E# o2 x4 S' RBUFFER_SIZE = 60000
& o% f1 }9 U4 IBATCH_SIZE = 2569 n8 D9 A. R. E
/ M# q6 J: R4 C+ j3 k) z
# 批量化和打乱数据$ Q. H4 I) P0 v& R* `7 \& \
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)9 k3 r  A: q! `1 v6 I. }
三、创建模型
: ]5 i6 k& o8 p9 ?0 g2 |# o主要创建两个模型,一个是生成器,另一个是判别器。( |8 A+ a  n0 P

9 p5 p: D4 Z9 n, C& t" ?+ {/ ~
! M3 r9 H( O: C3 r
3.1 生成器0 I4 x, e; u& w, X' x4 ~
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。( G1 Z  p! l# L! q

1 P; w( [% b8 @! X7 X; }2 I

( C# o; p6 w* A然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
/ O! y& v, J3 T5 O8 x$ ]
) ^& R. H% R# W# |

' C% e, p5 k7 U4 @1 R% @6 b后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。& g. j) p  K# V9 T% [

1 ~2 B, X% G: p8 T: M

+ O4 K: ]: D; a! M3 v+ zdef make_generator_model():/ S4 w( Z. E5 A: X9 ?! ~
    model = tf.keras.Sequential()
9 k. N. l8 E+ z! h( j    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
7 Q" }6 W" o  k& o( H    model.add(layers.BatchNormalization())2 z; }+ h  u( T9 }$ U" {' `
    model.add(layers.LeakyReLU())- n5 t7 @( O: V9 n5 {* ^8 h$ |7 y  t) x
3 m7 V' \0 ?& F' D# ]
    model.add(layers.Reshape((7, 7, 256)))* B+ v5 Z- J& d. }
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制( q) Q, x5 [: R7 a6 L
( f. c/ r9 G0 e" l' P6 y4 Y$ v$ G5 O& N
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
" A* g1 a* d1 m$ @8 }" C+ }) j    assert model.output_shape == (None, 7, 7, 128)5 ]7 ?5 Z/ V+ Q6 b; J3 h* S
    model.add(layers.BatchNormalization())
  k* N2 Y$ k" z" E    model.add(layers.LeakyReLU())4 T! F8 q7 W+ v- }8 u+ i% a+ s
" F0 ^" ~, K7 C2 f7 W9 v# l. y
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))- M9 c* z, q% M# x% G2 `$ l- {# c: f
    assert model.output_shape == (None, 14, 14, 64)
, H5 [. \6 J8 z# u  \# W    model.add(layers.BatchNormalization())' c' c( H/ }3 M0 K- r
    model.add(layers.LeakyReLU())" g- ?2 f/ e# ?% Q& i3 F/ \- S5 p2 }

, [: f* ]: |5 F) E1 r2 W    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
) T0 d9 Q( i) E# _5 l7 L6 G* a    assert model.output_shape == (None, 28, 28, 1)
' q* ]- s0 \7 S, G$ [. B6 Z : H  C. g: V! ~8 X1 x3 g
    return model
7 [8 X. S/ o+ E) D/ ]" H7 u用tf.keras.utils.plot_model( ),看一下模型结构
& d5 C6 n: Z+ O& O, S' Y4 T. {. b! i/ n; Y

; Y- M+ ]; J2 c+ T! `2 u
( p! U4 P; R! ~4 Y, O3 a4 |6 c# u+ l$ |9 ]. d$ G+ r. z8 k2 N, K

: B( H4 W4 L; q6 E% f' y7 D$ F用summary(),看一下模型结构和参数
5 l$ d- D! W# u4 J( p4 u- o# _( K' s) ~* m  n$ T/ q

- @9 A) {$ u% |% B4 c! Z! V2 ~7 y# @; s
* J8 j$ L6 \$ a! E

" |% S0 ]7 f# k1 p! h3 x- L# O. B

& M& v% |+ p  v使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。6 }1 K" ?1 E& P7 ?/ |
3 P* ?4 h& A- _* i
& P0 K/ G# J( D, p( I) W: e
generator = make_generator_model()
$ n" k% K, d! O; E# W( D$ F' p9 S
# k* [5 M* }+ b) o- H' w& n0 c- Snoise = tf.random.normal([1, 100])
5 S0 A5 I# e! W& ygenerated_image = generator(noise, training=False)
; ]! n" t  X( K- u: ?! T * ^4 j! k* |) }: k5 I1 p
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
4 V5 G: k2 W4 w& u, i- v+ Q
6 a0 W1 O  a, N- P# c% p4 E

9 t# k6 ^% F8 R! `/ q& [# `
( y  c7 [) V, C0 B, H" z

6 \0 q( F/ y/ a$ N3 M0 ]6 s3.1 判别器
* \6 e* E/ i( \  `! j% d* P判别器是基于 CNN卷积神经网络 的图片分类器。9 P/ m; m- ^$ |0 |, w- ?6 _

7 g: F, h% X5 F4 p" ~9 A8 ]

" a1 I7 c& G6 Y: xdef make_discriminator_model():9 |% T% h  b( a7 F0 f7 X) n0 [+ S
    model = tf.keras.Sequential()- e+ J8 v( k- m) `( U
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
% ~. B5 U' v) B8 T  U1 e                                     input_shape=[28, 28, 1]))
: U; ~$ k3 i: u* |, F    model.add(layers.LeakyReLU())
, `4 S, R  |5 N, H    model.add(layers.Dropout(0.3))
6 v  N+ ^' f' s$ d ; d5 K; w+ N! R8 o2 X+ u- _
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
1 Q) x" w( M: e5 A    model.add(layers.LeakyReLU())
/ E( ^0 L& y# ], [- I0 V    model.add(layers.Dropout(0.3))
- f# h5 g- [  N# O 0 u# V9 p9 [, N3 D' {& |( r: t, y
    model.add(layers.Flatten())
0 F' h7 s% t& n' N    model.add(layers.Dense(1))
2 @5 n3 f! Z" {7 [
/ L5 l; f0 ]& ~! o# K8 B    return model" C' }7 y2 w. g" x1 {
用tf.keras.utils.plot_model( ),看一下模型结构
: P  D" ]0 [/ z5 U( @8 n
5 \* ]9 l' l* d, G$ ]$ g4 I6 z
4 d$ d' x: x: b% f
) T1 B- \/ r0 h( E: F
, M/ k* o0 p$ @' u% E6 i' {/ W
7 Y5 z4 ~# Y3 g8 x; r( Q& N
0 g/ m8 ^* u- C9 O
用summary(),看一下模型结构和参数
% z! V$ ^2 |  k/ n7 H9 }! U* A6 ]7 Q4 |1 U; P, x
) S+ b6 u4 x, J* e2 S  J

* @; ]0 J% Y) C4 Y( U" N( s

! x) e6 J. F# A/ z
. t$ B5 i; I5 H6 a/ F4 i
7 g$ n8 F" X" x( l) G: G' x
四、定义损失函数和优化器
5 b) t& d1 h6 ]# E5 U! n由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。
8 ]; U' u6 a% \. z  K8 g& ?: K3 D: e0 W/ k3 J
8 |  I( g0 X0 A3 _( h: }0 Z
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。6 n& l$ A$ Q0 ?  a, ]
8 a8 }) k: ]( P" c$ H6 R
% c5 d, V! N# p; M* Q9 e, f
# 该方法返回计算交叉熵损失的辅助函数
: x3 Y- D& e; n9 k. ucross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)3 y" Z6 T& u4 G8 l
4.1 生成器的损失和优化器/ S) w! o; ?5 h, M( }8 s# p/ r" G
1)生成器损失
  V: v% P2 A6 H5 w1 F
# ~0 c4 z+ V, w! Y! ]* {$ n# B

' l5 c+ G% A9 a3 G0 m+ W+ X; O6 j生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
, T" q4 \7 _6 n2 i4 i8 x3 U
, z1 ]/ I5 a5 a# ]0 i6 x6 d

! }1 I( N8 p* }$ `这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。* l- ]) S  m0 ^* L
/ R, J( p4 ]+ K% D8 D4 t+ L

, [2 g+ @4 n8 _8 Gdef generator_loss(fake_output):
# ~7 c2 b) Z$ w+ j3 R; A    return cross_entropy(tf.ones_like(fake_output), fake_output)
, x6 m1 W4 D7 l# y+ q% Q- ?2)生成器优化器
$ X% s$ m. t$ c  V+ G& v$ t
$ J1 S7 x, B2 k  o
. M4 f0 D) i' a6 s6 y" x* W
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
3 [; _$ n# d# u( d4.2 判别器的损失和优化器
: d. v: F7 @8 D1)判别器损失
6 b- ?/ {. c) T$ t+ J+ [9 @! u7 R
8 _% N3 w0 c$ N5 y/ `4 e7 V. u: k
) X  A, e; G) ?! ]% h( u
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。0 F" M+ x3 x( R$ n1 ?
3 f' N. Z* q9 F8 d+ n/ g
: z. i/ C; V% V+ W
def discriminator_loss(real_output, fake_output):
% W& G: r( k/ J% f    real_loss = cross_entropy(tf.ones_like(real_output), real_output), u/ P+ G) p; [9 `8 _
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
% @) M6 O, Y4 c5 c: ]    total_loss = real_loss + fake_loss
$ W7 }% V4 b* g9 Q) \6 K/ T    return total_loss, j. N7 k9 l* r2 H$ O
2)判别器优化器$ s$ J! l' ~: s8 p

- b) z- Z% s& l
5 ?. u1 r" |2 l, F
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)4 z7 F$ ]6 g. N# H$ b
五、训练模型
! T+ s- s6 s5 u/ [5 S" s5 F- T5.1 保存检查点
2 a% V1 ^6 Y  R3 g3 ^9 z- q% c保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
1 s- w7 m6 ^- @/ N) |3 B8 T6 q- V9 B6 F5 Q) V9 u5 O- ~" W2 k
% K8 T7 u: \: F3 M9 X+ g$ X* y
checkpoint_dir = './training_checkpoints'& b9 O$ L6 q6 {% F8 K
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt"); j' D6 s1 h1 n) ~$ _
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
1 J- V& S/ F7 ~  }% F                                 discriminator_optimizer=discriminator_optimizer,
( k4 v9 g. N+ o" a2 Z3 G                                 generator=generator,2 K; y0 |! D* k+ k1 `# K2 k
                                 discriminator=discriminator)9 v9 M* [5 V$ v1 f$ {5 j
5.2 定义训练过程  z2 M3 o) O9 T# j0 w+ W' W# i
EPOCHS = 50
+ D2 d! x% C3 i- G* n* N+ enoise_dim = 1009 p7 Z8 `* U! G3 Y: }/ _
num_examples_to_generate = 16
' H  Y' ?- |( Y6 I$ o" o + j: U) d4 Y1 O

; P1 n$ G/ k# k1 E* G! R# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)- B5 B# F% ~6 l0 e1 g1 D
seed = tf.random.normal([num_examples_to_generate, noise_dim])* T, W3 E7 T0 Y, i% R% b
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
, U% d6 r& J) b% Q/ S( V9 w8 ?; S! L3 F

* a1 m% w3 R8 d: K& N判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。) _" Q4 Y4 i- f+ B
: i& ^0 V6 R& U! T, e- T; ?' x" }+ A
: u' z1 {, X1 i9 M( Q
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
: M7 H# b7 A5 S; K0 ~/ z9 f4 t1 o& b- t

7 z) _7 E+ w9 q5 v* ?% D# n6 g5 @# 注意 `tf.function` 的使用) w' \9 M# x$ ~8 ], i' \/ C
# 该注解使函数被“编译”, K, h0 v3 p' ?5 V
@tf.function0 m* W9 |0 w; Q" E) J* t7 y' Y
def train_step(images):
& ]7 P6 U! S1 m; o    noise = tf.random.normal([BATCH_SIZE, noise_dim])
/ b& ]8 }  c: c! c- ?! u" p1 l% p9 ? , R. M4 p) J. ?, h
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:6 X5 c1 e! _: ^1 p. r, X
      generated_images = generator(noise, training=True)( _) U/ y7 ^0 B- e7 w& q
% u& N2 q9 ?" C  M% |2 \- L
      real_output = discriminator(images, training=True)8 |& `; r2 U& d; ~) I* h+ e
      fake_output = discriminator(generated_images, training=True). |: z( g0 p5 {6 S: j
4 F$ n% E) n7 ?/ ~: g7 X2 F' b
      gen_loss = generator_loss(fake_output)/ E; O) |; a! Q7 U- v( ^+ A1 a6 Z
      disc_loss = discriminator_loss(real_output, fake_output): i+ Z7 v- I( o! o# M4 B
: L2 a) o( }! J7 C
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)7 |9 J! K& W4 ?( H* ?) V: G2 @0 K* k
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)" o) I. ]1 z) w0 f) m
# L* C) D: F/ ]1 F) v$ S7 G9 m
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))0 E/ v( x7 y" n' {0 p: v1 B5 M8 q
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))0 u' Z8 L- \  M# ?; C, ^( c+ N8 l

. h2 f) M, w8 x0 B" V9 F, @, zdef train(dataset, epochs):
! W: r9 _" {; A9 A  for epoch in range(epochs):
! ]/ x- \0 {3 r# _7 [' b    start = time.time()( Y8 v5 \. F: y# E! q6 h) [

: M$ {5 E! m! z* y/ {    for image_batch in dataset:# |$ O3 Y) q% w- O8 E) n& A$ A. ]
      train_step(image_batch)* z# q! s! F/ J% I8 F
' x$ X$ }, m0 g
    # 继续进行时为 GIF 生成图像
! z  n# n6 E# ]5 u4 N: L    display.clear_output(wait=True)# a% Y4 I2 q% [, d& t
    generate_and_save_images(generator,5 }# @% Z" L$ d1 ]/ O+ s
                             epoch + 1,  U0 a7 Q. z* p8 B/ q9 Y$ B0 D, T
                             seed)8 c9 |2 `( {  L: S
, b+ g* z: I9 D, t
    # 每 15 个 epoch 保存一次模型
4 L/ v( K* g$ i5 R2 W    if (epoch + 1) % 15 == 0:
# H' F3 ?0 W3 l) v4 A4 w      checkpoint.save(file_prefix = checkpoint_prefix)! v3 Y$ g( b* P/ ^& a9 q- V

5 Z0 x0 m: S& d5 D3 s    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))  x# U1 K4 o' s- f  m( ~% x

) X7 y. x1 H0 {3 X1 p$ p  # 最后一个 epoch 结束后生成图片. N- N1 V9 K& j/ W. E- a9 ~
  display.clear_output(wait=True)2 j+ v& k; {& \9 o6 J; k( a' T9 l3 n) L
  generate_and_save_images(generator,0 r7 c5 [1 H$ b& r, M* s6 T: p  X4 @
                           epochs,
  S$ ?6 s! ~8 p# P* C6 |                           seed)
* c3 t6 u/ U' r: ^" h- _- G
9 s" N  y3 P! t* }- m* I# 生成与保存图片" n7 b' Z, o3 ?7 |* _) r
def generate_and_save_images(model, epoch, test_input):
& r3 D9 }. ]7 H" y  # 注意 training` 设定为 False9 t. s5 W4 ~% i
  # 因此,所有层都在推理模式下运行(batchnorm)。# ]* r2 p& D9 \4 ~
  predictions = model(test_input, training=False)- J, b( `6 \" v+ ?1 [- O: e

1 y  N; D' i8 Z" [% G  fig = plt.figure(figsize=(4,4))- v* n5 w/ }, c; W+ A: |

6 G4 a" V$ b9 i7 F* ~( m/ i  for i in range(predictions.shape[0]):
0 h, E  |. H" E$ ^9 X      plt.subplot(4, 4, i+1)
) c0 [( n, X' _+ ~2 H      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')' M7 o) G) }( E0 n; M
      plt.axis('off')/ e7 ~5 \* I% E7 X
$ b2 \  c" ~1 i) A9 f1 \" Q% m
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))" b5 ]# j8 ^" E: m2 W5 s
  plt.show()7 C3 X/ [* C' p- s
5.3 训练模型4 h( z! P' ~5 M" ?1 j8 H
调用上面定义的train()函数,来同时训练生成器和判别器。% m3 l# ^- [. [* C
2 ?% K* `# S' y0 b7 D9 }
( @1 H) u% F! j" Q/ s! _
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。& l1 ]3 k7 h5 X+ [

* `, }2 A% r6 c
: h7 u+ ?! L5 c) G+ a. ^) r
%%time. o+ r, d9 d. s/ Q* `: w" I( c0 }
train(train_dataset, EPOCHS)8 T  |& B  w( `5 f
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
8 I; q2 j7 E+ s. `2 s, q
2 p/ w- e3 i* X

: ^+ n6 I6 s$ N6 C  |+ a; a* k0 K训练了15轮的效果:
, G" `0 {1 ~, X4 @% O5 R- {5 B& H. `) |7 O' S/ w6 d1 j: W
5 K; C- C+ L# O) d5 B/ B, @- I/ _

5 A! i+ S( g: Z8 |$ i- b* B, o0 Q8 c

% J8 n  O, R% P5 k$ n/ N; x% |, G% {* U/ J  E. I
$ w- f  k, _7 y! Z  w8 N
训练了30轮的效果:$ z- c  s  |) j/ |# [/ j# K/ s
3 \' A- S/ ]" k! Y
3 k# e4 a/ d( M, D) M  r7 O

; P1 {8 U/ B3 d/ H! D# K/ C7 V/ V

$ m: R8 l* V: j1 W8 Q" @. m% y, o! _% h: X3 @1 Z2 @

! K  e$ T6 z" n+ [训练过程:
! K; X- Y/ U6 ?3 j1 s, [. j! d- w) \" p% p
7 q  q; B  A: h/ i$ P- d9 x

' y- ]( N$ H8 A( C7 K4 G7 Q0 T3 Z* b

+ f6 O. e2 r& T' n, M
; d- J$ V+ Z. ?
0 l6 c8 d! j1 S9 c! O
恢复最新的检查点
( h7 m, D" @9 }9 x
% j; b2 v# f( |1 }4 x% g

8 D' Z( J, G/ ^checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
2 ^2 e: G4 h. E% T六、评估模型1 W! X0 W4 }4 i' M1 ?+ l8 S
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。! F6 ~5 {6 V0 G0 }- l* q

$ u8 x5 h9 C* L5 T9 ^6 F0 y5 V

8 V8 V0 g3 S& C% b% r, }' Q& s9 w# 使用 epoch 数生成单张图片
3 [' Y% s$ M* l0 U# p) sdef display_image(epoch_no):
' G9 n. s: R# m1 v7 X0 z; O! Z( M! m  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
8 w3 u7 _' t$ ?3 `/ B - ?9 Q: w6 ~$ J6 \% X# |8 v* v8 V
display_image(EPOCHS)+ z$ g' M- a" O: u
anim_file = 'dcgan.gif'6 t3 l6 Z: h5 W9 h

$ H6 U/ m1 \4 S/ j9 C' V0 R- zwith imageio.get_writer(anim_file, mode='I') as writer:9 y: Z( U1 ]; x
  filenames = glob.glob('image*.png')
2 H- `% P9 a& i4 c/ w- ?; ?  filenames = sorted(filenames)
8 J9 L, G9 I* W: |; e% @  last = -1
. o9 z- L+ _4 [. s: O$ `4 `% U& Q  for i,filename in enumerate(filenames):/ p- I- C/ _  D) L% a' Q3 I, ]2 N
    frame = 2*(i**0.5)
3 e% H" Q. V, f    if round(frame) > round(last):
0 i0 g# U- F0 H5 F/ a% ?( w0 e      last = frame+ V, ?- b6 A4 m: m+ ^/ h/ r
    else:
( V) k3 m; v7 s& w: o& t      continue  f* P" e3 z0 R/ l
    image = imageio.imread(filename)+ p: i# P$ d6 P* S" R8 a, D
    writer.append_data(image)( _) h6 o7 ~. l  a* F
  image = imageio.imread(filename)
" ?3 N  ~' Z7 p  writer.append_data(image)
" ?; W5 e) y8 y4 f( O, J& _" S' L
' x/ G) |; v* b, A8 F# Q; Yimport IPython
$ |; o2 A# {$ m& Tif IPython.version_info > (6,2,0,''):4 L* W. Z8 Z+ x) p. n1 K
  display.Image(filename=anim_file)
8 f" N+ B1 O# ]5 v! [1 g
, r2 [. P! c. f8 J1 Z* G. r" u

  y8 {* m- Q1 @& C5 Q$ M: D$ Y
- L; e: h% y! K% [

. L3 t2 `$ H2 p& N0 O9 @3 p, e完整代码:
8 T4 @; d) r/ @" x$ E
, v, R. N8 Q6 v. d4 Z3 K

7 V1 P. s  [* Q9 \1 Jimport tensorflow as tf
2 d  D) H3 L: U* s# Zimport glob
9 Q9 ?. l6 s' c, O/ R$ _import imageio! ]. Y% S( K% Z* }- |: [
import matplotlib.pyplot as plt! _$ Q- d. S& x$ ^7 M! b9 V9 f
import numpy as np: E9 ]1 Q( u! w0 K6 K1 Y0 Z
import os5 N' N% S$ C  X& q) P
import PIL3 J0 y) `& b" L" \" {- O, ~+ b
from tensorflow.keras import layers
: k4 I1 W0 M$ n4 h& c6 fimport time
% T) R, h1 Y# d0 I ) l) T+ s: s; e$ y$ W
from IPython import display2 g  u+ `$ R2 V: E; L) R
2 B( t" z7 P( {: ~+ i$ a
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data(); L* r! n) |5 _4 z) X( q: ^

/ H4 J/ o4 b" S4 ]* utrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
& p) \1 O6 q* T6 O2 ]train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内- y0 n6 _  t" V  f0 ?

7 z0 L( `2 |; J8 IBUFFER_SIZE = 600008 s) k: z0 Q! \$ J( y! \& F- U
BATCH_SIZE = 2569 V  y: V% I+ u1 b; e% Z
6 p8 L% c( @1 a" {4 d) _: o( u1 G! `
# 批量化和打乱数据9 g$ m. K* M  w0 f
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)( s8 l$ O4 u7 [6 L
' j) z+ }5 N+ O$ P
# 创建模型--生成器! E$ b2 D1 z  I/ D
def make_generator_model():; r# _; m) {; P. b* x
    model = tf.keras.Sequential()
& K5 d. {" P) m1 a' G    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))9 C. x& L7 I( o3 T- Z  q
    model.add(layers.BatchNormalization())
4 ^- C! {/ R  w( o( }$ A    model.add(layers.LeakyReLU())3 H" `9 q8 Y4 _8 _4 E# \

. n$ q+ P+ Z( A4 t" T    model.add(layers.Reshape((7, 7, 256)))
4 U4 b: \+ B2 J    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制! r/ Z1 w2 c' b$ K* u8 H
* u' Q8 t- M# v8 Z; d6 I
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
% W6 @6 u& v$ r    assert model.output_shape == (None, 7, 7, 128)( F" ]2 w' Q* L( E3 T1 V3 E8 b
    model.add(layers.BatchNormalization())6 M  H$ W' ~% D1 q% O. V( u
    model.add(layers.LeakyReLU())" t8 u" C" U+ T$ s  h

, c+ |* _/ d9 [: ^2 n, F    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
, S' E  x! D7 u! Y1 c2 A* K. [    assert model.output_shape == (None, 14, 14, 64)
3 s4 w! T" d- u0 K    model.add(layers.BatchNormalization())( s) r( [1 x' B$ S8 y. `
    model.add(layers.LeakyReLU())2 ?6 E2 h5 `# N0 g, [6 D9 \# f2 ?

' G. S% r9 ?) r/ i5 k- C    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))* J/ U* x0 z% Y6 f5 W, H
    assert model.output_shape == (None, 28, 28, 1)
$ H/ p: T9 C- m8 \) F' S" k
7 t. ^. |4 R, D' i+ V    return model
$ H/ C) W2 @* `0 c
0 V6 v2 p& K- P- \: {+ d; B* ~- e# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。- u5 G8 W+ D, n! R# o, ^6 u6 ]7 G
generator = make_generator_model()
( n: o: P+ I$ h + m7 F6 R' d7 E( w6 _* d5 P1 G3 V
noise = tf.random.normal([1, 100])/ N) _. d# U8 n! C; F0 ~
generated_image = generator(noise, training=False)5 Y; d7 I5 @6 s9 [4 m

# e5 @0 f, N4 W- x; J2 i; J  gplt.imshow(generated_image[0, :, :, 0], cmap='gray')4 n, z! o4 H. {. b
tf.keras.utils.plot_model(generator)
) L; Z; m. ?5 B; l- Y6 {% u' R2 X1 n - H' f' _; E) R9 F' ^
# 判别器0 @2 o. {* |5 T' A  E. ]3 @
def make_discriminator_model():' Z0 O: S- ]) p0 U
    model = tf.keras.Sequential()
; B6 h$ {" H% i( B' R% C' |6 E  h- i) Q    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
4 @( V. p% y1 R3 s6 p3 T5 @                                     input_shape=[28, 28, 1]))
( ~  M$ h7 k! s    model.add(layers.LeakyReLU())2 k3 w" y. M4 h. R/ I
    model.add(layers.Dropout(0.3))6 u6 _, Y! l, _( g* q2 W2 V

; I: T: M( y- _) ^    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))  _9 v! d* g  V- P2 d+ ]- l
    model.add(layers.LeakyReLU())
- f3 {/ X8 s7 c, _* ?7 ~    model.add(layers.Dropout(0.3))2 j6 O1 l" {) Q! K; k6 I

. }6 m& k7 x+ d3 ^5 W( h- x/ }( r    model.add(layers.Flatten())" t  q' m* }7 v
    model.add(layers.Dense(1))
, F8 I. P) o2 H. j0 ]# l1 h + |" a& r3 E' A: Y; e3 J6 @& l
    return model' b5 e2 H& c& R  d
3 `% {2 A' a4 J  b0 k8 i
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。3 S2 J3 M* ~; K0 \2 c, j4 S# B+ k# T
discriminator = make_discriminator_model()* @; g! B5 p& V; C
decision = discriminator(generated_image)
5 N& T  j2 S4 Kprint (decision)
1 a% h- t  ]  ~! o2 o& r ; |4 m4 d" E! d* f8 m0 x  |, g
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。9 l3 p6 x, m' N: y$ O
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)! B* r6 m. g" b" H
0 E/ _# D  l- u
# 生成器的损失和优化器
' K; ~* Q) }" L; ldef generator_loss(fake_output):: K9 r, X( ~! G$ R6 u! _! {" z
    return cross_entropy(tf.ones_like(fake_output), fake_output): g* C2 J0 j! M, L+ F# l& a
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
  {) j, s8 m# B& y/ r' }6 g; q ' ]+ L1 H) z- U% |
# 判别器的损失和优化器
2 w% `& a: t: y4 Odef discriminator_loss(real_output, fake_output):
& D4 `9 ~" X# z, C; p! R( P/ Y, j    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
( J: C/ |7 `; a, d2 S. k  j    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
+ M& b3 h9 m5 u- ~$ [; y2 ]    total_loss = real_loss + fake_loss  ^# L* G' S% v% Q0 D  X
    return total_loss
8 k8 Q$ D4 c7 C0 r+ ^1 Bdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
5 {% h1 D) `7 X- \ 1 V* v6 B  G* l" S6 ]( e
# 保存检查点
1 z4 L2 O3 k4 [  q% q4 hcheckpoint_dir = './training_checkpoints': ^3 J$ f# l+ l/ c* F
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")3 w* B# G0 \( j  b& M1 N; C
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,+ T3 Q2 z% {# ~8 G/ d- ^8 B, x' ^
                                 discriminator_optimizer=discriminator_optimizer,
$ Y* S5 Z4 Y) O" Y0 H7 u* c                                 generator=generator,4 U5 S( e5 X9 J$ y& Y0 Z  E8 ~6 b# V/ E
                                 discriminator=discriminator)( Q+ ^1 H! Y, E& a' x! W! a' \

+ Z8 n! \  s6 X! H' j# 定义训练过程- _  y! g" I4 m8 x
EPOCHS = 505 B; D' a$ h' B
noise_dim = 100
1 v3 h4 m% @4 x/ c* vnum_examples_to_generate = 16
. H- d, s# }* N- E) [' s 3 ?% m; c3 g* D5 U% [
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
; h' {8 B$ M$ U4 z6 Fseed = tf.random.normal([num_examples_to_generate, noise_dim])
4 u/ b4 u* a: W
5 t- E7 ]* _. c" s3 }# 注意 `tf.function` 的使用
: B1 Q& s/ }8 P' g: w# 该注解使函数被“编译”" H. n9 V: l9 p# t
@tf.function
7 t& z$ J  G" ?4 i" ddef train_step(images):
' z% J3 Y+ S0 ^: Y    noise = tf.random.normal([BATCH_SIZE, noise_dim])
/ Q- @2 k+ Q+ S9 ^+ d/ q* I; y, g, V & ~: z/ @' Q$ m! O
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# c" {. j0 g% ~      generated_images = generator(noise, training=True)
8 G# I5 c6 m7 D0 q) ] - h  u. S: b8 d, X7 g
      real_output = discriminator(images, training=True)2 D! D2 H2 a/ w7 F+ g# f
      fake_output = discriminator(generated_images, training=True)
$ M, n  a& W# U
/ y8 D/ G6 z" L5 f0 |      gen_loss = generator_loss(fake_output)
& {- p' @* j- L$ B. b" H# y+ e      disc_loss = discriminator_loss(real_output, fake_output)
# H7 m/ z: g- e4 D5 W/ d " o% E, d, v( m# z/ r, T
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
7 @1 D# E- m' E    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
) n! r8 Y2 X' ~5 Q
7 F5 p7 v0 c! P* `7 q: g0 `8 }    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))  J- h2 _& h" _' y) Y- P
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
& K# P2 j! M4 p) _3 F% s9 t& P ) I8 f8 R/ y- l  Z; Z
def train(dataset, epochs):
: D( U- \2 V1 j' t0 O( T5 F9 p  for epoch in range(epochs):
3 s1 L: m$ \* u+ \, M1 {    start = time.time()
) a! ~  U. O9 g7 C- Q$ k$ f0 l
/ a' o# X8 {  R: u    for image_batch in dataset:
9 I6 D" ~: @0 \) b2 l$ v      train_step(image_batch)$ [  P5 c/ B1 }
7 X2 v  F; ]( t! e5 r" S: V
    # 继续进行时为 GIF 生成图像
" q/ R2 ~& {' }1 [  ^    display.clear_output(wait=True)
2 X1 E' L. u. W* Q& a$ {    generate_and_save_images(generator,
$ @# {! ?  s* \) d                             epoch + 1,
( f( t  ?$ v: B6 E( o# @8 S                             seed)2 l, T6 u8 A' ^: ?4 _) S# r
6 ~- k/ {8 X6 F. ^/ ~2 ^
    # 每 15 个 epoch 保存一次模型
) E; G  r" i1 }# D    if (epoch + 1) % 15 == 0:
+ i# Y1 |( X( M: e8 S' z      checkpoint.save(file_prefix = checkpoint_prefix)
  E" Y  d$ n/ z6 s
3 v, q5 z9 h! R# G    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))! q4 }& u! S+ _7 m" N
7 }. x+ a" n& X& l9 x$ X2 R. F( [
  # 最后一个 epoch 结束后生成图片
, ]+ q0 E- q. O* _# ]: ^  display.clear_output(wait=True)- F# i; Q5 }' w1 e; M8 B
  generate_and_save_images(generator,  r. P1 L6 m2 ?; V
                           epochs,
# W2 G# C* h* ?+ P; |- j7 X                           seed)
2 Z$ `, |& G% C0 a# W. ]) c* [
! \' k. V1 W' C& u# 生成与保存图片; e5 n1 x% b; b3 f1 j. }& O8 v
def generate_and_save_images(model, epoch, test_input):9 j5 N+ @1 z4 l5 W' |; O2 O3 c
  # 注意 training` 设定为 False
7 z( I( R8 T$ F& X  # 因此,所有层都在推理模式下运行(batchnorm)。, z6 K, J3 t6 g) c& j6 \
  predictions = model(test_input, training=False)
9 q% g& h! T8 k- t  y
; N; s$ @* N  I3 C5 z/ @1 P  fig = plt.figure(figsize=(4,4))% a0 C5 B( n2 D

6 c+ \, Q3 J. P1 t& E  for i in range(predictions.shape[0]):
# A# [! L- Z: @) X$ F0 E      plt.subplot(4, 4, i+1)
+ \) Z5 V& c" S- B9 o      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray'), H0 R# `( h# ?3 J2 M4 M
      plt.axis('off')
9 [" c- h! S) r0 k$ L1 C; N
/ B- f# t! V3 D' |6 |) l4 e  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))5 m' Y$ `6 D' h4 u
  plt.show()) J+ {" P% d5 U$ @6 X

! R/ }( r) ]' i  H1 y0 m% b! G# 训练模型
1 _! ^$ Z0 [, Q' b; ktrain(train_dataset, EPOCHS); y1 s8 {- n! B: R% @( S, M; t* v' H

) l7 F$ m& d& J# 恢复最新的检查点
7 i: B' ~4 C0 p' n4 n/ G/ w8 Q# V* w7 ocheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" `  d) p  Z' [

! r/ i# C( u( p1 P" X0 a, r# 评估模型  ?7 \* c5 S. t: o6 S9 L
# 使用 epoch 数生成单张图片% M4 `1 b8 K, R3 ~6 |
def display_image(epoch_no):
  B# C$ J; Q8 I% G' |  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))# @6 y6 O, l; ^7 V

; E& L/ g5 G# u4 c7 Tdisplay_image(EPOCHS)/ k" t1 ~5 _/ }  k" J* j
7 I* |, m# r# h
anim_file = 'dcgan.gif'% Q( M, d% a& K
. `2 E4 P7 @6 U- {/ ^( S
with imageio.get_writer(anim_file, mode='I') as writer:/ y1 l. B: `! _! [! J+ m/ A
  filenames = glob.glob('image*.png')7 r, h- _" G5 I6 o5 J6 O- ]
  filenames = sorted(filenames). M# K! Z$ C3 k( j' R% P6 U
  last = -1% p# Y; }3 ~8 I6 F  J
  for i,filename in enumerate(filenames):
) `1 i$ a7 n7 i* x    frame = 2*(i**0.5)
9 Z$ O5 E7 v1 |. i8 H+ ]    if round(frame) > round(last):
# J. \, [) h  ^' ^6 N8 p2 Y      last = frame9 m7 k& L& z# y" _* i& }
    else:; D7 y1 k$ z" D
      continue. t* ]7 |" y9 j3 j( A7 V. C
    image = imageio.imread(filename)1 K. O( T0 Q, \4 k6 ]5 ~) j
    writer.append_data(image)  M  [$ ?/ O2 \8 U; x% w8 d
  image = imageio.imread(filename)# c1 T& S; ^* t( ]
  writer.append_data(image)
1 m$ i: ~5 D- y
4 P- ~, ^# U: s! X$ S" J/ Timport IPython3 Z4 b; n5 U2 \5 t1 i0 S! O
if IPython.version_info > (6,2,0,''):2 ?1 ~6 f' w( ^0 J" l* `6 A: P1 H
  display.Image(filename=anim_file); N$ R: I- C0 ]; B& K
参考:https://www.tensorflow.org/tutorials/generative/dcgan
6 Z4 \, C6 l) M! K; c, t3 k, B8 H————————————————
( X1 z" J& h- X1 e+ E$ X版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。) i7 Q3 s9 j# ?. Z* N5 G, ]
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111! q8 ?# w, {& [; V9 O4 B

) j* s+ D! G1 i7 L& F1 P( `+ v
0 F& {1 o1 ^' A7 N




欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5