- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564691 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174630
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
4 Q5 I, y d3 i$ O& ~深度卷积生成对抗网络DCGAN——生成手写数字图片4 y7 J% V; |9 d, a" g/ T& P
前言3 u# k; [: Y% d) @6 c
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。* u/ Y; O# w) {3 V: O
' r0 H* B( d- |, M& Z! \2 `5 J* v a
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
( r0 E& e. f# c, K$ I+ @. Y5 H+ h3 n; ?
; U- H9 y# L# n7 P( K# 用于生成 GIF 图片
8 }& I O; F& [pip install -q imageio
1 y. C) X$ y# g目录# k3 ^4 C* s7 [% d7 r* u# t
0 W! z% w) r7 r' F0 t' T0 c
g8 `* @) ]1 l, U/ G8 r: X
前言. i" d+ R* R" K2 {$ n- \* D9 C
+ X8 H& X7 A; B0 K/ A7 n% A8 O' A" C5 _3 s7 X% o5 E
一、什么是生成对抗网络?
; Y7 ~$ o8 J; D5 ^$ z- _/ E% o/ v3 G, W
7 }! |/ c& }6 b& a6 a& I
二、加载数据集
2 N$ z0 l( V: L8 y
t/ W+ v' V/ a- a, h; C6 S4 f! P* ~9 @' D% ~7 C- |
三、创建模型) V8 V C- ^3 G
6 c$ X3 }2 p; z1 }
, @# g+ L' g4 _6 j+ W4 {3.1 生成器4 c7 _6 q; z+ W6 g
" N' T1 }$ L7 R% F& T$ v
4 F6 ]0 M) l9 B2 f1 ]( Q; n4 p3.1 判别器
/ e& ?# l, W* _. l/ w
: ]7 D; D, @2 \0 h+ U, M
# i$ Z& ]2 A: M四、定义损失函数和优化器
! ]/ V+ b0 Y' t$ K/ z
0 w% Q' Z! H( s; p% b/ }6 K. y: e2 `# X; l7 h$ H; i
4.1 生成器的损失和优化器
' L/ `8 o- m# P7 S! ~- y/ v, y% K2 P8 G4 |8 c- ]# ^
0 V' C' ~1 j" D: U. \% q' i4.2 判别器的损失和优化器7 G) Q$ r2 k% [8 P* P/ Y5 a( Z
5 A$ ?. `' r; i0 a$ i# ?8 a6 h; O; P9 ~3 b, Q
五、训练模型5 r# D8 n% W+ w, I ^+ h" x3 P
. y5 @4 w* U3 D1 j
- a( B5 _6 N# a6 {* ^4 Q& j5.1 保存检查点
; k, j" `1 l% {; k* G2 [4 ~! D9 c8 Z' J; a, e, o# Q! f; G
6 t3 j! z; }/ j2 ^5.2 定义训练过程 o3 k9 P. o. X' l. d
. z# h% P. R* t6 w
/ Z7 N5 t5 u) F/ e5 v8 g* T5.3 训练模型
+ W- K4 p/ ~$ K" g+ c! h
6 w6 G% d& G% W* v" Q3 }$ p
, R U0 _2 c) v& e A0 ]2 T' a六、评估模型
* U- [+ z; \* v' |+ r/ ^9 h, Q
% \% Z$ S1 S: m* j: \( _% |0 p$ @7 c$ ~* g/ b
一、什么是生成对抗网络?0 w: c: u7 C8 ~" a/ c7 B( @- q
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。5 z. K; W7 A! S
9 n8 P* d# D. T) H5 Q% I0 i% c9 ?+ X5 V6 l, i" b
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
, O0 ^6 {9 K$ b2 Z/ l3 f
4 Q+ V( ?9 i5 h9 v6 m
& w$ ?9 h+ J O' ^: g* b判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。2 z- Y& h$ h7 K3 J1 A/ y
; j( T; h9 Q7 D$ o: c& J5 n# q/ ]; R8 g2 h/ ^
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
, `& y0 ~' P! c. Z+ z
a, g; [5 A \" n4 R- F, Y0 U5 A: x3 r& x
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。9 A# C1 X3 ]5 P- Q8 A8 x7 N
4 O( C. N l" G7 c3 w. \' @3 K# Q7 F! X" c3 r2 @! a
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。$ I% F5 N& R0 }% l" P
3 j6 m, a4 o( Y2 x' X6 I! K9 ]! S$ `' w
二、加载数据集
& l# S! A7 j8 W9 I3 O3 y使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。, y! X) ~1 O5 F `! f1 }
3 k/ Q% b* Z p4 @* A1 z
; n+ ^# @) m& v" w(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" x8 x8 z. M, t" Q
) x" A4 W) O2 Q3 u( r, l& G0 ^train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')2 L& X7 H5 z3 N B0 Y/ L; g# [- o
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内- N: d2 M4 p' }
+ r+ ]! i4 l2 w' v
BUFFER_SIZE = 60000
- o, k$ z' Q% {) ?/ MBATCH_SIZE = 2567 O! W/ _, L/ d% _4 j: N; T8 C1 F
8 w }. D' o: e c+ j. k; y
# 批量化和打乱数据% }0 U% G; ^; V1 n2 X! \' |
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)$ r; `) g+ S/ Z' [
三、创建模型; t' m6 @5 d! x z+ P
主要创建两个模型,一个是生成器,另一个是判别器。
" W! v9 c* U( d+ R( F1 a) U% z) y1 f1 e
t& [7 {. ^4 t6 C% D3.1 生成器
; K% }: I6 m7 {' E3 c8 u0 z" A生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。" k8 n) a9 E6 L/ S5 c9 n' h7 A3 E' A
8 r' W0 \; ^$ ^, R. D
2 t0 b+ t6 s5 C* @7 v然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。9 K' x+ ^2 f5 g: I+ g
" s+ _9 O; p& \% H8 U4 c% ]" f' s" E* K5 K& R5 k
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
) M! w0 e) G: N* Y3 @9 N/ L
2 j* }7 P' l3 I8 I! L6 K
3 \+ i- j+ U& c9 V7 ydef make_generator_model():
3 b7 D4 z G, v3 m& u6 K model = tf.keras.Sequential()
9 s: _+ u/ A# p# }# ]- h model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
_6 @) C) A r# S7 z model.add(layers.BatchNormalization())0 ~; b3 u" {4 Z
model.add(layers.LeakyReLU())
0 F7 s/ H" F, V4 s " d1 F, {" d" F
model.add(layers.Reshape((7, 7, 256)))6 _$ I$ o. ]( g/ J- G. D+ D# E
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
3 \0 ?+ v& I4 U5 P( P) D
" @- D( ` l Q) `' @7 \+ V model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))3 u) `1 x2 ~% n4 `6 B
assert model.output_shape == (None, 7, 7, 128)
! E e. c, n2 E0 e6 \ model.add(layers.BatchNormalization())# h/ q, w5 U, S+ \6 r6 m
model.add(layers.LeakyReLU())7 R" a' i3 y5 f6 ]
5 g, f# `, ~2 _, \* n" y; S1 M6 M model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
! V' e, R, U& j4 {+ t assert model.output_shape == (None, 14, 14, 64)
; D* h" p1 n: y$ g5 T y model.add(layers.BatchNormalization())) K2 B) b2 D$ y8 R
model.add(layers.LeakyReLU()), W" M1 b- Q2 U2 D$ b( a: v G
* x2 \- r- o: q5 z! q model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))! T z( q$ {; \
assert model.output_shape == (None, 28, 28, 1)
) P# S: c ?" J3 R/ u
+ h2 F; k5 z0 A return model/ [0 H5 n& \& ^2 O
用tf.keras.utils.plot_model( ),看一下模型结构
2 }* d) L/ t& a& n2 |% o. _, _. q. J+ i5 V' w
( V$ i) ~7 I O, M: g- @2 w, I
% v8 b2 f9 p: a# x
' b a- `* F/ P
5 `: H7 R9 i# a7 x用summary(),看一下模型结构和参数7 M* O9 U; k( W, o9 ^
7 j) d, \1 E7 s2 m
8 p4 t8 p( @: n& M
T3 `4 X7 d% T8 U a; I. y
3 d) |$ h# }4 O
* p" A3 `" H! V4 J5 q! L0 ^
: p9 M/ W" K- u1 q6 R! p使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
^, U k, M6 h8 o8 D: s# y
0 q6 @- m3 R# K, s I5 ~1 E
6 [* n2 ^1 _* ?! E5 v6 \* h* kgenerator = make_generator_model()
1 x7 p% u0 G) @ 4 h5 x# _3 N& C+ D% P5 L9 j; q& z
noise = tf.random.normal([1, 100])0 O. B3 E. A# e0 t: }
generated_image = generator(noise, training=False)
+ }/ N6 Q5 K# a6 s 6 u- b+ O; s' T+ d6 [, |
plt.imshow(generated_image[0, :, :, 0], cmap='gray') N# m- [8 _; u6 G' c
# a# O7 K5 u7 L! S3 g
$ o: v% c6 r. B
6 a; {& m( F' V! X! Y: n+ L
8 F6 V! M7 M9 J3.1 判别器
0 B- p: y4 z9 c判别器是基于 CNN卷积神经网络 的图片分类器。( E9 h% a5 X" Q/ e" ^
, o" {0 R0 q2 v! `
( E; V9 p$ K: mdef make_discriminator_model():. z5 d' w* M9 A. M
model = tf.keras.Sequential()
4 L, s1 f- W. J, l, P9 Z; ? E model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'," ?. y: m, K- t4 d0 d) H# P
input_shape=[28, 28, 1]))5 _! C/ S& a$ u( g! U6 T' d, D
model.add(layers.LeakyReLU())
% q5 z) n) f# V model.add(layers.Dropout(0.3))3 m- S8 O( K( Z1 ~
% b: B4 B. @ _9 G. [) b! ~
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))/ d- p2 s0 _: t: e: K( [
model.add(layers.LeakyReLU())+ a0 p0 p9 B5 W% I- W9 m
model.add(layers.Dropout(0.3))
- {3 `' t* |. A $ O2 u4 `3 v9 J v7 P5 Y
model.add(layers.Flatten())
- Y# [$ V( w& \ P) y3 E7 p+ k model.add(layers.Dense(1))& d) E( L( }; Z( y, P
) g. L1 r; x8 e1 O0 q! p: h return model
) p& L5 d7 R+ _3 @4 N3 I8 j用tf.keras.utils.plot_model( ),看一下模型结构 l4 j/ v9 ^3 i) X1 H( F" z5 m
, e7 p2 g1 X" t/ x- q( V& ~
" c7 o2 q* v8 I" ~
/ F/ [. |: v: V
x# ]4 g* L/ X) A# O4 R8 y
1 I. b& B' b1 p! B) f9 O" ~! s/ ?# i9 n6 g, R5 m$ C+ `9 A
用summary(),看一下模型结构和参数, A9 i2 R/ V" F; a- {: V( v
/ h! p; O1 F" H# v' s2 T+ X8 j7 ^4 A# G Q
& ?# s* k' r& h8 w7 \2 U* [! _! P5 O+ o1 U+ v' e# R; d
, e- }7 I. z5 k0 [& A* Y# K3 _* t! H
) t1 |& X- F4 H! g1 L0 v
四、定义损失函数和优化器5 [8 O2 q3 |/ o- V
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。! s9 D8 O' F; s( V4 v* C* | ^7 V
& W+ r) f; K2 w% R/ ^( A0 p! J3 ^8 C; m7 ]' t# ?! {2 R! ^
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。% H0 [, @: H( M7 y
! b% O8 W: k+ C
. D% l5 ^0 d9 y2 e5 J# 该方法返回计算交叉熵损失的辅助函数
- L5 x# Y1 u+ C i( o7 Gcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
/ o# Z! P0 n4 T6 s! P' R4 y4.1 生成器的损失和优化器
' H8 w! |; F% M) V1)生成器损失
; f% |% u5 d4 B7 W, G& J5 ]. W2 c4 ^
( [* e' z5 J; P" ^. v* u* O9 h( o3 k生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
~; n' Z! S. ?, ]- Y. j7 f- ?6 @; u- s% A, \+ ~
y0 ~+ a. E2 T' B4 N
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
, l* ~+ L0 ]) Z8 U7 d
# F/ Z9 J1 D/ R, E" v( q9 w
& H2 f& c; k, H: o8 N* C( Wdef generator_loss(fake_output):& `1 n% w+ h; T1 j% S7 i
return cross_entropy(tf.ones_like(fake_output), fake_output)1 j' T! J% P! x+ A- s! R, P" T
2)生成器优化器/ P& w, f$ {. f. N7 ]- g
& O; ]- X/ H$ r/ h- T0 t3 \5 L
, y! r7 D$ Z2 R/ C5 Z0 w6 b5 K
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
+ Z0 o! w# B; S( {4.2 判别器的损失和优化器! U# z1 ?/ u# |6 V5 h
1)判别器损失
% `/ ]5 F% R+ d5 E# ` t5 m; |: w7 A2 ~
$ P; R5 i g3 G! B! q! ?判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。 q' {9 D0 m d5 \: M. F7 k
" B" ^% ?5 O6 _& z7 {8 \; L. g1 X* ~' v T
def discriminator_loss(real_output, fake_output):
; ?# x5 f" L9 m; Z! Q+ t8 H real_loss = cross_entropy(tf.ones_like(real_output), real_output): b/ P$ g U7 B. N
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
( w+ h( e+ H! ^9 ?. N5 ^ total_loss = real_loss + fake_loss
; N0 E6 {) T: k return total_loss
: M7 \9 ^$ y& x! W2)判别器优化器" J7 ]6 r# t9 |1 x0 w7 n
0 v$ y5 Q) {2 ~5 b) O x- }
, n1 S( C6 L/ h( b# F8 V) Qdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)3 O2 B" m3 r6 w
五、训练模型9 k3 Y: h6 E" D" c8 c5 `1 s" }
5.1 保存检查点% @7 h$ b4 M: A: a- r0 M4 ?
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
$ f( A/ i% G0 N+ J
/ k8 k. c- T" l+ W" W$ h* S0 b9 x5 z5 m, k
checkpoint_dir = './training_checkpoints'3 Q6 E2 |5 L- z6 J! V
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
( G7 r, T# d' A8 Jcheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
2 _2 n2 U+ P. l discriminator_optimizer=discriminator_optimizer,
( |( j9 \2 _2 d! T. @ generator=generator,
% e3 V& I' v) H F. V discriminator=discriminator)& {9 w% r. V: ?! i
5.2 定义训练过程
* P: z) a5 J5 w- t# x! ?7 ^EPOCHS = 50 G9 E; J- t, @- B
noise_dim = 100
0 M! D) T6 ~& ^num_examples_to_generate = 16
4 `0 t! ?" T' _+ `6 b: h! A 2 c' u. U5 m L( z
# ^8 ^; f; J. S0 g& R# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
' y- v5 j$ g( |% {) @seed = tf.random.normal([num_examples_to_generate, noise_dim])* s, Y5 w% D) Y1 P4 p
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
# }& B/ w/ D) r) k! Z" H4 i
0 _* Y! S- k8 `! f# B! Y
# W% d. `/ P( {; E- f5 m+ k判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。* @- u6 h* {, P* P* ]' n: }0 Q& x
5 P. O1 _: D M
# y9 I# D4 [( P+ a, g
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
- Y8 D- m {7 h: S9 U5 j6 @7 B- i5 a+ G
0 q7 ]" P6 I0 |, |0 _1 n' J
# 注意 `tf.function` 的使用, }( E6 I5 P4 `- S
# 该注解使函数被“编译”% S: _+ d4 b* J! l! Q
@tf.function
9 u% s4 v6 T; F% s9 Hdef train_step(images):( L( O! C; D, }9 s$ y7 Z& f% @
noise = tf.random.normal([BATCH_SIZE, noise_dim])
; H1 |+ r( d4 w, D
7 i+ N \* x, Z. r with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:$ c4 r7 b6 J1 i6 ]1 b9 c- S/ q
generated_images = generator(noise, training=True)
1 O& R' D C0 `2 G % v: q% |& Y( H' C7 F
real_output = discriminator(images, training=True)
, B% W) V9 ?. ~4 K& z& I fake_output = discriminator(generated_images, training=True); ]( ^; d- [% `
4 [8 r; C9 z0 f% U6 k- s; Q gen_loss = generator_loss(fake_output)
! Q% l7 S7 Q% p5 R disc_loss = discriminator_loss(real_output, fake_output)
* z' T- j' n8 a3 n 7 Q) j3 P% E: E' Y
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)) M2 r# I! f, Z/ C( T
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
) T; M, N0 i; m 3 f% \7 l) ]* q; q# q6 L
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))6 Q' y( R1 d6 @) n8 ?3 U
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))5 M8 X" ]' e3 j9 C N% Z" H( c
. j0 _. c8 z1 G' t' W! D5 V
def train(dataset, epochs):7 y! `* M3 j) C0 y& q1 C/ x
for epoch in range(epochs):, T1 M9 V8 h2 d/ W/ v' o
start = time.time()* O+ L, |4 p D
4 Y+ |( _/ U0 ?2 v- i: F9 Y4 G for image_batch in dataset:
# B: C: F0 t- [! q train_step(image_batch)2 W) U* m" @0 g- [( I' q5 R
1 c, _# A: u3 u8 q5 v( g
# 继续进行时为 GIF 生成图像6 D0 |' }; _7 V/ [5 {0 n
display.clear_output(wait=True)
( I- O" W" F' X6 X1 S: G# p% E' S generate_and_save_images(generator,% l. Q, @: z& T7 v
epoch + 1,) Z2 K, p1 D1 H4 M
seed)
& D9 N$ B6 A7 c, w& p
7 F* U! }! d7 T1 ]5 i # 每 15 个 epoch 保存一次模型, F$ F! J. t) N
if (epoch + 1) % 15 == 0:! c8 K( Y1 a* N6 Y9 y$ `
checkpoint.save(file_prefix = checkpoint_prefix)
. w7 c3 q* [3 N( p2 o 5 e) j# x7 h. q$ Q* s: L, K- J7 {
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
2 Q0 G+ K: @7 e- G ; t: G, m: T( b5 N# B. w
# 最后一个 epoch 结束后生成图片+ U. L9 V E, R/ V) |( _
display.clear_output(wait=True)
8 H. m- l4 H: v C generate_and_save_images(generator,. H& w) D8 b- P! S8 ~4 ~+ x
epochs," v, }; Z4 y9 H1 ~
seed)
) ?) b& j6 T/ }9 @' O5 e' ^
' V5 C6 A" a. P! q! w# 生成与保存图片
" ^1 }/ F; i: O5 o) g0 A4 J# {def generate_and_save_images(model, epoch, test_input):
% _2 Y# G) y5 x& ]( B # 注意 training` 设定为 False
' h+ c |$ u% L3 M; K1 M& Q1 L # 因此,所有层都在推理模式下运行(batchnorm)。& I" v. Q* R" S; {+ d% S& E8 j6 M
predictions = model(test_input, training=False)
8 D/ j- f3 ~4 x$ z: t2 U L
! d" E# P2 E6 { u" S- S fig = plt.figure(figsize=(4,4))4 Z& ?4 {# |$ I5 P+ f) A4 M
2 t- \) c0 e- Q for i in range(predictions.shape[0]):: n) H2 m- K6 p- @& _6 k4 F
plt.subplot(4, 4, i+1)8 I" w2 B1 M! I+ Y j3 ?. ? d
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
5 h5 i' y- h0 m7 E- c plt.axis('off')
$ x0 ^3 b2 m1 D7 s ; E* q7 p' O# _* P
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
) i; z8 c; S4 C plt.show()
5 i4 j- S3 ^7 K7 |' b$ ]$ a; q5.3 训练模型
1 D' C" K0 U( Y! T2 u1 n调用上面定义的train()函数,来同时训练生成器和判别器。
! C! E2 }# i: f/ ^ T8 u( W
- i# A% D. T$ b% {
. v; q; X" A+ H- h注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
2 d$ v- E1 l+ r( o( K2 Z
j I; H' R* M6 |
; S7 _! |1 ?9 z6 @%%time- \: T) U, h3 M8 \
train(train_dataset, EPOCHS)
; y& J3 l$ `2 Z7 A在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。7 L0 \; C: E8 N3 r1 ?7 W' m' C J
- C3 F# y- n- I: `. Y
& ?7 f% q; `% E7 S9 x- @训练了15轮的效果:
1 g% A8 X3 } B" h
4 y* _! D- B- [: n) H9 D! I) K# Q: n5 V$ h7 ?; E- d
8 A& ~" O, I* ?* c' n. ]8 h3 m3 A2 |5 s: n0 v1 B% o+ i* Y
* v2 t o, Z: q6 F; v4 T
) F! s, F/ o. W训练了30轮的效果:7 j% c. L# f8 x( y
/ J. F5 r& D# L
L, ^! D2 }9 |' n/ @$ {
M% G* w4 }; L, X0 K) j( j1 v: m7 O6 }" G
# t% Z$ z( M1 U3 Z/ t
7 b( l$ @. B% x: b o p" ~, W
训练过程:
9 E; e- A5 C) T! \& J
3 {0 J3 m3 O3 h6 q
; [) H* i$ C) ^) _; M( D; {3 e$ s8 W
( D( i f% R, c4 X6 E" x8 `/ \
; e- P; J! M( _" Y7 z1 {% w
恢复最新的检查点/ o" I6 ~6 M% F
) J) J3 B* e$ W* v3 d( L v
" ]4 |1 @9 m. s
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))* ], z" r; z9 }; f U* Z
六、评估模型
$ w) w4 C$ V" f! C3 a9 e这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
; Q! N$ ]7 {) {0 @
/ }5 t' l+ z. C. A! W3 Z1 [" S. z2 Z
( z: X9 o' D' W! ~6 L! C# 使用 epoch 数生成单张图片
2 ?( ?6 Q! |( D& c$ Y- ^! g6 Q. f2 bdef display_image(epoch_no):4 o$ q. |* d" d* h: i& i* h8 V( ^6 R2 B
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)): F( r: N: G1 ` N) v3 z
& ~* u8 s8 D% \. `8 |% | Hdisplay_image(EPOCHS)
) G/ l# m- b4 r2 q. h3 x) }; X' a! \anim_file = 'dcgan.gif'
6 ?/ G. P& }* @/ d- g+ {
7 G. _7 K2 P3 J$ `) D' l( J- \with imageio.get_writer(anim_file, mode='I') as writer:
, E7 T+ p% `$ [1 ^' i7 q filenames = glob.glob('image*.png')2 k' b, P, e( t# @: v
filenames = sorted(filenames)& N+ T& `& U! v
last = -1* |) ^1 d! I: r# _; |
for i,filename in enumerate(filenames):% H0 |0 r+ @7 H1 b3 Z. K9 I
frame = 2*(i**0.5)4 M! [" U, {) s
if round(frame) > round(last):
' A6 U& R4 m: \/ {+ { last = frame; w3 G; \5 g& Q! y4 g
else:* I+ v* L. _$ @, g& W1 j1 h) S
continue
' Z6 H: a' n1 c& @1 }5 ? image = imageio.imread(filename)
, s# v% w4 x5 C8 I7 ], U5 j. B7 n writer.append_data(image), G5 o/ Z& T7 f: q' I" r! Z
image = imageio.imread(filename)" \* A' l" U% L* r# y- X' s- @
writer.append_data(image)$ e9 Y/ R- z* A3 n9 ?5 }
. ^% q- l) C* k% ]7 ^import IPython
. d' z: F+ ^" |5 }$ _if IPython.version_info > (6,2,0,''):
" k+ ^/ @* N% Z4 X display.Image(filename=anim_file)2 {1 t! `6 i x5 a+ _
+ W, D* s- g% d8 K5 h
/ Y( l# t8 G# v/ r7 Y) w% }8 x1 d! i8 G% S
% r. Y$ \5 M2 v' |/ @0 T( u
完整代码:
% M. F/ h0 C5 P) p; `
+ c) [2 e$ l. x& j
% L- Q1 f+ K4 w5 Gimport tensorflow as tf% V9 ^2 _( N: t* A# ~
import glob
3 t" R# h5 b3 y( x" J8 |import imageio8 a. E. C/ W9 D) k
import matplotlib.pyplot as plt3 y3 f! h& @$ a# b3 D( w& ~
import numpy as np
, z0 W7 e% h6 I6 l' M4 iimport os
! M7 E2 p) u- O0 f9 U( M) n& {0 c# himport PIL
6 U. x$ B& c3 ^' \+ nfrom tensorflow.keras import layers K9 W, K/ B4 E6 D
import time
: m4 d* \2 [ r7 Y, N+ d! B % o/ \# P' `/ R# V' g
from IPython import display
5 a* e' ^! s6 R9 {) E2 ? - j% V* k0 k [$ Z2 {
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()* n2 f! C y4 o* ? B9 @
3 h; C6 Q2 l- V8 B$ E2 X% f# X
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
6 z, k7 j7 b* Ctrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
3 H8 y8 k1 E8 [# \+ a4 p( a9 d
; v1 r3 m0 b1 @' _& _9 {; IBUFFER_SIZE = 60000
: v$ C* p: d6 o$ `BATCH_SIZE = 256* X1 T d1 G/ o
! q" X3 x) X+ R4 U" a
# 批量化和打乱数据0 r6 ^$ a: h! q$ k3 H
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
3 b1 Q; ^; O" g2 \* N# r & D: { V' p% {/ ~2 s) Y5 `4 f
# 创建模型--生成器
2 A: l; |' {6 U7 ^def make_generator_model(): U8 r1 \6 H$ P: P, u, i( i$ m) g O
model = tf.keras.Sequential()6 j# ~5 [% `) i& w
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
( b/ k; I+ ~0 W( M+ s% _" q model.add(layers.BatchNormalization())3 P- Z1 \: N" T6 Q8 [7 `2 v
model.add(layers.LeakyReLU())
* f3 W9 I- T/ c! A* B
' x/ O5 I+ g6 r& u7 s+ c) q model.add(layers.Reshape((7, 7, 256))). N& Z% a6 v v9 |6 Q) h
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制6 T+ f1 J1 k. @! o
- g: Q) k+ |2 _( o) f model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
( U! m# `0 p2 K/ ]0 H/ Y( w assert model.output_shape == (None, 7, 7, 128): U) w" g5 d& Y5 R9 ]4 m$ Y# n$ u+ y
model.add(layers.BatchNormalization())( y2 v* J/ U( g3 s$ I
model.add(layers.LeakyReLU())4 n6 _& [* o. k) ~4 P9 U8 u0 g
8 n5 x% H5 ^; f model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
4 y* S# ^6 V" p2 G$ l assert model.output_shape == (None, 14, 14, 64)
) E; a7 b) J* ?, ?7 W model.add(layers.BatchNormalization())* o9 I8 I% V; {0 y3 K
model.add(layers.LeakyReLU())3 l7 _6 d: k+ O6 j) c- `! }3 L: M2 G
! f! }+ X- F& f, w7 I model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))- C+ w0 {2 J' w, I/ Q; N* n
assert model.output_shape == (None, 28, 28, 1)
, V( J: d" e; [2 _
: i" H4 ?$ w* ^# y6 m0 d( ^. B return model
4 N1 b. h4 g$ _
# z P1 Z7 `- K6 f$ C6 F# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
- `! S% ^- u' w! C" qgenerator = make_generator_model()
. B' s9 q. f- a( I6 { 9 r0 j. n! o: y
noise = tf.random.normal([1, 100])) e! \* ]1 O& S k7 f5 I; D
generated_image = generator(noise, training=False)) v& S5 W* L( m- j6 f0 D
* p+ Z7 W# {7 Z+ |3 @plt.imshow(generated_image[0, :, :, 0], cmap='gray')
! E3 a3 q; k) C4 v# _( a$ l. w8 ~6 Stf.keras.utils.plot_model(generator): L7 {: y' ^. `# o
B) v& Y: _) ^, Y2 l) z- _# 判别器
4 j+ N% Q* a+ V: Q) o! ]+ qdef make_discriminator_model():) a3 o% `6 Y2 V4 k+ P; B5 R% J
model = tf.keras.Sequential()' q9 r% Y# ^4 ~; J; b$ Z
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
7 H8 \ r8 S# |7 c) \5 f input_shape=[28, 28, 1]))
" B& s4 q/ X4 O6 h& h model.add(layers.LeakyReLU())
C$ B& u. e# E8 v model.add(layers.Dropout(0.3))4 D5 S) Q- L8 C0 q5 t
! b1 m1 D4 @1 `7 \& M, n$ ?0 C model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
0 p& a6 D2 a- J0 \0 O model.add(layers.LeakyReLU())' P) n1 e# |+ O4 x
model.add(layers.Dropout(0.3)), ?4 r+ ?2 c% ?3 x* X L
2 N3 a3 k( w$ l& Q2 {5 { P) A4 \
model.add(layers.Flatten())' O6 |1 P0 a! |' d& l# u
model.add(layers.Dense(1))
7 Z( t" J7 Y( g " u x3 m* J- A5 e$ h! l7 t
return model. I+ _3 g0 c }7 F& Y
% h* [' ^/ a' ^& C/ A# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。. d c9 _1 p; q) N* [0 a
discriminator = make_discriminator_model()
2 ]- R' J) ~. n$ F1 Odecision = discriminator(generated_image)
* _! q; B; @7 P! e% Q/ o* yprint (decision)$ G: S8 [8 j' c" {
8 x7 n9 X$ Q% {8 k: V2 y L
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。9 L4 B$ y6 P, }5 ? g4 t4 e) R% x
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)5 x/ v6 N$ X7 ^ K9 v
' u# O! e3 r# \1 ?4 a3 Y( |/ k# 生成器的损失和优化器
4 e. H+ b/ P5 J r' k: gdef generator_loss(fake_output):( T$ ?* A( [5 m+ W: @( p0 ]
return cross_entropy(tf.ones_like(fake_output), fake_output)- J6 D5 w2 M5 j/ J* G6 H! G
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
' [' S' o- K) k' L# G0 Q$ y6 ] ) |: T1 Q3 d4 h; ?" A/ b& c$ y
# 判别器的损失和优化器
$ u& f; Z: ?+ h/ l5 K) M/ V/ t* {def discriminator_loss(real_output, fake_output):. B& N# | d2 G; o( ]4 c6 D
real_loss = cross_entropy(tf.ones_like(real_output), real_output)9 ~0 N8 M" `, k
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
( w- w" z- Q A3 h7 W total_loss = real_loss + fake_loss
! J3 N e. C2 E; w return total_loss
A% G" A1 T5 R8 B9 odiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
' H0 l* {+ k0 j; |- k, f: O0 u/ P : H ~/ ?' ~7 K9 A A9 @
# 保存检查点
4 {/ j# `: S8 m5 Qcheckpoint_dir = './training_checkpoints'5 R# }6 T) s( ^+ N; \6 f
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
5 }- f# i8 G5 d# [; y2 B, Ccheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
2 Z- r) v8 U' X discriminator_optimizer=discriminator_optimizer,
; D+ c. D7 [2 c8 X( p generator=generator,4 Z! a, z; Q! }1 |$ V: J
discriminator=discriminator)
. J$ M# W ^1 [0 L4 ^& n* c- i
, V% ~& v* R& A( `- E7 [! P# 定义训练过程
2 K$ p2 N9 X$ C2 |* f5 C) GEPOCHS = 501 h# d3 M- d8 Q+ Z# L- s# O
noise_dim = 100+ c3 y6 z; E! a8 k
num_examples_to_generate = 16
) L. v. J8 u1 f* }
* `. m" ~0 A9 c0 w1 W. T# A# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
7 s. U4 O3 Q# T, ?seed = tf.random.normal([num_examples_to_generate, noise_dim])5 v" N0 k+ c. x# C: ]0 q
) m( \1 U+ W7 {" g6 ^" `# 注意 `tf.function` 的使用
9 P x* W6 R' C1 D8 ]# 该注解使函数被“编译”
$ O' E& X: ]. P& {6 C2 T$ {@tf.function
- B [5 g. c- ~1 Wdef train_step(images):
! F. b' \) T; d# ?8 E7 p noise = tf.random.normal([BATCH_SIZE, noise_dim])( d0 k& A# G, |
6 o: [5 J: W, O" O4 I6 S7 w
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
0 ?' f# g) ^0 h/ h3 Z$ a8 U0 G generated_images = generator(noise, training=True). X' X Y' Y7 n' E" ?7 J, n
4 I# @6 u3 V2 @' m real_output = discriminator(images, training=True)
% R, b; s( _7 e/ r/ b5 v% ]% ? fake_output = discriminator(generated_images, training=True)$ J! t/ a. Q0 u C- C6 e+ J
; Q# Y. z/ r% S, U* m gen_loss = generator_loss(fake_output)
' a/ n1 Y2 i `% ]3 @! ]7 W disc_loss = discriminator_loss(real_output, fake_output); ^2 i5 L X( u$ y3 D( E
# D* d* ~2 y+ r, j& X- ~& H gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
" u4 L: r! J1 }9 t$ t' ` gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
* A( K5 q' i7 R. f, \( Q ( ~0 u! t6 y$ b
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
$ `: W$ l6 r2 P* X7 ^' _' x discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))" {% Z6 j. X* ~- V7 q! r, y' ?% K1 n: B
0 n) q7 a# M( z2 w _$ Adef train(dataset, epochs):7 B5 x% _+ Z1 c$ P3 f* g
for epoch in range(epochs):
, F. p- t0 w e! a' d# Y' m5 N start = time.time()
& Q5 J# A$ O8 }4 P f 4 O1 b( [/ J' n+ |3 ?$ ~
for image_batch in dataset:
; u# v, a% p6 b) O8 t( Q' m1 e train_step(image_batch)7 Y6 `; i9 ^7 O" {/ P) [
) u* A+ J4 o7 k, O, s" A # 继续进行时为 GIF 生成图像0 G+ ~ J6 X3 E. t% g* s6 i% t+ _
display.clear_output(wait=True), k" y" y _1 W1 U$ M* m
generate_and_save_images(generator,
! a5 Z+ W/ f: V- F# x& H epoch + 1,! y3 ~8 D" c# r6 J* K( p" v
seed)0 l8 b# d% g. [( f; L% g
2 }+ o# s3 v& q. F R' n # 每 15 个 epoch 保存一次模型# {: `) H+ N" H/ }
if (epoch + 1) % 15 == 0:: i3 }. e" ]2 s" s7 T, p+ i, E2 L
checkpoint.save(file_prefix = checkpoint_prefix)1 t3 {; k4 A, U0 z: L
: G- {+ R N* H+ d
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))9 h3 j: X" ^. @0 z! p& I9 o
' K1 q3 K; w% a1 n* k2 [ # 最后一个 epoch 结束后生成图片/ c, v* h0 k7 ~3 B2 g- e2 \; [
display.clear_output(wait=True)
4 k/ m1 E# p6 m H! F5 }5 ~ generate_and_save_images(generator,$ @' J* f5 h/ f6 X, o% Q7 ] r
epochs,
! P. X# ^8 _9 D9 t. T seed)2 p* q5 p( ?# a$ q6 K
9 r9 Z4 ?% f" a# ?; d, R7 S# 生成与保存图片# {7 x! u4 g% R/ t6 [
def generate_and_save_images(model, epoch, test_input):
& {# a6 t5 m/ s3 h: k& i # 注意 training` 设定为 False# V6 f) H: {( j) F. y
# 因此,所有层都在推理模式下运行(batchnorm)。- |9 B7 ~6 Z1 c6 A
predictions = model(test_input, training=False)+ I* C6 c" w# l4 X1 d- d1 y8 S
; `0 q6 h, C+ e. l
fig = plt.figure(figsize=(4,4)): r6 ]6 ^, I/ h) }% D
1 F. }' m5 A) @' [9 ^/ U; }
for i in range(predictions.shape[0]):! U& D( P& | o# m+ U# F/ ?
plt.subplot(4, 4, i+1)/ F5 I7 j5 p& i. ?% Y5 f2 C# o/ t
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
A: z& r6 S# o% ?' A% A plt.axis('off')2 c7 |' U; i# ~4 |* f6 W
3 [; N& B w j% T0 v
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
% d0 R, H7 ~; K* R, ] plt.show()2 f; g3 L! g$ g1 V
; g: a* W! W2 c! Y3 B [
# 训练模型
1 s/ k; \4 ?- vtrain(train_dataset, EPOCHS)
8 }0 T! R1 P$ [ a- M3 q & l% p" o9 I& D) J
# 恢复最新的检查点
$ d6 `4 V8 _! V9 W- icheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
5 R C) u: y C1 D; v2 P5 t , ]% C" ^& f/ E
# 评估模型 Y8 B. d- R. P! j# _% e
# 使用 epoch 数生成单张图片
$ T5 P B& e; N, \- |2 N6 _def display_image(epoch_no):7 H% S7 O/ s: I. t; O2 B- J
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
9 X! ^, o8 Y ~! B$ o) Z , f0 ], {1 k( X/ J/ ~ O% |
display_image(EPOCHS)! B8 p% h2 m- N# u/ g/ U8 T) H1 [
4 o* U, Y. ~2 i: kanim_file = 'dcgan.gif'
- u2 n( q' h& u " B0 m: O$ P9 g2 {6 a+ ? z1 Q+ e+ }
with imageio.get_writer(anim_file, mode='I') as writer:4 l4 k- l' ? C; L" U$ \
filenames = glob.glob('image*.png')
: I5 f" n/ N+ k& Z* n; J* `! V/ ^ filenames = sorted(filenames)
# ]+ a9 a7 [8 R last = -1
/ w% _3 O/ ?0 w8 K' ^- T for i,filename in enumerate(filenames):/ b! B* e6 k' `$ X/ \. c5 a3 [$ B
frame = 2*(i**0.5)
1 }0 I# W- Q: C, R2 m y if round(frame) > round(last):; N5 `1 |/ b8 M. x# b) V& E
last = frame
# |, z e/ W( p else:$ W3 r+ U$ R* u7 u% L
continue
3 z# P$ R' i1 p( G, Z, n o2 h+ Z+ I image = imageio.imread(filename)
$ |5 b; u( ~! q5 p writer.append_data(image)' E4 z9 B( R' @* c- I" y
image = imageio.imread(filename). B4 l( [* |$ V+ l1 x4 D, D- i0 f
writer.append_data(image)
J& V( z: U& l4 |' j- E3 \7 K/ N & Z: @) R( n# `; E+ _
import IPython
" E/ ?0 P; C) ?6 |; y" Tif IPython.version_info > (6,2,0,''): P% i2 c, q! M* X* |1 l
display.Image(filename=anim_file)
0 Q: a0 |1 B3 n( D参考:https://www.tensorflow.org/tutorials/generative/dcgan
5 q4 X* G+ ?3 T1 \* t$ T& ?————————————————- K9 D! O/ p8 U! h& L" U5 \
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
! M) C; p2 I" t; y4 ^$ u+ A/ X原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
1 G! ]- N" E% U& H. ]. i' j
. L' k% f+ a; d) o7 }, Y1 i3 P( _3 m3 q
|
zan
|