- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563312 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174216
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
2 @' M, G8 {2 N4 z f; z深度卷积生成对抗网络DCGAN——生成手写数字图片( a% J. H# S9 o
前言3 H0 Z) g+ S4 {
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
4 ^4 ^7 `# A( K6 _) r, E7 z9 \( u' T7 V. ^: q) s
' P- j- x6 ?( j' h! j8 _; l
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:0 o/ ] \, }( h" e9 ~" \
* \# U& B0 _& N/ W5 d, z# I6 `' q
" ~% d l) k. L z f# t
# 用于生成 GIF 图片
" W9 e6 E+ L: |pip install -q imageio8 U1 M! o# O/ l( c" |0 v
目录0 B( [7 ^+ v, F' }& P; ^
8 |% T8 ^. Y# O. F- s
* n& P& z9 s# M; G# B. x
前言
) Y; [! V6 u4 R. u3 m
. E5 ]) z$ f* A( H4 X9 s/ f
5 R4 l. a z% P' W一、什么是生成对抗网络?- I# i- r7 Z& n0 P* J
$ b+ u' b; B$ C; C( {$ n
# n) Q3 z3 Q0 e/ t; U4 N5 i$ g# w; l二、加载数据集0 P! m4 s8 r, v
' [5 _# ?# l U
; B: `! i5 U, e4 l8 N
三、创建模型, H w; |/ d! l1 Q
9 ]4 T$ M$ }' |- b/ ^
, w* o2 M# ^0 z' _2 ~! k3.1 生成器
0 X4 a( l0 d$ T3 Z7 M$ c
) P# F0 V/ g) _% \/ D4 t
+ O7 A/ d1 ]/ w3.1 判别器
) U) @) k4 `2 W }- h6 z" J R. s! t
" a8 R; |3 }: P
四、定义损失函数和优化器
: j. W$ `$ s9 G/ D. c' M
" M3 P$ N. a9 n( c4 [1 Y. \
" G8 J' I' a4 S1 c9 I! m! C4.1 生成器的损失和优化器
( v) t% h$ ?3 P' l6 k# u/ T3 K9 @; H: @6 X" e3 `, N
8 Z7 W" P$ q0 c- u4.2 判别器的损失和优化器4 a$ V4 x8 \2 a6 @" f
$ i8 G+ I0 q" G& |7 W; `) q3 `) M7 L6 g: y( X5 o; O7 p
五、训练模型
6 K8 D k( `& A7 a/ J2 u6 R X$ A: g, x
j$ }9 g$ o5 ~2 |2 r5.1 保存检查点
6 ]+ Z9 o6 g2 Y# a
x7 L9 y( [' K0 ?7 N9 u
, a! k! v) N$ ~( v0 Y7 `5.2 定义训练过程& ?# V) r8 `! F7 W
5 C9 \- D9 L2 S4 [& C. S. `
: m( g6 d+ x0 U( O5.3 训练模型
8 p, |$ R; u! P& P; Y. w" O- S o& T
1 e6 h% Y. B! A" @六、评估模型
8 t. |2 w) ~& n0 d( O: P
- c2 }, X9 M# z! I( G8 N$ r1 n
8 J+ N8 W6 f$ q+ O% l一、什么是生成对抗网络?) c) ?0 n, U$ X+ I5 I, R
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
" Z9 r! C! u* v
0 b: u5 W2 w6 z6 @- n m- y4 N/ q8 C) T
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。0 t, J: O9 W% N9 W% l! k
) X# E5 F6 g, D5 p* P: `" J( e( B6 m; X" J
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。; B! z; L1 V& D( @
! ~7 N5 I+ t* B& B, W
! |5 {, j4 B4 p( c8 v
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。7 e' f! f0 E- N( `$ `- ?8 L
- @5 k3 a# b6 {7 m; X. M
4 ]$ r: c T$ S. h) m当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。: m* R, f+ N0 W6 m; o; M
, |3 d3 y0 a8 ?' j& H! l; t' ^0 Q' O# _6 |
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。; `: K) N* C; ?$ P
8 F( Z& n5 }& c! O( I
* z" i3 m# U+ }' }8 ^/ N
二、加载数据集6 `- I9 v$ |3 n7 p, s5 S0 Q
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。4 _, J9 b' z* W2 z; T! C5 l# R
% W+ T+ L& p9 r9 T; p5 g2 K1 w7 m' {
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() o+ e* T& h R& J: i( [1 |$ {
, c( x! z3 ]- f/ y( s
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')* ]6 [, s: y, c+ g y1 R+ j. i
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
: N0 M1 e+ |" J0 w 2 K n1 J, K/ q+ h0 a
BUFFER_SIZE = 60000
% f+ G d) p! P9 q+ yBATCH_SIZE = 256
( N. y# T8 K& ^8 b8 p4 @/ w 4 }9 X6 Q7 J) M1 ]; B$ X+ a
# 批量化和打乱数据
S* c0 ?1 Z& V) \; x. e- J* B, Ntrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)% m. Y/ _6 r; r$ ]8 J! w5 g, d
三、创建模型
5 z7 ~8 c8 l6 S% g* |主要创建两个模型,一个是生成器,另一个是判别器。# d& n) p2 T' _
$ N+ B6 r! Z* A- V: J- {6 r6 v n p7 N: C3 @& J& }; Z4 D7 S g: Q2 S
3.1 生成器5 `4 b0 J) [! j% ?6 N+ s
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
) Z8 A9 E) u" H5 v1 {' u, } z) l: r5 d
+ v7 J$ W. r) m: F; G4 v
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
! i6 |4 r% X; q ^. r+ ?9 O* }. h7 g7 _, B. }: M
5 c. w; U9 Z7 x' X$ X6 R后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。% }0 X: B4 ~: l3 p! u( Z# C7 R
9 N9 w* l' F- g5 Z) O0 n7 U6 X
6 {& x# x- L& ]def make_generator_model():8 w. o; `1 X; Q* h( C+ h8 ?% ^9 B# F
model = tf.keras.Sequential(), E' j7 k. e6 b) s: B* S
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))& X0 t. _+ N* J& R" [6 |
model.add(layers.BatchNormalization())
' E4 a; J2 E: P& [ model.add(layers.LeakyReLU())
$ p/ N/ W: z, C' H6 P) U6 F
( r! Q+ ~" W- T) V5 m, @5 k model.add(layers.Reshape((7, 7, 256)))
/ h! }5 S, ]! q* z7 n+ z3 u7 O assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
. l A8 y4 f8 y7 O
! {) _" k( s1 R) c( Z model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))% w% d- a9 \( V$ S) N+ K
assert model.output_shape == (None, 7, 7, 128)& w$ h$ x6 }. o1 ]" x/ L" N4 G. z( N
model.add(layers.BatchNormalization())9 a) J; I. h, o( B2 y. J9 g
model.add(layers.LeakyReLU())' D- V0 q* {( E+ I( w- F
$ t( f' \ u1 u4 T# H) {
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))- W' S% S" L: s% V
assert model.output_shape == (None, 14, 14, 64)
0 N7 N. N+ y0 m D, R- | model.add(layers.BatchNormalization())
7 f9 {( n/ s+ x+ q model.add(layers.LeakyReLU())/ Y% X" z: x! f
" ^9 d- F) ]; w( T3 `
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))! W5 k1 j3 S+ B& {. ^
assert model.output_shape == (None, 28, 28, 1)
* F" Q ?+ C% K! R1 D: C3 e 9 c5 N+ v! _2 X
return model
, S4 Z* D8 |$ |# A用tf.keras.utils.plot_model( ),看一下模型结构
5 U+ {& n$ G9 x0 |0 }- l5 J) M; k% h/ s9 G$ ?3 }
0 @5 \" ?3 {7 X4 ?
! A$ U6 Y5 j ?
6 f6 F& U8 s% O9 O
: R/ |! b; u2 Y/ }" f7 b, c9 Y用summary(),看一下模型结构和参数
+ V9 U6 k! V H& [9 g0 [7 K+ R% d u2 d' A7 K3 O/ t6 f9 |
9 }4 c# u9 ?' o9 |8 w/ F3 b
8 L5 @$ B1 g% \! B8 @( O3 X2 y& C0 R) e
" M+ u8 P8 O9 h1 r! d& r/ e) h
' a+ S0 Y Q3 u$ y7 s# b
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
9 E) y# g- V0 U% n0 R3 y. ?4 U/ B+ e5 `- h4 L* J
. h) f$ {1 `' x4 mgenerator = make_generator_model()* ~1 l# I6 f! p
0 F" ]3 ?6 X& h1 W) hnoise = tf.random.normal([1, 100])* h2 D+ X% S4 u% M2 Q5 I
generated_image = generator(noise, training=False)
0 W7 j0 o0 E$ Z0 c) `, {. D5 Z6 X* H. y % r) X* c% n( @
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
% T/ }0 G3 o( ~( K" A( q9 Y
- C% b9 ~/ N) m8 w8 Q! i: M6 C0 ~: o0 X" ?! p8 @1 l3 z* }0 z* w) J
0 [" p: _! u# z% a2 }. H) P
4 W! ^% p! e, e6 q3.1 判别器0 m" y; C5 s7 o( i( o+ r
判别器是基于 CNN卷积神经网络 的图片分类器。
M% M( e$ S* e8 E- S# y
- L5 ?: H* C/ a# q! C# _& i3 t, q ^
def make_discriminator_model():7 r5 N) }: M1 L$ G" f$ y0 V, z4 j
model = tf.keras.Sequential()
, B1 ^, [, L" c8 Z4 h [ model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
) E. Q; h% k7 ?4 g K input_shape=[28, 28, 1]))# k9 O) J4 o3 k7 I# k2 F
model.add(layers.LeakyReLU())
$ @5 T$ {" n6 X ?4 F model.add(layers.Dropout(0.3))
2 @ [4 C s8 b: Q& \9 w& a2 o
; W& T: G9 f4 n5 |6 r( {% w7 } model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))0 t" h7 E, a" V6 X! X5 W
model.add(layers.LeakyReLU())
$ ^" e% m4 N3 h" H( q$ H- c3 S model.add(layers.Dropout(0.3)); r9 v; M9 b# _7 P6 ^6 {
7 E! s! ] S: r1 u% }# L model.add(layers.Flatten())
; s# d; R7 x0 F, d5 O, f/ k7 J model.add(layers.Dense(1))
. ]; b. k0 f: ` s8 w
8 f, p, H! Z% f return model
: h5 x& B+ f, E& \用tf.keras.utils.plot_model( ),看一下模型结构
* ~6 H* _( Q) P* M: x- Y
, S; x% u6 \7 A! ^: _
" l6 I/ f0 q8 |2 F$ s. e# L% O
9 k% M; K( p6 n# X2 u/ d* A% i9 X& }! t% f# O
, A, U3 Z7 z1 X5 |1 ~% D
4 D/ u" {0 a. e3 q* V! X用summary(),看一下模型结构和参数
5 Z' U- ^1 P _" `
T; V. v' S* Z# R$ g
, N/ I8 X* L5 h, _6 _
; u; K5 P) x# }+ j: f* w& Z# G6 T9 s. d
- D$ }6 U8 V+ u3 h/ Z4 R) m, B. @4 X; [9 u
四、定义损失函数和优化器! d$ `* ^- F" ^, P) [5 d
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。1 a2 F* S7 q) [" d5 N: M9 s
$ T+ L! d' z- D
9 Y6 I% k8 a) q& ~4 T3 s首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
, K$ h8 P- H, c" W
8 r8 F, m# `3 {9 x# N6 E9 c
9 c% I+ [( r3 E' x# 该方法返回计算交叉熵损失的辅助函数
! v: x j2 Y4 O9 `6 Q% dcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)2 r0 k% d* S% ^
4.1 生成器的损失和优化器
# S7 m3 p2 r- g7 A1)生成器损失( k/ ^3 R6 v1 @" q3 n- A
: q( r- m( R# i8 d/ y6 Q- t
7 N, @5 ]& y+ y5 x6 X( Z9 ]# y
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。) X! O. }3 O) H
9 L# Q; P0 }2 C l# I, b3 y4 q8 ]# L, G9 z- X* Q: F
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
! K+ Q$ b4 B- S! w1 f+ Q0 J4 }" f1 _0 G# S( n; ?5 i. d
7 F( ]8 H6 l6 k7 ?6 Y
def generator_loss(fake_output):
, }& q* `, V. k; B. l7 k3 k2 `4 c return cross_entropy(tf.ones_like(fake_output), fake_output)
& k, q" q& s& D: g1 W/ ]0 l( Q2)生成器优化器- Z6 l3 N2 R4 b L9 `, x5 Q
9 K' f! b) M: E. [/ v+ D$ P) e
# x3 b$ R) G; R% E# a3 Qgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)5 K* X! @, N' E6 G: E) O( N
4.2 判别器的损失和优化器
3 N4 h0 A7 ^# {! ~+ v1)判别器损失
( M* w3 Q% m9 r/ a1 }# H/ r
3 q5 a$ f% k0 W0 H3 W# Q+ M3 i
& X* D( ^$ t0 W# j3 N- ?* c4 f判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。: N4 ?1 ^$ l; q9 V' I
: x K5 C* ]% m# I r* G4 d+ J* x
1 L5 i% k3 N8 V1 R. _def discriminator_loss(real_output, fake_output):
- q. J+ @8 M" I* d. @ real_loss = cross_entropy(tf.ones_like(real_output), real_output)
/ ]' j* n$ H, c9 t! M+ N fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
: |2 d- m/ w5 @5 r: V/ Z2 ~ total_loss = real_loss + fake_loss
% X! u5 i t( B return total_loss& L9 t3 C/ @7 W( {( X Z
2)判别器优化器
/ F& G, R! N5 m% N: _9 ?4 u5 F, K. z4 i9 s
6 A* }0 V+ j" C! T0 Fdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)2 u3 ^$ H2 p) d- ^: R; ^1 x
五、训练模型
6 q9 E% P6 v- q$ m) y# B+ _& t5.1 保存检查点
1 h& a: M: k* L- X保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
( F7 R9 E- \# }$ f; N P% x$ u
5 l! s$ B1 Y% I6 k: E- G* h. M* |( X# ~5 z* i9 t" M
checkpoint_dir = './training_checkpoints'; T0 \9 c; w6 R; \* n
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt"); W( J$ n' w- @3 Y6 a( u# ~
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,( Q: @0 N6 c9 {3 Z
discriminator_optimizer=discriminator_optimizer,
0 E# w- p' y4 A generator=generator,
! s5 o3 A+ e2 g. u H discriminator=discriminator)
n' c) _( f; F* L! O- J5.2 定义训练过程; y; e: x. e; k/ p$ X/ {
EPOCHS = 50( r9 K' ^& a" s2 ^) {
noise_dim = 1009 i( i7 Q( b2 M& X8 I6 H- `
num_examples_to_generate = 16; e' d, P+ Q7 J: e) R; A
1 u l& y# [5 @- V) X
3 j- v Z8 {+ b. Y }* ]
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)) [( ~0 w' z4 Y! g% j
seed = tf.random.normal([num_examples_to_generate, noise_dim])5 g( [% i7 E7 A9 C# p- r
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。1 L& u, `& r2 \" s$ y
2 t5 x2 p0 Y% ~
6 L/ J2 h; `- q5 K- p
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
& ~6 _- ]1 F! k% A1 t3 K- n8 ?/ A
& U8 o, t3 F% ^+ A* _% {* }6 u两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。5 Z, v3 l2 z5 B' z, z
7 F f m1 ~" ^2 F1 e
; Q! n/ ]# D' v- u, }. R: r5 y* O
# 注意 `tf.function` 的使用
# }. {' r' H( F0 `- ~) @# 该注解使函数被“编译”
7 E2 h1 L. D- o# }4 I% b1 o$ O* r@tf.function0 _' E& Z6 Z; d; B" I
def train_step(images):
+ O% V+ G5 C* O3 r" f noise = tf.random.normal([BATCH_SIZE, noise_dim])
! C/ i4 u( y- u1 x/ H* k
8 A' f! T: p+ i9 a" m6 w0 o8 Q) Q6 } with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
8 s& L( ^0 B/ R generated_images = generator(noise, training=True)
9 H2 n' N! W% L0 X2 u
# O. g, \- C& Z8 e' B: V+ P real_output = discriminator(images, training=True)
$ A* h* p. h1 |. T' M6 ` fake_output = discriminator(generated_images, training=True)+ [8 |- ^- U9 u% S4 o
) w3 q7 {* s/ j4 R4 F gen_loss = generator_loss(fake_output)* r& v4 }' O0 O q _/ ]
disc_loss = discriminator_loss(real_output, fake_output)! c" Z T4 G0 M/ G" S4 J
; ~+ D( X5 b2 x8 Q, k gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)* C/ @" t6 a& l
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
$ J/ R; k, Y. J9 w! ~) A
/ m5 h+ h) M2 Q) J: C$ Z2 J3 ] generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
3 @/ Y* P% T* l% p7 }- I7 r discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)). u' O& T/ V1 j
/ ], f' z5 m9 H: K4 g+ Qdef train(dataset, epochs):0 e7 p4 @: y- _) z6 V+ ^$ N+ b( `
for epoch in range(epochs):) s! l: B- A; ]% H
start = time.time()6 W W) r! h7 k5 o
8 ` N1 ]) ?% J# @
for image_batch in dataset:! l! H6 D8 Q4 |- O& M$ F1 _
train_step(image_batch)
& |; ]8 [+ d+ B
9 ?! X6 O1 B/ T& F # 继续进行时为 GIF 生成图像6 K, y5 g- [5 M; v \( R5 Q7 u2 u0 F
display.clear_output(wait=True); A& l; f- @3 [5 a. ]- Y# @8 h
generate_and_save_images(generator,7 L& t% `8 `0 R, q: h
epoch + 1,
; q0 v7 d# D) M: Y7 y: g seed)
8 I$ v+ ^" m& S; ^ $ M; T/ b/ q" h4 T
# 每 15 个 epoch 保存一次模型
) C. d3 Q0 q, V9 m0 ^1 Y if (epoch + 1) % 15 == 0:( Z { [9 n( k* v6 l
checkpoint.save(file_prefix = checkpoint_prefix)
1 _ A9 T* t' L6 f ) E: n( ^* q( I; x2 H; c, m+ n& y1 Q
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))3 s( C3 R8 F! D: Z2 P T9 H6 S
/ B# b+ ~* g5 r' {4 H- l # 最后一个 epoch 结束后生成图片9 X) W( a* E6 `/ L6 Q1 c" P
display.clear_output(wait=True)
2 T: L% p$ u8 U generate_and_save_images(generator,
2 p. G N/ Y$ P Q2 l( [ epochs, a6 W1 V X" Y* U
seed)
# ~# X1 s. v: k: H 3 V7 ?' ~6 @& ?( _
# 生成与保存图片" `( t% K- Q: w+ _. Z3 g
def generate_and_save_images(model, epoch, test_input):- l' T6 o+ E! [7 t8 t1 y
# 注意 training` 设定为 False( B) n( J7 i$ F5 @! I
# 因此,所有层都在推理模式下运行(batchnorm)。
' X! z) X3 V+ l9 ~' \ predictions = model(test_input, training=False)& ^$ a9 _! c+ J* P9 h
+ c3 d5 o* h, q. F
fig = plt.figure(figsize=(4,4))
@9 |! q9 L. {( E: y0 }' I # u o1 j" I& l. S
for i in range(predictions.shape[0]):' y! C/ J7 z5 z; K. ^9 _% Q: W
plt.subplot(4, 4, i+1)8 J/ \: }3 m# _9 V/ d
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
4 R( v" E, h9 m2 B" M5 Y plt.axis('off')
5 F" K/ K' _) m& c3 W0 g 9 ]9 @5 H% I+ N+ \6 {
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
9 C. Z2 i' H/ `- C G plt.show()' w6 m8 ] g& h6 w3 B! d
5.3 训练模型# v# @$ g0 {) _" Z, x% E
调用上面定义的train()函数,来同时训练生成器和判别器。5 K! N X, g) y& ?/ V5 J
' L# C3 b4 o. I3 w1 E
0 c! ^6 i. E! k' `$ `
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。* l8 y6 v# C2 v1 F# i
9 {& e9 b% c" Z+ e: w
8 z: K' V& {" n, s$ ^+ K%%time
" d8 W# P9 L+ i7 } J% ztrain(train_dataset, EPOCHS)
. K0 d0 @! ? J- O! Z) F在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
# f6 w7 k/ z* k. X6 x/ J7 c' k, T" o& Y: K1 |# ?
7 G' F% z# n2 ~; c' C @
训练了15轮的效果:7 N) p+ s& u/ B% C h# I
, P3 {8 N2 e- v" C1 q2 Z) N) b! ]
$ {* h6 C3 ?# ?0 s! i1 R. b' w+ X7 @) m
: D$ \8 S0 c. [3 c! G
# | R& x+ x1 ], E6 m
. h: a( W! K. Z3 A# g训练了30轮的效果:! l8 _1 ]$ k& X. L) X2 _9 ~- `
& w! `1 L `' x$ u: a: u3 U2 ]" B
5 W- S% W# M; x& p5 c1 ]
- e# M) h( `5 W- g7 l( ?) l6 m( A1 n- g+ f4 I
3 {; t& o1 h, @: [$ Q
& ?6 X2 O" D" z3 x$ ?7 i训练过程:
( x/ @6 K/ K" f3 z# u$ y. z* P8 M# {, r% t
9 Y# w! \5 [, a1 N
( Q: k& S E6 l. c2 C t0 {
& q2 V- Q3 T$ e( g* d4 B ]5 O3 J) z0 \9 `# q
' O. Q0 i* t1 M# X" b恢复最新的检查点
/ b# V6 ]# V7 s+ v2 q4 b, o* l' f, }% v
3 ]/ U; G0 B( K9 I' m' Xcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))8 ?/ f* i' f: h' J$ K1 V4 }
六、评估模型
7 S% N6 n1 r; J& ^- N3 W这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。( v* c8 N. h. b0 y- }9 Z3 d4 I* \9 f( U
; q! [7 v& V- c; [) i5 Z+ z: K6 N1 q& F# n) J* S9 Q
# 使用 epoch 数生成单张图片
0 [9 `: F3 [. P3 Ldef display_image(epoch_no):
9 d" {% N3 m, n- i. { m- F return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))5 Z" Y4 {1 }+ _" ]+ N c M* t
7 }- w+ {, F9 }, x( R4 Y: F) ? kdisplay_image(EPOCHS)5 U% W1 I+ w0 B+ e
anim_file = 'dcgan.gif'* t( W* z( p( k" F7 y( I! g! D" N0 z$ t
# R5 g6 ~9 n, i; B
with imageio.get_writer(anim_file, mode='I') as writer:' b( u4 I q4 ]( o! j& { _% f0 Y' \
filenames = glob.glob('image*.png'): S" H8 q% [. N1 Z- i; k) \/ ?
filenames = sorted(filenames)
! R4 \( s) ~) o+ ^7 u! W9 ] last = -1
! k6 z7 A6 v- V: s for i,filename in enumerate(filenames):! X) B: G$ L2 S% i3 H
frame = 2*(i**0.5)
1 s# g# Y% p! ]! f6 k- R, D% E0 | if round(frame) > round(last):
, B6 Q' _% \3 a, W1 {/ s. y$ e5 s last = frame
# p2 V. ~% }( j/ R! a' c3 a2 q# r ^, P8 T else:) h1 z$ O8 S6 |3 {9 b
continue
8 Q/ Q, j- ^& W) R) [8 m8 S image = imageio.imread(filename)
. A5 o" U" U/ Y% A writer.append_data(image)
# {9 g( I: \- A! t image = imageio.imread(filename)
" a+ g% ~: ]) y4 y' q. z1 @ writer.append_data(image)% ?3 ~6 B" z9 p& n: A
% n2 k5 x2 l9 _import IPython
/ Q% G' g! i" gif IPython.version_info > (6,2,0,''):
) {' h N7 t; K8 }6 F+ Y display.Image(filename=anim_file)
2 m( G7 i7 }) F7 h9 F* H4 o# f' p$ G
9 F6 e" [ [* G V# |9 B: G# @) J9 T
! G. k5 J. I9 {1 Z
完整代码:0 e& A* L* r- l7 `( g
/ @- c( {/ i, y" h6 f7 I% a, ]9 U" ]. ~* f
import tensorflow as tf
0 h" z# ^4 S6 K3 z2 x) z; A2 `1 J% `import glob
, q5 A) i: ^: ]" y2 ?$ A% i; Wimport imageio
: y8 _$ P3 c2 O* U- Cimport matplotlib.pyplot as plt4 `$ o5 V7 \, K# y9 t2 v& i
import numpy as np T9 j( v: d5 j/ r8 I4 A
import os
1 r! o( y) {# n' o9 |5 Q$ bimport PIL. U% Q7 ` G& M0 T/ g( i
from tensorflow.keras import layers3 J& C3 y- J; M& s: E
import time
3 r# J+ W( m9 F
- D: I$ M' X# L+ Q# u8 D( Mfrom IPython import display
" J) _- P2 L/ o- U% Q& {; o
' ?1 y6 y7 a; w2 C( c% I(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()+ M2 C- Q5 a- y& x3 L v$ |5 j
6 C- b' w( [. Z9 a; B R' N' e
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
z; F* {8 I% Z3 u9 Q; b0 atrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
( s* M' W- d& A5 U8 p9 x9 w1 t B : T4 U( {+ U5 \9 @# w
BUFFER_SIZE = 60000
: V5 k. U; f6 v* o' @- x' e, ?& lBATCH_SIZE = 256
" q; S% f! E0 w7 {" G+ q $ F- Y( ~; |7 L2 x; ^5 j/ O
# 批量化和打乱数据
|9 m6 i+ w. G7 L) k9 B m y; S+ \train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)# |: y- b7 l8 a
( D5 A% |0 \, a, ?$ }, z: ?% G# 创建模型--生成器9 E% u1 F0 j( m
def make_generator_model():1 v& M1 @% q3 m5 H6 r" y- m% s* x
model = tf.keras.Sequential()
7 R" B! O) L' ]6 j! j b% K model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
* |, h/ Z. C5 k( W$ p; v I5 O model.add(layers.BatchNormalization())
* e, C4 Q) P8 |) L model.add(layers.LeakyReLU())
# X& _4 M% A8 i. n
; V7 C9 e7 u& d0 j: O- B model.add(layers.Reshape((7, 7, 256)))- Y$ y) I$ m) _7 m" @) p5 @& m9 G3 |
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
: B5 Y* l2 O' I. |0 ?- ~ ' M6 ~+ s O% B9 q) E: M2 S- ~
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
( m. G6 `: U0 K @1 k; I assert model.output_shape == (None, 7, 7, 128)- w6 a# j7 r* G# D& F
model.add(layers.BatchNormalization()): _/ D3 U; e9 p7 f8 c$ `* G1 Z# m
model.add(layers.LeakyReLU())8 U8 E$ I* U4 n3 K
# B6 L- N5 a4 u
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))+ _, h; T% b" [+ k7 q6 ^/ @
assert model.output_shape == (None, 14, 14, 64); H; G" \8 m" G" a( N) v3 k
model.add(layers.BatchNormalization())
7 U+ _5 v! O; {& Y6 [ model.add(layers.LeakyReLU())
1 ]1 n2 g5 d [) d
6 x- ]2 X8 |4 b# I$ r model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
" Z( P+ I; `* X assert model.output_shape == (None, 28, 28, 1)' t7 M( U$ x" T( W: b% Y% a3 r3 D
2 c: |" i2 F9 @4 D9 g# M3 G) f
return model
6 f0 R8 _0 y/ r+ B$ t6 x ! i q, ]' D2 E h" g3 R2 `, l4 U
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
" P, J/ ~# m) h% L" w2 o! }generator = make_generator_model()* T5 d+ M& _" E. e. c$ Q
5 T( H" n8 `1 B( I+ F% z) h4 ^
noise = tf.random.normal([1, 100])
. \ M' x2 X% T9 e9 Q, N) n. Hgenerated_image = generator(noise, training=False)$ v- B9 N7 J7 \9 {
7 T1 _ ^ u$ ]3 @4 e7 g: N8 L
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
" r3 p: } x1 U: ^6 n( n7 N# Ztf.keras.utils.plot_model(generator)
! Z$ n2 f. W2 p ( u0 f; {4 h9 d- I5 O( i5 r
# 判别器6 W" ?( C; a) J8 Z! `, z* H. W
def make_discriminator_model():0 o* l4 I; }3 s, a9 @
model = tf.keras.Sequential(); }. u3 ~: A% i3 S
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
% O& Y0 \. M5 W# x! S: K! Y input_shape=[28, 28, 1]))6 R; ~2 `1 `. P* `
model.add(layers.LeakyReLU())( l/ L8 G) w4 o" b5 F- M) a, p
model.add(layers.Dropout(0.3))' w( f# Q% \9 _! `7 n, _. z: s
5 R# W1 \: Z# \ model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
- y# |, y6 e+ b model.add(layers.LeakyReLU())
5 Z$ D& h1 ^# N- V s: g model.add(layers.Dropout(0.3))
" D0 H" s) k8 ?' Y9 W6 t 0 E; z+ B8 u' S9 D5 e
model.add(layers.Flatten())
" S Y) G, T D3 k, d model.add(layers.Dense(1))" m0 e- o3 l$ d" ?5 M! W
' p( ^5 J/ Q! L0 [6 s return model* \- K2 {0 M/ h/ C
% K7 n3 R7 G& ^ K; f6 k$ X; }2 Q
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。- y D3 B; H5 {' h4 @
discriminator = make_discriminator_model()
9 e' z& N/ }3 e1 ?+ t4 ~decision = discriminator(generated_image). b; Y+ C9 ]) u
print (decision)
* f" P( y: y1 ]- g: P 7 G$ Z8 P& U- m! H! m; u
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
0 w6 [' o2 v3 b7 j, H$ \8 Wcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
6 l0 `- W4 z& i3 h- j! _7 D ; C) N, \7 [* F, M2 c5 e
# 生成器的损失和优化器
9 C3 n8 [5 h8 i+ Z& P; Odef generator_loss(fake_output):% p4 Y4 h! D2 q. r. n
return cross_entropy(tf.ones_like(fake_output), fake_output)3 S! i, _+ e" C* _
generator_optimizer = tf.keras.optimizers.Adam(1e-4)6 d2 J( f! c3 ?6 ^1 ^0 {+ y
& y5 A- y [. c4 a- j
# 判别器的损失和优化器
# a9 f0 o/ h. v. V/ fdef discriminator_loss(real_output, fake_output):
8 F7 ^& J& e2 K: j2 G( l real_loss = cross_entropy(tf.ones_like(real_output), real_output)( H) [/ p# W7 b4 d
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)* Z/ z& r% W- G1 Z& l u- f9 @2 A
total_loss = real_loss + fake_loss i! I% x; t D: e6 z1 E: s
return total_loss
3 T) s+ }* A5 p* S8 d& ~discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)+ H9 w) Y3 c x+ L' C$ O0 h
1 s# t- T2 c! t& | Z0 O, ?5 d: p
# 保存检查点
$ o2 I9 {& g& q4 h! Dcheckpoint_dir = './training_checkpoints'6 a; H z5 x' @( `: i6 H; \
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")# ^5 G: c: d+ v; ?
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
- X* Z8 L" h( [ discriminator_optimizer=discriminator_optimizer,
$ m# F+ V6 o4 t) e; W; Q generator=generator,
6 P5 J3 k9 w8 |* q4 \5 _ discriminator=discriminator)
, \3 z+ B# {; ~9 g8 O
; t, Z9 @. o) y& I$ l1 ]1 n# 定义训练过程
- F% e- S/ N) g. T4 KEPOCHS = 507 v& \) ]! i- O1 j* ?
noise_dim = 100
8 `" R# T0 I; X, r; ^' i9 Gnum_examples_to_generate = 168 q2 m; e& A; t( H3 j6 ~
2 q2 Z' |2 P) g3 o7 h# y# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) D) M; V/ k, U5 y
seed = tf.random.normal([num_examples_to_generate, noise_dim])
' }$ R7 e5 m0 b( _. k/ _ + o/ T& O' i# s# p. `* `
# 注意 `tf.function` 的使用' w7 M% e7 y8 L: n6 E: {- \; k
# 该注解使函数被“编译”/ n- L) [3 G: y. u
@tf.function
; i* }$ ]& ~3 B% f; X$ k* bdef train_step(images):7 h+ T6 _; _ P- M3 M( g
noise = tf.random.normal([BATCH_SIZE, noise_dim])& E1 K4 U% C- E& Z; z! n
8 a: v8 A/ \, i6 Z! k5 x' S2 y with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:- R" m2 L4 ?9 F7 x5 m* F" m
generated_images = generator(noise, training=True)0 u- D R1 y. o2 s
1 `4 [# w& ]: `/ _
real_output = discriminator(images, training=True)
+ V* l+ a! U0 Z# h fake_output = discriminator(generated_images, training=True)$ H7 A2 o3 \2 G3 v- a0 x6 j
9 T$ S$ v0 C$ f1 x% f
gen_loss = generator_loss(fake_output)# y$ o8 `- E z6 H$ J( p! f7 b3 a# U
disc_loss = discriminator_loss(real_output, fake_output)
) _# a' W6 h' ]! [
3 {2 K! S1 r* F4 i; g# x+ i gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)6 R5 x* ^1 V( H2 l, \
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables). s. P$ S" N8 e7 I7 R7 ^
' a/ H) i* R& v+ @9 h9 | generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
r9 p# J4 h+ G0 K discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))& Q. a0 Q1 o. x+ ~) L3 A# x v$ K y
* t, @7 |6 g4 O) ^3 L
def train(dataset, epochs):* K5 k) s/ k+ Q M
for epoch in range(epochs):# S9 u" R# g ^4 T* I
start = time.time()! w, @1 \- @' |
1 o' v @8 H3 e8 Q' S* Q! F5 {4 O
for image_batch in dataset:
. S" N+ H, g8 {2 B `, L train_step(image_batch)
9 ]+ B1 C) G1 ?" r: k
& C& v8 G/ R" E p0 [/ i # 继续进行时为 GIF 生成图像
+ y) g7 U4 k$ [' ~* u display.clear_output(wait=True)
1 b* W6 U2 i; R generate_and_save_images(generator,
5 _ q. l8 }3 e epoch + 1,
+ E7 X& u8 M* A7 p v seed)0 y$ v. |0 ~; a6 |$ l {9 ?. R
0 L8 b5 t4 Q7 M& F0 g3 J # 每 15 个 epoch 保存一次模型5 V# V& F( X$ N' {( F `
if (epoch + 1) % 15 == 0:4 y% {( i3 i3 z; K% F
checkpoint.save(file_prefix = checkpoint_prefix)
/ U" x% ~! ~+ x/ w" x% `- ?
1 J3 Z: I* c6 \! K print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
. g. ~4 U3 {6 _ 8 p( j; v5 d! q7 v
# 最后一个 epoch 结束后生成图片
! l# j6 W- Z# k display.clear_output(wait=True)
\( Y" d4 K6 a {! d4 ] generate_and_save_images(generator,
* a% D# S9 g8 w$ N: b& b epochs,
# {7 s& Q& q9 g" [: Z6 v4 \0 ~2 d seed)
- c8 o. p/ o/ ?9 G; f
, L) `' C% }7 B o ]# 生成与保存图片
5 W3 @) D$ I, L: Ydef generate_and_save_images(model, epoch, test_input):
4 d/ f; R7 P3 c7 }: Z5 A; g) X% k # 注意 training` 设定为 False
5 d" T& R& Z5 S; B, }1 y$ c # 因此,所有层都在推理模式下运行(batchnorm)。
* d. N7 H( M1 i, I5 \ S( S5 G: i& B+ E4 Q predictions = model(test_input, training=False)! R, P% m+ u$ Z3 J: h& l$ E
n0 k! s8 N$ C4 R7 N9 H& c( m
fig = plt.figure(figsize=(4,4)), I+ N- o, u4 M7 P! V2 A! J
# w% K) ]- h4 Y' S, ]1 X for i in range(predictions.shape[0]):8 _& `! D: m D( D: Z% \- t
plt.subplot(4, 4, i+1)
* z( n; E) @! o f% D plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
) ]' H- c7 w1 O0 R plt.axis('off')
& d0 K) t: S$ F* o- y6 n9 F
1 X+ b) j# d0 j- c' j5 L plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)), C1 ]6 @/ k- x
plt.show()6 _- A4 x8 P1 B; R8 d" L
O' h/ |$ M4 l# H1 B
# 训练模型2 l; r3 f7 `5 R: T9 M+ Q' c
train(train_dataset, EPOCHS)! _/ g1 l& a/ e- p( Q
$ o+ W0 _5 f$ p; q$ e* `1 M# 恢复最新的检查点
! _5 k; l3 Y/ ^8 h' lcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))( B. U" P9 j, n1 H/ w- s8 Q$ K
2 N4 L6 M9 v2 d! G9 L8 l& A/ z# 评估模型% \& { A& b$ A$ S
# 使用 epoch 数生成单张图片1 T N% u7 _( \, V" U0 P
def display_image(epoch_no):3 m' V0 h4 B4 |& o5 G
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
# Z/ F4 J. I( O' W) j & A) W+ a. t7 Z1 v0 ]
display_image(EPOCHS) f3 z( O3 P7 L0 r% `$ ~0 {+ X
s7 r( _% D6 l8 n; S5 @
anim_file = 'dcgan.gif'+ b( r. H8 o6 H! v0 D8 Z
9 t" z# Z9 s- ^& V. }/ A* hwith imageio.get_writer(anim_file, mode='I') as writer:/ Y/ \) m( U3 _, E$ U
filenames = glob.glob('image*.png')/ `3 ^9 N# e$ D# I
filenames = sorted(filenames)7 g6 H7 d2 q- C9 }, @
last = -1/ G+ p3 R6 V) {
for i,filename in enumerate(filenames): E" A0 \% j) G; o3 Y5 i
frame = 2*(i**0.5)
0 K$ c) l" A% S* j" ~5 _ if round(frame) > round(last):9 x5 C, ~6 [8 E, k; z; |
last = frame
/ P8 E5 c+ M8 E' I else:
- ?- L E( E- R: T continue
/ v+ n2 T0 _+ Y, }7 U image = imageio.imread(filename)( i% V0 h" q) @! ~
writer.append_data(image)
/ d" l$ O; @# f! j+ o image = imageio.imread(filename)/ t2 m: n2 I9 h* J7 F
writer.append_data(image)
0 z' O% I: N! [ B3 R, ~& n $ e1 f6 d% d. H8 V7 z' w! p. l: J
import IPython. E# ^6 p* T+ q+ T* `
if IPython.version_info > (6,2,0,''):$ b. k2 \: g5 N8 H6 Q7 F/ G. P. i
display.Image(filename=anim_file)' W+ L# b' d0 M! t9 N7 E
参考:https://www.tensorflow.org/tutorials/generative/dcgan
: t5 h* J( d G- R* P$ W0 `————————————————5 F: [' d9 K, n8 I( C1 E3 v ?
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
# G9 d' S1 _/ n原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
- l9 O. p" l' o: Z9 T! a3 y5 Y4 {
$ r9 J1 l7 g& X& ], ?/ G1 U
0 W( H6 B7 ^* u" h8 X2 I |
zan
|