在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 564692 点 威望 12 点 阅读权限 255 积分 174630 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 3 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
9 m1 s v* ]+ g' Q
深度卷积生成对抗网络DCGAN——生成手写数字图片 , _0 E+ ]" x2 X0 Q; d+ g
前言
: E' @& _: b" i; q; j 本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。 : c8 |% f3 ?+ m2 K; W
- c: i! x3 i l e( _9 P; y$ ], H! O- W
. g5 l( o2 \ f" x" D ?& C, j; k 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下: 0 M! r* h5 j: o+ {
9 S; B0 o0 [4 O* A ) B8 r* J6 D- \4 l1 B9 [+ T
# 用于生成 GIF 图片
1 N- H) K% S8 l% O! U5 }& K pip install -q imageio
# v3 [- N+ ]; j: I 目录
( C/ o' k( u* k' D; h! o# |/ T
2 ?+ `/ A& U. _ ) }+ n+ D. |" d2 w4 {1 t
前言 9 N7 p' x7 l, p' ?
( L. D7 g: \6 o9 b6 ]4 ^/ W( u+ j
: _8 N4 o5 L& \, D* y5 m/ m
一、什么是生成对抗网络?
L/ D/ m8 x8 ]- K' h! p+ ^# w
! q( q# w* a% J* a' {* S ( O5 T) i8 r5 q! u2 F
二、加载数据集
4 _: R9 I1 u2 \& }% t. h, B& [7 d + y$ T( A; M/ @- P; ]8 Y! y
# |7 y; t( Z8 b- j, B 三、创建模型
; P$ S* q5 [7 B4 _8 L; {
: |% P: u: K$ T* ]! P 1 F$ W7 K2 W; y \3 C
3.1 生成器
: y9 p8 G2 P6 i) T2 q7 G
2 r; u2 l0 C0 @: e; `( W; U" v2 \ , R( h7 t7 S4 {1 t
3.1 判别器
* L7 i, G4 A+ Q# y! h$ a& i, L
& [" U. k6 Q- H: t/ w& b- z ( w* i: p2 T' L" e. h
四、定义损失函数和优化器 ) R* `' s/ r% O' U0 @8 {! Q
9 e, d6 Z4 i d+ N0 [, G
, f! Y7 N' Y& \8 N5 i
4.1 生成器的损失和优化器
: U& {5 ]0 x* K/ l 8 ?& }+ C2 N7 V! _) q/ u) [: c8 o
( z0 t! v% ^6 {- o1 N9 Y, s 4.2 判别器的损失和优化器
# T/ T. L5 j/ ]) n( o9 w
* S( I- f2 D1 s( k
; ]: L2 |$ T9 Q$ M# y8 o7 S: s4 j 五、训练模型 % ^4 ~' T9 Z# L2 u8 Z4 }
* ]! R0 [, ^/ j* p' e, t' t 9 Q @5 M5 l& u, {& p/ Z' T
5.1 保存检查点
5 Q) j0 T' |& }4 K ! h+ X7 O5 D; ^8 B
/ h; @& n0 v: x/ p, s
5.2 定义训练过程
( A: c5 K8 H, P" j2 O I# a
" ]5 y* d; O6 [3 v 8 Y9 t4 O% s( t8 I9 p
5.3 训练模型
4 P: h$ i, }; o' k: G
& f5 t! R: O# m* t+ [ 5 N+ C# S4 G2 y' n6 b: g% K2 U
六、评估模型
" N; p7 g* z+ ?2 J 3 d! G! O& }# P8 P1 n5 E: s$ r" d
- s) o; T2 G" ^' T+ `) E 一、什么是生成对抗网络?
4 [1 X# e9 H# b0 B$ i- c2 w 生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
5 Q1 Q, Z3 x# D 9 j: D" d2 q$ K- y0 U2 r
; g% M; `; I3 C2 E! k4 f) S
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。 8 Z; }' _+ O) {/ Q
5 X5 h# Y+ K# p9 T% c3 Q! d5 A$ {
. Z! u3 Q( n2 Y5 G, B8 S. l
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。 1 K8 L ~3 E6 o: y9 H2 m5 Q
S R/ Q3 g j8 M$ ` ! h) I$ `2 e/ ?. X& c5 |
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。 1 @9 u7 J. ]9 M+ M
# |" }- h7 a( W
, ]2 \$ [6 I8 l% L. A0 \ 当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。 / ^4 w/ Z! I9 a( M' O/ d$ K" h3 Z
$ i9 c' |6 e& j$ q: N % K9 V9 h" N8 \+ m* c$ Q
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
8 W+ M3 n4 d; j9 _! E 3 f& u3 E' V9 j* d4 A6 I1 k
; H0 G- T1 U5 Q6 Y/ C! w8 U. r( P0 u, D
二、加载数据集
7 u7 Q& x6 q L2 b0 `2 Z1 ` 使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
4 p' {' _% s6 k! G! T- j: _2 C : q9 ~ u( R! M) g
/ z* ~: q. I/ c% v; ?8 c (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
/ A; u- p* ?7 {* j5 J5 i) z# C
: Q }- [9 A W3 [% X train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') 0 Q! A+ U! N) K* R" q' L5 U
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 * u3 v6 L0 g7 t5 ~( Z& r
| r( Q3 W% m" Z4 I BUFFER_SIZE = 60000 3 E% D/ [' Z: R" C u
BATCH_SIZE = 256 2 c4 D& k0 T3 t- d
9 C8 I0 S' y7 w# C
# 批量化和打乱数据 + e( h% w8 | X! |
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
& ?8 S' L( B) S 三、创建模型
0 l1 t$ I& [+ b$ o+ F* k 主要创建两个模型,一个是生成器,另一个是判别器。 4 l! T, }; T; K6 y6 u/ u
9 R9 R" }$ [$ w* Z2 C4 a4 r# u
, t! d4 E( q" K$ l. X
3.1 生成器
$ w4 V" M+ E6 p/ S* B 生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。 0 ?4 c; B' r3 n4 q
% ^& R4 u# q5 p" z. k/ h
0 U6 C x$ s0 _ 然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
3 R( f2 L3 X+ E) T 6 H( {: q: E. E1 |4 M6 L
9 i, e% p) S C0 O, t4 @
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。 2 b5 Z: G! f' C) F' s- e! C
/ a h% z0 }* ^# R ; U, \/ b) q9 g; y3 K+ Y
def make_generator_model(): ; X) r c& O5 d; S {2 |- o
model = tf.keras.Sequential() 5 \( Q+ p' c9 \
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
# K) Y' j* a; P. E$ a, y" b) W model.add(layers.BatchNormalization()) ) K! E- U: |! N8 X% T
model.add(layers.LeakyReLU())
$ h/ d6 { C& D* Z' A# p$ v/ @
2 F5 A0 F9 L$ r! w model.add(layers.Reshape((7, 7, 256)))
8 ?) ?6 w4 [- o, W( d* W assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
7 n* u% x2 |1 v# |/ v- D+ Y# w0 b: g * R* `' U8 I% f8 _
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
1 `' {0 u: A, x1 ]! u0 N0 r assert model.output_shape == (None, 7, 7, 128)
' K; h' W/ x+ ~ model.add(layers.BatchNormalization())
1 R8 I. R0 [2 [ model.add(layers.LeakyReLU()) 1 P; g, T6 p& k. W* R
% F5 `; T Q$ J9 S' d" W model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) / |1 K5 n1 d/ F+ j* L) K
assert model.output_shape == (None, 14, 14, 64)
, o D& D- }/ X' J; x2 p model.add(layers.BatchNormalization())
' q! }, H$ b5 Z( N$ u model.add(layers.LeakyReLU()) / e5 G2 c) Q0 |! r# {5 R2 Z) o
; R3 ^& V1 {* V7 d( F model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) ) L1 X" g3 d. C% ?" q
assert model.output_shape == (None, 28, 28, 1) % ~4 Z m; J; ~# B/ N
, Q( ^1 Q; o+ r/ q; i- P return model
0 x9 ?1 E2 L& H2 Y8 I 用tf.keras.utils.plot_model( ),看一下模型结构
% X% n) ?+ E6 u0 X C: E0 S* [' h4 [; P7 }, B4 q
0 D1 m# L) [' J) t1 L) d1 u' F
3 R8 E& E' f8 u" O
( b" y9 e7 |1 b9 g. A0 y( { % ?3 ^% T8 [1 {7 y9 T0 e
用summary(),看一下模型结构和参数
% R/ `$ E7 ~( E) A7 d& `# v5 w, F
) [- X1 ?7 P2 F) W5 ^ " L4 g4 M6 w; {& g
8 M6 ]: I8 T8 d3 }& J
! M# \& I3 v% Q/ m; ] 1 F) x( T* S. b2 d2 h) N
/ G& J' o' m, ^- |' K 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。 5 S/ Z4 y j: W, c% o3 P6 f
+ E0 p) j3 z- t' C 6 ~& u1 K4 M$ c0 u% W# J u8 t
generator = make_generator_model()
( i" `* |% L/ X9 q
( |+ E. f. {1 ]1 A+ F4 q8 B noise = tf.random.normal([1, 100])
8 `- W0 s l2 ~+ H3 v0 h8 S generated_image = generator(noise, training=False)
9 j E- f1 r3 C0 a 0 X/ A* P+ Z2 Z/ p! [
plt.imshow(generated_image[0, :, :, 0], cmap='gray') 8 }3 z, a4 x# G
, x* u x' W/ o2 W . z7 Q1 z, t* e4 a: U/ b& o7 ]+ k2 ]
. N/ s( I! I. V7 m, q6 e
3 h! {; q' m9 s; p& t) S" h! T: N
3.1 判别器 : E# b, ]# M6 O' D1 w
判别器是基于 CNN卷积神经网络 的图片分类器。
8 t. ?( ~4 i3 T
8 v V+ W3 H5 a1 T9 G) L 3 R6 c$ V) a5 N- J0 d& R2 ?
def make_discriminator_model(): & F& J& r% f; O+ Y2 H
model = tf.keras.Sequential()
/ O& H H8 Z# \7 t/ j+ m model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', ' J1 H/ [$ Z. l: C) F# c
input_shape=[28, 28, 1]))
5 Q1 M3 t4 o" R4 B, _ model.add(layers.LeakyReLU())
7 h* Q; J& o; L model.add(layers.Dropout(0.3)) % d4 m4 u2 i' Q0 ~# P8 o0 [0 M
v* E* J1 ]- s; W& i6 j, R4 C) A
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) , h8 h! e3 F, k' f. j
model.add(layers.LeakyReLU())
+ ]( t( t7 Q! `, I% \ model.add(layers.Dropout(0.3)) 2 @/ ]5 h, X4 U9 w0 z7 ?" q
7 E2 n r6 y0 C model.add(layers.Flatten()) 2 s' N# W; c% j) Z7 _, m* C
model.add(layers.Dense(1)) # e8 f2 b% M; ~5 v
$ Y) u. \: r- i% f+ \4 D& w) H. O return model 6 u9 S u( ]* z& y8 i* n* M
用tf.keras.utils.plot_model( ),看一下模型结构 0 I' m% [# K# c
0 H4 L6 T4 o o: |
1 P4 w4 H/ `" c% M- x) O & r0 Y$ A) V, ~$ J: f
' j& ^2 x+ h' p3 U5 ]) E
9 w& x6 o( ]) u8 S1 \8 k, j) K1 T! v) B 4 i1 P1 H1 s9 V) E) d# A ~
用summary(),看一下模型结构和参数 4 i9 R- R/ N; R) K
+ H8 Z$ j+ ~6 j- p- X. V Z , a" ^! }4 K" o( @3 L3 W. Q
/ p5 M) j8 Z9 K! z- o; G. |* Z 6 C( I0 f* U, }' g
* X" J3 T$ V5 e% L, X
. Z, C$ @6 J' L, V, ~* u$ y/ u; z0 M E
四、定义损失函数和优化器 2 n X1 T( I! S
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。 ! O$ q6 F8 A4 k Q2 ?
( B B0 w" K. w/ i# s 7 B3 W9 h$ o+ I" k' p8 Q" k H: u
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
4 W1 z: z2 b' i2 H/ g * `" a: N6 L: ~8 H* u2 ~
5 [& F# z& e D; _6 r9 a # 该方法返回计算交叉熵损失的辅助函数
# I! r2 h4 x2 r5 ^ cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) $ [, h% N6 M O) B! F
4.1 生成器的损失和优化器 " H3 s) t) `0 M2 y7 ]5 D
1)生成器损失
% {5 E+ V+ {9 I' @' Z( d0 [ - M" w* I; r2 Z
0 q0 M& s7 V% s4 Q3 f1 g 生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
9 i% G" h4 i+ Y# W$ N- b 6 X/ S& C3 O3 m& \7 X: J1 |
7 w' t& l- p9 S/ G6 C6 h
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。 : L! B2 ]- l9 e% [( s: I
' ^) {2 l# V) _0 x/ w; {; t9 L# J
9 f+ r4 J9 Q& F& `/ P7 R& @7 P( a def generator_loss(fake_output): * m5 Q- w3 H. P. }+ p' w
return cross_entropy(tf.ones_like(fake_output), fake_output)
2 M) G& T/ F" }' k; J 2)生成器优化器
2 {) |& q) b$ A3 K! D0 y) k4 X3 s
7 E+ U. n8 X2 _& e1 N, X
5 |, C/ y- [ S6 J: R& Z generator_optimizer = tf.keras.optimizers.Adam(1e-4)
; |. R2 c7 H( f4 J0 n) d 4.2 判别器的损失和优化器
' d: ~) s' W5 }+ G' x G' u0 b 1)判别器损失
B1 b/ h) E9 ]; c% p5 m' _, m 2 T+ z+ \( V* ?& O$ f) d: E _
' R2 {2 H% d c8 g6 {
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
: V) j2 ]2 e, n2 h& @ 7 p9 U# K7 y$ [9 d' t$ s
- U% O2 v1 G7 U1 Y/ |
def discriminator_loss(real_output, fake_output):
* T1 q K8 b) W" F4 K real_loss = cross_entropy(tf.ones_like(real_output), real_output)
3 y' C B1 h# e, B3 f8 J fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
3 c4 e+ p. U- b: S8 \ p total_loss = real_loss + fake_loss
. }: i. S2 B: i2 R' w return total_loss ) ? X: |: w5 e/ f
2)判别器优化器
D4 f+ Z8 h1 b ) y. j4 u6 X8 O; p+ h3 x& ~' o
3 w2 V5 ]! r9 u3 n4 p- C discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) ! b, `9 f0 Z( C
五、训练模型 5 p2 S" O$ k+ z
5.1 保存检查点 , ?5 e. R( H& }3 @
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。 & a1 c2 k* ]$ L! l: ^5 U4 h+ a* F- j
7 g' n& R, s) Y' Q) f5 t0 R
6 {. `% l0 H6 A) n* U5 t2 Y' G9 P
checkpoint_dir = './training_checkpoints' D8 V+ O6 V6 U: X2 R
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
7 H7 t0 z1 t# ?: { checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
+ }" ?/ i' {$ x discriminator_optimizer=discriminator_optimizer, . m& q8 X% w+ Y9 R0 s: W2 g" E
generator=generator,
9 g6 w: o7 M6 q; Q+ B discriminator=discriminator) 8 H5 }6 U' t& c. G H& ]# ~
5.2 定义训练过程
/ q% @! |( s' ]( }. s6 ^* L$ Z EPOCHS = 50
: n( b1 b8 K. ? noise_dim = 100
" N8 U5 J3 T) L+ q' k$ P num_examples_to_generate = 16
* I( W6 y" m8 O( Z& s& T V3 _
$ J; L5 | h w/ u3 j& @
3 c4 x0 k0 `+ P R- ~+ ] [ # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) 3 ]5 c( |4 N+ T' d- ~6 D- n
seed = tf.random.normal([num_examples_to_generate, noise_dim]) 4 N. b/ k' V, L* ^& l
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
7 _8 `* L X! ~- ?( Q- k + b! }6 k& _) V" q6 F
; B5 ]7 M8 _+ [: J( F7 k 判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
0 _ a Y4 M9 A) }' f/ \# J 1 k+ {# S' S4 w7 [
5 y' L0 M, }7 t: T 两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
s, O- k, @5 P& H1 b ! n' i0 Q Y. ]& @3 P! H) ]
' }; [4 u5 H- Z8 c2 v/ V# ~3 P # 注意 `tf.function` 的使用 + j/ g f% _5 k
# 该注解使函数被“编译”
. v ^4 n3 G# T- o+ p- ^ @tf.function 6 p/ z; q) J) g4 ]6 R
def train_step(images): ) B* `; [& [) {% ]9 m; M
noise = tf.random.normal([BATCH_SIZE, noise_dim]) & X5 ~9 v* \$ Z9 L0 u* r
) l' B W, ]1 d- F# Y with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: 4 {4 b/ V5 I1 j. v. f/ k) I
generated_images = generator(noise, training=True)
6 P' \7 P) H8 G/ z
$ _2 z1 S4 t6 |' F real_output = discriminator(images, training=True) + E, _, A; y0 f! W# ^
fake_output = discriminator(generated_images, training=True)
5 ?% K$ w& m+ q6 w
* h; Q- Q. Y; j# B gen_loss = generator_loss(fake_output)
# J7 v2 c7 k3 E" `& L' b0 ^ disc_loss = discriminator_loss(real_output, fake_output)
- V4 N& Q: Z& P! n5 o! Y
" ^2 Y. {0 I- G$ s% p gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
5 h; W p! x9 Y# z8 t gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) # d& f g2 T( r' |
# H1 z: f" L( `- s generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
! ^7 {* d* j0 t$ e& W! Z ] discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) . c, N* q9 \. V
3 t% t, B ]9 _& a" u2 F) I def train(dataset, epochs): $ g5 y! j; i0 _) G* O- C- J0 b
for epoch in range(epochs):
& e* U2 z+ c5 a& m! X1 h( |% n$ n start = time.time()
/ y% ]+ _& B# `# p ; D" u% C3 e( O& e( j
for image_batch in dataset: " u) p' V5 X4 R8 f' a4 P6 W$ a
train_step(image_batch) ( K7 ^& `+ Y e* | i0 D9 G) X
1 W8 m2 a' E2 {" V/ J" Y5 l9 `+ _ # 继续进行时为 GIF 生成图像
4 O( D% X5 \1 ` P display.clear_output(wait=True) . b4 M. |1 X5 m
generate_and_save_images(generator, o8 C, i$ \2 a% z! }1 h
epoch + 1,
! j( C; j; N# I* \6 f0 s. u- p seed)
& H% j8 q$ ]) @, W* r ( v2 g2 P; B `9 H+ y8 [, k8 X
# 每 15 个 epoch 保存一次模型 6 j* Q) }. }7 E9 ~ P: m
if (epoch + 1) % 15 == 0:
: N% j5 h; l7 y3 e+ c checkpoint.save(file_prefix = checkpoint_prefix)
$ m9 K9 r2 C/ t7 g" d# S k0 P' v0 u0 u# ]% |
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) ' A; S# |+ m K- U$ F/ T
9 k; `3 S, z+ x O# {3 E4 P+ e/ U # 最后一个 epoch 结束后生成图片
8 W$ j& Z! \3 f display.clear_output(wait=True) " k, C# f1 C5 u$ O
generate_and_save_images(generator,
- D: E) q1 f& K! ] epochs,
0 Q% b4 t$ u3 q* f! p seed) ; o, B/ k; g- ]- c: y
' y! C2 \: Y, _' m0 _ # 生成与保存图片 $ E( E: ?2 t4 H. Z; c2 l. y6 I
def generate_and_save_images(model, epoch, test_input):
0 p+ j! A4 j0 K& S8 S # 注意 training` 设定为 False
# k( r1 \: p8 p; X& V, N' I% Y # 因此,所有层都在推理模式下运行(batchnorm)。 . {2 F, v+ y+ |9 ~" `9 M
predictions = model(test_input, training=False) 9 F0 T4 I5 Z+ I0 q& S( r* e0 E
* H0 x: Q b! ?. x
fig = plt.figure(figsize=(4,4))
: n& O- z: B3 z: F! C - Z# a+ Y; c* y( z
for i in range(predictions.shape[0]): ( ~/ r* Y. A: x5 O# P; Y
plt.subplot(4, 4, i+1)
$ u9 e# e4 M1 Y o! M plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
4 M$ q* @ O8 Z4 S0 @ plt.axis('off')
5 }3 N/ A; n( b2 b
8 A/ L- H x) w# L) u, y2 V& m plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) & j2 D1 R- y V; }
plt.show()
% Z" M0 V' b) u2 k; ?9 T7 ~ 5.3 训练模型 : n" ]" l5 o* t5 ?& z y0 Z
调用上面定义的train()函数,来同时训练生成器和判别器。 / G% k, y: f+ h' ~% R
/ @% l, ] Q6 s" y Q. k; u
2 r2 }' d; r$ Y" O* o% C2 S 注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。 2 j e; k0 B3 X4 `/ h
6 x& m& ]0 h; z
8 k+ O: g* x$ W* g+ j1 J %%time
6 _2 N) `* c5 h/ I5 {! C train(train_dataset, EPOCHS)
7 n: a* {; c" C0 ?+ e" h2 y 在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。 0 O% a' S" ~5 k
2 ~2 ~7 W) n* c. m0 i* q, B
1 V) H) J2 w, m 训练了15轮的效果:
9 v3 l5 \1 s* `. A
' T e# r! A1 J Y# A) Q8 d7 w 8 m# ]: B8 S3 b
! @/ I4 r, e; D7 ^ J
; {7 y3 [( H |$ F V# T
$ M; K* e! y5 a) L' L
4 P+ h* M3 q8 ^% }) P6 d. }! x
训练了30轮的效果:
0 `5 c- u7 Q2 T: a+ I# V8 n# I9 T* | * S+ Z* |/ C" m6 X6 h0 d' l5 ~/ s/ o
2 m- D# c( m# g+ [2 ]' \6 P
8 ]4 }- t+ g/ ~8 T
1 L) G9 [' T. W/ n& Z ! O( ^/ S9 b* k) D& o# G: W) b& [
! d5 q+ A/ h) s( v
训练过程:
' P$ `5 c" E7 |3 l 5 t9 j3 b1 O4 q+ x
3 Y& F' U( c' Q& D
5 y' {; ]* ?# w5 P9 e$ J
: P+ m ?! Q. ?) A
/ Q8 a3 I! r |* g/ y
! v+ D- v, U4 B* c3 v 恢复最新的检查点
# D* j5 K/ ~- O4 u/ p( V3 B 0 x, z3 n' D1 V/ Y, o' Z& w
" O3 W( z& I% A/ O$ o* S
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) 4 b6 t' H) h9 V7 r; T% C* p
六、评估模型
6 A$ S8 ?7 M1 H# w8 B5 i0 G 这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。 % G; s9 F5 Z2 B7 k3 ~
- t7 U- J) ^' K 9 i+ q( ~+ H1 I8 A+ z ?9 z$ v
# 使用 epoch 数生成单张图片
' I( Q3 J" E( e9 g1 L3 o0 Q4 f def display_image(epoch_no):
5 _- {3 u; ?! F/ a return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) , O' n; u9 z8 O/ C* L+ t* x! l
$ f) w8 J& S3 @" O) `
display_image(EPOCHS)
. A! B: }& a7 A3 x anim_file = 'dcgan.gif' ; M& _! P0 y" J9 O+ O/ B
9 ]- j: V# e* p% D6 w with imageio.get_writer(anim_file, mode='I') as writer: ' |8 G3 c: Y% f1 k
filenames = glob.glob('image*.png')
% ]/ o' v% G! j( W+ V: B, b! M filenames = sorted(filenames) - D3 @( x9 T$ }- o6 a3 P
last = -1 6 E8 Q/ S J7 I# i( U9 [5 b
for i,filename in enumerate(filenames): , c# ~+ }3 Q- J s3 b4 a- N
frame = 2*(i**0.5) , m* U+ O1 K! E$ S8 b
if round(frame) > round(last): 1 W9 e; J7 v9 F7 E/ X# _: g% Q' k
last = frame
: B7 W: N( h) `) r else: ) m( N5 e! U$ e9 f5 C
continue 7 S0 ]& ^/ o( n8 b" B2 v& q$ K
image = imageio.imread(filename)
% u1 t" j M) c8 \! [ writer.append_data(image) % U; x' o; Z9 Q0 z) t. N! }2 }+ P, O
image = imageio.imread(filename)
( Q( f B( ?2 Q$ ~# z" ` writer.append_data(image) 5 T; ]( l; n: P5 o7 e; Z
; W: q, X% w5 n/ ?5 j& Z3 I import IPython
% H9 P) E) |* {( z if IPython.version_info > (6,2,0,''):
' F8 Z3 r1 X3 z2 M display.Image(filename=anim_file)
* F( Z- m! I2 P: D
8 ?9 A& J0 a; p% T
( U/ X! _" c0 t7 |$ M3 { ; d9 J" Z9 X# k) K, |! }" h
]; i( O& q. G( f7 D
完整代码: 1 H W1 E/ p, n% Q7 T! e8 c, Y
) |7 N# X8 @% d; `1 P
1 F0 w$ S/ o" o9 F import tensorflow as tf
3 `) o5 g# G6 u: B* r( }" p- y+ o import glob : Y) u+ w7 C* r6 j, Y
import imageio
% t. n4 k; H5 p3 ?9 c; n9 N import matplotlib.pyplot as plt
! t, x8 t' o9 b import numpy as np 1 u0 e8 |! v/ n) N+ M8 p
import os + {$ B* \! n0 S6 @
import PIL 3 D \7 ?6 ]2 J, J
from tensorflow.keras import layers
! G9 H2 }' i( _ import time , i( P3 u% v0 i, e4 T. x }/ Z
& f, i" Z9 r/ u# l* @
from IPython import display 1 U3 C( m$ D9 G
8 `& E5 g% g( \% s# y8 o) T _5 p (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
& \( ?( z7 ]1 j1 j& p ( L1 G. x" p$ @& z2 T
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
- R/ l5 B2 H, I, B+ P train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 ! _. s! o; m3 d4 U. `7 }
4 m, e0 f! x1 Q3 z. N4 J9 q! B' }
BUFFER_SIZE = 60000 # K, b9 `1 ^/ |1 c* W
BATCH_SIZE = 256
7 f U6 A* L. D+ ~0 q . t( S% y/ g- ~
# 批量化和打乱数据
0 G$ [; J* L$ }9 A train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) % `) i+ q8 E. y6 R+ S8 w) ~! j
7 Z* N+ A% v) Z
# 创建模型--生成器
0 e1 M9 R$ O5 o def make_generator_model():
, g2 j6 j" E8 U. r% ]; Q model = tf.keras.Sequential() 4 s2 Y" O7 b5 y- K U3 E
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) / K2 G6 K) |/ ^6 E, D
model.add(layers.BatchNormalization()) ' m$ ^% q! W% v! R ` y0 o
model.add(layers.LeakyReLU())
& Z( a' s: G9 q$ @0 @
& ? O" ~6 C9 R) M4 c model.add(layers.Reshape((7, 7, 256))) & I. l. t4 T9 n4 T) F" f2 X( a+ D
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
p; O6 g) G6 l4 G, g 8 c" p, z: v- M' T! w8 B8 m/ ?: C f
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
5 Z* b3 N. z& L3 }, v assert model.output_shape == (None, 7, 7, 128) ; A0 t3 l# C4 m5 D! J; l; ]5 U! T: C# J& [
model.add(layers.BatchNormalization()) 9 m7 L* E& l$ g2 W* e
model.add(layers.LeakyReLU()) * N2 e3 l2 X5 }2 H
7 k5 D5 k1 ?" s model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) # [& q x& L$ R% h$ D( \: ~- u$ Q
assert model.output_shape == (None, 14, 14, 64) # d8 E, y3 h* _; T/ v
model.add(layers.BatchNormalization())
: l8 \# h/ m) B/ A6 o# s model.add(layers.LeakyReLU())
- x( x, o* b+ r4 Y; |+ M# e
" z5 S- Z$ t* p- ?3 t model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) ' d: w2 H" M9 K
assert model.output_shape == (None, 28, 28, 1) 1 ]5 J+ `; A- b" W8 r
, s' K% C/ }. t# h
return model 5 S0 q: ~0 k2 ?& t
9 f3 C+ ]9 _( _% n3 `. a; u% n
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。 ' l v, y" ^9 z! n% `6 l
generator = make_generator_model()
) s8 T7 f2 i+ h2 Z' _- o
' N0 H) t; s- r9 _+ L/ \# S noise = tf.random.normal([1, 100])
# }1 R! S( }2 z. H+ H generated_image = generator(noise, training=False)
. i, h7 i: c* {) p3 _' S. z( c
1 o: j: Y% L3 {. t" u5 D7 f plt.imshow(generated_image[0, :, :, 0], cmap='gray') d, k3 @0 q: \: o
tf.keras.utils.plot_model(generator) ; m& V) V' n$ g" W- j/ b) E
8 x/ Q$ V3 t- p) \, }
# 判别器
. x$ H8 u. c) |9 n( O' i def make_discriminator_model():
* X* e0 V: K5 L. W& d4 [: C model = tf.keras.Sequential()
' C- b. g. s' i. W+ y1 s f0 y model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
* Y3 i7 O+ L f input_shape=[28, 28, 1])) & x; C" A: m: ?0 F
model.add(layers.LeakyReLU()) # v8 J- t# f5 @% G4 G& I% }
model.add(layers.Dropout(0.3)) $ v, J0 }) C* j9 _: }% k% ]9 }
4 @/ _; c/ T+ z2 ^3 U- } model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) 5 @/ g9 A+ }7 Z% w9 r
model.add(layers.LeakyReLU()) ' c+ P4 k2 L5 s N! u+ U
model.add(layers.Dropout(0.3)) 4 \* W. ~0 i A1 |& u7 R7 B( V
1 I( U K$ Z7 c4 @4 Z model.add(layers.Flatten())
; j# G- [' Y0 H6 H- v/ t1 t$ P5 ~ model.add(layers.Dense(1))
# e! z9 P: h! _+ I / s& ~/ X3 I- S- Y8 J$ D( m
return model
8 A$ Y6 X$ {' L# p1 v4 J% v ! x, p4 p# U! |; W! E: X
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
7 T- u& M5 T& D- l3 V discriminator = make_discriminator_model()
6 b7 a+ {# T1 c, n! G9 L1 r decision = discriminator(generated_image) + }6 W) v5 Z& _8 B; u" V! m2 P4 O
print (decision) & X$ v; M" v" O; |3 g2 A, B
% Y* C! c) A% B0 d # 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 ; Z5 o/ @& F |9 {/ I
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
3 ^& i: @' O* Q# M I9 Y 8 U5 ^6 h6 Z, W4 _; t, q
# 生成器的损失和优化器
) q: f v4 P, Q, y- Z' f def generator_loss(fake_output):
; j/ b' N: s( q return cross_entropy(tf.ones_like(fake_output), fake_output)
! D8 Z7 C1 D3 u C0 Q' C generator_optimizer = tf.keras.optimizers.Adam(1e-4) - l( Q$ n% A8 i) b( h' V0 _
3 ?. ?9 O3 V( g* \1 Q" @# }3 C
# 判别器的损失和优化器
" f9 x; r! [7 K7 w7 s+ y- K def discriminator_loss(real_output, fake_output): 0 S U w2 x+ O7 @ T' T
real_loss = cross_entropy(tf.ones_like(real_output), real_output) & J8 ^: F, p$ r
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 2 C! G& F9 y$ y& n+ H
total_loss = real_loss + fake_loss 7 ^! f7 D- K2 h& f- {5 Y
return total_loss 9 Y9 F; t) z+ L) Q3 ~0 ^# |
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
1 ~$ |6 M, k7 a% o Y7 P
2 U/ V' C: ^" u4 C7 o- C; O # 保存检查点
" l: ~3 s6 l: A checkpoint_dir = './training_checkpoints' 3 U' `% \& [% c" B: o" [: u
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") - i) `6 S6 y s6 A; a
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 5 q9 q0 d- g6 S9 o( ^
discriminator_optimizer=discriminator_optimizer, 3 {: u& R9 v7 c0 A$ L
generator=generator, . n O5 j0 ]' D& v# I! J+ O2 N
discriminator=discriminator) 7 i0 x; z g' D+ B6 m
( ~1 F. B" b2 D( [. b7 f # 定义训练过程
1 \1 z8 Z T; \3 L- M5 q EPOCHS = 50
3 K' {2 n# V' a: Q noise_dim = 100
. W7 d) E- Z9 I! z) d num_examples_to_generate = 16
+ T4 s# N# p* | 5 P6 x5 i% v" Y4 L
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) 5 p+ x; ]% ]3 ?0 o; S( _, }7 D( Y
seed = tf.random.normal([num_examples_to_generate, noise_dim]) ! x' Z* u9 s2 S3 v/ _/ {
; ?; _8 x4 U6 `4 y. I8 g v( v # 注意 `tf.function` 的使用 + T$ |8 w$ ]# M: I# a8 g
# 该注解使函数被“编译” 1 x1 Q" t& U& \6 l, ^# d
@tf.function
7 Z# B. V- g1 @) g- E def train_step(images):
R& x3 W8 [1 B& u noise = tf.random.normal([BATCH_SIZE, noise_dim])
# W7 F- z# V3 Z5 ]5 \6 X0 d% i5 @* ?+ R 0 Z K9 F8 {) ]4 P- M
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
' D- p, N5 J3 Q. d' |. e generated_images = generator(noise, training=True)
4 @' h2 H- C' j* o. G$ C' b' G) B) \
' j J ~. B; l/ U1 v) o4 { real_output = discriminator(images, training=True) 8 |; I0 z, R6 V& J+ S2 ~3 } I
fake_output = discriminator(generated_images, training=True)
7 L+ ^! h& z, f5 S
# L6 k/ V; a1 {2 _) ?$ U gen_loss = generator_loss(fake_output) % X& P0 @+ F# c0 u! x0 F
disc_loss = discriminator_loss(real_output, fake_output)
# r: t) J' d5 ]8 l; @) L8 [/ ^ 7 G% a& Z1 q: D6 z' L4 |; Q0 T5 m
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 2 j* f, \! L1 k
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
5 G% T' ^+ F* ]$ M! y. B V4 s$ `0 b: V( A9 Z
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) + ^! ]; S5 @0 B/ L
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
9 _: [. H4 B8 ^2 m7 Q
& s# G4 _7 @) _3 H& q: e: Z7 O def train(dataset, epochs):
" L: z% ^; L7 X9 @6 y for epoch in range(epochs): / F! H" F1 t5 R/ i& t/ u
start = time.time() ) Q/ p" p2 ?% t. v, ?- ^1 B/ {
# E' N5 P# b5 U5 y1 M* q! B for image_batch in dataset:
6 W" d6 k, {3 E( b) r' Z4 J# ^ train_step(image_batch) 9 J" K# N8 z2 d1 q
( j3 U" ^4 l" W( v* d # 继续进行时为 GIF 生成图像 ; G# Z8 \) S7 O# }5 z$ t
display.clear_output(wait=True) * y* c. u9 E8 k8 Q! V9 \% [
generate_and_save_images(generator,
6 U; k; B; x$ ~( M; G2 D: |$ V+ ~ epoch + 1,
5 e# k- Y2 g* `: a# u seed)
8 B) Q- L, F ^1 K
$ b1 N S s8 i) ?" y # 每 15 个 epoch 保存一次模型
8 o* \% }8 t: V* K/ e if (epoch + 1) % 15 == 0: 7 M: O' ?4 h! v4 U& L( j$ @' B6 z6 C4 W
checkpoint.save(file_prefix = checkpoint_prefix) 0 y6 n6 }' R. S4 T8 n* E+ U3 V
4 B1 F: R; `' v( r: u" B print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
+ H) T, X$ k5 z2 u
# e9 {- B) h! I F+ {9 d8 Y # 最后一个 epoch 结束后生成图片
9 H2 I' N& ]9 z# W display.clear_output(wait=True)
! |, l5 o. w) D3 s: ^8 B generate_and_save_images(generator, * y# X* N( m# y' t' d
epochs,
8 p' B( ]5 F8 G7 B1 n, ^ seed) $ S, F4 J$ C6 T# r
; Y2 M/ u9 w* s8 p' U # 生成与保存图片
1 J- b! H7 x7 S9 B: P def generate_and_save_images(model, epoch, test_input):
) n j6 m4 Q5 ?3 }/ B8 ~ # 注意 training` 设定为 False " I k) z- G: T( ?
# 因此,所有层都在推理模式下运行(batchnorm)。
( H# q+ b$ H6 R" a# A predictions = model(test_input, training=False)
! m" _# H4 c L4 ` b9 G
+ A* D7 q; z m" t+ d/ e6 \+ M! z fig = plt.figure(figsize=(4,4)) $ y. U1 Y8 T5 f/ a( P& L) K
. L2 H& Z7 R% q$ S for i in range(predictions.shape[0]): & Y1 r8 n: r& O- c! Y
plt.subplot(4, 4, i+1)
7 N6 }: i1 |( U4 @! [9 A/ r plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
; [ C9 i- l( p" R1 e# q plt.axis('off') 1 U9 g9 _ N1 Z8 B1 X0 R+ W( f
: g& N/ Z: |. x3 ?# I. \ plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) 6 s; ~4 i; T- x# _2 ]
plt.show() 9 Y2 t+ A h5 f' `; T7 Q( v- F; ?8 _
) \1 |# q$ }5 H1 E # 训练模型 3 H. ]+ e; f$ Q$ r6 U; _: w
train(train_dataset, EPOCHS) ) j' p( Q) v+ }: P
! Y+ [* b2 Q) B6 [ # 恢复最新的检查点
/ D% G4 S, `1 D1 g1 N checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
$ I4 w' r% \" c ?. V2 z 3 r5 b% @ i# O) J1 |0 [, o
# 评估模型
6 J2 Y8 P! k9 h# v # 使用 epoch 数生成单张图片 4 b) p' ]4 k+ C1 b, B9 K) v) Y
def display_image(epoch_no):
" B- z! z) K& H- R7 X return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) $ k% W( j0 S0 D6 Y0 ?% T( Y
Q1 p6 p6 |& ]5 @& N" [% u
display_image(EPOCHS) 5 s: A; f% |* V; n0 J8 b3 y; J
( p1 Z3 J/ N1 ~5 F: x: I( j anim_file = 'dcgan.gif'
* V3 C; ^$ B' R, Q4 X) p6 | - F* Y8 [: k2 e* E% \5 U
with imageio.get_writer(anim_file, mode='I') as writer: ( k# {+ W) a0 n& v7 m
filenames = glob.glob('image*.png') 7 V' F5 |& S$ O6 z5 d# f
filenames = sorted(filenames)
7 s+ g6 Y- ^. U/ E' q7 d# e last = -1
1 h: L! k4 Q$ e s) a. k( g s for i,filename in enumerate(filenames): ) |6 d/ D3 N5 ]% _
frame = 2*(i**0.5) 2 Y7 e. I* {' A6 I3 v# F' j
if round(frame) > round(last): , p/ Z) L1 j9 k/ y
last = frame
" U5 `6 X6 k5 |# s3 S" P' R3 j2 t else:
" l0 l& h3 E) N5 A2 c continue / V; K1 P: A) P- G# J7 k
image = imageio.imread(filename)
' P# Q3 c: D8 i writer.append_data(image) 6 h& e+ i/ w0 ~0 D" c. a
image = imageio.imread(filename) 5 K# b- F0 E9 i# }
writer.append_data(image) , n8 U4 F, V7 R/ ?' l( h* ^
( k0 n( q$ L. H" ^; e. A import IPython
7 E4 W2 \' r# m Y% ?1 b3 I+ G0 \ if IPython.version_info > (6,2,0,''): " F& X" E5 ^* Y$ R& U
display.Image(filename=anim_file)
/ W8 Z7 B4 S2 `" `1 s* C 参考:https://www.tensorflow.org/tutorials/generative/dcgan
+ i- l$ e' ^, p, V/ Q ———————————————— 2 ~2 N' Q8 ]% c/ f% W! i
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
. g! D" Y0 r4 @- \1 F+ I! v 原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
4 X7 W8 X" l+ T# T 3 C% L2 I/ R x; k) T7 A/ x2 {; X
- ]6 q4 M, O- \) Q$ `% e
zan