在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 563353 点 威望 12 点 阅读权限 255 积分 174229 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 3 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
0 X* h. ]% E# {* y; |
深度卷积生成对抗网络DCGAN——生成手写数字图片 3 ?" }# z5 s6 r' l& _
前言 F7 R* s' ^8 p: J9 j% \
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。 # r- }' Y+ _/ y* p
6 V% z9 |5 d4 b5 O5 h
5 h( @* p$ K& c# O, G
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
1 F9 X% G; N( S+ a) \5 K( H # p; f' X/ G0 m( L1 W% Y5 s# V* c
0 `$ Q2 L! h* y, G6 L& T # 用于生成 GIF 图片 - ]% w, s/ u4 f, B$ W. X' D2 Y, L x" ~3 l
pip install -q imageio
: x0 b! Z9 E5 `9 T; u& H3 O 目录
, D2 o: w' x8 s/ B# k' n$ ]; ^" L4 i 1 B0 y9 T' j6 y2 R6 @6 }. H% s
/ L8 Q; C. R0 c7 Q
前言 . k( O/ Z5 |7 ?% A+ S7 F
. }5 e) S4 [. h- `
2 T5 s4 h- j F6 a 一、什么是生成对抗网络?
0 x5 h `$ R% @: ^$ b9 @7 _
4 ~+ w- s& ~* J$ r1 S. |% S 1 Z0 g# q% \; B
二、加载数据集 5 q6 Q5 ` l/ S; z9 ?
, z( J% k E! M O( I: p4 n( g5 ^
三、创建模型
9 `6 G* |5 {: k8 x ! O6 t V3 j7 b8 W* K" K
* X& e* _4 G' \( w: U$ A
3.1 生成器
8 ?+ N! R8 P1 b4 k" S7 w$ A " Y8 t7 Q) d+ s3 l
% U% I6 e' x% x0 \" X% v 3.1 判别器 2 v9 ]/ ~) v6 j) H; x6 {1 G
} U9 m1 h9 ]' N
* c, B& R8 D/ q 四、定义损失函数和优化器
5 K7 X3 {. D5 k0 t " L6 t4 s) @! H8 V! \# e
3 v, c; O0 m" X0 p! j4 y3 y9 E2 @ 4.1 生成器的损失和优化器 / [+ ] B* e+ s: x0 Y$ h& u+ \
, u( @: K7 j: N' q- n' o
6 P% J7 ?5 A# z: |$ Z2 S0 H1 x$ s 4.2 判别器的损失和优化器
6 H% W4 K# K' `/ d- {1 H
* H4 s8 }- a+ d
- F/ l' Z& H& T4 j/ _) E w( B7 @ 五、训练模型 ; i7 M0 _( f0 V p3 o
; @; r3 B# @5 Q. w
: Q9 ^7 _2 f; U+ N* ~
5.1 保存检查点
Q# J' H% i, }7 f; Y( I ; `% r1 I8 a3 I1 K5 A, K
1 m. z. d& q2 M0 i# b
5.2 定义训练过程 5 N2 I% S: \) r7 Y8 L
9 e9 l9 _! [, g$ x# y7 Z( X
8 j& M# Y8 n+ O) Q4 H8 f% P+ O% Q
5.3 训练模型
; j3 [) ~8 i ~ ( j+ w# ]1 H, @
. P) j! w- \ v; @3 z% X1 d* n
六、评估模型
' M9 e) R& v, q
0 M: c5 u6 ?; P% F( N9 R# o ( ]# I/ v9 ~+ Q7 ]
一、什么是生成对抗网络?
8 d% K8 h- A0 O3 Q7 `. d 生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
# ]: R+ ]1 X2 e" f % g6 t2 G% p2 q; I6 `2 W
% B7 b* ^4 J5 }0 ]8 u( s
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。 9 A2 G: {% B# ?0 v, @
2 f; {) h7 B" l7 Y( w+ s , T9 X+ C( v0 K% s1 z5 t
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。 Q2 }" A; K7 T
6 M+ M3 @3 n: \$ a0 @1 x* d9 d) q) d
7 [7 b4 b0 ]* Q/ D 训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。 ( o. t/ C a" A5 j) h
" m. u& ~8 S" o7 x' C$ i2 P6 Z7 G ! B+ S% W7 Z3 V, h8 a. n( s4 J
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
; d7 Q# v* l/ y+ u 7 ^/ H2 Z) T8 B! ]8 i
% M3 [ u8 h% p 本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
x* x! u+ k" T8 V# `+ y7 Y
5 t( C O0 u E: C$ s2 D
1 v' l, R/ I" [1 `- F1 h 二、加载数据集 , f% g$ M' o+ t# T
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。 " p0 r9 M S5 ?* H3 v, G
3 S! o7 V4 g/ b: M1 ~8 Z: C- a
, J; \/ [. D) \7 I! i (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
n# I- o. V" i2 V
* O" Q Z* D) I, I( `' V8 ^0 m train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
6 E4 K, V$ n. c+ f ]- n7 } train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
A# ^+ i6 O* y% D9 r9 X ( r3 X+ ], r5 X* k# a7 ?7 R. T. s
BUFFER_SIZE = 60000
; T0 \; B$ B- ?# |6 v k% x2 ], w. c8 D BATCH_SIZE = 256 3 T: w$ A) G0 g, m" W9 p
6 c( b- e* I! q4 k1 Z k9 s- m
# 批量化和打乱数据
/ P4 e' {) F1 Z# r7 M4 \& J! N. ] train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) ) @# k8 \$ n! l( [0 j
三、创建模型 . C+ T4 W8 R; X; }% w; Y
主要创建两个模型,一个是生成器,另一个是判别器。
' a" L/ w$ ]' [9 ^: T6 o$ I4 n
/ O' u t0 b' y, _: U 1 V4 e) P$ Y5 W" Z) ?7 h
3.1 生成器
, n' ^- e$ m; V 生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。 - D2 r1 X( \; Y
_/ y. M2 y. a
_& \3 h* ~1 W 然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
5 T' S% {* r6 } - G6 g4 S1 D/ U) ?" l) L
2 G4 x6 @3 L7 \' w1 s% J
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
s5 x# O: u6 v( @4 V! V# B , [" S$ c. J7 Y T0 l- u
/ s2 S- C! i- E* g/ o
def make_generator_model(): 2 Y& P" v& w- |! y0 \0 Z
model = tf.keras.Sequential()
! u$ Y1 ^' g" ~ model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) " A. d: E' R4 {$ N
model.add(layers.BatchNormalization()) " I; a, n* S# [7 z9 @
model.add(layers.LeakyReLU())
; E3 M( Y _: S% x. U ) j% H2 D8 K0 T7 H' S" k
model.add(layers.Reshape((7, 7, 256))) $ z4 O" X! p; i' B7 I2 X! @" X9 Q
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
& B7 f9 |. k( K) @2 d ) k6 [! f% o) i+ b# I9 r. t6 Q i
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) - Y' Z6 @: Y. a2 w9 {' r1 m
assert model.output_shape == (None, 7, 7, 128)
; z( j) @! ^# m. Z5 P1 _1 O model.add(layers.BatchNormalization())
; n+ D B( @! y# s$ V, }1 |& m$ [ model.add(layers.LeakyReLU()) " v3 F, C" G, F: W9 Z
. J; q a+ o: i& p ^
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) . j2 h, }4 N- e, g
assert model.output_shape == (None, 14, 14, 64) . C( ~; J" ?6 z: y
model.add(layers.BatchNormalization()) # J& @5 R: D7 t
model.add(layers.LeakyReLU())
3 \+ j5 ]$ c! _5 V 3 A4 Y6 S2 x% W# Q# ~+ b8 ?
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) / g, K/ z# L8 y" W3 V
assert model.output_shape == (None, 28, 28, 1) 3 e6 t9 K/ G4 L: A$ b+ q" b
" _1 p5 V; U& V& U/ K# a, b" c. ] return model
' j1 r& Q2 b. ^$ Q5 x, H" q' q; q 用tf.keras.utils.plot_model( ),看一下模型结构 " C0 ^4 Q; J% B, Z" g
0 L( q4 J5 W4 I' A# W2 g$ [ & D' d) b1 W, S: z
& O( B8 j- n) d' g' v! A- y
. H& S; @1 s9 G* U0 K $ V9 d; ^# q* z+ H' s
用summary(),看一下模型结构和参数 / v. \! o$ l9 _. z( Y: M* S2 N
& u+ b0 G3 j r. n w7 t
: x! _* p/ ^9 ]# t' {/ O& [0 l
8 O* _8 a$ [0 Y$ a2 f! |! s
# K0 A2 ]2 D" H5 I7 K; ]1 \+ O
' m4 O3 t3 Q9 ^2 j% B5 q6 b
% W- \& G, o- b+ m* T 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
' v' D' z Y6 a) \6 V S5 S+ H S$ W- h8 h7 X3 ]9 B
4 J2 h+ J" c5 h- u4 k
generator = make_generator_model()
- _% x/ d6 D' F$ q % H& u: H6 d. g& V; S
noise = tf.random.normal([1, 100])
2 l0 Y3 p; X+ V3 g8 B& g. y6 q generated_image = generator(noise, training=False) # C2 K; G! F8 l$ W4 Y: x% Y
6 Q; M& i& D. S' I
plt.imshow(generated_image[0, :, :, 0], cmap='gray') ) a A' g, d( j6 m3 o' u
% q6 g0 z# a; u1 o# L# a6 I
& a6 @% J H( i: n0 Z5 o0 n' v8 e8 Q
% k0 l S. \' | ( C3 t. Q& {% z; }2 E
3.1 判别器 - W* Z$ g0 e% n5 ]+ F7 `2 d
判别器是基于 CNN卷积神经网络 的图片分类器。
" L$ I/ u8 M2 j
. h8 B4 d& v F6 I2 c8 {$ ] ?9 ? & {5 J6 A4 T5 T. v
def make_discriminator_model():
$ L3 v u- o6 [+ D& p0 } model = tf.keras.Sequential() * }# X% x5 R9 W1 g' ?
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', ) f$ R$ D. o9 M5 g/ V
input_shape=[28, 28, 1]))
) ?5 ~4 r: ]: _" q* f model.add(layers.LeakyReLU())
, g8 M I( O7 p8 l model.add(layers.Dropout(0.3))
: p8 P2 @" ^- O % Z2 G4 m9 {) |# `: s7 l
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
5 [' r$ F% w6 U. P! I" | model.add(layers.LeakyReLU())
# v( L8 n' Z; g1 A5 w+ ~ model.add(layers.Dropout(0.3)) @$ I8 Z/ Y7 @
: _" h# B D2 U2 j' n
model.add(layers.Flatten())
7 Y% K5 M) z9 x. @$ \! ~+ { model.add(layers.Dense(1))
, t- o* X2 k" J: Q% w% [7 ~
, c" ^* R- h9 d4 Z" u n* U3 i return model
9 a: j! ^" E( j1 S, ]2 A" b$ e( n9 h$ _ 用tf.keras.utils.plot_model( ),看一下模型结构
J) F5 p% P9 B: S ! Z7 g) M& e9 I0 Q
* ], f3 a1 Q4 r, H
! l7 N8 @; x+ b% H+ l& K: O7 C- n
a, a8 H8 @9 |8 a- f
# Q; ~% N5 M( q& r9 l' H
# U3 Q g ?3 C' R, z. a 用summary(),看一下模型结构和参数 ( }+ R% |8 H# X2 h
0 G. J* K* H; z
7 ?5 ? m3 l$ A 2 _8 U8 m& @/ i3 B1 w& X
& x- g- i0 u" ~$ k' c& M3 q4 n) L9 x2 I' v
1 n, k: n6 W z$ A$ M
/ D! j% a( f$ c! O# m" u; D& G( F 四、定义损失函数和优化器
, P, Q% W3 d2 f9 g& C0 D" \ 由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。 ( \. m; u! z/ E' }& E' I# s& C4 p, A
7 a" D! p: B4 ?1 s8 G. \
7 I0 z, S! Q$ O 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
" {* G: V' B: x* ^. _ & f, x D9 X2 ~# c
% I6 s2 \: [4 _6 d! n9 ] # 该方法返回计算交叉熵损失的辅助函数 , s' r+ J& R4 s% L, O+ [
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) , q# S* c! V) @) E" V) i
4.1 生成器的损失和优化器 : T. q) j1 A: l
1)生成器损失 ; X/ K4 e O1 g1 `3 J" [- n: Y
& m; ~# D# \/ ^0 ^$ N
, V$ |7 M8 H2 p 生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。 ( `7 H: `' ?6 \$ p2 r2 J% P
! M# H( A {! E# m% W' R7 H
# Y8 f5 k; L5 j6 w; y# h% @" n 这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
o9 {6 f& ]8 I. \ ' k7 v: M4 @# I1 g% s" Z
6 x$ C6 K6 o; H* I# X3 i8 i) j
def generator_loss(fake_output): - o% B6 |; R! d! X" ?6 w3 l
return cross_entropy(tf.ones_like(fake_output), fake_output) 4 |1 t! U4 K/ `5 \- }( K
2)生成器优化器 - K1 y6 T$ r4 K* e1 e5 `) ?
) R' c+ e# @9 t
5 y( B! N0 ]- n
generator_optimizer = tf.keras.optimizers.Adam(1e-4) ; M; n! \( H L7 c0 ]; D+ @# a
4.2 判别器的损失和优化器 / _; ?/ Z1 I$ M3 s" Y: |+ ~5 K
1)判别器损失
$ i: @2 r# {: N5 }: r6 O5 n# k9 t% O5 e % F" ^# r* @5 Q" Y+ [. r; e
4 D! B2 C* d4 {! g; }3 L4 \& v, [9 d 判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。 # D3 n. ^: x9 Z) b4 a# k* Y* w
7 E, c/ l3 c9 l$ C4 G
& @" L' D9 a/ v# v) g$ l( {# O) h def discriminator_loss(real_output, fake_output):
: Y, K9 g/ ?$ x% `# F real_loss = cross_entropy(tf.ones_like(real_output), real_output)
B! B; ` J; K* N$ R fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
' W- {$ }- k2 F4 f% d. {$ t8 f total_loss = real_loss + fake_loss t. K2 f* Q1 j. L! A- h7 K
return total_loss & W3 C( {& g. T0 S3 ?4 f
2)判别器优化器 4 F* C1 _4 I& d3 a! n! b6 B: x
) i( Q4 M( L h5 l* k
. a L# e8 O7 j0 X3 p discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# x: N; l# Y+ Q1 K$ _ 五、训练模型
3 C! S& ~. f( P- Y9 s% T. p 5.1 保存检查点
! I6 C! r# L/ G# g4 P# A* B" l2 e 保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
$ Q: v" S8 p3 N
+ U" P: g1 b8 o8 n4 x/ O! X
0 z# k! Q# L/ T+ ]; W checkpoint_dir = './training_checkpoints'
: w4 @+ _" f- O0 ] checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
- r& [# P0 q' R' V4 s checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 3 L! p, e) O6 F, \, m" J
discriminator_optimizer=discriminator_optimizer,
& ]% X2 O) Z1 S% | u: p( ~: d generator=generator, * d0 F8 s: }! ]8 B' I. r
discriminator=discriminator) , B) D5 d8 P/ _# I) c
5.2 定义训练过程
* V6 j( F$ n& R- s% N EPOCHS = 50
1 M6 t! B$ ?; V1 l noise_dim = 100 ! o1 b4 k. s9 h Z0 e% b/ _9 O
num_examples_to_generate = 16 , g, D+ n" n! V1 c1 S, e8 J
" e/ Y- K9 c5 }0 D ) Z q: Y" Q+ f0 @, v/ `! M
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) 9 ~4 @0 |5 z1 d2 m+ d# @, p
seed = tf.random.normal([num_examples_to_generate, noise_dim]) ?( k$ _; y1 P s* N
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。 : C6 J. `3 A1 \9 K# z/ l; R( x
, T9 d5 \5 C: V0 V" b- c
3 K+ N! [, e x1 ]) I2 m 判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。 & W3 t9 A n; B; Q3 g
4 l% `" P5 T* r! V- S
1 r7 B8 Z; p& e- l 两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。 , i) G* [9 y6 N' e3 r
2 v: Y( \$ f8 }, I7 O4 e9 w
* T3 K, t/ N9 I% s8 j( E& B" N+ l9 G7 E # 注意 `tf.function` 的使用
% k" ^! _- k0 | # 该注解使函数被“编译”
9 b9 e% g) V$ E' z @tf.function ( v8 V G2 c) s: `% @- ?
def train_step(images): / P4 ~; W- M9 M( h* P
noise = tf.random.normal([BATCH_SIZE, noise_dim]) 2 ]1 ` |2 S, T9 G/ s
( N; c& W+ f7 |
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
/ l& E, I" u% \0 l generated_images = generator(noise, training=True)
6 ~+ k4 U2 P. }6 n: z O ! f2 h; b$ p- o6 a
real_output = discriminator(images, training=True) : `! H* H8 }+ i+ W( a" `: T9 _
fake_output = discriminator(generated_images, training=True) 7 ~, J' ~* ^; Z2 }" i
- l* E ` Z5 X, T% t8 a2 w gen_loss = generator_loss(fake_output)
- l' W% p3 V5 S8 ^4 A disc_loss = discriminator_loss(real_output, fake_output)
; n" x8 ?" k3 w6 q0 s& [7 S6 Z# Z9 Z
, H! ^) m# m+ C& P gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 3 m' J! }+ o4 P: e, Q3 S/ y: R
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) 0 S) `6 O0 h: S* s6 O
+ Q0 |+ {4 {, y6 ^! ?- U
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) ) _1 ?( C- T( u! m& h
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
1 i% {" u3 Q" X* a/ W # z3 V0 Y% b- W t
def train(dataset, epochs): 8 W: l# o! V0 L/ Z0 ~
for epoch in range(epochs):
* ~) N1 X( L: {9 B$ T9 E) w start = time.time()
5 A* j# _2 @/ [* R
, p) m* m, i C6 e* d9 w+ C+ S for image_batch in dataset: 9 Y2 C _' X- p$ X4 D
train_step(image_batch) 7 L# O. e9 p* G" B
+ N) e( _3 `/ o* l7 T* G. q # 继续进行时为 GIF 生成图像
# u+ i# f4 q: p8 E% H8 ]1 ?1 Q/ t display.clear_output(wait=True) 1 k' d7 [) k3 p7 _( @% K
generate_and_save_images(generator, 2 U* P# ^2 Q, C# z" L, ~
epoch + 1,
" ^% {/ P+ I0 J5 i seed)
" o& y6 @7 v8 ^! B: W ' A: R& \& M4 V3 P6 c
# 每 15 个 epoch 保存一次模型 2 {; _+ ?. r: O8 k, J/ B
if (epoch + 1) % 15 == 0: ! ^' }6 T" e$ L% G
checkpoint.save(file_prefix = checkpoint_prefix)
; ]' K, M7 ?# j! u& G( W3 n
) m% h% y" j2 m) `" q print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) + k+ u/ h# {" P: f+ B) S$ ?
/ [! H. ~- `6 i' F5 A* x4 r
# 最后一个 epoch 结束后生成图片 $ U4 T6 B* A0 T: o) X
display.clear_output(wait=True)
* l3 X8 x4 V/ a generate_and_save_images(generator, % M% u! t7 ?0 s7 ?
epochs,
. G. I) c( Y+ {8 h7 {! S% i, ` seed) ( }! ~; R2 n$ Y
, H# |1 y( p2 ~
# 生成与保存图片 ) t. p( C; l2 Z: c( b* d
def generate_and_save_images(model, epoch, test_input):
# }. [/ E9 x* W& O& N& v. S # 注意 training` 设定为 False
8 s6 ?+ j: ?" N \0 o& V2 u # 因此,所有层都在推理模式下运行(batchnorm)。
% c0 v% x" P4 b c. ~) j0 K predictions = model(test_input, training=False)
8 ^0 B; }5 Y5 Q. `% h
% t$ k, A4 Q- F% a5 P# K H fig = plt.figure(figsize=(4,4))
1 C. i& k" U4 B ; M$ i8 W, |$ b9 @% ]* a- k
for i in range(predictions.shape[0]): + ^7 L" b1 M' C
plt.subplot(4, 4, i+1)
v) C. U5 d: r9 n' Z# Z% L plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') : u' B) K8 r3 n$ z$ N* I" T
plt.axis('off')
: b- P$ e# [6 Q- r4 W* e
; G2 B' ?* |$ M0 n) B7 V& G' b plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
0 c# w Z, L- B1 }' O: T k plt.show()
3 A. H) \0 M. Q/ _ 5.3 训练模型
& H- |8 J; ^) M7 o( K 调用上面定义的train()函数,来同时训练生成器和判别器。
+ R" M) m/ J _* z! f& b5 n; ^
( f- q5 D! @; e* k) Z/ z- c % C, d" |5 J. D- H" }
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。 " }# w) D& `3 R' M ?, O ^; ^
: e. B/ b% ^' @ k6 w. U " o, `6 K9 T& |$ J9 c
%%time 7 ?& ?$ Q: O8 y, v
train(train_dataset, EPOCHS)
. X7 z! t% x3 Q$ f" u4 } 在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
3 D) r0 T4 Y; B 7 ^9 I6 u1 L2 z( J% k6 r8 L
5 W- Q. Q' ~; x6 d, c3 y! m1 R4 z
训练了15轮的效果: 7 G! u. d4 r; h' n' L+ i4 g1 ^
5 z, X. B! E7 [
, w5 m: a% y2 ^0 X8 `* K# q
! D! R$ u6 m% H2 U
* G7 i" @1 a) t% l( H % D0 Q9 P6 G6 n1 `4 G! V0 Y
1 K1 z* h7 c8 W$ |( J9 J
训练了30轮的效果:
9 K( m& f1 m4 e" k$ W 1 Z, b' C- T/ D: c8 r: C$ ]
5 J# n+ q% C7 |
& ? u n; t: P" T! o/ ]: |
c$ y/ ]) W9 L) w& F
: w) o. k' z1 [+ p2 I
& Y, R F2 n& P" ~6 X 训练过程:
: E( _$ \5 m7 V7 T9 R! @( h" K
9 o1 X) }- g e 8 a* B/ S+ `. W7 W% R9 |
9 d+ _! z# P# X! O7 G \& o6 P1 i J+ W3 ^, F4 K
( w, Q4 b# R# P: D 5 s5 |) l7 u# \1 w h8 Q: ~
恢复最新的检查点
6 a- @0 [9 c" x) S7 d8 W) g
2 n; f$ h \' d$ C; I- Z4 H/ B8 V & `! I7 L; p) E5 |! h7 d. s* h1 f
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
8 S0 R6 A# [5 R: x 六、评估模型
5 Q! K5 k) V+ ? \+ ^& ?5 ~8 U 这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。 : K. f8 j9 J/ ~5 y
% U5 w& K2 p; o- C
5 s# b4 R' E* Z: }: T5 c: ~* D! Z
# 使用 epoch 数生成单张图片
2 i' u- o% V" [ def display_image(epoch_no):
$ v/ s/ ^" s" L9 r1 Q, a$ w return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) 1 o3 z, I, v; Z9 m! y4 r
# X/ H% X4 y! ` V
display_image(EPOCHS)
; A7 P( n2 o" H+ c0 }9 { anim_file = 'dcgan.gif' " e2 q# o) @: t9 R4 T3 T
. j; c6 V3 c5 o- g" Z with imageio.get_writer(anim_file, mode='I') as writer: ( T% Q' r- U+ o' t
filenames = glob.glob('image*.png') ; P: e% i' I' Z3 l+ s9 I
filenames = sorted(filenames) 3 P. ^7 n4 z. H1 ]) d) B
last = -1
2 U4 s* _8 M( F. @/ q6 e, @1 E for i,filename in enumerate(filenames): ; q, v' H9 q4 ~# [/ o7 n
frame = 2*(i**0.5)
: R4 k0 X; P- `6 v0 Z if round(frame) > round(last): 4 j4 C3 D$ y- ^9 n0 k
last = frame
& F, F7 S) M% M7 j4 V else:
% a, C! [6 R2 m! g9 E" G5 L3 Q continue 7 g; Z" e0 a- }+ d9 _4 Z C2 K' i$ g
image = imageio.imread(filename)
7 o# g: H% B/ d3 J( \7 m1 b4 M: s9 b writer.append_data(image) ; D3 }9 f5 Q8 h6 L; s, |3 x
image = imageio.imread(filename)
. O! }- v y! L( k( \& Z7 Z writer.append_data(image) ' E8 R, H% e" L5 \0 G& l1 V
, X( E6 F8 ^/ c0 q4 }+ y3 \0 e import IPython
; A) G- v1 D' m) Y if IPython.version_info > (6,2,0,''):
9 h% b( U/ X; x, N8 p, J display.Image(filename=anim_file)
) b: J5 {: y2 |/ Z7 [4 Q; X 6 b- T: c0 Y- V; q
3 u3 I. X, H5 B9 P: i3 _
9 I Y1 ^1 F+ p3 u e. V : \, ~% h' u/ P X
完整代码:
7 A5 ]$ @& W( a( r, L $ j. ?+ x/ a8 ]$ e: W+ t/ @
1 G; x% ?+ n8 s* n3 t9 ]7 H import tensorflow as tf ; [& c$ s' `* ?1 C+ w
import glob 7 G1 }1 A/ [, Y
import imageio
8 ~; c; K( ^9 i6 w ?: Q5 Q import matplotlib.pyplot as plt
$ `9 a N' k6 j8 T import numpy as np
0 v+ t+ r' y7 w8 y. M import os 6 S: m2 P7 v" z7 F4 p
import PIL
& p1 k1 v* T$ |" F5 Y from tensorflow.keras import layers
% H( V( U$ _; I9 J. [" K import time
; f, {" `8 o' }- o+ _; x1 M! J/ f" e
1 e/ F, x% R( X+ |/ b- E from IPython import display
0 }% ?# A* [5 m1 }# X, ^' l
0 D( I1 L& E& O9 ^ (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() : l7 Q3 V* ^( g4 n7 x0 z: k8 F
8 g6 q; a4 O. Z( n3 |0 `2 o# q
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') - C* s/ \- u- u3 ]
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
; M/ q' k) s! R7 x7 I " u1 E$ y h2 J9 P3 f+ r2 K1 P
BUFFER_SIZE = 60000 ( C2 q a; U, _2 }8 n
BATCH_SIZE = 256 ) b) M5 L; q I! u7 l. \
( G* K% V& b+ i% _+ i
# 批量化和打乱数据
% w* l0 H9 A) P; g- j( l6 a train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) ! f0 Z0 R' r$ E9 h
. ]$ m# }' r4 W) v2 }' P5 A; }
# 创建模型--生成器
* @9 {$ ? c1 A+ z, r def make_generator_model(): - M" o* g$ |6 k( u2 E. I
model = tf.keras.Sequential()
1 Q8 ^. D0 y! {8 I model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
. M# g( m, I4 T: ~ model.add(layers.BatchNormalization())
( E! `! o& O4 M0 ? model.add(layers.LeakyReLU()) ; r. \8 r, Q0 f, Q
# `& f! ]/ n1 T* z3 q model.add(layers.Reshape((7, 7, 256))) 0 x( m( Y1 E1 r( N+ w% j8 Y9 ?
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
l8 G8 v: ^) t9 u) H + W" O# K, v m3 x: j
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) . O1 e- j& [, U* ^1 F
assert model.output_shape == (None, 7, 7, 128) , n- _$ q* ?5 g( \
model.add(layers.BatchNormalization()) 8 M- d# A6 W! V( p, W$ c z: K
model.add(layers.LeakyReLU()) * ~6 z* G6 L* O3 V! t5 c# j
3 F* o6 N0 f9 M6 t" |$ b6 q
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) 9 z- z1 E- \) {5 ] u2 ]* f6 T
assert model.output_shape == (None, 14, 14, 64) . \4 E0 f. `2 o6 u) ~8 R
model.add(layers.BatchNormalization()) $ t! x$ ]2 N! _
model.add(layers.LeakyReLU()) + }! |9 q# q% g5 I, c3 W
4 H+ s% ?9 V9 S% s, R2 x
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
- J. C; i5 J) L' E7 C. M8 ?' f assert model.output_shape == (None, 28, 28, 1)
! `8 H+ O1 A R- `3 F 6 S7 E+ X$ {/ D! ]6 y5 L
return model
: | d. M, P5 n( l* ^; n5 q
" \" ~; ~& G5 @( Z( N # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
7 K- L; x: Q7 C' y( s$ V y2 m: L generator = make_generator_model() # p% z3 r1 j/ B
% q: n3 p$ H, o( a) h1 @! Y
noise = tf.random.normal([1, 100])
; C5 V$ _" N- I ~ generated_image = generator(noise, training=False) * \ k1 z) ^8 V* a2 ?( o& `1 X U
9 _3 A H* g3 W% Z9 X plt.imshow(generated_image[0, :, :, 0], cmap='gray')
3 l/ a$ F0 s9 C' M+ E/ O tf.keras.utils.plot_model(generator) 6 F% j! M& I0 X) s
7 l# J( {; e2 }8 F* i# j: j3 z+ l
# 判别器
" q6 s, G' ?( K p4 y def make_discriminator_model():
( z; Q8 z1 f9 z& @* f8 [* m model = tf.keras.Sequential() 1 L; j6 D2 ^/ k/ u# i4 n2 O4 |
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', ( S$ o5 Y8 H7 ]9 e* k- g: u: ~
input_shape=[28, 28, 1]))
! z" ~) Z$ ? p( V model.add(layers.LeakyReLU()) 0 V- x( c6 o6 s" ~
model.add(layers.Dropout(0.3))
8 E) R# {$ A4 W ' s# P% X3 ]5 q6 E& Z
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) : @! P0 V! X* }& |$ q/ {
model.add(layers.LeakyReLU())
9 b( X" b3 z+ r2 e3 c# J1 K model.add(layers.Dropout(0.3)) 8 C/ ?$ D% X) y" }) U
# E- C# y+ Z5 h+ n, m model.add(layers.Flatten())
1 ^8 _$ c. T2 g3 y( ^( s model.add(layers.Dense(1)) 0 w0 j, q% @) C6 _6 @
! Y8 H$ t; f; S/ S$ ^ R2 m return model 3 ^8 U9 O' j6 R$ F
3 S$ D+ K# b0 S; r" p # 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
1 w& `5 G6 x; ^" \6 N2 k' f discriminator = make_discriminator_model()
3 M: k H* t( \! Y& z2 P3 \1 Q decision = discriminator(generated_image) / B3 u6 ] }8 O. L* a0 I
print (decision) ) A4 c: Q$ x A0 ^# o9 ?
7 l$ Z1 L, u" D+ ^$ A# j# e2 k
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 + W# P7 C" l% B: m0 J
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) ! l [% p8 r* E1 }6 Y6 O: _5 U7 M8 M
) G2 \$ l7 N+ z( t5 b6 U # 生成器的损失和优化器
+ T$ u$ K7 l( v0 K; h. S def generator_loss(fake_output): ) X$ j& F7 A( x) A
return cross_entropy(tf.ones_like(fake_output), fake_output)
. v( x4 i7 ]) c generator_optimizer = tf.keras.optimizers.Adam(1e-4)
" n. Y0 d3 R4 \& v2 o
5 E' N" j! n6 T ^0 q0 N # 判别器的损失和优化器 , ^. k0 m, \7 E/ X `
def discriminator_loss(real_output, fake_output): , q8 w P/ z/ h. b3 T/ k8 E
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
- ^2 ^& M6 W& t: Q0 Y$ E1 U# j fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 2 k$ n; \% b( @8 h
total_loss = real_loss + fake_loss
9 ?, |' y" t6 k- y4 ^ return total_loss ; p- i/ D% p: r6 G7 L Q
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) ) ?2 N$ T1 g) K9 s% n, H
* v( }" [! S$ T) w" S
# 保存检查点 0 A( K5 Y0 u" B( k: j+ y
checkpoint_dir = './training_checkpoints' + R" S9 M; g% p- L O: R
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
- _- k. c+ y/ M checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
5 m( d3 ]3 h; u discriminator_optimizer=discriminator_optimizer,
) ?) J) a& g% C+ I& ~7 P generator=generator,
# q+ M; y# F( \* w! C" W discriminator=discriminator) ( V: n9 P- P' o3 h4 K/ E
9 v$ n6 `! \+ ~; \' @7 o+ l) @) I0 a6 r
# 定义训练过程
* \5 A7 v7 O* ?/ ~7 |% I EPOCHS = 50
7 _) E% g/ m# ^8 n6 \6 P noise_dim = 100
& C, S% K }; _3 L num_examples_to_generate = 16 2 Q2 D6 ^- l% A% `
8 v# p8 l$ c1 S6 g8 ~6 G: D& C5 h& ]
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) 8 W; Z# s2 x5 }4 A$ i! o
seed = tf.random.normal([num_examples_to_generate, noise_dim])
! i# ^& q1 G1 W9 I1 n( ^! B3 i3 B
# {% A6 e/ t! H0 l$ b # 注意 `tf.function` 的使用 9 _9 e' U, n( s- m9 Z
# 该注解使函数被“编译”
2 b9 {0 Q! }6 J; C( B @tf.function # D6 E" f: h$ r. w' D" b8 v$ V6 E- c
def train_step(images): ( `6 u* v$ |* M M, ^3 w4 b
noise = tf.random.normal([BATCH_SIZE, noise_dim])
0 D: P) w s9 T% }
/ A5 P0 c* ?: Z9 @: t& \8 Y& u' k with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
" n) N+ H* D4 e0 y+ s" p generated_images = generator(noise, training=True)
; y j) g2 u1 p- K/ y* S
) Q7 A" H% Y- Y# ]' l. b/ @. d! p real_output = discriminator(images, training=True) & }! j& F; |7 ?
fake_output = discriminator(generated_images, training=True)
6 o4 n& P) r' l2 v& w$ b * u- z9 H6 B7 w2 X& G1 I
gen_loss = generator_loss(fake_output) # x- u% ~ e7 l# s& d) ^
disc_loss = discriminator_loss(real_output, fake_output)
0 Y: l5 H) R5 r1 O: W8 k
|) n! J( G- j4 j0 c5 j4 l gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) ! B1 T: e+ w, I. f
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) + D# ]0 q$ B, Z
. V1 @1 }( I8 q# E4 n generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
, \2 t& d4 A% ] discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) . f8 m8 R6 g) i8 Y9 q
$ A4 N6 M J8 ]% B
def train(dataset, epochs):
, f! ~$ V4 Y2 O for epoch in range(epochs):
6 B+ a) j) S% H) w* I0 m6 c5 ^ start = time.time() % n7 }8 _' h) @7 n9 k2 S" b
) A8 T& b4 `0 f' }, m3 y- R for image_batch in dataset:
; T& R7 ]9 V1 v& F6 m1 V" \ train_step(image_batch)
! y) B. m Y6 n
. | t* ^5 ?3 B # 继续进行时为 GIF 生成图像 ' ~8 Q3 |7 L+ V( j7 u5 K" C
display.clear_output(wait=True) 9 Y( b; N7 l8 @4 ^+ S8 o* s
generate_and_save_images(generator,
1 z* v- d E5 C+ j8 h# x N/ ? epoch + 1,
3 b+ G% m3 A5 Q% J seed)
9 Y# @: D C; u) J( t & i+ k$ [- x3 ~5 x
# 每 15 个 epoch 保存一次模型 $ \6 x5 Z1 K4 `- g. I' g
if (epoch + 1) % 15 == 0: 7 k8 A, e5 A" t8 J( N/ T
checkpoint.save(file_prefix = checkpoint_prefix)
6 Y7 O2 H/ { e) B" \- f
6 o s8 U! G9 c8 j Q, s4 k print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) ( \4 s7 x: m) h" E
; T6 X2 b3 ?) |8 A& H # 最后一个 epoch 结束后生成图片 2 _" A P2 L8 h3 L! @; T+ R" F5 n) X
display.clear_output(wait=True)
" f+ a0 i$ F W2 z1 A& { generate_and_save_images(generator,
; D: o8 B- u. e% _' w6 O) {. V- s epochs, ( `4 N5 R/ R* \6 ]7 u3 S5 v2 |, n
seed)
) [; K. X( E+ h. A1 u
6 L" F v% R8 H* U1 U # 生成与保存图片
1 `2 S+ c0 G$ P9 G5 [ def generate_and_save_images(model, epoch, test_input): - n) L4 L# l$ B. D0 R
# 注意 training` 设定为 False
4 S8 G) z3 P" t$ T3 f # 因此,所有层都在推理模式下运行(batchnorm)。
( |/ E6 k3 |0 a predictions = model(test_input, training=False)
. Y0 \# W8 J& `, m& r* `
3 d4 v9 ]* w* U8 o8 `( q3 B1 R fig = plt.figure(figsize=(4,4)) 3 X3 w3 \$ F5 R) p) q; U
* g& T- x" c! F1 p for i in range(predictions.shape[0]): + C3 ?0 ?- |' ?* B/ W# q+ O, Z
plt.subplot(4, 4, i+1) % f1 e8 g9 _* [7 W
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') 8 @# T5 s8 h. N" e
plt.axis('off') . f" w" l. i S4 ~" F% t4 Z1 P; M& J% d
. c7 g" H: O: R9 `% q% R
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
G9 k' H: m( M4 {8 M, M& f3 ^ plt.show() ' @ j" s4 y/ U
, T2 i9 [% _1 @2 S2 P4 ^8 }
# 训练模型 $ ]6 F8 {1 ~( ^, k- w
train(train_dataset, EPOCHS)
# K4 i8 @1 v8 V. K
& B1 E! W) E2 D! \$ j # 恢复最新的检查点 9 D3 h" x0 y! q- z- ]6 h& F
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) " R! b* C3 i! G y: |: w8 u
|. Q6 q7 m' _ Y # 评估模型
9 X+ p( n/ W3 j' J* M, O" {; J # 使用 epoch 数生成单张图片
; ?) r( n) y+ W, N A def display_image(epoch_no):
3 x4 \ h2 b' V O+ d4 c1 Z w. i! z6 h return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
7 u4 n" B4 a5 r6 m 0 c9 o/ b6 A( x9 E9 Q- M+ E
display_image(EPOCHS)
; e0 ]: I8 ?! o9 l8 Y
' b$ X! ?$ {2 r2 l anim_file = 'dcgan.gif' : X, {0 ` s. W- Y
' w" o# t9 @) n( C& x
with imageio.get_writer(anim_file, mode='I') as writer: 6 N7 e& E; k K. A8 c) [: N
filenames = glob.glob('image*.png')
1 U( X. a' M* G. N filenames = sorted(filenames) 4 u: D% o) p* R1 j5 O( N" g
last = -1
- d8 {8 b3 s W9 Y w( |) z for i,filename in enumerate(filenames): 2 G' l& h- Q0 u5 a: |
frame = 2*(i**0.5) % Q( I E- P, g% f4 _0 r' e( Y
if round(frame) > round(last): B( O. y- p/ M4 v+ E
last = frame 7 M! g+ A7 h4 y! G' p
else: 6 v' ^1 t; H* C7 S7 S
continue $ A6 F9 U5 M( S7 f8 }$ Z; ?) ^
image = imageio.imread(filename) : M7 r7 W$ H9 x) A6 \" E
writer.append_data(image) % l' j v8 M0 L; U2 J/ u9 @
image = imageio.imread(filename)
- }" V9 o, w; Z) y/ T4 Y1 q/ A p writer.append_data(image)
+ p% |( N* B! m( q' q - f& B) J$ ?! ?, z& K
import IPython
V7 ?" h3 c+ s/ E5 s# T/ t: @ if IPython.version_info > (6,2,0,''):
0 x7 }& J/ K- Y& d9 N8 ` display.Image(filename=anim_file)
3 Z8 e/ A4 `* } 参考:https://www.tensorflow.org/tutorials/generative/dcgan
* {, |0 r4 ~! D; k8 C% d: N ————————————————
: K! i; B! K% i 版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
) B v& D) u$ ?% c T1 } 原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
% o. P: | l6 C. N4 }* Q
; J, q9 U) H' ~ , b; F* R1 z1 J3 ]) a' T3 @
zan