- 在线时间
- 1630 小时
- 最后登录
- 2024-1-29
- 注册时间
- 2017-5-16
- 听众数
- 82
- 收听数
- 1
- 能力
- 120 分
- 体力
- 564691 点
- 威望
- 12 点
- 阅读权限
- 255
- 积分
- 174630
- 相册
- 1
- 日志
- 0
- 记录
- 0
- 帖子
- 5313
- 主题
- 5273
- 精华
- 3
- 分享
- 0
- 好友
- 163
TA的每日心情 | 开心 2021-8-11 17:59 |
|---|
签到天数: 17 天 [LV.4]偶尔看看III 网络挑战赛参赛者 网络挑战赛参赛者 - 自我介绍
- 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
 群组: 2018美赛大象算法课程 群组: 2018美赛护航培训课程 群组: 2019年 数学中国站长建 群组: 2019年数据分析师课程 群组: 2018年大象老师国赛优 |
# O& d3 R$ D' T" y) x深度卷积生成对抗网络DCGAN——生成手写数字图片
) X! A4 x& Y5 h) |8 p+ k. @- N前言
; I* {) k7 ]/ O& i, |9 S本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
0 o4 b8 g7 A' I. @2 k0 ^
/ ~) d9 k) e8 ?/ N; C, U y: J# z8 x5 |# H; |
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:& Q* P ^& T) d9 C
9 \: p$ D0 P3 T6 {, E* Y; R
7 A! u' M$ B* H; ^) q0 j# 用于生成 GIF 图片, j0 D6 G/ X/ B! L, G
pip install -q imageio# D7 p% }+ A. n" y, G- i* O, y
目录
- u# G" P/ X0 k; ]
# E# v# f: x1 h- q. |0 b9 V+ a* b- p5 P8 W/ v x+ o2 ^$ S
前言0 N% V% C: |2 [, R2 h8 W2 w' s
! R, P4 n8 A- f# s
+ A x3 v3 F% f8 n" E: K4 a0 d一、什么是生成对抗网络?5 x# W T: F- F
6 A2 m2 U6 u) T8 W3 K# ~: Y, q. O3 J, t0 w- P
二、加载数据集0 k6 @* H7 f% T9 p
) H) X/ p+ W6 z; Z& b9 S
z2 j! l; p- a" ~- E: o% T. {三、创建模型
1 P( U( A3 T% j ^, c7 `0 D# }# H( _% o' Q9 R& }
- \+ J# i U; n% l3.1 生成器
5 d" l2 y* Y @* g9 z
* b5 H# E+ H+ `2 M! S+ U# p! Q# O" }/ X: {, N' Q7 ~% ~2 M
3.1 判别器, K' w* ?: W; U" M! _
& C: G% w V# L# R1 a1 G
, N& s* Q0 C5 [( |( b/ v四、定义损失函数和优化器
1 F; d! X" s: `2 {# |7 I& @6 M( {) E' D
, P, Q# L7 k! f3 f+ T. h
* y5 M5 K. \0 t- k! j4.1 生成器的损失和优化器
* Y2 i; ]( O: i' N1 v& a$ X( _* ~
% k. B! P z0 B5 o4.2 判别器的损失和优化器
* x+ l# R. d5 a" P+ Y8 s/ k! Y$ G3 ` B4 ~
8 }* b: p! e: d/ `( O) B( A3 p
五、训练模型2 E" B+ d: A% v' B7 n
; g! p$ | G# Y+ P
# w( N0 `4 g& o/ b) K; m; l5.1 保存检查点2 o+ r' y+ _ ?2 M9 R0 x
2 s/ i( h4 _! Q
: _; m. S8 s* p* w! ?7 Y5.2 定义训练过程$ q7 h- l* }2 D, l
8 E; n9 S" F; P+ F% E8 o% n+ {2 `; P( K8 |2 ?) c: N6 ^
5.3 训练模型0 J9 j' ]! `8 T- V+ X
K$ ?. j' x# J4 e+ G' e
! s2 c8 J) B C' M5 b! c- }, q) ^六、评估模型
: F. I, A; o$ u) f
. i. |9 |" K; g+ e0 ]- J+ o7 b6 g( i3 I, K4 g
一、什么是生成对抗网络?( d$ i3 i$ o1 m' W! c+ m1 }) U
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。8 t# X* c" i4 m/ y$ ^6 s
# p4 r* C8 O6 J% X, z# v8 _- n8 m
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。8 N7 W2 I# A( T# G( [
$ x3 z2 ]- g% }* S/ k
4 ?: S# s+ S1 V- D判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。! U; W7 d$ V, u2 a4 a8 k
' l' g) X9 _% Z' g
+ t! T" ~( w* j# }% G训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
# b4 z5 d1 d4 }; _* y1 ?1 N
0 r; ]' m/ }- C* X' z m2 d; v+ |' u* I+ l) R
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
1 H0 T* h: }0 O! F
9 T8 R- M+ V( u. z5 h2 Z1 }: ?4 Z& @! e- _. H
本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
8 Z) \1 @; E; s) L9 L7 M* |7 v; B+ g! m
7 L" s) |1 S7 S& C
2 n/ R+ ?& U: i/ A; e" g7 ^二、加载数据集
; m% b1 m" r5 H5 T1 i, T使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
+ v: _9 o0 f( f2 S( K; H1 F% L5 U) e
1 G2 T% |2 G6 z" K& ]2 a(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()) }( k8 L" B5 [7 M9 q* B8 O$ v4 y5 a
4 d" _. d: o9 c8 q
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')5 w# r5 W- t [1 D' Q- M' U% _
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内/ K9 i9 q( Y9 F/ \
) H# U+ J6 q/ ^# r
BUFFER_SIZE = 60000
2 D* h: w. C1 b# n0 @BATCH_SIZE = 2569 F, O; }+ w+ ^
/ q0 H5 z: Y& n6 \' g L; d
# 批量化和打乱数据 |8 j9 S6 ^& J* A4 }. F. _/ F
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
* X: R" u7 h' N8 u三、创建模型% ^& x& o* E# Y8 S) ]. B- }
主要创建两个模型,一个是生成器,另一个是判别器。
) C9 X' e. l" o7 O. Q% J/ ^0 }7 M
! X: C6 X |- e. X* N+ |9 J+ M' L |3 c
3.1 生成器
3 t% H y3 F' y w5 Y7 H生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
' \# W8 M9 N- \7 d4 S" o3 U: E% S5 L9 x
# \9 h$ l5 F7 T" T: k( r/ W# U然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。! D; J$ y9 a% F, T' b7 C4 g& t
* [0 o) B4 O& m
$ l" u3 U% n( O; h
后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
9 u3 H7 p g) }1 W7 o1 x6 G9 f% R. K7 N2 M' T( d d
( W' P" x# L5 @/ e" H, `* Mdef make_generator_model():. B- e# t8 |& D+ T( v3 L( n; q" ], y$ B
model = tf.keras.Sequential()
, j/ H, B1 H/ ~- ~+ h" k model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
) v% U4 Z k# @& J+ H* m model.add(layers.BatchNormalization()), R- G3 e! N8 O1 t* n$ M2 h
model.add(layers.LeakyReLU())
$ \* R+ }( b7 @3 C2 e3 C2 ~% M
9 w9 a" @- ^ |! I$ { model.add(layers.Reshape((7, 7, 256)))+ Z, C" a5 f8 L4 s* v
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
7 R" E+ z' x7 f2 l, p: }) J& Y H! a: p
7 _! Q" E% {7 O0 v7 ~; o4 x model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)): o4 v! ]3 H- b8 O3 p2 ~. M
assert model.output_shape == (None, 7, 7, 128)
7 y5 w/ c7 {; m5 p% E7 u; s9 O+ M7 V5 f model.add(layers.BatchNormalization()) M( L1 t P5 Q5 ]: Z6 X4 p; A
model.add(layers.LeakyReLU())
* X% Z2 G6 [. b# U; H* @, U
: ~2 m+ ~, F" A- j model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
% n3 F7 |+ x$ A assert model.output_shape == (None, 14, 14, 64)
, b4 C, W: w ^% r model.add(layers.BatchNormalization())
9 a/ L. z X0 X' [8 Q model.add(layers.LeakyReLU())
5 q" m: ^) j# g' H2 Z
0 `2 X! |9 w; g v model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))8 N8 y* m# }8 s1 t" f" w- [
assert model.output_shape == (None, 28, 28, 1)+ _& \) M G. u2 q
0 h- F% K3 Y/ M. U1 R& m return model5 @8 o. c0 w2 W) Q% S
用tf.keras.utils.plot_model( ),看一下模型结构, D! D# a. n) {6 B) }
8 Q" h7 x c; \- L& {. P8 i2 V: K( s
5 [6 v3 u$ Y3 ]
9 U3 q& Z2 |' h7 A5 i, o3 {
3 M0 k1 W9 f+ F3 D3 ]
2 D/ p$ J! H) E& ]: u9 A7 w! j" X用summary(),看一下模型结构和参数
9 v0 b/ s% t) \8 r. w* e
: [) K; l3 f7 z* z2 g6 k. A; V* j' W; k# ^; y
/ S1 L$ Q* I! R2 C# q3 B
9 Q% r5 {: S& t2 P% h& e8 {- t
: r4 z0 q( Q ]6 W9 q/ Z5 N/ ^0 W N6 B. \* e/ l( N8 N0 ^1 y+ r
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
/ \9 W/ g G+ F. V4 Y* n+ I, V: }% X$ g3 p1 m Q" j* j6 }/ ]
. ]2 l5 ]9 r f' J1 }6 Mgenerator = make_generator_model()- \* E4 f! \6 j7 c P3 D' J {7 @
9 k( n2 c& R; w9 i( ?0 Q# ?
noise = tf.random.normal([1, 100])7 U& z' P5 C/ y7 L ?4 _
generated_image = generator(noise, training=False): @0 w$ T* \$ {
" K6 S5 j6 T" O6 E; Nplt.imshow(generated_image[0, :, :, 0], cmap='gray')6 u" p$ P3 E( k b
' e4 F& T% c7 Q6 e
* x( w9 P9 a$ v2 K+ D2 S
6 J# T; E9 {8 `
1 O. C5 E2 e- t8 w I- D1 @1 u3.1 判别器3 x+ {" |& h, G. ]
判别器是基于 CNN卷积神经网络 的图片分类器。3 k& G" O& v1 E) C7 d
) p! I5 }& n6 s- r% s0 M! O+ A: i/ s2 G4 o
def make_discriminator_model():! o- o3 t& H7 e% B( W
model = tf.keras.Sequential()
& R0 x4 |3 Q+ T9 g7 d6 F model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
e1 m- E' k7 W3 w input_shape=[28, 28, 1]))
! @ R g2 i7 m8 E% S3 ?+ V model.add(layers.LeakyReLU())
* m; ?$ A0 a: ?0 O model.add(layers.Dropout(0.3))
- P+ O% ]/ }5 v% y# Z# D
) w' u0 }$ L H# ?$ s- L Y y model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
9 b3 F5 f5 z2 a model.add(layers.LeakyReLU())" b5 M9 u. |2 V, @ f
model.add(layers.Dropout(0.3))
, s+ U+ m p$ O0 T; h 3 @; [8 f( t# U& o3 C
model.add(layers.Flatten())
$ f0 v1 _3 Z9 |7 o model.add(layers.Dense(1))% v A& k6 R# y3 m
" ]3 c; k. S( y0 S( r- {
return model) j+ C1 p& N; i% W9 r$ Q1 U
用tf.keras.utils.plot_model( ),看一下模型结构
5 a9 \* q' `8 Z" c" `& q* [! m( d9 q# _: {3 U
) P/ n% I, d# y8 S# M) B. s, d' i; E+ U
% R: K. L b$ x0 Q5 {* m( I% a
/ L: O* d: ?( P# _8 b0 D7 g9 I: I% [" U) k4 ]
用summary(),看一下模型结构和参数
$ u# A+ D, e) f( a/ g1 c7 s5 C8 F5 @7 ~0 ?
! G: K3 x6 J/ _
" J! P5 I* b1 I. t
& @2 n5 g7 T z7 H0 i3 ]. u2 U1 S
% W1 P2 S& s6 b' ^8 z# s: X: e! \7 @& t1 S% d
四、定义损失函数和优化器, k: `5 S* r0 z9 b {/ ^7 C
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。
5 n; l2 e' N! l1 M/ P( ~% r/ ~ ~$ G( J! Z
# M# q/ C5 }- B% w& O2 B% I% Q- g3 Y首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。- g F, k: ? I8 t
) o9 K3 m* m+ M& x: I; k0 \
" n9 B! s8 n& C3 Z* [" Q# 该方法返回计算交叉熵损失的辅助函数: z/ p$ H; I# V( {* v7 Q: `
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)! |5 K2 m( g4 S8 j. ]3 w
4.1 生成器的损失和优化器5 m" [" V6 q$ ~. c7 K
1)生成器损失6 L/ I, h2 c; x: S6 O+ k: |
! T1 @9 F" d/ C
' `, A. b) Z, Z" v1 V生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
& e( Y5 j v2 E: P }4 u7 Y( [- U( K; J! l' l- Q
% s+ V. W3 @* B) @' y这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
7 ]4 H8 C/ M+ U9 o, J" w& S3 B$ p8 _9 v
' X6 m+ ]/ {; j! H( Fdef generator_loss(fake_output):
' Y% m8 q) z( j/ M. D return cross_entropy(tf.ones_like(fake_output), fake_output)/ {2 Z6 H; ^3 j3 t7 @1 O
2)生成器优化器
X# U" {! R: g! S# W
8 C5 j7 x+ B; d4 d9 ]9 D# z
2 Z' i2 G) s5 }+ dgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)$ ?; k" \% w- y/ \" l8 _7 `+ b
4.2 判别器的损失和优化器
. J- S4 [; K+ R6 I$ W) C( p1)判别器损失& H) M2 X+ d; u+ `# I/ M
" i" ?7 ~1 z; L3 [6 a% v* n7 f
' a) \6 e: C. r2 ~判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
; g% ^- a _& \4 g; G( I- H. q' |; B2 S; u9 f) u1 s H2 W6 K& N
+ Q2 T1 e0 l1 Q# C, P/ _ s
def discriminator_loss(real_output, fake_output):
+ P9 s+ |- n* W% t: b real_loss = cross_entropy(tf.ones_like(real_output), real_output)
: ?7 Q8 S' j5 O. C' W M/ ` ^ fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)7 z. N- _$ w7 k
total_loss = real_loss + fake_loss
3 s4 }- f C; g ?; L return total_loss
]1 z3 @) I! @6 [( r0 f; e( F7 Y2)判别器优化器: ^0 p9 _/ M( u4 D3 ]
?0 P" Q3 |1 k, w" V7 v! a. k
9 f5 o2 r) ~& r5 Z& {3 u3 R
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
: E7 D9 t( o1 ^' D6 m/ T五、训练模型! J U" b* }& _+ G: h- k. ?. T% i5 j
5.1 保存检查点
: _0 \( S n- f! j; o# _" r, A保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
+ W' S/ N8 r$ p- I9 u; ?: K" V R! P; u! d' ]$ l$ @2 ]
4 K, G2 A% g* `0 z/ f
checkpoint_dir = './training_checkpoints'
" Z- D* U. A( Kcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")" E# s, J# H8 v8 c H( ^5 @0 I
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
" C5 ^* I6 P) q* Q discriminator_optimizer=discriminator_optimizer,3 n1 n* K/ k2 V4 I+ E2 q5 t
generator=generator,$ m& F3 \! z8 Z: V6 Q0 G/ P8 E
discriminator=discriminator)
- n2 k5 [9 a7 H! D( b5 g$ c6 M f5.2 定义训练过程
$ R/ j/ o- [0 E6 ]! ~EPOCHS = 501 e$ P9 N7 p. ?$ }* p- w8 L
noise_dim = 100/ [+ b b0 g; n( I. R( h7 p- ?
num_examples_to_generate = 16: T: j7 z/ p- ^* @5 p$ r0 N" t/ k
7 y, c7 ~1 E) U: D2 D8 I
: g/ [( \; i' h4 I3 F' e# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
2 N7 J/ p+ J- \seed = tf.random.normal([num_examples_to_generate, noise_dim])
, z/ l5 _4 G% g: O8 ~, V训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。( K+ \' _, o" ]. M& v$ m4 ^
% G0 Y q: C# V+ G0 c! O4 x+ D
w* S; d) u7 w* Y8 o判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。+ n7 P9 B4 D: T$ b3 m( t+ o$ J3 c: L! o
# i0 k% P* Q, C7 f- F3 r
! f, }9 B/ ~2 j F& q' F4 Z% T( D
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。( @3 d4 L6 T8 ] Z$ B u+ c6 a
% U* ?6 H5 ]4 T; n. C" v$ F) D Q7 ?+ j8 O9 I9 D: `6 t3 j/ c
# 注意 `tf.function` 的使用
. \. h: v# d8 Q& C* p& c; b# 该注解使函数被“编译”- I2 K) W; ~) Z4 F
@tf.function0 s3 }5 {" H2 E4 r7 z
def train_step(images):: ?. E$ X3 t* a' [9 V5 J
noise = tf.random.normal([BATCH_SIZE, noise_dim])
3 c' ~+ X) U3 X/ ]) d
/ h B) ^( l+ C) l9 l& j- s- Y5 C with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
. b& N! \8 l* ?8 L7 d$ M generated_images = generator(noise, training=True)- a3 c& L$ s3 }8 W& |, D8 C
0 d" _. C/ }/ }6 T, d% p) q$ M
real_output = discriminator(images, training=True)1 `3 X3 Y# Z- ~0 N
fake_output = discriminator(generated_images, training=True)2 H5 `) {; X1 _; l% h
K7 T/ [1 q' T6 z1 z$ r gen_loss = generator_loss(fake_output)
. E" E4 z) i7 R3 m% H* \: u disc_loss = discriminator_loss(real_output, fake_output)
! u: ~3 z; a* A. C, @
1 [! i; N @8 z W* Z* C& H [. d gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
9 I% V" ?1 _$ g! V; @ gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
. s) M2 u0 t/ c ( y- K4 \4 Z5 t' m- G/ s
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))+ u; }3 I/ B% `" O1 O: G
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
8 }$ x3 \5 u9 Y
8 M0 ~1 }* f+ d d6 Z! \0 kdef train(dataset, epochs):
& ?, W; B, d5 i( q for epoch in range(epochs):4 V/ Y% k, {; F+ _/ N# W8 i4 j
start = time.time()
: f$ v4 _' r* a$ k' Y
( G% Y6 }! X8 C: z1 P* L2 c$ p! y for image_batch in dataset:* E: f8 O8 p) l& U" ^
train_step(image_batch)) Q! X" B4 D6 E! s+ P. A& ~7 {% n
+ k7 C- W: Y8 E& `) v l
# 继续进行时为 GIF 生成图像( h: d9 x+ G" e
display.clear_output(wait=True)9 q% H. x) }3 z
generate_and_save_images(generator,1 ?; n! I u1 [( b# d6 e
epoch + 1,$ z& ]% C7 l, Y$ j- U
seed)
2 C1 B7 O2 t2 \! `* m3 ~8 `2 ^
% h/ k2 g5 d) W3 t9 |& H # 每 15 个 epoch 保存一次模型1 Q' z, l4 E9 d) I. s
if (epoch + 1) % 15 == 0:, B8 r& E7 O5 T2 _" Z! o! R) A
checkpoint.save(file_prefix = checkpoint_prefix), c; @0 T- j2 P( c7 X: Z" ?7 D$ v
+ n0 Z3 {( K& U+ t
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))4 J' Q; m& c' Y2 s% z2 Z5 h
/ K9 x% W1 S, s5 W. a$ u9 ~3 l
# 最后一个 epoch 结束后生成图片/ k- L6 ]0 s. ~0 k; h6 A( V
display.clear_output(wait=True)2 ~4 A3 e/ ?$ n2 t0 ?* \. Z H
generate_and_save_images(generator,
Y( |. v2 c5 q+ F( | epochs,& Q' E6 _. X. [9 b; c
seed)7 ]2 C9 @5 k, p) f; g4 E9 |
. n9 K, d4 P" z+ ~; ^6 M# 生成与保存图片
+ ?2 V4 C' W9 z1 L/ @def generate_and_save_images(model, epoch, test_input):
/ d! s; |/ a* f$ L: R' _$ E # 注意 training` 设定为 False. ~) j8 q+ Q+ N" ]9 @5 m
# 因此,所有层都在推理模式下运行(batchnorm)。
9 z! g# j* W- W# F2 v predictions = model(test_input, training=False)
0 u B/ ]- q I# ^4 { s8 h# ?+ N$ i
fig = plt.figure(figsize=(4,4))
, H4 @7 H7 x; J2 f
/ V- |+ L# S/ T2 P7 T3 q. ]* S for i in range(predictions.shape[0]):
$ y$ h y4 W) J# a3 R plt.subplot(4, 4, i+1)8 i5 z3 h! ?* ?. f5 O
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
* }2 t* c5 M1 z4 X m plt.axis('off')" p/ P$ i! M* h: F# k
% c7 I x6 j1 x$ q/ o$ J! r
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
% u' ^$ V% ?- R, D, L6 D5 v* n. {) m plt.show()9 B7 R- [2 a+ i4 A7 @8 S
5.3 训练模型8 f3 ~" H3 J V% v
调用上面定义的train()函数,来同时训练生成器和判别器。
2 ]5 }% @# n# t- x4 U2 i( U- o x3 ?
& U* r5 s9 ]$ H a9 f' |7 i注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。8 _* [7 k$ b4 w/ f+ z0 L
4 v2 u7 f* q& X' w C
6 ?$ D0 ~3 h9 v
%%time6 p- c$ |/ R& i% f) c H# @
train(train_dataset, EPOCHS)* D. Z, V2 i, S& K% x8 Z
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
( Y; D, ~# ]' H4 G+ G' W
- `* _! b6 B q8 `' L2 D( S& @
* `" b$ x! x: ]4 W: H$ U训练了15轮的效果:
& O; q% s8 s4 B3 O3 Z7 r" u% \7 J3 d E9 Q: h. F, [% H
9 E: m4 p; |8 V' Y. t5 B
$ [- X0 U$ x" ]
& \; V. Q- c( D9 U
; n. q9 `$ f& z- k4 ]9 S I. y
训练了30轮的效果:
. h4 h. i( t, n+ Q
5 x& C, W- k0 \% E% M! X( Y1 h9 I+ F( a" M. h
4 E& n2 h8 T; R i: \# T! _) n( A+ {' Z) Z& o$ J- X& s
& ~6 i$ c$ o* E2 T) d6 J. u
. f2 x1 i, F! b1 O& k4 [. U! z1 h. `
训练过程:- P* D8 |8 S3 w5 B: @7 a
2 o8 i4 ~2 j4 W
0 o7 i/ h T3 ^. @3 K0 b0 i+ m8 [0 d
$ F x$ z e e8 c' q _
, @( d1 z* I- T5 D9 T
% @3 m6 y( r8 o( f* ^3 n6 _/ o1 H _+ R9 g8 m, x' R
恢复最新的检查点
A- u+ V+ B7 w% i0 o$ M* G' L# D- ~
; |6 k( U( k" _6 @, ^
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))' z ^. S: h6 W Q& @: R
六、评估模型
9 |' t2 M2 w, g1 Q4 A这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。' _5 ]' Q8 Y% Z- u5 i( b* }+ T+ o
: Q) B, ]8 g Y7 r
" Q8 t6 g9 l$ ^, Y3 Q+ V" A# 使用 epoch 数生成单张图片
8 K' g; g# K7 Y+ Cdef display_image(epoch_no):5 U @( G8 U7 f
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)). ^. |7 A4 i9 g: d$ f b
) g$ S1 {$ l: i! q7 @display_image(EPOCHS)- g D6 o9 X; Q9 Y- S
anim_file = 'dcgan.gif'
7 i4 n+ ?- i: H4 L: g
* h; Y/ j2 T+ J0 z4 Jwith imageio.get_writer(anim_file, mode='I') as writer:; E# _! D4 W' Y2 f y# `. g. {
filenames = glob.glob('image*.png')# U `* h2 q# ^
filenames = sorted(filenames)( ]( l/ u/ s# D3 `1 w" _3 ~
last = -1
) k( B+ u9 F) J H for i,filename in enumerate(filenames):
; U3 V5 W8 d' D5 N" y frame = 2*(i**0.5)8 {5 A8 u9 }3 r, @% S) Y, u1 C/ H
if round(frame) > round(last):5 m" ^ Z! Y' h. C7 e- S q+ ^
last = frame
& t+ w. d; I" q6 k else:
F+ M t% b9 q- V continue8 ?4 X' t7 a5 \; A" {. X5 s( z
image = imageio.imread(filename)
( `7 {8 q/ i% V writer.append_data(image)& c+ ~ T1 }$ x; P
image = imageio.imread(filename)
8 S) W( M7 y- q! T writer.append_data(image). A+ E/ G$ ~- Z% W' P
, c: M6 g a# r, F* v0 G5 f/ L" S
import IPython& E" V- a, y" w4 i! a. {3 ~/ o
if IPython.version_info > (6,2,0,''):5 r3 G5 i$ o) ?0 o
display.Image(filename=anim_file); H* ]4 B. a( V- p: ?
2 d& k- C: U; [* F9 N+ f( T4 X' n1 U; ^: R! b
- M4 y* }5 S* N! r* n7 x( @" D
7 r& ]* \7 G4 T+ \9 C: g$ f完整代码:
9 |; U) v. U! @+ o/ }. d% ]
3 `; c# w5 S. ]
8 K6 S( Y5 ~: M2 ?2 `6 ]import tensorflow as tf
2 E$ q) @# [# A' _2 Z4 _7 V5 `import glob
4 \6 n3 [8 @- {" J8 L9 r; |import imageio9 S$ U1 @2 d4 W1 T B
import matplotlib.pyplot as plt
7 h0 o9 R3 o; ]import numpy as np; u9 ?5 D) D0 n0 d w2 I$ Z1 U) M: [
import os
' H+ j* I. e& B' Nimport PIL$ Y2 M0 k: [: J/ v! M9 K+ b
from tensorflow.keras import layers
7 z N, F6 v6 p! s& Q0 Bimport time
J6 N: \, J0 W: \- g * N0 l' |+ ^3 {/ V/ V
from IPython import display. n8 }& M" |5 ]3 P& w, f- p/ j- ?/ u0 Y
) c0 [) `% F# z3 ], a* q/ u% F(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()7 s4 I4 y" K' P' ^) {1 Q% d# D
c' S" z2 B: v, \( _5 vtrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32'); u' X. o) d* H4 B$ X+ ]
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
5 D" D9 _( ]4 u5 q+ x ' @( ]: t" Z J' X* c; A7 [
BUFFER_SIZE = 60000! J5 L0 L `5 Z- Q8 s
BATCH_SIZE = 256
+ B, d1 M9 r6 i+ |
2 @( a; t8 w" b2 B$ E$ y5 Z2 E' \# 批量化和打乱数据
& v6 d P! v* ttrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" G v4 o3 A3 O2 T- u
8 ]/ Z: Y4 N3 `2 g7 a# 创建模型--生成器
0 R, H; R4 `: Gdef make_generator_model():
# S2 T( L: y$ y: h$ X model = tf.keras.Sequential()! m8 S9 [) K R9 R; G
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
& u- d1 F' b' c model.add(layers.BatchNormalization())
$ N1 b2 v) @" E! j# o& d model.add(layers.LeakyReLU()). z% j6 ?: U" K
9 c7 H7 j, M) m7 c" D' |) z! b
model.add(layers.Reshape((7, 7, 256)))& e) D" Z' \0 C, E1 u
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
: g k+ r* F$ A# {$ \) q2 o
$ b5 f2 ~# n5 ~* D5 e+ ]( f model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))' D' Z/ V8 F3 _$ e' p" L
assert model.output_shape == (None, 7, 7, 128)& y4 m& k) R5 t z7 n
model.add(layers.BatchNormalization())7 _3 e' _. Q0 k! Q( |1 z
model.add(layers.LeakyReLU())( c% i1 s+ U0 K5 \ Y7 I- q
7 g" z+ t. O# P/ a1 y
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)). ^, I( q: y+ _
assert model.output_shape == (None, 14, 14, 64)5 W3 ^6 Y5 V: v9 N3 j9 x
model.add(layers.BatchNormalization())/ e p5 F) l' E/ k, r: F( S
model.add(layers.LeakyReLU())2 k/ V' G1 Z6 q: o( J1 _' F
+ ~; u. |' q# M; X! |; `' P6 F( F model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
Z3 o* q6 p% s$ J0 ]( p assert model.output_shape == (None, 28, 28, 1)( z a$ N2 ?+ [8 R2 f5 R3 i
4 T" g" N% [; r2 x/ Q return model$ B& i! d- h1 h* D1 J' @
5 u& Q/ ?' a6 b# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
8 O+ D9 K+ P+ bgenerator = make_generator_model()& K( k( U3 J8 H9 ?5 ]
4 A$ }/ V2 X1 Z- enoise = tf.random.normal([1, 100])
- ], G F6 [' @0 M- P: D9 Jgenerated_image = generator(noise, training=False)
: C& U; k- h/ c7 W3 M) q
: ~0 X7 |3 u3 s! l: e7 yplt.imshow(generated_image[0, :, :, 0], cmap='gray')" L/ G" N% ]" C
tf.keras.utils.plot_model(generator)* o- i% z4 { a3 f) }$ E- ~% x+ ]0 b5 E
. `) y: J3 b0 t0 O X- D# 判别器8 I3 e5 G9 G# @( e
def make_discriminator_model():; f: z! r% i# D/ S4 @1 j5 v0 l0 W
model = tf.keras.Sequential()
1 |3 b& [, i. z7 S& L6 b7 t* G4 q model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
" S! U. y( a' F input_shape=[28, 28, 1]))
2 a1 Q. z, C( v! `! h3 g model.add(layers.LeakyReLU())
: l' N. n1 A# A8 S) s) [0 G9 P model.add(layers.Dropout(0.3))+ j4 X3 T9 {( x; S7 N' L, s
) N3 L% Y& m' E5 I! Q model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
/ `2 N3 h+ X _4 A/ I4 D6 D' A model.add(layers.LeakyReLU())) e2 Q: k5 o5 Y1 {, u2 B
model.add(layers.Dropout(0.3))1 P- x# X G1 l4 p
- u/ C; K# W: d model.add(layers.Flatten()) K' W' e6 ?5 i# K4 _+ [$ f/ }9 j! F& v
model.add(layers.Dense(1))6 x0 X$ R9 E# m; |! p( T+ `4 O/ A
$ C) |7 @( O: c" t7 P, r, | return model
, N4 G0 y: G; y
& T0 H* U5 }/ ?' t8 P# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
! s: Z) n4 U" z9 G r0 R+ sdiscriminator = make_discriminator_model()
# a( U9 _1 Q( Kdecision = discriminator(generated_image)$ V* I+ D. P7 e8 S7 i
print (decision)* ^0 z* {$ Z: s1 S2 L# w( I# o
3 G: T$ _. }- Q0 r' G; @# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。* \8 ] y+ |' a* g
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
2 `% v* ]/ u3 [' t$ ~6 `
% v2 U( R" X5 ?9 j' P& b* g# 生成器的损失和优化器
# B* V* T5 U4 }) P6 X: ?def generator_loss(fake_output):
4 W; u1 p( W# W. V+ I return cross_entropy(tf.ones_like(fake_output), fake_output)
+ n9 O9 k/ H* F7 `% V: ?generator_optimizer = tf.keras.optimizers.Adam(1e-4)/ b* { k0 I1 j% K# D
( U- g1 f2 p# e) G! W+ P9 z
# 判别器的损失和优化器
0 [4 c' j( R4 D. T( ^4 vdef discriminator_loss(real_output, fake_output):2 i ~7 p4 o8 `& D0 j
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
~' Q) l; e! |# `6 a' C fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
- R# l' D0 c1 g total_loss = real_loss + fake_loss
) q b- [* I0 V- I: } return total_loss1 X7 O! }$ a2 H# F+ n8 g
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)1 D% E6 M2 J' e
1 p$ e( J& H, F
# 保存检查点
' B( V5 G1 b& ^! G% Acheckpoint_dir = './training_checkpoints'
0 `% U" J) [+ M6 D& G Tcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
% ?, }% J2 L2 X, |% x) ^. O5 icheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
- P2 y4 P3 [7 O$ ^- V, W discriminator_optimizer=discriminator_optimizer,; P9 M1 X' ^+ v% s6 k/ g {8 U
generator=generator,- B' P ]5 A# m7 a
discriminator=discriminator) a- c B# Q2 p2 r
( K7 @% N- n6 [- v2 I+ F( W1 L( C: O
# 定义训练过程9 u: z1 {3 F& X
EPOCHS = 50
" V/ l+ W6 v, x q2 i. A+ ?6 fnoise_dim = 1007 b: d6 }8 m$ i6 t: I
num_examples_to_generate = 16; N$ {+ N* P! i" @
# J2 z9 ^' P& I
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
/ Z( e% `9 f/ f' bseed = tf.random.normal([num_examples_to_generate, noise_dim])
]4 u2 e9 J' K( l9 ^
2 E3 ?) {- z. I# 注意 `tf.function` 的使用+ n( x: [$ W) j, _; f* }
# 该注解使函数被“编译”
* O4 w; y/ s; O! p3 n. O@tf.function
3 d; U6 w( i( T+ ]$ Zdef train_step(images):
/ o) k5 J. x o& t/ q noise = tf.random.normal([BATCH_SIZE, noise_dim])
8 M I, ~- f; [5 h3 J
/ B( d7 E4 W+ V8 H# ~) L" d) ~ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
2 h, W( X1 F0 j3 k2 h generated_images = generator(noise, training=True)# j" ^5 ^% M- \, R
' F; c) i. T0 o- p6 ^4 O real_output = discriminator(images, training=True)
9 ~7 \7 ^; o! i2 c( }' ~ a* ^. n7 W fake_output = discriminator(generated_images, training=True)& }0 E1 N5 K* j4 \) s1 D! z1 q0 D" I
# w3 [' f- Q! {" G) } gen_loss = generator_loss(fake_output)5 k' v+ k8 N+ r6 J/ }
disc_loss = discriminator_loss(real_output, fake_output) e: D- T0 n0 {' y5 I
( C2 i8 P" z0 u* \0 @( { gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
# t; O" D6 \1 Q* I gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
. M2 _, ?6 r+ O. F9 z: Y ' p9 m$ Z2 [8 t" s! ]' ?; W
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
) j* ]7 j3 t" H3 h+ j8 \ discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
+ r/ n# w6 b2 _ 8 d( c9 L# Y7 |# f+ }
def train(dataset, epochs):
) Q% T& L8 E% {6 m* C2 t for epoch in range(epochs):) C9 _, B- _6 \0 Z" |# b
start = time.time()# {% e1 O0 V& I n$ `' v# z
6 S1 B. B2 s5 A) R7 u8 v2 ~ for image_batch in dataset:
7 W7 B4 ~8 c/ Q4 J& {" G( W train_step(image_batch), @! |0 z5 _! g6 C) d5 U& n: d
: R; D: ]4 F( c: ?) V
# 继续进行时为 GIF 生成图像3 X' t% R0 q& D' W- k4 _+ t
display.clear_output(wait=True)
8 N+ @2 Y5 S2 M generate_and_save_images(generator,0 }9 K1 V: h# }+ }: o
epoch + 1,( O, g; b Q+ t2 Y- [: {
seed): T5 f0 s+ P* m, \$ ?
0 {% T2 t: t R4 @ # 每 15 个 epoch 保存一次模型6 p3 o# p e8 ~7 v3 m, ?
if (epoch + 1) % 15 == 0:% }) Y& k y' e' Q3 g6 W, @
checkpoint.save(file_prefix = checkpoint_prefix)& y8 G; v- w# s8 F) `2 `
- T, I& U1 F5 k/ _3 @2 V8 t
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
% N6 A8 F7 J7 I6 c , `% g1 Q+ p; G* p- S) v9 H e
# 最后一个 epoch 结束后生成图片
, M$ i0 o8 ~& d% A8 d display.clear_output(wait=True)0 R1 n5 W; q# y, A8 p" S
generate_and_save_images(generator,
! C F$ L3 J; j" q* ~7 | epochs,, H6 ?1 M6 Y" ?* @9 u( `5 n( L5 v
seed)$ L4 _4 e* h3 Q; M
7 g- A. c' Q0 e m7 |# 生成与保存图片
9 m/ N) J! n' _4 K% Mdef generate_and_save_images(model, epoch, test_input):
/ a' b( P: |; M ]3 ~# I # 注意 training` 设定为 False
& Y/ R6 G7 j% i7 L7 E # 因此,所有层都在推理模式下运行(batchnorm)。
8 Z: i% S# @' X. h$ a predictions = model(test_input, training=False)2 w( z' K! i" e! ~
- M# o# ~) r. i. s fig = plt.figure(figsize=(4,4))' ? l6 F$ m* a- x4 L: m* j: F9 \
* I& E' H5 |' e9 G& ~
for i in range(predictions.shape[0]):" _/ b) b3 q) S- u. U1 y
plt.subplot(4, 4, i+1)
^* ^( S* j% |) y: H plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')# o6 G8 _6 g) A5 d2 z9 l
plt.axis('off'), K0 m, W {' t, U8 p! S
" y' ~! h1 O \& m% j/ [* X
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
3 K, M/ X }- V3 o: w! h plt.show(); d Q- M: ~8 e) P
+ p5 R8 H6 M' ?- x- {: o# h0 J1 r
# 训练模型, E% H" u) X P) }/ Y2 c4 b
train(train_dataset, EPOCHS)+ Y a; C) q4 p$ f2 k- z( C
' D; W' S. x8 M" B" m' J# 恢复最新的检查点
c4 r; x5 I+ N) {checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))9 U' e7 h' t' H+ z9 F" a
% `% h2 u; D: Q/ Q- r# 评估模型0 d2 g4 K C* B1 g3 g$ C' Z& b
# 使用 epoch 数生成单张图片
# F& I8 Y) c4 D2 p* a! B. fdef display_image(epoch_no):" x R% H$ B1 ~+ @
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))# D, q/ A0 \6 @( `- f
4 R6 z0 f1 i; u/ Mdisplay_image(EPOCHS)$ q3 d' _& X$ t+ P
& Q$ v) b$ c4 v) nanim_file = 'dcgan.gif'
4 T5 t% Z' {# }+ s+ b ( E0 B; U8 a. z: @. H2 v
with imageio.get_writer(anim_file, mode='I') as writer:
5 B8 [4 J# E& [/ {" K+ q5 X filenames = glob.glob('image*.png')) Z: {1 w0 Z% Q3 ]# n
filenames = sorted(filenames)( ^$ V+ H6 o% y0 b$ {$ _
last = -1
0 h3 \( s' g( c7 E) `8 O' n for i,filename in enumerate(filenames):
8 L+ O# R+ @* L( E/ f frame = 2*(i**0.5)
5 G# B C; T* u% s! _7 c+ R if round(frame) > round(last):2 ]$ T; _2 c; }% [9 e5 v
last = frame
, S% ~$ _: g8 c0 @ else:+ c8 v8 i( E- ?+ r% s% B& W
continue# N* h3 J$ ]% u6 L& A( b8 Z5 [1 |9 a
image = imageio.imread(filename)/ S3 q2 \& d& a& a
writer.append_data(image)
# R- p2 V/ w0 v; }& k image = imageio.imread(filename)
7 x- p- _' ?5 W8 d$ M6 S3 _ writer.append_data(image)5 u( G+ x' q+ f9 f* ~! k
8 }& Q5 T( o% _% |1 \& \
import IPython( i/ m! [% I1 g6 U# y, @6 O
if IPython.version_info > (6,2,0,''):& X ]2 m/ g9 M' j: d7 R' T/ o
display.Image(filename=anim_file)& }$ E1 H6 h; h2 N. N
参考:https://www.tensorflow.org/tutorials/generative/dcgan# R$ b9 W; V7 ~, \1 c* k
————————————————6 b0 C B4 x' @
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。0 s% p+ s: N# J( q6 q3 Q3 R$ e: S
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111$ ~. v$ u6 l3 @
% j" v' G: M4 V! ^/ j/ L
3 y; P" M. f/ D
|
zan
|