在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 555911 点 威望 12 点 阅读权限 255 积分 172146 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 18 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
' v. R' d' O/ Y) h 深度卷积生成对抗网络DCGAN——生成手写数字图片
2 f, D* ]+ P2 I 前言
9 ?7 i+ ^; N1 S2 Z7 e7 f 本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
, j* N* F. A/ J$ r; W9 y, }: u 5 f- Z- E+ \$ |: I8 M; ~8 b
" `/ S2 e' s: _( r! s. ? l D0 E 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
4 }4 m+ }& z% B' w% z2 L " L( N1 T% b6 ^! h$ d$ ]
i# s$ \$ |& k) b # 用于生成 GIF 图片 $ W9 B, Q& Q7 I% y, w
pip install -q imageio
+ u3 Z' ^& ~/ r' I6 r 目录 : B, ?7 V2 j* Y- g
* n% [% h# P+ y+ @+ ?" y
. R% w. J6 S7 R9 H1 r: x 前言
/ o; X0 Z. d3 ?# X7 G2 }1 b* ~
. [8 m, F9 o5 h/ s2 Y
. j4 }- Q6 A9 x v. Y% d 一、什么是生成对抗网络?
A1 {5 ?$ e* B6 k" E
: ?8 {5 m" c! m" ~
6 ]; r8 ]" R# s/ e; W5 X: } 二、加载数据集
8 ]: ]% ?, d1 }! m' ?$ D ! P+ M* }; T% L" r- g
6 {$ [% C$ [, V3 x" Z& \: L8 c3 v
三、创建模型
7 b" a) `6 J6 n 4 r% ^3 [$ M4 B" _ A
1 m- k N3 y: r
3.1 生成器 ! Z+ t4 |6 D* a b# E. X
. \9 j- I ^& e# W/ w0 k
! |: v- @. U! ^
3.1 判别器 H8 x' i/ y2 g( C: J
' n9 ^7 N' ^, M! \' V4 x
& t! f; D- |* ]: ~# t3 x* O
四、定义损失函数和优化器 % \- e+ G$ |% ?6 I( |$ d
) F; a4 w+ p3 J5 ]
. u8 d- }1 l; N% M+ o* V
4.1 生成器的损失和优化器
. R% P( z( I) s 8 I8 a+ I U1 }4 b; H6 l9 S9 \
" ^6 n( p4 R$ ^* o! U
4.2 判别器的损失和优化器 / e# N7 V7 P" r: @9 _3 f" {8 D9 `5 x
8 W' P |' r5 V2 c0 Q' x' U5 b1 o
5 ]( L# s* P# I- ~9 O 五、训练模型
& g+ I _2 `) l R# N " |: n# @2 R1 y& k7 t
' }% O) _+ U& x5 s 5.1 保存检查点 6 A! r8 b* q& N& u% r
( a$ Z( Q/ |2 ?7 v/ L
" p: k& C6 v1 b0 A- R8 N0 Q 5.2 定义训练过程 . E% `6 z' e( `( ^( {2 I% t6 \
; e9 a# j6 N+ b; I# t9 _
; p. k4 @7 r4 ~( m 5.3 训练模型 & V- b5 Y# A0 o+ C7 g9 W! T
) Z% k1 i: ~- Q: Y& C9 l, |
2 C- ?6 U0 V2 | 六、评估模型 5 x4 t0 @; L# |7 y; o
7 ?! u1 u2 B+ L+ f6 C' ^
/ Z( o. _4 V3 t- Z* [' O
一、什么是生成对抗网络? # N! a: C9 M* q8 K. b, {& \$ V
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
5 R4 u! [7 |/ W
?* R/ \* t5 D0 {4 w, C
! Q6 g9 I: |+ s$ r$ \! ~, I 生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。 1 C) p4 k# _( e- Z. d- s) t
5 z; e( @0 @8 j2 ?3 q
1 F* k0 U1 T" D% V T
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
6 C, l$ G3 b) ?" F; F5 q& T
9 E5 W1 B) @; D# w $ j1 }' e7 t% S8 U! U
训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
& e# l; ]& f W: @7 ? $ [1 z6 u2 w/ M" D; G y. j* C
0 H- n Y6 `6 e- F/ r! c! |; r 当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
3 Z( a: C. ^8 o% }) V! h. z $ A: V1 V. R3 a) r2 j' a: {9 T, l
5 J" Z: ]0 A5 ] U* b7 U' @ 本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
; r# G. \, F4 y2 M% {# e 1 E# a( M; U9 c; l
( I w! W5 J5 }/ P" U
二、加载数据集 . p0 ?% ?7 q2 R a
使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。 ; T+ V I+ z* q5 o( ]; n
) e7 z/ P9 ]& I# P4 O; E/ G2 d V
5 V- u9 m# J6 |& V2 Q1 S (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() 2 V- s/ i7 `* H) ]0 x$ X& Y
- ^6 U4 J" {* o: _( v. @$ O train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
3 C6 H6 q7 W8 c/ i; w* v2 s) @ train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 0 S9 ~3 A$ I+ h0 |+ R
& ]/ e+ k$ |0 K4 b& I% y0 @* ]& N4 Y BUFFER_SIZE = 60000
9 {0 U: o* ?2 R0 F) Z BATCH_SIZE = 256 / b J+ u9 t$ V" E2 l! t q
% k; R9 r* Q4 F # 批量化和打乱数据 : o# Q& K* e6 ]* }1 V9 A
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) d6 z/ q% z% S( I+ h9 U& t' Z
三、创建模型 + E* h5 `# L- O6 |+ i9 p
主要创建两个模型,一个是生成器,另一个是判别器。 # r" h; z( u6 V1 r# p
8 d( [# t/ l d4 e* d: r% l P: `
% _& ^! x) E: o- T 3.1 生成器 + s1 L4 E. _! q$ f# S
生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
5 a2 j; ~9 K3 i* g4 H& K
! o# e4 P" X3 h# N% L. _
/ U& n% w0 b3 b& W! K7 m4 L" M 然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
6 K/ _0 `7 N; _) Q1 w% T 8 t( ~; m: ^" r) \: P, J# k
3 k; n s/ ]1 K! u 后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。 . l$ d; n `. X
0 M3 n8 b: y( u' G) @5 a; y ! T: {* \" F0 A+ B" z
def make_generator_model():
6 K+ d, M/ G4 `/ N7 X8 r E model = tf.keras.Sequential()
4 m* B9 @; b1 Z1 T8 a model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) , {) z } V* E n: Y
model.add(layers.BatchNormalization()) 6 \. G( H/ q( |, @
model.add(layers.LeakyReLU()) . @" C& V3 s2 Q5 Q; H
: d$ q! w8 I. Q$ j model.add(layers.Reshape((7, 7, 256))) + ?6 |8 y; P: w/ w. D. p3 Y
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
. ]3 }( N- \* w% M9 p + F: a- l" A( l `: m: X8 Q" q
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) # B" [% {2 a3 E s* `( q7 K; w$ J
assert model.output_shape == (None, 7, 7, 128) 8 b( ]& |& O8 h6 I! |
model.add(layers.BatchNormalization()) 8 U: Q" [6 a5 }
model.add(layers.LeakyReLU())
. Y; x" F( @6 w" _$ [, p ( E! ^' L7 Q5 H* s5 t4 D
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) 0 s+ R7 I( w* R% c) O; U
assert model.output_shape == (None, 14, 14, 64)
/ V) J9 [+ h. |! s model.add(layers.BatchNormalization())
0 _. m U/ c6 L model.add(layers.LeakyReLU()) 4 M# y( o* X1 {" ^3 F7 O* R, M
- t- Y7 D8 z* _* X" Y
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) 2 X1 j( n- R: g6 U; v1 `" {
assert model.output_shape == (None, 28, 28, 1) * j. X- E/ W3 K/ g2 C# w/ H1 x
) o: E% d( w) M return model 3 ?4 M4 o5 k* C
用tf.keras.utils.plot_model( ),看一下模型结构
" X: k* [; Y) B) ?$ k& C, ]8 b# M
" x# k4 z" I; J, x: [) ~9 C
3 E" y# p3 O8 c3 @0 A6 {+ g
1 J" K. O: I( p9 I9 |
: t( A4 A: n$ Q% v; N: f 3 D1 y& S6 [1 C- s( Y! w7 u9 Z B
用summary(),看一下模型结构和参数
4 `( P! q: n4 g+ l, ?' J0 M
* o* }2 A+ b R t / n8 A# o @$ s% z; m4 D2 }
; F3 b1 g8 h1 c: h; x, Q # a/ c+ V4 @7 z0 n
7 P9 P+ A8 g: }
: ^- i& g! w0 F, \: I. m 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。 / E6 c7 u/ y" K* _5 y6 ^" i
1 v- k9 y' x- M4 ? ! W+ k9 J {/ k7 x
generator = make_generator_model()
7 d) w; ]/ c) B; Z8 _6 g + Q% O% \2 X. e2 a& S v) p @
noise = tf.random.normal([1, 100]) ( P# A: V# y/ c' ^6 H
generated_image = generator(noise, training=False) . }+ Y) G( R; O/ ?# L8 |- I
1 U3 D# ^1 ?0 P$ M) m2 K plt.imshow(generated_image[0, :, :, 0], cmap='gray')
( }7 i( ^+ b# |* k$ b! T, Z8 @
- J; a/ u3 B1 j! `; Q" k4 Z/ e+ @ % T* H' Q5 O, _' N8 A; G
3 [# u: @2 @1 x
( ^6 y: I) Y& x 3.1 判别器 - W& R% ]" ]7 t5 |! h. U X
判别器是基于 CNN卷积神经网络 的图片分类器。
, j- v$ w* Y- N' K6 k# a7 v 7 Z a! h& v6 p4 q
" ~* t2 W' Y W! \( |$ @2 A
def make_discriminator_model(): 6 r& Z: E! N9 B r: Z# T
model = tf.keras.Sequential()
2 P- w1 K9 B# R- b+ d. s+ H! H1 ` model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
5 s" ^0 m' k! g- m ~3 E# J input_shape=[28, 28, 1])) 0 I' z7 F7 {" L' V! Y
model.add(layers.LeakyReLU())
/ \" t+ c0 H! G; J; i2 O model.add(layers.Dropout(0.3))
0 C- Y# o( h6 Y# n; T& B
$ R5 N0 m1 S) {: D6 w7 B model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
- B1 H0 p. I T0 S5 c2 G model.add(layers.LeakyReLU())
' S1 l8 C/ a6 p4 [1 e: @5 A model.add(layers.Dropout(0.3)) ) h: c/ F, V" h) Y+ |
, C+ T: Z {0 c+ ?% B model.add(layers.Flatten()) ; s' {% K& V9 i7 q* V4 i1 ~# m. @
model.add(layers.Dense(1)) 3 w& i" F4 h' n- ?; m9 v: T
- z0 {5 P8 M. V return model / I8 ^: D, P3 r2 c) s
用tf.keras.utils.plot_model( ),看一下模型结构
4 p8 [6 c+ R# q& j: Y
1 ]( T/ N. Y- C
: S/ R9 ~- z" K4 l0 ~7 A8 B6 j" D# F
4 n& V. u" b K7 F: F% P/ G 3 ^( g: U3 g5 b8 U: d7 D3 a
, j" m/ a+ F5 w l& F
1 e( Y7 u" z/ N- B# E 用summary(),看一下模型结构和参数 6 j# _" p5 _( Q g* R3 V
@( b3 {5 g- X) i: R
7 }7 s* i9 v! a( D $ d: Z3 x, {0 z, s1 T
; n; W8 n3 \; @! W5 e& J0 v( e
& M J! q+ `& [& D2 z4 n 1 l7 u5 ^$ N m K1 q
四、定义损失函数和优化器 $ n7 V5 n1 A" i6 R( ^( _. d
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。 2 I1 Y1 d; e0 L6 L# ~/ U5 a% z: \2 k
# j6 T6 K9 Y9 d# R- k : t- k6 h# a- b7 O3 M
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
* w$ b3 i0 {! e5 h$ g* y% x 2 }- Q8 q& V2 n1 h( `/ ~
. P K/ B5 e" j0 Q: j
# 该方法返回计算交叉熵损失的辅助函数
# w' I: ]( i) k! q' b cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) ( R; G' j8 d, Q! w0 S( K
4.1 生成器的损失和优化器 4 q3 n4 L9 U6 v/ i3 W2 i' D/ x
1)生成器损失
8 D% D) T7 i/ ]3 i 3 p6 s( p0 q S
% J) k2 l2 h6 D
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
3 Y2 B7 r' E. h. F# c
7 ~8 e D6 q. z! C3 \8 \ $ ^! f; {6 I) P' k( `9 d
这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。 9 r9 v+ H+ {! ~& D1 [" a Z2 F
& r. e0 f. i1 T {1 w
, [! m" Y$ v; v S* I! z+ F def generator_loss(fake_output):
9 T, Q6 G0 p' k7 _0 j! O$ Q- L return cross_entropy(tf.ones_like(fake_output), fake_output)
3 j; i* U. x9 a0 l1 c; U- A1 i 2)生成器优化器 ( P, W# q) ^3 e1 s0 }8 h7 [# r3 U( b
4 Q6 [( H$ w& E, y; [$ i $ @* |7 B9 Q. Y
generator_optimizer = tf.keras.optimizers.Adam(1e-4) 0 ?8 E+ V- N) E; j$ S" g" i8 K
4.2 判别器的损失和优化器 $ k' p. `3 t. C, i r( h
1)判别器损失
# p4 t9 }0 Q1 g. e0 f / H9 Q p) H6 i: _3 U; K7 l
, R! k! E4 @* p/ Y1 C$ i
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
, b2 ?# B& Z# E: c + Q& t. s( U/ j/ {- C; r
3 A# V' c; Q9 E# C z% N- r
def discriminator_loss(real_output, fake_output):
$ O1 c4 z8 N: @8 u5 T2 J real_loss = cross_entropy(tf.ones_like(real_output), real_output) . Z$ @0 I) s o8 k8 B
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) ( J, h1 Y$ Q8 R) ^3 I3 G4 k) u
total_loss = real_loss + fake_loss f& [6 r* F4 M+ ?* {
return total_loss
' d+ D6 D. _" a" ? 2)判别器优化器
1 F5 b, |* P7 {7 ~
+ q# c5 j; b; Y8 p * f: F9 N; t1 q1 Z
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) 6 ~# ^- U$ J( i) W) N" I1 }
五、训练模型 % _5 R% a, d, M8 m8 @7 L3 W
5.1 保存检查点
) Y* Q+ M, R/ v/ v* v 保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
0 ?1 ^# X* ~+ Y9 W" [- a
5 \5 @9 G! n7 x- S' I! e 2 U. |+ j0 Q. V, O R
checkpoint_dir = './training_checkpoints'
8 R! J& y) H/ S8 {7 m checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
9 Z& b& d8 f( s checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, # S# h# x+ B# v3 @4 }6 a
discriminator_optimizer=discriminator_optimizer, + R% ?( o& m$ ]/ M
generator=generator,
7 j. t7 y5 n3 V o! X0 O# h discriminator=discriminator)
- [' S w; ]% Z 5.2 定义训练过程
0 E9 P% ?: V; h7 D EPOCHS = 50 6 p4 U6 ]5 N+ k2 x
noise_dim = 100
- ~* p0 E7 U( q$ u; {/ \( { num_examples_to_generate = 16 & X6 `) c, j4 k4 G/ ]8 @
& O3 E- `( s. b
+ x1 Q4 \$ r L+ O; Z1 T # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) # W! V- \2 V; q$ A8 e6 g# @: ?
seed = tf.random.normal([num_examples_to_generate, noise_dim])
3 h/ |1 C! x9 y5 w6 n/ W 训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。 2 b' M8 G! `9 T# L6 j
. w2 R8 W$ _: d) Z
+ x3 J5 |- N# f3 G6 y* c$ n 判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
5 p6 p* P3 L- D
" N0 a% [% a$ G1 p0 {
$ L# H& W: U: y% a( x5 ]8 S/ E 两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
9 h! X5 x6 ^5 J: U6 H % `% v L8 P9 R: b! m* r' g" R
( N/ S" m+ J$ d( U' S. C # 注意 `tf.function` 的使用 * t0 W: H+ A4 c, f- z, c% S# m
# 该注解使函数被“编译”
% n* l8 q2 K. j/ v! x* z% ` @tf.function
& N: Z2 W' b' a# H# ?# [' r& ^ def train_step(images):
: T9 U7 n a/ C* j$ q; H1 X noise = tf.random.normal([BATCH_SIZE, noise_dim])
9 E1 c5 L* ^; x& p, U * s k5 u9 P; u. g3 d6 i
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: , @2 D3 T* E k3 B/ o
generated_images = generator(noise, training=True) 6 i9 |& @6 _: M- F9 f
7 z8 a' b- u7 s: T real_output = discriminator(images, training=True)
0 d) H* h' Q: } fake_output = discriminator(generated_images, training=True) 6 Y- N, p, v! B( |; ]% Z; Z
0 D9 f% |! W' @- R; y& Q gen_loss = generator_loss(fake_output) 7 U4 C$ }) b/ w" T* u; _( u
disc_loss = discriminator_loss(real_output, fake_output)
2 q8 F: a. G2 p5 ~8 a% M' }3 S
! P5 l* _: H! R; B% ?* D gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 0 z6 ], i7 U7 h4 t* w1 ~
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
. j7 T( `( k0 X1 ~ X. K 0 @# ^0 l" O+ `7 X
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
( x9 F# X( e+ s% i9 Q2 v discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) # c: _, f U1 R' q- H2 m. T9 a
; B3 s+ h' e$ V- q1 |+ g$ w def train(dataset, epochs): , I' `' |! V8 Q: Q+ }' g, k
for epoch in range(epochs):
( \% B5 p; M' P7 K& \& H" l start = time.time() ! N! \8 D3 i9 c( a+ k
: v# J0 T+ p; S: s8 c for image_batch in dataset:
3 W- h+ v7 s- W/ d" h* b train_step(image_batch)
0 |, k1 Z f, S5 m/ o/ ^2 D$ G 5 L9 Y( h! `/ m, v* Z N5 g5 O$ Q6 i
# 继续进行时为 GIF 生成图像
4 o; ~; D I/ A3 d6 Z display.clear_output(wait=True)
& O' R2 c; h8 L' n generate_and_save_images(generator,
! h& Q h( U2 W. y- f epoch + 1, ) G( ^9 A1 Q3 d( K8 U* `7 T
seed) # [# P7 t( e6 A$ O7 m0 V. P
% u3 A/ q3 W2 u! x( F # 每 15 个 epoch 保存一次模型 3 B6 f. p5 r2 X+ ^
if (epoch + 1) % 15 == 0: ( i/ A7 B8 y7 w- [: t
checkpoint.save(file_prefix = checkpoint_prefix)
* I! ~) Q1 L8 T$ Z
+ {, G T+ \& l" ` print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) # v# f2 c. s$ a( P7 _3 Z9 u; d
8 \; c2 H; X6 C! Q" i( _: r; g
# 最后一个 epoch 结束后生成图片
: L' N) I* L2 Q display.clear_output(wait=True) ) h% z i' k7 u$ q/ ~
generate_and_save_images(generator, Y4 I* n' }. {" d5 p, U
epochs,
& a- P0 ^' P6 c. R seed) . q& D* ]" ~' `! C/ Q
8 \! h0 S3 U! |! K; c8 ?, c8 a
# 生成与保存图片 + @8 c' a: ^9 T
def generate_and_save_images(model, epoch, test_input):
U4 t* H5 \% a' z. X, x+ O # 注意 training` 设定为 False
) i# }4 [" C. s. S4 z! g9 G% ` # 因此,所有层都在推理模式下运行(batchnorm)。
% W0 I9 b, b, V, F- @ predictions = model(test_input, training=False) 1 W9 s- s* F4 ?/ E7 y6 b5 i
. k3 o. T9 \2 B$ g* e
fig = plt.figure(figsize=(4,4))
2 J0 d# D( d1 q* w: W: a- } ! d0 C# Q1 j: N- c# t# ^
for i in range(predictions.shape[0]):
5 o/ }; g; h+ U/ N$ C' G plt.subplot(4, 4, i+1) 4 ?+ t" \2 @# S7 q) P, i) V
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
/ h7 [$ n% E5 j1 a, n plt.axis('off') 4 M$ v* K9 N% k, a4 _" A5 ]9 \
- f: d" \, r# i" Z8 M+ O* g% Q
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) $ D; o$ n7 l- C {! T% k1 a
plt.show() * H5 _7 l1 _1 h! D0 P8 M
5.3 训练模型 0 V% k2 Z. J. l/ U7 \7 v) ~
调用上面定义的train()函数,来同时训练生成器和判别器。 3 @6 y3 b! @* Q' q
8 U7 H9 d) T) o; E: c0 K- M1 K" W
4 m4 v u9 |, A% @: V/ L5 J* H2 W$ z* G 注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
$ w |, J& Z1 g$ I6 x" q $ ?0 B4 b: `! t w- |- v9 `8 ~
/ ?( ?! _ d0 z0 U" b %%time
, P& E9 E `" }2 Q" n1 z1 n! C train(train_dataset, EPOCHS) u) H$ Q& g) l4 @8 G a
在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。 8 Y" L6 N4 x l# H; ~4 _7 z* s
5 R7 Z2 m% J G
. p q3 u5 S; c8 O0 N 训练了15轮的效果: 2 Y1 Z1 _$ a& L7 z7 W) a
' [3 Q6 \4 R* g; f! e3 J" ?2 x5 t& B
/ t$ |* P( |; _) C
# O. ]3 r" o$ Y: H6 T8 C9 ` # b- i6 [1 O! s4 P& h2 \
! Y9 U$ G" o6 R( u7 R" ]
/ D8 E% X* i* k1 o 训练了30轮的效果:
8 K5 i: e* b/ ?) s 6 B$ T0 ]" z/ n [
% z7 ?1 y9 [ |& a' L/ ?
?: ?$ A/ G) j/ U, O, e; n3 \5 u) D
2 f+ `6 z+ t8 r6 L
# @- b/ T; L$ ^" V& @ m( y$ y ) R- P8 C' D) e2 s
训练过程:
( X' y* ]$ y8 t- S3 Z& Y
; R* e0 [. I5 H4 `" ~2 D
! D; a( M2 _ D9 J8 A& j9 n- T$ ]4 V
" h. J7 s$ Z* g : Y3 j% F9 U# Z, Q+ `0 A) R
5 E8 U" |5 ^0 a3 O, z/ q) { % k* h+ q0 D+ c+ H S6 N
恢复最新的检查点 / d6 V; g: i( H" _" s* s
' B8 y7 Q; V- p0 m$ Y' `4 o
8 N, c9 U" a Q4 s O; _: u8 C; i
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) 8 f/ I/ f9 b4 g% N
六、评估模型 " k8 t8 Z# I# {, ?( g# `2 e
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
* W! q6 t$ ^' h: h2 y( c
4 o6 A# P% X9 P$ i9 b* @1 v+ {% d
+ L( f2 t# f, P6 \ c' Q # 使用 epoch 数生成单张图片 + D& U3 ~( `- b
def display_image(epoch_no):
2 C, l6 e$ a; J8 f return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
6 l2 d* s3 [" ^- m9 C
! y+ g" B: s! {+ y* E# O# \) G8 H display_image(EPOCHS) ( l6 c1 a- q# i+ |+ K2 T& T/ Z* d
anim_file = 'dcgan.gif' % e( a/ z2 J: U4 T
, M, V, S; @& y/ N5 g' [* ^ with imageio.get_writer(anim_file, mode='I') as writer: 8 ~; w2 }8 I+ P& \2 Z5 Z) u
filenames = glob.glob('image*.png') . z7 ~! a) w, t5 U* z
filenames = sorted(filenames) " v9 w, ]. t) c) A; e5 f6 {
last = -1 * S0 m/ D% ?3 e1 T2 O4 y1 V9 j
for i,filename in enumerate(filenames): 1 X- \6 c8 c+ \
frame = 2*(i**0.5) , I1 E0 y8 P3 g/ i* V
if round(frame) > round(last):
- l0 @) a1 Q+ \/ `% s last = frame ' `2 f# z4 ]7 M- x' L4 [
else: , I5 }! T& J- {" l7 K, c2 Y d
continue 3 Z; Q9 m9 |3 Y
image = imageio.imread(filename) % C9 P- a j( }5 o9 [% Y$ m: t
writer.append_data(image)
& }1 Z" P5 Q/ g: C6 u6 o image = imageio.imread(filename)
3 u# r! C! j2 | writer.append_data(image) 5 k% E* k0 {* v1 D
) { {+ e8 e3 j; h import IPython ' z+ k, L; t" O' e4 |3 S$ z: j
if IPython.version_info > (6,2,0,''): 0 B! A9 r& N3 s5 B9 ]5 u
display.Image(filename=anim_file)
; F# f4 x/ A5 c! X- e
8 Y5 A8 K; }" ], O 9 D# Q4 `5 |% |4 N
( s; |4 R! @: h5 _
* h: i4 |6 E7 O* p 完整代码:
5 M7 B% a" U- H % w9 N9 k7 G3 p( X' I8 A+ ?: L
! `# Z$ T; h. [( \1 |( o5 C0 R
import tensorflow as tf
$ A" S9 V( E" @( P( p5 ^2 ] import glob 5 b6 @+ o) F# ^3 M2 T" I4 d4 ^
import imageio
Q- O# z. H3 Z. d5 Y" O( [+ t. X0 o1 U' q import matplotlib.pyplot as plt * j" {# o" ^) c6 m* h( ~
import numpy as np F8 v; X Y3 [/ h
import os
* h6 u; {& B9 r+ ^+ i import PIL 0 |! o5 \3 i. ~; i9 X
from tensorflow.keras import layers 6 l) z8 U4 Q2 O: i1 |
import time 9 r& `) C2 K* p$ f% ~8 U
/ w' \1 F. Q$ o& }# C8 \
from IPython import display
$ }' X/ L7 @& Q 5 u* N E7 o( L) s K: {
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
4 w$ i4 S3 N# ?$ D! q2 f
# O1 [3 z' d* q3 f! g train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') ; g" p8 u& Q9 |+ c8 h0 L# z% n/ I
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
4 e; |# W4 L& j+ E$ O2 j 5 ~8 H; ^, f, @& ~9 ~
BUFFER_SIZE = 60000 2 s% E% I5 z$ P) |; r( e6 |, z7 {
BATCH_SIZE = 256 * F2 m# g! n. E |
& T7 m2 v* k* V% `' z* e- P- c9 H) q
# 批量化和打乱数据 $ ^6 [& r) f8 k: Z! X
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
' J1 k# s& x" w
& g" N: k4 V$ c& y # 创建模型--生成器 H/ O6 x* L* x) o
def make_generator_model(): 1 c/ R. w' P' f& M6 l
model = tf.keras.Sequential() * C9 d2 v1 K+ W5 [
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
( O. m" ` |0 z/ n model.add(layers.BatchNormalization())
" ^, u1 c1 g' q5 k! E+ {& q model.add(layers.LeakyReLU()) $ F3 e& s. Y) @" a
$ Q8 c/ }: M9 m4 H
model.add(layers.Reshape((7, 7, 256)))
% q" K# z2 @- u: f, a% k& z assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 ; G, n% ~& a+ ^: y( K" E
) h9 z) Q% u0 b
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) ' ~5 c- C# j7 w: z3 v/ ^
assert model.output_shape == (None, 7, 7, 128) : s) E. x B" m2 ?/ j, t
model.add(layers.BatchNormalization()) # x' E: |1 }3 S4 g0 A
model.add(layers.LeakyReLU()) 9 @* E1 q1 e# _' j, s. B: h% Y
, j7 s0 e+ t2 @% P+ ]; D
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
% l7 _9 u+ u; x9 u! B( T7 H assert model.output_shape == (None, 14, 14, 64)
" i) p4 r! v; A$ T6 G2 i model.add(layers.BatchNormalization()) ) h0 ] V' Q+ ?5 Q& i% P: s- i. g
model.add(layers.LeakyReLU())
" o# f+ W( D$ C) r5 P6 A7 h1 S/ z ( J% j1 N( _( U; ]
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) ) h( _; F) k& N2 ^6 i+ S
assert model.output_shape == (None, 28, 28, 1) 2 u6 k. w- |. R; v
0 E" r" G2 v N P
return model 5 ?. z( x9 t8 l( e
$ q8 z' E; m X; v* Q2 V
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
F' p8 n: b. Z, @+ s generator = make_generator_model()
) K G% V6 \/ p# d
0 m3 b. e* A$ x. V( u( ] noise = tf.random.normal([1, 100])
! K4 P7 j, b$ f( w2 ]' g1 m2 g; H generated_image = generator(noise, training=False) C6 y4 w: n1 Y, g5 W" ]
. d- [9 ^5 g: \2 m: o, w9 f plt.imshow(generated_image[0, :, :, 0], cmap='gray') ) V& X$ z0 U' }' L* W g. Q& O0 ]
tf.keras.utils.plot_model(generator)
2 g' y, @+ [( V' W1 E
% G, a' T) X) A: f" K* _ # 判别器
- E3 R' G. b) N def make_discriminator_model():
8 @6 B2 u( C% U$ H5 F- @ model = tf.keras.Sequential() " \9 i4 @. r) t2 R
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
# M2 I7 ~, t0 T3 S! ~8 h input_shape=[28, 28, 1]))
/ h3 @1 v/ |* D+ x4 j# O1 l) a# S model.add(layers.LeakyReLU()) 6 _0 \: ^) f0 z) L" d0 K! Y: Q2 T7 F
model.add(layers.Dropout(0.3))
. ~9 J. d2 D( B; |4 }0 z 4 j* ?6 W b" r
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
& C$ @% @( {4 D0 }' j2 \- f model.add(layers.LeakyReLU())
& {. r% ~8 V( s/ A; a0 U! a model.add(layers.Dropout(0.3)) * J m; {( `% b/ `2 X$ s
) F" E) _' U. G- Y7 X
model.add(layers.Flatten()) 3 C% f6 S/ q; f
model.add(layers.Dense(1))
# Y5 `0 X: ^/ \1 ^* y4 g3 Y' q ' A1 L. P9 v& `4 e7 Q6 n& N8 }! ~* h
return model
; H2 j4 |. H" g6 }8 c , ^2 A1 V, B5 F' U, D3 L. j
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
6 Z9 U# N- @: Z discriminator = make_discriminator_model()
' {0 t J& v( `' G9 L f decision = discriminator(generated_image)
6 ~; o$ H+ u1 @* {+ A print (decision) , s0 p6 Z7 Y0 S1 F n" L
1 [7 Y$ F8 I4 V0 g9 v9 F
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 : F& X: j* i# b/ t v; i2 a
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
" T) X9 d# V# V4 y% [8 j* ~ z 6 J/ a$ U% @5 m* z
# 生成器的损失和优化器
3 L+ z2 }; ?5 v) E7 n0 j1 h/ M0 m def generator_loss(fake_output):
' g" {, X8 T3 F; v0 d5 ? return cross_entropy(tf.ones_like(fake_output), fake_output) Q$ p3 u6 Z, Z8 O( g5 f' f
generator_optimizer = tf.keras.optimizers.Adam(1e-4) 9 j: b( l) @1 V( [* M, R
" e8 d. B1 \9 d/ ^! ^ # 判别器的损失和优化器 % |% e6 W6 ]; A2 a; B
def discriminator_loss(real_output, fake_output): . M5 ~8 W' O5 t
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
& d* V) ~( u0 @4 z8 j) C" h/ b fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) , c+ w& P& T; @: v" e
total_loss = real_loss + fake_loss ! O+ [/ w% r) T: @. n- e
return total_loss
F6 p p! s3 i9 I0 |$ y* I4 p discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
: Z4 C- `2 w; |4 m' o ) e4 L P7 `4 J P9 R5 c+ P" K
# 保存检查点
5 s% _: T6 X0 b* | [2 E! N. O checkpoint_dir = './training_checkpoints' 2 ]1 O: W. F8 o8 F4 g& L
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") , o$ i& [) A$ q1 j' v) @" E L9 k
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, ; h" J4 o' L" t/ R
discriminator_optimizer=discriminator_optimizer, 7 m2 K, o+ j7 E: H
generator=generator,
5 N8 l5 W0 ~. b& m discriminator=discriminator)
" S) H" _- Z8 z! l
% U/ y% Y$ ^2 U+ y) C # 定义训练过程 0 Y2 x( M/ V3 A* }+ s. |
EPOCHS = 50
: L# s& b. J5 a6 D5 k$ \ noise_dim = 100
2 y$ \7 @4 Q8 A* K# n num_examples_to_generate = 16
2 u$ F$ ?5 ?4 P. j; m2 s) ^" ~1 @' U: _ # g; m0 d# n; V! `( j- e* Y
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
$ n( B# X7 R) P8 G5 w' s seed = tf.random.normal([num_examples_to_generate, noise_dim]) 0 h* n+ F2 ^9 b0 d: n; C
9 }9 e# F, H/ u0 S& r! X # 注意 `tf.function` 的使用
" W. W2 Z$ L3 r& w2 ` # 该注解使函数被“编译”
: v% b6 _% h1 p @tf.function * D, |4 Y- `& a F' j8 V9 g
def train_step(images):
: `5 B- j7 ?8 W( W! I noise = tf.random.normal([BATCH_SIZE, noise_dim])
" L- A( `: h6 ]7 Q2 _ _4 ]0 e $ V9 W: T2 k7 n: a- p, S$ ?
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
9 l) x- Z5 t2 A3 }/ `/ Y1 T/ G; O generated_images = generator(noise, training=True) . K) p5 M( j3 P1 ?* L& |
+ D! b" |6 X; Y4 _% B real_output = discriminator(images, training=True) 6 Z% r; n1 g# S4 t& J: G U
fake_output = discriminator(generated_images, training=True) ; h1 T5 ]- ]1 q
" X4 }8 V4 f! g& Q4 x gen_loss = generator_loss(fake_output) 1 t# A/ h, C4 O: y0 k/ I3 Y
disc_loss = discriminator_loss(real_output, fake_output) 1 h. x5 V: B% ], F/ h- O" N7 c3 D
8 E" I/ o$ V* c# `5 V% q
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
; ^; D# z. E' {. E) N# y gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
9 p4 C. S- n7 T- L' w4 F% O7 H5 I
) Y5 D1 W5 Q1 r' n5 R! N/ e7 F generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) 0 F& J2 X$ E; f- J2 T
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
3 r; s; Z. W6 P( I7 m3 f
! Q! g+ m0 h7 ]4 c. j( | def train(dataset, epochs):
b$ y2 K' Z9 v& S" P" R& x for epoch in range(epochs): |4 q+ m0 k) x* a! ~/ S, Q
start = time.time() 9 v! r& n# C; _3 r1 K' ]" Y
: Q% ]1 ~' V& G, b% S1 d
for image_batch in dataset: ! r; T+ o* W+ p$ C1 G8 n5 ~
train_step(image_batch) ! Q; k- A0 S8 X- o2 [3 h
* a) W" Q$ M5 A% l3 ^( X # 继续进行时为 GIF 生成图像
- v* m; q3 Y# L; m+ t' L display.clear_output(wait=True)
- I5 v, v4 Q0 J& m3 t2 o) C; U ]! ? generate_and_save_images(generator,
) Q' U& P2 G; m) w$ Z6 n epoch + 1,
" M- T' v3 n1 L seed)
: a' ?' ^& J/ G" j6 t & n( y. N4 ]9 G. L( T/ g! f7 j
# 每 15 个 epoch 保存一次模型
+ c6 W" L& K& _) o: f: _ if (epoch + 1) % 15 == 0: 8 e7 x- C# a4 K( t
checkpoint.save(file_prefix = checkpoint_prefix) % @; ?) O2 ~0 g) }/ c- J8 h7 y
! b( c! g6 N: }! ?/ l/ o l* D print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
% M3 N; n$ m1 }1 a* }# T& t . c3 P# ^# O1 [! l. B7 a/ \, i
# 最后一个 epoch 结束后生成图片 8 q8 ], F/ x$ q1 d
display.clear_output(wait=True)
& Y# z& K) }7 @( o. Q generate_and_save_images(generator,
/ D8 |7 A/ i+ ^! {' y- x H epochs, , L5 |& o v, Z& {4 P
seed)
* {$ F( X7 }6 @
' B- _" f; W' z( @ # 生成与保存图片 ( M6 o/ y; m1 M( `) G* n
def generate_and_save_images(model, epoch, test_input): * h1 K8 D- V' _+ i! @7 I; y& N) \: J3 W
# 注意 training` 设定为 False
- R5 n% k! C7 `2 B # 因此,所有层都在推理模式下运行(batchnorm)。 ) y) e# c* \4 Q) C0 ^# H: j7 R }
predictions = model(test_input, training=False) ' l( b" w& C9 W- A. r; G
* G+ h8 v) N5 q/ m. q% V% g fig = plt.figure(figsize=(4,4)) 3 }6 c$ A+ n! \2 u: M- t7 p
' u( p+ ~- E, c* P for i in range(predictions.shape[0]): ) y# H' b4 }8 j. c4 c; [: C
plt.subplot(4, 4, i+1) $ r( Y4 w. Q% w& \: Z9 i
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
& V3 ?6 q: q* E6 A plt.axis('off')
& C: @" Q7 ]7 U! v( L& D2 \" j+ D
' j7 m: `$ y! ?% R- m9 C plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
& ]" n! |7 X7 V6 R2 y, A* x# {7 v6 p plt.show() 4 q3 u9 G# E. W% i; m
" l: k+ A$ R! E* P$ `# U # 训练模型 $ a! ?3 ]; M- w, H5 [
train(train_dataset, EPOCHS)
4 q' Q/ r# Q* r' o2 h, O$ E
/ [' }+ \* o. J! s& y # 恢复最新的检查点
. Y( Z- R* _$ y6 x9 W$ s" H checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
4 c! V0 n2 _/ P. C3 n. t5 C 0 l5 t2 q! Q% y8 J7 ]$ z% }
# 评估模型
5 L c* n* q0 ^0 S' D # 使用 epoch 数生成单张图片
5 v; J$ i6 d7 U6 G" `6 X def display_image(epoch_no):
* F/ P1 t+ s& u( i& w return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) ! {' K# J( Y: C
+ q! b" z# S& ]) O! y) m, M
display_image(EPOCHS)
: T' y# A% z, E0 }
3 T/ `* J7 ]* Z+ B9 T7 N8 f) u anim_file = 'dcgan.gif'
' S# d4 y: [% K* M$ S% N; V( p
! s/ _; _0 L9 u3 X% z( r% K: F with imageio.get_writer(anim_file, mode='I') as writer:
8 m* k$ \7 m% t8 k9 C( D8 y filenames = glob.glob('image*.png')
, T3 T' t* k9 I6 \! o! T" s! |- p filenames = sorted(filenames)
# U4 |& D) C7 @3 T9 M6 G4 }% y last = -1 . x5 M9 D# f( s0 X( F9 m2 p0 K& G+ n
for i,filename in enumerate(filenames): 0 X8 O- f, ^* V# K/ u; N% g
frame = 2*(i**0.5)
- g" ]1 B3 \, I9 y" c6 D& D if round(frame) > round(last): 4 ]9 P3 d, H' n. m* ]
last = frame " G! Y* k" [4 R: z
else:
8 v u6 f7 R/ W5 q+ u continue 9 p+ a$ T/ T8 U+ c+ s0 w; ^; C
image = imageio.imread(filename) 9 |: A5 p8 ^2 L! W2 f. W G
writer.append_data(image) 6 u; a: I [, M# |0 w
image = imageio.imread(filename)
2 Z: n: G9 x- n' }* u, R writer.append_data(image) % f0 ^( r6 g2 d# t
; f4 b$ o' R/ Z3 ]3 ~ import IPython * }* Z' ^# k- {
if IPython.version_info > (6,2,0,''): * b- U4 Q0 l. @' ` a
display.Image(filename=anim_file) 8 R5 ^+ [. t( [' W; p' H! q
参考:https://www.tensorflow.org/tutorials/generative/dcgan * i# c N g: _# q1 w% ]/ Z" b
———————————————— + ?2 L) o4 B* i- a
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 * u5 ` b6 f9 ], `8 ^, i, y1 e
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
: ?7 U$ d; C* K/ W2 i & n( _ A4 A: R7 U$ W0 R) |& j
: c* t9 w3 w) _* K( I
zan