在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 563312 点 威望 12 点 阅读权限 255 积分 174216 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 3 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
, |5 }, _5 {, V9 ^# z1 G
深度卷积生成对抗网络DCGAN——生成手写数字图片 4 \- V( S7 u& a4 \, {! T
前言 0 u$ P+ e8 Z% r/ v+ { K% F
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
& L7 |5 o% t1 A) K* z' k+ ?. o ) Y0 p* W3 X& U: ~' Y% h
$ a k5 f* Z) f" r% o# c 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下: ' H# v( j. c) d2 Y0 t
- R" j4 j& d& k. l$ C* I9 E \
9 F. y5 m7 s) Z9 c3 {- W/ @ # 用于生成 GIF 图片
1 N" k; S, J, U7 L pip install -q imageio
6 |% Q5 Y* B0 W7 x 目录 8 l5 N% P* d5 o4 P; ~0 w6 C
) O, X/ ?* f2 I+ h" s4 V
* @- V; j' o1 f8 Q; |3 X 前言
( {8 B* e/ t) w, k* X" Z
; W, Y" k" r$ g4 v* H
# G( n% E6 `$ H6 c" Z s# |% M" X% `3 i* R 一、什么是生成对抗网络?
2 o! R2 r. Z$ Z- P) h
$ m5 R( b) s+ O5 l , n% P2 n' J; f, U# E: c+ x: f
二、加载数据集 ) X( U/ ?% U8 @
: M- C; c9 M5 q: F& P% ?- k1 x
- w5 c9 t( z: W/ j. u7 Z
三、创建模型
9 Z8 v6 s9 L8 A% J2 u ; Q3 y" {. B+ E0 `+ `
/ L. L6 B; v+ i 3.1 生成器 : R. r, ^! h6 O: m6 t6 }# x
- X) @: B6 C. A: [3 A T
& X! A I" q6 p 3.1 判别器
) [7 `; V" I4 z6 ^% w, Z O% R9 o v; a, U7 Q- W5 N
) z* U, ?* { z0 b. n
四、定义损失函数和优化器
8 @4 i$ G1 Y+ X' y + M: Y+ a. @7 D8 Q" u/ z, |1 w
7 c- f: Z, Z6 D+ ~
4.1 生成器的损失和优化器
, j6 D, g! Q m( W+ a* b 8 V- R% l: y6 V$ V0 v+ R$ n
$ _* C0 o, ^0 j# [3 S& z; o9 Z$ z
4.2 判别器的损失和优化器
0 U! m1 }, x% j9 _ ) L& h- W+ G/ Z! L3 A! e
8 Z! Z4 c, K: f% l! n+ F; M: f1 t
五、训练模型 % B; c. c9 j# d
" p" V$ ?5 v* Y& m5 l' _
. V0 {& n* ? R
5.1 保存检查点 * C: k% @ A3 }
+ }( |* G i+ S- X2 i
- t7 k% A' S8 s) t 5.2 定义训练过程 2 x: k! F: s# ]
1 T* o. v. Q% E( G. \' x2 Q8 d # k' _9 T( B( ]) c5 J' T1 e
5.3 训练模型 4 C, M/ F, X3 F! O: H
. B' b# V9 U. {& D* j$ l# f3 J
. @$ f' A& ]1 @ P% N% i 六、评估模型 8 z: u, _% F1 H8 D& }/ Y
7 b, @' k8 g6 i# `
/ V6 f7 _8 d/ t; k9 N/ B: D 一、什么是生成对抗网络? . Z/ x3 C9 [% }
生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
8 T- W& p, g* g 9 ]8 v2 c( o$ T1 B
6 p) o+ f5 P, P) \# U- Q6 C _
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
8 R2 C* x! h/ ^9 p7 ]/ b 7 S1 X9 c* }9 ]5 `3 ~
, o+ K. T2 m' M0 l0 l- ~+ o+ A
判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
2 V5 ~2 x2 Z0 ^* s* f* Z5 g6 m, ?9 n
9 k# `' ]- ~+ _ j) N6 z
& C* k0 X* _. p0 \ 训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
: ]# v$ i- f9 F, {5 Y
( f( b$ [) B% l: e Z. ~: f G8 n7 ]9 m
当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
5 F) W% l f- r
. W( b9 d5 D8 q8 p, J+ [
! B$ f6 u. N3 k/ y# m5 e0 `# h, | 本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。 7 [5 ?; e) r/ p9 \
Q s# N9 P3 |8 V6 q& W$ A+ ~& v J, J7 M3 a+ Z1 g9 N6 i
二、加载数据集
& l1 X/ s6 j+ ~$ `! {1 H7 e3 k 使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。 ( [2 G: I+ Y4 `, l" W- d+ T
/ D; X+ P/ o6 d
) F* |$ U- Y- A3 X$ O (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
) a; H7 v6 _! G$ t( T5 V9 B& X " W" d2 i g. r2 J- w7 q
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') 6 C3 e& [& N6 N5 b5 Q2 Q
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
( u0 e7 b1 z1 Q5 h' }+ \0 j+ w& @ / r3 C, o' |( J v" W, ?) V" s, v4 K
BUFFER_SIZE = 60000
/ M* X( |& ~& ` BATCH_SIZE = 256 8 R' w- l6 S$ r; A$ O
- R* A9 d# M$ ~# O/ S # 批量化和打乱数据
4 \" ^5 A/ {0 H- {/ e/ E# J train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 7 w# c4 Y; z1 H+ Y
三、创建模型
# p- Q$ @& s. |8 S 主要创建两个模型,一个是生成器,另一个是判别器。 & H' V+ C: z+ e( h V. K3 i2 ]
9 Z( g! n2 C! @5 `1 z! a/ |7 P5 {
1 R" W5 C/ O- M @3 U2 u5 m 3.1 生成器
5 S7 o6 v5 F8 N 生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
- O l( c- ^' x9 V. t( G
! X* |( c6 {2 T: v+ I, X* e
0 n# T5 p8 a$ f$ I 然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
% z9 X7 A. K6 B! D# o
5 C) X+ K9 k2 g* }( W2 ` @. t& F
$ h1 B7 P) `7 H7 b, J0 C# _ 后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。 : n3 }* q7 W$ r3 Y W
9 n0 c/ H [: k6 D . {) y' Y- ?! M' ] P. N7 D
def make_generator_model():
# e) Z$ i N) v! K% J. ?; b model = tf.keras.Sequential() . G. M; i( C8 L1 [; V* O8 y
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
r* D) f& O' k: b7 m model.add(layers.BatchNormalization()) 3 g5 g- ~& j( C) \
model.add(layers.LeakyReLU())
7 h0 t: j# G$ N6 O+ s! z: {3 }: r
: j, [1 C. d- p9 J4 y model.add(layers.Reshape((7, 7, 256)))
( A2 `! }: c* g6 Q7 l Y assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
& S# C& K. P$ v+ H0 S1 h: a
# M! x& d% m& Z" y model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) ' e/ v2 b( A! u! L1 H. V. v
assert model.output_shape == (None, 7, 7, 128)
: P. e2 X( O1 a& ?; z! j model.add(layers.BatchNormalization())
" b$ P; J* l x9 [* J- t model.add(layers.LeakyReLU())
) h- m% A) i3 Y& G3 o. ]8 G0 n3 t% o ; Q/ }+ X! h2 @4 F0 E
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) ) p" R5 C6 p, D# O) `- O
assert model.output_shape == (None, 14, 14, 64) 2 v3 i! H# P. S* N; y8 K
model.add(layers.BatchNormalization())
3 T' ~* P- d' z7 U/ _ model.add(layers.LeakyReLU())
0 \+ J! x4 S3 c$ s 1 L, f% ]0 E- E) t8 c
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) : O; d6 u6 C' G( p! p, n
assert model.output_shape == (None, 28, 28, 1)
% e& E6 h% c; N. V# R
5 E! u0 a& n3 e0 C6 O3 Q+ A return model
, Z2 q7 G. o$ V" Q5 k/ ~/ @( _ [ 用tf.keras.utils.plot_model( ),看一下模型结构
+ Y6 Y( ^9 p+ o/ p) ?0 h( N- U6 G5 I
1 Y- {0 \& k3 e4 p 4 s' V5 |8 }. ^* Z0 F2 ]
- m- H9 f0 D3 ~) ^! y9 H9 I A
8 D1 \& @% X$ G1 A
3 O& c9 Y0 I% T t
用summary(),看一下模型结构和参数
: j* z; c' M) e4 a# [7 q: A# n4 J6 O
& m% [1 V$ N- X9 z! A# M 4 P# t. O/ D3 p+ N
( G* G, F# k" J# U9 {
, A# j& R. x E3 p
7 l& w) O/ L! i8 F* a ' s( y4 W# r. |4 [
使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
* s/ t* I2 x% w! }$ O. i$ X
4 K- j- ~ |" M6 p: f 8 x" r$ A9 m, J" B; i5 q8 K" t! M
generator = make_generator_model()
8 D! E5 Y4 j, S+ a! F " z' B/ W2 Q' H7 t; t
noise = tf.random.normal([1, 100])
- v& V# \* o+ Y% l, {. X generated_image = generator(noise, training=False) 4 U ?2 @. i$ ]/ _; ?2 B- X( @
2 ?% y, b" k8 Q. H4 s" w. Y! p. _ plt.imshow(generated_image[0, :, :, 0], cmap='gray')
: @8 P [7 y3 B$ N5 G6 X9 d
6 Z, l4 ]7 Y$ e5 K( h* {; N
" D, L+ a( r" E5 B4 s4 V! J
+ @5 s$ V0 F; Q; R2 s/ D5 S- ^
$ D. [& Q/ E3 R: l% N 3.1 判别器 6 j$ O* v4 Y3 ?) r; l, [) \) N
判别器是基于 CNN卷积神经网络 的图片分类器。
' t0 Z3 @* h* |- |4 j( i% t # q: V$ M) \ ] P
8 }0 a2 f0 r; [/ j
def make_discriminator_model(): ' \- t3 ^+ L& e6 R1 o
model = tf.keras.Sequential()
t0 y) O& o/ q" g1 \ model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
5 q x. r( h& Z4 x2 x" K& G w0 I input_shape=[28, 28, 1])) 5 W) |+ H4 W( a/ b
model.add(layers.LeakyReLU())
# P6 B6 s; @; v: N3 v1 _; b model.add(layers.Dropout(0.3))
# m) P3 k i; p1 u/ s* |3 n# c 1 r4 X7 b( o% Y# a: r3 n# _ |
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
) o& X, p8 d# U6 G$ w- G; l0 O9 ` model.add(layers.LeakyReLU()) 7 r+ M m" z0 T- E+ w/ N
model.add(layers.Dropout(0.3)) * G" B4 E! a+ J8 Y- s
/ _8 {1 t+ o7 e
model.add(layers.Flatten())
* s: p6 q; }8 R/ F model.add(layers.Dense(1)) ' A1 c0 ]* W9 l+ {
7 ~: |" d+ i) [; [
return model 0 I/ P# F) P3 D0 L& D. n$ F& |
用tf.keras.utils.plot_model( ),看一下模型结构
4 @" E3 F2 \" d( D% x3 J
3 x. K( R2 l4 n" h$ S: p, T6 v& i
, _8 @ X0 f4 T( r- E) X8 J" l
5 e# q2 A/ I4 b
6 l/ `' |+ T `5 z7 m . A: Q3 `7 i: ]. q! {
% f% j6 p; c6 W; a: b 用summary(),看一下模型结构和参数
& J0 ~3 b7 m+ k# |4 f2 D: u- R6 z
, o2 ^! s8 }8 \0 w+ s" S
* l% w' A* F" S( T
9 M0 P' ]6 @; F0 v& l5 P. D0 y + ~# E, Z; v% J( Y" `8 N7 T
; A1 y4 \8 P8 y2 ~& \" k- `
& ?+ r. D. t: u 四、定义损失函数和优化器
( ~0 E+ f- H6 G" e9 |! K 由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。 2 ?, C- Q0 }9 a+ z1 D1 o v
- a2 k/ |+ I) F7 E$ E
) O9 ?6 o, l ^ L" h7 f& }4 m 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
& C2 v" A0 _" ]$ @ 7 R* D/ Z9 s2 a2 g
1 T2 v7 u6 U3 M: g9 l1 Z6 e # 该方法返回计算交叉熵损失的辅助函数 " k" H C4 X z
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) : d" B4 J; b3 A8 B. `1 B
4.1 生成器的损失和优化器 . [+ V& t( G1 }
1)生成器损失
$ V2 z- ?0 m" J1 P n- C# K* i ) m( K0 l# }) _
4 [* ?. U; t+ T
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
0 G3 ?" l, c9 D( C 4 T& c5 @) M. ?1 k
, w$ j! F- [. A 这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。 / n0 _ S' Z( ]; Q7 ~
' Y. g' v( I: A2 `8 b
$ w) `9 f, k8 j. p% ~- R def generator_loss(fake_output): 1 T4 K9 J. a2 m. L
return cross_entropy(tf.ones_like(fake_output), fake_output) , V( b' e" ^* z4 v; n& R3 Q
2)生成器优化器
1 K) j# N4 [0 w3 a j0 @8 i6 C- x1 o( F
. o6 P$ x- D4 E8 N! s& s: k 0 H: N; T7 h4 K, p! L y9 j& \
generator_optimizer = tf.keras.optimizers.Adam(1e-4) " ^3 g5 l, S8 N1 E+ d
4.2 判别器的损失和优化器 2 t& v$ l5 \& }
1)判别器损失
1 w$ w9 f, ^5 B/ @5 r
! B# s( {5 l/ n3 K% R $ P7 m' J( }( u. }1 O5 c4 w& z
判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
/ _; M4 l" Y( R' `/ D
% L2 R# j+ u1 ?) {
* K& ]3 T5 Y8 q4 B! ^5 { def discriminator_loss(real_output, fake_output):
7 L( b) `) ~1 h; D# F2 \ real_loss = cross_entropy(tf.ones_like(real_output), real_output) # l% S" _. G7 B- F) w# z
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
$ |& Q9 X: u* ]4 k# g total_loss = real_loss + fake_loss $ x1 \) t3 x, `& S8 M
return total_loss + ~; f( v7 t% P0 s* B
2)判别器优化器 2 c1 i) [9 i# k
- z! S3 V i4 D% R6 ~! u# E
. L9 W# E* o( m3 ~
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
) u6 p) k6 K# | 五、训练模型
; V. m, w) j4 s) A& d Z, {0 _4 _ 5.1 保存检查点 6 j: F' a; b; k
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。 7 d& }: q3 L' B) b- D
- d2 I" m! [* Y. `0 |
v9 E: j3 r5 ^/ `$ d k checkpoint_dir = './training_checkpoints'
, Q& Z+ r! u# r, b7 P checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 7 d/ q6 ?: y1 @ N: i Y
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, - {0 E, ^& w3 {/ l; `6 C& r
discriminator_optimizer=discriminator_optimizer, ( n) |: a* o1 W$ c
generator=generator, 6 _7 A1 ?% k+ M7 }9 H; |
discriminator=discriminator)
5 ]6 {; W( Y5 T4 m# }- r 5.2 定义训练过程
) N& q& H' i" t EPOCHS = 50 % ?5 J8 K B& E$ R3 w
noise_dim = 100 " c0 W9 I* b! h2 I2 Q* f8 s7 b
num_examples_to_generate = 16 ! T8 Q) H" _$ o
5 A2 b/ v* A6 F! {8 N
9 [+ E' u/ @+ E$ Q # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
; D* J/ w) Q( \/ m, t i F seed = tf.random.normal([num_examples_to_generate, noise_dim]) 2 F* C8 r- ~6 k- @" ~9 q% W
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。 8 z/ V0 e2 R) i9 ^
8 N3 y* S& x, w$ `8 M( n. S 8 N0 F) Q7 j3 n" q9 g! L
判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
: i( z5 u9 E2 c5 q 3 U- d; u5 b3 k, B
2 G1 i, k+ \0 f! m9 {2 C: O& H
两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
% C3 |) V8 C% s8 Z F3 T7 j& X
; t% [2 b0 i4 Y$ J# Q# x
$ T4 n0 W- m9 n: n9 _% z+ ^ # 注意 `tf.function` 的使用 ) U# K1 m' j7 ]! G; t0 v$ S5 A
# 该注解使函数被“编译”
6 ]$ @" D) @3 g V* Z. c @tf.function
3 Z. q) E- h# W. |2 u7 g, E def train_step(images):
+ Q |" |* u1 Z6 v" Y noise = tf.random.normal([BATCH_SIZE, noise_dim])
G; h! t2 t. {
0 u# W& b& @# {" a" z8 Z with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
9 b& e* i B& C, a generated_images = generator(noise, training=True)
% W' c# W- d3 r! ~' c& K6 _ E% U& M1 H ! z& G4 S7 K g6 Y: ?
real_output = discriminator(images, training=True) 1 G0 T3 _' l' q) N; `( ~- V1 A
fake_output = discriminator(generated_images, training=True) # ~( w5 |% K3 T% X/ d4 {
% k: n5 e1 F& { gen_loss = generator_loss(fake_output) * b/ \+ V9 S' ]! a
disc_loss = discriminator_loss(real_output, fake_output)
5 g) [5 V3 c4 z: P9 Z1 @) I & E3 C( I4 J1 O8 j! i) x
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
9 n I$ r6 C6 P gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) I% w3 @. c( i; E4 Q
" v @& U- Z* s7 ], }/ x0 X+ U generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
, P% n' v! P: c- M) M discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) $ o$ V. J1 Q" [# w$ s
' K- {7 b+ m: M4 _" i+ t9 f" Q def train(dataset, epochs): 3 P* b- u i. R" {
for epoch in range(epochs): : e, \5 E' f% K. ^% d/ [- y
start = time.time()
3 [: u0 s8 D% X( \" ~, `
7 y6 r) v: S L8 n v# J, v for image_batch in dataset:
. N, d# j$ B* A% n# F5 j train_step(image_batch)
$ W' I: [8 E; a8 x& c1 U; O8 T9 m 9 Z( t y% r4 w: \
# 继续进行时为 GIF 生成图像 . X" W$ I. m8 p, O
display.clear_output(wait=True) 5 Y; Z- ]2 {" q% [2 V
generate_and_save_images(generator, . V2 |9 \- ]6 l) E4 U' J
epoch + 1, 5 B' z9 [/ I, m
seed) ! G) {+ E6 I+ |
/ R. J& {& O$ r8 o, |' \* D
# 每 15 个 epoch 保存一次模型 , X& J# \7 n9 v3 N' n
if (epoch + 1) % 15 == 0:
: [" H# e. f0 A checkpoint.save(file_prefix = checkpoint_prefix)
% A, ]+ {1 s* d: [5 A% o) G
0 c+ `: [# T- y1 \ print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) . ~. J. B- w) t- f6 G
- }/ Z9 ~3 D( y& s # 最后一个 epoch 结束后生成图片 / Y" a) l9 M# [+ P. t- z
display.clear_output(wait=True) " B! ?- Q9 `$ m% @6 u( p
generate_and_save_images(generator, % |0 k5 ~0 W d" C2 s" d2 v/ F
epochs, 3 i5 j4 X+ p+ B2 E
seed) 4 [' P( ^4 _4 U. ^$ q
: L( _5 L% R* v8 {
# 生成与保存图片 ' y! A* r3 C v* U h& `: l! X
def generate_and_save_images(model, epoch, test_input): & C6 A, `( F8 Y4 G: ?/ X4 o
# 注意 training` 设定为 False
$ i F7 Q' N) j) Y% P3 X # 因此,所有层都在推理模式下运行(batchnorm)。
8 ?1 k. v' `0 ?5 u6 w2 Q, x: r5 V predictions = model(test_input, training=False) 8 W, @3 r% p& O) T, J
; B V: |2 @& o4 X* n fig = plt.figure(figsize=(4,4)) 7 r4 f' y! F. [
9 n( g8 C$ s9 I! [9 P; g for i in range(predictions.shape[0]):
( z/ c$ T- @" K1 q plt.subplot(4, 4, i+1)
: i w' l/ b {1 d/ _2 b1 r+ q plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
0 e- v/ G( b' z, v. J' K. L plt.axis('off') + h8 n4 v. o: j, ]! [2 D! e
4 W+ Y- W$ f( V8 X" z5 L
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
* A' o: f5 d$ N" K3 S5 Y; [* A plt.show() 0 A g$ s. }& Z- D0 V9 m @. f
5.3 训练模型 ' ~* M+ `& C, a4 Y( J
调用上面定义的train()函数,来同时训练生成器和判别器。
' ?2 V* u- S5 N: P! U
0 y$ ~1 g( _. I, [% ?1 Q# ^8 J
# J |, {9 ^# o/ j& E6 P 注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
- Y9 D9 F$ I+ P s
5 R$ W$ M A2 v! _2 v1 L. X ; S+ ^ O; x, Y& A
%%time 2 W2 x+ H! j1 O" P
train(train_dataset, EPOCHS)
% A5 ?, p- x0 w) X' n- ^+ p 在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
7 A$ @+ S7 \2 `6 x/ P% M$ [+ |
" T& L: @, K- v$ n
% m7 o5 B1 C( \$ Q5 ^ 训练了15轮的效果:
; Y/ `4 `/ Y: `" X9 W
3 \! T' m! z+ X7 g! A 5 j4 g* q! [2 U) I
6 S' O. z' u# V/ }* q
4 X: k" v/ d; U* G& `- O
0 F) F- t- q( w' N* [. `
. R+ y+ i1 N$ Z8 v
训练了30轮的效果: + `( A# v! b5 u1 \
4 v* H1 d. j& _' k) |1 q" |* {
$ U* x8 j$ T; `' o" K1 k- v8 ]* R5 I h
# q% f; Y, l5 Q+ N3 f! ~& U; o
8 l( S8 E, k( ]' f/ L4 R, I9 p
3 S1 v9 J( o( D# \: q. u3 a 4 q6 T* R/ o5 i; W' Q
训练过程: & p9 g/ Y1 _+ a" U
1 I9 l0 E z! n& W3 s- Q' \2 ?
0 A- @; u1 X6 ` 1 ]+ Q* w; J; w1 F* `4 @
( R0 u7 \, A( [. ?
6 g' t' N1 e ~
8 k% M+ V( D/ P# i- f, T+ C5 Z# l) @
恢复最新的检查点
4 z4 e' ^! r0 r* j9 T/ D6 l 0 f4 u* r Y8 r3 M9 X7 l* ?
: P; h+ s: l/ X9 x$ M
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
* \3 J7 M1 x( k, m 六、评估模型 $ G( m F" P! J
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
, ^5 k1 w! V# e+ u! R
2 D. R: {7 C* w
) a& { |5 z& k- _- n # 使用 epoch 数生成单张图片 3 c5 Q* B) s/ n" z
def display_image(epoch_no):
* e6 C1 X4 A t% D1 y4 Y return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) 4 i5 ^& Y0 m% B, k, v
/ B* I' Z& K9 z: Q8 T display_image(EPOCHS) + F1 C! F. Y- F: a% k
anim_file = 'dcgan.gif'
' m& d6 ?* r% E* p1 t v5 A7 r
" }5 C' e2 G3 W9 h: w with imageio.get_writer(anim_file, mode='I') as writer: 4 O2 `7 v* Q1 D
filenames = glob.glob('image*.png') " m x# o# n! a
filenames = sorted(filenames)
: n- Z0 C/ i8 c! A t7 v last = -1
+ l2 ^% O4 A1 a: Y8 H, ], s for i,filename in enumerate(filenames):
3 p3 G6 @" @# l/ `9 F, M: A frame = 2*(i**0.5) & E8 y2 a: S( l4 Z6 U' u
if round(frame) > round(last): 4 t" A. P y; b: }+ _3 e
last = frame
) p* \' K+ }9 w5 G else:
2 ?$ ~) i( V! R; f: x continue
: R k/ ]; U i! ~& b image = imageio.imread(filename)
$ X- X. P: J% ~8 p1 x writer.append_data(image)
/ N( C: p9 F; ~/ m8 F image = imageio.imread(filename) ( ~3 j9 c+ O* _+ K; H% {9 h
writer.append_data(image) 8 V" P" d1 N1 B. x
- V8 h. B6 |4 ]% c% A; f- v
import IPython
& m3 i% e4 L y. U9 C1 o if IPython.version_info > (6,2,0,''):
; k) L4 k3 ~( D% i: ?" r display.Image(filename=anim_file) + t, s% t' s7 s% t
1 D9 f/ g5 V4 U4 [% U& @% A3 I 1 |; A5 F- t2 I
7 v. v2 P- b, V' {: c0 E1 a( ]
7 A! N! k2 K. ]3 Y 完整代码:
8 b- n) h/ ]* n/ @& M3 x- _ ; }' I) c% z3 v& r' k) z# ^
2 c6 B- K! E: E0 | import tensorflow as tf
3 }7 k d$ S; ]4 b import glob
& U9 w1 l- Y' ^9 [( v import imageio
# ]. V! k3 D0 e- u3 T9 u import matplotlib.pyplot as plt
! x3 O/ p. l' t import numpy as np
# \5 ^$ G d' O6 n O import os 1 }/ _5 _# r7 x2 q" X3 i
import PIL " j; k% f* C2 S0 L9 C; ^* }
from tensorflow.keras import layers 1 ~3 Y3 _2 Y/ w
import time 1 w. ] c& ]2 R: n
- N+ ?' t, V8 @" ~
from IPython import display
3 x: W! _9 O: I ; G7 G% e% ^3 c3 O1 o4 S7 u
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() + h! a/ M! `7 \4 p& W
3 K4 C6 j$ Q! k' j0 l" w% x0 p train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') ! g$ I, l3 r* j8 H J0 O/ s
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
0 T; S1 M+ F9 f; Y x
8 z0 W: f# @2 x$ ^; c+ g5 \5 f BUFFER_SIZE = 60000 6 Z7 o8 Z* w2 G9 y
BATCH_SIZE = 256
# ^9 x2 Z4 I ?4 P6 c2 g 3 B+ H% a2 t/ |% i6 h8 m
# 批量化和打乱数据 / P$ s/ {! W+ H: B2 j7 l$ q, S" w3 J
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
* _9 e% k' R& j$ w' [0 x% |; ~
$ k0 Q7 s6 x; n( } # 创建模型--生成器 ; e! j& H0 b7 D) f
def make_generator_model():
: w3 \9 w2 D1 b: e model = tf.keras.Sequential()
' h: i5 X3 E5 ^ model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) ; C( G) A3 Z8 B
model.add(layers.BatchNormalization())
9 S( G& Q# Y5 L8 V% O0 l+ r; | model.add(layers.LeakyReLU())
* Y& U! ]5 R4 ^9 \3 N
. M5 _! _3 d: p7 m1 D0 ? model.add(layers.Reshape((7, 7, 256))) / m. D0 I5 w, o* C
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
7 I* A0 L" @3 i6 ~0 e( n
|, a/ ^" ]+ [% S9 a9 v model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) ' g( m3 I% O- m- e* m8 R% G: k/ n4 l
assert model.output_shape == (None, 7, 7, 128)
# q9 E4 x- f* N model.add(layers.BatchNormalization())
6 q' B8 X4 J' l. _ model.add(layers.LeakyReLU()) 7 C. u5 X9 o% T, H* ^
& h! Y' U- ]# \+ O$ O, E model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
2 X( h, Y% U k' A assert model.output_shape == (None, 14, 14, 64)
9 f, @' g* t1 S" ~/ z model.add(layers.BatchNormalization()) / J: g2 Z) `8 W" F% m1 I
model.add(layers.LeakyReLU()) m5 o. q, C; |: r
/ J1 I7 q: t' R1 s
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) 2 t3 t% k' e8 M1 q
assert model.output_shape == (None, 28, 28, 1)
5 ^+ v* \* ]3 q# f 9 S D8 N7 Z# G! y" Z0 Q) s4 t
return model
3 ^, ~8 Y6 [: m- b) j @# ] Y. W
4 @9 Z# s" `* S: Q # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。 + ]8 k4 H7 k! p" O, Q
generator = make_generator_model() - q7 z7 u9 v9 y/ b
* m0 _3 ^ Y9 s& G. k5 \1 z9 u
noise = tf.random.normal([1, 100])
: ~9 O+ {0 C9 L- o$ I; q8 d generated_image = generator(noise, training=False) $ Y% q' ], H' L; r7 O% W
& i3 ^ T$ d2 s
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
: S5 t( b/ Z, j7 G: Y, G tf.keras.utils.plot_model(generator) 8 O7 D0 G" W0 k
( k' `) O- H6 D3 y2 @0 z0 s! w$ x
# 判别器 ! |/ t6 ?; V) c* k3 y
def make_discriminator_model():
) P% B0 t* z% F3 p model = tf.keras.Sequential() 1 P) Y/ ]: m5 d; z1 H; ~4 G( z
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
" B$ k7 x% |) Q# z8 n+ _ input_shape=[28, 28, 1]))
1 c3 D2 `' a" S4 v3 \1 V o& ? model.add(layers.LeakyReLU()) M- P) q2 I/ }- D: }
model.add(layers.Dropout(0.3))
) g' E6 q0 w. j% u: U
9 ^- ]9 K& n9 m. A* y" B& L model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) 4 u0 @( g$ o2 B
model.add(layers.LeakyReLU())
" u! E$ T- `8 _: K) U. X$ N$ M model.add(layers.Dropout(0.3)) 5 A! D* q4 y7 @4 |% g5 B4 \/ h$ L$ B* h
! P5 }+ G4 I3 ]" @2 J9 Z
model.add(layers.Flatten())
, G( o0 q0 u2 U9 t2 ^- o% t# Q! l/ R model.add(layers.Dense(1))
- I' i+ l3 I& Y6 z9 X
* Z% u+ W+ ?. O* I- @" ~2 x c! ?+ ` return model
3 }% \& C: V; x, a2 G. g 1 k% c7 \, T. B2 W3 x" m/ p
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
) x# m2 w# L0 u3 C discriminator = make_discriminator_model()
1 B5 Y$ ]5 G) C, t3 ] Q3 X( s decision = discriminator(generated_image) . Q- m1 M5 M7 U+ J, K+ h# f
print (decision)
$ O* Q; y2 n9 s; P$ q
! G, s: f8 G. S! m # 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
( a; n7 x0 o1 M( u2 k' h/ M cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) - R$ W0 s: s* g8 C' [& S
Y+ h3 n( @5 F% U! ]5 ?# ~/ `
# 生成器的损失和优化器 6 D S7 w* J% W1 q
def generator_loss(fake_output): - w1 s' V& D3 H( Z5 ~
return cross_entropy(tf.ones_like(fake_output), fake_output) % l3 j( r, t3 p0 L1 J3 }/ K$ K
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
5 B$ b3 R. ^& C! j% Q. h # U3 p% j, ^% l: z
# 判别器的损失和优化器 & z5 [' E# Q) I( z
def discriminator_loss(real_output, fake_output): 5 E9 M2 k: g7 a3 G' y" s" A
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
8 N0 M* h' c' ]( E" K1 U8 u1 N( y fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) ' d& k8 S. q( y+ ~3 g! H2 {; F
total_loss = real_loss + fake_loss 4 e* T9 n q+ b9 y8 W9 Y
return total_loss 9 r) l7 I) [$ ]
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) " d# ^4 @! W: ?4 x- _+ a! T
, T# T/ \# w# a ^2 C3 e
# 保存检查点 6 l1 |7 @( y' i
checkpoint_dir = './training_checkpoints'
9 J$ `5 b1 T" K" g1 p- ` checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
4 ?! Q0 ^5 R+ s) ] checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
9 z; p6 k4 W2 C/ ~4 c0 t discriminator_optimizer=discriminator_optimizer, " Y6 i/ q' V5 h) T* j
generator=generator, / ~( O( W# o6 p- q* E+ S7 e
discriminator=discriminator) 6 ~$ [. o$ E7 S$ [3 x/ f
% ^# F0 F2 a; G F6 m ^# A/ e
# 定义训练过程
( m3 y5 Q) ^: w X( Y- P EPOCHS = 50 / q% c9 w& I- T/ ~8 x+ p
noise_dim = 100 ) M, L1 F x! Q+ F
num_examples_to_generate = 16 9 o$ ^9 }- U) {0 U% X
7 a- _# \8 |/ l # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) ) N' ]6 I- P5 \/ v: r
seed = tf.random.normal([num_examples_to_generate, noise_dim])
: F8 I8 [9 d* A; G
4 n7 Q" R0 D" `2 f # 注意 `tf.function` 的使用
) C9 Q: k9 ]4 V: J) s # 该注解使函数被“编译” . z. _( P1 v3 {# O2 `. P# x7 F( ~
@tf.function
5 J# B; z3 g) b5 c: B def train_step(images): ! W0 r5 s" s, s- G
noise = tf.random.normal([BATCH_SIZE, noise_dim]) : R% I( a' Q' q
( i1 o) {6 ?# v0 d
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
n) l# ~- h+ ^4 }4 O. s generated_images = generator(noise, training=True) ; V |/ j' ?; t, _; }$ S& J6 I8 k% j
4 d) _ s7 S1 V/ }
real_output = discriminator(images, training=True) 0 {+ S; `* m: H8 Y, M+ G5 V
fake_output = discriminator(generated_images, training=True) 8 R. j( x2 W' ~4 R
+ o, z5 f( P; _$ m gen_loss = generator_loss(fake_output) + Q4 Q" y% M& m5 ?
disc_loss = discriminator_loss(real_output, fake_output)
9 o5 J- z ?/ D
/ s$ L4 O( L4 Y9 g& I& h gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 5 n' H6 o# h0 ~% ?) ^
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) ; A7 V- [4 [" _
8 m2 \3 f8 m$ Z; Y! N
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) U+ d3 r, m A! e, ^2 q" N* J
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
; R0 }$ t: l# u3 O( Z5 `2 D- Y
2 }6 e8 u2 z2 J7 h9 [7 [ def train(dataset, epochs): Q' A- C2 I( z
for epoch in range(epochs):
! x8 O% n8 d+ F1 Y, y4 t start = time.time() 6 C6 |! [6 e) \: K. q1 S
$ q) A/ d3 v+ P9 `8 O
for image_batch in dataset: $ k( {- W9 F: s' S# L E* \& H
train_step(image_batch) # C; w) P) A' L) Q9 w8 v
$ m" |) P1 z6 V' s# V6 w
# 继续进行时为 GIF 生成图像 1 D9 ^# p' C2 Q/ f {; V
display.clear_output(wait=True) 1 F8 P7 ]6 U& m w* f) ~% {
generate_and_save_images(generator, 6 ~0 O7 n4 r- E7 k! H$ |
epoch + 1,
, o% {' Q& z- F6 J& p" X( {# P seed)
/ j$ d1 G4 w: L 9 d, q# L. s& [5 y0 G5 z& ?+ C0 \
# 每 15 个 epoch 保存一次模型 7 [0 w+ y4 J6 S5 ?% s8 K- {
if (epoch + 1) % 15 == 0: ; n2 O$ f. [' w K
checkpoint.save(file_prefix = checkpoint_prefix) 5 C6 Z$ M, k. J9 @, \# _. n
/ M9 Z% O& t x0 B1 V print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
" H$ d4 _3 Z x H5 A6 q- t4 W# _- q: U8 q
# 最后一个 epoch 结束后生成图片
M, O" P% W3 T7 x display.clear_output(wait=True) * F- C! y$ J6 S& }/ ~, l8 m
generate_and_save_images(generator, ( H. x. _# P! |( p+ J# T$ |; B
epochs, ) _) Q7 M$ M7 _: Z9 N' A; g( [: }
seed) ; b& |1 J" A! |4 C
8 z. @; h: |; o
# 生成与保存图片
/ ?- p5 f( \% @+ d* @ def generate_and_save_images(model, epoch, test_input): ! Y5 E' i, d9 N
# 注意 training` 设定为 False ) D$ H6 Q E( N7 N' {; A
# 因此,所有层都在推理模式下运行(batchnorm)。
2 T9 m# M1 Y9 T" ~( w predictions = model(test_input, training=False)
) Z. J4 M/ d6 N
- z7 q# ^, \, P5 e fig = plt.figure(figsize=(4,4)) I5 A2 C1 e; T6 [5 p+ O8 |3 h
/ ~9 f% i- z' C: h- n2 Q2 w; U9 R( @
for i in range(predictions.shape[0]):
3 f# o. [% y& r( P0 u2 E; l2 a+ D plt.subplot(4, 4, i+1) + p% X/ h+ @ |8 m9 v# E1 E5 Y
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
1 S# x4 [" w; W5 n* [% S/ Z# G plt.axis('off')
! W5 \" w* r3 F4 \1 a" e
! a( Q8 s3 P. p% S. Y* m+ s plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) & y R, X) r+ f& s- I$ {9 [7 g
plt.show()
) K0 Q+ A4 Q& \, Y x# F
4 p; I+ e. `# t # 训练模型 3 B, `3 B+ M1 r7 @, i+ t( y# N
train(train_dataset, EPOCHS)
0 S2 v4 p. A' b5 g) y. e% I
" ~- c K4 H9 Z) B& u+ h # 恢复最新的检查点 6 H# e( w( Q+ g8 V1 l B0 y
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
4 ]. l9 O& h& P ( ?7 S, ~* [- i! v1 J
# 评估模型 ) q. z. k. S, J% b, u. o [6 Y
# 使用 epoch 数生成单张图片 1 T8 _: D P6 b# j" H B8 M
def display_image(epoch_no): . c& \$ D- Q* i5 S
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) $ [5 |9 T, x% [, z7 z$ {0 n& }; o% t
! F! c% f% \9 G" V display_image(EPOCHS) * F1 c2 M b& o" r' _/ ~: j
F) q% C- V' `% [" V$ K4 z- S
anim_file = 'dcgan.gif'
% O8 J) X8 t9 w4 [ . ?' {& ?; E# @6 l1 v/ g
with imageio.get_writer(anim_file, mode='I') as writer:
, M% c3 Z$ m- m# E" G* M filenames = glob.glob('image*.png')
& v8 f& s4 R% l) w5 @ filenames = sorted(filenames)
# \6 A9 w: W, Z( k# o" Z last = -1
4 c3 I, f2 |; a; L6 ^6 F for i,filename in enumerate(filenames):
b8 x- B. Q& D2 {, \ frame = 2*(i**0.5) + m' d' X: K: a& `+ J9 D" `
if round(frame) > round(last):
% j/ h( @ K9 A last = frame & ~# Q6 b4 k6 A$ r7 ~5 ^: i+ y1 x
else: ( q/ D, ?! }5 |1 G: B
continue
( {$ E7 l% S" _0 i. b( `; v' y image = imageio.imread(filename) 9 E- h/ M) C- ], p, p& R0 d
writer.append_data(image) " }) ~; H% d7 A2 Z) I! m- c1 x& U
image = imageio.imread(filename)
* ~# i, l: d+ k4 a# p writer.append_data(image)
* v* t& O" d j2 D+ }$ X + F q- A! s b
import IPython 9 U# o* Y3 m c
if IPython.version_info > (6,2,0,''): " W1 Y- p9 a9 G, M* T
display.Image(filename=anim_file)
# m- [5 [# I( V" q 参考:https://www.tensorflow.org/tutorials/generative/dcgan & i) ~ N) H5 d/ o
———————————————— # L* a) v1 V7 [* Y- d. I; t$ P- W- h7 v% ?
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 ! T- U" R. [/ R0 q
原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111 ! f7 t$ m5 `6 y! u& k; ?- S4 f9 p
( y' y% A4 c, [/ f' \
! l0 I5 r, U/ ?; w. r
zan