在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 558984 点 威望 12 点 阅读权限 255 积分 173068 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 18 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
; H. L) [9 R0 R) Q; P: O 深度卷积生成对抗网络DCGAN——生成手写数字图片
- h! F8 V% {3 K0 f 前言 1 n0 Y1 F- j/ s
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
0 H: h8 \: E* W& L- {# E
4 r5 \ _) P8 @- C9 h3 R; T
0 e, z9 A+ N0 D2 Q1 P: M 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
# O/ Q- I G6 X# q/ R
8 S3 d) x7 K* J; a1 z5 p
0 m8 G& N$ \5 A2 i+ K # 用于生成 GIF 图片
& N- n, b" S7 R+ M pip install -q imageio
( r+ R$ y: W. Z8 j! M 目录
, w. |/ z+ D/ g; ^: N1 h9 M
5 L; O3 ~2 ^- ~ 0 X V6 P- L/ v2 S1 [
前言
" S! d: R) b4 o9 \0 k/ o 1 y9 ~' J g8 o3 R. U* _
. @8 f8 ?3 e6 g/ I) L9 U
一、什么是生成对抗网络?
# K9 D$ n1 L; u4 |# J" H& x3 N9 t: G 8 \/ d* b/ z7 W
- l% d& N! L( U" m 二、加载数据集 + T0 E8 [$ A) u( f v4 k9 k$ K: H
) M, D( _1 \& }. | L/ S( s4 y. @' n
三、创建模型
& f" B. `' A+ ^3 Y- f) G( T1 b
6 ~5 K' r9 @+ O) W/ z/ y4 h: F 2 d/ A; t5 |: i
3.1 生成器
/ ?+ T. f3 `) t+ G4 ~& M 1 N1 a+ }0 j! H; A! {2 A( o7 w
8 f. n+ k8 G* j$ |! d# t 3.1 判别器
+ z/ I$ O! k& p, d, j Z' i. f9 G2 J" B, c/ T+ ?& x
* Q. O. Y+ \5 a6 | e- X2 `# T9 L 四、定义损失函数和优化器
. B4 d3 v( R; `! D/ t2 D7 s9 X % ~! ?* p& N r9 g0 S/ |" e: M
+ I* b2 H( L. v T3 a: s( s0 ` 4.1 生成器的损失和优化器 $ B& W F) S/ q" m
) S4 Y! V0 Q5 H0 p- ?- \& ^
: O; j3 \. P/ n" N1 n
4.2 判别器的损失和优化器
8 n# @3 Z' k8 w& v5 O2 i" I # c# E6 w# `2 u d9 ]1 D4 E/ c
+ a6 h- h" U7 G2 a 五、训练模型
2 _. l/ N) t5 j* K+ v + m: f; D' Y3 |5 [# _
& f0 n! M: P. [" f0 Z0 w) D 5.1 保存检查点
8 g% X3 p/ {0 I w' H) p 4 F R/ S* s; {/ k0 x
- |' ~/ c$ x% c. P* U# V | 5.2 定义训练过程
" e, \* j1 J+ s3 e# y 6 e3 W2 P0 @. b& W# e
, F& ~7 f. N( t" `, v
5.3 训练模型
s% U& q0 \8 x( C8 j, f& t( V
% ]4 p# \- T; C5 V
% J: g9 i" w% M 六、评估模型 9 y3 z" N+ T! t8 z9 b U
- I# f) |& \) c
" I u5 Q" `' R0 D8 A1 U* F 一、什么是生成对抗网络? " f* W* `( Y' a: I
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
# v( s, G; K; G9 y1 s0 D1 @& c
9 q6 E- N# l& k( X
0 W( s: ?$ @, H$ f" G 生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。 - r; S R; K# {
- o. X" Q Q& D- j4 i+ H. o% U
+ [1 ?' j9 w d) y, n
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
; d ^. T3 v$ h% U
" b/ O- g& B% Q6 z9 W; Y
: |, Q6 @3 `+ h* O 训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
/ I8 ^' |% \& G N: l# `% s
d4 k* `; g# A7 J+ A
8 {* L; _0 x q( ?: i; d 当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
7 i8 D/ k( H0 o7 } 3 r# k5 G* s% N+ R1 Y m) \- O; ^
V) B4 s5 V' R. h0 E3 W 本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。 6 u+ D! \5 Z0 C
5 f: \% a o h5 z. _- B2 C4 y% N9 n3 c ) d; G( D5 i5 s" _
二、加载数据集
4 P1 @" _8 Z* f 使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。 ! A) l9 `5 b F( ~4 [' S! u7 x6 U
9 r+ b) Z% @- t: Y& X6 R
* a5 l( U1 W! f$ M' A9 f
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() % K) ^2 o8 `, t- j: C& k( [" B
8 f+ W/ e9 W: u/ d/ ^, G
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') 1 O6 E$ [$ P. P& {' w/ n2 \$ B
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
+ P0 ~) j# T* ~3 u
0 L+ a i$ w; [1 Z$ G BUFFER_SIZE = 60000
4 _7 C+ n g/ P- B* K0 Q- R* E( J BATCH_SIZE = 256
5 \ \/ ` s" j( j7 C& n ' z9 e1 \5 V8 M. y% O' P* a) [4 Q
# 批量化和打乱数据
9 Z/ e" \. a2 ^( k" D( t train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) - V8 A9 z0 H- F5 i
三、创建模型 8 z& c! Z; S) ]) N( |" K+ c
主要创建两个模型,一个是生成器,另一个是判别器。 4 z6 c) d' Z6 Q" H. d1 k
/ G5 z5 D; Y/ J- r" v7 r
4 d" I5 M3 { ~0 g4 i) D( T! b 3.1 生成器
" P3 P/ R+ e4 d 生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
$ a6 M C3 g ~5 U2 ?0 G
. a6 {3 V9 ^+ v/ o& t2 J8 L - O) @, f. ~" a& E
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
F# q- m- y; ` s" A
+ L @: P3 e* j1 ?# K* h " Q" F: L R2 |# j/ E R
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
4 g0 {4 v' Q, F2 G8 J! v( O
# T# Q% c3 G% S9 V * Z' Z, G6 ~: x9 h% C. {6 M v
def make_generator_model(): * E. B5 s! J( T' s+ y
model = tf.keras.Sequential() 7 m' s7 v$ B0 p# i- \0 n2 u: Q+ `
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) & w1 Q& i3 p& [: Z2 b
model.add(layers.BatchNormalization()) + r5 i- X* M# N7 T! Y( B
model.add(layers.LeakyReLU())
# U! `! U4 d8 m2 A" a0 u $ g( b; x( R# ]+ Q6 Y
model.add(layers.Reshape((7, 7, 256)))
6 J/ y. U* E8 q9 c# w, D assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 + i/ i; V3 y1 U: X9 P. B- [
: w) A7 d) A7 [/ I- _* T model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) 1 X/ K5 a" g, s7 e, t$ y) E
assert model.output_shape == (None, 7, 7, 128) 1 ?/ d: W$ I+ w5 h
model.add(layers.BatchNormalization())
: N7 a# Q: j+ k( z/ z$ K: `0 b model.add(layers.LeakyReLU()) - d6 D A' ?: P" @- n# X: n
+ G! k3 P- E* Y t* S2 N model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) * B, M" Q6 w6 S5 i* [
assert model.output_shape == (None, 14, 14, 64) ' n/ o" |9 `& J1 l# V
model.add(layers.BatchNormalization())
9 l4 k: a5 ]" X. ~2 T model.add(layers.LeakyReLU()) 7 x4 y9 E% a. `9 [/ U" Y$ u
6 n' E9 K$ L5 t$ m B model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
* ?! [/ M0 K6 r/ }1 X2 ]/ e2 g assert model.output_shape == (None, 28, 28, 1)
! V# ^: @' @ f) J& J# p9 L
, Q+ U8 Y1 b2 P6 }1 w return model
/ c# X. K7 z7 | 用tf.keras.utils.plot_model( ),看一下模型结构
& b. \) ]- @7 `) W. k
" i1 o* @5 E6 w5 o
0 O+ \; j6 d. ~6 ]9 z2 L) Y$ l5 q
; d. A P3 q) E0 s( i2 [% G! w
" I& u" T& P" Y8 g; O
! n5 t4 }+ u C/ L7 W 用summary(),看一下模型结构和参数 % U! F! m6 c! K6 [$ K+ g! S
" g2 J, N, |+ e) g4 g! \, G; V$ V
7 ^! D" J$ |; s" X
' n7 O/ M. h" V9 l2 x# E 9 `. O5 { D1 m3 \
3 ]; K% z0 U p0 }
5 A9 e/ Z" W4 x. l" R 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。 ( T3 H' k7 d1 a# Y, A) ?( i
. [# B; T# P* z M
$ I% ]+ ?! v7 B# L: y6 p" j generator = make_generator_model()
! `% u7 ^( e$ K* S7 ~
" j! o% G$ ?3 W- H2 E% x/ n5 B noise = tf.random.normal([1, 100])
9 R. i* n' L" p# p- z/ a$ Q generated_image = generator(noise, training=False)
* y. i9 n2 d+ z 1 g" p- E( @) R! V/ B
plt.imshow(generated_image[0, :, :, 0], cmap='gray') , M2 H# v) y/ Q: n' O! f( Z
8 Y9 i. F! \! y# `: {/ e* n$ s @
' x d5 Q; S' T& V% q+ K% @: L8 s* w
6 g* `; |4 A+ L O* A
% l3 O( @9 m9 N. z% f9 L, e 3.1 判别器
( l; k7 h6 L) `' J' h9 n: ? 判别器是基于 CNN卷积神经网络 的图片分类器。 9 {3 q* A" ~! o+ h
- J+ B5 J( H' t; a n
( m j; M- O; i9 x6 C4 u
def make_discriminator_model(): ' B5 @' k& c/ Z0 B* j& _3 {6 i# [
model = tf.keras.Sequential() " k- E1 t) M. F# E
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
- ^& n; s, d+ c, R2 K' N input_shape=[28, 28, 1])) , ?1 F* Z, i$ f0 q2 v
model.add(layers.LeakyReLU())
: n# p) Q: m' ^8 M0 q model.add(layers.Dropout(0.3))
! ?) x* W+ ]$ A7 o' N' c V, ~: K) F% f4 i/ {$ z
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) & t5 n( V$ x: h7 [3 E/ e/ N
model.add(layers.LeakyReLU()) # N+ J3 \) N0 _, T, [$ o
model.add(layers.Dropout(0.3)) . n& p1 O/ b. a# H; W" G) y/ I! E
, p8 L% M1 [/ e model.add(layers.Flatten()) 5 D' O& i0 J) a8 I+ P3 }
model.add(layers.Dense(1))
" J* w( m- I+ c' h/ M* r
) |6 N7 b. k6 w( R1 ? return model / b) i2 p6 ^+ \2 k" `5 V' D
用tf.keras.utils.plot_model( ),看一下模型结构 S9 G; ?. x" `6 O& I
5 i* s* a. o7 ]7 W: U! \; |+ Q3 w : [4 k2 @3 V1 B( Z
- n# [, r2 c, Y6 B) @ S& s" x! d
) q0 l& y8 v1 ^3 V n
" q9 j3 x$ q& n! Y
1 e+ P( z5 b0 w+ @ 用summary(),看一下模型结构和参数
! I0 P4 @% A4 P' B* t; e# j
1 S3 M8 J& A; }% d9 p
; _" c' h! {. G4 x # P' j' w" v& f
6 P: l, t' }8 D" S: o" Y" o) I
% J+ S, f3 D2 U0 c6 E7 f
( u( R4 |, X" A, p 四、定义损失函数和优化器
. c9 I/ v- }5 e( k3 ?& z1 x 由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。 0 i' E4 Z# r1 I9 S* {/ s
! g! W: K/ Y! z1 A 7 t& E+ O! O6 @/ g; W. w
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 5 w" k* j6 U5 Q1 `1 v8 Z, z4 x7 Y- p
2 |( v0 D7 |& R- N! n- q* X4 g
5 P8 c9 c$ ~" B" @* a8 y2 i1 } # 该方法返回计算交叉熵损失的辅助函数
# u# W- e9 N* V. @ cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) , a2 I' q: B5 W
4.1 生成器的损失和优化器
/ X$ T- ?: {6 Z' g; g; M+ o9 f" U 1)生成器损失
. H$ j( }2 p$ a. m2 Y: `
8 M& S; O6 n" k. ?; v( d
( U" w) P* G, v 生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
; x2 `' I9 `$ ^" H# H
: M, _6 _. `! P* F
3 o) I4 o) y* m. C# V+ O 这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。 $ i8 _+ j! k4 M0 f7 ?/ W# d
; Y& M% s* T) h, B' T; m
% p* c! f: L' W. ?6 X5 ~& e def generator_loss(fake_output): ! d# U1 y3 {; M7 p- Z
return cross_entropy(tf.ones_like(fake_output), fake_output) ' W! {; p* z8 `7 S5 H7 G0 h
2)生成器优化器
1 F3 P1 J5 O5 e- J 0 D3 V. F1 R/ ]2 G
% t; f2 _- t3 d6 J' N
generator_optimizer = tf.keras.optimizers.Adam(1e-4) ) _7 W! Q5 g$ W
4.2 判别器的损失和优化器 * z+ G: C& F% X6 H3 R- w
1)判别器损失
' |9 P7 m; g' R, W1 R" W( { ' o3 M! N# t1 \# g0 T7 c
7 O n7 k1 i1 L; i/ r 判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
. S w" g7 m: Q1 p3 T( } ! ]5 Q, w% j+ s, s( a
5 w+ x5 q& q$ r- v
def discriminator_loss(real_output, fake_output): : s5 _' l! w9 Y+ M; {0 l# B
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
% I T( @" x }, i fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
1 e' T- h& Y& g z3 Z2 x total_loss = real_loss + fake_loss
, D5 d8 m1 b1 `( D$ Z5 u9 F6 e return total_loss
3 d) R. ~8 G" \* o3 `. [& A 2)判别器优化器
# _- o0 B9 b5 g6 P0 P3 i7 _
1 M, u% R# R$ w3 C: I. C ; u ~7 D- K- Y0 n. S5 b5 Z
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# C$ X" U% n0 w2 N; n3 a9 t- A4 l 五、训练模型 / B. g% @9 {) s5 `. p" D" a7 G, D
5.1 保存检查点
& r3 W3 Q/ [! H9 v! V) ` 保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。 ( Z/ W& G& j/ t5 K
/ B1 G7 v" }* \/ m # Y# h/ A+ o; v
checkpoint_dir = './training_checkpoints' 7 w `# v9 z' X9 U7 a V
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") + S. [4 T1 S" W8 g6 M* ]
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
# w) A4 y* M0 [ _* N! n. } discriminator_optimizer=discriminator_optimizer,
, p, Z! b: a/ w; q, I+ G generator=generator,
" _% \/ f. e% B2 w, s) z discriminator=discriminator)
0 b2 @3 ?* M8 `: K( ^7 ?5 H 5.2 定义训练过程 1 y+ ]) j0 W: E( G
EPOCHS = 50
5 @- {" j V! c) G noise_dim = 100 % A# |: a2 i1 K
num_examples_to_generate = 16
/ e1 K) m( t2 Q 1 r$ r% T: j5 X# R
- J6 w! x8 `" s6 ^2 \8 E
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) ) i% c- E& o; H' o9 G- x! m
seed = tf.random.normal([num_examples_to_generate, noise_dim]) # Y+ P, L2 Z/ X8 K4 F
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
! [ x' l. ]" F }- S4 i ; m/ Q! K2 | h( |; s9 S
; S" e0 B J7 Z4 O" r( _
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
0 @* s, }2 @$ i
9 h. U5 ? y4 n/ r4 s! C; g 9 W7 [0 _7 r2 R/ B7 U& l, `0 W' z
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
# P' U- O8 d* @
6 j U0 |8 t8 u8 L/ u! V+ n O / D, u# e' ^6 b+ }5 U* ~* Q1 j) F
# 注意 `tf.function` 的使用
* v$ Q! C- k j7 [ # 该注解使函数被“编译”
8 ~, _0 v5 z1 l& p6 j @tf.function / j& c# f. r: b, [% a% H1 e% v
def train_step(images): & w4 l+ n2 @. c8 O; w# g
noise = tf.random.normal([BATCH_SIZE, noise_dim]) # z; \2 F2 B; F2 C: V1 @
: S+ C# c. I+ I$ t' x; L
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
5 X0 {9 q7 a4 k4 } generated_images = generator(noise, training=True) 3 {4 N. _1 d F4 }6 D! l( M
" s g; [: _- J real_output = discriminator(images, training=True)
, E! ~! S. s6 f+ P- @- E Y$ w$ K fake_output = discriminator(generated_images, training=True) , U) e2 ]! S/ A0 j
& T1 e' \4 U g; @- K8 c2 ~ n
gen_loss = generator_loss(fake_output)
1 d S' P. `: a& t+ ^ disc_loss = discriminator_loss(real_output, fake_output)
, w B2 l3 W+ E2 {6 E
' Y/ A1 E' o/ Z+ L1 u, _+ { gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 9 t& C% Q. @1 Z8 F$ k
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) - [ z, F, P0 z ]# }) v
! n6 H! X) j8 w( W3 \% t generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) ( c1 ]7 X( c4 J( c, `# V
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) & e! L9 {' l. F: b7 [
8 u/ u! h0 m6 V$ q5 b- L def train(dataset, epochs): / o; R/ m; g! X& h; o) ~! I
for epoch in range(epochs):
T9 n. b& m8 o# G4 c" G0 t/ j start = time.time()
, a$ x g9 _: S! g3 n. M
5 Z! F! }6 ^1 N) V0 x( n+ P for image_batch in dataset:
9 B9 `! R1 A1 H" n$ q/ t9 N, [0 b train_step(image_batch) ! y/ M. O5 Z$ y, u5 s
" Q b4 ]4 }0 l; a9 S # 继续进行时为 GIF 生成图像
7 _9 y3 ^) A6 Z+ v display.clear_output(wait=True) 2 z6 c( x0 Y# e' \8 Q( z8 b
generate_and_save_images(generator,
6 n& m* j: ]3 u4 y% _ r& ? epoch + 1,
0 z0 c! Q, C7 {$ B seed) , X4 p% C) U: k0 T$ L
: k# R6 @7 c1 M! C# R7 C; _
# 每 15 个 epoch 保存一次模型 0 m; A1 o. ]4 L/ F' a
if (epoch + 1) % 15 == 0:
8 n: w( }' G7 j+ ~0 q; Q6 t& I& M9 h checkpoint.save(file_prefix = checkpoint_prefix)
/ [- L. |# G$ w5 \# Q$ }( m) v 9 x7 w# Z6 I1 w- d$ |
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) 4 C" z/ Y+ G$ i" W$ r5 G6 f
, O# \* |6 l( c( k # 最后一个 epoch 结束后生成图片
; J% \/ b9 u Z& z C% ? display.clear_output(wait=True) o6 L% c0 s: g! T
generate_and_save_images(generator,
. ^. ?0 H. Y" p epochs,
( T& ?; k; l% f- q7 R$ i3 m seed) 4 |1 S& ?9 Y& F3 h% O9 w2 _
0 |# {2 C+ Y1 S/ `/ w
# 生成与保存图片
2 g, I' [$ z3 s) |/ M; ^; o def generate_and_save_images(model, epoch, test_input):
, f; `6 k) T+ W) C9 e # 注意 training` 设定为 False
; _( L' Z& o2 d/ n* G # 因此,所有层都在推理模式下运行(batchnorm)。 8 f) n- ?* p1 v4 {: N
predictions = model(test_input, training=False)
7 I0 y; T J- k0 R
n) h# g! z6 i3 c/ y2 i. D7 i) Y fig = plt.figure(figsize=(4,4))
5 f( F) ]9 j' o' u* V4 Z, g) Q% P
/ s4 E/ v+ U- W1 h9 ?9 [ for i in range(predictions.shape[0]):
! \* k5 f2 l; f; e plt.subplot(4, 4, i+1) ' b A) s4 A! W! O" ]
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') - X( t$ l! Z: i) {$ |+ N$ o% V
plt.axis('off')
) o# y3 D5 T {9 Z g) j
. ^6 g. g" }' b a plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
' k) F+ Z, B! k" i! X9 j plt.show() p1 E4 i# L& A' a1 ]- K
5.3 训练模型 * w" U& T: z9 R& A' s
调用上面定义的train()函数,来同时训练生成器和判别器。
# G% f' S3 W' o; P2 r 3 P2 T) p) K+ M8 x& D8 q
, L9 u: G4 c: ~& H! ^9 v) t
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。 1 H* A( Y9 N9 r! L
- ~) c5 }+ i, x4 r; Q. K
* d0 }$ a5 y) ~5 v8 L %%time " S+ M! S7 j, y. {; d
train(train_dataset, EPOCHS) 4 c% j3 }* S* h' g) h1 u
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
/ O* k( v# ?( Y$ }9 J
- n' [# W+ s. V4 p
* D* j* t J1 L2 u6 t6 h/ B 训练了15轮的效果: : ?. p' q) [0 X5 h, q) v) O5 V
4 `7 t, @1 t/ L2 E 2 n3 ]9 G. p. X, X* f. d1 b
3 x4 Q7 I8 s8 m7 I- F
8 U" N% }, F% `9 q4 Q5 n2 X
8 i, s7 n- z$ t9 _! A* [0 I 3 H. L- W- u9 i6 h H
训练了30轮的效果: + Y2 ^% j! G# R$ m+ E
7 m8 U8 S8 T) @) o$ I/ h2 F! F
& ^# U/ w2 F8 O; b# n+ X ; j/ p4 r( Q! C* ^5 B
% w* G# k& h/ B5 {. r# ]
2 A# P+ r1 H7 r4 l# @# I
~/ ]7 p- z6 ] 训练过程:
6 r4 E* g! P% \6 g' X8 L, F% O 8 B( o# j( o2 v6 s' |# N0 U
% H+ W6 e+ P) Q- h & e- V2 k6 l5 y+ O; Z9 X) p
2 R% B% }( c# s. j/ G# ~
' w4 k. @# H& m* n3 r $ w/ C0 A7 z4 _. U5 q% c! v1 T# I
恢复最新的检查点
. F- m* h0 E. ]/ H% u k; X : B! }. r9 F- c/ E
4 C% b9 D2 U6 `3 @7 k checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
6 ^8 O; }( m( t 六、评估模型 % S2 {8 L- S" P9 F( m5 N3 W
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。 " w; E7 a. T% D( [! m
9 l( D7 @" b7 H2 P. @. o- } C # o! Z4 z" |- n2 o; V
# 使用 epoch 数生成单张图片
7 O p5 E- a7 ?% l4 _2 g( e def display_image(epoch_no):
6 E9 r% V# S6 ~9 e7 [3 K7 Q% L5 e. x return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) + c6 k2 m9 l4 Z
& R7 a, P! j4 k/ c: v& i$ J
display_image(EPOCHS)
/ i! {( |$ _' f2 P/ J$ P anim_file = 'dcgan.gif' % V" h' `5 P$ C% L4 u
. @6 h% M3 c3 r; ^ K; b6 K" R2 ` with imageio.get_writer(anim_file, mode='I') as writer:
6 U, z$ L7 K* t. a filenames = glob.glob('image*.png')
3 F, A! P& {* ~7 a* S; A filenames = sorted(filenames)
9 k4 c4 D$ m# c, j9 ?7 l last = -1 ) E/ U5 r, X# A7 p) K* z! ]
for i,filename in enumerate(filenames): / \& D6 n: W0 P
frame = 2*(i**0.5) 1 t+ L5 h! a& K3 D' Q7 J3 i8 J; \
if round(frame) > round(last): " o- X7 O! a, M m1 J% G! E/ I
last = frame
4 ]) R7 B A" {1 { else:
5 t: [: @+ |, w2 N: ?8 V continue
# d" j4 t5 f; {0 c! m image = imageio.imread(filename)
8 c* D: A7 x2 }- J) M f2 @& b" z writer.append_data(image) # \: I2 |. ^$ { w7 g8 G
image = imageio.imread(filename) # p* v5 P$ Y4 |
writer.append_data(image)
1 e1 M3 ?) I8 r8 u ~5 R
( S0 z8 G' s: ` import IPython
- G) }: r6 R1 R! @* n' N if IPython.version_info > (6,2,0,''): 6 C8 ]3 D& D) V, \2 m
display.Image(filename=anim_file)
8 J g3 Y; z& E 7 N/ x# [: d3 C- T r7 J
8 p4 s* H3 y* ?$ V1 }6 ^7 j a& o
; Q! w: h# P9 h$ q
8 ^8 N5 T+ `- o 完整代码: " l$ r2 a: |" u1 a
( f5 X8 i, x8 `2 Z' [ $ S' D% k3 J. M9 |8 O( X
import tensorflow as tf
# ]( A: U& z, k5 |* i! `' c import glob 6 R6 W/ q( h- E6 F9 O( ~
import imageio ( c9 \( R% D6 j6 G+ A7 {! U+ G0 F
import matplotlib.pyplot as plt
. [* e" N- B+ N D+ {! \, W2 o) p import numpy as np 4 k6 H! ?2 E* K* T
import os
2 v8 v7 s% T$ v& ]& L, _ import PIL
9 q$ Z" J1 [' ?* h& W, q/ N2 d, D8 ]4 a from tensorflow.keras import layers 7 B0 g R# v; Y2 a" g
import time
3 y1 `( I4 s: Q: G% N0 L7 Z$ l 7 ]7 w: A5 G8 k8 ]. }6 G
from IPython import display ! W/ V9 ~" M; F- {& X
# D7 Q3 J1 e" A! [6 I9 I
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
, Q4 b8 d0 m" D; e. K% M
5 G4 M$ i& ?4 k- S/ k0 k0 ~ train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
/ f5 f3 N' Y0 v( y2 m: e train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 ' S7 m/ B; V6 a Y( ?2 K* G
5 l* x* B) A; A }3 e% J
BUFFER_SIZE = 60000
! m1 B% m8 b+ J8 V [9 P BATCH_SIZE = 256 + ~0 N! [8 z9 R: \. c
$ L# W9 E9 g& W7 o9 \; R2 M- p # 批量化和打乱数据
2 h# I, p+ j6 e0 I train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
y8 e; d3 K7 i; |' P' N/ ~
3 }% D- P0 u& x # 创建模型--生成器 $ h {+ w, b, f, K' u
def make_generator_model(): + x- p+ Q( q$ ~- k+ U3 ?
model = tf.keras.Sequential()
0 I1 y2 w# D' Y: [4 n9 c) s: _ model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
1 J% @) ?: q" X2 {6 {; \ model.add(layers.BatchNormalization())
* t" d% Q) e& l5 i* n0 Z$ j model.add(layers.LeakyReLU()) ' z$ j) d. E. G; G
8 f4 {" i- c- j8 H; C5 V
model.add(layers.Reshape((7, 7, 256))) 8 p; A/ M d2 f; w# I, b
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
; P: Q- H1 y* O3 Q
: [; F3 Y" l# ]: {/ z" P: } model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
) Y! ?, g# w9 t. R5 ~( i* ? assert model.output_shape == (None, 7, 7, 128)
- {. K2 {' \( a- Y model.add(layers.BatchNormalization()) - j. L4 U0 t7 f8 q
model.add(layers.LeakyReLU()) 0 \/ t' i$ Z: y& B6 L; j
' r$ d' d; B5 I. y. E: d B
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) ) ]. L0 S3 F, X& Z" V) L
assert model.output_shape == (None, 14, 14, 64) % z N! ] k R8 @$ D
model.add(layers.BatchNormalization()) 2 C- m; L q. i9 u, c
model.add(layers.LeakyReLU())
" M$ `" D3 @* w9 d
; d; D7 ]. L+ K model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) 4 q# g# D7 R3 g- h! m
assert model.output_shape == (None, 28, 28, 1) 9 |: r2 \4 S$ |; d/ B
" z. C$ D! [3 ~: E; ~
return model
9 z! b, Q. g7 ?4 J- h1 u0 @
! c& A3 W! L# {8 l # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
, e+ v7 ^- _* y% C& h+ t generator = make_generator_model() ! ^$ \7 F; e: T" Q
* |$ A# ~ ]% B6 k
noise = tf.random.normal([1, 100])
) V1 r$ F5 O8 r$ q9 E generated_image = generator(noise, training=False) # i0 o; D k8 L3 n5 i" I( j0 q
% Z- q. e* r6 f4 A, i1 ^3 g, } plt.imshow(generated_image[0, :, :, 0], cmap='gray') 6 [7 A- y t. Z/ F/ e
tf.keras.utils.plot_model(generator)
6 V! F4 }3 V5 } 5 m! W5 h0 _- l
# 判别器
2 w9 I1 J- a0 s/ ?2 M def make_discriminator_model(): / B9 @0 z1 V M: r# w Q) y
model = tf.keras.Sequential()
; J4 e- K, F4 {# s model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
/ V" x. o! s3 Y, Y' A$ H input_shape=[28, 28, 1])) $ k6 Z X6 y: ?% h
model.add(layers.LeakyReLU())
7 _+ ?7 G4 \" a) ] model.add(layers.Dropout(0.3))
& E2 @+ g$ u6 L1 n9 }
3 V/ ?% B1 ?8 r) x model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) 6 A! D3 l# c& o! w
model.add(layers.LeakyReLU())
* b! F! \, N9 p model.add(layers.Dropout(0.3)) 3 A) i( e; d7 p2 J( U l* l
/ Y9 {: \# f& W6 |0 z7 e9 Z" H model.add(layers.Flatten()) 7 r. }& H+ S3 F$ J2 n
model.add(layers.Dense(1)) ; ]" b% }" R( b l1 J
, @8 ?% l; V% @9 [ return model ! m' }" E5 H! X8 C" u+ ^; o$ M
, H8 v& K+ T, i8 A5 A& x: g6 w6 q
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
8 c9 s9 F6 C) s4 t discriminator = make_discriminator_model() 1 O' o. g. E& d, y% H- @
decision = discriminator(generated_image)
; O0 z$ s: z/ V, B print (decision) 6 J8 u) V6 t$ @/ H' N1 x7 h
0 Q& w5 {! E; ]/ l) b
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
: l3 E$ ~4 w! {- l6 R" h1 D1 f/ r cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) ( ^% A, A4 w3 D- s, v/ ~5 M1 S
. \1 w7 x' D* X
# 生成器的损失和优化器 $ b2 X: g" w3 j6 u! ~
def generator_loss(fake_output):
( y& K" m& g3 j% {" E9 h/ |2 o( M( r: r return cross_entropy(tf.ones_like(fake_output), fake_output) 1 }; @* l5 o8 N! {0 E/ a7 N. q
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
Q; x/ ]0 C2 H0 x . |% X1 q& V0 R- @+ V& w
# 判别器的损失和优化器
X$ J! ^: O% Q/ }7 W8 n def discriminator_loss(real_output, fake_output): ; K+ q% I6 j8 O+ Y3 u0 K# e
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
6 V* g8 Q* Q) J: _ fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 6 |+ @8 g2 y. o. M( ^2 p" O
total_loss = real_loss + fake_loss
1 t; X! @4 v2 u4 U return total_loss 0 f9 {! l6 d( E( [
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
6 R/ v3 P9 A% C: ~) J+ s
( p9 X1 d4 J" G& N$ e, W # 保存检查点 - `" I" k) c4 I6 F& h
checkpoint_dir = './training_checkpoints'
' R; a! E* e# s% t( z checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") : o4 J: b6 O. I! s9 u' L$ X
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, & E- X- X: S, v* V* k. P! f1 b# O
discriminator_optimizer=discriminator_optimizer,
! n( B, a* I' Y* {$ P generator=generator, " ?# {) E* p: w. J% r" @' K
discriminator=discriminator)
/ M$ D2 Y) V! G
7 z. |% t9 K1 b% u, X # 定义训练过程
7 }& G0 ?; T1 S. k1 N, u EPOCHS = 50
5 n7 B3 Q$ }) E( [" B6 Q" u# J& V noise_dim = 100
9 e6 v% u* H) M* @8 R- G num_examples_to_generate = 16
( f* ~% u6 q! Z# ^
" G# }+ @- C1 j4 h" s/ ]& P0 t/ ]9 s3 ~7 E # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) 3 t2 X- j b4 M* ?* [6 ^7 l
seed = tf.random.normal([num_examples_to_generate, noise_dim])
$ [, |1 ^; w6 C( ~) `+ ?9 [) x! f . G! h5 M4 [+ W, M. u
# 注意 `tf.function` 的使用 7 T! s, K A. \# Q g
# 该注解使函数被“编译”
+ v1 a; K |3 E' c @tf.function
# P9 e! L3 W% K& @) n& j def train_step(images):
- \* W0 k& J, e) @1 L/ g noise = tf.random.normal([BATCH_SIZE, noise_dim])
" H- U$ U: @2 S5 J& Q9 |3 ~ ; `) r! O3 _2 R7 M% l% ~4 m
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
( r. u3 w9 L) t$ I. J' }/ ~ h generated_images = generator(noise, training=True) ) T8 Y2 s* |. i( D' w0 Y. K
. F) ` f# v4 G7 z- k8 E
real_output = discriminator(images, training=True) : ?" u3 v( |$ S0 d1 s
fake_output = discriminator(generated_images, training=True)
' Q5 j7 ]4 W* @% i6 V) d% a# h 1 ^) E' ^; G g2 [+ [
gen_loss = generator_loss(fake_output) \' e2 B" i8 O1 k1 M: x( i( s' [
disc_loss = discriminator_loss(real_output, fake_output)
: l8 p- E _* C. P6 R% K% ^& Q
, ^/ z* ~6 R+ E& |+ h* Q' I gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
1 k7 r- N* ?( {0 }' u" B4 a gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) " m; U. u' l0 u3 E% a$ @ k
% e0 X* M0 {$ ^. f
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
+ G9 @% E% s% h0 o discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) 5 h7 q5 C* V" S( e5 P
( l1 L0 E* [' O" T1 O' i; M
def train(dataset, epochs): + a5 D5 X ]9 m' D. d9 w
for epoch in range(epochs): / Z" L# C3 Y1 Z' ^) R3 o
start = time.time()
6 I5 \, N, e6 U3 X' q$ c" x5 [ 6 Q( W& X7 T7 O6 o$ y, ~" P! A/ a7 w
for image_batch in dataset: * R( j" d/ {* x
train_step(image_batch)
+ Q- V6 p% K# C! U) @
% ?! M' ~8 Q& |* Z* q& L # 继续进行时为 GIF 生成图像
; V1 `! S" a( |* y display.clear_output(wait=True)
" [9 z! r7 ^ f W% w2 m! |9 p, U generate_and_save_images(generator, ( P, l: u) `' q& h/ o- }
epoch + 1,
' w! M: T1 q' f* D( D2 N seed) # C% t- t% [, \9 s/ h* G; w
0 b+ A$ U, f% w; ~: H/ _
# 每 15 个 epoch 保存一次模型 ; m _8 w4 Z. F, B6 Z) ]
if (epoch + 1) % 15 == 0:
* K+ |2 ]+ ~$ d( R3 W! ]- J% l8 ~ checkpoint.save(file_prefix = checkpoint_prefix)
( d1 E" J+ D+ P1 @. d$ c
/ Y7 \* a6 B Q0 N print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) ; {" d' a2 L4 _ R+ m% S+ {
6 ]4 i7 J1 K' F0 ` # 最后一个 epoch 结束后生成图片
& t3 y$ v% ~: I' ]; R, A display.clear_output(wait=True) 5 N+ s" L* r, s' h) x
generate_and_save_images(generator, " n4 [6 e# K/ @/ |7 X
epochs, ; j/ ^, S0 i! m/ D2 ^' _9 M% ]
seed)
2 ~# P+ x( A0 q9 u$ R
: s. X1 \5 g) o; _ # 生成与保存图片 - M: Q; C4 s/ F' D8 Z
def generate_and_save_images(model, epoch, test_input): 8 U3 h! L* @* O! q9 a8 I; ^
# 注意 training` 设定为 False
' G4 \: z2 w, v8 \" Y # 因此,所有层都在推理模式下运行(batchnorm)。 9 b+ q0 P9 I" P# n
predictions = model(test_input, training=False)
5 H U- D: a& Q; b% E * ]7 b) R* J7 ]- q# u
fig = plt.figure(figsize=(4,4))
& u8 S; Z' W! U0 O1 A8 |& } ( @2 W3 C% Q$ `$ Q
for i in range(predictions.shape[0]):
. v) `# Q/ @! B1 v! A plt.subplot(4, 4, i+1) 6 v& o& N* y. b. Y/ {/ `0 J
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
% g9 d+ y& [9 T/ N plt.axis('off') . Z% h; L' f5 o# i* u: \) p
, e9 m/ O% L& M% C
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) {3 I: x* Q& @0 x6 F
plt.show() 1 N5 C5 T& C! Z1 N0 {* M
' B2 w1 H: f3 |( a4 p, _+ F # 训练模型 9 t6 D/ r5 m# p- Y! M3 I
train(train_dataset, EPOCHS) . k) ]: N* p3 p& X, q
" x: L4 |- G6 b& @* d+ G. `* ^
# 恢复最新的检查点
2 x3 y) Q# Z- p+ u) R checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
9 A8 l" ]# p7 ~2 o6 N% a( G3 T3 S
" d7 |$ U, c& ~( | # 评估模型 # X8 P( W5 b7 p( \+ e) D$ ?( B
# 使用 epoch 数生成单张图片 ! K0 |( Q* _5 l! C I3 ~3 ~& ~- D" C3 S
def display_image(epoch_no): $ e5 o" i+ h! f
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) ' _- o1 H5 A _' |9 R
2 x' S4 L6 c! E6 `" j9 s" [ display_image(EPOCHS)
: f5 s. B& x, c$ a( w
* M5 P% A- |$ [7 \- z- ~' n6 G7 a anim_file = 'dcgan.gif' % f; E) y5 V. b: j
& c! N5 m) V2 X: R with imageio.get_writer(anim_file, mode='I') as writer: 9 U& l) N+ u3 G
filenames = glob.glob('image*.png') P ^4 q/ o/ `# X3 G
filenames = sorted(filenames)
' Y& F5 W; w' @4 E8 Z; q last = -1 , a# }' y1 M9 P* P9 j- v$ Q/ W+ i
for i,filename in enumerate(filenames):
" k9 u- g) b0 {( c, b1 ~& a frame = 2*(i**0.5)
8 L0 O7 q5 x g6 p" o$ I: ` if round(frame) > round(last): 6 A, Q3 x: t8 a' E
last = frame
u& q# ~$ X& i U. v else: - K" I; p1 X! V: m5 ?
continue
6 b9 V0 n: z' `- v$ W* ? image = imageio.imread(filename)
! W4 N+ R" d1 F writer.append_data(image) ; n& L L- @- G! S2 k
image = imageio.imread(filename) ( |# ]5 y( Y; D; t
writer.append_data(image) 8 y! l4 s7 O: k& G& H% x
0 G! X T7 ~& f& x4 i6 b/ K import IPython 7 v: C! `# u/ R0 W" b3 n
if IPython.version_info > (6,2,0,''):
6 O% C! ~. |# h display.Image(filename=anim_file)
9 z Y: J- t" L, A 参考:https://www.tensorflow.org/tutorials/generative/dcgan ) ~4 |( t1 [+ @$ r7 t1 T9 F* r) Z9 ]
———————————————— * r2 @7 l6 s# u' @! E, D+ C
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
* L6 h- C6 X1 V' A# v 原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
: E R% f1 c+ R
2 [! c% O7 u1 ]% c# {
[7 f) C" U8 _; ~* y
zan