- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564455 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174559
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
) R% z" a( v5 m0 t0 o. _3 Z; |
深度卷积生成对抗网络DCGAN——生成手写数字图片
* r! r. c0 F# o前言$ |% E1 f" D4 s7 _. W* w) O
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
4 w/ z% i- p( f& E Z7 r; {: ]! O, D8 V0 X( a3 M
0 m$ ^8 Z0 `6 E, L 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
2 |$ c8 i; J! b
8 Z& Y% g$ _3 y& ?, U7 t
' z6 i& |- A0 S8 o# 用于生成 GIF 图片
4 u& C: a/ X2 C! ?5 [ ^pip install -q imageio% F: |6 t* E3 _( X& E! W5 F3 [
目录
9 Z/ v7 q' o& {- R6 q) Z. H' F7 D+ H, @9 R* ]
w8 t* a# @1 E) ^& Z0 L" I. W2 Z" q. a
前言
% N5 J5 L3 c3 W* U k% c( e; J0 I
+ t' r" f0 w5 ~/ Y' z- L) B8 G5 b/ I( D# t. ]
一、什么是生成对抗网络?
0 h: L1 P/ ` Z% Y) {* l$ s! t
- C U) L) c) S7 x3 W3 X; r) g0 q% ]+ ]1 w
二、加载数据集2 z; G% z4 h+ O5 C" _; _( k0 F$ z7 c
; _: O' b0 q9 T* Z0 G# t, \' z' l$ I$ e/ q8 U3 x0 A
三、创建模型- @) ^. Q. o0 [8 q7 d; X# ?
# ~$ F2 e. j( |! C2 a$ `
! f0 \* e0 z6 X3.1 生成器
% B! q! D2 b6 V2 y! T6 H4 }: d$ q/ R9 j( W" S" r5 T, Z& s6 p4 }, i
+ d( m Y# b% ~
3.1 判别器
3 o0 f2 s& `' o2 W" P1 a" L! P7 Q" x; |0 O3 m* C0 o+ g! s2 U! a0 C
7 h/ w* @' L5 ~& `
四、定义损失函数和优化器* i, P$ V8 f/ R, v/ u9 p7 Z, @
& b6 `1 e) [% n. B1 s8 g# U( e% _5 Q) S3 e4 n
4.1 生成器的损失和优化器% f/ R/ X0 n+ c$ E m
" A: v5 r5 ?$ J2 {# w8 G( C$ A$ x; B3 u
4.2 判别器的损失和优化器2 [7 B2 g. z2 S1 l+ s
& X4 l1 }' l$ a0 E; K7 G* I4 N3 H
6 H' A6 V" V- _ Q五、训练模型
+ l& O% Z" J4 J! ^$ a, z; @& P* N: m( D) o4 p' H( @
5 H: b7 g4 d; ]; r' k: n+ B' T5.1 保存检查点
/ M8 C6 b9 F" O# R* z) C% ?4 H' a1 g5 p: H- o
4 x' D4 i. D* S& s/ P5.2 定义训练过程
" W2 L3 A; t: Q: `5 A8 J- g: ?4 `/ y; G4 a, ?# G9 ]
& I1 r* J) ^. C! x; q, \8 C# j5.3 训练模型1 p; {& x# ?! j) R1 Y( q
0 S! e+ M% [% ^! H
* u! Y* C+ y3 h! e六、评估模型0 o( {& \" a1 W' x4 z9 q$ w: }* T
: }. d0 {5 g6 `& ?* @% B
2 B6 b/ T+ l2 [" u
一、什么是生成对抗网络?
* i# w4 ~6 k7 n生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。: D1 H k k: t$ n' e
) ~. i$ T- c, T L$ D9 v
% {! k. Y- ~: Z ~5 D生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。( h2 F6 T3 a( x1 f8 v8 v! m6 i
& I4 l. D: h0 J L
# F" H4 a* t& z' ^! j判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
6 h5 A* h1 e# b2 D$ D8 r" }2 W( r3 [; }; D* S) c
5 F& W* b0 V4 M训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
' e5 E. M- i& ~8 @, D' ^
. i# m& R; Z+ q/ P# i) C* w' S9 ~# Y5 w* _& d
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
1 p% Q: H5 N. z
# W3 W# W. t+ Y$ H/ h
0 ^/ _# g% ?9 z4 e, h本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。6 J+ B0 t8 N* v
( v0 U5 j3 l- T E% e% I6 C0 ?
6 }( i- i9 X# ^# V" W, d- u0 R+ {/ |二、加载数据集: p( {3 n4 J* g
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。; I7 a8 O: c4 e! X/ H8 d
0 K1 ]9 J$ i3 @4 S
. E2 i4 w& r8 ^2 X$ E
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
, F2 Z0 w P) j' e 7 z s3 v1 V; d
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')$ |; {- e2 G9 _6 v% n' E
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内1 n. L; m# j% B# |
) Z2 c6 e- A3 H7 v/ R6 M5 V7 IBUFFER_SIZE = 60000
8 P0 Z4 T3 E% {+ e& `4 q5 EBATCH_SIZE = 256& x _0 \3 d7 _) N4 U# j; V
* G% V3 \4 ^5 @/ i+ ]; }
# 批量化和打乱数据
4 D q% P& u3 c2 Utrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)3 H/ ?5 W1 U; b- d
三、创建模型
9 o0 ]5 }- b% o$ Y/ q" ^主要创建两个模型,一个是生成器,另一个是判别器。4 z/ M4 ^ f) e8 h
& W8 ^' M1 P1 C2 H) n
3 r6 N/ @9 S- b3.1 生成器
1 B! C& U) q! O" `& d. t. Z- `1 G生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。8 x2 ^7 ?: g0 x+ X) b* t$ V7 ^
2 j- e" H8 I, ~2 L, V2 H6 [8 |7 _! Q# O- Y
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
; a* ~8 E7 [$ R0 Q. O6 Q7 l- U& _ X$ ?, C8 n+ B' J; u
4 G& _+ ~8 n+ s' {* X( ~后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
7 V0 t9 U2 P- G+ y
& Q* |$ L t/ w$ b2 E4 }' d2 r( P1 e. s
def make_generator_model():
) a/ ^3 H4 Y0 q% H0 e model = tf.keras.Sequential() j2 E/ K, t" z4 `
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
- M1 R/ u; L. U# I" \ |* S/ h. v+ v model.add(layers.BatchNormalization())
& }8 s/ ?/ b$ |2 [- v; Q model.add(layers.LeakyReLU())
+ y: C0 O- |' h: n; H& e8 ^
1 o2 U5 |; s4 @5 E r! {4 C, i model.add(layers.Reshape((7, 7, 256)))! ^$ i1 v# f0 x: t; V- M) u
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
/ t$ p- x8 W3 I7 W6 m: V" P
2 c& i7 }: s1 p: }9 W) B) {2 E model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))' ^" I; k& n- r S w
assert model.output_shape == (None, 7, 7, 128)) k' }) O( ^) u! f
model.add(layers.BatchNormalization())0 X5 N- V7 ~0 q# n7 q
model.add(layers.LeakyReLU())- ~/ ~* b% I# w( ~% n. W- f
( m m+ ?" y& W: T! |+ {: y2 m model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))" u8 i6 o7 j [
assert model.output_shape == (None, 14, 14, 64)
& ^2 _4 u$ ~/ T; `/ ?" q1 p model.add(layers.BatchNormalization())
% \3 {* {7 C% o2 W* u( ^9 w model.add(layers.LeakyReLU())* ]8 ]+ Y! h' t1 D$ z
! Y+ h( \# }' w' f! M
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))$ A) p7 j+ q0 y! D5 N; T( _+ T! b
assert model.output_shape == (None, 28, 28, 1)' x. B# B; T1 F0 \
2 }' x' F( o% M% S4 a
return model: ^( m5 V6 Q" w5 {9 p) ]
用tf.keras.utils.plot_model( ),看一下模型结构* h8 R- {! c7 b% @& M! m
' B" \! T3 ~. }+ n6 e
$ @& M$ W8 u, X" v2 b% e
# q. G9 \; y4 G2 n$ {2 V
, R3 K& M1 z2 ]' A; n9 m
6 n; l) I9 ~3 _ r$ b; E用summary(),看一下模型结构和参数
2 O! N4 J* D" L4 i4 U* L
$ T$ h# h# S' Z3 L- H* m. B. d7 g: G
/ b. _% X& [' I- L7 e7 ?3 }0 n4 u" n% a- w9 _
7 {: p y9 _7 G( g W
/ M/ v; i5 |! l使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。! l+ k6 d. z+ |5 [
/ h8 `8 {+ x: o* V& \4 Q
8 k1 G! O* q' G# ]
generator = make_generator_model()
# T3 r; S1 o' q! y8 @+ L
$ A$ P" z4 j$ |' |# inoise = tf.random.normal([1, 100])
- g/ ~" j6 Y: w& b, q# Lgenerated_image = generator(noise, training=False)! P! p3 Z$ Y( f' u5 |
5 ^: ]# @# @( ^) I2 Y) b9 ^plt.imshow(generated_image[0, :, :, 0], cmap='gray')
7 }) H$ \7 c5 |3 A0 W9 n }+ @+ p! K8 \# g
, k6 }" N7 y5 I) e8 H! [% d
2 v4 F% Y1 @3 M: _0 \
) ]6 Q: G7 K% @" A9 r, y9 n3.1 判别器
1 d0 O' o/ ?5 t, m+ M& @8 X2 [1 w判别器是基于 CNN卷积神经网络 的图片分类器。% s9 ]( u4 N1 g/ Z& F8 Y
T3 F' R, e. v1 |6 } D" y
' u7 z7 a$ @2 Udef make_discriminator_model():
- W p# T# N" e |/ M model = tf.keras.Sequential()
+ b7 t( H) \# p, F# e model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',8 w) _3 f& W. B0 @; o
input_shape=[28, 28, 1])), \" T, u4 o2 n9 m( J9 A
model.add(layers.LeakyReLU()). {, Q3 k0 t" b7 {0 }# D
model.add(layers.Dropout(0.3))
$ ` B$ H" G# B9 _2 I4 I7 I+ p0 w; \
g7 H l' A$ C. z* C+ c7 ` model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
. o+ e8 m5 R! ]' X' a model.add(layers.LeakyReLU()); \ Q" i& Y! a/ E# z: w# @
model.add(layers.Dropout(0.3))
' O/ ]9 Y3 F) U, N2 R 8 q s. X1 r) c a2 j
model.add(layers.Flatten())4 Q8 W+ \, i5 j2 O, `/ ?: X6 S
model.add(layers.Dense(1))1 S; Z. Z# H9 `: Y* k5 w
! d8 f; }, Q- j* C return model
$ p; v( V' Q& d# T: {0 j1 B用tf.keras.utils.plot_model( ),看一下模型结构
: Q+ \1 t+ u2 v, e7 E, ^6 I- _: V" C3 N( ^( }
$ b: F4 F! t( D, W$ j- _: ]* k! N
' n6 u) D7 T; Q. N+ |4 ]0 Y
7 ^& R Z- u& T2 J% a; F2 c/ L8 ^ d6 K
% a1 b0 @2 q4 ^7 y: G' q
/ K Z5 n/ W% p2 g
用summary(),看一下模型结构和参数/ O4 O1 S& P V7 A! ]/ D7 u
: n% `6 R& \, B
6 u& e( s( H5 B F
9 w* m9 ^( H2 o. `& U$ n
- X4 @1 u- q! `
" f1 p) g& m: W5 X- I: d
) C! }6 ?9 u7 `* S' A; D4 p四、定义损失函数和优化器9 A6 H- D7 G s) W) o/ g
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。
' Z- D7 l$ ]0 D1 g" K( f
' r; _/ L2 t# q4 _2 @2 u" Q* F
3 w1 p; f: y5 r4 N' W首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。1 [2 H! N; I4 m, o5 l2 l, f, i$ V, Y
5 _' w4 x3 N% W8 \( ~ Z- }
2 Y/ l% U3 k/ n
# 该方法返回计算交叉熵损失的辅助函数( f0 P* O: s' O1 c
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
0 j8 _2 U5 U4 P$ [0 R4.1 生成器的损失和优化器5 q) { R) F; O4 E2 q
1)生成器损失
* k0 {/ i) X" {9 y
+ c3 \- f H2 O& }
" l: e, h( u# W, B& ^生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
, g" S5 ^, w+ z" r( F e3 r$ J5 b* i; o N( U; Z: j
e* Q4 h! Q- d1 F: t. y
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。3 G. M. a$ t' w# L% E7 H
% G! q) a1 G1 P6 i
) `. j' D, `, {% g Y4 H
def generator_loss(fake_output):" k2 C! z. ]/ m& W
return cross_entropy(tf.ones_like(fake_output), fake_output)
) O q& J9 d# }* P! H! R2)生成器优化器
& W- F6 ]& w- f; F* L' |4 c
1 D7 N: D6 o: t! [
9 J, L T3 K( I4 T0 y! o7 v( [# D: a9 i, Bgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)8 {5 L N1 D4 A2 z% `
4.2 判别器的损失和优化器4 l7 ?7 W! I% X! ]# B
1)判别器损失+ G6 `4 C t' g' f$ b' o2 R
, ]; r, \7 s; ~8 p
/ ?/ t3 o/ r; y9 }
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。+ [: s3 D1 G' w3 s
2 S( q) b6 D; O- ^- h Z" ^4 j6 I
. u5 w8 _( R9 Pdef discriminator_loss(real_output, fake_output):
4 t/ q1 d/ Y6 `1 r real_loss = cross_entropy(tf.ones_like(real_output), real_output)
( d* ^. m0 A* z/ h& O1 H: H fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output), u+ }/ u5 ^; V% O
total_loss = real_loss + fake_loss
- K3 s4 p; ^. S8 X+ i) y4 |5 a6 e/ H return total_loss
7 Z, W* ^+ G- i9 J2)判别器优化器8 D) `1 y# v' x. q
' |5 M* ?$ H" R8 m u3 s4 }5 n' l1 r$ h! e* O6 R8 x
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
+ \2 q( H, C# E* D五、训练模型6 R* H6 w2 F. G+ Q4 x0 V0 V5 @# S6 h8 k
5.1 保存检查点
$ Z' O7 L5 H2 D; q4 x保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。1 C3 q. I; p4 ~; W7 q, s
8 I3 b; M a8 i* L' F. s3 |
, N1 m; ^. K: |* M! fcheckpoint_dir = './training_checkpoints'* B; M( K E9 H' j* V
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")3 \' ?2 J$ v4 M) Q+ q' W
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
4 n, ?4 }! N2 F' L3 D6 J discriminator_optimizer=discriminator_optimizer,/ `, l4 O3 ], Z7 r( @. g. n: T, L
generator=generator,+ n: W" f7 j* ~6 j: a6 M- }: q8 b
discriminator=discriminator)
! Y; ^, `+ r1 s! |# v! y3 b9 X5.2 定义训练过程5 l) V+ [; T1 Q
EPOCHS = 50
1 k2 u0 `* e' q! t5 Unoise_dim = 100* F* M3 O4 [# {& _0 C; M
num_examples_to_generate = 16
$ u& q- i# i6 L/ G0 h ' g/ Z( A( v! w( @8 g
" W W2 U' r6 M& {: W( Z) @
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)$ b, i6 \7 N% k: e) s6 g! A
seed = tf.random.normal([num_examples_to_generate, noise_dim])' |1 z o# J) h) H' \/ Z7 L# S
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。. e _+ x2 e; _1 O- p. Y
: ?% @3 X* v1 j; R
3 x5 ~) X; A- _2 I! U0 o2 |# s判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。4 l+ p# |+ B; \4 y/ ?0 ?; R
2 N8 q7 P3 N' d% T4 E- [% ]! y: t( ?9 i1 S$ T, ?5 ]7 s& L' U1 V
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。9 r2 K1 @ w8 P5 d
+ m( l$ Q7 n, C. X7 E0 T
# r0 i9 U+ k5 P. @$ i5 C# 注意 `tf.function` 的使用% v; K: x: Z t
# 该注解使函数被“编译”& f* M) a3 ~) }% \& k8 s- d
@tf.function
7 k& K! r) D$ A/ b% T" {def train_step(images):+ q6 {$ A. c* s& h1 } P
noise = tf.random.normal([BATCH_SIZE, noise_dim])7 Y. a0 C7 c0 s. \* I4 Y
1 p' w4 d$ f% J9 X" Q' g with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: t; ^' s, W/ Q0 E5 ? N
generated_images = generator(noise, training=True)
' ~7 K; H6 G+ | Z+ E + I& B0 [% e5 }6 \: }7 X
real_output = discriminator(images, training=True)
$ M* j4 G E5 N fake_output = discriminator(generated_images, training=True)) V; D& w i) L' T# E+ n2 z% e# F
. R! N7 \2 \ |: k' I5 I9 R/ z gen_loss = generator_loss(fake_output)
: w3 q- O- S8 J7 x J disc_loss = discriminator_loss(real_output, fake_output)/ w) ~- x) N( _% L- a4 s
+ `( g+ D" u# _# v7 U- ]
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables): _% E; ?& U: d6 \$ z- e1 U
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
Y7 i5 t, b: F" @ T/ G" ` 6 e" i& M% }. C% R4 W/ |
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)), S" r: K5 C1 _ Z Q" V+ o
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
' q8 A8 }) ~" C# j
) I; T0 g7 z0 ?7 Q& @def train(dataset, epochs):2 M8 W2 J9 ^, W8 f6 P5 e5 q
for epoch in range(epochs):: A: I: y! U8 G7 _8 O. O
start = time.time()# O* M, ~* E# k8 N# @
% ^; a3 P3 z! u$ V
for image_batch in dataset:) c) E u4 \% c& j' {7 D! B- d
train_step(image_batch)
, l7 T) [: X& l: S6 p4 e, P # z4 y% V# [9 d
# 继续进行时为 GIF 生成图像
* R2 D# _6 ?3 U. m4 e9 o! f9 t display.clear_output(wait=True)
, D ?- B! I! m generate_and_save_images(generator,. g! b3 e$ F( l i! a! z
epoch + 1,7 {( T; R7 M* F: G- W$ U. F4 R8 J
seed)
4 C$ g4 b( P: ^" }
5 z2 k0 l$ x+ h% l6 n. E* ` # 每 15 个 epoch 保存一次模型+ c! k4 m( {! o3 `% A2 c$ [
if (epoch + 1) % 15 == 0:
; a5 E5 ?) }' M& [ checkpoint.save(file_prefix = checkpoint_prefix)
( Q5 w8 a8 ]" t1 X, X5 _8 l+ ^ 4 G. h) x2 d9 d7 D
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
/ g3 Q, P0 r6 Q / l. ?, s9 a, W
# 最后一个 epoch 结束后生成图片
0 y5 o, j7 s$ `, w5 e& G display.clear_output(wait=True)7 X7 p$ T e" L9 M6 Y1 l" m& o
generate_and_save_images(generator,: v5 n7 E6 J) A$ w
epochs,
5 j* L7 d) k# h$ j9 q! u( u seed)& K6 e6 ], z+ m4 \
+ U/ @% K# Z U3 `! i# q
# 生成与保存图片
* W- ?1 Y( X0 |7 Ydef generate_and_save_images(model, epoch, test_input):
' `+ v. U% p8 i) ^4 r1 a9 I # 注意 training` 设定为 False8 m& O* x d1 }) v8 Z
# 因此,所有层都在推理模式下运行(batchnorm)。
/ Q% j8 m6 g. ~# | predictions = model(test_input, training=False)
$ Z. ^" p j/ \
4 z: w; C" A5 c# `3 x' w fig = plt.figure(figsize=(4,4))) O- H5 F, Z3 i' ^0 H
7 T. l& N% s' J
for i in range(predictions.shape[0]):
; A4 y m2 r2 d7 j plt.subplot(4, 4, i+1)
Z" b4 D. | U% B" `& Q; K" k+ O) _ plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
, a$ K$ p/ p. s: |3 h2 ?7 k plt.axis('off')
' y& E; q# D/ \/ a
# E; ?9 y2 {, Q x5 U; X plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
. R6 m) V$ g0 a7 L plt.show()' K2 j7 m8 I2 o3 G! |+ \) x
5.3 训练模型2 i; W' N8 j2 K% h: F# J
调用上面定义的train()函数,来同时训练生成器和判别器。/ D. Q2 k" q# g: E. W
8 G8 z. ]. Z5 Z- `# g0 C) k6 _ p, }5 C0 d& ?/ `
注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
/ }( o5 @% j" }5 O
1 _8 L% q" Y& U; F7 w1 J7 [, ~0 J) ]/ v* W, H: L2 G) `& a
%%time1 k6 m/ B; l8 a- A! g. U9 z
train(train_dataset, EPOCHS)% H5 b+ |. b! a% i% @7 |* d
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。5 `. K! n1 k! R1 P, M* o7 h& Z1 c
* R% |6 V9 H: F2 m3 n! ~
P! N: d* p; \$ h训练了15轮的效果:
1 r4 u9 f7 p% f/ N1 l1 ~9 D
. G# }+ I4 T" N; N) W- F9 K& U# m1 J+ E5 V/ g+ S" c
. Z3 z3 I) W6 H% e0 C
9 ^" X- X, B5 ^
+ R$ i e, K& v; M" L2 M+ Y
, ^/ |% H' `& Z- y9 T6 S训练了30轮的效果:4 S) ^# h: j) [# ]; w' U
/ ?& {$ l4 y) Q6 Z/ }( r
9 K1 n1 r$ t& l0 U7 I$ N+ ?/ H% ~5 e
2 t7 n. O2 a/ w6 {) D" P0 L
1 m/ O3 p7 X5 D4 W# j9 A) @
, R) }4 n* B0 S8 r
- O# ]4 q. r0 w6 B1 Q; R- `: ?训练过程:
! c6 O6 Q, X( r) j' b
8 x( V9 f1 w3 D5 Z9 v0 z! ?) S E" t/ d$ h( @
9 t8 J3 A9 T9 K- [, | h6 W
4 N6 x# [6 ]' M& l; N
; a& T$ ^6 d3 k7 J+ T5 s7 h
2 o: T/ p( x8 }5 n _" `; U0 h8 {恢复最新的检查点" l g( \2 e+ h9 o$ d/ t
% c6 R; U: a4 X- d' s" b H
4 x7 O7 j1 c2 s, n8 ycheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
3 O+ Z( w" e6 h7 N3 V& c% j六、评估模型
' E( n- H% E6 i3 S1 R A这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
4 m. }5 M( i9 {0 b
$ f) V3 ]& r q' A1 e. a) O6 Y7 i8 z
# 使用 epoch 数生成单张图片5 a, ?$ e1 |9 v/ ]' f" k! s
def display_image(epoch_no):) ^3 i' U9 F9 X
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
: n% b ?' ?* ~* C 0 b* H! P3 P0 X
display_image(EPOCHS)
. Y3 A3 }6 U& P! ] ?6 janim_file = 'dcgan.gif'
0 y, d1 `% y2 b* X1 x9 _# E
7 Y' W: n8 e/ k7 Bwith imageio.get_writer(anim_file, mode='I') as writer:
* X1 R9 n- e. ~/ D2 Q filenames = glob.glob('image*.png')$ w* b- a/ O) i
filenames = sorted(filenames). `" R" j- v+ f! a- ?" q; o) A B
last = -1+ o% w* C% n8 r
for i,filename in enumerate(filenames):1 `. ^# Q1 r) l6 E+ [% T
frame = 2*(i**0.5)
1 ]! s$ y8 j" E% ~$ b if round(frame) > round(last):, |2 Y5 h5 t0 X& L
last = frame
/ S& W- k# C' a2 _* ?. R3 B else:
9 a2 ^& ?6 H& D' L, L continue; l/ a! z* g( @6 e+ o& o5 w
image = imageio.imread(filename)
2 r! L3 F0 @% ?$ R writer.append_data(image)
2 n. Z7 s3 U2 Y& | image = imageio.imread(filename); {4 `) E8 Q1 L% h
writer.append_data(image)
% C- H1 v- M; s7 \- ~, ?- T; Y
& ~/ `6 Y+ d$ P, b$ {import IPython _3 G! `5 m- q1 T, V! Y% O
if IPython.version_info > (6,2,0,''):
$ K* F( j! ^9 t9 U% Z6 i7 R display.Image(filename=anim_file)
Y5 }# H N' u4 I
V- q+ H; n h/ Q( S
( }* o2 Y5 z' R2 _( d# O& ` Q% B! n) g, c4 b, a) D8 y# v5 B
6 i( v' C. l) T2 s5 L" T* A% N
完整代码:) @, U# G; R/ j: h' w
5 j3 @+ M1 R; w L/ P
; r: a) ]+ D ], S( X3 b e; Nimport tensorflow as tf8 {; a+ ? H3 ]9 j) V1 o$ Q0 ]1 D
import glob, D. I" I8 A" |& e8 ]3 b5 k
import imageio
8 Q# d% n! W$ m6 I5 c Simport matplotlib.pyplot as plt
1 l$ [; G! H+ ]' ^6 q" Bimport numpy as np
5 }6 P$ S% v5 b8 O8 W" himport os% ]: ^! i- R: g& D p! B* [$ @
import PIL
% _" @% \ z! w5 y! q7 E, Tfrom tensorflow.keras import layers5 s* C- ^$ j8 @
import time
, Q! v+ {+ i; i+ S2 i/ q
' L% ]. Y1 R4 v9 C2 Qfrom IPython import display- K% A; i" y3 ~% d
: j9 c' `. H; O(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()# x2 _7 w6 X! K5 T1 h6 q" g! U6 b
- ]6 n/ {5 }& ]2 w* c
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')8 m/ ~" E; |' W4 J7 U. J
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
' ]9 ?6 R: n! J( X6 `* I' V3 u1 { : q6 g, y, o5 \6 C' T
BUFFER_SIZE = 60000% O* U/ ^ ]+ k% N
BATCH_SIZE = 256
+ T' N, X! Z4 J! P. z; Q ( u8 H! W% ?: h0 V, d( a
# 批量化和打乱数据# B, v3 U4 }' X0 h1 V
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# }* i3 [& G- J7 ~7 Q# j 4 Y( `& v/ G j6 ^% i
# 创建模型--生成器6 `% l, k4 N6 _4 k7 W# n- e
def make_generator_model():9 i3 Y8 b! v7 e1 ?3 k& ^- O
model = tf.keras.Sequential()% ]4 I) h0 K1 `
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))). K1 x3 b; e0 Q
model.add(layers.BatchNormalization())7 Y' ?8 U9 J& R8 o. W/ X
model.add(layers.LeakyReLU())
8 H$ W" ~/ b) }3 b1 s$ [
9 V4 `5 n8 j( V: x) D5 f3 C model.add(layers.Reshape((7, 7, 256)))
& Y" U& v0 o. G4 T4 | assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
+ O+ T0 z( Z, n; N x% Z/ l) }
5 B4 }# e/ C; A" _4 ]" c model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))1 Y6 c: D1 _1 i! M/ E
assert model.output_shape == (None, 7, 7, 128)4 C5 ~4 ~; Y( P$ k( X& ~( R/ n0 n) r
model.add(layers.BatchNormalization())4 n' z, ]8 e8 i: F& P$ a- I9 b
model.add(layers.LeakyReLU())
! y: R2 C8 k' [0 k5 H 5 G* C/ i' Q) ^; {- V0 C
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))6 y4 Y& {9 r4 j6 \/ M6 [; _( y$ [3 U
assert model.output_shape == (None, 14, 14, 64)4 j8 _: D1 d1 h8 M4 |) [' ]
model.add(layers.BatchNormalization()); X' g- X0 A8 i3 ?
model.add(layers.LeakyReLU())0 H: Q$ \. K2 S- Y) q
3 f( y4 n% k7 @6 q. R5 o1 h4 a
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))1 F( ~, Z; p" W6 `8 ?+ A
assert model.output_shape == (None, 28, 28, 1)
$ _1 E5 U9 V5 w: M3 a) p
( x/ s4 ~! w) v# k+ a0 K" g return model
' R" o5 K" w) m7 E/ {3 T, v 2 a* A, }* S: h, e& |" O
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
- Z& A8 w, f: o! Q) ?& {generator = make_generator_model()0 [2 T5 q8 b5 Z2 I _: E
- P7 {9 y; g, L6 A; j9 m0 e) @noise = tf.random.normal([1, 100])
7 X! J( a7 p% f2 Y. v0 X/ ngenerated_image = generator(noise, training=False): R$ e! o3 A4 ^/ U* u
) H& V4 i7 s. \+ C- D7 a
plt.imshow(generated_image[0, :, :, 0], cmap='gray')% e) E- I5 W S8 [) H7 z2 H, j
tf.keras.utils.plot_model(generator)7 F% v& z2 R3 p- F3 d) A
$ D8 t& ], h% C8 S/ s" W: y# 判别器
4 u* ^! U/ v) Zdef make_discriminator_model():
7 n# V, m/ h3 O7 l6 c model = tf.keras.Sequential()2 }" ]8 s5 o8 i6 O6 F
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
% m1 A6 o. s; l3 u! d input_shape=[28, 28, 1]))
. x7 Y( Z( ]5 B6 G( b0 S5 e model.add(layers.LeakyReLU())
' d; e2 ]: `) @) w model.add(layers.Dropout(0.3))0 P# s- U( r+ ^/ G
2 A) j+ X" l( q, o# O3 m
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) w9 b! P, w7 ?( R% p! [
model.add(layers.LeakyReLU())
& [9 G0 @" P0 `( D: b9 p/ t model.add(layers.Dropout(0.3))7 w; F; X! E$ g' I" B" c% \6 t( M+ g
( K9 Z2 x7 h4 ?- z% `) ~ model.add(layers.Flatten())
) |2 J) d: u# A$ V# \) } model.add(layers.Dense(1))
1 \% z5 z2 S9 |7 z 1 A) n# {7 W5 ?% t. G
return model6 Q& K6 K3 ?0 b6 ^
( Y: D; _, K2 ]7 @0 _0 y4 @
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。9 a8 M% }; y# F6 w( x( V/ u
discriminator = make_discriminator_model()
8 p8 n$ W; V: W. \! e8 N' Hdecision = discriminator(generated_image)
+ o+ ]' B$ A* K& F; Nprint (decision)+ \- ?" N, @2 `9 ?0 Z* }
- y/ `- L7 L+ X6 p8 \: O; V$ P# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
8 P7 S( r' R6 ]: ^/ k) B" Tcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True), r2 v" M6 t; k! q. |' {
0 b# i2 X' n2 a3 a' r# z' c
# 生成器的损失和优化器
& U1 N8 E1 p2 g% pdef generator_loss(fake_output):( h; |9 f* u" @ l
return cross_entropy(tf.ones_like(fake_output), fake_output)3 H$ _; `. i7 t% H3 j& l
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
1 J& ^: b9 z4 z9 ^( n/ Y0 c
+ h. T: m( ?3 u- D. i+ C7 B# 判别器的损失和优化器
0 I. k+ C X: A$ i& R! d4 Pdef discriminator_loss(real_output, fake_output):) r6 D* t) o; y+ M" D
real_loss = cross_entropy(tf.ones_like(real_output), real_output) d; k9 a: `0 v
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output), A6 l" u3 h/ R% e/ }
total_loss = real_loss + fake_loss' M+ M* T, S2 G( s3 n
return total_loss
' L- j1 |+ p2 t# tdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4). u: U% I! q3 } j& ^
6 I- ], I3 V' L4 [; e. B3 T) b
# 保存检查点
2 E9 q. j$ k9 C" i: ]% ^ ~checkpoint_dir = './training_checkpoints'$ R$ I% }" R. C' s0 X
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
. g, B8 ?. Z4 g0 }checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,5 N- N; i1 }2 q. G) c
discriminator_optimizer=discriminator_optimizer,
- _+ ~; [0 R3 n( ]% C% j generator=generator,- u1 P+ {% _( O5 L6 h
discriminator=discriminator)
% S2 u, m& X# U' v% k. r- [ 9 c( J% ?5 p6 J' _
# 定义训练过程" f8 d! Q! ]2 |' ]- _& m! W- _' }
EPOCHS = 50
9 T1 [7 C, v: t5 jnoise_dim = 100
( E9 ^+ J9 f. n' x/ F+ {num_examples_to_generate = 168 c; L5 m2 E5 ~3 W1 w0 m) v
+ {9 G5 K/ V Z# g( R: M# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
: u) ?" y V- }9 @% S6 useed = tf.random.normal([num_examples_to_generate, noise_dim]): c! N& u8 s7 n# V5 R {8 j
0 u, u& F* q+ h; M5 i6 M
# 注意 `tf.function` 的使用
- c4 j R' \% z7 A# i4 }3 e+ z4 a/ ^# 该注解使函数被“编译”3 ^9 V x4 J7 W
@tf.function
) j: X# h1 D h3 Udef train_step(images):% s% U, U' y; Z! f, @
noise = tf.random.normal([BATCH_SIZE, noise_dim])( ?) \* C) Q' I l" f/ X
( w2 M6 J, X6 [) c" ^. Z
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:; _( K* a3 j& k' Q
generated_images = generator(noise, training=True)1 n- D6 R/ I b' j2 i3 z
+ s0 r5 j1 J% A( {* G# h$ x
real_output = discriminator(images, training=True)
@' X1 X3 U4 b fake_output = discriminator(generated_images, training=True)
: l2 }- {$ B0 M. O1 i, D+ _ 6 k7 y3 z6 c5 {% b8 d+ P, Z
gen_loss = generator_loss(fake_output)
8 M: o; J* G; x) B$ y! g disc_loss = discriminator_loss(real_output, fake_output)
' X, E( I* m6 Z1 A5 K! i
, C7 u/ M; `. j( a gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)- V) L6 F* |! c% b9 l' Z0 A
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)- ]8 s( V2 C, A5 ^1 ]
7 {- k+ A# F2 W- [ _
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
% R$ m$ u6 k; {5 d/ }2 ]& a" {/ {2 ] discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
, f' \" \- }% Q, ~( z" |; j
* j3 K Z. w' l( zdef train(dataset, epochs):8 G4 m0 t2 V2 ?; u
for epoch in range(epochs):
r8 J. T+ }& F8 K9 a7 x+ C" w% v T start = time.time()
3 U0 X7 F* Z, [- l& V9 [
. O$ l5 L, I: H; T* z. x% O for image_batch in dataset:
* [* n3 S6 Y- z0 @# @2 T ? train_step(image_batch)
6 p1 B; g( w6 w2 m0 ?$ Q 9 t$ O% T$ N' z* |
# 继续进行时为 GIF 生成图像' U! ? X3 u( X' a& X1 h- w
display.clear_output(wait=True)/ h; G4 ?$ f# P( U
generate_and_save_images(generator,
: G& e( P' B6 ]0 {; ?& v5 I epoch + 1,
$ O9 l# l. j* ~0 ^: p" F seed)8 V: \+ l m8 S8 ^: Q6 X% r
3 S! \& D! [5 G # 每 15 个 epoch 保存一次模型
$ P2 l, v* v. Q! }: @ q1 L if (epoch + 1) % 15 == 0:5 p2 w# Z. A% a9 ^/ Q; L
checkpoint.save(file_prefix = checkpoint_prefix)( a5 D7 A2 F7 ?. I3 y0 V
! ?0 p g& t* q% D print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
9 U1 r# p( T3 l8 H( r
9 y. B n5 K% t- X, N+ u # 最后一个 epoch 结束后生成图片
! P" S5 r! t- Y0 M3 R8 J display.clear_output(wait=True)
" z" I. o: v2 j. x; ]* _5 k generate_and_save_images(generator,
" J* v9 L8 v6 Q' c epochs,
! c: }9 S5 w) `) X# X seed)# z" y! l% L w; Z8 c! g4 K/ z
1 O) {; b" A* s0 q' [+ f* A) v# 生成与保存图片3 ^$ h, R" h3 }
def generate_and_save_images(model, epoch, test_input):
# M- I% t$ y V( _ # 注意 training` 设定为 False" o3 B _( H" M
# 因此,所有层都在推理模式下运行(batchnorm)。
4 `- c1 T' Z& }% [# [4 W7 z& Z2 p predictions = model(test_input, training=False)# D3 T1 h. { V s4 `
9 [ K& m& @4 s) X7 }: Q2 H& M
fig = plt.figure(figsize=(4,4))
9 s) W: V+ I) k: E) z3 H& ~( k 1 q! a. P6 W, @* b3 I" n
for i in range(predictions.shape[0]):
7 N3 Z( l6 ~0 e( Y( ?) ^ plt.subplot(4, 4, i+1)8 E4 D. h: @' @7 f" d* [: t
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')4 X- a, d+ T/ v+ \) m
plt.axis('off')
& K8 j! f/ d# C' o3 k* Z. }' ?* [1 A , Z4 U* _) l$ l- O4 I' s
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
! x( J* B" K& `- e% h plt.show()
7 C# `/ k% k2 a2 q
3 C7 S: I, n, a& ~$ ~ O# 训练模型
; e* |3 Y M. S: k# O* h# O4 utrain(train_dataset, EPOCHS)
! x2 B# L- }! U, X1 n- k, `4 T& b2 W* } 1 Y0 H9 M' I7 t: C( e
# 恢复最新的检查点
3 V3 F9 P4 A! t' a* kcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
, ]& j" x& }7 h# T- v
( p& i/ _# w# l- i7 T# 评估模型! G. t8 A6 H; p+ t7 ^) T8 s
# 使用 epoch 数生成单张图片
& i8 ~0 ]9 p: z/ V6 e5 A2 ]def display_image(epoch_no):
l; T( |; K( R6 y! u( z% m( g return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))2 m9 V6 U% z0 n$ G2 [4 R# d& ?
& s0 O. _' X2 C$ a# h" V. W- G
display_image(EPOCHS)
5 }! ?3 s$ S. o, N5 c# o1 h1 t ( G1 R U, j: [% b% |6 r# `
anim_file = 'dcgan.gif'& P, p( c* q) ~* g k+ a
$ E. I1 y$ A! |" k. Y! s
with imageio.get_writer(anim_file, mode='I') as writer:9 H1 H9 D" W' T4 {4 m
filenames = glob.glob('image*.png')% {+ `# l) N$ Z P
filenames = sorted(filenames)
, ?% n; q$ w) P" g: t- U last = -1
* |2 r2 F! h5 |. U4 S! m1 p4 \" i for i,filename in enumerate(filenames): L/ F5 o9 u! I( o5 z9 X
frame = 2*(i**0.5)7 g2 b' t4 K$ W6 Z: r' R; @/ c B
if round(frame) > round(last):1 L+ h: ]8 q) c; T. k6 b
last = frame
3 r. C7 B9 n8 u else:1 Z1 Q, p5 V2 r7 F
continue
" {* |/ l, e+ i" }0 V. R image = imageio.imread(filename); X7 Q- v, b5 L, b2 v' u
writer.append_data(image)
" F' L, p' Y% b3 i) K) Y image = imageio.imread(filename): ]1 C/ j: O3 j4 ?" Y
writer.append_data(image): C! n4 k* n, e4 U/ K* S& K# d9 [
7 u; {+ @& F0 I5 E2 H+ a
import IPython
1 ~$ x4 x4 _4 d1 ]3 Yif IPython.version_info > (6,2,0,''):2 W2 F. X4 U8 ^- v
display.Image(filename=anim_file)
9 s, w% u' _) K9 K ~% Z" r, @9 ^* F参考:https://www.tensorflow.org/tutorials/generative/dcgan
8 F5 }! i$ R9 y; b————————————————6 z* r4 D" {- }4 A( c
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。, c ^5 x- j) O e
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
. i; Y1 H) p; D# \8 m d4 l( v6 b$ e9 m8 I6 z( _
, g0 {% L2 H c. \) p
|
zan
|