- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 559672 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 173274
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 18
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
l3 ?/ C8 D& r深度卷积生成对抗网络DCGAN——生成手写数字图片
8 s# h3 [1 v. F4 `' M; C2 O1 ~6 t前言
: @* X8 a. E( @* G) b: a本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
& s A$ T# E0 _4 C3 ^
; X7 `! k' }- |. u) Q; ]. k3 S. o9 I" i5 n1 Y3 u7 h( F9 s
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:4 F7 m+ U+ O) S! V0 L8 W
" ?7 F2 k! I0 d! D4 }
: t8 ^/ D# X- a( w# 用于生成 GIF 图片
: y: {3 C8 I, k" apip install -q imageio
% C# W1 U* E/ X目录' W4 n, s: |7 D9 P) x6 D$ k
+ Q8 C" t" r* J4 v/ _* o! d$ B. R. w h) _/ h) G; F
前言6 ~9 c: [! a) G5 P) r4 T
( @/ v6 r7 P9 K
* G, A, d# L$ c. L, S" f一、什么是生成对抗网络?
0 C: E8 D* K6 L3 u
+ ?9 _9 D! l) W- j
c, u' t2 C* e* o2 V: y+ b1 v( k% |二、加载数据集) g# i4 A# B* n6 e0 w6 G/ y' S
# S" V) p) b% y& N4 q2 ~) y
% U) b" c& c2 v6 T, r3 b$ D三、创建模型% \( e3 `" t5 d' e
+ k8 m+ F, R$ \. Z& V2 n6 U& v2 B9 k; V' u9 @- Y1 u
3.1 生成器3 Y) Z) L4 I5 S5 Y
, O0 u( D0 O! T+ U* H# l4 H
- O& `* P: ~, l7 f9 M
3.1 判别器
) k. R+ X K8 l4 i+ {
% M0 J; z4 }" ~, X! s7 U) A* n
# K+ K! c/ _1 @- {7 Q% ^/ R" K四、定义损失函数和优化器7 Q) y0 o9 a+ M
# l: S# {4 m" X( X# K/ n. f
4 c4 \; J i M; l% C4.1 生成器的损失和优化器6 o, h/ A$ Z- T
3 @4 O8 C) Y, J3 ]0 p
1 Q' l, |# v1 z4 x6 M- J" \6 c4.2 判别器的损失和优化器9 \: ?8 ~5 k# j% l6 B+ J4 E. \
7 e5 W/ ~4 k; M/ m4 q$ q% W( L
+ N! b P! \& j( A6 C5 [
五、训练模型
8 g9 N5 F6 ?. ?- I9 e
( \' U1 z q: `1 T; K2 U2 D
4 Z! n8 g* G/ s8 W8 p/ S$ a5.1 保存检查点" k+ t: @/ J4 E
" x) [) A) d3 r
) p) Z" e- Y( B/ n& f/ l c( ?5.2 定义训练过程4 W- s" `8 ?0 m( Z3 S% C
9 I& J( `: V! K
/ n5 x; R1 v& t+ U1 n* p. H) w; Z5.3 训练模型) }' g* e9 U2 X3 ^6 T6 F
1 h# N9 d* \! P0 K# r* _1 |+ C* \0 d; S% x) i0 j0 O* \0 J
六、评估模型
9 C: c' p9 ^' B3 ?* t w; B7 J# p- F; i; d
2 ?, q# X* G @# I, w" T一、什么是生成对抗网络?# r$ U0 S# s" x) }, U _
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
# {, \9 ^6 T5 X* G) C$ e" I+ N2 n; d; p3 R2 r g" R
9 C7 t, T- m |生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
2 ]+ m: i7 W' b8 S
# [$ q- V8 Q/ S4 h8 z
- \; E1 L! J- u/ o4 T- r1 I1 Y/ N, \判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
* v( p5 x5 \# S/ W# f
) g3 v# _' J2 H; v) J' L$ }5 W
" b1 D9 v/ o, j! G/ u训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
# p3 m( e3 L0 h8 y$ }+ A2 H6 c. c/ _8 w- V# r
6 p% y. G& `3 O! J' m7 d( w0 m
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。; E1 M0 p2 t1 m! |! m+ P
% F% Y7 S. l3 C$ S
, g0 }3 p2 a T9 G$ S1 s本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。# i1 x* q) o4 }
+ Z3 _( r# a# E8 ~% R2 H- Y
3 K# D4 }8 J( `& a; r二、加载数据集, ^" o: O+ y; U, u
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
- F2 o# B" ^+ z- i; y) ]% C, r9 b1 s f. U4 F
- y: o7 f' t$ r" w(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()+ H; f: \& ~/ A( r, q! T" y
) s4 g# A- A7 S/ R1 T8 i- k
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')4 ?0 } k' A. i8 K Y4 _
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内4 N) H! V7 `- C% J. R1 J0 Y& U5 P
6 w4 d% V9 B2 JBUFFER_SIZE = 60000
- P, Y: x5 H" M$ XBATCH_SIZE = 2565 D! j3 ^4 `/ l) k
# U3 v, n, K7 @) g l: Q# 批量化和打乱数据5 M3 g( R% f0 _# R# {0 x
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE). h/ N7 k1 Q- N' X) @& ~
三、创建模型
4 g* A: `; i( q+ K主要创建两个模型,一个是生成器,另一个是判别器。$ P; d# J- x' A- e% L1 x
. @7 s: i( g- P9 p1 O; M% i% {
8 u' S; V, M% X$ X e0 I# l8 D4 \& R3.1 生成器* C2 P2 q3 X4 }" M
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。* t+ [. U/ o) \% I: X" y
# ]+ X' G# ]+ J$ X& t3 @ x# Y, H4 N$ y; z/ Q( b7 j1 m' k' Y- c2 X
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
, _: j9 x. p- }9 A
0 d: f7 h' h* z, z# i- {* h0 S2 x) S' B
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。$ F$ O7 C9 l2 n+ U2 ^: R F* N
4 E) w; k+ o" [6 Z* D) N" p6 x
- _0 N/ C% n0 x1 B. m& K( b2 Ndef make_generator_model():
4 M4 q3 e5 D! K; G; i0 w model = tf.keras.Sequential()
" m. G% y6 [0 N- I model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))' Q8 `' c& L) l( G
model.add(layers.BatchNormalization())
4 X7 t5 [+ |. T) L* u& r8 } model.add(layers.LeakyReLU())
: Z) c* X# Z& W% ^4 a1 b) p 9 _5 b1 u" H! K& l% p
model.add(layers.Reshape((7, 7, 256)))
; o; W7 r' p! r: S assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制( W1 p# \$ I- K1 T5 j9 @, v
. {/ ^. H0 t6 D/ A) L model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
3 R, M3 @; F4 R. c( H6 q0 \8 p7 I assert model.output_shape == (None, 7, 7, 128)
" Y7 Z* {/ k; C- }1 n0 ~ ~ model.add(layers.BatchNormalization())
& `+ k% K1 M$ E9 C model.add(layers.LeakyReLU()), V! h% O0 N8 z8 X1 ^8 m+ P
4 n; L, l8 l+ u' I: O model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
# g7 F4 Y" h1 R) e5 E assert model.output_shape == (None, 14, 14, 64)
( U) ~4 M* E( ?1 Q [" F2 `; j model.add(layers.BatchNormalization())
' O) K* ]7 f1 A4 A' s' q5 J model.add(layers.LeakyReLU())5 `# P) E0 d. I5 _
n4 n9 G( T! ]4 g# K8 h model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))6 s5 A3 M* ] c. t# {8 T$ C
assert model.output_shape == (None, 28, 28, 1)) P4 V5 }, L- a M; n1 c* K: t) S
( @8 f: n+ p! @4 p' ]* A return model6 v. R2 p0 u+ w4 a$ v4 a
用tf.keras.utils.plot_model( ),看一下模型结构' J, x! W$ K3 c! B
; n9 k# ~' [3 P* B0 \5 _- x8 ]9 t0 m. W( b7 D$ A( l8 Z2 T
- _+ \( M* j" |* t- f
- N% |- ^* r! N
8 U: H- L0 H i1 `
用summary(),看一下模型结构和参数( f& \0 G7 b2 A- e3 y
1 ~4 m! e# k- G2 W0 S( T/ x3 M* I0 l; Z1 I- J: F. x0 V
|3 l; v" t& t S+ ~# u% J. Z3 ]2 [- `
; ~5 Z5 y) x3 m4 J: H) |- {( a
5 g3 Y+ j5 D w% @" Z
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。4 j# ?7 M8 C* [, d: n
9 N5 y: w" [6 u3 Q4 m
2 Y& b8 {+ o5 H$ w0 o
generator = make_generator_model()
% b5 I5 ] y0 X. [1 j
# c4 D6 C6 o! m& Y/ ^( `noise = tf.random.normal([1, 100])
8 l: X h5 @/ e3 C+ d/ e9 Rgenerated_image = generator(noise, training=False)
7 f1 z2 K6 Y- N% ]
1 W$ W. L! Z6 \6 rplt.imshow(generated_image[0, :, :, 0], cmap='gray')) z6 i( W$ Z5 {/ }% ]- Z
- A5 M) [, `5 b# L2 p" i/ E" a8 O, x
/ c7 ^9 Z9 Y, A* _; x$ X% q9 ? j0 D' ~7 I+ |7 z+ F* o
3.1 判别器
e. L! t+ d3 w' x/ X判别器是基于 CNN卷积神经网络 的图片分类器。+ @6 z0 W1 s& I) K% m: Z
1 q$ Y+ V4 O/ I" H- s9 I( V6 p2 X7 u# y4 L7 g
def make_discriminator_model():
8 K6 b! d# [0 ]- [ model = tf.keras.Sequential()
5 l& @, \ g, w0 ]% X# p7 F model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',+ Y- [; @! }4 C3 b6 v7 g6 U* ~. W4 w: _
input_shape=[28, 28, 1]))
& E( w# }7 |( t- q: b7 u model.add(layers.LeakyReLU())
8 w# A- n! U7 Y" Y: R4 ~ model.add(layers.Dropout(0.3))
* A2 Y% @- Q; _7 q% P1 g- _+ p
$ W6 _7 K: E" e2 f+ I! o; U; q model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))0 P- n1 [7 C- Y3 z% `: s, i& i
model.add(layers.LeakyReLU())
0 ?9 L4 |6 |6 |3 w3 g model.add(layers.Dropout(0.3))* ?2 x& @7 v: |( c
t; V$ [8 `' W model.add(layers.Flatten()): L. f$ k) M- E' o) w: n z
model.add(layers.Dense(1))
$ P: k! Y7 I) {6 ~$ p
$ O4 ^# b8 V! v1 d" B# A return model
o+ E5 u( K1 a; G, E7 A用tf.keras.utils.plot_model( ),看一下模型结构: V. j1 ` m4 m6 g, j2 B
& e1 q }( A8 P. K: t7 J7 H% n
5 V! Z, z, g5 C- c6 ?7 F0 k" l% I( H3 ?& f
9 |- w2 G5 U o7 U) R2 x! Y C
1 p2 Z4 q/ w. g8 s& ?" N2 r$ f% k
; ]8 a! B# |6 h, P用summary(),看一下模型结构和参数
7 t/ @5 s3 I) ?5 m, M$ {! S( s" G5 M1 `* Y; _8 [
0 c$ [) J! p% {' s* t* c& K) l* M: t3 H
$ V& q# U- _* e; ]4 y3 `/ ]' O
, J6 x/ Z) ?! }4 N. e- {" R, M3 {+ Y, {
四、定义损失函数和优化器, f; Z) i7 L! ]4 k: E
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。
2 m$ ~7 x3 K% |6 v. N4 Z: I2 E
& b P( G, \. @) e' k6 B! G V
3 q- s$ [; O/ C& b! j首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
8 f0 z8 r+ \, { ^2 ~# z+ `+ B8 s8 a7 P: O) f. G* y
% A3 R" M5 L# G+ K4 \7 T# 该方法返回计算交叉熵损失的辅助函数
- c1 C* O) Z) w! u1 ]/ r* K! ?1 dcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True). f# z7 G2 k8 k" L* t2 E
4.1 生成器的损失和优化器
/ e S- a# |+ t- C w1)生成器损失6 q" j9 j; T4 u* I# ], K; O* M
4 j/ M: E; Z5 r6 k: p5 A/ ?) {3 L8 s! N
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
" B" I3 S8 }, v% ?8 _ d$ ?5 O7 ?( `5 S R% X. \9 ?! Q
% c0 f* R! e; g/ K" y这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。) b$ s, c5 U, \2 _+ L+ X2 O2 h
! l% A# F+ D' Q, Q ~( Z! C
u: m2 D/ U# L# f( kdef generator_loss(fake_output):" O* t I& U% ?" l l8 n0 j$ Z& K
return cross_entropy(tf.ones_like(fake_output), fake_output)9 Y& D8 [% C" U0 E6 b
2)生成器优化器
; h7 P. V; b/ |' i: p9 [* J
! A, [/ c4 w8 a" u2 x
, z& K9 E3 m) V+ ]* Ygenerator_optimizer = tf.keras.optimizers.Adam(1e-4)- [' N. S% \" @% n
4.2 判别器的损失和优化器% B6 b! b" r! [$ r; W- t, V/ B' o& `
1)判别器损失
# z2 }$ @; ~: w. O1 |" z8 W) I
~9 v @+ `1 J$ }
3 t$ C3 \" U& I, z! N2 T判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。) e3 x5 @9 s/ U( r- n4 {
5 b. C9 K$ ^3 k/ ?: g# \: v( I1 y3 A/ W9 E7 I, M
def discriminator_loss(real_output, fake_output):
: e, \8 f. s8 P/ l real_loss = cross_entropy(tf.ones_like(real_output), real_output)$ T6 X: o' P4 m' V8 d8 \/ M
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
! O% ~; {; `) d2 O8 W" A) ]5 z total_loss = real_loss + fake_loss* y7 W @3 k4 P, _
return total_loss/ v0 O2 C! I, {
2)判别器优化器
) {6 E9 F3 ^+ l; L
: I- E% w8 i/ V5 { T6 o z2 u, r- [$ V0 ?2 J, H( n% s( o) V
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)9 i0 @/ x! z+ H) S+ ?& X
五、训练模型3 q- Y7 @) [7 _: o) J
5.1 保存检查点
1 E/ w- R% E! f- P* @保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。, g3 S7 A/ z8 z! O5 j& ?
; Q q0 |0 N5 q# U ^2 [
# r C( a0 d; { t* W6 Ycheckpoint_dir = './training_checkpoints'
+ L0 K1 x7 f) N, u2 ^. p: V2 Tcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
- u/ Q4 X e* T7 |( Y: t4 Lcheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,* K" b, u$ o. F0 Y8 F
discriminator_optimizer=discriminator_optimizer,
8 t& p( g6 V6 G' u$ h- n$ C generator=generator,6 ]& v2 o" W8 ?3 S8 s- e
discriminator=discriminator)
3 A( {5 \8 V& I5.2 定义训练过程
" ?5 L1 t1 D& L0 mEPOCHS = 50; I* [) }2 L6 ?, T6 G: a
noise_dim = 100
" ]1 F: x0 v1 s3 C% Znum_examples_to_generate = 165 X4 z8 k) ]1 C8 \
# L, Y& |4 d* w3 r3 X& d) H* e! v : D/ D, G' @$ b1 T) b/ K. p9 }
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
. _6 G# h6 T( F5 ^- z8 M- c6 |seed = tf.random.normal([num_examples_to_generate, noise_dim])
) i% t6 O& r) l2 h训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。' b- W7 C7 O7 d- C+ P [7 `
& S9 b1 A4 v* `) ~! F
# z3 ]" `# t. V e
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
2 B9 h$ T' K7 ^
( u2 a5 l* S4 p' H8 u' ]+ x+ t- I7 `
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。8 x* |9 q/ y0 ~ H
# |! T+ z! }. y0 [, k
) P0 U( s K4 t! U3 Q2 [% s# 注意 `tf.function` 的使用/ Z8 o& C M& K5 V0 g
# 该注解使函数被“编译”
& n% \, ?1 y8 Z0 S6 K. R@tf.function1 u- c7 C0 e) F: J
def train_step(images):
+ R m' V6 Z4 ?- \& E0 V. p/ Z noise = tf.random.normal([BATCH_SIZE, noise_dim])
6 D. I3 w3 p6 B$ A1 Q6 `8 p $ r" L }5 Z# A! o9 e) o
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# S! {" j: `& i# B( f7 H generated_images = generator(noise, training=True)
; s- h* R5 n4 ^
) C. [% @2 d& ~2 O- s real_output = discriminator(images, training=True)
7 p" W8 i0 E( U' C: |* @) M fake_output = discriminator(generated_images, training=True)& r7 p" |& b9 Z) h$ _& }
% @1 \" G; i8 `+ c p
gen_loss = generator_loss(fake_output)( g# K2 A2 N% ]. J- r1 ], P
disc_loss = discriminator_loss(real_output, fake_output)
# n8 U8 m5 T6 _* x) z; M/ D; R
' h9 W {' y0 l' i gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)( j# H; K$ A8 Y/ c5 Y+ z6 Q$ G
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)$ S J4 D9 w9 e" K
+ Z. Y( s+ q* s5 f, c4 @' C generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
: Q8 B3 T. D9 _+ B. D$ w7 V discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)); w, R3 o7 h7 P$ Q; s% P* U% w( s
- Y6 x7 J% p( e
def train(dataset, epochs):
: \1 Q0 u1 w7 C& w* k* T8 F for epoch in range(epochs):" E. E! M1 p/ m# m
start = time.time()
0 I/ l' u- g2 X9 @) \" e
" ]5 N! }0 u$ U7 ]5 o# h+ e for image_batch in dataset:6 n4 m, f- u0 B8 h0 A2 w8 w6 x
train_step(image_batch)! @2 ^9 m/ j; C! ~+ Z) J
1 K# H3 _2 d% a0 D
# 继续进行时为 GIF 生成图像
! y5 h# {. r# X7 X! {# I! Y# \+ l display.clear_output(wait=True)# z1 t' Z' L# l' l+ x1 H
generate_and_save_images(generator,3 n; C) X5 ^0 r9 F+ {$ D
epoch + 1,
: y5 Q" L8 E' H! |) ?' J seed)/ s6 T9 C" J+ P4 Z
# r2 Y7 s# V0 ]' b # 每 15 个 epoch 保存一次模型4 W/ A7 W! y1 W. N4 B( Y4 E! A& ?
if (epoch + 1) % 15 == 0:2 ?+ S% e; L: [; E
checkpoint.save(file_prefix = checkpoint_prefix)7 m, p/ g9 I" X6 k: u) p
+ `& y& l6 ~6 [/ X
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
" V8 @/ q. o+ n! J
. X5 S$ N. [% [+ c W& y # 最后一个 epoch 结束后生成图片: J; O1 j) o. u+ g! k
display.clear_output(wait=True)
+ ~: i! [% e% C C4 I, U+ I# x generate_and_save_images(generator,6 P; E5 V% Z& L) W! Q1 ]" C0 v
epochs,7 o2 a3 p) r/ c
seed)
6 L7 X; Q& b2 B0 g
7 ^2 z2 R* b' r# w- k$ c) [8 |# 生成与保存图片
: R* @$ T2 j8 i) g: [: b- fdef generate_and_save_images(model, epoch, test_input):" L2 L; x1 U# ^. J3 f& f* J; n
# 注意 training` 设定为 False
4 i! g; G a+ C" Z" | # 因此,所有层都在推理模式下运行(batchnorm)。
\+ h6 a% d+ q4 C: p9 m: |& B predictions = model(test_input, training=False)/ q% R; B: ~: \) @
5 M4 W) b+ F4 l8 L fig = plt.figure(figsize=(4,4))
4 K0 p+ |2 j' g9 j 0 u, M7 f. ?! z* v- i
for i in range(predictions.shape[0]):
1 R, `/ I0 B7 n3 S plt.subplot(4, 4, i+1)( k+ V( i+ C. I% t
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')* c4 I0 V. R5 K3 z( ?
plt.axis('off')2 [/ k# Q9 U7 U7 _
, n( e3 r. V) R! I plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)): S0 A" f9 ]) A; |9 M8 ]8 b9 h
plt.show()5 f. ^+ X2 p# B" v
5.3 训练模型 p. x% U/ L( ?- @( |& P
调用上面定义的train()函数,来同时训练生成器和判别器。
1 L3 j4 x/ R! P. _! `
+ m( X. A2 |# x6 v+ g8 |7 q S+ s1 N5 C- p6 L
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。1 S4 W8 q3 S1 A* @+ J& E
, R& }- ~ i; i
B5 {. e. f4 U5 s%%time
* A/ d/ H2 n" @6 l1 M7 Y6 w1 X) Jtrain(train_dataset, EPOCHS)' l( M2 f/ l0 a* i
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。3 v9 }$ \0 S8 s$ s4 T; T3 ]: ~) H1 B
9 N1 n5 O L5 V9 }1 {
5 \8 v5 m |* J' i: ]
训练了15轮的效果:9 z) J! Q6 M5 H3 C/ E& `( b2 p% u
8 J% E* O; x6 {3 |* [/ T2 Q z9 P) o( e7 D- Q: Y+ e/ A7 q% @' N
, x0 R+ r7 h/ R. E: d
: V3 \# y) F- \" j, }
. _' \& f0 Q, i; n- Z5 `
# \' \4 V" W5 F/ E1 H- Q
训练了30轮的效果:; V$ {% R& ]" k- D
7 D, F P+ N# _
# ?! F. X' I) a3 \0 [* f* G
" O1 t9 g# q7 a$ W" h3 l) B: E% B& M! N0 V
. \& d. |' e) m6 D! C
3 d8 U* z+ Z- P, s" h( x6 B2 P训练过程:* _2 \ N! V; T7 _; q: k! f
& w% B1 X- ~( T4 v2 b
7 g4 Z0 c( O$ y3 l0 _% M" ?7 @' U
8 G% V$ ^! b4 D
4 b& v7 v8 h' B' c: W
7 ~9 p9 G% ?7 m7 H, T4 A& u& _' j* w, v
恢复最新的检查点9 x% s, d# o+ T' r& f
0 @% z6 f m# w9 n3 T3 V1 t0 A) k
! ?& T0 b4 Z9 [" ] W. y/ Tcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" W" Q2 w2 S) h2 `: Z! `
六、评估模型$ R& ]4 J; D% u! [% m
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
- b$ X7 b, a4 W0 U. v \; M! [, I* F
8 w$ \1 _9 C7 d$ z' R# 使用 epoch 数生成单张图片; R5 l4 R7 |5 M) p) i3 q
def display_image(epoch_no):( G2 d0 o @ P9 ~& @6 n% e# u) r. i" `
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))/ M$ R: }" ^& R% y/ v, l
; q- {* e4 v7 d( j
display_image(EPOCHS)6 r& N& o- G8 y: H4 X
anim_file = 'dcgan.gif'# n5 l( N7 x" L; v8 y4 s6 h4 d
- v" E' P* D2 Q3 s" h# s
with imageio.get_writer(anim_file, mode='I') as writer:$ g& v* p% o. d, B. Y3 F7 x
filenames = glob.glob('image*.png')
; t, n% ?6 Y- _5 ~4 B1 J3 g* f filenames = sorted(filenames)2 B2 j. ]# ^$ ]5 D& w6 T8 t
last = -1. p! G8 ^$ O: o& {; g4 a {
for i,filename in enumerate(filenames):- x/ B9 ?! p A
frame = 2*(i**0.5)
' D8 x9 ~5 ]2 [) u8 @- l if round(frame) > round(last):
, p7 Z4 p3 \" h$ w last = frame) l/ z( \5 I; Y
else:$ j; x/ z0 N Q# l
continue2 r; q* k9 X4 d2 d, K
image = imageio.imread(filename)
7 J/ {7 o; L& p writer.append_data(image)4 y, [& }( C A" B
image = imageio.imread(filename)
# a) l8 F, {9 t) `) p% e writer.append_data(image)& M# c$ d* P* E/ }/ P
+ b# D# g! ~; C4 @2 q( L R2 @
import IPython
, |& F# E& v( oif IPython.version_info > (6,2,0,''):
! D' f4 c! n3 @% l$ M% @+ T' E display.Image(filename=anim_file)# l6 O- h2 \& c# K
5 M: e4 ^0 q* D- m
, m2 z3 ~6 ~. \2 }$ `
% u9 W) t l, J" a2 [; d6 | E+ t
完整代码:; e" p2 p* M( N- K% Q/ X& {, \
2 Y1 \# g7 u: Q* h" z8 S# j9 \2 _4 ?0 p f# [, ?- F9 B
import tensorflow as tf4 S9 N6 d% E/ `) t8 O9 [5 g
import glob
" |. b8 X5 [: I( O0 f6 ^- \' Bimport imageio% N$ @ P; O6 [6 Q0 N( ~ d
import matplotlib.pyplot as plt8 t1 i6 d) ]3 _0 c
import numpy as np
3 Q- Y% f+ [, |9 kimport os8 Q; M( _7 ?. _9 {
import PIL# L* p1 H: @" ]. G$ ~% J0 a6 ~2 ]
from tensorflow.keras import layers
' |) A6 h1 a' ~2 n$ K5 Y A( ?import time( n) I# a' j: X( K' k
4 t- O0 x, `2 U. x8 o% Z6 ~% Kfrom IPython import display
; O# U; S4 ~% h
% n! N+ k# }! Y# l, |(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
$ {" C# J2 |7 N+ {0 E/ @8 R, \, W 8 a, Y6 X* ]) |- A7 t) u3 G
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
4 T( j/ C# L* L5 C/ w% \train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内" p' E0 {4 R0 U- @9 {
/ N4 f3 S4 b9 D& ^BUFFER_SIZE = 60000" }8 H, ^2 ?. L6 x0 a
BATCH_SIZE = 256
. L/ k6 e. ?2 L- w: ?
5 N2 b" h2 v: o0 N* r* n/ a# 批量化和打乱数据! ^6 O t: n9 ^3 t" V* p: }
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)# y/ w. K2 |2 Y! A
' o! Q; S/ i- ^* I: D
# 创建模型--生成器
" X# U R9 v* k- W" V9 W$ `def make_generator_model():
" I0 T1 X; x: u2 s model = tf.keras.Sequential()
6 @7 @" U4 _6 D: ~9 N- _ model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
E* P) v1 {2 U model.add(layers.BatchNormalization())5 L( p1 G; Y# c! t( R [
model.add(layers.LeakyReLU())
, r1 `! P. U1 s+ M+ j+ q 1 _6 Z7 f# h( {1 M9 U
model.add(layers.Reshape((7, 7, 256))): q6 H3 n9 S4 L0 N. C2 y Q
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制- a( M& P% N2 w9 `4 R; Y
$ v' O& W' n3 p" m, A4 b
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))8 _0 n8 C+ H4 f/ c& U% L1 }
assert model.output_shape == (None, 7, 7, 128)
: r' o* U m3 g# _* J3 S0 l model.add(layers.BatchNormalization())* c5 W, B0 Q8 w' L, i
model.add(layers.LeakyReLU())
& m. L" G1 E+ S* u- v
1 D+ p7 t4 f5 {7 n model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))9 N1 C& q) ~; j4 p! ]: [! _
assert model.output_shape == (None, 14, 14, 64)
, u0 E( R( c7 F1 T2 F model.add(layers.BatchNormalization())
9 D" Q3 m# o) [( q( r: Z model.add(layers.LeakyReLU())1 U/ R! L: P" s
9 f: B: B7 r7 w
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
+ ^" {( g( R$ x0 N2 p5 ]8 a+ ? assert model.output_shape == (None, 28, 28, 1)
- e8 h* j1 A+ s! S$ _ ! a0 i4 V A' T6 E& A7 {0 a
return model
; A B3 o& D/ Y% M# N ; ~4 B+ t8 T1 v% i% G$ F
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。& M+ s# X( S$ X% t. [
generator = make_generator_model()
9 Z( ]3 n) X9 s; A
, S. Y7 m( B* P8 _noise = tf.random.normal([1, 100])% d2 S s5 k' H, I4 w" K: k
generated_image = generator(noise, training=False)( ^5 i' D" `6 h, J* P
+ J" P4 s( l _# U+ F& u7 T fplt.imshow(generated_image[0, :, :, 0], cmap='gray')
) y$ R) ~0 H5 B) p% |tf.keras.utils.plot_model(generator)
3 I, b0 W- E( p: ]' ?' a/ g, P 5 v% j) [+ {5 I7 |& [
# 判别器* x" p2 A' t3 Z& o# ]3 u- C) L5 a( }5 l
def make_discriminator_model():' ^" y4 L, y: `
model = tf.keras.Sequential()# \" }7 N. [8 A& X0 ~9 j+ w: {* g1 i# T# T
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
0 _$ b# E, m; V6 ]- \4 ^ input_shape=[28, 28, 1]))6 e: [" K' X' J1 b# s
model.add(layers.LeakyReLU())
" S$ M, w I# ^7 q! I8 x model.add(layers.Dropout(0.3))
& y. u6 N6 i: _% S- H
& s5 F" n1 A2 [9 t& j4 S- A model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
' Y) r8 J# }; v) c9 {$ Q1 D model.add(layers.LeakyReLU())9 N: w% Y4 m) ?
model.add(layers.Dropout(0.3))/ s& C% n# k) c
6 v( P4 d. s ^! h% l9 V model.add(layers.Flatten())
' U- s5 E5 A' [7 } model.add(layers.Dense(1))3 B @8 s/ s# M2 i
5 R) z& g: b+ L$ ~8 W
return model
' {4 G" A! @8 a9 k! b . L5 M J9 B" D
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
4 L' E( c( P5 k n6 R( P9 B% ^' u7 hdiscriminator = make_discriminator_model()9 M- q. U5 ]% F4 W; Z2 J" M, y3 I
decision = discriminator(generated_image)
( |" ^; S; x( \7 Mprint (decision)
, X% p: |9 p2 q/ a6 R5 L 6 o/ f. {# D- f+ Y, Q- u1 B
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
) \5 `' {7 L% K5 s- ^: |2 fcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)* u. `9 |3 j5 h+ v9 v+ w3 m. l/ T
; N% E1 B1 f8 i4 V/ B# 生成器的损失和优化器) k, w. C4 O9 l* A* l
def generator_loss(fake_output):" s9 E9 ?% ` N& o
return cross_entropy(tf.ones_like(fake_output), fake_output)* Q# D( Z. ` T9 r% W W
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
; S, U3 i/ Q$ q
$ t2 b* v" t ^' Q# 判别器的损失和优化器
$ D; Z% Y5 m/ ]% fdef discriminator_loss(real_output, fake_output):$ f1 p/ e) T9 Y9 Y- L8 j
real_loss = cross_entropy(tf.ones_like(real_output), real_output)( W0 z: [" l( z& {/ R
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)3 E! @8 q9 w( l! k( w2 J ~
total_loss = real_loss + fake_loss
: Q$ K9 e4 I9 x) A; S return total_loss3 i! i2 Y3 y7 F9 b
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)5 Z6 b& b8 y9 n! U
' C5 ^ h8 j1 Q! A4 w$ N/ i# 保存检查点( X) r0 W. N' g, ]
checkpoint_dir = './training_checkpoints'5 Q1 B8 S. ~4 {$ [; ]
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")! {4 B$ m( o% Z$ q, n+ Y' f$ x
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
+ c6 r) x! {7 ?6 c& }% v discriminator_optimizer=discriminator_optimizer,
% |% l: Q+ d- A3 D; ?8 U generator=generator,' x5 b) m( n- B
discriminator=discriminator)
' w3 R8 y0 \+ A
. G* Z3 ^! u' p! P# 定义训练过程- l [8 Y1 t$ n; ^# k
EPOCHS = 50/ c3 W2 a+ i+ {2 t3 [: E+ x$ d& b
noise_dim = 100
/ w# J9 `# s) l( _9 Xnum_examples_to_generate = 16
! Y O( x: C0 P0 F+ ]
t% f2 n- |- f1 q- y7 m! y ?# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)& ?. v! o. K2 H, r) F& \! J
seed = tf.random.normal([num_examples_to_generate, noise_dim])3 `7 l4 k0 {3 G5 G
( b. `: O* r; g0 x0 P
# 注意 `tf.function` 的使用
; J5 @) n6 i% O( i/ z3 t9 P# 该注解使函数被“编译”, w) G5 m1 m# C# r3 T2 p
@tf.function* E! ^2 h- R( W
def train_step(images):5 @! m9 |9 m7 A9 g
noise = tf.random.normal([BATCH_SIZE, noise_dim])% @" u6 I" m" n& y" m4 j
! f6 ~ D; p' o0 E( r/ Z
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: E w! U- ~- r; y+ I. e
generated_images = generator(noise, training=True)2 ~+ v1 H: c( d2 x H) l
: B& _, K$ `% J2 w/ v2 D, C real_output = discriminator(images, training=True)+ ]6 [, j9 }/ }" B9 w0 R* b
fake_output = discriminator(generated_images, training=True)
6 @6 n: C: @* O7 ^* F 1 g; b& c: g$ F
gen_loss = generator_loss(fake_output)
. N2 D- h, E& ]* ~* h disc_loss = discriminator_loss(real_output, fake_output)
- p. d' x) ^. x: b- \) c. q7 L! W ( x! w( [/ P, G$ A. h9 ]' r
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)8 x- [- @" \3 t, `
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)/ U( [4 Q) l3 z; t2 ?' T
0 X: X+ q3 p$ K: J$ E generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
) x0 W3 a- v& Y: o% |) d discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))6 @5 T% A- ~5 N* x" T: c5 p
; v: j$ L, }2 \! ^4 {def train(dataset, epochs):
3 s, X) B/ }- [4 ]# X% {9 |# ^ for epoch in range(epochs):
, V# S1 s& j! T) J. Y+ t2 W- o3 w start = time.time()2 I' \; Z$ {2 A8 v! q
- B; F9 N# s" \8 {/ c, k4 W1 U
for image_batch in dataset:/ i7 t* {$ v+ R; H6 b+ u9 E
train_step(image_batch)4 ^) P: H& h" J# r- k- P
- G, e8 n- K! ~
# 继续进行时为 GIF 生成图像2 U! A6 R b2 j/ e, q
display.clear_output(wait=True)
" f6 L5 _- I3 {* k% e7 [3 B$ t generate_and_save_images(generator,; \; b8 p4 f9 Y( F: w, y2 N
epoch + 1,. u$ |0 l; n4 m( F2 d! b+ d1 h
seed)
. h5 T- G" y) {# b7 Z9 T . _8 F. B8 D9 S" ?# k) \
# 每 15 个 epoch 保存一次模型
& Z* [) f# f: E/ {" { if (epoch + 1) % 15 == 0:! }- j3 O$ Q2 w/ e9 R. g' V: E
checkpoint.save(file_prefix = checkpoint_prefix)
4 U# x1 b$ J& L/ B3 p3 U4 A& c
@+ q, T( C+ b! F5 Q print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
: H. A0 z5 I, S7 ^2 ~! q3 A
! h% ~% ^& _" l6 T% X0 M% N, f # 最后一个 epoch 结束后生成图片
d0 h' C: G6 k, B display.clear_output(wait=True)) I0 e2 j/ O" o) ~6 S4 m
generate_and_save_images(generator,
7 d9 l$ i! l5 m' {. N, G/ m+ D epochs, M1 r7 u1 M: Y. _
seed)
" d% H- \* V# l& z8 S8 ] / h# x: d1 c7 N, X
# 生成与保存图片
5 g' S+ _# W" idef generate_and_save_images(model, epoch, test_input):
) ]0 p$ n0 a8 u O. O # 注意 training` 设定为 False
! Y, }3 {% r+ n5 } # 因此,所有层都在推理模式下运行(batchnorm)。6 X$ z Q. U" ?* S3 r
predictions = model(test_input, training=False)# y+ h; V, ]- C/ p, L9 P
$ [& o: T" Y' E& Y: M/ U0 R
fig = plt.figure(figsize=(4,4))" f2 m' u( C0 @ D" I' Z
5 r, q: N+ ^7 Q+ ~6 F' y
for i in range(predictions.shape[0]):7 ]& H) S1 H% t" f1 l0 \
plt.subplot(4, 4, i+1)
! u$ D( @9 c* {% H$ ] [ plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')( H* `( O! M% ]$ u5 [
plt.axis('off')
& m7 C3 P9 a) O+ U% N
- o. M$ R4 z9 z) L5 [' B) @9 D plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
) J; r- V# Y# u8 Q0 P& y1 o plt.show()% I: I$ [$ u5 _2 O% q
9 U7 m8 z$ P" W. _. e9 ~
# 训练模型/ o8 H) i8 Z0 ^; e& L- d
train(train_dataset, EPOCHS)
6 T' D% b* a, M 3 V* V V4 G) ^8 ?3 K. n9 [! g
# 恢复最新的检查点
5 J/ O# F. C# v4 ?- P: b6 @checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)): A3 i5 n) a2 [, B1 T3 ^
& m. W* F; H) i! Z# 评估模型
7 N( p8 c! }6 U0 G r2 m# N* ^# 使用 epoch 数生成单张图片
0 ?6 w- h1 x: }( ?! Ydef display_image(epoch_no):3 C$ K6 o- m/ [; Y
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
" x% Z! ^) i9 H1 Q: D. k
5 z9 h7 E0 t8 L3 pdisplay_image(EPOCHS)+ Y. \# d1 @5 l) h1 i z* @( E
, w) ]/ @6 ?" V* S0 @0 panim_file = 'dcgan.gif'
& K8 I) f3 g, T4 N# d( k4 u$ w ) t7 s. ?" H0 ~; \! b1 `0 j
with imageio.get_writer(anim_file, mode='I') as writer:2 U0 z6 x; k1 i, L$ K0 b
filenames = glob.glob('image*.png')( _$ U0 w* X ?* _2 a
filenames = sorted(filenames)5 g" v! q! ?( [- t# z8 g& Y1 Y2 s
last = -10 b o+ b( U: J" b9 f
for i,filename in enumerate(filenames):
! J8 i+ l8 O7 y# r frame = 2*(i**0.5)
6 Y. X# [' q1 |% Z; ` if round(frame) > round(last):( ]# l* b5 _1 `. H
last = frame
# @# W4 E8 i# s. u2 Q" P' m else:
3 e/ n9 v) v. V: r$ o continue5 q. P$ { W- l, z
image = imageio.imread(filename)% N+ P6 R9 ?& A/ n! y" L$ r
writer.append_data(image)
3 I {* Z: O0 a image = imageio.imread(filename)& X; g' q, c" A/ D- }' b- j
writer.append_data(image)
; c2 i, ]; n ]/ @( M ) k6 \) t, C( f& v
import IPython
( F- W8 S. ~, N8 P2 \* \% [8 jif IPython.version_info > (6,2,0,''):
1 q% t8 M/ |8 X& _5 }$ M0 n/ ? display.Image(filename=anim_file); W& q( T0 W* B" h& I
参考:https://www.tensorflow.org/tutorials/generative/dcgan: Q* K3 ]. L( S9 h
————————————————
, _" P. r! `7 W/ o& ]2 }版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
7 Y9 n8 G2 n/ D- W8 Z1 Q原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
% e+ n0 r. z9 l9 z* B+ L6 w4 I8 `# R% ^9 N
: T$ s5 X T- w. q( W8 E
|
zan
|