在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 564458 点 威望 12 点 阅读权限 255 积分 174560 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 3 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
9 ]7 Y( z! v5 E+ d! F0 H" H' s 深度卷积生成对抗网络DCGAN——生成手写数字图片 9 o& r2 j& ]; H0 F8 @( ]1 L! p
前言 ( [ b: ^! M6 V
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。 5 t* R5 a) n4 o. L, R
: a. A+ x, M2 ~
6 j8 ?' e8 z d3 _6 K9 M
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
) t" _( |. X1 a( ^& U/ c " w7 r1 Y! F0 l# M6 w
0 w$ N+ x, G; {0 @; w" R& L # 用于生成 GIF 图片 6 T' s6 @& L8 L0 M3 C) K
pip install -q imageio + r" U8 V" w7 ~6 z
目录
. w& P$ z8 V; O; ~) M1 z
; h# W0 k6 ~- V" ?, L) s. ^8 H, p f
5 a1 b$ n$ K: A1 b( { l% w 前言
( T6 G5 f5 n/ t
- _9 k! U5 [6 m% |' Q2 g5 L) e 1 X8 _. V( a) d6 x
一、什么是生成对抗网络?
% }) {& }- o% ~+ r' l
; l* ~2 f/ c0 Q
3 v" Q" _4 N2 C' J 二、加载数据集
% o# e+ l$ w0 A5 U/ v& ^ # G+ u( X% I5 H% d# X! V# }
: \& |6 A0 @# y) G; Z: o 三、创建模型
2 \. Y' L5 |# q
8 _) t2 O+ t4 S. P* k1 Z& _7 f : r8 l$ t% m2 n8 U
3.1 生成器
7 F* D; }3 x! q& ^. `' P) h; q & \) I$ b- V9 }1 e, v3 f% q9 o
2 w' q3 q* l+ @; P 3.1 判别器 : W" O1 r; `( X1 D
6 u2 `5 `* j. p4 ?# i
: _" B: u: g; Y+ n) s. ^6 t 四、定义损失函数和优化器
$ ^- _/ {7 |/ }5 z1 Z 0 s6 L9 F. A n3 l
* M8 J+ O8 }# G- ~ 4.1 生成器的损失和优化器
' ?2 Q) D& ]) s8 R9 j% ~; Q. v; z 9 g: n" w* ~! A
- t& p5 F0 X9 k- p2 F ~5 X6 P$ U- | 4.2 判别器的损失和优化器
" \7 l* m) z; @9 a4 b& d r7 n 9 x( F6 J; C( |! R
4 U2 j7 _: C" |4 {3 m9 r 五、训练模型
% k3 f, e% ~8 N& Z% W 8 g, V2 B- b6 j+ l; I) n8 G
4 R( C ~: R8 p 5.1 保存检查点
- U3 m4 O% r/ R( |$ Q. l ( u, F% ?7 h6 i& K! L) S/ J
6 [: k, o `: ^, X# F; l: W, h 5.2 定义训练过程 % E4 x" u7 `; h
$ A/ Q# y4 O$ w' R {3 z' R3 L T; P* H- O1 l
5.3 训练模型
+ d( e5 U3 x: }+ x! ], P( p # N* E2 l! U/ O5 g, w+ m
, [* |. D" ]/ ~8 H& r+ c
六、评估模型
; z- F) _2 g# E) j4 Z
7 y2 ?, {5 G# E9 u# T$ n ; W, W7 n$ g8 m, u( {5 d L: B% G
一、什么是生成对抗网络?
! a3 [ | M6 }9 ? x 生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
b/ H- f- E$ z/ J
7 X7 a0 _" z, ?( W % J* Q( Z7 C2 [6 {& o8 F6 S
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。 ) G( h4 A, t7 y1 Z1 T# q$ R3 c
, x& ` q e& p* P * F* I% O& P. { [- U5 {; z# ~ g9 Z
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
4 ?' \/ p2 ]# W7 f" e3 N7 H + X( k' J8 @0 p2 H
$ u6 R% M2 U9 g 训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
+ a0 U+ n. A& A( O' `! h7 u
~, S8 P# O% S: b" z: Z' W9 R; y
* J) L2 u0 G$ O, [* u# }9 l 当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
" E4 k6 f' a* ?; N1 ?
$ G4 p% M; e+ k+ {% Y4 v
; g5 `' L4 y3 U$ S' Z& D 本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
7 k* R! j( {9 z2 f" E3 r
% {* G: L0 g# d4 c
/ v! u. _7 `; }! [3 B 二、加载数据集 3 C5 L: C+ x7 p; H r% _
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。 & q% k( ]7 e( ?
3 ?" v4 p/ x, s8 B! `
, E' E/ H0 u! {% x (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
* a* m( G) ]' y$ Y- _( u8 z
) ]9 d% q' S* \3 _. N5 r train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') 3 N6 P+ l6 a( K, g
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 ) D' f. X2 J5 a5 ]
3 z0 \5 m* |' L, x& I2 N BUFFER_SIZE = 60000 , n5 E( y; O# u% H/ f
BATCH_SIZE = 256 - `2 I8 T( {1 a5 ^6 P; S/ v
q5 @' h5 M- y# J8 C: f
# 批量化和打乱数据 . G+ w; U& s) n% ?3 c
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) . j; f% J4 C9 N: w1 v0 o
三、创建模型
# q9 I8 r+ G& o1 m! w$ y3 e$ g 主要创建两个模型,一个是生成器,另一个是判别器。
/ D2 f- q2 M' o7 i0 s/ O
B3 G: Z% \" m3 H+ X' e% H: E1 i # ~# A" G4 j/ p( i
3.1 生成器
5 W, x7 V6 j0 n) S 生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。 9 D0 F0 Q" ?* [: u$ u, m
3 G- Y( Q( h( s* i3 T- U+ y
' @# Q) s: X% {& s
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。 5 Q. i- Y2 M. h5 l
% Q$ Z7 C7 }$ z9 @( {$ M* g" D
8 w, H/ b6 R, P( }; T& R& x* n" u+ x 后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。 8 G1 i g5 Z7 ~2 m) T; [
; C: x- A# D6 J& h
' k6 E# V9 l- R5 u2 j9 d def make_generator_model():
$ V# I' O4 G4 |. g/ d- T( o: [ model = tf.keras.Sequential()
" Q5 @# E0 ^: F7 p m+ P% U6 k3 ~6 ? model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
' w) { N* k* J* \3 t- ~8 r model.add(layers.BatchNormalization()) 4 i( R' \3 @& u% ]
model.add(layers.LeakyReLU()) * {2 k6 X% ^5 g
& W* W' o9 P+ R4 X0 x6 Z( d
model.add(layers.Reshape((7, 7, 256))) f) H9 B+ M5 V% h$ m2 B
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
f, n7 u6 I* B! c
* T" |& p& j) b" w" F' p' N model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) A+ R4 P2 K) B4 F, n
assert model.output_shape == (None, 7, 7, 128)
4 T& _$ S% v9 ?7 T. C) ^ model.add(layers.BatchNormalization())
! X# j V6 Z3 A- @ model.add(layers.LeakyReLU()) 0 V; w7 v/ K, v
' _, m% u" [9 Y9 D model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) & _( Z# S! R; B+ ~
assert model.output_shape == (None, 14, 14, 64)
; O( U- Z; B0 U- h6 b6 | model.add(layers.BatchNormalization()) " t- u* K, C( W0 q$ k) c
model.add(layers.LeakyReLU())
2 U0 x& M! g, X$ g! c i 5 l, v& X0 O: C1 e0 ]
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) " _" k% P3 }- b+ o- c
assert model.output_shape == (None, 28, 28, 1) ) Y" Q# q- Z% O! l
: u, t: p0 U4 Q3 C* G6 c0 E7 Y return model
% g7 @# n4 Y& \ 用tf.keras.utils.plot_model( ),看一下模型结构 ( |: ?5 Q! U# \
& I2 P" Y! D# P1 M- e- L& S
) c, k9 o r: h7 R4 @6 X' s5 e % Z: j8 W) F3 `2 ^
+ M6 P C6 ^, V F' ]4 m
5 o" G& c* C- q0 l 用summary(),看一下模型结构和参数 8 z ]+ c M- L0 |& e
: }% G: }' D {5 g0 W3 x
, o5 q* B" d" Z+ X( J9 p+ k
* o. e' H% |' ~6 F5 y( y : Z9 D* Q% q3 L) a: T, F! h
$ x# {2 x, s5 N0 M " ~1 X7 K0 n/ `8 o: v5 e5 c) a
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
# z: o0 S# f4 { * D8 [ o' y3 z2 b3 B1 Y' u
6 D, c6 H. l5 W$ W6 _
generator = make_generator_model() ( ]2 j( F: T3 y4 d" s' x1 v+ b1 W; [3 a
4 R; b7 b6 A! _5 d V# f8 D noise = tf.random.normal([1, 100])
$ l$ X" c4 X+ S% \+ M generated_image = generator(noise, training=False)
7 Y7 \) S! y. Y7 F* F# @ + t2 [7 l2 D+ N+ O
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
4 u/ B N# q6 P, S3 Y( ?- c. Y% ~ 3 v5 P. P# |- V. I
$ y2 T. y4 o) e# Y 6 ]. z2 N0 c" H* e% \( m
/ @, Y6 l% X$ r3 G! l l$ n/ V
3.1 判别器
5 `; ~) k7 @3 r' g; Q N7 m 判别器是基于 CNN卷积神经网络 的图片分类器。
. h8 n7 I9 j4 Q- v1 E
& Q- D- v. w3 C0 N& A: T ) V) C# A+ J6 L& J' U/ S: O
def make_discriminator_model():
( K- G' z' R2 w" O+ S/ `* S model = tf.keras.Sequential()
- a+ j( P9 _; H3 x" ] model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
0 R9 }$ m' q: K- K" H7 A input_shape=[28, 28, 1])) * h" e, K* A. K$ u9 m
model.add(layers.LeakyReLU()) & j+ I9 N5 |, ?+ _9 U% p( k3 ^
model.add(layers.Dropout(0.3)) $ K5 o! y0 Q$ g! D+ z& d- Y
# U8 F) o2 [3 f* | model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) 7 o0 K1 S( X- \1 ` Z
model.add(layers.LeakyReLU())
2 {& }/ }# u; V8 ^ model.add(layers.Dropout(0.3))
" d/ C# n# g( i* h0 Z * \9 b4 t' [3 v! G
model.add(layers.Flatten())
$ Q; V+ m3 ~: v: N& a model.add(layers.Dense(1))
* ?% o# n! b M' d/ w4 D- I% u
9 R/ k4 {$ Y8 | return model
4 r* o5 N8 v5 k, Q- B1 A 用tf.keras.utils.plot_model( ),看一下模型结构
$ {4 X# Z z+ s% K8 e7 P: ~ j 9 U c7 W3 i9 y
' h* H0 H; O. H w9 ^7 C' h0 s- Z' y
) ^& U; ?7 j2 W6 r4 k ^ ) R* B, T) F+ V# o* @
- E" I9 c: l& x3 b2 L: M) _' t1 X3 W ) g$ p1 W( b; J9 Q- v
用summary(),看一下模型结构和参数
" q. Z. N0 D! i% U5 X& p
7 `* u. { ?5 C2 z G
9 b7 E5 @3 o& t3 s/ G& ?& c ' c& Y$ P0 N7 z& b) v' ]
, Q, {* C7 o- l# l0 I
4 ]2 @: m! L4 R$ n
3 l! {' {" P# V6 V0 M7 @* U 四、定义损失函数和优化器 : O9 v n; W# i$ u- ^; i
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。
& j) B3 Y5 F/ i" K8 x5 Y" t6 a9 ` % R$ r: h9 r8 ~/ X) M) J7 P7 U
( i2 F* N: \3 g; s! e
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
, A% D" e' r# V0 Z( M% ^
& g3 j7 G8 G8 `- x9 s 1 r6 v; d7 U$ N/ O& `/ s
# 该方法返回计算交叉熵损失的辅助函数
7 \+ l& L! `% b- s cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
) s7 }& I! r4 J8 p- Z8 o& u& k 4.1 生成器的损失和优化器 - _3 U. Q5 P, h, F
1)生成器损失 % `. z7 p- P8 D3 b" O. Y( w a
9 ? i3 N3 g% s2 C6 E x+ X3 S' j5 L% @: y4 k& @
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
8 o" C9 x9 p4 h7 _
$ x, ~1 w: H( p/ X7 W- K: [
# N* O z4 Q( u0 \& v 这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。 2 Q& @. M1 }0 M, \
D B8 @! v4 r4 H
# S, h2 j4 A; k2 O5 {: X4 d def generator_loss(fake_output): 7 P; p" ^) a4 L8 H0 B! ^
return cross_entropy(tf.ones_like(fake_output), fake_output) 0 P0 v# _; U# q& s; u
2)生成器优化器
$ n& \8 N( F/ e* |1 v3 a8 q# N : t j) L* Q [" K+ k0 v
* x$ H( D- i" T9 t generator_optimizer = tf.keras.optimizers.Adam(1e-4)
' E" l/ N3 _/ [ 4.2 判别器的损失和优化器
8 J; N+ @" w+ g% [ D 1)判别器损失
/ g# q: w7 w: f# a7 s3 v5 Y
3 J8 F4 ~4 u3 a * z8 ~. p, e5 y' H9 Z: l: Z2 F0 j% A2 j
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。 4 g/ o+ m4 h `- b& A" P3 q6 b; G
6 k- {" e# _9 G1 B4 |' G" e* C
~( y$ j' Z! c# x* `1 |; i def discriminator_loss(real_output, fake_output):
6 V& E9 z0 Y2 ?7 L ?- v8 j real_loss = cross_entropy(tf.ones_like(real_output), real_output) 1 i, [* O: e1 j3 ]( G
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) / o! P+ L! m" c* ^* j
total_loss = real_loss + fake_loss 1 K# W1 ^1 g( U* X/ A
return total_loss / o r- P0 V9 z" D" z; L
2)判别器优化器
0 k% k g6 n, s) k d: d0 z1 F' @; x
5 }3 M5 d$ a8 \, L1 @
" y. ?; Q f9 |$ M- f discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) , _0 b- p; W- k3 s$ _! X
五、训练模型
! {. Y- ]9 e5 W! U2 Q! o 5.1 保存检查点 0 P6 }5 |2 @3 z! p* O0 E2 H
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。 $ n. ~7 j* [5 v2 O4 Y( [
. o5 p. }7 l/ `" K1 n R2 A" Y: U
( r, {( L, Y& t2 ~8 f2 }8 V$ V1 U checkpoint_dir = './training_checkpoints'
& g3 r' @7 r3 i/ } checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
: {1 k8 u! J& G# d: F2 X checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, ! u$ i6 u+ W: Z9 z( g
discriminator_optimizer=discriminator_optimizer,
* P* z/ b6 n: w4 I4 j0 Z generator=generator, , M) x7 D: t9 R% `3 o. S7 z
discriminator=discriminator)
' J" O! |2 k. g3 ^% Z 5.2 定义训练过程 6 ?6 S8 c4 a- ^" J0 l0 y
EPOCHS = 50 * p+ I: H. b$ y# |
noise_dim = 100
2 X H4 ^0 I% s num_examples_to_generate = 16
% R7 R% f) b6 \$ R+ `; O4 j
1 ` S) X3 f' e
% h3 m! [! f; ^& v+ j" x5 { # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
, O& E4 U$ L. ^5 t seed = tf.random.normal([num_examples_to_generate, noise_dim])
) N9 p v. R2 R 训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
4 y3 }/ `! q4 Z2 h
4 f' m! i1 `! n: g# r X+ Z& l2 j3 m
" a* z# O* z/ G) i& @7 e _ 判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
5 o. ?/ L% M! y7 A) O3 R
4 U& A/ s6 U$ y$ _3 R' u! c9 i7 c& H
+ M6 ~/ ?$ F1 c8 |3 |' Z 两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
5 n; X4 C7 Z. p1 a
, c8 y7 \0 Z7 N8 k- [ + C s+ B; U! E
# 注意 `tf.function` 的使用 4 ^; D; F: q/ A$ i# x' N
# 该注解使函数被“编译” ' ]) K6 G6 R4 P) }3 v, a. m; B* d
@tf.function " x1 P# d1 s( H$ t: U2 c4 |: H
def train_step(images):
4 ^) @: I. _ E, N$ f noise = tf.random.normal([BATCH_SIZE, noise_dim])
" Z9 ?" m: D1 ^* H- q$ x0 m / ?) n" a: @8 W9 M. x* d$ G8 _
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
, u2 G$ K0 l) G( p# w4 v/ J generated_images = generator(noise, training=True)
. z; k& x3 |+ R $ j- I8 v! n N" K
real_output = discriminator(images, training=True)
3 J/ U y1 i3 r' M5 }( t/ i7 y fake_output = discriminator(generated_images, training=True)
% S: \8 t2 L1 l, X. d& v$ g, A. j
1 Z' n1 h; K' X- G* u& a gen_loss = generator_loss(fake_output) $ j1 R8 N' x1 n3 H
disc_loss = discriminator_loss(real_output, fake_output) ; n; ?7 G3 V7 v) `6 z: y+ G1 _8 ?9 j
$ O! w( D- G" F- _% G gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) z' U; x0 r9 \% r4 D d
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
% V1 B- g$ @7 ~( |$ Z
) g/ [, S9 Q0 u- s! i generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
/ _. u" a" S) y discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
1 k7 o$ Y& {$ ~( l
5 ^' u5 i# D' V% D6 A def train(dataset, epochs): & ^/ V1 B% y1 z s0 i
for epoch in range(epochs):
# a+ R' B$ Y# e& c start = time.time()
8 R$ S# s( N" V1 R$ q
* X( i3 W3 T, a1 ]9 V8 S2 y8 B for image_batch in dataset:
3 @4 @1 V: L% i2 ]' m p+ A train_step(image_batch)
, T2 [/ B3 ]4 x" w$ {! Z
f3 T E! P! o9 S7 X; m # 继续进行时为 GIF 生成图像
. V( K4 s/ u7 D$ q8 P% m display.clear_output(wait=True)
1 A' ^- z& f- A$ q0 T: h generate_and_save_images(generator, " L N& P1 T6 ~3 v
epoch + 1, 3 y3 I4 a* |& x3 |
seed) & ^* ]1 n" s* r) v
. S2 d" N+ I' T9 w- r5 @# j
# 每 15 个 epoch 保存一次模型
9 D' h, R9 `$ @! u* C- l. U if (epoch + 1) % 15 == 0: : F0 D M- O& ^: Y* X
checkpoint.save(file_prefix = checkpoint_prefix) * r# D+ F' }0 v( p. F4 A
& P0 j* N1 c3 N' z1 V- c print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
3 t% y7 g0 m0 s6 { x / _' E/ x& l; z! p
# 最后一个 epoch 结束后生成图片
1 P' r3 Y, X5 ^! I7 T6 p+ c display.clear_output(wait=True) 2 S4 ~4 ]* n. R
generate_and_save_images(generator, 8 l. G. c5 G' N$ _" o7 C
epochs, # ?6 u& |' h8 W9 X
seed) , m6 Z/ i/ e" T& A6 T1 G
F4 m: G7 L4 n# h" m7 D2 ]! U2 U # 生成与保存图片
! k% t+ j' i5 d1 T2 u; [ def generate_and_save_images(model, epoch, test_input):
* K, y& L, F* L" `) u) L2 A* Y # 注意 training` 设定为 False
' @6 J2 L1 k! @) k # 因此,所有层都在推理模式下运行(batchnorm)。 6 o" Y# m4 s% O3 _& |
predictions = model(test_input, training=False) 6 |) Y, H; T0 p' W! s5 C0 \
- V" m& q: z9 h$ o fig = plt.figure(figsize=(4,4)) + r6 c5 t0 V+ c# y. ]8 L
/ ~# m4 K+ r2 A: t for i in range(predictions.shape[0]): 3 X# X$ ?9 i6 ?+ J' B1 w4 E
plt.subplot(4, 4, i+1) - z2 W/ d, V, c- E4 s
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
s1 Q: j2 _7 ~! p3 z% C& D plt.axis('off') # {' Z% T7 l- x5 e W; d
& _% Z8 g2 E, _6 h. S% u5 Z/ L plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
1 r2 T7 c/ P: T: ]4 V: z plt.show()
X8 O( |' L+ z6 K5 [* X) p 5.3 训练模型 $ S- L8 \3 c. Z6 D2 j" A! e+ k/ |/ w
调用上面定义的train()函数,来同时训练生成器和判别器。
1 e0 P+ z9 u5 J! R" t& f
4 o5 g; P' ?/ a' J' |/ [) q * P+ z+ L% T$ K* w% O' K
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
6 x4 M- |1 z' t6 [! ?& H: O - K, v7 k5 {6 j) `9 j! o
: |2 U- R4 J f1 ? %%time
) R) ~3 z- |3 l% Y* z* I# N& g train(train_dataset, EPOCHS)
+ d! L1 k7 ~* a1 h8 L5 [( V) A( Q 在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
' b( |2 ^% W l/ n5 E3 A
2 \$ s# G4 ~1 h4 e1 ?: H4 L } 6 S( _- z; F* w. E% h2 I
训练了15轮的效果:
" ]' Q; g. c3 ?+ x
2 L- _ q. e; o( g/ h, H2 j 0 x' T+ q: P& x' H" v$ R; S6 x
) _" {& A/ l7 D0 }4 t7 X * J" M" d# I9 N9 J: t7 ^
7 I2 d4 I z9 Z% O
3 K6 w& s" i% K* p! P0 U
训练了30轮的效果: ; T+ m( x/ V' j9 x! I* y
3 \9 n# N1 T* ^1 c5 X+ |$ n8 ?0 ]
7 t; Y$ C* N) | 3 B7 m t' r. [7 }$ u/ w* M3 W+ V
3 E% D" |# x+ j: g: o- N7 V! b
( x% K- x3 U5 {' V- t
6 I1 W3 [' G$ ~ 训练过程: / x( d" C9 R3 P1 K6 q2 E$ _0 U
* _7 d2 [1 j4 L6 H
/ R4 f% ?& E5 R
: B: m* m: T. { X " V7 K% B" w K4 i7 i
( Y( Q) @( }5 R7 E* _! ^- F$ t
& y& S; w" `8 t+ M6 o. [ 恢复最新的检查点 0 s+ z& d' U5 }3 P% b/ _. e
3 l, g, i! p$ l . B! X; M0 p& z3 K$ c
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
, X/ E8 o: k, x( p7 S- I$ T 六、评估模型 ' c. a+ C1 J0 x
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
3 Q9 B( H' ]% W9 b$ Q8 E
& Z' }9 x: t" F
) K/ U1 T( m3 t& t # 使用 epoch 数生成单张图片
& N$ m9 N+ R( `& g1 M def display_image(epoch_no):
7 @! h6 K6 ~( L& ~ return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) 2 m- y1 Y" E# t( L: R$ i! n8 a8 c
7 T/ [* P3 j# v& | display_image(EPOCHS)
( R% B2 r6 g4 ~9 n( }: D, z anim_file = 'dcgan.gif'
2 J/ O8 N% i! c. e% Q $ Q# f! O( |+ f' _5 I- H
with imageio.get_writer(anim_file, mode='I') as writer:
% s2 U6 h1 S3 G; F; ^% ] filenames = glob.glob('image*.png') : Y: t' U8 r+ Q: h4 ?; O$ f8 z0 {3 D
filenames = sorted(filenames) % B) I/ m! O+ A
last = -1 $ w L, r: R8 ^! q6 _3 S
for i,filename in enumerate(filenames): + F6 }! @# M7 r
frame = 2*(i**0.5)
# } |1 Q1 z$ a, o w if round(frame) > round(last): ! d/ d3 {" w' P5 X) M
last = frame
5 l/ `% @7 n3 u! k else:
9 ~4 E8 b3 R4 T; D4 }( m continue / G- Q& R; x' M$ M4 e
image = imageio.imread(filename)
3 T) o. }8 `: ?6 L1 U6 p8 ^9 h8 C writer.append_data(image) 9 Q) h* t0 v/ N7 c0 L
image = imageio.imread(filename)
; {; o1 u9 d& v0 i5 ?7 M+ Z, ?9 g8 i writer.append_data(image) * T% f# f; h* t2 U# s
+ X B8 c1 P* r& U% X! h _; ^) J6 O
import IPython L4 h+ J7 k' P% O$ m
if IPython.version_info > (6,2,0,''): # C7 r8 A' `( q2 B
display.Image(filename=anim_file) - S6 D$ |5 K7 G5 v+ g* `4 ~
- y! v6 g: X; h" W6 ~1 f I" f + V" K- `9 y( Z9 W
t: }+ @4 U% U
$ g! n, y5 D2 l 完整代码: & |, r; b- _9 S
, h! n% T' q" r# J" T# a+ t$ h
$ O! v/ x5 k9 \* f import tensorflow as tf ( [3 @; C; \* {! Q6 w
import glob $ Y' O7 s$ |7 {( Z) G8 f1 w7 ~% o
import imageio @" x4 y5 M8 I3 _3 O3 a1 J
import matplotlib.pyplot as plt 4 L9 v0 W; k! M9 y: P/ o# L
import numpy as np 1 U$ W+ U8 X' g$ J7 k8 A
import os : ~5 I* N) {$ G* N) C
import PIL
' N9 m, z3 u9 r* `) v0 M from tensorflow.keras import layers
; t2 A2 u6 @# S! n import time
: {7 [5 e$ F, K. v- `
- {3 Y, x, A$ ~5 ^. G from IPython import display ; A+ R- K0 j& C- D/ {. i4 [
0 J* n' P$ ~; n1 V$ j0 R' \, t2 s, K
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
7 {/ J. N0 X+ |9 Z1 F9 J! b, c& S
7 j3 j: ]+ z5 E train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') ) L/ B( T% J( B+ C
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
0 N* }0 T! U0 c7 w& w, v! d, c8 A 8 F; @9 s( o. q8 y
BUFFER_SIZE = 60000
# Z6 @% m% u+ [0 K BATCH_SIZE = 256 6 O# l' g. A7 E4 D. h
' g: Y' s" Z5 n
# 批量化和打乱数据 / a( q. j+ u( A0 i
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 1 K0 P7 t2 m' T' A+ w3 M7 ]" g
! R# g0 U D: e2 F # 创建模型--生成器
( y! B. V( }$ F: F def make_generator_model(): ! Y9 `8 Z1 m6 j) q
model = tf.keras.Sequential()
% ~* e" o6 E( C, j# T0 s model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
8 Z" m- {5 y2 ^ model.add(layers.BatchNormalization()) $ R3 o* w% V% L. S. i( p, U
model.add(layers.LeakyReLU()) [! @( E% ?* N, V$ c i' n |) l
0 }0 G: R% H0 C0 W model.add(layers.Reshape((7, 7, 256)))
! T7 O& \7 W1 n* A; F( U/ K assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
' Q1 b( G) F. {. p1 \' i0 h
5 C! l* ]/ M' E5 V. p0 n# H model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) * R/ |3 L& ^& X1 K0 i
assert model.output_shape == (None, 7, 7, 128) ! Y$ p! b8 P! _, U; \
model.add(layers.BatchNormalization())
) V; ~# K6 h; ~$ [, {7 {/ L3 b model.add(layers.LeakyReLU()) ) q8 }% r9 g# G4 q
1 x3 u$ h! ^& ^* @+ ^& a2 u model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
9 m, k) s6 i3 B$ y assert model.output_shape == (None, 14, 14, 64)
# _% v6 ]1 r c& o: j4 g model.add(layers.BatchNormalization())
! {& h6 _ Z# f/ D model.add(layers.LeakyReLU()) ) [) u. D! q) z2 ]9 s
5 b' Y5 x; C" p& Z: \
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) 0 u: k2 E% E% C+ {3 I; P$ s" V
assert model.output_shape == (None, 28, 28, 1)
: `; A8 x; U1 T1 ?, y `# k' x
( y. f Y; U+ v9 b: o% v return model
! J7 C/ n: r' B. t1 M * C5 K* {* d! g! n8 {' c# ^
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
" d0 Y9 t- u9 i generator = make_generator_model()
+ A7 n2 W" x3 Z0 P- [; U' j4 g1 [ + ?! I: _$ J2 Q7 n3 r, j
noise = tf.random.normal([1, 100])
2 B. ^2 B O' e Z' N generated_image = generator(noise, training=False)
* q% L; U5 ~( ?) x) T; [: }8 v5 V0 j( E
, p; t" W% h7 k3 e% Y' Q plt.imshow(generated_image[0, :, :, 0], cmap='gray')
) T; U5 y; n! [" R/ c6 c5 k tf.keras.utils.plot_model(generator)
, U4 {1 m5 _1 d1 w8 f* j! ?3 I# p
7 @: }% \5 d8 `+ q* X # 判别器
' _4 |& A4 i; y3 q2 P def make_discriminator_model(): 6 j* A( j. }0 Q, j5 n. q9 i
model = tf.keras.Sequential() ( c- F; y) b/ W4 I& Y- ^3 A
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
: L [% p) Y0 P% l5 J1 T input_shape=[28, 28, 1])) ( c7 Q5 O' {* X2 E- S5 b
model.add(layers.LeakyReLU())
7 E1 o0 }+ n/ u1 p1 a model.add(layers.Dropout(0.3)) 5 _# R9 u( k: R0 @# F: M
( H2 m0 T& m: x* D/ j+ ~3 E
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
5 e! E& ^9 p* m1 a" i& q* }. V model.add(layers.LeakyReLU()) ; o* f% @4 C0 C- W+ p
model.add(layers.Dropout(0.3)) 1 c" q! ?3 W: D/ F
! e' q& U7 Y/ ^' K4 U7 C& Q
model.add(layers.Flatten())
4 u+ k) C1 _, W [- K model.add(layers.Dense(1))
' s2 J$ g; N6 I, V0 i" r3 v0 O ' u7 L, R* h4 m* G* k8 h
return model / v( r( }( l1 H. l
+ r. T; B& }! b
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
/ A9 n. K) Q2 u$ F/ y, ?' K discriminator = make_discriminator_model()
& G+ B- \. K j: \& i, m decision = discriminator(generated_image) ! y, o: d4 N* ~9 o) V8 L8 X
print (decision)
' W; p3 i6 ?" p8 L, H ) H# P4 l3 A; b. v* {! I9 I
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
: W, V( N0 W( i! [7 g5 Q cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
- E+ U: Y( I) q+ A$ M8 p: M ; R4 {# o) O, k/ X& H
# 生成器的损失和优化器 : m6 ]9 A' \; J% A& H4 j
def generator_loss(fake_output): 5 h/ l3 _0 k: q6 R: a
return cross_entropy(tf.ones_like(fake_output), fake_output) % t, S7 W+ \. M9 o/ g0 t' h
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
$ x& r6 Q ~* {/ u5 V5 i3 q/ F
" K; {* K6 N0 R& Q" n. D9 |2 l # 判别器的损失和优化器 " Z9 A# n z4 ^6 f3 P/ @
def discriminator_loss(real_output, fake_output): ' R+ }6 @9 \# p! e& Z3 l3 W5 O3 L
real_loss = cross_entropy(tf.ones_like(real_output), real_output) k. c4 Z- r) L( ~+ u* ~
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
$ G" A1 U C4 ` total_loss = real_loss + fake_loss : C( Y8 f" i, q9 I6 w; M
return total_loss
- B3 @. N0 ]/ E& O; F$ ? |& o discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
& p5 Y2 p( i2 \. }+ @$ D& Q8 j |1 T; L- L' F2 G! B0 S ~" a
# 保存检查点
! F I `2 o; F0 }2 Y6 v; h" T checkpoint_dir = './training_checkpoints'
$ W' B: G8 h/ p5 A checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") " f# e( X" e. {5 }
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
: k9 v- ] a7 F$ I( U( O5 p4 _ discriminator_optimizer=discriminator_optimizer,
6 c% \5 t: o4 _. _ generator=generator, ; D0 m+ D! ^) [# p
discriminator=discriminator) ~! X! G# I* i y7 ~
6 m! X: U! w; R/ N
# 定义训练过程 * G2 ?% T. n! `
EPOCHS = 50
4 x6 e2 t* M v/ f) l! q- v5 | noise_dim = 100
, e! B3 v7 @; D# }7 y: j1 A num_examples_to_generate = 16 / n+ X2 p2 a9 u1 K1 j! {8 E
1 c0 r% R# A5 R/ [- j
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) 9 M3 Z" _7 j+ q& v3 r m+ z* T
seed = tf.random.normal([num_examples_to_generate, noise_dim])
* Y ~3 s* p3 p6 u3 j ' y# ]% b' u3 |5 U7 Y- I4 n
# 注意 `tf.function` 的使用 , V: P* O1 |9 \- v3 c
# 该注解使函数被“编译” 9 A) J2 h2 \5 E0 I1 L8 G
@tf.function
B" ?1 ]3 D4 V9 G# K* s3 H def train_step(images): ( j6 K# ]! e! J$ x$ r4 `
noise = tf.random.normal([BATCH_SIZE, noise_dim]) . `2 D% o9 k B" W; |+ l' h7 r" W
* C+ a3 @) r$ ]+ m V* h* H with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: + Q/ P3 C0 E# I" g l. T
generated_images = generator(noise, training=True)
. h4 S5 W; Z0 J4 m( `1 X! f) z * q' N( k5 \; z* M8 s' \1 c- q
real_output = discriminator(images, training=True) 3 q/ `- H( e9 b5 d7 j' i
fake_output = discriminator(generated_images, training=True)
5 b" ?6 U# ]3 {& [. P8 D % q2 X; \) y8 T" z
gen_loss = generator_loss(fake_output) 6 _, @) m7 u) x
disc_loss = discriminator_loss(real_output, fake_output) 8 s& L" D/ I: w7 J% A& ]! J
% W% Q( Y7 \! L# Y2 t7 f- v0 e' }
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) R4 H; ~+ K$ r6 b$ E: r5 P
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) C C# e7 i# f! V2 Q! w0 C. Z) {" n
2 {2 a; [7 z# H8 U generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
8 `- h8 T5 K- ]; C% J discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# n( U/ X/ C& Z" ^ Y $ b7 V. P! b1 y* q/ ?
def train(dataset, epochs): 1 f6 Q! S% R- i' h
for epoch in range(epochs): 0 S6 d, K8 B( m h: m2 y) R3 W
start = time.time()
8 j2 i" i8 t6 Q# m
. h& k3 ~' M- Y5 ^ for image_batch in dataset:
3 y2 A" ^/ ^4 v7 Y2 f7 o train_step(image_batch)
5 q/ [7 d5 m; g 6 g9 L" x& Z7 f
# 继续进行时为 GIF 生成图像
( L+ R3 x" X2 g; N9 ?7 o# C display.clear_output(wait=True) ; z3 @4 y5 ^; J7 U
generate_and_save_images(generator, 7 |8 M' r! h; ]; T* U+ X
epoch + 1,
% W q3 n& e2 A+ _$ K! \" i seed) " J+ F7 B0 M: ?/ e5 J1 Z) w
, l2 ?, f" U0 N2 d% `' B, h # 每 15 个 epoch 保存一次模型
' K& x, ?2 b- r8 u2 D# m- ^ if (epoch + 1) % 15 == 0: ! v0 ]! u6 o* j8 Y" g
checkpoint.save(file_prefix = checkpoint_prefix)
; o, W# A9 ?( [) n2 W# e " \8 }. `2 x) n4 A/ J7 s
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) 5 h6 o+ m* b! v& G
6 Q7 |3 d4 |4 b- Z4 F4 I* v # 最后一个 epoch 结束后生成图片 ; r, j# Q1 y8 E$ r$ r& H8 [
display.clear_output(wait=True)
t5 H _+ n( B- `4 ~ generate_and_save_images(generator, 2 I* p6 ~: T' R% \, ]0 W' Y) T/ m, U
epochs, # z ~4 m* q$ K# Z1 G) b& f) z/ `
seed) 3 \* G% x* i& J% @% k2 I8 I
# `0 P" g2 G) ~$ d+ C# ^' o5 _! y9 y # 生成与保存图片 ) q9 c8 }5 W# X: |3 G
def generate_and_save_images(model, epoch, test_input): 1 ~" X2 [5 b9 ?! X% [
# 注意 training` 设定为 False 1 M1 @8 ~3 s9 M! W- t8 L1 J
# 因此,所有层都在推理模式下运行(batchnorm)。
$ p0 J2 |5 Q! d3 U predictions = model(test_input, training=False) ) G% C" d, v6 _
% k9 V. v8 ^2 E) C+ ] fig = plt.figure(figsize=(4,4)) ; w8 o+ x; n8 d4 y' A% i
* y" a# c, D( {/ O
for i in range(predictions.shape[0]): " ^0 A6 v/ \- H8 k7 v9 n5 _
plt.subplot(4, 4, i+1)
* j$ ?* y/ m' K: e; Y plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
, w! \ q/ P# f$ U: X5 @6 k plt.axis('off')
, E- G# a1 `5 |3 d) |: z
% m* h6 p- o2 U0 f% u ]! q1 y plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
8 ?0 f; o+ n% _; p2 p plt.show() 3 v$ D# `9 {0 N% G/ y
" ?, q; `" }: o( c1 | # 训练模型 7 S4 s) u0 Q1 V2 g+ A
train(train_dataset, EPOCHS) ' ?% d( A7 M7 L
" Q8 l1 }, o- [( X) f$ v # 恢复最新的检查点
: I" o; E8 ~! ^* C) R0 C checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) ; c. W9 ~% y! Y. }
9 C4 z+ A& B; z' _1 L7 F% K# @ # 评估模型 # J. g( @& V; s2 F7 p4 `7 y
# 使用 epoch 数生成单张图片
3 x2 p! Y, t0 ^+ j O6 I; f* @ def display_image(epoch_no):
- v+ x6 N/ f- f/ L; t& c6 O return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) & b9 y4 l5 ?8 G. n4 i3 w- _; g
( b6 v/ m- }9 h- F
display_image(EPOCHS) & {( \+ h! D( _ s8 t
) D% c3 ~; C! F: H& \' ]+ s
anim_file = 'dcgan.gif' 8 b4 b2 L! U1 c e+ q) O
7 |# F0 A1 \# y1 m. |! ?3 o. N
with imageio.get_writer(anim_file, mode='I') as writer: 2 z6 Z% _* I6 F. c P
filenames = glob.glob('image*.png') - w: I& D" Q; R! U* U
filenames = sorted(filenames)
, z( v1 X; @0 w/ G% v4 K last = -1
$ Y& e9 t; z B! Q+ T for i,filename in enumerate(filenames):
3 l2 w* ?7 ]4 @% g! [1 i9 d frame = 2*(i**0.5)
( Y3 n9 L6 V7 ?" E if round(frame) > round(last): : N4 k8 n5 p3 y5 R+ ~9 W T
last = frame ; C1 r8 M: ^6 S, r \+ J
else: N$ |4 A ]- r7 K, L
continue 7 E% x3 j$ H4 _% M0 ^9 j2 @& c# v
image = imageio.imread(filename)
2 u7 |. i m: z% U- w: { writer.append_data(image)
4 ]* w9 [3 P- G image = imageio.imread(filename)
1 l/ G8 n5 l J& p! |) R writer.append_data(image) ' R7 U* Q/ h, x# l
# k1 I( o E( F: d$ S5 G# }! L6 H& r, Q
import IPython + |2 u/ r% n! w. e6 F
if IPython.version_info > (6,2,0,''): # q" M: O8 G+ \* Z+ o# s6 f
display.Image(filename=anim_file) # }! j( G8 o5 _, c& Q& P0 g) L
参考:https://www.tensorflow.org/tutorials/generative/dcgan
! g8 ~. s$ ~# @1 D ————————————————
4 R" }- ~+ ?) s5 d 版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
8 q) n' t: O, K5 [6 D$ @8 y% {* a 原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
% p! |1 a: b% |; a- i; l+ c ( J+ F! W K: b, i1 `7 `1 h
7 K9 l% L( Q% c9 `' q0 [
zan