在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 556036 点 威望 12 点 阅读权限 255 积分 172184 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 18 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
) @5 z# ?. @/ |: S* ^
深度卷积生成对抗网络DCGAN——生成手写数字图片
! K1 V8 y4 E! `+ J6 V 前言
- z7 T/ [8 F) n/ F 本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
( N0 m' F H' }, Q2 p; _- X R ) y; e2 t5 n& z7 c* Z
6 q" B% {( O( j5 o% d, i 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
$ c( T, Z8 J* l8 k& d% n- f! W 3 G. R+ [. K& A! F: v
; H1 w4 B' P, U
# 用于生成 GIF 图片 % m( R( e5 |+ T' I/ H0 ]
pip install -q imageio 1 m2 k- M" J8 u9 \
目录 7 W$ \, B4 U6 o7 e. L0 A9 Z
& x D' k/ I9 e" U$ s5 m : Y" {- h3 w1 J
前言 % Y3 b: x+ W( ^; |1 k; ^
" i4 |) W" }" d( K7 w
+ C5 y* x$ I7 n; C 一、什么是生成对抗网络?
& p# m1 [8 w& w7 e$ V4 J. q7 G % V) x+ j) }3 C, q' q* |
( c: i; h% |: `+ x, J 二、加载数据集 0 F( W" P' Z+ a6 h* J( l
2 Y! l) f& H( c9 y, b4 z9 e' {( V$ N
1 X# q* a% s' n) T' o) G
三、创建模型
+ E& L7 A- O4 O& M3 q 1 h( c$ @+ U+ v& L) b
" h9 U) F! e8 r6 D/ B K |- D$ m2 u 3.1 生成器
3 S8 w4 R( Y- l8 W2 r/ H
6 y! ~# d: |: F0 ^9 `3 j 7 ?" K' O/ J9 Z* ~* d0 s0 Y
3.1 判别器 ! u6 f2 L8 z# k, P1 l: w+ D h9 u: F
% X. |1 Z0 q: ]
( s1 ^6 d0 b) ~( C0 a8 f& S$ Q 四、定义损失函数和优化器 & A' Y6 s& S" p6 M/ E
: f1 O. X/ _+ E% P+ Q$ A( l
! |- q. h2 N$ a+ }& j6 ]* ]8 ?
4.1 生成器的损失和优化器
' k1 f/ _; V( R( @ 1 Z9 H+ T$ G8 \( _1 B, G6 d S/ f
# J5 Y, Z5 |) d8 b+ v 4.2 判别器的损失和优化器
& J5 ]: ^% A, K- {7 Q6 n! J
* |7 Z% ]% S& @4 D0 c
+ D/ s5 K+ @0 [; y/ }" Q$ K9 S5 ~ 五、训练模型 @+ a/ J0 U0 @: ]3 P- Q6 ]* X" O
3 |( e9 S8 _: V: j! _# E
1 O4 @ E0 {) p5 H% O
5.1 保存检查点
* Y! @) V6 f$ L3 \ ) N/ _# {( N! M6 K% f+ v. P. L
. N+ p9 \# \& t; W: n% D 5.2 定义训练过程 4 `& @% h* z& g1 Z
* H6 R# V, L" \1 V8 Y) I
' e3 } g- B, r* i1 T 5.3 训练模型
; x. s1 R) B2 k* J4 I: U6 M- _ 6 j. T, p. J) t: s
$ G# L* V5 b9 e/ S0 }! V/ D
六、评估模型 7 A2 G, r% Z' ]( s2 B" @* g$ H
; `; V# `) O3 v6 P
0 K! _0 ^% f" n2 T: f! H1 n; r
一、什么是生成对抗网络?
1 x* v r3 s% P+ J( Y; G' I 生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。 * F/ _ ?4 C" o' F7 M/ `5 k
+ w" I' J% S# t$ G. a
) E' Z" j# ^# n/ |' d P7 r 生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。 / U& G- W' W! V. y& r6 k) v
8 b Z: G2 B: h
0 a" }& s% ]) C3 k, p: K8 P 判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
! D% F1 o* A' K: I* u) ~
8 _. |/ H- @# J9 f8 W $ E" Y( g- A7 L M5 O( F/ l" L
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
: b' Q5 f' ?- o0 v. E) r2 ^ 7 C( [- C& O7 w# o
9 j! M* f$ P( n( J# T: v* V$ P$ | 当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。 $ `5 ]4 D$ @% B* ^3 E. u$ y
2 g6 K; O* a. G8 N V( `+ ]
- d2 d: d! R7 g/ k 本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。 ; ^6 c l% \/ \/ T
/ \: |" I; R- T7 R. E* o
! T2 t2 C4 `& l$ \
二、加载数据集
3 \, P+ P8 D v4 {- ~0 m 使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。 6 e% L3 Q/ r# Z: Q) F% ]
; z! @( k5 C7 Q1 M0 W7 ?
9 |8 S% A0 B G& V7 }
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() % ]- H) K. ~9 T( m! \0 o* f- D4 Z
2 k# V O% z7 P) z k" H train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') 5 W: p& o. W' Q$ |( H
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 4 S* @4 X+ H. t9 ?6 c5 J
3 Q6 ^4 h0 s" ~: h. [( Y
BUFFER_SIZE = 60000 1 N$ e& f9 B2 g _+ N( \
BATCH_SIZE = 256
$ p& ?2 ?6 n+ E; G& \2 n0 B . W+ V! I, k5 e$ [6 \- Z* i
# 批量化和打乱数据 # }, ^) ^1 M* o
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
$ g+ n5 T% E6 e6 C8 y 三、创建模型 3 Y8 F; h6 m+ }( ]5 R; n
主要创建两个模型,一个是生成器,另一个是判别器。
. E! d- L ?! w$ i9 u: j8 h ( ~. `/ ^ {$ S: l& Y' L- y
- [+ \5 C* y' V( N7 m4 R
3.1 生成器 $ ~- U6 m C, F; ]% f1 y
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
1 _. K3 }6 p {: Y% { 7 s( i# M9 w( Y. B/ `
$ |5 w( Q' R, n5 z9 E- o' G 然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
8 E' _5 f6 L3 h( ?7 a - Y' i: _9 U4 S! |/ l
( W* M9 n/ c. x5 M
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
/ p) X! R; [7 ^0 v: t 9 f- E; T# {, u* I3 P( m& m" ?: G
2 Q7 t; y# |% J* Y
def make_generator_model(): ( O1 G! M; _% N* h5 F
model = tf.keras.Sequential() $ N# g6 w# ^) I/ a# M3 l' z
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) 0 Y# A+ @; b6 B+ I& a! q
model.add(layers.BatchNormalization())
0 Z' }9 M! t6 e& Q( Z9 W1 r model.add(layers.LeakyReLU()) $ v! b) X! j$ T
5 Q) {/ l/ G7 l7 q c2 ~! o model.add(layers.Reshape((7, 7, 256)))
9 a4 Y$ ~4 C% V. w$ D4 H Z6 F assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 # ]: V( u/ N8 R' w
$ S: h/ H A( M8 v3 [
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) ' c- N& I T: @( }
assert model.output_shape == (None, 7, 7, 128)
5 Y4 r0 @( D+ q( i model.add(layers.BatchNormalization())
- s, [' d& e$ A% C0 f6 L% U model.add(layers.LeakyReLU())
! d, f9 ]2 V2 m& j |$ M
& K* L3 S8 I2 f7 }8 R% g! B model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) : p& s f1 D d T9 ~# l
assert model.output_shape == (None, 14, 14, 64) - j# P. D, k' k8 ?/ e r& ]
model.add(layers.BatchNormalization())
$ g% ]) O6 a1 p$ e/ E model.add(layers.LeakyReLU())
7 V4 Z( P+ b7 Z5 }/ l 2 m! G! A% x1 x# {! ]2 k8 G4 x
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
2 k9 R2 O6 _+ ~' z6 ^1 f assert model.output_shape == (None, 28, 28, 1) 0 s+ m% i+ T$ n+ [% P) g3 x) a
# }2 m( y1 J" o7 d7 u. ?
return model
* {( P8 I+ U+ N% g 用tf.keras.utils.plot_model( ),看一下模型结构 9 ~3 x: B0 h4 ^. q
* I" z0 O3 i: `# I/ ~& {2 I
9 U" o3 K" A$ I! b; P ? + c0 K7 N8 [; Y: v6 ?- t/ G1 X3 s$ v
$ Z K# ^, \) _7 t
+ O) h2 I$ k( k6 z' ^ 用summary(),看一下模型结构和参数 & h' t( R: O1 L# Y
1 u# F' _' l/ _9 T2 H! P a- z4 ?$ Z8 ~4 ?" f
$ y a+ Q! U. j& g0 K8 d
/ a# V9 R7 G) Y3 H7 Y: y" P
; @, [ g1 l/ E/ {! E( `% l+ l ; T, Q# P" O2 V" t; ^ R9 N; Z
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
( ?$ f0 v2 d! `0 z8 u
7 @" v5 \) \0 c + y) G. v( y( f$ U/ ?' g1 ~
generator = make_generator_model()
$ E) c; x" L m+ D! _0 ?
8 T% o( E% M t: F noise = tf.random.normal([1, 100])
! K) y! U+ R; K. L0 w1 i$ w( p generated_image = generator(noise, training=False) ' r# y0 k; B" H- f7 m
2 M' C5 p; W# Z
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
$ p! j* N `) y
* u" ~0 c3 [/ E9 X9 w% @. X' i9 R
; ~1 ^6 k& m, f5 |( ~, P# u 3 U) P! p; Q6 M- ?& z
' y4 _# E- U" ^- p' k+ e: S
3.1 判别器 ; M( G: W/ L3 {
判别器是基于 CNN卷积神经网络 的图片分类器。 $ _ J5 _9 K- Q; @
0 h& O% {3 l! K6 V3 @* |! A* D! s
9 D' k0 v- X% e" P3 k; t def make_discriminator_model(): g" I0 S, A0 x1 l' ]) [
model = tf.keras.Sequential()
% O. z! M1 O$ J model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', H/ d4 _: o0 X! o
input_shape=[28, 28, 1])) : r q3 Q% Z& Y* @$ T; }' H; K
model.add(layers.LeakyReLU())
) \$ r. N! [# E$ [' C$ n model.add(layers.Dropout(0.3)) / D" y9 L ?0 Q- i; n2 B M
% P, ]% Y' Z) M) l& @' y$ p/ z model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
7 x4 X; [ [* A" j- }" e! s model.add(layers.LeakyReLU()) . |% x; Q' ?; b+ y
model.add(layers.Dropout(0.3)) 0 c% ~: p0 L- R5 h1 Y* U
) i8 W( N* X5 t0 d
model.add(layers.Flatten()) 5 J0 C2 {. ~7 |8 y. {
model.add(layers.Dense(1))
7 O: T6 G, c/ f) p0 h
8 [9 x6 N9 x, {" k$ J' D1 q return model , o5 \, Q" j. t# ^
用tf.keras.utils.plot_model( ),看一下模型结构
. B- [5 D& L, Y: \; r- J2 ~ 4 B- E5 Z' m4 u$ z
& A [" h" s4 j' O9 o9 Q, Q
, A, ~ d. u$ v+ b 6 Y# V& C! O3 I& L8 b/ H
* d9 ~: \4 a" r4 s
: R4 j. Z$ P# |; F9 J1 @ 用summary(),看一下模型结构和参数 ! J; Z+ U* M$ S/ W) n
, }$ ~8 `5 Y7 u- D) e: |
5 z& j! G. U! ^& r' m
# x( B9 @' G# q7 Q3 e% ^5 y8 L
. l# Y/ b0 ~1 d: X a2 c3 c7 W/ s: b4 Q! Q
+ v6 q3 ~! ~6 W9 j3 j
四、定义损失函数和优化器
% V- Z, Y- o0 ~( w* ~4 Q 由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。 ' W3 o6 i, s& ^1 F7 A
4 K3 \% T* U3 Z) T$ U* J) {3 e6 \
3 M( U9 N" V" x; I/ _9 ^ 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 7 l2 q. V( o8 d; I. Q7 |% T
) z7 e& e9 \4 _: d& a
3 C$ ?! W+ h V+ ^8 n # 该方法返回计算交叉熵损失的辅助函数
8 N y6 l: W6 m7 G cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) $ s, ` z3 m# C$ ^; Z' |& E' A
4.1 生成器的损失和优化器 ) b$ P' N& ]9 `- [6 I E" x
1)生成器损失
( E/ p& e' o4 ~. O( z& M/ x1 C/ x
$ B W& N- u/ S E3 C# w
5 F. m7 V% d( j7 @ 生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
/ p# p! k7 A y: g# F . g3 t. h/ U. p% _5 Y
' f7 v& m4 ]' h 这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
2 J+ N% P/ G2 u0 V( d0 [7 n3 h2 O, O / [$ Z: I% w8 _" p" P- t$ c7 w' ]
" E D6 S) Q: g6 F' H5 _8 ~ def generator_loss(fake_output): 6 j% h& B, I. o3 n, ~
return cross_entropy(tf.ones_like(fake_output), fake_output)
0 [. P1 f9 V. p9 M2 t 2)生成器优化器
! m G, H$ P/ I' P$ w
; F" C# Y7 \' ]( @" I
/ }; k+ b1 K: A; e generator_optimizer = tf.keras.optimizers.Adam(1e-4)
$ [1 H9 k3 F( r1 P, Y 4.2 判别器的损失和优化器 % O. ^7 \% y) p! l5 F$ Q7 G
1)判别器损失 ! h" J7 x1 f0 ~5 `: Z' N
]6 ~0 q8 k& M1 w! b* G
! Y% w8 a7 g2 a- @; m. e6 A 判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
! d* ~! |4 Y6 d! B' y6 v9 Q3 H" { r& S3 B+ A6 I' q h* {' I
4 m7 l8 X4 w. e/ P2 E6 C: `# d0 G def discriminator_loss(real_output, fake_output): $ c7 _- _* Z6 \- a( X
real_loss = cross_entropy(tf.ones_like(real_output), real_output) 3 ]5 S( G% Q% l- L! M. J. M4 b
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 9 F3 o _+ x2 b3 d) O% b
total_loss = real_loss + fake_loss ' I# G' v4 {* q: h
return total_loss ) H% K, K0 ^9 R* f, T( Q/ O
2)判别器优化器
3 C# W4 m% N. j6 h2 [ 7 `* y& m/ S* Q& }
: B4 w' @. }1 Z9 G/ u3 D6 e" D
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
( T" f/ u8 s$ ~( Z" |9 _# G5 H7 A 五、训练模型
% d* u1 {) s' t6 t7 v1 v 5.1 保存检查点
5 K! e. _: w$ w* ^$ \ 保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
9 t: }# l2 h6 i( A/ s2 j
$ }1 @$ F6 ?$ T
$ v% t, t+ u C# d9 }0 e checkpoint_dir = './training_checkpoints'
5 t! y- n# P. O( U checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") % l; L) F" X2 ~+ a( z0 s9 T2 W
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 3 c/ e, G5 P3 F3 e
discriminator_optimizer=discriminator_optimizer,
' Z s5 f8 \4 o generator=generator,
; H' _3 Z, [& k w: E: ] discriminator=discriminator)
9 }3 e1 z, i; \$ r 5.2 定义训练过程 - Y8 Z |: q+ b
EPOCHS = 50
6 D+ k+ H) S$ D g& Q noise_dim = 100
3 L9 S! {& u+ _4 `& F num_examples_to_generate = 16 6 g0 {) g5 v3 g8 s+ T0 ?, Y
3 f5 y: L- l& f2 k C) X, e$ N / v4 k j" U8 D# V1 I3 V# `
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
+ Z6 A9 ^ d& S seed = tf.random.normal([num_examples_to_generate, noise_dim])
: Y+ C$ g5 C0 z8 Y' J8 O) W 训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。 " E7 N# o$ f" H; n% B
3 x9 L! e2 E2 A( r( c6 R
5 b; _! F/ U* \$ h( w4 o 判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。 , A8 @0 U( a" f$ U3 ^
3 ]& H {! h7 Z- z: I& h ' H1 W3 P! b" Z
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
, q* q0 W5 u8 t8 q, H7 S0 u . Q- t# b# |6 C+ q
0 t, y4 Z. j9 K: e' G0 P # 注意 `tf.function` 的使用
3 Q, o! E) k! [- C M% A. M8 J # 该注解使函数被“编译” - N! B: c4 C, j
@tf.function ) z3 I2 \# n- ?0 @
def train_step(images): 5 C$ q* _$ a: T* D2 p* {
noise = tf.random.normal([BATCH_SIZE, noise_dim])
" ~% E- p- J* H7 E! ~, D$ n " |) W6 h2 ^" w+ i2 a2 C3 w K" s; m
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: 5 p5 [% z1 [9 v: [+ U$ @# O: ?) T0 S
generated_images = generator(noise, training=True)
+ W9 x# v% F; _; \9 o0 M
' v6 G- v$ P6 A9 O! G- {2 J real_output = discriminator(images, training=True) # m3 P9 i, {3 d* W; P# @! a8 t9 ]; z
fake_output = discriminator(generated_images, training=True)
7 { d2 k7 t7 n$ }8 K # i7 _2 A$ V4 }* @" X( s
gen_loss = generator_loss(fake_output) " x" w1 `& a1 J1 p/ u- b2 U
disc_loss = discriminator_loss(real_output, fake_output) : u6 {! u( H" g. }& E% U7 y
% G" @% u4 l' a. U+ \ gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) + t% |2 N9 w! v. t) s( _
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
- L% O* w: s1 A2 R% T3 R. S
1 q @1 G3 r0 {, v6 s generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) % q, B" \( c" e3 G0 Q% z
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
, `' Q8 k A u9 h% c* c( j& b* z 3 v J. ?0 C2 L: N) Q
def train(dataset, epochs):
: Z. A- l: T; R for epoch in range(epochs):
6 p: d/ i. q& q( ?* I start = time.time()
* E& ]. O T. |' J& b4 f
* y+ {% }* d2 Z. t% J# y: C for image_batch in dataset: , U& K1 G0 u9 s# y5 V
train_step(image_batch)
) i1 X7 x+ r* ^/ | 8 |: O' e/ m8 @5 F! K$ E
# 继续进行时为 GIF 生成图像
2 ~5 }/ ] N* \2 u5 t( Z: S display.clear_output(wait=True) # I) O- h$ ~, F% {9 l
generate_and_save_images(generator, 4 w+ D9 m* v. f, I
epoch + 1,
' g) B; C$ n c seed) - l, W/ v5 E2 R& }
0 N" ~6 j, L. Y- F8 M # 每 15 个 epoch 保存一次模型 x' o, w, `- F: M# }2 ~
if (epoch + 1) % 15 == 0:
* b! L1 u1 D6 O9 N1 Y8 n. D checkpoint.save(file_prefix = checkpoint_prefix) . Q: c" I$ ?6 M
^% q5 Y0 ]) [$ s/ h$ z print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# K7 f% u3 d0 V# u' N* ]+ C) R2 \
# q) \, B6 N: A # 最后一个 epoch 结束后生成图片 3 [/ _: F" I. c2 K! D- G$ \0 ~
display.clear_output(wait=True)
7 ^0 n* D; C, |/ V" O generate_and_save_images(generator, 5 |, @2 ]& U* k2 L; C$ f3 b
epochs, $ ?" C" M7 n2 a8 q, k
seed) ' w2 o. z' U3 g K/ S! Y
$ v7 c% l/ T* Q& f- ?( W, p# g
# 生成与保存图片 , ?; M$ g+ i: W+ a
def generate_and_save_images(model, epoch, test_input):
& M( E: j: C9 b7 n) C # 注意 training` 设定为 False & O0 Q* Q* G: X1 I0 T2 j" A& D
# 因此,所有层都在推理模式下运行(batchnorm)。 f+ m" i' C4 V
predictions = model(test_input, training=False)
( v; A8 H9 ^1 o
, A8 L" A+ [/ d# C fig = plt.figure(figsize=(4,4)) / N7 w2 j& r: i2 d; Q
- s6 H. u, n2 _& m6 M for i in range(predictions.shape[0]):
- g7 Y* t7 g- R5 e plt.subplot(4, 4, i+1)
& Z% Z( ] k& Q9 p6 g& n- V plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') 5 F& w! j' H3 M" O
plt.axis('off') . T: n4 s$ Z) V6 d+ }6 L
( G2 c4 \7 l8 g
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) " ?1 j- L2 z0 X) B- T
plt.show() + C+ P ?/ M+ g0 ?, |
5.3 训练模型
, z4 ^* I0 i7 ^! S( k; b6 S 调用上面定义的train()函数,来同时训练生成器和判别器。 . H9 d- u! C! ? F& u5 M) \! f
, H- B, U+ s1 r$ ?* E
, j8 K2 j1 c1 f6 C7 E Q2 `- r7 G
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。 % }2 v/ B" W) E8 V/ y8 ^
$ E' x( P3 G1 g
9 a" V0 q6 f; K. b9 q %%time
4 |4 S" D$ _4 z6 e* n0 d train(train_dataset, EPOCHS)
# F- }1 \; A, ]1 b* f+ Q/ I 在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。 ; o) t/ Q2 s) V5 }- M7 c
/ C- o! y0 V, }( Z1 U' P ' C+ K% Q4 _2 A! _
训练了15轮的效果: ) t F+ q: F f
, t# n' [) `% a/ `8 [
6 _6 O2 {* A" F3 y7 w5 \/ v
* ]5 c' {$ v* K- h. ]
^5 `2 C/ ~+ g& }
) k1 S/ v! U& G" a
- X! q8 q E; S$ S2 J" i: c' k 训练了30轮的效果:
7 ?2 @/ Z' y6 z / F; c \$ {7 e. ]
# {- X K0 _+ C
+ k) L0 z+ ?- h; W; F6 I' P
% m1 g3 |, M1 M5 y8 H ! ^9 W3 m% v. u" I# |4 x
6 m3 E' _0 R' [ 训练过程: 3 Q( H, R' q' \2 `0 ?# N
9 ~4 D; }% H) M
4 z5 s% }8 {: ^' Y
; u: V% R1 q" n8 }) v$ b: B
, j. W/ j3 h1 @# t1 h ~7 f' Y* Y
# f% D2 P& k1 E7 K
. U7 ~- L6 B" l# @* @
恢复最新的检查点
7 [; b7 K+ t4 v @! { L( p- @# {
9 n, k2 s; ]6 w( U2 X
6 }8 T# u5 N! q9 S8 Q1 S5 R" R8 p checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
' {$ O6 M/ I; b) X! T( o ?6 i 六、评估模型
9 n( Y, s2 l! m" P: F 这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
# R8 ?, d1 r7 x$ g. ]( B
& h0 W# p4 K4 |7 u& y$ W$ A( U
, N$ L+ U, s. l# R7 p3 M # 使用 epoch 数生成单张图片
% N/ c( f( T( ] def display_image(epoch_no): 8 E% Z6 @$ [& x7 d. E
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
4 j! A1 H: j1 }5 I, H/ S6 r
* E4 s+ \# _% X display_image(EPOCHS)
( h+ Z' K) ?# F anim_file = 'dcgan.gif' 7 x3 S7 u/ w0 j |3 e+ r Q L
2 l- D3 v7 {; J* i: W with imageio.get_writer(anim_file, mode='I') as writer:
0 N: \7 e7 z) l* E4 }8 u. G; J( D filenames = glob.glob('image*.png')
B- V3 N8 B9 z, v% p4 d8 Q filenames = sorted(filenames) 9 [8 @0 j. E( v. y, D4 R
last = -1 " m. i6 r, y3 V) g
for i,filename in enumerate(filenames):
6 i0 F/ F! r& w' Y' [ frame = 2*(i**0.5)
/ |$ q* E8 L5 e$ c if round(frame) > round(last):
9 B- E: E+ I- c# m4 ^ last = frame
( J+ i+ O9 ^- H+ c: \0 u else:
+ n. c0 V0 g2 V. U continue : J$ d# X. @3 R8 i$ [4 `/ a
image = imageio.imread(filename)
5 | Z; E7 O1 I8 m9 y writer.append_data(image)
1 w& Q: i: s/ Q image = imageio.imread(filename)
7 ~( p' @6 ~) d. O* x3 i* G5 M writer.append_data(image)
# K x- K4 d0 ]* ?, U1 e. ^ # Z& ~8 H5 M) x4 j2 N% a! [: m6 v* i
import IPython
) n) }' V( W3 @" a9 x3 D i if IPython.version_info > (6,2,0,''): ' z/ I/ W4 {) _
display.Image(filename=anim_file)
. b0 `+ |) S3 Y. p9 m 2 M9 R* k) T4 h7 i, o* G, v
# N; o* a: E1 F( h5 W0 }) F7 e
: G: n a' \3 C/ C% ^' K3 h% P( F
9 H3 _! x8 T1 |8 Z1 m' [0 A3 Z: i 完整代码: % u$ Z. N# K$ n! U# [
$ W5 @! P) _4 t" B+ k! Q4 H9 e
( i3 A( }" `, ~ import tensorflow as tf
& w3 P. z3 K! o9 ?% \/ C3 s import glob $ K0 ?$ t1 @& n" G/ ]
import imageio
4 q G4 U; l$ x import matplotlib.pyplot as plt % i/ a q3 S+ ]& v( I
import numpy as np
! G4 n! G; N/ A. y l import os
5 j1 `6 H$ x$ D# u- C: J o import PIL 2 W- |6 s4 |0 `: \
from tensorflow.keras import layers
# r5 T0 A3 n+ p4 n import time ) {$ S" G( d1 l+ \( P
. K7 }( q- X8 F; p$ J from IPython import display ( ~# T0 e3 i2 F7 Z
/ F( _7 F( ]% f4 ^
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() . }2 ^7 |9 q. D: U1 Z
0 j1 y1 z# T1 o+ K
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') # y4 N1 F7 Z. T* N
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 ' J; _) q% T: H& ] D! D0 M
; f' P( N9 ^4 z2 o BUFFER_SIZE = 60000
+ z! Q1 n# W% }& T% f BATCH_SIZE = 256
. j! K6 r8 o4 a+ ^ " F2 m5 Y$ |- |! ^% Y, o% J, |
# 批量化和打乱数据
, u7 W8 Z, J$ [7 k7 O train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
/ \3 u" Z' |. L3 x5 ^7 P 0 G# r6 A6 f9 j# h$ P
# 创建模型--生成器
+ b* Y8 O& q, H% V- D8 w def make_generator_model(): c) K$ ^2 |/ g* v& o1 v2 w
model = tf.keras.Sequential()
3 Z% R$ N; ]& q+ X model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
/ e8 Z* S+ a U6 R model.add(layers.BatchNormalization()) 2 Z, E' q [& O3 B: |6 u2 G: N
model.add(layers.LeakyReLU())
% p+ J8 ~: F S* M7 S
0 W+ U: S. U0 g& u: K" P2 i model.add(layers.Reshape((7, 7, 256)))
3 k% `+ A, N; Q+ N assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
! T* e( N2 E7 j0 |6 Z% F
; g3 {2 O+ ^5 r+ u# t model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) 7 A+ K% }% O3 f { R2 z5 `9 q. {
assert model.output_shape == (None, 7, 7, 128) $ M4 J, B7 ?7 U0 ^
model.add(layers.BatchNormalization()) : \) d- \5 Z0 m* W% M
model.add(layers.LeakyReLU())
4 C) H1 S% r: A5 s6 w
: q& J/ a( P1 M4 H2 c model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
* j$ P# `5 @ |$ T assert model.output_shape == (None, 14, 14, 64)
. Y. ^. C0 G5 p+ U5 b7 \& l model.add(layers.BatchNormalization())
9 {9 f) r+ G" w2 W" D; Y model.add(layers.LeakyReLU())
0 j- |9 |& z8 ^7 B. D4 k: }
! B+ y' K7 p6 V" H) m+ \ model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) $ ^& e' i' M1 L+ v" ]. \2 w8 `
assert model.output_shape == (None, 28, 28, 1)
! o- F* x$ X; q- U* I3 z + w& W; ~! x* l. S3 d4 Z$ d' J
return model ' H$ ^8 a# N$ L$ f7 E" F, B
. m t3 B2 c$ ~; j # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
# ]8 q. F' C0 Q7 k. p7 [ generator = make_generator_model()
3 G& s/ g; N w6 ?; |9 M $ I4 s, e% P' x1 f
noise = tf.random.normal([1, 100]) 5 G( _: q' _) y+ a$ n7 b& |
generated_image = generator(noise, training=False)
9 P$ w6 g8 T) i. n% z
7 e- r# _* {: ]: M# [3 ~7 g8 d# a$ | plt.imshow(generated_image[0, :, :, 0], cmap='gray')
C8 _" [6 H8 h1 J tf.keras.utils.plot_model(generator)
9 ] y! Z! N# v- P2 ~
1 K! a5 R. T0 v$ U # 判别器 ! j- q7 U/ C# v% f$ j% U
def make_discriminator_model():
! g1 B+ a/ J; f' @' B model = tf.keras.Sequential() 5 Y. O( c" v, V$ B
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', : k3 t0 y: z# R1 v! W
input_shape=[28, 28, 1])) " n! \' D7 Y# a- F6 H3 j
model.add(layers.LeakyReLU()) : c0 ?' p3 K" R4 q# j! Z6 J
model.add(layers.Dropout(0.3))
+ ]) ^$ n# l% D4 d- E
! n6 n+ Q+ n, I/ i0 ] model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
+ y) m- x2 U! ?# A! `8 L# j model.add(layers.LeakyReLU())
4 m- I& ]+ p0 S5 D model.add(layers.Dropout(0.3)) / g1 o0 q( P" f' ^, t9 u- y+ J1 @
0 H0 Z& n! D! u/ A- ] model.add(layers.Flatten())
F( C7 F) r* q; G- M2 m model.add(layers.Dense(1))
( Z p# _" d% e- M' x# H 4 t4 u: A( a, r
return model
! K) |$ R+ z( O' y* [& r* J
, C2 R( ?. P$ V9 [( ^( y # 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
# |4 g/ i& O- w; J: M+ p$ k& y discriminator = make_discriminator_model() + H; i. a* ?# k3 |* ]0 i
decision = discriminator(generated_image)
3 v3 \, P( M7 m( U7 V. Z4 w0 P print (decision)
8 v$ M" [7 j- q- \: a
9 h0 a: a: T/ Q4 D* l* F. { # 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 $ |7 l2 I1 G* e" J' u- D* n0 I* o
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
6 h, q# \2 {3 ` : z$ P. m$ M. {1 T6 u( @' ?$ I
# 生成器的损失和优化器 / o( h& P7 Z3 Q$ `& R; M
def generator_loss(fake_output):
8 @/ b, j# a/ a3 |; M return cross_entropy(tf.ones_like(fake_output), fake_output)
* G' h, Z7 @7 i' Q" x- v' P+ [1 A generator_optimizer = tf.keras.optimizers.Adam(1e-4) 7 e+ }: t$ i! ~# h$ R( s: {! Z
1 u3 [8 y% s7 g) B; I' _" U$ @4 Z
# 判别器的损失和优化器 1 t) y" S& x9 p, e. O' z9 A( E: D8 m
def discriminator_loss(real_output, fake_output): 1 A% f6 x$ a8 _* R
real_loss = cross_entropy(tf.ones_like(real_output), real_output) % N1 P9 S+ J, a5 o: m6 X" e
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) ) y7 e7 k3 I1 g& G) [( ~
total_loss = real_loss + fake_loss
; z1 S4 Q6 @+ T; j* e return total_loss # s/ A- }. {5 U7 \' ~
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) * X; S- O' L) w- \
6 b0 N9 @% w% l1 M9 T # 保存检查点
- U+ m) o# H/ @2 K. u checkpoint_dir = './training_checkpoints'
- v! @) O% W. d. Q7 J checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 8 t' v+ x6 m# q5 E
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
~: o. j# H7 ^9 M" O% n$ L& v) k1 f discriminator_optimizer=discriminator_optimizer, " f$ e: r& A( k: x
generator=generator, 6 ~: j3 l# ]" @0 s
discriminator=discriminator) " Q, ]3 f- L/ P
! c8 n8 S2 ^9 y- b2 G% O/ L9 }
# 定义训练过程
, s/ e+ s: k6 T4 z/ l9 O' j' K3 o, y EPOCHS = 50 5 P7 Q2 R$ W. k- E' k7 @5 K
noise_dim = 100 5 {% g' k$ v$ G' ~2 K
num_examples_to_generate = 16 " m4 x0 q A _0 a
7 X/ N# i+ |) e, E* L # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
- N7 O5 M! U/ X9 l. ^4 z" ] seed = tf.random.normal([num_examples_to_generate, noise_dim])
' k5 B) U' d& |3 M* x 4 \! j+ _& ?2 K0 @6 `
# 注意 `tf.function` 的使用
$ y) V6 `9 \6 h+ p# r # 该注解使函数被“编译”
. {9 p. ~: A/ R @tf.function 6 R* A& Q" v1 Z9 }
def train_step(images):
' C5 ], C; m6 l$ Z# a ^! D noise = tf.random.normal([BATCH_SIZE, noise_dim])
# A! \4 K2 F0 K8 q: ~7 D , z3 X. c x, U. q0 q
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
1 _$ R s3 G% E; d6 X+ G. r generated_images = generator(noise, training=True) # T8 \) c: {% t; P; e. X& B9 [0 k% j
0 s# ]( O' A3 D3 M
real_output = discriminator(images, training=True)
& N$ W/ P! u7 K& G3 Z fake_output = discriminator(generated_images, training=True) 2 Z* D ^% e" _
. U0 w3 G* o! y& i3 b
gen_loss = generator_loss(fake_output) % `3 k2 _( f2 k {: m
disc_loss = discriminator_loss(real_output, fake_output) : U# F" u# T' Z2 E5 d
: D& B; Q+ I Q* a gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 1 j; r" w" X, J$ L5 b" P
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) ! j. l1 Z) x8 z+ J; A
! G/ T& C* _) j a- y" o2 t
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
/ k' i3 ?/ |; \. J- D3 k discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) % v4 U- e9 V$ y% {
: X: ? B2 G. ]; x5 Q def train(dataset, epochs):
9 d& R) D% E x for epoch in range(epochs):
2 o4 u# \9 Z7 k4 h start = time.time() " x; x% n% H. G/ B1 B2 G; A6 F
: E- P' f" D1 p
for image_batch in dataset:
/ u4 J4 X5 s, A. j train_step(image_batch) $ v6 C( ~4 U4 {$ n& O, a
; ]( C5 {6 [4 h. O1 n
# 继续进行时为 GIF 生成图像
* G8 {! X% t4 z( f8 f* c6 n display.clear_output(wait=True) % V S3 P6 i! G9 T: r. u3 l
generate_and_save_images(generator,
" Y" V# F7 n: d5 {( F epoch + 1,
% N9 Q7 E Y+ O9 d0 @ seed) , E ~% `: _% u
; Z9 F: a; B2 M! P # 每 15 个 epoch 保存一次模型 , X" y) }( @8 D) x1 t7 U( o+ ]
if (epoch + 1) % 15 == 0: " S; P0 a2 y$ t% }; X
checkpoint.save(file_prefix = checkpoint_prefix)
7 v! W0 _' g4 z: d! h$ r' s e
1 r/ D. d, \# X# w& u" R print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
* l! Z4 l( ^6 l6 ~9 s
. ]7 F# Z2 X) E- n # 最后一个 epoch 结束后生成图片
2 U% G6 d! y7 P& Y- A4 M display.clear_output(wait=True)
2 T, n- X7 p9 S" v+ _, E& s1 H2 P# y generate_and_save_images(generator, 6 W+ a! h0 s3 P& T( ^6 G0 y. s
epochs,
7 a& s" ]4 ~/ }9 o seed) $ S" r: q' n2 ?) E2 T
$ R& J7 Z3 @2 G8 D( G
# 生成与保存图片 u9 e, f3 u* N
def generate_and_save_images(model, epoch, test_input):
1 u2 R {- ]6 D9 r& a # 注意 training` 设定为 False
! P9 \* R4 M! u, B& h* h0 A # 因此,所有层都在推理模式下运行(batchnorm)。 & U) ~% u" Z; z
predictions = model(test_input, training=False) 1 n: @2 }. L4 G3 R: {. v
$ l4 p. C# F& @& d# r3 a
fig = plt.figure(figsize=(4,4)) 2 ^: e+ d% ~5 n: f' \& c& v; c
8 ^' ?( e# Q* C7 y' y' X, a
for i in range(predictions.shape[0]):
7 n) ]3 b$ Y0 i plt.subplot(4, 4, i+1)
% [. H: V5 a1 ?0 c plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') / z7 g: u' }6 h1 h4 K( x
plt.axis('off') # \% D' P( B+ m; K9 p
6 x8 e, y3 {5 }, s8 s plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
6 H$ T8 |6 u' i m1 A$ L plt.show()
+ {0 X1 c. A+ K0 T4 H, X6 x) O # p* _. F+ q+ a9 r2 z
# 训练模型
1 O1 i6 S3 I4 d% f* [ train(train_dataset, EPOCHS)
8 m' L$ ?+ _3 a7 o0 f
& j5 C, L* S3 y5 f8 ]: B3 { # 恢复最新的检查点
1 ]- m, ]) W! {9 z checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) ( ^+ W- t' M$ Y9 U
. G: }! @5 a- { # 评估模型 m0 I0 ?8 v! `6 s( A
# 使用 epoch 数生成单张图片 9 _1 A1 U+ H' C* I/ F1 j3 a+ ~3 Z
def display_image(epoch_no): y# N# A: k- {5 r
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) + G$ x' u' p7 K% i* [8 u I
+ D/ ^8 j8 P' H' F) h
display_image(EPOCHS)
4 s3 r" n& j$ ?6 E
1 U8 C. H2 q! ?# N- d: _ anim_file = 'dcgan.gif'
4 O6 Y: I! V5 [8 V. D5 B
0 O$ q0 d% L% U4 M5 a* E8 Q with imageio.get_writer(anim_file, mode='I') as writer: $ Q& M6 k8 B) _5 v$ ]0 W
filenames = glob.glob('image*.png')
- Y- S7 W- f4 ?" t( z! S filenames = sorted(filenames)
5 u' E1 K; k0 ~4 q8 {/ { last = -1
3 E/ \+ l( a7 }: \8 _) q( G) A for i,filename in enumerate(filenames):
0 |) X' \# m. b ` frame = 2*(i**0.5) 3 b. |: h3 M( `" h- Z- q# ^
if round(frame) > round(last): - g+ E. p, Z/ p5 k/ i4 {/ G- D
last = frame 4 R9 h& h7 m1 n
else:
. m- ^/ }/ P/ r* y; R" O: B+ R' n continue
F+ Y q* y8 o( d/ ?# \ image = imageio.imread(filename) 6 `" d% F3 E# _: I
writer.append_data(image) ( H7 v+ k9 N6 h
image = imageio.imread(filename)
/ j1 p( E; S; w1 m% r+ }7 K A writer.append_data(image) ; q7 H! P5 f2 g+ l
+ [5 j- u* E; v5 d' z/ z. a import IPython
0 [1 z5 E% l h9 i/ u1 d( | if IPython.version_info > (6,2,0,''): $ ^' A; i" a: w) T# Q( o( L0 g
display.Image(filename=anim_file) , u Z. R& s" |3 `6 Q) F
参考:https://www.tensorflow.org/tutorials/generative/dcgan
0 n Z, `+ o7 ]5 d. I, ?4 q: E* x ————————————————
4 C- t r( d$ e% }2 {5 s6 H# u 版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 9 N2 c7 @ H3 ~* Z
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111 K$ z; ?1 v/ Q8 X9 ~* m5 Z& R
# \, b# g! W* b ?
" w' Z7 u H$ M- Z% g
zan