- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 563271 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174204
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
* R# S* I" V" i! X/ _+ k! x8 g深度卷积生成对抗网络DCGAN——生成手写数字图片
, x1 n& y( i6 p; Q# E! Q, I前言
1 e+ @& k/ Y, ^ v本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
( O8 {8 s2 d. }9 ], A) i5 `% O E/ w V$ L
* S' N& T g: B4 @5 C
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
{4 [( {2 K$ ^
$ ^$ {9 `$ Y/ g' ]3 F5 H( ?6 N& u. V2 v/ V+ Y
# 用于生成 GIF 图片
) ~5 }- _- Z' [9 m& U% Y w) Ipip install -q imageio
9 P1 Q- V6 \) H/ V4 @6 Z- J# t: B目录
! J3 `* r I. @8 D4 r' m8 [8 y3 n; X; P" U, I3 j
4 v5 W$ P4 C7 A% w0 ]& ~% ?/ N
前言* o* X. M0 B* e; d+ @% v
$ L7 G# a) f7 i6 g
1 g& T; s2 I1 @$ q/ Y9 m一、什么是生成对抗网络?* x+ z& ?4 A+ R; Q2 j D5 |
" M& K; P2 N: _/ K& x+ ^
- t7 g( H* o. U( T二、加载数据集
% d8 f3 j# M# P. ?# Z4 i
, e- N6 M Q7 f4 X- S
. }" N; n! @# O# ^( X三、创建模型
0 [, C) _+ l5 d8 A; C' V
- z. h! b+ O& A: I
# j* r- l2 C& N! Y5 K G5 B+ K. ~3.1 生成器
0 y, }& G* s3 h2 Y; W4 p4 O
" w1 F- r# o' o" Q2 a; }6 v* n& v w$ h# N) e) H1 f" b& {
3.1 判别器2 C* Y ]! n( R3 S
; L2 O; M. ~- `3 z! `7 V, L3 w* _/ i% J7 V1 U; v: b/ E' g Y' Q; w# w1 d) R
四、定义损失函数和优化器
" O3 m7 c* N" m/ D( l+ }1 k; g" J0 ]4 Y, _ ^
3 j9 X2 t9 D0 N/ }' x& Z4.1 生成器的损失和优化器: `& u( X& L) E# d. A: I9 ` L8 f
# Y( \# d. j1 t* U& H, E
* I4 ?& _; ~! F) o0 q# X4.2 判别器的损失和优化器
, b0 z! N/ z* ?% B) K4 @6 g7 ^, J, p, c1 }/ j
# S. I* V) z- n7 }" L五、训练模型' B2 E3 ~; }5 |
* @+ n4 R) g/ i# `( i$ N
, V+ c& ?" \% ]9 C X5.1 保存检查点
0 Z6 f/ q# [6 a- e- P* B
" Y( S0 j- `+ L- _( R0 v+ O7 G3 O' a0 l/ C
5.2 定义训练过程0 o, V/ B9 I/ \6 [$ V5 v: ?
8 N: W5 ~$ F/ }2 ~
3 @# V1 R% c/ A2 `) ^8 h4 N" k5.3 训练模型
& _7 @( H2 D& C; o; ^ N
. T% H+ g+ N2 w4 X; S, y. F# [4 G# \) j1 k5 u7 O. t& E
六、评估模型1 c9 S# C7 [- j* i* {8 K$ T9 j
4 z6 U4 U- [7 T" h' C7 f: L$ J
8 Y& L2 |1 [8 R2 p; O- o5 J% N$ a一、什么是生成对抗网络?( o3 h4 ]# D9 q( v" G
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。3 p+ a6 H4 O* ?
* o( _: r3 r' w% e8 ~5 }( d& m" [+ @ s" x* e3 S
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
' R8 z+ Q: S& r5 G d4 C/ {0 z6 F7 t! H p; ^
8 i( U8 C( B3 N2 g) m2 m0 P
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。 `8 m( E" q7 m2 `, k5 c: P; F
% W* j* G% Z3 h* L5 `$ X8 e$ N
$ P) i1 i' v; I5 O
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
; M* R& v/ }/ e3 D; v' u( c3 g: Q' P1 x/ Z o0 T
0 }: @. X+ p% C( Q1 h: y4 t当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
* v4 r1 V% H1 o" B1 d+ V! S5 u
# O9 I9 J# `# a
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。% H* h8 P& ^5 K
! ^( S3 n% u9 M. l {& o( [
- L( X) T H! `' | c/ T* s二、加载数据集
/ b. m- L, o. l H" m- r, g使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
) _3 F& n$ A5 n& v( ]- S" Y' v6 L% l d+ P! M; F
7 Y9 L) j; |; c. Y* G
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
* \1 Z. J- L3 x- h& ~: E ) m: y9 X# U* Y7 a
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
" A8 g6 u2 O3 |# i2 B2 _train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内$ q) K7 ?- F- D2 |! W
( D: P& @6 k1 v8 Q8 E' M5 ]2 H
BUFFER_SIZE = 60000
) _. `: u! R# s* a, d$ N: dBATCH_SIZE = 256
9 P. i; X m! T4 i# [
6 \+ M- M$ ]' i! x$ i }# 批量化和打乱数据" I0 i% f8 N' o: V! D, U
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) V/ t1 ~: t+ ^4 b" P% m
三、创建模型) Q+ y0 Z4 \* J
主要创建两个模型,一个是生成器,另一个是判别器。
9 E* m1 `) g% g& A% Q# g+ `* r# k, E4 a; G/ H
' y0 X4 P `; S3.1 生成器
# o+ ^. E3 H( d8 Y2 ?( V" \. E+ s生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
$ X# W! V4 Y4 n ~4 Q: v" E7 e$ O! n8 A- l! c" }% X( T
4 N0 m- }& v! R2 Z9 S
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
4 E% o, y) h8 n& I7 ?( Z- y5 B# O( g) Q( ~5 L- [+ X
4 G5 h7 t' o4 I% o" d- N
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
) d- U# E( }- n' P9 h
( q' E$ V5 B0 e' i0 Z1 K* `1 n
, I/ {8 R" C; x. q* _- P( @5 Cdef make_generator_model():
8 q( y, m Y U& P# h3 b4 m. D# S model = tf.keras.Sequential()
* V% `+ g7 ^2 v& E/ `7 ~! n model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))& J- W8 g2 m0 _; O7 n1 L T
model.add(layers.BatchNormalization())& `; z% ~4 m9 v2 L' n. _. y1 |$ X
model.add(layers.LeakyReLU()). k `3 ?: H6 N5 y4 L% `& Y
' ]5 @- J# c, D4 H model.add(layers.Reshape((7, 7, 256))). f l; p( ] e$ R6 U" g
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制! z8 y* K- q# s- I5 @% H
% K' Y( o3 s4 m# K% N! m/ U model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))4 y! U. s. B; O( V( k
assert model.output_shape == (None, 7, 7, 128)
3 h* [4 L; g; y! Q0 z model.add(layers.BatchNormalization())
5 c2 @; D0 i) W* O- g model.add(layers.LeakyReLU())" P+ [# p& c D( I3 T9 z2 p/ K7 I
, V/ V i9 L! @( r model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
3 |$ J# v6 U8 A& B, N; H+ f. i- S assert model.output_shape == (None, 14, 14, 64)) ~* J% F& }+ W) W( }
model.add(layers.BatchNormalization())$ l r, e6 S' O9 E( C2 u% N. G
model.add(layers.LeakyReLU())- p9 m- T g: q! e: d0 o% U
: h: u/ m- n1 s7 L5 r
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
3 `9 p* p; q. l q, P0 n) D assert model.output_shape == (None, 28, 28, 1)
9 u+ s- S( n6 J7 X# V5 h 6 |: ^9 ]( z4 a$ W
return model" W @" ~8 S7 y/ O5 B
用tf.keras.utils.plot_model( ),看一下模型结构1 q1 X3 s% n H8 A- l7 x6 f; J
( g% ?# q& O6 ?8 H* t- T
0 G. E& k% {: o# L. E u8 Q
: P2 O) Y0 g7 \8 w7 ~
( f& [3 R7 B$ W; I8 C1 g b' ]' c, V/ a' l) ?2 j
用summary(),看一下模型结构和参数8 ~1 A; A. w+ `$ J9 O+ x5 z
: v/ K& B( Z4 y; A+ Z% K! T* E# h/ w' y& T4 c) A
: ^0 \& N# J0 Q1 T: _
; g5 e. g& L9 w# J& N
- f4 F: j9 ]) ~- b6 a4 l' |0 h
( \1 K: U& y# H" T' n9 ?
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。/ b t. d1 R! a! b* n/ P$ d) W
5 Y* T) }6 L: P* ^7 o5 s8 H, A, E7 h" g6 W
generator = make_generator_model()
- R4 Y& c7 X, r
# \, N3 L! Z& b$ f" H, x! bnoise = tf.random.normal([1, 100]): V! [% U0 s8 L; @
generated_image = generator(noise, training=False)( J9 G- M; R; y
5 q. K1 p( l$ ^7 A
plt.imshow(generated_image[0, :, :, 0], cmap='gray')8 f3 P( X% P3 ?9 w1 B) @: @3 `+ m
9 x0 m8 _" `. p& V4 v, ?
% N- C$ I C' {* \9 U
- [0 c1 L& n0 d# o4 ^' x( y8 K+ W( _7 o, U4 O" v0 K
3.1 判别器
) r+ O! M5 q& A* { G8 k f+ z0 \判别器是基于 CNN卷积神经网络 的图片分类器。. j9 l+ f s/ b3 d6 F2 b3 q
0 E. k0 \, G5 R+ u. x0 P( P
% b5 a' K" K* P. }) E( O y: g3 w7 idef make_discriminator_model():4 t1 m" o- v# ~% A& g
model = tf.keras.Sequential()
" T* S+ K2 y- A/ L! a5 c- B& W model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
0 B. E, q7 C! e+ A+ _ input_shape=[28, 28, 1]))4 ^( W, z- ]# E8 | c. W
model.add(layers.LeakyReLU())
6 ~$ ]( v8 y- r1 p' Q model.add(layers.Dropout(0.3))
I9 S1 P1 v* Y" e! ]3 N
B0 B( t& P6 U3 W' a C model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
& J* ^) p8 B- V/ z% N model.add(layers.LeakyReLU())# h* m% ?3 x& r; D& h" `. H1 `9 I
model.add(layers.Dropout(0.3))5 M1 Y% q" m; W& J
1 H8 r @$ [% P& J( i. T& T
model.add(layers.Flatten())
8 y/ L( D/ v5 I2 o1 p1 [ model.add(layers.Dense(1))2 u9 p8 [4 n; n4 A1 h
% I# g0 Z, |" l' y8 \% J
return model: `9 g0 ^2 t$ Q2 ^" U/ Z5 b# q. @
用tf.keras.utils.plot_model( ),看一下模型结构
# ^5 B7 o1 K6 N' u/ m
6 e4 X: W; E4 V8 v M- T7 w
2 k. _8 `6 [+ t4 D1 z$ E
' A4 |& I( @: f0 g( _
3 ?$ _6 ?1 D+ J4 R3 ^8 r; ] r+ f/ K3 j( q$ @
/ Z1 j: i- ^8 N" u. N$ H' h+ H% O
用summary(),看一下模型结构和参数
' q% ^5 a' P6 k7 g8 N3 q1 G! b R. ~$ ]( z% D' K# J
& D* G& U9 n9 O/ y
# A0 U5 {2 t/ p7 e* M5 e. K: [; |9 H
& X0 t* k4 P+ ~( q1 @5 k
6 [4 M m' o$ c! s. U
& L2 O* j( C {6 {1 h四、定义损失函数和优化器
' Q! A T0 M' R- v5 f* l6 }由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。$ @9 q6 ^4 m3 ~% U2 t1 ^
9 A5 P1 z# i1 C) w, p5 X, B
- K: A, a; ]& E% ?" }1 a
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
% K; e s$ O# w6 ^7 D5 R# b0 g1 D8 d
; x% w! |! n# q1 ]3 ~' _ Y$ n# 该方法返回计算交叉熵损失的辅助函数- w/ r( R0 E5 y' d$ n7 {4 Z X
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
* T" G: y# C$ h+ p! ?# I; e1 n# h4.1 生成器的损失和优化器; p# e3 m/ [' X! I/ W
1)生成器损失8 {; @. j* k, F0 `/ A1 [
+ Z" D1 |+ T% J/ e* m! j
6 x+ [, Z8 ]$ ^- i( y; H生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。6 Y, S6 l6 W2 }5 W2 ]9 W
6 V9 L s3 R" J$ C; G, p
2 d8 ]; D1 i7 t2 d9 P这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。: D* ~0 [' n' j
3 P* D9 F# t ^, T
" V8 R. E+ ?1 [. n) Y2 Y+ \def generator_loss(fake_output):
( N* ]7 B& x( f+ _ o7 F return cross_entropy(tf.ones_like(fake_output), fake_output)
' N6 a9 R) [5 g2)生成器优化器
& G4 E- d" i2 `; k+ Y7 j, b) U( \# ~& Q8 ]
$ R0 g1 k9 q8 \' R' P
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
' z1 W" B8 i% U2 _% Z1 Q, |* a4.2 判别器的损失和优化器
% j( D! k5 G+ K. O, Z9 D1)判别器损失
5 V' ?8 s6 |! n1 X+ L3 z# J* j, \
3 D6 x: D D7 U4 a% P
x. n1 z# V) h; J判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
$ ?, ]0 w {6 n b* `, k( X: F& O: G7 q/ E8 }1 u/ t4 Q5 o
# E% i& w. X0 d, S) p, ~
def discriminator_loss(real_output, fake_output):
* Z: a( f9 l2 E5 C. d8 b2 X real_loss = cross_entropy(tf.ones_like(real_output), real_output)
+ k4 K [# Q0 L, ?; M* o) @; i fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
# c, W2 D) p O2 B: ] total_loss = real_loss + fake_loss% g- k. _/ M) C0 a, V- _, M8 o+ U
return total_loss
; j* d% d7 v- O7 r2 H3 _0 z. S2)判别器优化器
+ s8 g6 c& z0 o- c; k2 E2 W- E) P O' S6 h7 L/ L6 z
8 _7 E6 z7 } F$ }. m4 V
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)1 @* W& E+ @5 b) C5 p9 n& X/ `
五、训练模型
- G3 l& X8 w5 J3 q5.1 保存检查点
% g9 a( q( k( A7 e; ]保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
" {7 i& f6 D6 d8 B1 [ V! Z5 V: k: n# \. b
8 T# [2 p8 W9 hcheckpoint_dir = './training_checkpoints'/ o$ s5 v. v! g! C
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")2 g8 W) K; ^0 p: K
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,5 ]& e4 E& s( \
discriminator_optimizer=discriminator_optimizer,
* R7 L7 f# W5 Z5 ~- w generator=generator,
/ M& U4 I% T5 @! g# A) U discriminator=discriminator)# R& H4 M' C. L4 T3 C1 l% h$ E
5.2 定义训练过程2 j4 J, [% i% j7 U6 ~
EPOCHS = 509 ~3 w/ P7 W" a: ]0 B
noise_dim = 100
9 b0 Y, Q; u7 B- Snum_examples_to_generate = 16! s5 Y' f- _$ F5 N/ G& V) U
+ ~7 ]3 o. l* ~( g/ E% x; {
4 x4 H6 y- z( j$ i! c# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)' S/ A- w Z5 Y" o& E
seed = tf.random.normal([num_examples_to_generate, noise_dim]) F) O4 G+ g8 a; V1 ?" r7 }
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。! n1 o3 ?8 U/ M! e: e$ }1 K& h, T+ u
, n9 i; Y2 |* q5 P$ {
' S, {' n" P# R7 p7 f: f6 t+ H判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。, ~% H5 h$ G. I3 G: I. a- z
* f. g2 b/ q9 [" H7 _. O; y
9 X9 L+ [/ o7 f p1 t
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。0 C) u- x7 C- B! k% I& D4 _4 U+ J
) n) h4 M4 s6 z, ^' B q
+ M8 z, V K! t2 y6 Q# 注意 `tf.function` 的使用, V7 D" ?& u$ e8 j
# 该注解使函数被“编译”+ I. n1 \' w" U# W
@tf.function
8 p& F) j; Q' ^( A, {- F; edef train_step(images):
' |0 N; h6 w6 `0 ~' B noise = tf.random.normal([BATCH_SIZE, noise_dim]) Q* Y5 d) T2 F/ _! r7 m) r6 |
0 p' k3 Q- a9 w. U% B. Z with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
3 x& }. V$ M. u1 @ generated_images = generator(noise, training=True)
5 |" D* l! b" m
/ W5 E& Z% Z! J. o2 d9 ~% v real_output = discriminator(images, training=True)4 e* e2 G9 T2 M, H4 T0 b
fake_output = discriminator(generated_images, training=True)
* Z I+ x' f; Z
+ ^7 Q T. D' n3 o% ^ gen_loss = generator_loss(fake_output)
! v( c o. z6 g. _5 v' @ disc_loss = discriminator_loss(real_output, fake_output)& y5 y$ x& q* [, b* v l
4 Y6 r! F! }/ N- Z( e
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
9 g8 c, b( t# q( M3 {9 [7 \% r gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
: o) U6 ]( n( Y5 D) M9 x
+ L3 d2 [1 p. H4 T4 |2 z; ^' ^ generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
( z+ @* y4 O0 t) L+ ?9 ^! ~3 J discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
' v4 H: j) K% H% I/ }
, w1 Y1 h& q, Q- p C6 jdef train(dataset, epochs):
- B: h+ s+ z/ `% ] for epoch in range(epochs):# p5 D9 _3 R/ {- X# ?1 T7 K
start = time.time()
8 Z* K/ x6 R6 j5 V : z6 f5 C; o8 ?9 @5 G$ {! R
for image_batch in dataset:0 u% s6 Y; o/ m; w6 Y
train_step(image_batch)
; Y" t) N6 B" f7 c' o& O3 Q+ R 8 `- `: y9 z! g' i$ q2 I: @
# 继续进行时为 GIF 生成图像
; e; J* ]1 L& }/ }+ _+ ^0 o$ W. O8 T display.clear_output(wait=True)' f6 e: w- j8 Z5 B- `
generate_and_save_images(generator,
! d/ Q3 y+ @ H6 G; {: h epoch + 1,
: J+ `; A# ?: }$ ]* u0 X seed)3 L3 u$ s2 C# b4 O* m$ N5 `
/ B' u2 c" F) z c7 E! Q # 每 15 个 epoch 保存一次模型9 L: c% @% ?" ]" a
if (epoch + 1) % 15 == 0:
6 \+ t' V) O( D6 p: x checkpoint.save(file_prefix = checkpoint_prefix)' B; F6 K# t$ U" w- |$ J- K
% r$ `$ k$ }+ ~6 k print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))* a& D5 c8 L0 D8 B2 u; Y
4 d: \. U1 ~( h% O( ~+ C9 Y # 最后一个 epoch 结束后生成图片
9 c3 z6 O' m9 o: b7 c display.clear_output(wait=True)9 x9 o9 c; M- k |
generate_and_save_images(generator,# W6 }' M9 G" J/ ?# \
epochs,) x* V7 w* }# F# ]$ F! W
seed)! P$ k% b8 X. ^0 `+ R7 V' O& ^7 z
/ D! l, ?7 T. u8 x+ i" s# 生成与保存图片5 B# V9 |7 B! l, Y; f2 J
def generate_and_save_images(model, epoch, test_input):8 ~- ~& h) W0 }% n. y% k
# 注意 training` 设定为 False
* y) T: [9 U9 G. z$ I( _ # 因此,所有层都在推理模式下运行(batchnorm)。" e/ R" }# h( Q* B' ]
predictions = model(test_input, training=False)
) H* K$ m9 f* r0 o& p8 }6 O . P. \% x. n/ B2 x
fig = plt.figure(figsize=(4,4))# e! `8 P4 w# w
+ U o( G3 C: E# h for i in range(predictions.shape[0]):' T6 V9 Q' Y/ T) Y4 \% }% Q# T
plt.subplot(4, 4, i+1), w A4 K. M% |+ e' Y) ^ b, K
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
5 E3 J, S- X2 U( D- P. l6 j, [3 O plt.axis('off'); t o3 [. W0 v; `" E: `
$ i- {1 N3 D: C- Z k4 j& y( \
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
# [3 Y4 M" W( W2 K0 U; {9 ~ a$ | plt.show()+ I# t; l$ ], a1 A Y7 f2 |
5.3 训练模型+ p2 M, Y& `5 L! s' o4 x5 q6 c2 G
调用上面定义的train()函数,来同时训练生成器和判别器。
' k9 N" K. i/ `9 A9 L2 R! ^" J* Q4 [& D$ b" x
1 D8 y% O( M2 r注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。: P9 A* }( I4 y2 [1 t9 |+ ^, J! T" d
% T6 l6 ~! |7 ~; a* g9 W" t$ C1 w" m, @7 z1 J# T- B
%%time
- Q9 y4 k, v6 X% gtrain(train_dataset, EPOCHS)
: W7 Z6 c2 k. n1 f, W4 a% q! C* z在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。+ e) u+ j( m" E' _
+ E: s: h1 _2 ~& M a$ X
9 G$ O N3 F$ Q训练了15轮的效果:9 C! \4 j. F8 M; ^% l7 J7 |
* z/ W: d: |9 K8 S$ [
! [, D6 W" M+ [
: Q9 k9 ]* `% l" v B
) K+ E. y; O3 Y6 q0 l; d4 y( u
9 o: [. V7 }. g4 c训练了30轮的效果:
1 {/ o# k3 \0 K, c7 H, N$ @
1 i; N# B* ^1 J9 S7 \0 n9 w8 C; ]$ k
; Q( }+ M7 g3 ]3 o: u" S) a4 h5 G2 y3 C5 F$ T1 ~: W
" g% D# e+ J( N! r, p9 H1 [* e
2 [& ]+ G/ N( a! d' d( }训练过程:
m4 C& R6 q5 m- }7 h9 g$ ?! R- o+ K. U6 x
" U* q0 S) i( h, l! |
. [/ G N- u$ h3 J" ^/ }* n8 V( `! Z) ^3 Z8 |7 U5 s6 f5 n; `
2 D3 y* ^& u* K6 x" [
" |" q M2 a- x `! j
恢复最新的检查点, o' H X# }# K/ x3 {! Q& b
G, n I1 k% {7 D! u
. `' w% X- ^( dcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
. Q! _5 w( C+ k e; ?( e六、评估模型0 W8 \1 l5 Z, `5 j D
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。; f2 Z+ U) F# m- K) c
& v. f& n3 v( N* j2 ~' ]1 T6 Z9 J6 L9 u# u
# 使用 epoch 数生成单张图片
+ f2 I1 M$ R' j3 kdef display_image(epoch_no):; Q6 N0 S2 F; V6 l' j; t" }/ n
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))( v7 D2 w) _7 U8 k% }: \9 e p
5 m. \1 f, @9 f5 g! ~# l: w* I3 b
display_image(EPOCHS)/ w0 j( w& p6 m5 i
anim_file = 'dcgan.gif'
, |) d' R. j- n$ ^2 r1 l+ N3 h6 d - g8 Q- V6 e. w
with imageio.get_writer(anim_file, mode='I') as writer:8 z5 s0 m0 x1 |& m$ N- h9 p9 _5 i
filenames = glob.glob('image*.png')
; l5 T, j, k7 ] filenames = sorted(filenames)1 v3 ~7 X$ T# V4 f& A) L& `: h( p2 |
last = -1: i; T% _7 `7 R/ V9 U
for i,filename in enumerate(filenames):- f0 ^ I1 V2 ~$ v; p9 Q6 Z
frame = 2*(i**0.5)4 x+ J! H( U- q8 r5 r+ G
if round(frame) > round(last):
v% N( \! i3 G3 N last = frame% w5 p" b+ h7 @5 U
else:
# a7 ^% ?2 X- O& |) q) b3 N continue
( Z7 |. v* g, W$ i5 s* ] f7 o image = imageio.imread(filename)
& n. I8 `' v" ]+ d writer.append_data(image)
# X7 x4 A( Z, ^' }' c image = imageio.imread(filename)) p: {% w8 {0 _
writer.append_data(image)6 ~- q8 U* A$ w0 e) v/ L" T1 M
; V! h# j1 i- z8 O z- Q3 D
import IPython
% H7 a' V# ^6 p; k6 rif IPython.version_info > (6,2,0,''):
5 B7 u9 `, c9 S5 l x! H$ ? display.Image(filename=anim_file)
. g- M& R, J5 k4 x- @* n, v1 h0 ~8 S8 B& b
: Q T0 N* A% Q+ y, \8 t) x0 @
& W- D( [+ V d
1 X# h* Z* x2 K! K3 D7 [5 }完整代码:
! B6 z M: l* Y# \
4 b7 ~8 \; R% o
$ _& C* z; O' W) V5 Kimport tensorflow as tf. {7 B2 B% d* ]+ N) Z4 y2 Q0 R4 o: D
import glob
( f, b9 j& y( A( V& `/ k, cimport imageio4 T3 M V, c c: X8 p$ h
import matplotlib.pyplot as plt
9 S- j2 k8 i- S2 eimport numpy as np
. l5 ^7 _1 K0 h/ Y" jimport os+ Y; z9 C5 ^: _1 _& Q1 B
import PIL
% k$ {1 ?& {5 R7 c7 Ufrom tensorflow.keras import layers) a+ B$ B5 U! V) X: u
import time
# \, N2 V) v9 ?1 a8 q6 Z
! P" }+ ^- Y; P( b( H& Cfrom IPython import display
- v9 ~( Z* [" ~5 W
) I4 o- P$ X+ l% u) \& l* W(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
! C6 P1 b5 a3 b$ U H
- @, X: ^ v3 _2 B* m) X9 n* Gtrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
* w6 d) d/ [- A5 `5 O4 E9 Z; ntrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内/ B" v$ k- E {0 C3 L, _
# {8 P( \) q8 |" L) GBUFFER_SIZE = 60000) ]% Q2 N8 ^; U/ B
BATCH_SIZE = 256
( E7 E# t* Q5 b' N, O, C; N
; ^9 _$ U: F! \+ x9 f0 D/ u# 批量化和打乱数据
1 U, z, W, W1 i# Z. htrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)( @6 c7 h: R1 n+ k' y
4 ^8 ?$ e4 d# P- T
# 创建模型--生成器' S* V7 c7 j2 J T# F5 z z
def make_generator_model():" F' b: n- w( {! b& z
model = tf.keras.Sequential()
$ ~% u" {0 Z$ d" b( p model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
0 a: D+ q2 g% p: x" y! B' l model.add(layers.BatchNormalization())3 \; y( w, H7 _6 g$ r# x! b0 R4 a
model.add(layers.LeakyReLU())& X0 e4 Z8 O T2 h
$ ?/ e1 t1 B, f! q
model.add(layers.Reshape((7, 7, 256))). F. Y* T* s& `1 R
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制0 Z+ h' \3 W/ L$ X9 ?
/ }, D: E/ y: @& B7 ^
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) X8 W4 m8 @8 K$ j) C
assert model.output_shape == (None, 7, 7, 128). u( B! S0 L: u' w% x4 S5 D
model.add(layers.BatchNormalization()), }' p- {' y3 j
model.add(layers.LeakyReLU())
/ ]: `: A( f% _& n. @) e4 B. H t
' p3 [8 v9 O+ `$ r! q. x* ~ model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))% y, l) ?2 l2 O) v1 U* x, S
assert model.output_shape == (None, 14, 14, 64)
3 R7 B' A5 _6 P% O M. Q model.add(layers.BatchNormalization())
' C8 F( C1 W0 | model.add(layers.LeakyReLU())/ D* y; I$ s7 X! C$ F6 {
1 x- D8 X& |9 B H model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
( P+ f) z/ e: ^8 n7 H assert model.output_shape == (None, 28, 28, 1)
3 H( ~3 G( n$ v3 u6 Y
, U* E" n+ z6 n return model0 M: |8 y/ u$ h. e/ a R! v
. w: }) y6 M& w3 _! ^7 K$ y
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。( ^5 w% _: h7 c, A5 o) e
generator = make_generator_model()
0 V# \/ X5 Z) P5 o D9 p/ y
# }9 a' j6 k4 @& d6 z H$ f* Qnoise = tf.random.normal([1, 100]), e; M3 n) C, X! c. d7 V
generated_image = generator(noise, training=False)) `1 V5 Z" J I+ }( d$ {
3 P0 P1 y; Y' @- z; z# J( o
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
e3 ]4 G6 @ m+ w- ttf.keras.utils.plot_model(generator)9 V/ E( Y3 ]/ ?
8 ~( I' b) @+ v# 判别器
' P" |1 }; F) [) l) \def make_discriminator_model():
: u O) ~8 o! b; \9 B* L model = tf.keras.Sequential()6 J/ Q( C! J/ _
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
% s) e6 ^* T4 g. P+ R input_shape=[28, 28, 1]))
+ M+ X8 d9 E } model.add(layers.LeakyReLU())
. Y X4 o, }# _( E v model.add(layers.Dropout(0.3))
- J3 V% ]6 P$ L4 H( y1 i 9 W( Q) n8 I, Q: I/ Z
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))6 J/ R5 x+ W' ~; R! ]8 B
model.add(layers.LeakyReLU())
( P$ B% p* t* k$ C, _ model.add(layers.Dropout(0.3)), w. N+ z) b7 h6 l7 ~
: g* p' E8 o! h' e: l, q# ~ model.add(layers.Flatten())
) X1 O X& p3 M# Y. ~ model.add(layers.Dense(1))/ i% ?3 g, s. q- M. H5 Y$ b
; \2 I4 C- g0 b3 ~; B% [: B0 ~
return model2 j- `+ A/ T- w& @2 K) j
8 A; Q! T+ b. P% w# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。 R4 g4 E8 _# Z4 D- Y9 R* ]
discriminator = make_discriminator_model()- ?2 m! G" {6 k. G4 y
decision = discriminator(generated_image)
; E5 G6 J+ Q$ i* c5 U9 k6 Qprint (decision)6 k1 [1 w% p i
% t8 W/ r, N; b! w# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。) s* |) l+ C2 e' `
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
2 c: F+ m' w& O' u
8 y0 V3 F5 f, L( b; s/ i# 生成器的损失和优化器/ r* ]7 Q* D+ B/ X
def generator_loss(fake_output):) M; v/ h* ~* }/ R4 @+ p
return cross_entropy(tf.ones_like(fake_output), fake_output)& b2 H. R. t1 C, F
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
" c3 A& Y9 B, g; o1 s ! v u7 N9 _& C5 z
# 判别器的损失和优化器; @7 E7 s) e; @' g
def discriminator_loss(real_output, fake_output):8 V" m# n! I8 P n6 m( d
real_loss = cross_entropy(tf.ones_like(real_output), real_output)( ~, s7 W7 C7 G; a/ O+ B) h% p
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
; v- X9 t: w& K& J) p total_loss = real_loss + fake_loss1 L- N" M4 v! I: y
return total_loss# w9 H) v8 L( S+ T
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
/ _( S2 V$ [) C0 \* u, p
7 p/ A! m+ v* u( |" |4 G1 L' `# 保存检查点
8 S& U- Y4 C3 _) K) s/ Fcheckpoint_dir = './training_checkpoints'
+ p6 H" H& g' j6 `6 F$ o+ O8 K% mcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
2 P" B6 s+ c2 x3 x5 ]6 t tcheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,% I/ w" m, Z" D; a+ X; `2 I8 U
discriminator_optimizer=discriminator_optimizer,4 k" Z6 F* B, ?5 z$ R. Q
generator=generator,& V6 f3 n5 z- Y, ^$ u
discriminator=discriminator)
& y. r [. m7 E
# _7 x; Y6 v- `( [9 c& q# 定义训练过程, a2 I' E" d) D k5 l1 L( q
EPOCHS = 50
) `4 ~2 s1 O C! [1 f1 B; T; Y; v4 \noise_dim = 100
( E2 v4 Z6 [" N1 Enum_examples_to_generate = 16
* l- b9 m6 Y' s - G: i/ h; ?2 b5 c- S
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
/ K% Q6 `# g! ?+ M5 v4 Aseed = tf.random.normal([num_examples_to_generate, noise_dim])( x2 u5 Y! E$ Q6 y
: R7 L9 w1 J& N$ C
# 注意 `tf.function` 的使用
7 T7 j: u& `' |, e+ w# 该注解使函数被“编译”
1 s4 b' _5 i! }6 E* t m- o+ z@tf.function
. W3 ]( W5 Q% [1 o. fdef train_step(images):5 R4 r6 d0 Q+ v. v
noise = tf.random.normal([BATCH_SIZE, noise_dim])# _ v+ ]8 X+ b- R8 m0 Q. {
7 I9 S: s w( @( \3 b! b- S* f with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
6 E4 b* Z! ]5 ?2 ] generated_images = generator(noise, training=True)3 |5 e& r% e7 O5 z
_/ \) T/ `1 }7 v( |+ _ real_output = discriminator(images, training=True)
) O! C6 b) ~1 s1 g, |' T fake_output = discriminator(generated_images, training=True)
6 q; W" O% @1 Q1 V' R! P7 F
7 h4 C/ B1 }- L" p+ C+ d gen_loss = generator_loss(fake_output)) T* E" w/ \% e0 h' d$ Z# I/ y
disc_loss = discriminator_loss(real_output, fake_output)
# _: M) M/ P! D8 l+ ~ 1 U+ ~) _/ a; E/ b, ~
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
& E3 n; k, A1 g6 H/ C gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)# w( w, S, }7 j8 _5 Z: Z0 g
5 `( M/ ] P! x+ g/ L5 t- l: l: Q: X* u generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
+ ~+ b9 x: F( {. v; i% e/ v1 Y discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
% o% Q0 B/ z# I: V1 G - d$ g( T- v* S! z' p) d- \
def train(dataset, epochs):
4 [9 D$ ~1 {5 n& q) L Q2 \ for epoch in range(epochs):
6 n3 n3 Q0 w$ f( C$ x$ Y# V start = time.time()7 u% h: @8 A8 e% m1 z7 D. o# j7 [
+ E8 d+ x# x) ]3 U$ g. ~0 N9 l for image_batch in dataset:
, k' \2 T# A# h9 a# G7 M% d train_step(image_batch)
/ q- K w5 Q! s6 U6 {
, v' H' Y' q) W9 I) v- b9 X4 W1 | # 继续进行时为 GIF 生成图像. T0 \( [9 ?9 o9 _/ t: ^3 W
display.clear_output(wait=True)
) g _2 q1 Q! Y- O& f5 V generate_and_save_images(generator,
, B& J7 s$ J( {2 T* z epoch + 1,
! z* g- K! z' M2 d3 I seed)
. K8 h5 j# l: Q0 x9 E5 H. E1 _ " r# s9 P- V1 _ O8 a! {, k! X
# 每 15 个 epoch 保存一次模型: H* F0 h8 ?" q# |! [8 Z7 o
if (epoch + 1) % 15 == 0:& p# ^/ M- R7 ?
checkpoint.save(file_prefix = checkpoint_prefix)$ F7 p1 n9 T4 o/ w
% \" `% E5 m' B# B* i. P' B print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
6 i7 s. i( F3 u: c6 v( q 7 j4 C* j9 j6 }( c% r+ I
# 最后一个 epoch 结束后生成图片. M1 p' L% w2 A/ C2 c- c( Z" H
display.clear_output(wait=True)0 {; t5 K }5 ~1 I% I
generate_and_save_images(generator,- H V: g' {) h
epochs,
% U) [, I2 c( C- u7 U seed)
0 l( g% K) E/ M- A6 J ; h& C, F- d! M c
# 生成与保存图片: R' x( A2 f! c2 q/ `+ Q
def generate_and_save_images(model, epoch, test_input):- T$ _" s9 T \
# 注意 training` 设定为 False; F D5 D: y; f0 `: f+ q
# 因此,所有层都在推理模式下运行(batchnorm)。2 }+ s7 X$ j# p& O7 Z0 @
predictions = model(test_input, training=False)' B* H* S/ E- W% c L# o
^* n$ W& Z- C% `( F' T; ~
fig = plt.figure(figsize=(4,4)); {' {% I( _% Z$ I) \8 t
) @! K" Z8 Z% A- X- e' R5 _
for i in range(predictions.shape[0]):
. p0 z$ K) e- E$ X5 D5 N6 B plt.subplot(4, 4, i+1)+ `$ o* v4 v, A& k) Q- S8 i8 M+ y
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')/ \2 W" B2 q5 M, f
plt.axis('off')9 l# V+ T0 @$ t% n7 l
! T4 K7 h0 S4 D plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))9 q' L& y- M; U& n8 T
plt.show()
" `# I) {3 G m 7 I# K" H, ]) T1 O4 q2 q X2 i
# 训练模型0 ?: U. L+ h; Q/ w2 r/ F6 g
train(train_dataset, EPOCHS)
0 u4 C) `' q& z) X/ c1 d9 X 2 o, h6 V& j2 y( h) a- P
# 恢复最新的检查点
6 d( y5 l$ R& ]9 G: P" ~$ _) Tcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
+ o9 s6 v: N$ U2 w/ T
+ ?' B; F; E0 x( v( `# 评估模型
. s5 J8 P/ Y' F e# 使用 epoch 数生成单张图片
5 ? r/ }: X) ddef display_image(epoch_no):5 t7 Y$ n0 c w" a G, Z
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
2 r" V$ T8 ]/ |2 { t4 f
' {) c* q" W" U" p, \3 d0 N- v( I$ ^display_image(EPOCHS)4 V2 w4 N; q2 H
( p! B. w% c- x0 |
anim_file = 'dcgan.gif'
$ u2 n, J; c* y2 Z4 o6 O
5 a' W1 s0 p9 d: k5 ?& Zwith imageio.get_writer(anim_file, mode='I') as writer:
1 j9 d5 ~: O O0 U6 h8 `# y filenames = glob.glob('image*.png'), T6 T2 ^2 g! h' I, C& Z7 a
filenames = sorted(filenames)
# d! V; R7 j4 X* n0 D! ]6 s last = -10 U+ |% j& b0 e4 ?0 h4 ?- C
for i,filename in enumerate(filenames):
, L( I3 t( r: p7 _ frame = 2*(i**0.5)
/ k5 I/ H& X: `9 o if round(frame) > round(last):
3 P# Z- k5 e) M last = frame
8 d: t2 G) y( L( d5 S else:
. B, @' Z! z4 Q, G1 S; H continue8 M" n5 `6 r) K9 [ ~& e8 N
image = imageio.imread(filename)
; O" O3 K+ g2 T7 \) v writer.append_data(image)
' A9 n* C# P- C9 i image = imageio.imread(filename)! h6 a$ F9 v% e7 P4 W' f. ?. j
writer.append_data(image)+ ~. p/ u( H3 M& M0 d
$ F) {# A+ c" Z. s9 I
import IPython9 p5 d+ |% ~" C% }+ m" f+ l g
if IPython.version_info > (6,2,0,''):! z7 H! K, L8 _8 [+ V8 H; ^. |! T& h
display.Image(filename=anim_file)* J& n6 A2 p# h$ P( k; ~% l7 Q
参考:https://www.tensorflow.org/tutorials/generative/dcgan; `$ c8 \4 g9 ~/ `
————————————————# |- [" C# T4 ?) r
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
0 ~1 c$ `8 ~6 h原文链接:https://blog.csdn.net/qq_41204464/article/details/1182791119 S; x( i8 x9 z! m0 n& s
; L' h7 z2 @2 p
g [, h9 K' }
|
zan
|