数学建模社区-数学中国
标题:
深度卷积生成对抗网络DCGAN——生成手写数字图片
[打印本页]
作者:
杨利霞
时间:
2021-6-28 11:54
标题:
深度卷积生成对抗网络DCGAN——生成手写数字图片
. y' p2 h+ {0 ]" D# {
深度卷积生成对抗网络DCGAN——生成手写数字图片
' \8 h" D% K, _' F* c( m3 e$ V
前言
, v* O3 K; F X. U/ Z9 m/ k
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
3 X/ j1 w4 X0 O3 B1 F9 j3 t) I# O4 H
! j; S8 x' N& r
6 u+ P0 Q0 i l: y
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
! H7 `5 D% }: Z% k z
& u& T7 e# q1 M* ^; l) Q* k
9 j$ j. p c' B2 D- `4 c& C8 a
# 用于生成 GIF 图片
0 q" ^" b& P# C
pip install -q imageio
+ r( t |# m; e2 ~, M3 o
目录
$ W2 D: c$ H$ X
1 h* `: q' k) H# @1 X
- u3 x) `$ H# o& D5 f) |
前言
: a7 }5 B4 Q, n4 v( I; G
: ^! Y# ?& i$ s { g/ K8 `
' t' Y4 F/ v& s+ W% t3 `: ~' R4 z) t7 c
一、什么是生成对抗网络?
* W7 ^- {; F, `) C, G& p$ v6 @3 F
) k; D2 O! K5 R* u/ w
% J# O9 } e9 ]
二、加载数据集
8 y2 V. C1 R3 ` t
" y! Z$ R- B$ k
4 d4 R8 D9 @8 K, Z- }1 {
三、创建模型
2 ]6 Z' {& `( h4 T
" d6 k/ |& H* G2 L+ _3 _- A
+ `5 y/ _9 D+ x- h9 @% N
3.1 生成器
/ ~! `7 z+ S( ?2 L, S N
0 m# {! ^" R; M1 }$ ~
. \; R8 {) V! d+ ]# g
3.1 判别器
- ]( ]0 i/ b, g
3 v" I) N3 E: e% W
7 v, B! ?4 Y- ]' W1 D, N2 ^ \
四、定义损失函数和优化器
' a8 W" g7 s- f8 P2 v8 {% J
8 |0 @. h% c0 _3 H. E( w4 v
; {6 U7 X! A7 H7 X4 X
4.1 生成器的损失和优化器
4 j9 j, ]0 ?, R4 a1 i
. Z& @8 v9 i/ R h
8 C9 ~2 `7 m5 n- K
4.2 判别器的损失和优化器
1 c+ U! p1 k& P+ @" F! G9 Q6 p
4 }% d+ `- j" ]5 \
9 Z* N* r3 h5 M
五、训练模型
( `' {1 T0 G2 t7 v/ q
" }5 d0 P u; W" p* r
- Q' [& \: p* \; h! K5 {4 r# B
5.1 保存检查点
' A4 J2 P6 S, h1 j! p
, J. h5 b7 Z$ O5 y
3 h4 a2 [% U/ M! |8 K
5.2 定义训练过程
. L2 D3 }. Y; L, n, d% x3 y5 |
. `1 M4 A. y( ?! c1 F/ s! h: M
# k! ?" J$ o8 P: E# X$ G& c3 ^
5.3 训练模型
5 g* z# R* x. X, S
7 ~7 e5 d6 V6 J9 Z z7 d7 J
; U+ S) Y+ ^+ e$ ], m+ j* j
六、评估模型
. R( a {( s* @; _" l% g( C
8 S) A3 C7 C% \9 A* n/ o; L4 g
5 ]5 E: ~- g# l, x
一、什么是生成对抗网络?
6 v6 q8 C; \) c3 F
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
; l" A2 g; d$ ~' o. g( r
- I5 e( x- g5 ] [2 S
$ F9 ^8 O' Q+ k6 n, W
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
' A8 n/ @9 r. L' v( ^: r) b0 h
4 q9 k6 t7 x) X* c. d( M- p0 g
X7 l& A/ }5 l! p
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
+ \% i6 N Q+ d" I5 Y7 F
! ^! ]& t8 h7 \
! }- Y7 {/ ^+ J/ \
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
& Q; u' @% D% {0 x$ { K
# }$ f" B7 |( P* w( d& ~% a+ \
4 d. _- @- Z6 N5 E5 e: f" i
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
/ r8 b7 ~+ B# G3 v7 U
5 D3 {# [+ v, Z
2 X) V' l/ z8 x$ W
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
, A. N+ M# k$ h' c. D4 y
4 I, Q( T& C' l2 x
* h! R0 g/ R. u: _7 O# `: p
二、加载数据集
) H6 a! m7 k }+ ^2 r
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
! d5 d( L% m( r3 z" G# S* {+ q: t- D9 P
3 }. m! ~' Z4 z! [& X1 C- z# @
& q* F5 L: ^' V9 H0 ]% k
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
8 y5 q# P/ s0 T
' y$ h. ]; c2 o/ g" ?$ |7 L: w
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
4 F( S5 _/ H: E
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
. F* S) d/ o% N
8 x& ?( O6 b1 Y- F) r h
BUFFER_SIZE = 60000
. a, r( @+ b- X3 n/ F& F
BATCH_SIZE = 256
) P. S: I; b+ i& S; I* b2 c3 S" v
0 f* o4 S2 p8 e4 c9 i8 k
# 批量化和打乱数据
( l: b; |" ]* X r+ ?6 N
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
4 g( E! {" O, ~, M0 }1 {$ X
三、创建模型
* K" P+ O* A* n
主要创建两个模型,一个是生成器,另一个是判别器。
: F5 j! h7 l+ k. t, w: b# P
" Q3 ?# U# W/ r, j3 M
# f' t' Y9 c# v4 f" _1 _
3.1 生成器
O; J7 f) X/ E( r7 l* q9 a0 J
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
# N4 q6 ]% H, o& b; ^$ j
3 H; [3 ~; S; W3 \
* w6 c9 T1 V! I5 L
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
- x5 y- T- w: ]
0 H" \4 {" }* L" L2 X$ u) Q
- S' C8 [) }, q" A
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
; S, ]+ X% V. o' B: y) s
6 b- w* U& _! K9 I+ o
( d" ~) C5 L, M0 x: H" B
def make_generator_model():
, @7 J3 G7 d5 n' ^' A H/ Q
model = tf.keras.Sequential()
7 Q V7 ?+ }4 l5 ^* q) T/ `
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
+ O7 l$ m4 b1 ]
model.add(layers.BatchNormalization())
* S r2 h2 N7 m# g
model.add(layers.LeakyReLU())
0 s0 S. C7 @/ T+ y1 }
/ }# w& K% [4 m' o7 h
model.add(layers.Reshape((7, 7, 256)))
6 U7 Z) M+ e6 `4 W+ ~' H
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
6 w* o# b9 ?7 @3 Y" v+ O
8 i& Z) W% u J- H. I: A a; l
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
( L8 e0 h" z- m4 t" _% @) u
assert model.output_shape == (None, 7, 7, 128)
9 }1 {) ^( x/ |: t3 O1 A3 O
model.add(layers.BatchNormalization())
( J( n; C$ {, @
model.add(layers.LeakyReLU())
$ h% B3 H) s2 E7 X& F/ V6 Q
5 m* V# T" u" f e
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
; ^' M6 j" @8 ~
assert model.output_shape == (None, 14, 14, 64)
! T7 u3 J" v. T, x6 G
model.add(layers.BatchNormalization())
8 c; |, F- @' C0 ^% E
model.add(layers.LeakyReLU())
3 ?( ?2 Z. S v1 c7 a) \
6 v% N/ R5 |6 n1 Z' F) K
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
3 C+ O8 c3 q2 e3 A
assert model.output_shape == (None, 28, 28, 1)
- E' }- |. B: q2 E' C
# |" A$ W. m0 p9 e+ ]% c4 q
return model
" {3 B$ T6 H6 y! S$ h+ ~
用tf.keras.utils.plot_model( ),看一下模型结构
6 o) S6 N) Q: r, X( h) K6 V% X3 Y
9 e& e4 f$ z3 d9 j6 _6 c; L9 L
1 D2 b! c3 H3 Q) P3 A* }: k, C
2 q- j5 D! z$ a! k# J
+ W9 w l2 M# y Y: C% {
1 c; Z; C% {4 f5 u! v0 Z0 X& z
用summary(),看一下模型结构和参数
# O. W. J' L6 o( B2 [( [- H
) h, X- x, ^+ P/ B
6 ]3 e+ s6 L9 [, \" g/ f! E3 c
* S ^ z! F/ e5 O+ \6 c
( w/ N2 C& s6 W' L; q$ ]! ?) U
1 O# d: ? p" y* H3 e
$ r5 S1 }" |8 t7 m% C
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
# e3 ^" ?5 ~4 U8 d C
5 X5 ?* d$ E( f/ }. l- G5 b
5 n" B1 ?+ a# F8 g6 ^7 Q
generator = make_generator_model()
% v! d7 n6 R! Z
, ^- c4 |9 ^& ^3 y* c$ J+ I; k
noise = tf.random.normal([1, 100])
1 s) P: G* x% u! {6 R x+ k1 _
generated_image = generator(noise, training=False)
2 {. y J! M! L0 W9 d
' h4 h$ d7 Y8 Q+ o0 G! O+ P
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
4 C: `2 _( {4 T0 \7 b& @' C
# g, o9 r$ ]) T+ r
& ~1 s: z7 [/ o+ d
9 Y, W: _+ M! B3 }2 M% a: v
" F( K* K% p: Y: i# n
3.1 判别器
" d7 O, E! ~- g2 k' x2 |
判别器是基于 CNN卷积神经网络 的图片分类器。
' N0 c$ O/ Q( \9 Q6 N7 q
! h' f. q; p- P4 z
\- Z& W: V2 K3 C9 I6 _* P
def make_discriminator_model():
2 ]( L8 O1 F/ [
model = tf.keras.Sequential()
4 T$ p& s, \6 l& r9 {* h* R
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
. z, {1 _0 k5 F
input_shape=[28, 28, 1]))
! s/ `* V- O/ v* d! d, ?/ K: @
model.add(layers.LeakyReLU())
" k7 f7 G6 M2 B: R, J2 d+ y" z
model.add(layers.Dropout(0.3))
0 F0 l/ S! ~# H0 D4 [7 N. z
, z' Z! ]. B9 @0 n2 Y% _
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
. L1 [5 Q, g" T7 n1 T
model.add(layers.LeakyReLU())
c4 s6 R4 W% Y) l# f6 @
model.add(layers.Dropout(0.3))
8 I% S9 C( d; s; d; P
6 `7 o! t5 |# t7 r9 A& X- k2 k
model.add(layers.Flatten())
' F6 X6 d5 b, K
model.add(layers.Dense(1))
; T( }% [: t) T3 m1 Y
: t. I' i7 l3 W0 x+ `
return model
5 n% p5 G# n; R" ^! B4 m# }
用tf.keras.utils.plot_model( ),看一下模型结构
* u' N; i: D1 x z) a
5 [+ a8 \2 D; r) o' Y
' t' E: O* l) A
* t" I$ }% s% C$ j7 j, c# Q7 j# U
8 H/ u0 v9 |$ Y! C E, q
2 P7 x, e9 L/ K9 ^) V' F+ r6 t
9 |) W0 T u9 r/ [
用summary(),看一下模型结构和参数
$ S) ^9 z. U( |" {! G& O7 z
. {' o( [& f9 L' i, P
, g- ]* a5 i6 Y5 V% X E8 f
/ l* I0 k: \& r
8 e$ j/ y6 _2 u& S& U
, q5 S! F7 {$ b5 C
8 q6 H: h( ]& m2 X$ _
四、定义损失函数和优化器
1 V8 O5 D: S5 G! z
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。
3 g2 e6 Z; r% ^7 L
! M4 D+ G/ f- q8 w
) x$ J4 X5 h. z# e
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
5 w2 J# |* G& y, S7 e% I* P% T
& H5 P1 a/ S2 k* S
# |' [+ H5 D) ]7 c& _* j! X, z
# 该方法返回计算交叉熵损失的辅助函数
, L7 u7 [9 l7 t
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
: g' }) t9 Z0 K+ g. Z
4.1 生成器的损失和优化器
. E7 A: a4 o# D" |
1)生成器损失
6 D$ ^' e# M1 T7 K& K, T
, y* j1 B1 V6 J) v. q) U
" p+ ]3 x, a0 V7 t3 t
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
) i+ k N7 R* _/ u, s
5 o# `6 R" E4 y3 X6 A
6 v, q& \6 W4 ?' F3 b
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
. t; y: V% J$ A) R0 |. q
. a5 ^" L; c/ T% G9 F' O- f
( ?4 \' c$ D9 ]/ [0 i
def generator_loss(fake_output):
1 Z: c" f( I5 h: a! }9 M
return cross_entropy(tf.ones_like(fake_output), fake_output)
- g$ V) P/ ?9 F7 F# }
2)生成器优化器
4 G; U* m+ d, O2 f2 b- s# G
+ H' Y/ ~8 V7 ?3 o1 q3 A
, B/ g) b6 J) c; L8 }* [
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
; Q) M0 x# g: U5 I
4.2 判别器的损失和优化器
- w/ ^$ @/ \ H: B
1)判别器损失
$ n3 r! `7 I* L$ h! w3 Q/ V8 O5 |+ X& q& a
) I% N3 Q5 q$ F2 Y# h! @+ p
! y6 F) e0 O6 e& s& K9 b2 L% F
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
2 O4 ?7 ^7 u) P9 q: K) |( a
- m! Q+ m1 l! T0 O6 l9 W
2 P7 B( ~# `% I' L
def discriminator_loss(real_output, fake_output):
8 F' W0 z) J4 d" S( o7 C
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
( M* G4 R9 F; N
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
& O( D. D& ~$ Q7 i2 [
total_loss = real_loss + fake_loss
# N# ]# w. B8 I' y3 b/ Z3 U
return total_loss
, q5 O! M/ ^. R4 k* Q0 ~; X9 t
2)判别器优化器
E5 d( i- G S' N3 @6 k* p% c
m* v' A( a0 U" J4 p ^- @7 o
J# a4 [/ K/ H) P. {; z6 D
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
! y2 f) a4 a0 I8 L, v3 ^5 l+ |! A
五、训练模型
+ o7 Z1 l" X! h) O9 `7 f" ?
5.1 保存检查点
* ?* b! Q# g2 L
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
6 Y8 s: x7 a3 k. Z% ^0 g& }
; g$ {4 L/ A* i' C d
- G$ I9 |" _; Z+ \ X) p: P, E
checkpoint_dir = './training_checkpoints'
6 m) |# S( q; D; W' I* f
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
2 J: v$ b' v1 ?
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
' N5 Q2 r) m& r$ B
discriminator_optimizer=discriminator_optimizer,
, m* m, u1 e3 @! @
generator=generator,
- [& X9 l2 |3 F/ S" U9 E4 v
discriminator=discriminator)
# r' m" X: z+ O! k/ I$ g
5.2 定义训练过程
; O- G- c4 R4 g4 c" E) W$ T
EPOCHS = 50
6 p& Z# n5 ^% |& G5 w
noise_dim = 100
! I6 R1 R1 V% o5 P$ D; _# h
num_examples_to_generate = 16
) f& V( J1 O: K8 ^3 |
5 x" C5 e V; q2 R1 l* ^$ c- J+ i
3 I" l( `+ U# G/ k7 X* }) F9 @. i
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
5 }* ^1 p- { W* b6 o2 p
seed = tf.random.normal([num_examples_to_generate, noise_dim])
7 t- Q; y3 @3 x( W- n' @
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
& [ e W/ X2 _# J7 w
' ^+ l6 S" y8 b/ p3 B1 k# s
1 W! C' j+ S; W- X: V- h' F
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
3 R6 M" U( m% i& s5 H+ d
7 B5 M( F- {8 ?3 c# s5 f7 B: w
/ G0 `. u/ @3 b8 g2 {1 P
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
1 [) K6 E. `4 d+ g" \# C; L
- E) L; g8 _, b0 N, R
$ D9 L2 Y1 D J# F, ]$ |. d
# 注意 `tf.function` 的使用
! G2 }4 d" q* G% |" Y& |/ L8 j
# 该注解使函数被“编译”
* ?/ W+ z9 t, [
@tf.function
: u0 X% r9 ~5 |
def train_step(images):
$ b$ h* M3 }( n2 n3 o
noise = tf.random.normal([BATCH_SIZE, noise_dim])
6 P2 W& Z5 V4 [6 U4 [2 y' s
% o$ u3 p8 n8 B, N3 ]3 q. t
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
$ J4 y! a; u! I+ T* b2 _
generated_images = generator(noise, training=True)
+ U/ e' M5 k% H. X# d* k* l( |) }
; U5 @7 ?3 w A7 W5 J# c* d3 u
real_output = discriminator(images, training=True)
" ?8 z5 ]- K' W4 o; ~
fake_output = discriminator(generated_images, training=True)
M3 b! H! J: P2 B- `1 \* ]! y
' D% e: B7 o$ U( S8 P
gen_loss = generator_loss(fake_output)
4 o- E( L) k$ f* a1 G- h
disc_loss = discriminator_loss(real_output, fake_output)
+ a" j0 g" H4 t
1 c4 ]& E, B# V% b9 ]
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
* w( E2 }& K6 F& H
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
+ Y2 A5 \: K$ K; a" F- p
& Q% A6 u- h& g$ i
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
; e& e( G' F6 k) y* L2 D
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
" S8 b, Q1 k; v: O6 g+ e
+ F/ C! x0 n" g
def train(dataset, epochs):
; ^9 \; h7 F4 U- @6 R
for epoch in range(epochs):
3 Q5 G$ @5 t# r+ {* q0 X& g3 I- F
start = time.time()
: K5 `' i- x; J7 s; R( X
6 ?) r1 e$ G! u! d" S
for image_batch in dataset:
; t S K+ i! P i! B) s6 B
train_step(image_batch)
F; K2 H0 I# [9 z i3 P4 t
2 N$ A/ k! Z6 R! b {& b
# 继续进行时为 GIF 生成图像
; w8 O0 S2 q+ {' o$ c2 j1 W3 i) m
display.clear_output(wait=True)
8 Q* t+ T. v1 n2 a
generate_and_save_images(generator,
! L" n0 p8 P) \* i9 o$ v
epoch + 1,
7 W+ ^$ A" R. j2 y' V q+ k
seed)
: C! f2 l: Z9 ]0 [# |9 g
+ t0 ^+ t# F" K2 a( t2 U% _
# 每 15 个 epoch 保存一次模型
0 r, c" {7 F9 E8 o: Z" u G
if (epoch + 1) % 15 == 0:
# A" ^- }8 v( v6 H$ |
checkpoint.save(file_prefix = checkpoint_prefix)
+ [, K+ p8 Y1 H/ _0 s
. D2 o' I {! q6 S9 K9 u5 Z
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
3 N1 |% a8 t: ]. \; R( R/ _5 W
# Z5 c' {' G- ^0 U" D
# 最后一个 epoch 结束后生成图片
6 V) U. `. h8 q- c
display.clear_output(wait=True)
! W* s/ v& T3 Q4 h
generate_and_save_images(generator,
. B- V' D1 J5 i. e
epochs,
: G+ s9 m, y1 f1 Q+ |: E: K
seed)
9 L( w, c* ~6 R
; K8 |" Q3 N( f$ c. F e4 U
# 生成与保存图片
& J1 d# { l# q3 w
def generate_and_save_images(model, epoch, test_input):
5 {6 |0 [! d" _" ^" K0 y
# 注意 training` 设定为 False
1 u( [& h+ Q6 X* ^1 {
# 因此,所有层都在推理模式下运行(batchnorm)。
$ I7 @8 ~0 k% i2 |% b0 c4 q
predictions = model(test_input, training=False)
) y9 J" ?2 G: A4 z V, W7 s4 i1 ^
3 Q* N2 e3 J6 S7 ? [- k1 p
fig = plt.figure(figsize=(4,4))
5 a% _5 r8 ~9 R; w7 h+ [: N+ T
/ Y- U1 o3 u, y( S) ~
for i in range(predictions.shape[0]):
1 x4 u7 B. u* r# p: `) d
plt.subplot(4, 4, i+1)
" x# D. i$ n& w4 b# Q
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
% ]* k) v* Z) A2 {8 x8 f' J, G. {
plt.axis('off')
2 v1 v- v* b" d2 n" Z. i
/ r; a& t# F2 `2 `$ {4 L8 _9 |
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
+ n! \! G4 P w0 B* N
plt.show()
) s) }# ?' R( P3 N7 m
5.3 训练模型
7 \3 V2 R) ?( T1 e, j
调用上面定义的train()函数,来同时训练生成器和判别器。
4 D! k) [# \" a$ s% E, Q3 p2 W; L- {
* H; [- c' V$ O9 x" E
1 y- X9 |! v% Z+ w/ B
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
. H2 x) a: O% A! C F X
! H3 a" M: P8 E
# K6 U2 \1 r& M5 T5 D
%%time
2 f/ y" G! p+ ~6 i
train(train_dataset, EPOCHS)
( q$ C& Y( y0 z/ R
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
9 D$ l8 N; S* A+ }
0 q* J; g. g! m- j R7 u
) A' z) P; h- P4 q S
训练了15轮的效果:
$ ?! E5 e: b% A% W
6 V! t! `3 |% O$ w' j
1 {& E- I% N- v& P& W
4 s$ V O6 ?- w& l
9 S) C, Q5 c, `, f, G( o u
# u9 `7 s c' G4 b( R+ S# l
. W) H! m' C- N1 N0 z2 j
训练了30轮的效果:
& p7 i7 J# c- J1 F
, i" P* e, \$ R# h0 D O7 k9 g+ M
: [' T5 L* S- F3 B3 g
$ ^+ o) q, f2 Y! e& ]7 k+ |
# M6 C3 F. G7 q
9 ~2 O" O# S* |) ^/ b
) b9 h4 r8 L$ W! A( f( Q& o
训练过程:
+ Y9 o9 h5 w1 Q9 h9 Z- Z' D
' c0 D1 C5 f ]( X& k6 l, j
9 D/ c1 s5 r9 h+ u
0 ^" @/ z, @2 I9 B, f. u/ U; T3 O
& m- R1 Z% S' \- x9 a
) {' w; L1 w0 g: T, O
# N' }; T( d+ y2 W
恢复最新的检查点
9 w4 |4 X1 q. W. v" l6 G! r
! H0 q: S6 R0 Y6 P* V
2 u% |8 O4 m. F1 Y3 m4 ]
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
/ C+ B4 l( Y, I& n( w7 Q7 a+ B
六、评估模型
& N( u7 ]+ ] D3 A
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
0 I$ F6 S9 n5 G- U
* u& v+ ?6 i4 M1 v. u
/ w- O E4 \0 n3 g
# 使用 epoch 数生成单张图片
. z/ _$ n) Z, q# I5 V' l _
def display_image(epoch_no):
5 h% q) f9 s# D, U% ~/ c& Q
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
; _) q4 {+ s& \5 R
* L1 ]: s# L) m# z* o- G
display_image(EPOCHS)
6 t9 i2 `. z2 n4 ]. t+ l" e
anim_file = 'dcgan.gif'
- t; u% d, X/ Z) R: O# O ]
7 J6 {+ D) a" O8 k0 C4 u0 I* U
with imageio.get_writer(anim_file, mode='I') as writer:
; d* ?4 U! h& j$ m4 c
filenames = glob.glob('image*.png')
* e9 c( H& }2 G% |" i
filenames = sorted(filenames)
" N( c1 Q5 H3 m* q6 W
last = -1
% }) h N# E1 R8 P: \
for i,filename in enumerate(filenames):
3 i. R7 E) Z: `! H- s/ j( i
frame = 2*(i**0.5)
7 I2 M- v: k5 i3 r% u5 \( Z8 e
if round(frame) > round(last):
; S* o) v& ^5 m+ c6 f, _
last = frame
: T) R5 Q# K B5 \
else:
7 q. G' t9 B, x X% p7 j4 w
continue
* J y0 L% ^- B( W# A: \
image = imageio.imread(filename)
5 Z" F6 `* f: G) o$ M. h0 Q7 N
writer.append_data(image)
4 t J( w' w# s$ O& C, @3 ]& A4 P
image = imageio.imread(filename)
" {8 Y! U, G, a& [ _7 ?. L% g3 d! v
writer.append_data(image)
. l7 H( z- S0 m1 G6 @' N
4 k0 R# L3 r; ~
import IPython
" f8 j, q+ G2 Q! ]" E) L
if IPython.version_info > (6,2,0,''):
! ]5 B( `, `! e+ u
display.Image(filename=anim_file)
4 I1 u$ u# S3 t9 c# j0 X
8 D/ G2 H! t% O4 _+ T' v$ l. \( v
I1 E% p5 C$ q# n8 a1 ^: d+ V
0 I, L( j; i- u, o
$ I& [) B. Y: l' |2 l
完整代码:
" a% U! d6 n, C7 b a( T
( Q9 `1 \, [. K
% |, \6 H% h( u w K5 ?
import tensorflow as tf
& V' A4 d; M+ X" b' v1 ]& B% a* ~$ u
import glob
& D# x z( u# Q, v
import imageio
: K9 ?6 M4 u* X$ ?5 s7 B
import matplotlib.pyplot as plt
# s: n( V5 u# n
import numpy as np
) d/ v: N$ o/ x( _3 M2 c4 z
import os
% R8 @- \: s( u `4 @+ r
import PIL
4 B0 }. W/ B, D4 n
from tensorflow.keras import layers
7 B1 ^& B/ i5 D8 a
import time
7 E$ X, N/ F' F8 e) K* D
$ |! J4 y# \; ]: E) A! Y6 C: n& `
from IPython import display
3 Z @: Y; U9 K5 O
$ Y" c8 b- X5 V ?
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
! h" t+ b0 J2 K/ m; p$ D
3 N% k$ o' u4 g* s; M
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
" t/ u: K- C" P7 N6 j' m
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
' `. K& [& u8 _% k
$ e1 n: a( r" {( B1 G) V- }
BUFFER_SIZE = 60000
% X6 i2 E2 x$ {! \6 U
BATCH_SIZE = 256
, {. a" b! e6 w% G
' ?2 X' R2 C7 I: {; B1 i q, \
# 批量化和打乱数据
9 D, W; n9 u1 C, Z9 A0 x
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
' _' n8 `& }) M& }1 \& b
% h. x1 Y; V$ k1 k' B1 G
# 创建模型--生成器
9 Z3 K* F5 A: V% W: o+ Y
def make_generator_model():
& X' s( F8 k$ [0 {' N7 M. D% X* A
model = tf.keras.Sequential()
' g, w' N; D, s; E; A9 _ a W5 a
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
: g" T. J" }, t$ ?& u! A
model.add(layers.BatchNormalization())
- s" R+ [1 j" n' r& [) \ A
model.add(layers.LeakyReLU())
1 G0 @4 @% r3 y' {. u4 m% m0 f
8 E- k7 _7 e6 e. `2 I. o
model.add(layers.Reshape((7, 7, 256)))
9 ?+ I1 \* v" ?- C) W% [
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
2 k3 L( P$ A: A8 \) w1 \
5 P- W) ^! |0 u) l' n
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
# a$ [: f! g5 H% Y- I$ M3 ]* H% ?
assert model.output_shape == (None, 7, 7, 128)
4 F: Y0 I7 `3 k& a$ f, a
model.add(layers.BatchNormalization())
; K3 A7 [: z4 K. F+ ?( a
model.add(layers.LeakyReLU())
6 [6 b; u: s7 y# r* o3 l, ~- F0 V
- _. R" `8 N* ]1 T8 R2 z
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
. X2 c7 [$ w0 U1 O; |
assert model.output_shape == (None, 14, 14, 64)
( I- v. Y" T' D( Z3 x1 x: g
model.add(layers.BatchNormalization())
. M& c8 g$ L' \+ Z
model.add(layers.LeakyReLU())
1 G; V9 l. J2 I( s }0 I
: n4 N, X5 r7 V M* D
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
$ ]! f& @; d. l1 }
assert model.output_shape == (None, 28, 28, 1)
4 i3 d* f) Z; U7 c: I( {& n6 k9 H- y
+ O! v7 M- G6 o8 I% g$ f) l9 N( n
return model
k, U' J0 M1 D# }
1 q& J a- T" z
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
) h1 h2 b2 Y; t0 X; p) u! o% q
generator = make_generator_model()
8 A! |: M& ~! h4 r+ F
9 T* A( @, C# W5 _% g
noise = tf.random.normal([1, 100])
2 t) v/ [( v/ c+ k2 R- E- S
generated_image = generator(noise, training=False)
# O6 J9 H0 p! [, u
$ C3 O& L: u- n; ] n/ r
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
+ K9 W) Z: |* ^4 b; {! j7 M, O
tf.keras.utils.plot_model(generator)
" N0 E6 _; J. A Y3 @8 a0 Y% g
, `* Q& I: H' G+ u7 I" P
# 判别器
5 e0 D+ [/ E, p7 I* ~
def make_discriminator_model():
4 C- ?, }3 y. ~3 f; C
model = tf.keras.Sequential()
+ P/ s; B6 o* [7 r! w
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
% _* S/ J) s& h/ O! P7 N+ R; g, v
input_shape=[28, 28, 1]))
2 u7 i5 i) ^8 V7 P1 Y$ d: r
model.add(layers.LeakyReLU())
1 q/ P4 z1 Q: W% @. X3 X' a
model.add(layers.Dropout(0.3))
4 w E& [- F, X% t( J3 q- [
7 O: d: V6 ^/ ]/ Z
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
- b [- v. X3 U: a/ _( V9 S! n" W
model.add(layers.LeakyReLU())
, Q% H8 b0 S6 c# Y- B
model.add(layers.Dropout(0.3))
; w9 _, h9 n+ C9 j# C2 u
4 \5 ]# p0 n; i0 n A
model.add(layers.Flatten())
) _8 s0 B. ]3 J+ u3 E* c/ B
model.add(layers.Dense(1))
- b% W& H7 Z# Y O) j$ ~/ x: @
; t Z+ ?+ h: ?
return model
% G L1 `0 q6 I/ v3 a1 F; ^7 H
( f" F( E: D& T$ q& y. b) y
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
! k3 e8 V9 d- z/ @
discriminator = make_discriminator_model()
. [' Q- ?% Q! A% {
decision = discriminator(generated_image)
|$ N& k( L/ G0 x; a# [
print (decision)
2 V4 @0 u* W0 a: Q; ]
* P5 n4 D! b) c, a
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
$ n* \5 n% l5 i8 o
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
0 A/ x6 K3 h0 N+ S$ y- F
# P- f+ e8 A; r6 k# Z0 m
# 生成器的损失和优化器
+ Q8 Y. W0 O, H5 n5 U
def generator_loss(fake_output):
, [# `. m0 A! W& d
return cross_entropy(tf.ones_like(fake_output), fake_output)
, W ]5 w; g9 L( D8 K/ _
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
* X6 C( z( R' T& [$ t9 G5 Q
$ _! i4 W* D- [, {% W
# 判别器的损失和优化器
; ?3 I7 G7 x( D1 ?5 G- K
def discriminator_loss(real_output, fake_output):
! m& o4 [' @ q" [6 E
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
4 g o( w* k; u5 X7 _; @0 k
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
0 N. t7 I. I7 r- C$ R7 y3 i
total_loss = real_loss + fake_loss
& d X. g8 |7 @# H
return total_loss
" W1 \8 {9 H s7 e, @1 G- n: ?
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
: A) p- W* Q0 D8 q* c/ P
( a+ s# r7 z5 G8 \
# 保存检查点
( C" `5 b* j: q( x N. f/ b# K1 Z
checkpoint_dir = './training_checkpoints'
) d) m- B7 o; I
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
" p5 ]( T* q! r
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
" O3 j- e5 l& x; ^3 h
discriminator_optimizer=discriminator_optimizer,
* j( u9 n- F# c; `$ K# @7 i; z) k
generator=generator,
& P7 u4 _3 y: K1 `% q f
discriminator=discriminator)
: h. @! Y1 ^( K, k+ V# z7 A
' L5 B. G% M! o! m! o
# 定义训练过程
% t# k7 z2 u ^' D8 t
EPOCHS = 50
( W$ P1 [: x$ x0 x& P
noise_dim = 100
9 M% D, X5 I- k% P3 w
num_examples_to_generate = 16
+ h5 o; L" c) e: H6 |- R* t/ C
6 ?% W$ e; g% ^7 G/ h# _
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
' P+ o" A/ f; a
seed = tf.random.normal([num_examples_to_generate, noise_dim])
z% f7 G/ j; m9 B& i
) ~3 A" l, J. w4 u8 R$ N! _: l
# 注意 `tf.function` 的使用
* t* W, h( }1 \" c
# 该注解使函数被“编译”
2 S2 |6 R) E0 O, ~+ D
@tf.function
g) w1 R Q% a- W
def train_step(images):
, m7 [. v9 s7 ~; [
noise = tf.random.normal([BATCH_SIZE, noise_dim])
* i: ^( A$ _* Y, D' L6 U% Z: S
7 f+ |* h% O, e, L7 y @( `9 [
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
7 I; t. I7 m0 R1 s# B
generated_images = generator(noise, training=True)
* Z* E/ H3 ]3 o2 ?
0 H' p& d. ?6 {0 r
real_output = discriminator(images, training=True)
! _8 b/ e( Z+ v" @
fake_output = discriminator(generated_images, training=True)
; L- h; O) ?4 ^
) ~1 Z0 c, }/ p; y4 X* |/ v* l# ?: B# h
gen_loss = generator_loss(fake_output)
7 z& `" W* r# v! a
disc_loss = discriminator_loss(real_output, fake_output)
+ V' x( g% ^. p4 B" i! N
" j: b( Y( b* J+ l/ e- B' ^8 F* g7 j
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
( ]6 P. d$ L3 L% S
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
( I8 s3 ~" k, t' [5 k
0 g* Y7 w: C2 _+ S( g/ V
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
2 u1 x x/ l% a) g; u
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
. ?# R I: j) m
8 K' v/ r1 ]' H: Z) t* r {2 j
def train(dataset, epochs):
7 S( W' C# ^# z) X8 l/ j
for epoch in range(epochs):
. {3 |. z& I) H. X
start = time.time()
/ v5 Z6 }! s' q
+ L7 H2 \6 w4 b8 [
for image_batch in dataset:
) h3 q' @* G3 u" N) R+ `. v
train_step(image_batch)
& G( y1 ]1 V0 j+ b- x) W
1 [2 p5 T$ s ~, N" y/ K: |
# 继续进行时为 GIF 生成图像
, o( d% L D+ g& W6 g$ W. O
display.clear_output(wait=True)
$ O* |' }; K3 r9 y, ^4 B8 w K3 T
generate_and_save_images(generator,
# c9 g9 n5 F4 S& u0 M% Q D
epoch + 1,
# q# z, v/ f: Q/ @+ M) k F, W
seed)
; ]1 b* d8 C% y$ X) m
& @, ?4 c& l; Z9 t0 R) f. N2 n% \+ Q
# 每 15 个 epoch 保存一次模型
' Q% x. {2 P% z' ?
if (epoch + 1) % 15 == 0:
9 p% l; H$ o+ c- B$ e+ y! [9 U
checkpoint.save(file_prefix = checkpoint_prefix)
8 L9 E+ a1 W: u( W& _
8 F' M# |0 c! F6 S2 Q) }' F
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
8 U! _/ Z4 f, [9 a- [2 v5 W4 Y
5 O v8 T9 |( B) q' t- W, ^$ @
# 最后一个 epoch 结束后生成图片
9 Z# F' f6 I% y2 e4 u
display.clear_output(wait=True)
! {8 g) d/ s, o7 d. Q" Y
generate_and_save_images(generator,
- r+ V! g. B" w4 x2 M0 j( {1 a
epochs,
, N1 e% `! z1 O) m# s, W
seed)
/ L1 J \' M8 n! x6 p# J
+ N/ G# u5 J/ E& h+ A' Q9 [+ N8 {
# 生成与保存图片
+ T. C/ Z! x, B2 ]
def generate_and_save_images(model, epoch, test_input):
- F g* \8 u* z# Q: x& O& N
# 注意 training` 设定为 False
8 U5 W6 e$ n0 X3 O$ e$ X
# 因此,所有层都在推理模式下运行(batchnorm)。
U9 Q i6 X9 |' m. i: t% m( e+ [% H
predictions = model(test_input, training=False)
4 S8 A7 _/ S c! S) q) o. R: b
" _; C0 K7 T5 |# Z: G8 p! E
fig = plt.figure(figsize=(4,4))
1 M, [- C) b) x; y2 l( W2 ]
' T; x/ d9 q" a( ~
for i in range(predictions.shape[0]):
9 n- R: v) l& v) D8 H- M
plt.subplot(4, 4, i+1)
. [7 i* u# E- i1 n1 V
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
' K \; {) y: `$ b1 q
plt.axis('off')
( _1 |1 {8 M: Z9 A/ M
1 F b; i" T0 X3 \" ~- w% Q1 j7 D8 Q8 p
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
_& V. P b3 ~% v/ f5 f8 ?7 ^8 @; e
plt.show()
, D4 g) S7 o+ y2 [! `* u
* \, z! y, N- I: j4 n
# 训练模型
* J1 P/ W4 q. Q9 _8 [
train(train_dataset, EPOCHS)
# I; s ~# j* l" T6 j3 R. s
- d' x0 K1 K; F# \" S6 K% [: b! E1 n
# 恢复最新的检查点
$ m: R6 E& I! P3 R
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
/ M- H t" o; v* {2 O, O$ O* C9 b; _" u- [
/ M- W& _9 _8 v3 e3 I
# 评估模型
2 e- k9 P% p" Z2 y' F
# 使用 epoch 数生成单张图片
9 ~& n/ E# J# x0 A8 H( c2 I' ~
def display_image(epoch_no):
& l1 H1 X+ ^& l: I* Q
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
) {( q3 x+ K7 r- W) X" i& c: k( F
' |/ q7 ]( Q2 o
display_image(EPOCHS)
6 @- E6 D# w6 F
8 C5 C* t. B f( U5 o0 n5 W
anim_file = 'dcgan.gif'
0 q, m7 a$ P E' n
) }+ n; P& x/ O' k% K7 a; U
with imageio.get_writer(anim_file, mode='I') as writer:
Y5 c( l b! d; v
filenames = glob.glob('image*.png')
/ }1 J( S1 G. d* ]2 I- v
filenames = sorted(filenames)
7 a: w6 M4 N) n8 T4 p0 ?5 L
last = -1
9 \8 y* l/ b& w2 Y5 n9 m: W J0 @
for i,filename in enumerate(filenames):
( `* ~) s2 b6 `. ?8 y
frame = 2*(i**0.5)
8 ^) g$ E( I0 L% g1 |0 H) i
if round(frame) > round(last):
$ o! E( J$ ^9 ^
last = frame
) V: O+ o4 f4 V+ T; Z
else:
1 T& [+ b: ]+ p9 F
continue
5 Y j5 p+ r9 Q
image = imageio.imread(filename)
1 X# |) Q& u6 X/ E) [- F
writer.append_data(image)
0 b6 \/ P0 @. h) A: @& L
image = imageio.imread(filename)
: K1 ?5 x$ C& c8 P
writer.append_data(image)
8 h* E2 x. D, ~$ W& I5 Q
9 u1 a$ D0 @4 r) M+ l0 d7 k; Y! b
import IPython
" T& c5 X8 ?( D* c
if IPython.version_info > (6,2,0,''):
5 Y% _' Z8 c1 e/ R
display.Image(filename=anim_file)
' f& B& _( \) D- Z' Q/ D/ ^, o
参考:https://www.tensorflow.org/tutorials/generative/dcgan
9 h- s4 G' y5 W* B7 l# _; Y s
————————————————
8 N M3 a$ b" A4 U! Z
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
% o9 C- K X& |$ l9 n
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
L- z1 n6 z2 `6 Y2 A0 a
0 K l6 G) G6 n& r$ E' B, H9 w
9 b H1 [$ b- {0 w9 j. U
欢迎光临 数学建模社区-数学中国 (http://www.madio.net/)
Powered by Discuz! X2.5