在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 564697 点 威望 12 点 阅读权限 255 积分 174632 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 3 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
) S) T+ r5 o9 T1 s& V
深度卷积生成对抗网络DCGAN——生成手写数字图片 - s- C! }* I7 x6 j8 y( ?, i
前言
3 L: p4 q* M% ?2 X5 U 本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
' `8 Q% j) ?' n; v& R1 X9 b
- s2 p5 [2 T$ A- N3 v# @
_9 d1 S( {' w/ } J" {2 H/ o 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下: 0 x" \5 V$ F( @( {0 @! z/ z
* Z; @6 X( u+ S9 n: G! r
) {9 J& A6 D, k: S # 用于生成 GIF 图片
! f& v. y* x7 w O- Z2 ]9 G pip install -q imageio
0 c) T# g( Q6 @3 e; t+ `9 B 目录 " d( G* t5 M: ~9 ]: t4 b+ Q
, t4 G- `0 o8 K4 d 4 D- [# m. \2 C$ k) D, i: G% C8 Q
前言 0 D$ O. A1 Z* z/ V) j1 _
u+ E. |+ u4 @) V- j ]3 [( y7 e( B0 F8 O2 }
一、什么是生成对抗网络? : @% c+ h- A- P2 f( c! r$ d8 z4 p1 L
4 N: b' v5 V; `" Z; |2 m 3 A: W; M$ q. T2 e
二、加载数据集 % L9 u/ T& O! e" a9 n; k2 p
/ Z, M% z/ x! h. \0 o$ M( _4 t5 E
0 O6 t3 @! I' n8 R 三、创建模型
$ h) @% f/ S2 n & Q. c. ?. ~. }, Z
5 H9 w9 T. Q. w2 y+ c- z. O 3.1 生成器 6 b4 o: J0 } _/ d) y
8 h8 A `1 Z7 b8 e
. ?+ ]0 I7 B: d I
3.1 判别器
) y: G' ~% o; s8 q6 A- N2 |
4 Y% s6 X& V# M& \9 H" i$ h! q
4 n' b! n s6 m! j6 ? 四、定义损失函数和优化器 6 W- x" T4 v) o# p9 l5 x
9 a$ D# E) }* k8 f9 }0 Q
7 @+ t' Q% Z# N1 }6 P 4.1 生成器的损失和优化器
8 T& F |+ E$ K! F9 @0 f2 t 4 v8 N) p) Q* D. l5 Y
3 @' K. w u: k8 D' n5 ^
4.2 判别器的损失和优化器 5 I) d, N+ Q) E# W+ R0 `8 V
7 _8 @" v+ T! t m% V+ b' t7 t
& B0 a L5 T9 y, g 五、训练模型
8 L, D. j3 C* t: U, c/ Q: l : z9 k+ ^6 N$ l- n; ]9 z4 A
1 W1 e8 D. D- P) |5 x# ~, F) P 5.1 保存检查点 9 x8 t) t5 b1 c& J# g8 V
3 R7 S+ E, x! t
4 J8 O/ F* h4 K' L
5.2 定义训练过程 5 i5 B/ k1 S3 u' P. X7 S2 I
# ~% Y4 o. H! G W, y
( l0 ^, x: P- m+ q( i
5.3 训练模型 : V3 M, r8 h; K4 q1 T
3 g! {- E @0 n% C! T1 s ; f1 p4 V. p9 i0 @9 M6 L* F
六、评估模型 . j# I. L$ X* f. c% o7 u( Z
. o8 }; u5 W4 A( u- A h5 S9 L ( u2 T) c& |% O: ?
一、什么是生成对抗网络?
?% B, a f- p 生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。 # ?" G# U) l* i( n
8 a; \# o0 N1 l; }
R6 r1 x- \& G; v5 o4 a 生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
# @; _" c' a; v0 d9 u 5 r3 ]5 A8 k3 n: l6 X5 v+ p% v: f
3 G! L# ?) w7 p: j9 B) q4 P+ A
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
& G0 G9 o7 F3 u
+ w% Q+ |; Q ]- w1 h 8 l$ R* _$ J0 }& Q' h* S1 E
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。 ) }& A; P9 T; {6 T' W
& E9 p" M5 ^& G3 o6 H; P% K- b
% ~ {4 `2 _+ U% n& }) m
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。 Y4 u0 p5 K" ~# F3 L
" M- e0 k% f' Y6 T7 Z
$ I, P. r4 r$ X' x' ~
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
" ?+ n0 E& {4 E- } 1 R* n; D0 a2 A5 s, ^3 r
) P. h2 [+ i+ E6 j6 h 二、加载数据集 1 O, w1 `, K% a
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
6 z' T1 M& s- X' H/ u+ D/ ?
9 M* i' W% }* M+ m
. o3 q3 |! I/ L. \: v8 z) d) G (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
; V- y- I l2 C' G3 X) T
; |; t6 e p/ ?2 R, F train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') @& X T6 T: O/ s) z2 G: E
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
) F1 y- N8 I( z& E ( c) A2 {( e h0 S& W0 Z
BUFFER_SIZE = 60000 1 v, A, M* {- O" o3 n8 e$ e
BATCH_SIZE = 256 3 ?) J8 U8 h6 ^8 t" `
% ~5 B& O' @% k: v/ V # 批量化和打乱数据 5 C; n7 m1 A+ [
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# S7 w9 {! L v2 c1 R 三、创建模型 , Z2 o9 A1 T2 f* f& J
主要创建两个模型,一个是生成器,另一个是判别器。 5 w. Q( a( y9 j: ~$ P3 W/ _
| z: q6 w+ e2 L9 C
, ]3 k8 O7 L0 H 3.1 生成器 n. u( m. P2 g9 B2 l
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。 9 T4 O/ o% \! E4 Q
8 ?: H* X# w% {2 V+ b
u* K- Q' @4 Z$ N& ?
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。 $ ?8 f% Q- K1 _5 D
]/ W+ S- V6 G! P$ C
Y7 ]5 N' W% k: a2 G& M" t 后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
% T- S9 Z( y; _6 p; {# o2 @/ k # p3 |5 q6 R7 i6 G
; l3 e, K9 r- U
def make_generator_model(): / N- K9 K$ c- a) \
model = tf.keras.Sequential()
, v+ d7 O( l7 p$ m6 J model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
1 }, w" s( t) y/ D model.add(layers.BatchNormalization()) 6 F! k/ y7 s/ ^$ J6 D' \& X
model.add(layers.LeakyReLU())
3 a0 Z4 E8 |5 ?/ s0 @! v' o( f$ u
/ N6 P' w; U( c5 V7 ` k model.add(layers.Reshape((7, 7, 256))) ' |% u8 P6 G. g- z7 E
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 9 H7 \) P* E" ^+ z" X) @. ^; |
1 P3 g2 p1 e8 _! O# R& G c
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
: K0 z1 {5 t/ {! A assert model.output_shape == (None, 7, 7, 128)
; m" s! u8 b4 ~ |- K8 \8 ] model.add(layers.BatchNormalization()) / l6 s* `/ a, S$ h8 O
model.add(layers.LeakyReLU()) . E# q# s0 g9 J# @* x
; R; y+ b# K1 Y9 h( S% v2 [/ r
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
$ C1 ]( Z# t+ f assert model.output_shape == (None, 14, 14, 64)
N. G, |6 y3 G Q' `* F model.add(layers.BatchNormalization())
, L5 z$ W; R8 o0 G& \* x1 H- V model.add(layers.LeakyReLU()) 5 `- P3 V2 q6 _- V% [1 v: w' M
# I! d b3 u* }. a1 e
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) 8 M" `, y$ p( J3 G ^
assert model.output_shape == (None, 28, 28, 1)
) r3 _ V- H+ y% {5 u1 X
% V. k8 r- u* G& l9 o) x2 g return model
( w8 R. j! Q" _/ k m 用tf.keras.utils.plot_model( ),看一下模型结构 1 T. }2 c+ G8 p2 E; [
1 z: H+ A- |9 B0 G; j$ k / x" S& _0 z/ f* ?% L
}" a# y" F* A
2 F s$ j* |: b9 b% u' y% T4 M 7 ?$ ]) k& i* x6 a
用summary(),看一下模型结构和参数
! L4 I1 f9 U: N* { g z+ c# r( e* s8 t. M& l
8 V* I; R, x) A
* m7 c H y2 |; O8 _! \' B
& v( A3 g( \; W; }+ T 8 m2 L" e& w1 ~. z1 x5 h5 I$ m% R
* z% ]- x& z# z
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
0 L t! j4 k0 e! L - u- W9 g1 I. ]: b) Z3 s
3 f6 u' @( D4 v4 O9 o3 d6 `- }
generator = make_generator_model() " g5 m* l' |' \: k s
+ B ^3 M9 K) ?+ ^; q noise = tf.random.normal([1, 100])
# p% F1 X, C8 X* o5 O generated_image = generator(noise, training=False) 0 f- j5 M0 T, W- M
% @; x+ y; J4 A$ j plt.imshow(generated_image[0, :, :, 0], cmap='gray') / S$ @8 W& I5 w' O5 Z6 U! U
; I; c7 C, O+ ]; W! r9 T# _9 p
/ `! b# u& I; F- t5 P% ^) a
- R2 f" K+ Q; y/ E5 J6 k1 r ]) l/ g$ ]$ K2 W
3.1 判别器
7 u9 i, ?& Q; ^9 V. B( J/ M# H 判别器是基于 CNN卷积神经网络 的图片分类器。
+ l# S# b) B9 x( W: v& B3 x- N
- m* a2 C2 X' G
$ q: w9 K: y% \, m' Z+ t$ T" l3 a7 u def make_discriminator_model():
6 Y# s3 G* F* c% G' k model = tf.keras.Sequential() # q% J+ _/ P9 k7 Y! ^: [
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
* _! o# C* z" R/ [: F S, C input_shape=[28, 28, 1])) " P! p2 m1 _# h1 S& @ l& J
model.add(layers.LeakyReLU())
3 z u5 B. a# F( ]5 d model.add(layers.Dropout(0.3)) 7 X; @6 {/ P2 }% A* {" S
/ o, ^+ q& C) }% Y7 q8 M& _) V
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
) O% g Z- c4 w' l model.add(layers.LeakyReLU())
3 [. {- U" d2 K) e# Y$ _* g& E& o model.add(layers.Dropout(0.3))
/ ^4 o1 Q: a4 f & }, [- {" N e' D; k
model.add(layers.Flatten()) ( _% _' q( q1 m! L7 N; `% [
model.add(layers.Dense(1))
) I- ?3 X4 O a# | r2 a
0 M- k/ M: W& w% c' p6 }! w& D return model
5 G# c( |4 C* ]! P r' v 用tf.keras.utils.plot_model( ),看一下模型结构 ( U- P1 F5 U3 r1 P
/ V6 G* F P9 ~8 R+ h * g. g& M/ s1 z- T9 {+ i
B" @6 D) `0 I+ `6 r8 g( A# l1 V 3 I1 J2 m) _8 C
8 N% V0 i4 g( g+ F1 L M
' Q3 J) o& i1 {& V1 V* K
用summary(),看一下模型结构和参数 - p: [ d# c1 M5 Q
# s& i+ V- j. N * t+ {) f/ |0 ^9 M1 h4 l3 b
* m- S1 L+ E' p* g8 k! V
8 m) r* \- F$ S9 o- o( w6 R
. A4 E5 h/ w/ h+ a& W+ x: ~8 j / _# W& @. b, \ C% l3 l$ z+ N
四、定义损失函数和优化器
) @3 y7 y, P* I: p5 f9 a/ J 由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。
( W2 l: I* O! h' u) U8 b : m1 T3 L; G; S5 a+ J0 V
4 P$ I6 M, J! u+ |2 G 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 3 g& c5 v+ t3 c- B( T; V% L
- G: d. ?4 ]3 Q$ b9 {' s 0 y! D" P- X/ T2 D M4 K
# 该方法返回计算交叉熵损失的辅助函数 9 `$ d X2 U4 I0 g$ c5 a0 `
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
% X2 m d; V( l% L; D$ S7 j 4.1 生成器的损失和优化器
0 |& R! i1 L% z; k& W 1)生成器损失 ; ~& I7 ^( J, e- c
5 I; C! h8 s+ y! ?' l+ z
& x& c7 y) ?! P! |# ]( v
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。 # s. q' ?+ ]; I" R2 n6 g
/ }+ e9 y2 X7 Q. B
/ m; m- N+ d' t {, w
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。 * q6 ^. T( U! R' x$ V# v6 J* o/ N# g% ~1 L
! q% Y. ]: |/ J4 R 4 ]) `5 l$ P9 G J) N
def generator_loss(fake_output): + |8 v. N: a) ]9 D
return cross_entropy(tf.ones_like(fake_output), fake_output) : K- |' ^9 s+ W# ?/ [
2)生成器优化器
( N( H0 {) e% `; G B$ i l$ a
& H) @% ?+ M9 ?3 c4 M
: ~/ Y$ Z2 { W0 C. j; C# c. I% p3 |9 x* ^ generator_optimizer = tf.keras.optimizers.Adam(1e-4) 3 f& J# G" d% m7 x; H8 o2 i
4.2 判别器的损失和优化器 1 S7 c) T5 p! V; n) b: Z' V
1)判别器损失 $ z5 n1 k( {, f5 M; y1 ?( Y! c
+ o$ Z" `9 e7 @, a
- s7 }; q6 R" {: R+ Y, x% [ 判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
! @, g8 Q, P6 G9 \9 V; g) ~
* o' V+ n3 I/ A' E% o
3 Q% j( R6 l! x, L6 v9 s+ ] def discriminator_loss(real_output, fake_output): $ J4 F I4 C5 s
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
2 L9 U! D) |* f5 z fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
; V' K* \0 |0 m- B0 @$ ^ total_loss = real_loss + fake_loss 0 q$ U% s1 F F2 N* C4 s
return total_loss
" f/ P, T& l- b4 }" f8 Y 2)判别器优化器 + d0 A4 S/ q: }
' L- V- c/ T: [4 w% f
, k. @; k% ]! k& b( s3 [# ] discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
- m* f/ h+ u0 P# w2 A# O 五、训练模型
. n# Z, e8 q& v9 I- J* N0 a! ? 5.1 保存检查点
$ ~) \' x; o- D) J7 W 保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
9 v8 |8 t6 c- @
5 q- c5 q; T. [' F4 Z; G 1 n) l0 k* ` P# l) M- N
checkpoint_dir = './training_checkpoints' 3 j# }8 o; D, }2 s
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") ' {5 u# P1 A7 x
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
: G. s8 S; Q% C discriminator_optimizer=discriminator_optimizer, 3 B; t$ z, C5 c! j) E* J1 b1 H' e
generator=generator, $ i3 u+ ?. E, q' |! h7 o$ n; r, `- Z
discriminator=discriminator) 6 c! t$ P$ j5 w" G: s. u% s5 V
5.2 定义训练过程 " ~- c5 f1 z! `# O3 n
EPOCHS = 50 0 h4 l3 r2 p1 n5 {+ S6 m
noise_dim = 100 / I; Z8 n$ I. l6 q& \2 T# v
num_examples_to_generate = 16 ( g- K/ V/ {$ w7 {
3 Y# A' f$ C# R0 Q! ~
! f K) J8 F+ |" [- J7 u
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
i. ~% G+ Y0 p6 ] seed = tf.random.normal([num_examples_to_generate, noise_dim])
( o3 K/ @- a: [; m8 N$ J( s" \ 训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。 & \$ H) K: `/ D( I, W
# j0 j, U; T; A K
/ E$ f' S2 j& @/ n& W) p 判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
: w1 k" A4 f' d; j# x 7 Z8 u( U/ r. G3 D1 m6 l
2 s" y% h+ M' O; D& O% a
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
% J5 O% [* q9 ? / `: q4 ~+ ?/ ^2 L9 i; D8 V" u
% w- H- I) `8 ^0 {% i # 注意 `tf.function` 的使用 & J9 l: ~, `& n# E. @. n
# 该注解使函数被“编译”
. O8 t: `& ?& b8 { @tf.function
: A6 v& u' K* U9 _0 k+ Z1 E def train_step(images):
! S# [# V& n7 d noise = tf.random.normal([BATCH_SIZE, noise_dim])
3 s1 }; I/ t: v+ |4 \6 n - A) _. A, L g8 K& [
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
" o: S* z# U+ I" [( w generated_images = generator(noise, training=True) 8 y6 V1 n( n5 _8 m
8 r! ?% u- R6 \ c* S" h: f. |4 y
real_output = discriminator(images, training=True) ( M6 h5 k# e9 {0 N# j7 ~' u; N- a1 Y
fake_output = discriminator(generated_images, training=True)
, A! Z! w! T9 i8 K+ j+ z, j% D# Q& B. a
, ?5 \% {: X% N* _- Y gen_loss = generator_loss(fake_output)
3 p/ }) ~! ^! j0 D5 p$ j5 s disc_loss = discriminator_loss(real_output, fake_output)
% F: x3 Z2 `" n1 T1 i
# ]: Z/ {3 ^' b+ B% u gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) . `6 e, f2 c( b Q5 q5 ~6 j1 x) [2 G
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
1 j+ F1 v- w* S& o' [$ a
; F: e! l0 `- E$ ? generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
5 @& t0 D+ i9 a discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) 4 I2 ^ D& c0 n7 G4 z1 H$ C; U6 T
' X# U) g# r9 e! g3 W+ L( E
def train(dataset, epochs):
- q }- Q/ d; U9 a for epoch in range(epochs):
3 n m& C" H0 O6 ~' Z5 I start = time.time() 1 r }' `: B; _& t4 U
( U) c% g M' e5 N# [0 L
for image_batch in dataset:
% S/ F- v2 S! m- I: Y) m train_step(image_batch) 9 }% s" Z8 p! |, T5 ?, X( z
; g7 K4 _7 G; k c( L4 @ # 继续进行时为 GIF 生成图像
& x- ~& |4 i, ^6 J( S display.clear_output(wait=True)
4 T9 @8 J; U1 \7 R( ]( J" o0 A/ s generate_and_save_images(generator, ) |. z6 @' M1 _
epoch + 1,
# e" m& h2 L9 _! m6 y" w9 W seed) * _3 W% u& u9 h
- P5 N. e: t, O0 X: ^ # 每 15 个 epoch 保存一次模型
8 `' \2 P+ T# n1 v if (epoch + 1) % 15 == 0: ' q% D! N# f. x3 I9 k; u, y
checkpoint.save(file_prefix = checkpoint_prefix) 6 Q0 l* G7 \. S6 h$ f, r
: u2 C, a! C7 D( x2 Z6 k. W: b
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
1 D+ z8 Y% m7 V9 s, t
# i+ F+ @+ Y* p8 x" y# a # 最后一个 epoch 结束后生成图片
( P! D/ y% @) \" [9 P3 S% { display.clear_output(wait=True)
) \* b- C, C; f! ]. e d' [) ` generate_and_save_images(generator, 8 g; b4 m [8 U
epochs,
0 H/ D7 i% |. D1 i# b seed) " d3 R7 m4 T( ~- E
) G \* h% p, y+ C) u/ ~: N. s
# 生成与保存图片 ! w4 z0 W1 i- ]3 ]# @# ~0 x
def generate_and_save_images(model, epoch, test_input):
! e6 M9 a; j, N # 注意 training` 设定为 False : T5 }! P) D8 g/ u- D7 q
# 因此,所有层都在推理模式下运行(batchnorm)。
* u& g6 q: p; `! L) ~ predictions = model(test_input, training=False)
5 Z: _7 [+ w2 y9 P8 c' B+ s; J- Y- y
q+ h% s* S4 Y$ t0 d/ L' t fig = plt.figure(figsize=(4,4))
$ g; J- G% r- {; d, U" d % ?/ i# G" k# e- W7 t: n4 H
for i in range(predictions.shape[0]):
* @' ^+ G3 l2 e8 f" z plt.subplot(4, 4, i+1) : I" }6 F* J' u& E
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
8 H0 \+ y+ E) } f0 e plt.axis('off') $ J c# Z% i9 |( @
+ k7 @! u; s+ D6 f2 J% u plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) 0 N% y! A7 e$ S
plt.show() - [% X7 x4 T5 h0 ?6 P W; e
5.3 训练模型
k5 |1 B" X g- |: f9 {3 F 调用上面定义的train()函数,来同时训练生成器和判别器。
9 x- C, ]0 Y+ F9 k % G, ^( _/ D- m" K: M
$ a( R- L# r1 U2 b2 i( o' Z g4 o 注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。 # C1 R6 D) I# I/ X# W& K Q1 j
1 \7 S* g0 F& s% V2 E " ^" L* U0 `; [* n0 y8 k0 i T
%%time 9 @; W2 a& b; f. n3 I
train(train_dataset, EPOCHS)
' x5 H9 B( c$ | 在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
; ^7 [$ t! U; K( V; Z; ? " w4 c! R) d' B3 c' ~
' Y/ m' h# }4 I, ^6 q$ i: N 训练了15轮的效果:
2 Q- u- C# s' v
" F; a& Y/ v1 ? Y+ O' @) Q3 o& M1 [ * ]# M# B4 X5 w2 `( [2 j! l$ @+ L0 w
2 f5 e$ I1 \. ^- L9 j' I$ r7 v% J # z O* a) H1 }6 s7 I! z
, e t5 |' R& p0 |" f& n
+ s5 p( }' T+ s, R, E6 _( i- Q
训练了30轮的效果: + J# H3 ~6 F; z' I. ?' B) B
5 u# ]* b6 } P/ a+ M
8 ^/ f( |3 ~8 U& ]% a : | A/ Q7 U4 v/ D1 N) u
$ Y0 u) q* m1 _5 U! Q& Z3 _* I
0 A9 F: ]0 H' d
5 o& d& j6 p) b. P
训练过程:
1 b) ]6 M# P ~ j' c" {- m2 o p0 w i7 n! |3 D7 J! I$ O
# C1 t( v* B7 y1 ]; y- w% H
1 u. r4 G" I# E( K& L; T0 S( G% B P . x6 q: J) N. E
2 ^7 a5 b' l' c6 _( g $ r- V, ^$ `- N# }6 I
恢复最新的检查点
$ H2 K/ }! }) Q) x6 J 6 W S w& j0 n" {! E
% M/ x5 |* h5 [+ p2 x% H5 h checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
$ C' ~: x# {2 m6 l. W6 E 六、评估模型
/ _9 {9 z9 p, s" X* u7 w/ ? 这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。 ) z$ h" j H* n7 u! G- c, w9 |
1 U9 g, ~. D7 p/ T* H+ H3 H 7 a% z0 W, z0 q( U& I5 r0 ]
# 使用 epoch 数生成单张图片
8 s$ k+ @. M+ P( T2 M' I/ ` def display_image(epoch_no): / B* K/ m+ z2 f I* m5 Q0 P4 F- |1 T
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
5 A- O" i5 @; j. ~" X % G, T5 f8 l9 _% M# r7 v- @: O
display_image(EPOCHS)
, a3 K/ c( b( [! t) G# t3 j anim_file = 'dcgan.gif'
1 [- u' e9 b% M+ D7 e : Q* V4 t4 |) y. T/ m/ N( m6 X# k
with imageio.get_writer(anim_file, mode='I') as writer:
9 M8 \9 q( t" p; C filenames = glob.glob('image*.png')
& ^2 Y4 d- i3 E filenames = sorted(filenames)
7 k1 S% o! Z% R! f+ F8 P last = -1
' M/ g% w& _! |5 [* i$ c for i,filename in enumerate(filenames):
+ \1 L' s4 a! y" f# X& ~* U frame = 2*(i**0.5) 3 U% s4 u$ `3 x' [
if round(frame) > round(last):
! [8 L; U* [/ _! \# p! \0 t last = frame 8 `: I3 f+ B: l9 R6 Q
else:
K& L4 ]$ k7 E5 e: d continue
8 C4 `* k/ ]' R: N/ y7 l image = imageio.imread(filename)
2 p. ]. K3 J; J& f/ M+ V! ]% u2 k writer.append_data(image) ( U) `" h2 N. ]4 g$ M* U V9 c* U
image = imageio.imread(filename)
: g$ p: U% G6 n/ H$ n) o$ y4 n! M( Q writer.append_data(image)
$ z* ^* V" L6 m6 t {7 F
* H7 Z' T8 J: G" N' K2 N2 R import IPython + U C) v0 R8 M9 b# `( h
if IPython.version_info > (6,2,0,''):
6 H& E8 O# q3 b) I) W/ @ display.Image(filename=anim_file) 7 y5 i4 D% M% b8 q
+ J* f0 V: X$ w/ z# a& n9 ?
: D3 y, \" U1 J2 \* F% ]! Y
% J- f# v& J$ ^0 }$ f7 \5 E
% b! w$ e$ B& r7 A8 j( D 完整代码: 8 ?# G4 p. j4 e2 L! W& d7 E& ~
# R# v8 j2 r! T( T. F
& Q; t% P- L- Q import tensorflow as tf 4 I5 S2 ?$ X% j$ F2 A9 c
import glob
0 K' n9 c: u; u7 P( S import imageio % Z$ I% [% U8 D6 h2 ^. @
import matplotlib.pyplot as plt $ Z p! f6 S% [4 Z# m, B7 `
import numpy as np
! z: Z4 @* l( ^% _+ j import os
6 c+ {& w2 L5 x7 F2 q: C import PIL
. x6 O1 h7 l) i5 S9 T from tensorflow.keras import layers - q% u) v8 _0 B6 e% q
import time * `3 ?; ?" b! ?0 Y: V
! i8 r. O! J Q0 N from IPython import display
* l) e2 b6 P. ?$ [: P. G6 l " O7 T$ h+ r3 d0 l
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() 7 J) I" L. k& o
+ y- I! N1 ~4 I3 h7 t1 T
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') 1 v' ~: D2 @; k9 v6 n
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
$ s8 }8 B. ~" N% z$ i) R2 I , o0 ?8 P; T1 T& \
BUFFER_SIZE = 60000
; n! N3 n- q9 A. ^; m- P1 ? BATCH_SIZE = 256 7 m( \- t1 M9 [% z# m1 r8 h, Q
& N& E0 O$ _. E* `, l # 批量化和打乱数据 6 J" V2 [3 V9 n/ {/ u, E9 N
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
: d. J8 H0 R; o$ L+ H 8 J9 c1 n& w# Z6 f
# 创建模型--生成器 % G! o% W8 J6 g
def make_generator_model():
, e/ J& t# W$ {8 q% G9 g model = tf.keras.Sequential()
% P9 R$ K( F( ?1 w- i9 ]5 o$ u model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
: u% m9 h4 i6 V: n$ W model.add(layers.BatchNormalization()) 8 {: r: F. V- ?: T* V/ E6 O
model.add(layers.LeakyReLU())
% i) |" K% v5 L! v: q5 p 6 y; x& ~# H4 Z
model.add(layers.Reshape((7, 7, 256)))
+ |- z( b9 v# y! i: s3 l assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 8 P: N$ \/ G) W' W7 I+ m
; t: o" ~4 M/ B7 @
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
" R3 U, {- m G/ l+ N) { assert model.output_shape == (None, 7, 7, 128)
+ `. N# z! [% Q8 |( @$ h model.add(layers.BatchNormalization())
- }8 |! t' X, m; c/ K model.add(layers.LeakyReLU())
- p, x5 m. [/ T2 ~ D& B2 Q) P/ F
9 V( d1 M9 I1 Y" K model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) 1 x/ `2 a" S! R) r* \
assert model.output_shape == (None, 14, 14, 64)
# E; Z& N5 z) a/ C model.add(layers.BatchNormalization())
4 q& \$ t1 T7 \% E* C model.add(layers.LeakyReLU())
# p/ K3 x; l4 l# {$ `3 W% a8 S ' H/ i; h* l2 Q( m) B
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
- D0 e8 {3 e( O2 \' `3 U& ] assert model.output_shape == (None, 28, 28, 1)
, X' W" o$ w) e( n* E
# p* v8 O0 `! |# O% m! z) h1 S return model
1 L/ h' M6 H; r( X6 h& Y7 x; n " |. I: ]' d% s% j
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
. n" m; |( {* [; M X# X- o generator = make_generator_model()
: Q: `3 f" G+ M$ w/ L 4 @% l4 t% {: j5 M/ f! P) }& b. F. [# y
noise = tf.random.normal([1, 100])
8 t$ R2 s3 O' }* p# c, e. E' ? generated_image = generator(noise, training=False) 8 G0 @: H( N% Z- S% s% ?. D3 H, ]
! G! f. T1 ~2 ^& ^4 ^! C plt.imshow(generated_image[0, :, :, 0], cmap='gray') 2 ^0 O8 ~: C6 D/ x6 z
tf.keras.utils.plot_model(generator)
- h+ s7 ?; ^/ N7 ] ! ?! s8 Z/ Z* R% E. H: K- l; \
# 判别器
" Y f! ~' |% W G def make_discriminator_model(): 6 D: N6 Q/ _7 d, a% z4 Y& a
model = tf.keras.Sequential() ( e+ ~- f2 R9 a# ?# A. }. ^
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', 8 \0 i* b6 ?8 r2 t- q
input_shape=[28, 28, 1]))
8 Z. f, j- ?5 _3 M) A model.add(layers.LeakyReLU())
$ {2 c8 o2 s5 S model.add(layers.Dropout(0.3)) . p: T8 s2 e8 ?8 M
g O5 P7 f% M2 Q7 H
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) 0 u! d- E! W! e4 I
model.add(layers.LeakyReLU()) . n5 S. J9 O$ {
model.add(layers.Dropout(0.3)) $ |5 X6 V* |* N4 C
: |3 Z: e$ d2 d model.add(layers.Flatten()) ; L, Z& ?* ^7 m# y
model.add(layers.Dense(1)) $ P3 U/ g6 c }# z
$ W: N( R2 ?* W# \- `8 m/ B return model 3 ?- y6 r, X0 p% K# I4 a
M; s* ~% X+ i. R # 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。 4 [2 r' {+ O: I3 U- u, r" X
discriminator = make_discriminator_model()
, R5 x/ s- M! {% g& V decision = discriminator(generated_image) 6 |* K1 Q- u$ e3 W+ C& p' [
print (decision)
* Z; U' q9 l! d! ~! }
* n% `" q. b3 g# C3 n% ?& j # 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
/ L6 P r/ \7 z+ N cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) / e/ h; s5 A& q5 `
& }# s8 C% D8 @/ m5 X6 z6 e
# 生成器的损失和优化器 , L1 r% c# F8 O \
def generator_loss(fake_output):
: i# u) f, x# d$ M% Y _% Z return cross_entropy(tf.ones_like(fake_output), fake_output)
" y" u" T& S; I generator_optimizer = tf.keras.optimizers.Adam(1e-4)
" V) s# B+ v% L. F# ?
2 x& M; ?, `+ D" U" ` # 判别器的损失和优化器 ) Z; `9 H( q8 Y: ?" S
def discriminator_loss(real_output, fake_output): / V! h/ L! c- b g O0 j: N) t: o
real_loss = cross_entropy(tf.ones_like(real_output), real_output) 2 h3 @( Z1 y7 o! N9 u: |7 c
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 2 r `8 t( N9 t# f3 W
total_loss = real_loss + fake_loss
3 A1 v+ v" \, J: J8 S6 m+ f return total_loss
! o4 x9 O$ X2 `: x, B5 F discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) 3 O6 K7 r% s- x7 D1 p
. ?+ G' Y0 T: _6 ~' ]: t1 x9 P
# 保存检查点 " [7 f' S% r0 K& K7 B* B4 g( r, F
checkpoint_dir = './training_checkpoints'
8 b; G! N' S& ~# d2 } checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
% \9 z- d3 t/ g. _" {7 X checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
* Y" E, n* K9 K discriminator_optimizer=discriminator_optimizer, ! `% R4 o& n7 @# b1 A
generator=generator,
9 w) v6 W0 Z/ y6 m( o( E8 n discriminator=discriminator) $ ^0 P- v& [% X4 t
" l3 @, T( o# u # 定义训练过程
. d$ `5 R' Z O) ^* p X EPOCHS = 50
$ V) r( T, T% s+ v7 ] noise_dim = 100 # ^1 F! X% Y# R' ^
num_examples_to_generate = 16
* c Z+ e. K$ W+ i) B. s5 E! p 6 E$ f' g% |' h5 Q9 j! h& W
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) ; [2 f" d, v* w0 d" y1 j
seed = tf.random.normal([num_examples_to_generate, noise_dim]) $ _( D5 F1 ] Z& L
; m( P9 s% Q4 E; r5 f
# 注意 `tf.function` 的使用 6 K* [! {$ }8 q% q
# 该注解使函数被“编译” ; D+ B. R/ G: X* p$ W9 g
@tf.function 5 p0 {9 M: V( o+ Q: O
def train_step(images):
3 G8 l4 s" A- u noise = tf.random.normal([BATCH_SIZE, noise_dim])
0 ]$ }- P" y* F4 ~% W $ Q% T- Z; q1 e/ Y" p3 X3 r
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
! o' Q5 A: d: h+ y) S. V. q generated_images = generator(noise, training=True) , m2 X# s0 G4 t+ j
6 Z8 w. ] J0 y& {) o
real_output = discriminator(images, training=True) % i6 M" N7 Y# |
fake_output = discriminator(generated_images, training=True)
) M4 M" \# s; r4 K8 R1 a) N/ n
" m7 e+ ?. w& [# t gen_loss = generator_loss(fake_output)
- g9 ]( ?6 X9 v disc_loss = discriminator_loss(real_output, fake_output)
4 n; c0 X& c( ]5 E 6 u- a, E G& D# J7 G8 n( x1 R1 N
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 6 x# j" J" y' d) S6 u- F
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
& P' m2 ?+ L; G # P0 T* z7 s( U
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) 4 X }. d' J- {4 q8 `
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
: c! V7 `5 m; Y7 K3 D, v I
9 L8 p, O3 R4 `3 C) t% r def train(dataset, epochs): ( U# g: N' g5 @8 C {
for epoch in range(epochs): + C: G9 b& N1 z5 [- Y* _
start = time.time() ! {/ W8 ]2 |0 {+ {7 W7 X
' R) [- S3 q* t
for image_batch in dataset:
3 K! v; A) ]3 V2 u4 u+ G6 o train_step(image_batch)
! }4 o! W5 t _ " m* w' I4 H% U6 O# v8 x
# 继续进行时为 GIF 生成图像
" }" w6 w( E& t5 W. k display.clear_output(wait=True)
% t- o" c. `& ~4 l. g; F" J! i generate_and_save_images(generator, 8 D& [( H% i) q! a: {- w. j. X7 C
epoch + 1, 0 K5 F9 {( j8 }* u9 r; q5 |
seed)
! J& Z4 @- Y; q' C3 ?" ^
$ _# L. s' K+ P7 T! u # 每 15 个 epoch 保存一次模型
7 _4 }' ~. D% r( G/ j9 v: J if (epoch + 1) % 15 == 0: * ^/ J1 O) k9 A
checkpoint.save(file_prefix = checkpoint_prefix) % J0 b1 [) ]4 Q+ w2 \
# q/ S- c$ m9 j. h
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
: J8 @+ L6 s. p/ J
7 D/ c3 B2 `1 V- O # 最后一个 epoch 结束后生成图片 j4 a: ]' P) m% U7 X& c O
display.clear_output(wait=True)
8 J4 H% w% @0 ]$ i% a5 r generate_and_save_images(generator, 9 `9 Z- l) P2 T) o( a
epochs, ' t9 L3 N- I. R! N
seed) U2 t2 u4 T; O |. L
; D; Z4 _, l7 n0 I- @ # 生成与保存图片
* T G8 M5 ]' c# T; U! ~2 Q* e& v def generate_and_save_images(model, epoch, test_input):
6 P* v- E N) g # 注意 training` 设定为 False , z6 |; ? K0 [# r# d% L$ o* W; j" g& L
# 因此,所有层都在推理模式下运行(batchnorm)。 6 b5 {& {/ S* e3 M5 K
predictions = model(test_input, training=False)
+ }/ N9 L; d2 m! a# J1 [) g
+ v5 i9 L. c/ I& d; u0 a fig = plt.figure(figsize=(4,4))
! y c" |' r7 v% |" o! s
7 s- d9 i0 ]6 L' U" z for i in range(predictions.shape[0]): 6 c2 Q5 t, d2 n1 @9 M! ^( S+ d, f
plt.subplot(4, 4, i+1)
* c( ^; i3 J1 M3 c* G) i9 B4 ~' z plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') ; b/ _9 R5 r& `) N
plt.axis('off')
# q% c+ ^; z" c
- U/ C5 a. F# i. S% _( ~ plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
2 s, k. x: V* b3 |+ W/ M! B3 W plt.show() & [; K; ~- ?/ c% {5 k
) g% _3 U( k$ k; B1 m
# 训练模型 9 h4 C) t( y0 z7 q) G
train(train_dataset, EPOCHS) ; i% X- }) P* R' H6 b3 ?
" _& Q% B; _, R
# 恢复最新的检查点
# r3 T. q$ u1 W checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
, p' k5 r) ^+ N4 }
9 |( X9 d% \* N$ c& i # 评估模型
" w& p! V* i, O* _8 E5 e4 E, o # 使用 epoch 数生成单张图片
. m% E I9 G: D+ F& @ def display_image(epoch_no): , z! u! Y/ |! S; l+ w. b$ H' f/ d
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) # u; j- Z4 R$ R, u) ^& R
, V {8 [: Z+ }( \8 k
display_image(EPOCHS) 8 z6 e4 Q& y2 H) E8 u
5 a) @# e: D' R( }$ h anim_file = 'dcgan.gif' $ i/ p$ a) Z! s
4 E, n0 w0 F6 E! X
with imageio.get_writer(anim_file, mode='I') as writer: 3 q( v" z5 J V. G ^1 @! C
filenames = glob.glob('image*.png')
* d; r- l2 C" ^ filenames = sorted(filenames) & ?' O8 {6 v; ` I% z* ?
last = -1 - O1 `3 M/ ]& U( L( f( [
for i,filename in enumerate(filenames): ' p" O+ W' I5 K
frame = 2*(i**0.5)
* a/ \ J T& n7 Z# t8 S: F if round(frame) > round(last): " b/ A$ v! ~! ]" o' }! W
last = frame
( O; {; _; N' s- d% i else:
; \4 ]( v+ I7 O3 b continue ; r8 i: a- b- s7 j( |0 s
image = imageio.imread(filename)
# ?4 A5 `$ \, @4 U& M! @- F6 l writer.append_data(image)
( v* j- k# B, h6 K+ V image = imageio.imread(filename) ) c; p8 C$ S# o D
writer.append_data(image)
6 x& \/ a, m8 z- n, D1 u6 b 2 H0 N- c2 w2 Z3 v3 L9 g( O
import IPython , W" t) ]( I& ^3 b' x1 T
if IPython.version_info > (6,2,0,''): 6 P! m E1 Q: q# T& L3 H2 h
display.Image(filename=anim_file) ) Z# K/ @6 F( R- l( I5 h3 s6 ]0 q
参考:https://www.tensorflow.org/tutorials/generative/dcgan ; u/ R7 @0 m" ^
————————————————
. ]" Q; v) l1 s7 g 版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 / G' f2 [; w6 ]1 h
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111 : i$ o: M8 I1 S
3 {9 t8 t" a& D9 A0 Y( p8 I( Q
' ]$ Z4 x9 x% \# m
zan