- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563274 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174205
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
* e1 E5 m* I" I2 z; |$ s8 B" c" Z" O
深度卷积生成对抗网络DCGAN——生成手写数字图片
" S! K5 f" O2 Y0 B6 D9 ?+ i前言( s( B( B3 o/ U+ e( _
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
0 m- d6 H9 ?! _( v k& A) f& I" D" M0 F" b1 a/ L7 n3 c
$ H1 i6 L: m# m8 }; x' Q o6 I 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
8 Z L+ F1 j! n/ \* k3 ]7 ^6 c
2 J1 e; l8 X- d H: m1 {3 _% M! b* t* ~
# 用于生成 GIF 图片2 O2 M* b# a' E' T
pip install -q imageio
4 H! e; ?# V9 B. x- Y1 L* j- k目录0 l$ o: D+ C& M. n& a
$ ^/ k* [- S' X- ^- d: Y
( N/ F9 u5 F+ O前言
2 F Y# W- k' [: i4 A& H
8 G' N; o1 u) I+ @+ }5 J4 Y: A5 ?" Z2 N7 C
一、什么是生成对抗网络?
4 N; u6 s- ?+ J7 d$ `5 ?9 _4 H) z+ \; {2 L
" _5 k6 }' c# g
二、加载数据集/ ]: ~, b( C! X7 {8 f( N
4 M" p3 ]8 U4 Q/ G& z3 K
' h4 s6 p) w" x1 h3 t) J
三、创建模型
# I% l# \/ q% O/ F% h: a! ?9 [7 J! ], G
% S" V m$ A5 m( y; F0 Y+ N
3.1 生成器
; V9 z# P0 d8 q/ L' g$ H1 M. L( y+ e# q3 V
1 x6 V$ G7 r0 w! D3.1 判别器
, T) p9 M* q; R/ w9 ~4 X8 d$ ?' Y& }/ q6 a* \
8 l* t* o9 G, X四、定义损失函数和优化器$ B8 ]5 T1 Y; a! A
7 \+ R! v8 A+ ]2 u- t
& U0 L% O) K3 g( D0 G6 }
4.1 生成器的损失和优化器9 @. ?; H L+ A7 Z" ]
9 U. G8 t; A; C, d
$ D4 ~' x4 V5 h5 U1 e, K1 C; T4.2 判别器的损失和优化器
$ O# O G9 g8 |5 @4 j9 u$ {- ]) r9 T0 ~" v
- _- {! F* a9 D3 Y6 C6 ]- u4 @4 _
五、训练模型 ~; a6 E" G1 `
( m I" p6 M* g0 T$ A0 a0 s% M- O7 o# i8 }9 U) Z( S
5.1 保存检查点
* m, D U+ r# O' J8 T8 V% d1 }0 \' u/ `8 U: v( s
1 ~ v2 d) d' d1 s9 Q' j
5.2 定义训练过程1 ]- H6 v1 n6 v* B+ Z: `
! a7 m) K3 s$ b
0 o( k' |; Q/ J$ A5 `5.3 训练模型
- g! m9 Z! a7 G) O6 s7 U( k) _9 J* V3 e# c, ^1 e/ e
* m0 _% ]( `- i9 l* B* Q* ?! t+ u六、评估模型. ]7 K/ k* R. o5 F
* L/ O8 f3 g- v z8 _- j+ B0 @- g
" ? N! z5 }6 |7 p! m1 x% I7 \一、什么是生成对抗网络?
7 M! i3 H6 B8 r9 z生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
+ O: Y* [* d* }' P
- J: R: V1 u- D+ h: d. j/ C! K: k5 q& _+ X5 H0 U
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。- e: T/ F; X9 {9 `( T
% [* Z$ U ~* ~' @) P5 v+ ] g2 @
1 x" U* Y1 x: O) Y) q/ N, l
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。0 n8 T* U$ E- F
1 K* [2 `# f3 N2 p5 S6 U9 x9 f
& J( v: |# i: R$ G7 _1 F( }
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。 J$ a: N5 ^, W# E% D- e5 W* l
* o: ?& ]% A; ^: q& f" \3 g; @) M: G' E7 e0 X
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。5 B0 f- L; C0 K- C1 y! S
$ Y' a- _! k" j* L1 u' G/ O7 X) c4 D1 k1 U: E8 s4 c
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
9 t) p; q3 I; }% U' g5 ]4 }: g0 f3 _4 w5 b# l- |0 l+ s3 P8 @
2 n/ A2 ?' l) r二、加载数据集
5 ~1 V. p9 p. ^2 R使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
. y! p9 b5 Z+ K. e% b( E
4 m x* b8 n6 J8 O9 d. K, O6 \% a' C
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()9 x5 E8 k, f7 C8 D; C
- H* F, E* u) }/ K P7 ?
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')0 _0 S4 Y% S; V. `6 a
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
) y/ `3 R b& n. j1 }4 y
) p2 a i! g( T$ p% k/ d5 QBUFFER_SIZE = 600005 _% y9 g6 f" h% S$ e
BATCH_SIZE = 256) T! t6 i- g9 L) W) `7 O
0 o6 b" Q" C4 T l& u8 F8 }) h" D# K1 s" J# 批量化和打乱数据" Q5 L2 J2 x8 W" j( Y( L
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)5 w- B. j& y/ [) Q2 ?8 N1 j
三、创建模型$ D5 a4 ^* U1 I6 a, {
主要创建两个模型,一个是生成器,另一个是判别器。
9 F/ W/ F4 F0 z5 N* B& P& J% y
( u, X I8 U* e3.1 生成器" M* A4 _7 d( |/ \6 P- D; T
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
6 A% ]' ]: M3 G- G8 L8 a+ x. z$ `+ R- L9 e7 }
- k( p) i8 f1 h6 O; @- ?. Q
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
- }# p6 f( z4 p
. \" ]! Z0 }3 m/ y
- v4 K9 h5 y/ p( H' B) j: S后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
" X; O$ a+ F/ z/ W) W, ~
: ~4 f" g5 |; f) A; v7 X, t; Y$ w! E0 M' G. i; I
def make_generator_model():" V4 Q6 b& O$ ]' v# F; [0 |- M; U; C9 u
model = tf.keras.Sequential()
' w% z4 A' i2 ?3 L3 i. d model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))' G3 `0 T5 g+ B1 n. ~: t. c# x
model.add(layers.BatchNormalization()), U7 z4 _4 y: y: y7 ]
model.add(layers.LeakyReLU())
: i) h8 t: f' k* K$ I: e# p, L- q * B8 j& @& s+ J+ H k% J/ L
model.add(layers.Reshape((7, 7, 256)))4 S3 Z9 }! B6 c
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制8 o" Y: P' h; X/ \/ U0 L. W
M: b) |6 R l# @
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
" T: j) I1 p. U+ W' E0 d assert model.output_shape == (None, 7, 7, 128)! H- {6 a V9 t' M! X8 x
model.add(layers.BatchNormalization())
7 r- a% I8 J) r model.add(layers.LeakyReLU())
! L9 S: m" U4 ~* w/ n) |
, Y$ W0 L; v3 D model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))1 g. @/ |% S2 U* a" k# M) R
assert model.output_shape == (None, 14, 14, 64)
X# P' m* I9 \/ C model.add(layers.BatchNormalization())
+ u. |; I/ @" {4 a6 p model.add(layers.LeakyReLU())( P; q) M& }( ~: R( ~$ U! T8 z
- [% i9 Y( n' C; {5 ~! c; g model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))( [$ A6 s7 T$ V/ N
assert model.output_shape == (None, 28, 28, 1)
1 m9 B- v' A5 M9 H/ A% ?# j i 1 q5 x+ i5 B. K& ~
return model
' F, s9 R/ J: N用tf.keras.utils.plot_model( ),看一下模型结构+ H. L% B" ?# i
# E5 P, \8 U2 L
+ e5 q* \- l5 {8 f$ |# E/ N
* d- Z4 e* w4 v1 `& @+ ^ z. B8 d. q- H5 O; n
2 a y) _1 v3 @! W4 }* C9 i0 w6 y
用summary(),看一下模型结构和参数
! r0 A# i" p$ ~; r* t O6 j: u( c+ J* ^$ x2 W0 ], z
5 Z i1 I$ f5 p+ [+ j9 m
& n, p0 U& ]$ ^2 M2 c1 S1 B0 w# y/ L8 V$ N, p; C8 S% J; n
" ~6 A: n1 f8 f( ]5 ^1 l' M
. ]+ g" N U6 \0 B) l使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
1 n0 N5 n c3 `
' Q! d. J* ?; O( z. {1 Z" Y! P+ b3 L5 u* A( w; f5 `
generator = make_generator_model()! o3 k7 F* t# S8 F/ [- p9 Y
, `3 P- s! P* j/ H$ u
noise = tf.random.normal([1, 100])
; P1 e+ z. [3 Xgenerated_image = generator(noise, training=False)
! `6 m+ \: m& L" j6 T& g& R
( Q; N3 T) |$ X0 bplt.imshow(generated_image[0, :, :, 0], cmap='gray')5 j( v. ?3 C. `) w/ s. D
8 i y# w5 c0 ?$ ^6 M+ [- F
$ J% Y3 c5 g- y7 L9 m
& D3 W- s" Q/ }4 _! O6 w$ b
( R8 y* s! l/ K, w3 H3 Q+ r3.1 判别器
' @, _5 y' r* k/ `9 G. r; ~判别器是基于 CNN卷积神经网络 的图片分类器。
7 q; K7 U. l) S* T
. _/ e! j9 H, {( i8 I O4 V! C X$ I0 O/ E9 L# L+ y8 F O7 e7 `* \
def make_discriminator_model():
; W% e [. Q& T, x+ L4 [ model = tf.keras.Sequential()! ]6 V2 [8 }; @
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', [5 o& D0 v2 K. \2 ~" C% H
input_shape=[28, 28, 1]))
8 O1 o8 Z: ~1 ?! Y- n model.add(layers.LeakyReLU())6 O% y$ E, U* Y
model.add(layers.Dropout(0.3))5 I q: L9 e; K! f) t3 J
# l. E6 p! w2 D8 _
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
i3 t. `" T; M$ y0 N) b1 s- E7 f model.add(layers.LeakyReLU()). R$ G( c+ j1 [) t; b
model.add(layers.Dropout(0.3)). z% e; C& v& c+ V1 X
' b+ \5 p, G! A) F- T4 a: ~ S
model.add(layers.Flatten())
4 Z" i4 q; d6 R) L+ i9 ~ model.add(layers.Dense(1)). D. L* E- a7 z% ^" J0 h
- T5 L5 s( b# a1 s( z; ^ e! F/ E
return model
J4 O7 X. u1 b; k* g3 d用tf.keras.utils.plot_model( ),看一下模型结构+ L6 ^# T* s: U* ?$ w
, z- l( o) Z. u3 N) }9 Z
/ o$ ]# p7 N# c* [/ J# U8 m1 b
$ O$ k( n& F$ L' [
2 y+ R( k1 J0 p7 n) V* s( E" ?
" n. R. p/ w- b: c/ ?9 r$ p+ t1 G- g, z, C: x' ]' T1 ~
用summary(),看一下模型结构和参数
2 Y* m; y& |* A) n: w
. {2 I7 R$ u9 F! ~8 G; q. n7 ~) }: M* P- Q) ?8 o. A
+ Q5 P' T; w" W
0 I: s6 S/ _0 p5 |2 i# e
9 D l( e1 { t ]1 U
% Z$ A* t5 T4 _) ]% }) e四、定义损失函数和优化器
x3 p4 o" N8 h3 |( R& ~) Y) Z由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。1 @9 Z. E( ~% I* Q, \
9 H/ O6 m- @7 ]- t
/ k2 j& o: L' ^# m首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。& c4 e' c0 O9 `8 p4 X' s0 G
( f+ d2 @9 \- Y
, O( {8 g) F: i; x i3 m# 该方法返回计算交叉熵损失的辅助函数2 }+ s0 h% C7 L+ g) _" }3 s) `
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
3 ?- `& e3 f' a# k( J' J4.1 生成器的损失和优化器0 j+ h+ o$ R3 S
1)生成器损失
2 R/ ~- i( d0 r; _7 f6 i- J
1 y4 G3 u+ u5 y, ^
+ _/ e3 |# S+ B9 B) r4 @7 I生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
- g1 l. B* [0 R& t8 V4 R4 G1 h: U& x; \- R3 Q/ k$ \
, `' T0 ~" T; E0 q, }这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。1 j0 S1 e& Y: Q7 w1 z. L% Y
+ u/ j: }+ ~( w" ~8 z: X
& Q& @, p& \) Xdef generator_loss(fake_output):% `% N. r/ F* v2 y3 D; H
return cross_entropy(tf.ones_like(fake_output), fake_output)
4 g" l) {6 c# g0 `3 a0 L2)生成器优化器5 c) O) ?8 B0 c2 P/ T! k. ~
. r2 s( d; t, k/ c4 a- I' ~
; e3 M. X6 w3 ^* r0 u! r
generator_optimizer = tf.keras.optimizers.Adam(1e-4)) a6 h1 a) L4 P+ N
4.2 判别器的损失和优化器
# ]7 |) Q9 u: ^% d- |1)判别器损失3 t; R1 p0 L' C1 ~2 y
) I$ C S1 o7 Y9 [; |
* _$ ^ P; w: d( ?& x
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
4 u& k' S! H1 T1 N/ x) D4 X5 ^& ~+ m+ g$ u; g/ E/ Z
0 p2 |) W4 c/ x B4 @# {8 v
def discriminator_loss(real_output, fake_output):( k$ L* A. S! T
real_loss = cross_entropy(tf.ones_like(real_output), real_output). C& X8 U) B7 _5 R! n* G
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)5 B0 O# G3 U; W r" Z! A
total_loss = real_loss + fake_loss
/ S, I# Z6 N+ \( [; A return total_loss
3 C8 x; F( y$ y2)判别器优化器7 I* Y% ]$ K/ L$ c: }
" o( D* P( m0 A' g! S' Z/ L
/ n, R: J, W# Ydiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)0 |, K/ b$ t7 ~# r6 s7 _8 P
五、训练模型
* [# X4 j' M5 F# j2 W2 k+ w% s5.1 保存检查点
! j" y5 k. w( a8 q+ t- K) u保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
4 h: G& P5 ]9 [- I5 v) H
+ M0 `8 @. I z5 ?
2 b4 l6 G o& x x- [( kcheckpoint_dir = './training_checkpoints'9 J8 f5 K' B( N" C, w) Y
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")5 m/ S* N- u- D- I8 |
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
3 Y0 l" E& Q. c discriminator_optimizer=discriminator_optimizer,2 N8 d9 X5 ~4 `9 V+ Z
generator=generator,
$ S2 E" _! L" `3 n5 k. g; P discriminator=discriminator)- P/ o- s& S2 U; T/ e) W1 L
5.2 定义训练过程
' \& d9 R6 d; X- e9 A- zEPOCHS = 50& e+ N. @2 {) @1 k! E# ]! V
noise_dim = 100
' j1 a Z3 d7 j% v- k$ fnum_examples_to_generate = 16
9 _* E( i8 n+ P3 Y; |7 E& w 4 ~2 ]6 }$ z9 i- B( v" q
. x7 l* D! a3 @6 k
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
# M" p; D# M* h& rseed = tf.random.normal([num_examples_to_generate, noise_dim])
q: C5 m. U1 V* B y+ K: z训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
; T2 C0 v/ U; M+ T4 s7 e$ A3 T( h( _4 v: t' }* b
3 [. |' l& j( q判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。( L5 V& L( b# d
1 i& y4 M4 p2 i0 I
6 K4 w' Q% g8 B% Q! J8 _. l6 L7 C& f
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
/ D7 L2 q: q5 `' J& N# Y( e- [$ [! i% D, L- a
* {$ T3 Y/ s2 L' v6 ` R1 V
# 注意 `tf.function` 的使用
( q3 V* N* j( {5 H$ |# 该注解使函数被“编译”4 e9 N6 f% V" t4 Y% \
@tf.function
5 g V8 `& Q& X; B( {/ \5 L* O' ~def train_step(images):
7 ]1 U( w" W x noise = tf.random.normal([BATCH_SIZE, noise_dim])
) p* \! J' i6 X0 p& T 9 o" D v+ R5 l# S
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:" X1 [% E/ M; c5 {" a- A
generated_images = generator(noise, training=True)
% v& |9 b9 @2 I; u$ _' V, l& N- F " t; I7 j0 o) A# ]
real_output = discriminator(images, training=True)
2 s$ l2 B# I% f$ i fake_output = discriminator(generated_images, training=True)
8 {; q( A6 Q' ]9 u
/ p' Q1 m. [ C7 g7 A) {: Z7 L gen_loss = generator_loss(fake_output)
6 @% K- u" x) |, G! a/ P disc_loss = discriminator_loss(real_output, fake_output)) m3 a# K& I8 y8 ?# V- i3 k
( }4 k1 }0 H& G7 `( v
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
: U; K% k1 }. M$ `/ C* h gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)8 _ E/ }9 \4 S
8 D1 `0 Y Q# i# \2 r generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))5 f% P) W' t( G
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
! F6 I8 V, q2 w
~; t" h" Y1 N6 Q, b: odef train(dataset, epochs):
! x4 ~! E; Y1 Z0 u' \ for epoch in range(epochs):) \3 T1 R9 D% V- |0 B. V
start = time.time(): o' f& n! `1 e P" E' v2 e/ e
8 p6 g6 z/ J K" U- V for image_batch in dataset:, ]( ^, E6 O% x; y* |" Z: c7 u$ Z
train_step(image_batch)3 }5 I" z2 k: e% o% x
+ O7 a$ a: j9 C @$ P3 Y # 继续进行时为 GIF 生成图像
1 j1 A! a3 x4 m* Y1 V5 V display.clear_output(wait=True)9 y$ z" F, m* }7 c
generate_and_save_images(generator,. X1 D5 t+ ^3 D; g, v2 t. O
epoch + 1,
/ H0 y' L# l# p5 J3 n, i- s seed)
9 {+ b% W) {6 ^ + F8 U4 h7 z+ r3 ?1 f5 l$ L
# 每 15 个 epoch 保存一次模型, l* y+ ]' v. j8 K
if (epoch + 1) % 15 == 0:
% p) @/ `7 q& p( j0 O checkpoint.save(file_prefix = checkpoint_prefix)) S, G% R0 x5 _% x& y3 P
% B, ^6 o+ h$ z9 s& y% k; B
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))" o9 V3 ~2 A' [) u5 o0 F
: l4 o/ Z" Z2 p& E6 u# {
# 最后一个 epoch 结束后生成图片0 c6 {* U/ @$ I$ N+ h5 I
display.clear_output(wait=True)9 C' K7 h' s4 [! T
generate_and_save_images(generator,3 N; w& ^: C/ C* a2 R) g
epochs,+ v& [+ O4 L! v# k" X h
seed)
1 V+ B+ F/ w) A9 ~7 d2 j* S $ A/ l! p; b7 I' L( m8 y
# 生成与保存图片
* ~+ @, M. K, v1 `- s- Jdef generate_and_save_images(model, epoch, test_input):
- v1 P% R& z2 i+ }) D. v # 注意 training` 设定为 False
$ {9 J O7 {9 }' a+ f8 x. K # 因此,所有层都在推理模式下运行(batchnorm)。! L7 \) S# Q6 F0 m" l. x+ h
predictions = model(test_input, training=False)
+ H2 d3 s: s& o % g4 Y, H+ Z1 L7 y3 H2 u" @8 N
fig = plt.figure(figsize=(4,4))
: |5 ~1 x3 {: x2 ]+ _. w" ^( J 3 `, K9 F* l6 H" e
for i in range(predictions.shape[0]):
" C9 J4 z- M1 w/ f; `1 i; V plt.subplot(4, 4, i+1)
/ R1 ^/ `$ I/ N0 p3 J$ ]; { plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')0 p# H% }1 h: h0 p4 b) K
plt.axis('off')
N) n: d' Y) t [) _ + j/ V7 z T' @& a
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
+ T1 b B) l" `, z) W3 b$ B plt.show()0 e! j6 j, C' I5 v
5.3 训练模型. c* d, D8 e& N2 Q1 t
调用上面定义的train()函数,来同时训练生成器和判别器。2 k6 Y9 D# E+ r. A
5 G' [3 j n u3 g* K& r( u+ u
7 I! K5 }7 X- n# A$ C. r
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
2 O: w0 O: f( V" B# m# |% W1 m& V7 L
- O" |$ b1 R/ a( J ]# `
; V- b; S+ B; _: i) v3 o%%time
- u. v5 v: k' r% N7 t2 |. L4 ^' F: u3 N# Ntrain(train_dataset, EPOCHS)1 h7 a. F; W$ I" ?
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
) f" W+ X9 z* X1 h" c5 Q7 S
* q' c% {( Q i& B+ [, n* |8 C2 k1 u: V' Z5 N+ f( ^
训练了15轮的效果:
/ L0 c" F7 w% I& i( `7 o8 \- T9 f- E- R0 T" I+ [
* x7 f% I+ X5 e/ A: Q4 u; z# t# [( S) E+ `3 R+ x
0 G, {6 i5 l; S
; ]8 Y6 U; y! G# {
# A1 e5 f6 z+ o( W* W. s* ~/ t) p训练了30轮的效果:# ~# W0 O9 i( K; ]; _+ l
6 M2 P! R! l* O* Z, m5 {# N p, [. ~6 K+ T
6 y& G, Y' o( l# m$ q! F @* X) F- R* V9 d
4 |5 L8 `7 Y- r. {5 k7 |$ q% O2 T5 Z! W
训练过程:
8 Z0 L# [5 o* l; p9 ?0 R2 _7 x8 c9 @& o* e5 [* a* J
$ Q- O- B) _. x# ^# o) t* {4 v
1 w- B; C$ t5 D
( { E% x, D6 O' G3 j, _( [3 o; H% B
. S# j% i/ O: d. m9 Z) S
! n- |7 ?7 V$ v& G; U
恢复最新的检查点
* l/ { n* n1 G; S$ G2 F/ R9 B) j+ |5 \8 Y2 N& E* ^
2 s( l; l6 e4 z7 x6 G
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
/ X# `3 {) W8 z5 c六、评估模型
! X+ h& W2 X' y' P, d5 N" ~这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
+ P! t4 g% y0 K* o/ Z0 s7 V# {( t" m4 l' P* |) y8 |5 [! f. D' W! {
: P3 O- B8 F5 D. A: X
# 使用 epoch 数生成单张图片4 R, C6 I0 C. c# c( w1 L5 B
def display_image(epoch_no):1 Q: a2 ]/ r- K- G
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" P9 X3 Y. B3 C. \, U% Y
& R$ d% y3 h$ x4 o+ G, J
display_image(EPOCHS)
. [; \3 k3 F; T2 a/ D/ ^% hanim_file = 'dcgan.gif'
: d- f" U6 A; `+ i ; I3 F. k: o3 _. N& |7 V
with imageio.get_writer(anim_file, mode='I') as writer:
# R: S' s H7 k2 ^% f4 E filenames = glob.glob('image*.png')
" r5 _2 ^% F5 v; w( ~3 r filenames = sorted(filenames)
* L% |& h6 k) m! F# n( [" r% y) w last = -1
- n: |3 {' Y; h, G% p" S/ f for i,filename in enumerate(filenames):
# L0 E/ f, V4 U f: `6 l+ Q: o frame = 2*(i**0.5)
9 [. s) Z o6 o, _0 @/ x if round(frame) > round(last):
, t, d1 O2 ^6 g7 I last = frame+ i. N# R* r6 \3 {
else:
' g, x, \+ \+ h1 g8 H1 F/ Q$ n continue
2 G3 O8 M9 q X' {/ j' Z! b4 ~ image = imageio.imread(filename)( W N% |- D4 c
writer.append_data(image)
D, E Z0 I; ~' X image = imageio.imread(filename)2 I( p0 t: V9 r. B7 D. P
writer.append_data(image)
$ x" `4 R* s9 \& A) x5 x+ I
! p7 d" }9 R# E0 dimport IPython
6 F- V1 t4 _* `& p0 Y! J9 M! `if IPython.version_info > (6,2,0,''):
# E0 w3 G' S3 r q display.Image(filename=anim_file)
; ^4 b C# m* c n' H4 `- {5 E* w5 J: V* {
0 o2 n) u! s& t: p
) O9 Q& Q3 Z. j4 v6 g3 c. C
1 v: c6 q7 `$ m7 U8 P; Q4 Z& o# Z完整代码:- Z% X- ~3 f3 `
' W- t+ O" z6 k0 D# o& t
0 w% j1 g; w) b; `/ N3 |6 @! C0 }! z
import tensorflow as tf' f, E- \5 Z0 N2 c- j) f! c5 ^
import glob
+ p$ `. B5 ?* w8 Rimport imageio3 l* b( X, }- i. v# W* E& D0 \* q
import matplotlib.pyplot as plt3 f* w9 a! P) _: {' f# W' a
import numpy as np7 _) t& s! W2 @6 ]4 [. \
import os9 Z8 B+ x$ R8 y
import PIL
7 C; r4 T' M; v3 lfrom tensorflow.keras import layers
6 Y/ y, R# X) ~1 b8 n: `" J Eimport time
9 A" g% D/ L1 \3 O2 {: V7 o 4 s/ i0 r: m2 @$ K |& }
from IPython import display
: L1 }1 k' f! ]; a ) s n! t. i- L0 s8 W
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
i( j: \& R# K3 n- g: { ) o7 e* I g0 b
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')7 a/ Y K. l2 y
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
5 K6 q( Y+ h2 p4 P, M' A' G 5 \$ T9 [. ~) S b
BUFFER_SIZE = 600005 p" k- Y2 F, a8 q
BATCH_SIZE = 256' q' B) j- h- [8 P
8 T9 ^5 `. @% w) S' a+ {* b# 批量化和打乱数据
2 i' {% }# }7 G, v, M: D6 ?! }; Gtrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
: Q; |! v$ @4 b* K
5 Q- j/ F8 i: t$ L# 创建模型--生成器. q f, `$ J0 R( ?0 o2 A
def make_generator_model():
8 m, H: D/ I- `5 u/ v model = tf.keras.Sequential()4 o4 r1 p$ C# @/ k5 |1 s
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))1 h: i2 |$ F u* |3 d3 u
model.add(layers.BatchNormalization())3 d) ^" V- x8 u0 w+ {/ f
model.add(layers.LeakyReLU())
f- d" `- K2 }1 b & X8 I9 R. d+ I0 ^
model.add(layers.Reshape((7, 7, 256))), Z/ X! ]; Q: H" q5 o
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制! b& M7 {! V+ F& x- P
' k+ q* B9 S e" t* q. |& c; T model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
1 l( Y% l+ e5 u' l# v% x; s; P assert model.output_shape == (None, 7, 7, 128)8 o; o) b3 V0 M% @* O% l, z. ~
model.add(layers.BatchNormalization())
! {5 P# E# K$ M# G9 z% | model.add(layers.LeakyReLU())
; \% x+ g8 k5 y. e" c" h$ ` 5 H9 g2 X) ^, w @% K2 [# H
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
1 W5 d9 i4 _3 n/ \" [1 Z9 _4 ^ assert model.output_shape == (None, 14, 14, 64)8 D# m7 q* ?# u8 a3 Y I
model.add(layers.BatchNormalization()) G j, T8 I$ F% F
model.add(layers.LeakyReLU())
% E2 y" |- f' a& T$ \ $ x [# ]! I2 \* c# {
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))! K, m. ]- e" O
assert model.output_shape == (None, 28, 28, 1)
+ ~. j- `( i$ Y k2 S/ l# { 3 p$ h0 T" v# \0 ?# k5 F
return model
1 t O' S- a$ W/ B; `2 q/ E$ j$ _1 v 5 @$ Q: F1 _5 }$ \; [ [. M
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。! W+ u/ [! f' H( \
generator = make_generator_model(), }) ?7 f+ c2 O% s
) |- Y; n% j( v6 g b# Q) v
noise = tf.random.normal([1, 100])
+ Z: w; E3 C: O+ S* n" p* y; e! X$ Sgenerated_image = generator(noise, training=False)
8 W0 L5 Z( Y; `% O) s 0 z4 F' n+ F- E; l
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
: c% U( M+ C! ^: B Z' Itf.keras.utils.plot_model(generator)
' K; c0 Q/ P& k( P' d7 F6 C V6 I6 r- I2 P
8 k5 `9 M( }7 R1 O" R5 g# 判别器: C5 g8 l: t5 T }! {
def make_discriminator_model():$ ^/ b1 z1 s7 r
model = tf.keras.Sequential()# r) f" p8 f$ }! U2 I
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',5 ?- ~9 f. r0 T
input_shape=[28, 28, 1]))
6 |9 j4 U8 n/ d- K' q4 x; {) A model.add(layers.LeakyReLU())! K; l# S0 d1 a+ p; h1 i4 g/ v4 Q$ d$ L
model.add(layers.Dropout(0.3)): A$ ?- @- I; O, r) S2 q1 W
2 r- _. D9 ^- _6 }3 c# Q2 p model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
s+ F# l) |- X2 @8 f4 z! c model.add(layers.LeakyReLU())
8 k6 C6 y' Y5 I5 l/ ` model.add(layers.Dropout(0.3)) V8 }) R8 r+ E1 f( M
$ c+ |( _& X# t _/ M% G3 q
model.add(layers.Flatten())) z! S: f" `9 c8 d5 d+ ^3 \
model.add(layers.Dense(1))
" i `; O* V: M* E
4 B9 H% o+ f4 Y" [' _! V return model3 t& D: M5 j5 n* q
. I; ]' R9 }% U" Q6 a+ N# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。, \& z) W& I. r9 Z- K8 R( T: l/ K
discriminator = make_discriminator_model()
- `/ I0 L) N+ i/ O/ q4 z( Jdecision = discriminator(generated_image)
+ K' ~7 V/ f& V2 Z3 uprint (decision)8 W0 }0 ^+ m' s, z( z
. ?# G! Y- b1 T* ]5 o6 l# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
$ a/ {7 k2 } X+ x" P6 Icross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True). ^2 O: S% A% C+ ` a
, A* B) [& \' p# d$ G3 n# V- |$ x
# 生成器的损失和优化器
$ @" V K- r! `0 a0 rdef generator_loss(fake_output):9 N1 H: Z5 l5 @& [4 M4 ]9 f
return cross_entropy(tf.ones_like(fake_output), fake_output)
! q6 I' ?" b) `* a% e- a. ^1 ggenerator_optimizer = tf.keras.optimizers.Adam(1e-4)
" r1 T$ _* c h% m7 w
$ K& i$ R w; L5 }* w% H8 X. s# 判别器的损失和优化器. M" b6 z# m2 K! f
def discriminator_loss(real_output, fake_output):/ }" T3 f6 R, t" N2 u% C
real_loss = cross_entropy(tf.ones_like(real_output), real_output)$ P7 u+ _; C2 T# q8 Y
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)( Y" Q+ i& P9 H! r. J4 o# Q4 Z
total_loss = real_loss + fake_loss
) L( V v; d/ o, q' u( _' U return total_loss# ]* i }4 D( e0 Z( a5 j! d* M
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
. d) n& G. y/ ]3 v7 k3 |
7 ]8 P7 X! k! f4 Z& q& b% v Z' X3 N0 c# 保存检查点
/ S P% S) P$ h4 K( Ycheckpoint_dir = './training_checkpoints') R' v# V) Q% q4 N- L
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
: h/ j; z4 u/ A5 M" O6 J2 F( `" Ycheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
! W& ~1 l, ]8 U/ {% ` g7 j discriminator_optimizer=discriminator_optimizer,
- ?, y {8 ~' Z% u5 F/ v$ b0 }* A5 x2 U generator=generator,
w* {, h% h8 b" @! s2 O. s2 e discriminator=discriminator)
. ~ _& Q, Q8 V& ?4 \' ^: l4 ^ ! L8 b2 W. M' _1 J7 @
# 定义训练过程9 y( V H1 _& l8 b) T
EPOCHS = 50! ?: A& ^+ k9 A7 L, Y/ w: X& Z
noise_dim = 100
2 o7 r1 b5 Q- `- E* \5 \# Enum_examples_to_generate = 16, b( {) L$ B: d3 Y r* H6 H# X
* o0 N S1 g2 E( d# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)/ F7 d; u2 |2 z/ h
seed = tf.random.normal([num_examples_to_generate, noise_dim])
. i' @0 Z- ] e
V2 w* ]5 `* d) ]1 R& P# 注意 `tf.function` 的使用: L6 a3 c* M7 {3 C9 h% }2 u' C
# 该注解使函数被“编译”
9 v+ W$ k1 B8 C@tf.function- @, o! T( m |7 I* s! s u) t
def train_step(images):
! O! B" ~ f6 H noise = tf.random.normal([BATCH_SIZE, noise_dim])+ X1 [9 ]; r" a) t i# t
* b3 g O+ H4 `4 t9 Y- D/ J1 l with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:7 a' P6 Z5 B) `9 i) F5 e6 Y0 i5 C+ {
generated_images = generator(noise, training=True)
' I# j7 }6 K7 N 6 l7 W8 p, ^3 c% h$ i
real_output = discriminator(images, training=True)
% K, D# |' ^1 ` fake_output = discriminator(generated_images, training=True)
0 o4 K: [( W. G* |! z
8 F( [' K. ]& G- w" B gen_loss = generator_loss(fake_output)/ A- W( p6 a8 |! p
disc_loss = discriminator_loss(real_output, fake_output)" m! X# x( B& {- k5 `" ?8 x( Y
8 X+ z4 L$ j u- H: s gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)( p( t# R- c( d4 [5 p
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)! ]% y- Z( g( t
. f7 z3 q" k' ?0 t- B
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))9 [- v; r0 c) O2 [" b6 c& C: B
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))# x$ ?! S) r& T# K: @
/ B4 h Q7 _0 R! f( n- ?
def train(dataset, epochs):$ B- U- W0 s6 L. z1 {6 X2 @
for epoch in range(epochs):9 _3 i# |. v: {7 B( r* v. L
start = time.time()
% B- z" C7 O4 M5 ?" ?
" S3 N7 u) K: L( l" ?6 s5 K for image_batch in dataset:
6 T- T$ k* F% E* v# H2 O' N; B1 z train_step(image_batch)
( _* t A/ u: O: N" t' |7 ] . d8 U! u' M$ A, q+ U
# 继续进行时为 GIF 生成图像) c/ j3 x) ?0 h, `
display.clear_output(wait=True)
3 R# E3 u8 x8 ?; l# C generate_and_save_images(generator,) @6 j) R0 k8 \5 ]6 V
epoch + 1,
3 u% y. S+ n n0 O seed)
0 k9 m: I. C( r& G+ a: u
' r y* o7 Y8 z2 {. [8 m # 每 15 个 epoch 保存一次模型! _3 e, m w; s% d9 r6 Y
if (epoch + 1) % 15 == 0:
0 b9 A' L/ b. p. Y- ?+ A \- W checkpoint.save(file_prefix = checkpoint_prefix)
. b- F* S5 B8 {2 x; J) ^
5 @1 g1 O, j$ Q* @/ A' | print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
% O/ |! ]4 _+ C
/ a& Z: i# Y/ F # 最后一个 epoch 结束后生成图片1 U7 n. Q2 E K& C8 q9 r6 ~
display.clear_output(wait=True)9 K8 n$ c6 O4 F0 s% n: U
generate_and_save_images(generator,
+ M* c; d+ T' `% I" c8 k! O epochs,7 L2 ]- \. c0 U; \' f9 i
seed)
; c4 O- w% y: n
1 z% N& v3 b4 Z& E. P* s9 x# 生成与保存图片% Y) x) L$ U* N. E& K$ N
def generate_and_save_images(model, epoch, test_input):
8 Q4 d, h: c: ~# s! V # 注意 training` 设定为 False" B+ _" Z p+ a* G |& Q4 K t9 J
# 因此,所有层都在推理模式下运行(batchnorm)。
: [/ w0 _0 [' v. K2 X" {5 H predictions = model(test_input, training=False) \6 S, `% ?# a5 r% [ f- G' x
2 D0 X; N2 m5 ]# E; Z
fig = plt.figure(figsize=(4,4)); I* o3 ~+ \# c+ m, r3 y: `
: h/ h( x. b4 R; D1 R; H for i in range(predictions.shape[0]):
& l! K$ Z2 ]7 _7 L plt.subplot(4, 4, i+1)
+ W) g5 q n( ^: E6 l plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
& O9 P0 S0 V# p F' I" D# [- ` plt.axis('off')* v9 w; N+ i$ ?+ H
7 z9 }3 Q/ N3 I8 K+ H" y# B plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))( V) P- g6 F( r" a- R
plt.show()
9 h+ o4 r$ a2 L4 B( c. u7 o1 m ! S2 @$ W% |' }9 O& D1 {
# 训练模型! `4 V8 ]6 }/ N. d
train(train_dataset, EPOCHS), ^( R# E2 n$ h4 a4 A1 O
. y/ @, d" r9 T9 p7 q. D" K6 j
# 恢复最新的检查点
& R& {! ~5 O2 \; r& `$ ocheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
: X. ?1 X6 s" i' p/ P: C3 C
2 B8 g& `+ `1 R# 评估模型
: k4 p9 O3 C k7 W- g/ K+ M$ e# 使用 epoch 数生成单张图片& U2 |, x% U. ~ E
def display_image(epoch_no):; C' a2 Y+ ?1 x& ]
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
- ^3 n. {0 U) x4 r+ | ! r6 y; Z# K( k
display_image(EPOCHS)6 z" J5 U' a, D5 o
+ w/ _& J- h% n, m* A1 l5 @
anim_file = 'dcgan.gif'
+ s) U& y5 j& t) d# w # l- L1 D& M% B2 \
with imageio.get_writer(anim_file, mode='I') as writer:
; N1 K1 j% U7 q& M filenames = glob.glob('image*.png')
: H; m7 J0 M1 q T) A9 R) H filenames = sorted(filenames)# L, Z; ] s6 g7 U. h
last = -1
V& q$ A, Q; A for i,filename in enumerate(filenames):
. B" s) {0 v' I/ M% { frame = 2*(i**0.5). E. d3 { L( d) R6 `
if round(frame) > round(last):
, N" f2 C% @2 i last = frame c6 z2 D0 ]! L$ _* C/ N
else:1 K! }$ J o" ^$ `6 u: Q- s3 I4 u/ J- l
continue
$ d. }' F- |) \! H) Z& u image = imageio.imread(filename)0 m( S) m+ A5 P: V+ Y! K O- @
writer.append_data(image)( ^6 i* e" u6 [1 @4 W5 C$ |$ B8 {9 H
image = imageio.imread(filename)1 ]( E6 G+ G/ j0 H
writer.append_data(image)
: M5 C. o4 h5 [+ {
$ s# A, z8 D, K7 ?import IPython9 [ L V: D l. `( f& E
if IPython.version_info > (6,2,0,''):8 D7 _* \9 [- p$ |! l8 Z5 a* y
display.Image(filename=anim_file)% C& o1 K- o2 y8 E" S
参考:https://www.tensorflow.org/tutorials/generative/dcgan
& L H; N/ R; i* M————————————————. r9 J# a, U, N4 ] w) \
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
/ u+ q& |8 @1 @6 \0 z8 b( e原文链接:https://blog.csdn.net/qq_41204464/article/details/1182791111 p; c! i6 o/ Y5 C3 c% [$ R% R- }
& ~$ I+ K+ u3 K! ^, M0 B8 i& {0 ?. @: \# t
|
zan
|