- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563349 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174228
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
; G/ L3 }+ [, |! u5 `, j
深度卷积生成对抗网络DCGAN——生成手写数字图片
! D0 ^& Z! n! e/ h0 O" `前言
/ w7 B' N! r* E/ p" x0 o本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
6 X: O/ Q1 F! V; X0 l2 C
+ }# t2 W( m( f) D9 C' v$ B# k( Z0 b" n# ?* P A3 X' }, c
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:# \: K2 y9 q9 X4 H1 V
" ~3 U# l9 `6 k- r5 m4 }0 O* x/ Y/ W
3 q' ~5 S C1 C7 k/ K# 用于生成 GIF 图片* S1 j, _8 V# \) ~' d) Z4 F B
pip install -q imageio
- E: r, J O. F( s+ R目录5 @6 f/ U- A+ M0 m9 H
7 O. D8 w% W6 \) C
- O. [. M6 I3 s/ `前言% h; Z& `, u) }/ k, N
; @0 o9 F1 C! s. \) i
3 u5 ~' K$ B5 ^8 W( A. B4 q/ k" V
一、什么是生成对抗网络?
2 |4 A& E9 Y/ v/ M/ a- A- t% ?, Z! e, N9 G- Y0 w( w2 `+ t* @
7 t3 U7 E% C8 k+ u4 l
二、加载数据集
2 {9 P. u- v3 Q* j) k* r% g/ ~. b3 L6 ?9 k2 ~
7 H/ ?8 S; n: d/ U# P三、创建模型6 [5 J1 f1 } i4 A
4 \: u" c8 W! z; \- a Q( k) j4 [! G# R3 W. s/ `3 h+ p9 Q, u
3.1 生成器; l/ G v6 w1 f% r! t$ X; g
6 w N0 a- v# d [1 h
9 s, C: O6 c" w/ x. v
3.1 判别器
2 `9 |! O" M V0 W& _
, b$ K0 I8 n/ j2 b/ V f/ b& }: g0 a8 z, Q! F" E* A. `
四、定义损失函数和优化器. y& r4 r! U c; C) q2 D4 C6 U
/ f2 v v' V+ l& B) D0 l1 n; K2 f
2 l1 j# i& Z4 K8 Z! ?' L4.1 生成器的损失和优化器9 l+ L1 H( i' O7 T
( N$ B1 ^/ x% @* g, m
* j0 t% L" J) i- N
4.2 判别器的损失和优化器/ P; K/ F( n0 Z
& U0 j& g: _2 t7 Z- G
' j! z+ t( i- }五、训练模型+ D% b, P# A2 B8 l# I$ v
$ R6 e- y% c! j5 L
7 o P" \1 A% `+ B% Q
5.1 保存检查点
( P' b+ a0 v: g% R/ d. ?! f9 j w
+ f3 |+ ^; S7 F; Q" v) _5.2 定义训练过程
; {3 m# i- ?- x; ]; H8 J
- G3 P; C0 e' X3 p7 z9 q
- v e' B! F. ^6 U. |( i- p5.3 训练模型
1 t' Q% v9 v, p0 G& u- Q7 j/ F5 ~# M& h; _5 G4 ]7 `6 J4 _
4 G/ o2 L8 }& T六、评估模型
# d! N+ N8 \' [) G% Y$ ?4 u
) b! n" O8 I* k9 b, A" ]) x
( ~1 R1 j9 p) F: a9 Z一、什么是生成对抗网络?
. V, f- ?$ F2 b" u: h5 z8 }生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
5 j3 c4 `. H0 P; g3 q) L2 m) \3 f4 C1 h0 _
6 b# l- B* O8 o9 _: g$ E8 N, K- Z8 w
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
3 f) O* U' I! N7 ^
, o* v* c- B- ?" M) m# m
7 G7 }- X) H8 N6 g, b; Z/ ^判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。8 o3 K! r: b% j4 z2 A
0 K p \6 n4 @/ Z: M8 p8 ]# N# f! V% G
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
6 K+ e$ Z% l5 Z, y9 A! P
4 ~ e* x' `% E, c3 E) I# L2 [2 h% Q$ c1 F
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
; Q$ G9 I; k2 U7 y9 E* F+ y2 e' J# B& Z
" b# x$ k! B9 T7 {本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
3 e. p4 G9 d7 d/ A K( N. w. |7 P; o$ a' ?8 u
* G% b+ \% \/ E/ Z* |1 k
二、加载数据集7 w/ q; j+ }$ {% @; s' O6 I
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
6 Q' }3 ~' f5 h! S' [% }- x5 I% x+ `& u- C& m
2 i; k& w/ C6 r: }(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
6 Y2 P4 M9 B3 L6 q ' E( I( G/ {% z
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
) S+ R' ?# i6 L8 Qtrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
5 J. T+ c: U. e: i! _
% r, N6 j( K, \BUFFER_SIZE = 600005 e' D9 N' {' g# |; J
BATCH_SIZE = 256
s1 L% G4 r' L. C" F
. M9 V' p- U! i+ w# 批量化和打乱数据
9 q1 f; x7 _) E% a6 [; L5 s: k$ qtrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE), Q6 O5 I# }0 E" R+ j; J
三、创建模型
+ ]- S* {2 ~+ f) N8 C5 a主要创建两个模型,一个是生成器,另一个是判别器。' {- a* y, G+ J" @: ^9 O8 {
* O( d5 w( d7 u
% B2 [! Q! \& C1 l+ @3 T
3.1 生成器
, @6 j, N9 h+ C0 P, l$ [ Q生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
" q: m4 ?! \. ~" R3 w3 ^2 p: p
$ i* j" U( P5 @# h8 s& r: u) b- F
: J5 v3 l- j G6 }' ^然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
7 L. L0 A! o6 K
# R, Q! ~ a& R+ N1 s: |
3 A" \5 W: O+ d* z/ R/ X$ }后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。( Y! B7 [ t6 ]+ F
. V! ~% N0 \+ M( x+ S; L# h# z3 O2 |
5 j- x4 x2 F5 K; wdef make_generator_model():# ^8 j, x. D$ e2 s2 b/ |
model = tf.keras.Sequential()! A- Y) ~( z* ^3 N3 J: m; d
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))$ O: }( ~! F7 |( Z. q5 }
model.add(layers.BatchNormalization())
$ Q+ Z% C7 K! ] model.add(layers.LeakyReLU())
8 f9 r3 _7 q, H6 d* ]
7 _! Y0 H' n ^" e model.add(layers.Reshape((7, 7, 256)))
" l& @( E$ g& }; l1 }7 W @ assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制- Z" ~8 `% R* N0 g0 p3 }" A. P
2 I. S% P/ u* O& a
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))1 T9 J1 J; {% {) s7 u( W
assert model.output_shape == (None, 7, 7, 128)# o6 c' M& B, d& M+ ?/ m+ W
model.add(layers.BatchNormalization())
, v. K' ^2 g4 Z3 ^6 j" F* X model.add(layers.LeakyReLU())1 C0 y% H! A, ^0 u0 ?" f/ ~/ V
) r6 G# T3 d+ \1 x/ V model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))) C; m% t- j- D4 }3 X1 A
assert model.output_shape == (None, 14, 14, 64)5 X/ P+ l5 [# ^+ h" ]" v( Y
model.add(layers.BatchNormalization())" C* g) z" R% _/ e# W4 `/ A2 b
model.add(layers.LeakyReLU())7 z! M- X9 b2 R/ \7 Y7 _% D1 C
2 G( J, q- _7 V6 d! O
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) }5 S5 O1 S- c$ A
assert model.output_shape == (None, 28, 28, 1)
C' j# g, _2 c) p6 u) L & A3 o5 w9 ~) e# j+ W7 h
return model; m0 S; @: w: Y. T/ ^7 V
用tf.keras.utils.plot_model( ),看一下模型结构3 S W- X. x/ ~- S" a L/ d
& F% J+ k& a% O8 Q
& W$ v x. t- i. S) l, Y% H8 h
4 C4 M3 F) S$ t
$ {% E1 j# x1 R% Q- |* S( ^# {( f! P' B0 y- I* I" e8 S
用summary(),看一下模型结构和参数2 v: C1 o# @( k0 X* b. R
- b P4 u% B" y* {5 w
, Q9 ?+ \' [- V" n4 n* f [/ {7 |4 A- o
2 A2 o# D+ A8 f7 |& r* h( i" d- W1 { j, e1 ?! \& c( C" `
T" X5 o# X ^3 I使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。4 Y# z b! w* d7 @; T
! n: e5 m, x! r6 E# ]5 I) Z3 ~ Y/ x5 c- b" ]& J
generator = make_generator_model()
# W$ J, [$ v5 w- X+ c; b8 j% z' y
! \1 `5 Z/ u l4 r3 p u) A2 r0 n! @noise = tf.random.normal([1, 100]): H) B9 I1 U$ g
generated_image = generator(noise, training=False)" W, q8 O3 U7 i! m
/ P7 `- I$ z: Q' f
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
& Q6 }% S7 @( p7 M& b
6 L5 C/ T# Q4 H( g3 @3 _6 d1 B/ H; M
0 o' f/ a N4 ^" m9 ~! O* A( p3 ?+ ^7 b8 Y* U# E% ^: z
3.1 判别器
7 E" h2 l- z; G- ~1 \; E判别器是基于 CNN卷积神经网络 的图片分类器。
! u2 C+ q5 Q( E
& v/ g2 F3 y$ w! }/ |2 \* r: ^/ a2 O( ~4 ]
def make_discriminator_model():# x- S1 v5 m& A. \
model = tf.keras.Sequential()
" P2 c W- g* l) t, s; d" C9 h model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',9 W0 S( [0 [! B" S$ T- |
input_shape=[28, 28, 1]))% a9 a7 s' n/ }5 g
model.add(layers.LeakyReLU())
$ o( Y3 f! V+ Z# i1 ~4 o model.add(layers.Dropout(0.3))" u- S+ R$ t1 P' w2 d
0 z: d0 u& w4 V model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))0 P4 h9 V' B6 W
model.add(layers.LeakyReLU())9 |' [" [2 X z) t0 H$ c
model.add(layers.Dropout(0.3))6 R; o6 z0 V0 F2 ~% N' X% F0 A
# ~- m* D% v' j
model.add(layers.Flatten())
1 w, q& h& ]6 [8 G( M model.add(layers.Dense(1))9 d3 g( V7 m8 d
0 n6 F" O! i$ [) q
return model' }+ N' B% N0 q# U+ _0 D
用tf.keras.utils.plot_model( ),看一下模型结构
- W$ h6 I2 P* v0 s8 N6 g( D; _5 j/ @: Q
6 |9 m6 M; t# M
0 g% c3 a+ L0 y- q7 a5 H: q
: z4 y9 u% U8 c1 s1 i0 K1 n, x; w* H
9 A/ y+ i8 \* n7 B* J5 l2 ?
用summary(),看一下模型结构和参数
5 ?, j1 D7 k( }1 p6 M4 ?2 a
- M% j: k; r, t# g8 E" _3 R& G+ s6 ?1 y) Q- X) I, [1 w$ D
5 R) G/ b8 d9 T3 s! m4 R+ h+ m9 y. s; t' I( F
3 W& j; D3 P& D; n
9 C# |/ O4 n! n8 a0 U四、定义损失函数和优化器
/ A& e; C( l- F" U, N由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。9 A) F2 {9 H+ l) c! B2 R% U. ]
5 t P) K2 \* t. m' j
|/ F' D3 b. ?- V) Z2 C首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。3 ]- J* w# G$ H: q6 J G
6 s5 ~$ |7 n7 S! ~& [1 \
6 g" H+ ^. a3 L8 S; o0 g# ?' ]
# 该方法返回计算交叉熵损失的辅助函数) l, J3 q! \4 x
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
% J. r' u7 g) _( ~9 v4.1 生成器的损失和优化器7 f' o7 h2 ^( S+ }+ F3 p. ~' |
1)生成器损失5 u$ d r* A* C" `+ T
0 K F# P% }- a8 w, s3 j: M
. A3 s8 K1 R, l8 i* H+ Q5 ]生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
/ D7 D+ z k$ k' U" \; ? |" n0 ?- Z- T8 h# k: h( P! w* ]
& n" k2 }3 C+ D3 p. e这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
3 G: i& e% z1 T; e% [' P! x1 {, Z4 p1 J }: y7 v0 q
# w3 r( ?# Q3 R
def generator_loss(fake_output):
- I V& g) ?" j- T return cross_entropy(tf.ones_like(fake_output), fake_output)
& m5 ?; H5 `8 ^' l2)生成器优化器
3 C+ C# {3 D( F& t, U5 w6 P7 z' p L% Z) I8 B: G. H8 ^, }" @
# ^7 v0 K" o' L' Tgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)
7 U1 ?# Y" g& f3 k4 \: c4.2 判别器的损失和优化器
# g% u) X. i3 U' F: `3 g. r7 ^1)判别器损失
. _6 o! p; v7 f2 g- M& e( _/ \
4 X O: A, L' K) ]+ m6 i1 l+ T& a) v6 m8 P
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
* D/ ]5 s: z; D% c# F, G/ E; a/ X" w$ q1 w% [( _9 v2 q
0 y: l0 D: H( B2 X. s( `def discriminator_loss(real_output, fake_output):
3 O4 a# `3 ^7 R real_loss = cross_entropy(tf.ones_like(real_output), real_output)& S4 B2 t% y( Z( D& t- s' P5 t
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)4 e! ]2 G( ]& B9 p
total_loss = real_loss + fake_loss
- g" W$ S; Q) Z7 z+ x* w& X return total_loss0 p) O. V b7 T/ s( O3 m5 w6 r
2)判别器优化器
5 r' p" T. o2 O" N' F d3 G6 ~! n4 c3 u0 Q+ R+ P
1 Y6 k) h n0 ]$ Z1 ]* p& C6 {
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)1 D7 T) c2 J' u- r8 U
五、训练模型
) J& i6 U& d& y3 N1 P* [5.1 保存检查点
+ }, W. A- |- o7 u) B0 o( @% q保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
* c/ \+ U. q0 B7 t- q' p2 x
" ?# i. W7 Q# L# |* u# \& U: g4 A: W) m Z& t: k
checkpoint_dir = './training_checkpoints'3 p/ v$ O8 c, K: O4 a. U1 G$ E
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")" J! G2 g$ r- L3 P# G& _9 @5 y& F
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
- Q$ g* U2 C0 K discriminator_optimizer=discriminator_optimizer,
0 r% Z; v1 y/ I7 p$ v" K generator=generator,
6 @3 s1 X" c5 R+ s& \+ u discriminator=discriminator)5 }- V1 x5 D: j% ~0 j; u
5.2 定义训练过程
" p/ o! o, W) S* F( K; IEPOCHS = 50
) h# h+ m( S/ [( Snoise_dim = 100. y9 U( d& T p8 Y0 |, `8 Q
num_examples_to_generate = 16
/ ?/ K" h( r8 o1 c# |6 k
_) N5 t4 f/ ~7 P1 X* [/ H $ ? }" d* [ ?& P
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度). ? o) s2 g+ L; N8 a. T
seed = tf.random.normal([num_examples_to_generate, noise_dim])
, F8 `$ C9 E4 I3 |/ g训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。' j0 L2 R; E. h6 @( w
0 H3 }. p* W: \' r3 K g8 n7 k6 W% k# O% @$ F
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
6 ]2 k- H* ^1 s( }2 d, I% [
2 u& v9 `0 d) U8 S) z6 ~7 s( ^0 I1 H: C
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。4 ^( a5 P# P7 g# J5 `
4 k& R/ S2 @) {( [9 h
/ i/ b: ]; c' B8 s# 注意 `tf.function` 的使用
& t$ f+ x5 M" H" A3 W+ V. N8 _3 [# 该注解使函数被“编译”
' b* d9 `# a1 w" E* X@tf.function9 d: y2 v+ i- N- O
def train_step(images):
8 p* L4 N3 I2 H9 o noise = tf.random.normal([BATCH_SIZE, noise_dim])# T. f3 e8 q+ v% o
! ]" z" j; R6 f. X8 y. l$ s with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:" H- ]/ g# F c( ~7 v% u
generated_images = generator(noise, training=True)
5 f" V+ m- N! l$ h 5 ^2 c/ x4 K+ S) A8 p: |
real_output = discriminator(images, training=True)0 R2 X0 B7 v( N8 B z
fake_output = discriminator(generated_images, training=True)
% w' P, `5 M9 _# H1 Q9 K6 _ ( `- @. e" i( J1 ~6 R0 H2 w
gen_loss = generator_loss(fake_output)
' ~2 `( m& D* _' ~6 S* p5 r disc_loss = discriminator_loss(real_output, fake_output)
1 a2 L. A2 v/ B) m' t6 E2 c" T: n
3 H; ^/ j2 @/ \* e gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)6 n, u! Q B& b4 J) y) n/ ?
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)0 y4 |3 K$ L4 @) T; D d
" |% O) D! N$ I+ \) r generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)); H+ F: u: `0 z1 x$ M
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
* r, N: r$ S: D, q# E9 [" e, r B( D* t0 i+ f& b# `7 o z
def train(dataset, epochs):
+ I/ j& O3 h; k) E W: \& f7 N8 E for epoch in range(epochs):
5 a. X. N- r z5 T. ? start = time.time()
R0 }/ Q+ ~- d" ?
( O4 ?/ ^/ @" f! Y for image_batch in dataset:
4 @1 N2 f3 Z' d: X& L train_step(image_batch)
^, ~8 K1 l4 s" o2 s
- P: T! H7 ?: Y # 继续进行时为 GIF 生成图像# d9 t9 h6 [3 d3 X. u# Y
display.clear_output(wait=True)
+ T7 n V8 e+ O generate_and_save_images(generator,
+ h! l7 \* Z, G9 }& @5 q epoch + 1,
E% t+ h; B3 K' J0 C seed)
- |6 e$ b" d9 n" x7 b5 ^2 p4 F : q! [. }" {, [; {8 D5 V
# 每 15 个 epoch 保存一次模型1 U; ~" a# z8 y
if (epoch + 1) % 15 == 0:
2 A6 N; \5 u3 A# j4 { checkpoint.save(file_prefix = checkpoint_prefix)
$ y/ N. b3 g" O0 A# O3 h( m* w
e4 a# n* [) x! z p+ ?. B& [/ ? print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))1 k5 V% B# D& ?: Y
) ~; j Y1 S% T8 h8 G* @
# 最后一个 epoch 结束后生成图片6 z$ Z' S( \/ Q i
display.clear_output(wait=True)) n* Z7 ]5 n% v9 {' c% l! I
generate_and_save_images(generator,9 F3 W1 m Z. F
epochs,
2 g9 N" V) @& x seed)
i2 _( J8 t( d! D! o
% H# X0 ?' S+ @# 生成与保存图片
' [7 E5 ~' ` f& _) b* Edef generate_and_save_images(model, epoch, test_input):2 j O1 T2 r" c0 f$ }. y. H
# 注意 training` 设定为 False" H: ^/ G2 P0 f; ^
# 因此,所有层都在推理模式下运行(batchnorm)。
e3 L% A% p2 |; q" C& e" j predictions = model(test_input, training=False)
4 ^# u% l- o6 x' k N' Z
$ t0 _1 x2 s9 T2 t2 L fig = plt.figure(figsize=(4,4))
2 x$ d" X4 w+ V" f0 r
; V% \1 x, Y# [# s" y for i in range(predictions.shape[0]):
- U v& @8 H; G# G plt.subplot(4, 4, i+1) |6 H G7 R, X+ y1 U0 V
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')5 R2 y" L) F6 f! M( ~, P3 B
plt.axis('off')
- H7 C c/ M- P" D6 O
. a; e) U. H u plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
( N5 G4 \& M9 V plt.show()! a9 @2 o: j- `+ X& ^, b0 y; B
5.3 训练模型1 R7 |. _/ w1 x% D8 d# \
调用上面定义的train()函数,来同时训练生成器和判别器。& b1 n* v% D9 F/ S" j
( d6 X! l2 U. ?7 a. m! B: z
* ^, h2 H3 F5 ?# {) _
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。& |+ e9 L+ }- C
* Z) Y4 a& M; K: i8 f% w1 p
, i6 }7 G4 w: h, ?3 N& j; P%%time
0 Q% N" @+ B8 \8 n$ j' J! F9 ~train(train_dataset, EPOCHS)
3 ?' s- F8 \( P4 @" Y在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。. r/ y9 w% w: N6 @& N& i0 X7 x
% g4 F D4 U! h3 M$ R
9 v: y t4 I5 O' _2 G" d
训练了15轮的效果:
2 O g- F* C4 W; p! k' k6 T i7 ^! |7 @2 N
4 p! ?2 H9 P, n
* l0 y& I8 t. E5 T- m/ _
& a# j) x4 E y' a$ Q% G
' J5 t1 Q. ^( B% H2 \& }& p$ }" W
* P" U Z! v+ R- W6 O' C
训练了30轮的效果:& ^, O* N4 P2 N3 H7 h
y: a4 I: q: @9 g, H* b/ h) ^0 N- X4 w1 t! Q: g
; a/ A+ V( b) W% K0 o
K& }5 ]; |$ [$ A
3 V6 q( `2 u. o$ _4 c
7 F3 |$ |0 Y0 m4 Q: G9 X* J& c训练过程:% |$ `# |. r/ p" L( f
* i7 Y. i6 [" B. x8 f$ ^2 r
9 J6 S: A1 X, p
9 Y! q4 J) e% L( m% b3 x* ] ]* ]8 f( p& H' l' Y8 M
/ u# t# ]2 Y: M4 k# r
* N; s% u) N* _" a. o q& {! p% y恢复最新的检查点
. }& T( i& X# V1 N8 n. f) I
6 I, l e7 a4 w" o x1 j- ]! x+ j: [8 U# K
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
8 q6 N: z, B, o6 m) O六、评估模型! t, @6 M; Y" f0 s( f- ^+ H
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
2 M# t1 ?/ l4 W2 ]" a5 L" i9 g3 h; \6 V
1 \1 B# B1 B: \
# 使用 epoch 数生成单张图片
& Y( I" Z r* d. y% @% i+ o: ]5 ]" Edef display_image(epoch_no):$ C) a3 A+ Y; v1 V$ j
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
. }$ j4 {- ?& x. A) }) e% C
: P. l2 X& k* I* P, v: C) Kdisplay_image(EPOCHS)
5 n6 M. }& E. B: A( ^" Qanim_file = 'dcgan.gif'
" J9 e6 i" m' x) [# _ $ n# ~8 L2 `3 P4 G( x8 u# x
with imageio.get_writer(anim_file, mode='I') as writer:
: D7 r+ Q" w" o* s- n% k filenames = glob.glob('image*.png')
9 F9 H) C$ H" z/ h filenames = sorted(filenames)$ S% t. E) ^/ R% W0 [( Y) R: T
last = -1+ e7 g$ F, R8 u& c# ^( Z2 }
for i,filename in enumerate(filenames):
/ Y- @, S6 a# A frame = 2*(i**0.5), J1 t/ x" V7 v) [! T
if round(frame) > round(last):
2 W0 t2 U1 R8 s7 N8 t0 q last = frame+ f( i- e7 R2 S8 |8 N1 N
else:) F. F+ C6 f' }! w, @! U
continue
2 B/ A. Q2 h) e% x+ d1 _ image = imageio.imread(filename)3 B) N/ E! H7 l$ ]: ]
writer.append_data(image)2 n1 J+ M0 u- s# E; X* I; Q
image = imageio.imread(filename)
4 O4 O T9 d. i- p writer.append_data(image)8 w2 H1 m5 e. L
% Z0 I! f9 l- U- M# t: ^import IPython
2 W# @. U; S: F: v- v- ?if IPython.version_info > (6,2,0,''):
4 k8 h2 S% ] W- Q6 T" {2 K j3 g9 X display.Image(filename=anim_file)
7 i7 U! e' Z4 H
, j% b6 n. ^5 J' l$ v; G% J: O) x- G! t- {) y
0 ~1 d2 R1 Y4 J' _
S, n5 X, p4 ?: v N( U" I完整代码:
7 c q' u' }0 i2 D' H m) a, q ?# r* P; `. Q% f% {4 b5 {
; N) ?1 ]# S6 @' {7 W1 o
import tensorflow as tf" h) r/ r; A; G) u
import glob& I* M8 `* h" Q; E& W! a* L/ O
import imageio
5 |8 W y. q# L, _+ i8 Timport matplotlib.pyplot as plt
" a9 |/ W* k! E* i/ Rimport numpy as np
9 H$ P. u! Q- ^! e2 {import os- j4 \1 g- z! p! j3 b: }8 M: G
import PIL
! T8 }9 k, p4 q: e- D* w7 Sfrom tensorflow.keras import layers7 d6 @6 D8 e$ g6 E: _
import time( X. j; f; J. T% x
! j6 ~5 f" r" \from IPython import display
+ t" I& G8 ]0 P) v- Q6 R: {' C
3 M' E; g) G6 k4 m( f5 ~; ]8 A(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()- V" j" g% q/ ~0 q" e
# T# P& z( B) `1 P Ttrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
3 w* ^) ^& [; T Rtrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内( y" ], b! `8 c; G7 S$ j
1 r: d) C0 f) e s4 mBUFFER_SIZE = 60000
! I" T% L p" P! K, C" R6 ]BATCH_SIZE = 256* j4 ]$ ?) B+ B* X- t2 ^
& h) h, L7 v$ P# 批量化和打乱数据. v6 h& l, b& T% Z6 d$ y6 F% @
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)# l! g! ?& ^- k3 B5 B1 I
$ {- Y. h1 g% |
# 创建模型--生成器; ]. D4 M8 e5 T2 h5 R
def make_generator_model():
, ]) W3 y6 D# e* ~* y! D. J model = tf.keras.Sequential()2 [2 C5 I" a) D8 H
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))* W% [) I! ]; b" g; v* o6 I
model.add(layers.BatchNormalization())
* s4 d. a/ H: ~! [: Z& G6 y5 ~% L model.add(layers.LeakyReLU())
m3 e0 ]9 U3 H( [# N* w 8 b/ q" F) ^! L3 k/ G, C1 P/ t4 v
model.add(layers.Reshape((7, 7, 256)))
2 R+ w, b" q& v: C. q( T1 n assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
; o0 D. f3 G+ j 2 y9 F- |4 Z8 h3 |! z, }/ T+ {
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))$ R- V @" G4 ~/ J& q
assert model.output_shape == (None, 7, 7, 128)( S0 I! U* F( y# c
model.add(layers.BatchNormalization()); n- T$ e* x. R: i5 Y. v5 r! w) R
model.add(layers.LeakyReLU())0 a% D/ ~$ e A, X$ }
% I3 ~& l; I+ n: h4 I0 h model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))6 j0 A7 Q* P; Y" H$ l0 H( y
assert model.output_shape == (None, 14, 14, 64)% b/ d* X& |' V+ d3 |
model.add(layers.BatchNormalization())
6 j% @/ c4 }4 Y p& P1 M$ I model.add(layers.LeakyReLU())
) S1 v/ M- B$ I0 k& X _* ^ 8 |0 K. v' w6 A
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))0 E* X; _; t; S8 c; T
assert model.output_shape == (None, 28, 28, 1)
, p& |6 i0 [6 C7 C3 C/ T) b* q, O# U
. P( |5 c2 a: K# C+ c9 b. H9 r+ Z return model
9 j) ]2 b( _. n, ` 8 z, t% r* w: \2 m, h# I0 u: N
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。, [$ N2 c4 T+ r' |2 Q
generator = make_generator_model()4 {- R; P6 u" s0 I
7 a- t$ w' V% k# Z& h* d9 z( a
noise = tf.random.normal([1, 100])4 S7 h" K: D/ C7 Q) J p; b
generated_image = generator(noise, training=False)
+ }, a7 `4 W9 ?. I9 Q
9 E- c' C9 Q& s% Z/ X8 ]& Rplt.imshow(generated_image[0, :, :, 0], cmap='gray')* s' b9 F- Z9 B% v; w6 P' E- G/ F* z
tf.keras.utils.plot_model(generator)
5 _/ J# |8 w& W( i( ]5 Z4 m 9 s( s9 j* q2 n) W7 K5 B6 Z W
# 判别器3 q! |5 g! C0 g9 [0 t/ r9 D
def make_discriminator_model():9 [- S6 ?6 A& H, M$ P3 U* t) l" h
model = tf.keras.Sequential()% z/ r1 ^3 G$ `2 u7 @5 k/ E: a
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',; m7 k( g6 J0 D! w9 V. ~
input_shape=[28, 28, 1]))) l" X% }0 l0 f, o/ r
model.add(layers.LeakyReLU())
" O6 y. Z* S$ |! q X0 D* G model.add(layers.Dropout(0.3))' Z9 B8 l) v: a. b
/ D8 f1 P1 ?# s$ `
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))2 i: u8 m$ t* Z6 K9 ^1 a
model.add(layers.LeakyReLU())9 S1 r# ?3 o; K: f# L# Q' _
model.add(layers.Dropout(0.3))
" A) }6 ]9 z" S" i" ] 8 h& c) h# \ u, R& T# L: U; S y
model.add(layers.Flatten())" j8 B8 Q0 Z* ?; A. D& ^
model.add(layers.Dense(1))
2 `8 U( Q! A# c* g ) U H5 t, c9 O; F7 j/ n( n
return model
1 l- a: Q5 B* Z( \% ^/ ` + Z% c. g7 K' X8 A
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
3 m$ {$ Q& A; j2 d! g; k9 d& qdiscriminator = make_discriminator_model()( y% `+ c5 [% \7 k2 z, p
decision = discriminator(generated_image)
2 _; W, T+ l, U) uprint (decision)
. q4 }# d1 ^! c, f
! Z! B B( z [, W% v- J5 o# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
# G: M L; _' Z1 @0 M5 O% Hcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)& q" c" r- \* ~' \, I% j7 _; U
2 D& L' S3 c" c- M9 I) o; e
# 生成器的损失和优化器" _( T; M8 l4 e/ U; j
def generator_loss(fake_output):6 k- }% A" G$ J, \0 A* S% D- k
return cross_entropy(tf.ones_like(fake_output), fake_output)
' m; }* U2 p4 k W( D0 X- i3 Fgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)
, ]* P4 q6 j# o* _/ v
c6 j! S% B" u2 u6 o! m5 G# o# 判别器的损失和优化器
( [: o9 d0 W: idef discriminator_loss(real_output, fake_output):
9 M5 A' w& C9 ]: A real_loss = cross_entropy(tf.ones_like(real_output), real_output)
# m& W% c2 m1 S8 l5 I& C2 v* ` fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
4 D9 t/ n" p/ g3 G( b( M total_loss = real_loss + fake_loss
$ e0 E( w7 _1 Q/ ^# h return total_loss
& T8 L& {# l5 h0 i! Y9 t0 Udiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
8 Y' ^6 o Q% V* _+ `' K
$ u) T B$ P3 \# h1 O# W0 x! D) `+ }# 保存检查点
9 P/ F; u% }# i( u- wcheckpoint_dir = './training_checkpoints': W: Z8 U: ~' t; u6 \
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
9 Z6 E, l1 ~& j+ E& Ucheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
# A m+ x& S5 g& w( n/ D" F, W discriminator_optimizer=discriminator_optimizer,
( E2 v1 k E( H$ y* d& g generator=generator, ^! H( a$ s- j9 F! d$ A. M
discriminator=discriminator)' L/ {$ u: r3 V
+ K! c. h$ ?" E8 J8 P) K" ^
# 定义训练过程: |: Q, W7 ]1 B3 n$ ?1 |
EPOCHS = 50. L7 }4 Q$ C. z5 S
noise_dim = 100% V5 w4 l$ L) p3 k2 ]5 b
num_examples_to_generate = 160 B: e. O* _/ a$ X
3 A+ J1 U# Q* f
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)/ p! n2 T" n1 O
seed = tf.random.normal([num_examples_to_generate, noise_dim])
+ [; g1 |8 ]+ J) X1 O , d+ {; m6 z7 Q& S" b) H4 a4 V8 t6 V
# 注意 `tf.function` 的使用
9 p, \0 X* t% V3 O: W+ [% } m# 该注解使函数被“编译”
; b* Z7 ^7 S5 E8 e@tf.function
5 R" M8 m% A* a* Mdef train_step(images):
2 B+ i: a3 {: `" _! g* A# o3 ]3 u noise = tf.random.normal([BATCH_SIZE, noise_dim]) J; e- B+ E! d5 ?" m1 M
" n$ X+ O6 I( U" \
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
+ |# ]; B' a* ^ generated_images = generator(noise, training=True)" F) x4 |/ i3 z5 I: B# r
/ d9 ^1 r: L7 V1 t real_output = discriminator(images, training=True)6 a# _+ B8 \3 S0 c) B! Z7 @
fake_output = discriminator(generated_images, training=True), L' P" q* X2 J# X* t8 Q) w
8 s3 Z( j% V6 w6 `1 h+ e3 U! k gen_loss = generator_loss(fake_output)
/ |( ~( _3 c( c1 @# o; | disc_loss = discriminator_loss(real_output, fake_output)
; X8 u' j3 s* ]+ ]# D9 j; }% B $ i, }, R$ N; Z" F2 Z2 x7 d
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
3 c0 r5 n) n3 D G# F gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)( X; h3 l% t* _/ V2 o* x
: i! `& C& B+ n0 d: D! P+ b: m generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
% p. H; X# c/ G0 M discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
" x1 X1 g& F; j4 t6 [ " M; W4 l2 ^5 w
def train(dataset, epochs):; I5 [9 R, U' J, S# Q7 g
for epoch in range(epochs):( o5 @2 u- g4 w7 [
start = time.time()
. e1 O9 U6 S' C' |$ w, t) D
8 M9 V9 [; a2 h# A" x4 e2 N; N3 ^9 } for image_batch in dataset:
- I4 c# k5 p2 q$ i0 g train_step(image_batch)- Q1 \" L" m/ e$ }* Z3 u
0 Y5 i4 l( M+ y. H2 A # 继续进行时为 GIF 生成图像
) x D! h; ^5 Z0 t( w display.clear_output(wait=True)
# l, \$ ^+ i! B( H generate_and_save_images(generator,* i- F% d) q1 ]5 v
epoch + 1,0 d6 S& H# t; [* L R
seed)
( i2 Q4 H! B7 W# V: f6 w
% D2 F7 l) e' i # 每 15 个 epoch 保存一次模型
1 ~( Z# B$ }2 D9 I& `9 @ if (epoch + 1) % 15 == 0:; [+ i) C; k1 [. s
checkpoint.save(file_prefix = checkpoint_prefix)" W/ G. |( T& G% i$ I
- R( p7 E/ V9 U print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))9 n/ T# L& t2 U, s: |! H5 Q
9 Q( {8 o% x/ x+ w' j # 最后一个 epoch 结束后生成图片& C5 i. [6 |6 Q% K
display.clear_output(wait=True)
6 q* x$ ]) j% w6 u1 g. u( u6 p% B generate_and_save_images(generator,
9 b" H; M6 i) F' Y3 d5 D epochs,
5 M5 Z6 R/ g$ j5 e$ S( Z seed)
. Y/ i* R0 G1 \# H6 y$ {: d9 R; n
$ |% e/ I: O! \! ?8 P/ }3 N# V: ]# 生成与保存图片$ b1 e& k% a$ ^& t' P; L z7 h; G3 R% h
def generate_and_save_images(model, epoch, test_input):
7 |$ t& v4 X/ L8 m2 K # 注意 training` 设定为 False
/ \! r1 M6 M m4 h # 因此,所有层都在推理模式下运行(batchnorm)。+ d0 b. N; }! n# o8 u) a# n* Z0 Q
predictions = model(test_input, training=False)
/ w1 `# \. H* a3 K
5 P; J b- l1 p fig = plt.figure(figsize=(4,4)). q. @1 X* T0 K4 R) L
0 B' f/ C# q( a0 | for i in range(predictions.shape[0]):
7 S. {8 ?. V2 L4 c; {* R# N7 ~; E& u plt.subplot(4, 4, i+1); q" b9 ?" _" B2 v& z
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
" Q8 j; G# Y W9 A# y8 F plt.axis('off')
" Z9 R0 c9 C' m9 r1 M. A $ y- U# ~" w, g u! y) w
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
+ T0 p. U& J7 _$ y plt.show()
& h m( H9 s2 M0 F0 o
, f4 E5 y) T/ V# 训练模型% G9 P5 \% v# O* S& E2 Z
train(train_dataset, EPOCHS)1 M' ?4 T. h. D0 o4 X8 d( j8 p
# B& V! ?: q: X! u# 恢复最新的检查点1 g' ^/ z9 a1 l. U- {
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))1 G2 j% Y" z+ U+ q
3 c* J3 i6 }* {- {; N" ]
# 评估模型" a. W. l" `1 [8 j h/ F; r8 H. D |
# 使用 epoch 数生成单张图片
% s( d6 O& ?. }def display_image(epoch_no):
$ s! B, U1 a/ L1 ~3 M8 p return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))& ^) j( f( W) o. V' c5 N6 ?1 O. u* C
* d' \5 F; e2 v8 J1 P$ n pdisplay_image(EPOCHS)' H- J7 n) w" Y0 l- L
" d. `, n4 L* Lanim_file = 'dcgan.gif'
" p* t$ l: h9 Q/ \7 u1 X , x4 R$ P7 D4 [3 g
with imageio.get_writer(anim_file, mode='I') as writer:9 |3 T- f- a- T6 n8 H
filenames = glob.glob('image*.png')
' ~1 e* ?! |/ Y filenames = sorted(filenames)" @+ J- K. D# t* O$ R4 c
last = -1
0 \6 Q; A, \# }: C2 p. s for i,filename in enumerate(filenames):2 ~/ f- z' J1 q
frame = 2*(i**0.5)
1 U" v- L+ p8 M" }, w- S if round(frame) > round(last):
7 R4 k t5 s; }4 Y/ F last = frame
9 r9 A I) E! ?. E+ g' E6 k else:
" e& w* h8 p4 J, w continue
8 G" R9 F, c8 y- \1 H# p7 w image = imageio.imread(filename)( ]8 `$ n* C$ y7 O$ l/ H% C6 Q
writer.append_data(image)
% n" l# `, [- G$ E8 k0 a) S image = imageio.imread(filename); F+ N1 w: X) n, @+ w
writer.append_data(image)
. ]+ m+ f, \3 s& O1 ~: y
6 b D5 T- e0 k) p7 K; A; y/ Himport IPython
C0 ~; n% G$ e( B! jif IPython.version_info > (6,2,0,''):
" P2 O7 @, Q0 B H8 V+ Y) e display.Image(filename=anim_file)
) C: V$ ~# c- H0 Y参考:https://www.tensorflow.org/tutorials/generative/dcgan4 e9 h) L5 Y/ }; o
————————————————
p* {% f0 x' Q1 M版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。" ?) F2 J( G( {7 y$ X
原文链接:https://blog.csdn.net/qq_41204464/article/details/1182791118 h: y8 K1 o( T3 t' h* ?) M
, X* w' c8 p! C* Z# |- D' p, J3 x' f- w. S2 X6 y8 L
|
zan
|