在线时间 1630 小时 最后登录 2024-1-29 注册时间 2017-5-16 听众数 82 收听数 1 能力 120 分 体力 563304 点 威望 12 点 阅读权限 255 积分 174214 相册 1 日志 0 记录 0 帖子 5313 主题 5273 精华 3 分享 0 好友 163
TA的每日心情 开心 2021-8-11 17:59
签到天数: 17 天
[LV.4]偶尔看看III
网络挑战赛参赛者
网络挑战赛参赛者
自我介绍 本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。
群组 : 2018美赛大象算法课程
群组 : 2018美赛护航培训课程
群组 : 2019年 数学中国站长建
群组 : 2019年数据分析师课程
群组 : 2018年大象老师国赛优
0 j" R% @1 U! M2 x 深度卷积生成对抗网络DCGAN——生成手写数字图片 . f0 r3 C3 _- ^) ]" Y- a$ l, J- T
前言 - t9 g0 D# n/ K
本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。 " t; T6 t) F5 w6 F& q) g
- v( {8 K8 R Z8 b0 b7 Y9 ~ ; |% r8 @3 P/ @( @/ Q% K
本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下: 8 g6 y4 u& V7 D+ i6 u% a5 t, G
1 P& u" R* |; u9 k Q: x" v! y Z; @) G
# 用于生成 GIF 图片 0 _" o: F; k P0 Y( [
pip install -q imageio
4 B# R3 S: B& V 目录
" S; N9 F1 I' P T1 R8 K. X: N3 Q
/ R6 b) T5 Y' J. e4 s 前言
; o) n/ S/ @% l: `' h' } / L0 p9 S% \* z$ P# \: @, c
" m4 F, R5 {& F6 T
一、什么是生成对抗网络? ( U1 F4 \' g$ d! m6 ~
( B, a1 c. @# D# k% z: w
% r" N% w2 G; H, P# l
二、加载数据集 % I6 e' K$ A# |- H0 W( z
3 O8 ?2 \* Q2 ]) L3 `& w `
: D! L2 B% Z& v" b' D2 w4 u5 u 三、创建模型
1 Y4 \% c# P3 R2 o. n! Z
! b/ A; U! ?! E0 \' |8 _# V* }* }
6 g m; I g) d 3.1 生成器
( C( S- e# g1 v. ~3 W : ~2 m8 S; u8 J7 }9 G
' |: g* K! W& }; ]- D" N0 r) R
3.1 判别器
4 Y4 I0 I0 b2 e! F( F5 Q
6 o. N$ M% V, [: X : J5 c2 [) `8 b) E& `
四、定义损失函数和优化器 ' `* a& b8 m6 @8 a6 S, b
7 @3 T4 ]$ A7 |3 h$ {7 \
/ `7 @3 q3 {! S1 _) e 4.1 生成器的损失和优化器
- h# ^- G& ^. V ) T; c2 r' u- K. ]- ~ D
6 b) t. F) ^$ b7 { 4.2 判别器的损失和优化器
1 `9 J3 P7 O& y: V5 G6 Y9 ] T
4 s8 o9 F% s! s) {4 ~, a4 ]4 U: }
+ E/ |( Z8 Q# p6 ^' m! A/ t 五、训练模型
* d6 A- m$ ? l6 T8 F
4 i& B' o8 Z$ h7 o) Q" r& T( N+ i# O + U1 k! ?' D8 E7 X( t; W
5.1 保存检查点
% ^! D: r+ I. ^2 b4 b& ` ) P9 d8 v6 i1 B! n8 k$ c1 A0 i* l
6 W0 [1 I8 W8 \4 |# T V- A8 H 5.2 定义训练过程 A" ]+ E: @% p5 l: U! m6 Y x( N( { f
) g }$ z6 z9 \" N 8 |0 {1 T6 I _1 w* R! [4 S+ c4 L
5.3 训练模型 1 q, z% K j9 T v9 t/ H9 l( I
3 H* Q) j5 y$ a+ v/ }
! j' d% a3 Q/ v ?6 y
六、评估模型
8 Z; g) O1 z0 B. ]* I
0 l1 R. |7 d( p3 Z* m" G" Q7 I
0 ?$ p& V9 W! j 一、什么是生成对抗网络?
; v* \% Y- S! S4 D 生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。 * i- K; x4 _" P) _
& @* f( a! T. Z. A% k8 c: ~+ C ( ]' c8 _$ z! p: M* F* I
生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
7 \! P; v* O7 ~) |, W! l" B& R; ^# b
6 B) y- |! M" t" R3 D* V. R
# b$ r4 u ~7 _ 判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
+ X- h9 M g- r3 _+ F Z2 S , z- t4 n8 C1 I" c
4 U% @( ^1 p8 h- q 训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
6 H9 w( u2 x8 Y( _. m4 J ! G N: ?4 u# v
0 O0 N: h8 x0 O* `8 _) |- J 当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
; ^5 g! c' b- i2 L6 @ ; F* {& \/ B; k6 f
2 Q/ Q$ R9 w5 U$ o 本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。 . |6 o- w; Z# K" b$ `1 O
3 K4 Y; n+ ?# ]8 p! z# Z* E" [ & o' V; v; e3 s C! [* M
二、加载数据集
& O k* Q# G$ n 使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
) F; S1 j9 G( j& f) p* D
; ?5 b- E4 d3 V* c9 X8 W2 L# O
6 P: R V6 U# m (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
, ^% e; \2 Z2 z8 R* o- ~ 0 W" E( n4 k; f8 O
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') - h6 A9 ~. x& e* l1 f
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内 - ]9 L& L; w/ Q5 {
1 G4 N4 i* n3 h1 X2 i9 } BUFFER_SIZE = 60000
2 _% F: Z6 f$ S9 s& k$ f: u7 @/ o BATCH_SIZE = 256
q0 c7 q5 v" P2 a. N" A! \1 e$ J. m
% r. F1 A, L# v- o4 B # 批量化和打乱数据
. ?& C& m5 f9 m/ h train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
+ P# A( B5 T! {% w 三、创建模型
3 Z: D+ ?0 N' S5 R$ w) w- P7 R 主要创建两个模型,一个是生成器,另一个是判别器。
" o/ d* X- v, K4 B: [9 ]* T
% e& u+ \. J6 t5 T; z4 W O: n& ^
& h+ B& C) k- s E. y 3.1 生成器
1 x( s& }/ e9 S% h$ v( E* S 生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。 V3 K9 D9 s! f/ m* \
! N) l. p! n* I8 q: H3 b( X! h9 x: _
" }6 j7 _$ j1 k7 n* n4 r
然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。 0 E3 n4 g. v; Y0 d$ o( S5 w
! o* C/ J. H0 K1 K, i
" |. Y5 c- ]4 I. T 后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。 ; B( f3 P$ Q! M6 U; d$ A3 R# p
" D6 O) S+ w2 L% M4 u* d. z; l
# E' U5 N) O) E6 `- V8 I: w def make_generator_model(): - L% l) \1 s; G
model = tf.keras.Sequential()
) V) u) W, _8 W& G1 _' M, [- N model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) 0 \2 y! E0 u% A4 L$ M8 ^$ `
model.add(layers.BatchNormalization()) 0 }' l0 m; g) m) q* m' |5 h3 ]
model.add(layers.LeakyReLU())
4 i* U# n/ }: L. r* p4 G ; {. o5 m/ d$ r. M' j
model.add(layers.Reshape((7, 7, 256)))
$ {6 Z. A4 [& E. q# b8 ^5 z! k5 X assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
4 ?# ^ ?* [( q0 G+ }
0 b2 {- p& p% x9 k/ C model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) 8 {! E( |; a2 M0 E
assert model.output_shape == (None, 7, 7, 128) / w& x# b# ?/ T$ F
model.add(layers.BatchNormalization()) : Z( ]0 O$ m- \0 Q- q
model.add(layers.LeakyReLU())
3 B2 r# N# ]' _0 c6 g 0 F4 l' K. e. j
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) / g9 \3 J5 a1 C4 D) F' X
assert model.output_shape == (None, 14, 14, 64)
9 D, c: U `2 B5 s! _* O model.add(layers.BatchNormalization())
4 n8 `, G# k: F! @6 P model.add(layers.LeakyReLU())
$ S% I; i" {3 o+ ^5 b5 m- ? - ^2 h; D% V% ^$ Z7 O q
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
; z3 x* J4 y2 U0 v8 o* }* f assert model.output_shape == (None, 28, 28, 1) 4 k0 j- h# J% g
9 r2 F2 g9 F* W& A( c+ Y
return model
# R/ ?: ^$ v7 r3 L8 B 用tf.keras.utils.plot_model( ),看一下模型结构 : ^5 y, g( }$ a `7 i4 k: q
( Q H& L! m- {3 T4 C, B/ x
! V, D4 ^6 J: ^/ I
" I2 I4 Q d1 C
9 i0 }3 N. v" L r [( N: F; I # s. Z1 M* g) G4 i6 K
用summary(),看一下模型结构和参数 8 k' T0 C% g I, {$ `
) i8 @1 ?# b4 c l7 Z( r
. Y+ J3 ^. M! B, F/ O- `& P, H% g) }
! b9 m. ^4 K" y _* {% F
0 S/ {) B, R9 [; X* W' R # m+ S3 w \0 |2 h+ D' S
. |1 S. A$ t+ p 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
% h! ]/ \ y1 ~ }
+ [7 G A6 i k6 R
( @3 O2 w4 ~% _* M generator = make_generator_model()
. u" H" p% ^3 {4 b
- X9 R" S0 S# c, U8 v! I$ l' e noise = tf.random.normal([1, 100]) # R6 e G! e, q! G# a; Y: C
generated_image = generator(noise, training=False) 8 _. G. [* \3 q8 l% [6 C; z4 x
4 i2 E, n" l8 H; _, q$ U; n; B plt.imshow(generated_image[0, :, :, 0], cmap='gray') . A+ k: I2 L" y: N+ w
- j4 N7 K. m( W& g O
. t5 v$ j1 J2 B- F; e( w5 Z
* p0 n" P) d$ h" ~! c " r, A6 ?7 K" G" x( k8 {
3.1 判别器 ) g" @9 _% _1 }" a- ?4 J
判别器是基于 CNN卷积神经网络 的图片分类器。
/ s2 i O% _7 D, c6 {' }" F
z" Z' G9 l- [9 ^ " |2 t( v1 n1 ]3 ?7 A) X
def make_discriminator_model(): # z E7 B( y; Y. y% T" m& c
model = tf.keras.Sequential() 2 \2 I$ D8 O. U' T4 K$ L
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', + A1 K9 ~2 L, z/ h8 j; ?
input_shape=[28, 28, 1]))
K7 Z+ v! r7 m3 c model.add(layers.LeakyReLU())
& q) @+ M: l% a `( G model.add(layers.Dropout(0.3)) $ V( N2 k9 M+ ~% X8 }
1 ^6 q1 B0 l. K; ]
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) * J$ \+ u# t! d3 Q2 h. P
model.add(layers.LeakyReLU()) - a1 h/ l7 g$ j/ |- @2 O
model.add(layers.Dropout(0.3)) # P- W. `$ e+ r5 t9 K* _
6 F8 u! t5 }4 O3 ~1 ^& h5 [6 g, Q
model.add(layers.Flatten())
3 u- z7 s1 e2 B( l1 u# `6 i2 ~6 k model.add(layers.Dense(1))
% Y8 i$ z5 G4 c) H; D
4 m Y( ~& K5 @ @% ]/ F s return model 6 T4 o" A) A- P
用tf.keras.utils.plot_model( ),看一下模型结构 0 j) x) l& i; p" U
_% z# {8 S3 o% q% [6 s
' P7 n7 h# ?/ {& M8 P, x
: u- s* ~: w# o/ M8 c) z
8 d8 Q2 t1 K3 A4 m- y7 R/ Z% Z0 K% Q / B* J: J' d1 Q2 j+ J* s7 E% j5 W4 k
7 I7 c# V- u4 `9 J+ @! i- u 用summary(),看一下模型结构和参数 " R; c$ i% B* K- H4 ?, S; @
1 t! E; u) F9 Z$ v# t/ J3 I
2 G6 C, d- S$ c$ [$ J: O ' k- W) m7 {" ^* B8 e( G* h
$ i" n" s! I7 o) |9 \! l. ^
" q* U/ ]2 J$ h# \/ B
7 F# {7 E% e: _5 _: n* H- q 四、定义损失函数和优化器 % M% C- M( m K: }
由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。 % ^7 {; P* J7 c, V, E2 }
5 G0 g& j( ~9 F0 I8 x4 R
5 s6 q h$ [4 a8 H2 ^. a% j& b( J
首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
$ q" S2 U. @1 ^0 V/ V' T9 V + M: i3 ?- N/ i4 H" }
5 \8 o( x) i, q! ?% M4 g9 Y! p6 z
# 该方法返回计算交叉熵损失的辅助函数
2 r3 F, G. j6 B$ V# Y' D cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
( z4 C- \4 P# w. x6 s 4.1 生成器的损失和优化器 , d1 J" a1 \; |/ t( s4 f
1)生成器损失
* v: b7 t/ H& @% u
1 Q2 |" u. A3 T. l . d+ \, t9 J) G
生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
1 m: ?+ X6 ]* b) o+ [. `: c# n6 A5 w
/ n8 J' \/ _7 \1 ?9 \0 ~3 w
8 h1 L- E% a% e2 Q# X) N$ i 这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
' n8 D9 X6 D& V! \8 g- V$ x % r- |$ R% S2 N2 n! i
: v. U3 f& v5 g! V: t1 B- ?5 V7 m* G def generator_loss(fake_output):
) n9 q% `% v2 j0 k0 i return cross_entropy(tf.ones_like(fake_output), fake_output) $ c) p7 M) K/ K6 J$ B
2)生成器优化器
4 f2 o' A6 P7 o4 D% K2 x
; N( n9 G$ x' ]1 v: R* {( \
9 W. n; h k M# G4 H6 N generator_optimizer = tf.keras.optimizers.Adam(1e-4)
# x- D4 d3 X0 @: A4 @ 4.2 判别器的损失和优化器 ' m' ?, j3 S# v' @
1)判别器损失 : Q& `9 N3 |$ m( _8 e# F& G% r! u
* r& M+ G( {* w; A8 a
; K6 m' i. N: j5 I5 M1 N% c 判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
! w, C( X- [$ a* N1 [
! `& h! J% o' O7 b: B' h $ B+ X6 B+ e+ X& z* ]
def discriminator_loss(real_output, fake_output):
7 n* ?7 s9 L- N) a/ R real_loss = cross_entropy(tf.ones_like(real_output), real_output) , y; P1 G' Z3 n5 H% x7 h! a/ N
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) , O0 W; \7 ?. h7 v; Z% O* y
total_loss = real_loss + fake_loss ' q/ Q# c# `3 b; I% F
return total_loss 9 ]0 `* r. ?) k% H; p U
2)判别器优化器 6 |" a% ~, q9 Z3 d6 u3 W
6 p/ X. K/ y( A4 D0 H1 J% W
4 `" D2 B8 ?, ^6 w- m H- E# ]
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
+ o( R% ^0 E! C" C3 X 五、训练模型 : q: k& m# s0 ], V9 j q; {
5.1 保存检查点 1 |0 f& x5 @2 \' _- n! B
保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
2 X8 T$ ]& c0 ?$ f e! r" I" Y8 c) d
U+ d4 F2 X2 Y5 m2 F/ c- i L
checkpoint_dir = './training_checkpoints' # e3 v; }3 h8 x, E
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") $ w1 S( x; b+ G* o L( z
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 5 I3 r* G1 ^0 Y% \- C) A( X( c
discriminator_optimizer=discriminator_optimizer, ) o7 x; l0 h N l; x+ p. c
generator=generator,
O. z" E9 p$ @9 ~6 k6 r# f discriminator=discriminator)
( A1 ^" q" z9 n% M! [. V: W 5.2 定义训练过程 0 b; B; j: O% W8 ~7 C" m8 `
EPOCHS = 50 2 d9 `; Y( ^# i1 o: z
noise_dim = 100
, Z/ d' ]4 D g8 `# Y! E4 j2 w num_examples_to_generate = 16
! D% `3 C% L' ?" _
+ \* n7 {) [7 c3 P ; w h9 {. ~- g2 E7 ?5 e( t; p
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) g! L! R2 S! A
seed = tf.random.normal([num_examples_to_generate, noise_dim]) 2 E4 W' m; r, p( B" K9 g
训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。 1 o! m4 a3 F! j; S' H: \
# c* t5 `) \" {# O6 B& e
1 p1 Y- B0 K) Z8 K% o' l 判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
/ F' H0 |3 y8 m* G
2 ]4 {+ x" k# d' \$ Z
+ l! u1 y' M, ]6 z4 t4 H5 ^ 两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
5 t( A6 }$ z$ y; U
! h0 t' A6 T% |. b: Y4 l! i # z# N1 ]' {4 P9 {* h( `
# 注意 `tf.function` 的使用
+ B/ w: U4 c8 ~: r; l% Z # 该注解使函数被“编译” 7 r; k: c0 A- _; C
@tf.function
3 {9 u3 b1 _" L: A def train_step(images):
( f4 I7 M* ~/ X2 i noise = tf.random.normal([BATCH_SIZE, noise_dim])
1 a! _0 d- c$ W+ y 5 t" P7 M' @- S6 w% b
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: I1 _3 t. c, S0 R
generated_images = generator(noise, training=True)
8 X% M3 y; o- R- U; @; x- w$ L ( ]2 \& e E7 \# a) s8 Y
real_output = discriminator(images, training=True) ; l4 d3 Y+ i: k6 L$ X ?; b- \
fake_output = discriminator(generated_images, training=True) $ r; y5 r; ]# ?
" z- p4 J7 p& { gen_loss = generator_loss(fake_output)
* @( d0 `$ W4 K: w disc_loss = discriminator_loss(real_output, fake_output)
" y) P- _% v8 `0 Y8 o n+ \& W- o 8 B' I& ]1 A! \7 y! X3 g7 {) C- Z+ x. Y
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) - {( a% r# h N$ J6 @* b
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
- W$ G1 W+ b* y/ X8 L
O: p/ `# C' e1 I generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
( F; d; F0 Q1 Q5 q discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
3 T3 C/ O7 ^: D3 D7 s, z, s ' I% Q9 J' z' l* @8 {8 e+ U# j
def train(dataset, epochs): 9 R+ K$ h) G# w. _3 l7 d
for epoch in range(epochs): * e( @' L3 J6 A: L
start = time.time()
' Q1 r, g* c m1 c* O9 D3 s
( P" t) o4 y6 [+ q; M1 S' @) J n for image_batch in dataset: * g& I+ I% x' [" X; O9 R# {
train_step(image_batch)
" z* _/ W3 F$ `7 }9 S; [9 ` 0 K' T! j- E1 N, C
# 继续进行时为 GIF 生成图像
+ F, P( e0 M; }/ B% ^ display.clear_output(wait=True)
4 }) f/ u$ ~ N0 P0 y) A8 g generate_and_save_images(generator, ( j* _) |. n( ~1 z% @3 u% B/ u) W& c
epoch + 1,
) ^9 `' ` j/ t6 M7 \' F seed) _; x6 F! m# H/ u# l. j* j
9 R& M5 o: \; p! H
# 每 15 个 epoch 保存一次模型 4 G! L! F; E- y0 Q
if (epoch + 1) % 15 == 0:
8 A" @. k( ~9 s7 |& G# W checkpoint.save(file_prefix = checkpoint_prefix)
2 N v7 J+ a2 D+ W1 i% q 6 y R+ [1 B# d B
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) 6 y# [; g" H/ ]5 s) T& l$ E# S) E) j
( O$ K8 [8 A B/ D # 最后一个 epoch 结束后生成图片
! N9 {& n8 D/ Z display.clear_output(wait=True)
. q! H! \) i+ D( q- b6 N- O generate_and_save_images(generator,
2 c( v4 R9 { k% W epochs,
3 e Y3 }0 [# i; R: Q; M seed) ; {* _6 x; X R5 k& q; W
. Z; r7 {- X2 l2 ` # 生成与保存图片 ' L0 J5 ~0 o2 ~
def generate_and_save_images(model, epoch, test_input):
- U$ i' u2 \9 t # 注意 training` 设定为 False * D' H7 \ v# |; h: g5 o# R; X/ r
# 因此,所有层都在推理模式下运行(batchnorm)。 " D2 L! K& s. k3 P" m" G
predictions = model(test_input, training=False)
& ?! K4 N; ^- v # y( H. \& o8 t" C& R8 A$ `
fig = plt.figure(figsize=(4,4))
4 T5 U5 k! W7 `- W/ ~. M. ?+ d9 M3 D
5 S. U5 G; o/ e$ A' S for i in range(predictions.shape[0]): ! ?$ e7 z/ P n/ ]8 r6 [( D
plt.subplot(4, 4, i+1)
5 B/ d" U* O; l plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
4 `+ W! }8 i; {# c: n2 ` plt.axis('off') . _2 h) o# \, @$ U; q* Y
* \; T) X) o& M K* s2 ?: ?: ~# g5 H
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
; v7 m. \+ e* E5 ~ plt.show()
% @/ J+ Q# H* X+ K 5.3 训练模型
$ O4 {5 {$ |3 { 调用上面定义的train()函数,来同时训练生成器和判别器。 1 |. x+ X- u2 j _. b3 e1 F3 p
0 \2 s' x: F5 T, j
$ u! U5 W* c4 n( n; t- n 注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。 & e0 M% k7 q/ _* h7 E% l2 M
0 z* R' y# t6 [4 h2 ~
$ X% F7 R: m, A! e, @, D- W
%%time 5 J1 t/ A' D. l8 h9 T
train(train_dataset, EPOCHS)
& f+ T% l) i% K- C/ m% Y 在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
+ V' B( v, {3 Q# f) x, q7 q! I
9 y- B% q3 ~5 _8 T9 ~: s
" ^6 g9 K) W! e' Q H4 h% u 训练了15轮的效果: # @1 |. g h* q) z* T
- C4 n) C- X$ `* h % s% M0 i. S5 l" E' {1 t* u
: c. n6 E! y4 ?
7 L- F8 z* c1 ?7 H
0 _* \, Y3 K: } h5 @
; [) R5 y3 ]' _, p5 s8 b9 z 训练了30轮的效果:
! Y( T; _& K2 s4 E! H9 |# d ( H5 P# x) C0 ]6 F
0 m( \, j y' O; e4 x . u7 d/ a! N7 B) ~- U7 s2 _/ w
; I) O' T: Y4 o9 n2 v, ?4 D
- ?$ r+ S+ S8 C1 E! m1 q4 V
1 F" ]! B/ z J- f# ~# k 训练过程: ' }6 r- H1 a v
; s7 w! I5 W6 B( p9 W6 q( T
* z) U; ] h7 v3 p
4 K% o5 Q8 ^; `2 }( X: b8 K+ g
' V* i9 r$ [$ x0 G$ ~- w; m
5 `8 Z( R: W$ O
3 w1 C3 d& Y# Z0 N* J6 J' S 恢复最新的检查点
, W$ l; N/ q6 G1 O/ T* x, l7 b
1 L) j1 r& u+ z, x1 U
. G7 n" M6 V7 n) ^7 F) U" C! u checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) 8 T6 ~2 f, H) I5 d, _: B2 V
六、评估模型 8 L ~+ _; C8 }- n) q
这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
2 T: y$ Z" @; D Y4 a6 z, C7 t# h: [2 B
6 l! T# ?& n7 j: V- W: w6 n4 _ 8 ?4 I- Y$ u# e/ f3 y& Y/ i
# 使用 epoch 数生成单张图片 ' g5 z% m. V0 a1 ^9 V
def display_image(epoch_no):
( S# e) X0 Y" O/ M( @; w! q! W; [# b return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
& W- T ~' Z1 w3 c# J8 n
! z% n$ T U. M" |* ?+ H7 m! o display_image(EPOCHS)
`$ m- P4 g" w3 b! n anim_file = 'dcgan.gif' . p. k! z+ h: f" M4 T/ u+ \4 O% T
2 P) ^( L( y: Z with imageio.get_writer(anim_file, mode='I') as writer: ! q1 A8 h3 r5 I6 M
filenames = glob.glob('image*.png')
! Y2 S) z5 P$ U) y filenames = sorted(filenames)
3 [/ b& K. q" i% u$ f last = -1 ) N7 W! F+ u/ p! F U; x: r# l, L. ^) t
for i,filename in enumerate(filenames): ( X j$ L. c Y" W/ E. k1 o' N* b
frame = 2*(i**0.5) ' k( O3 j1 H; k
if round(frame) > round(last):
/ o1 g) m4 `; Z$ @2 U last = frame
! ?7 }- A2 L6 C3 L0 n* d8 F$ ^ else:
& ^* o) Z/ l( r1 ^2 N continue , c& o7 ?& T# j9 O6 I7 |
image = imageio.imread(filename) % n! t- n3 _. m+ e
writer.append_data(image)
0 q! u+ `& R3 F4 q" a( O3 ]6 R* Y7 } image = imageio.imread(filename) 6 x$ ^* S( T7 ~% S8 d
writer.append_data(image)
! n6 L( n3 c- j. v8 N0 C& F6 w; W 0 Q6 Q) V4 ^/ \* P. Q8 @
import IPython
) `. H8 y' P5 p& A \1 @ if IPython.version_info > (6,2,0,''):
5 z. }4 ~. c- `# z" n% D display.Image(filename=anim_file)
3 Y. A7 ^) L7 H4 [ 3 \4 T$ n0 u3 u/ K. s/ \
9 q. G- y1 G1 Z n- V% d9 P
* f& I; h( P. k$ s 6 p2 J& [* v0 f8 l+ e& S0 D9 U
完整代码: / y) {+ A4 |( w2 j; f
8 L6 [* I# X% {2 I
9 _7 \1 R; ^6 ]: {! N/ R! D import tensorflow as tf
4 l5 A$ ?) @( r g import glob 3 Z/ p1 Z, \& r
import imageio
; `4 @/ [: L- @# Z' M& K import matplotlib.pyplot as plt
) d: |; @) @1 c& W( r import numpy as np C% z) y+ Q! p7 Q: t+ @# C# w
import os
; s M9 B g* D" P import PIL 6 n% ^ N) j6 j2 K s
from tensorflow.keras import layers
5 D- t4 e( x1 }. @ import time H6 K# ~) p4 o; F% ?
# f2 m- s" H# C9 w6 s9 \; K from IPython import display + y$ _9 i0 V* j2 m* P2 A- B
* Y; { p3 M, M5 N" { (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
6 M, J1 B0 u l( s+ v* p' y ( F# ?: Q; l: M
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
7 `" W+ Q; x5 m, J% Q4 C! n train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
+ h& ~ O7 h& q$ B4 l L# Z6 a: {+ S/ I4 U0 x7 u/ T
BUFFER_SIZE = 60000 9 B7 `. S7 S) t4 Z3 l) A+ [6 A" b8 m
BATCH_SIZE = 256
: q W9 w( L9 ~7 v+ n" H7 E! t J6 K- w! @( ~; L+ G$ o9 i8 Q/ D ~, a! a; f
# 批量化和打乱数据
& K3 }, J6 Z, S) x. k( k$ N1 a train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 8 d7 M" S3 X! k
$ ^% p: Q, e1 D0 A% e/ }
# 创建模型--生成器
! O9 h4 R J9 A, P6 }2 _; g def make_generator_model():
. b+ n5 y& [ F model = tf.keras.Sequential()
. ~: g) O/ I7 |' x/ F model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) 8 J. i3 A; H7 ~% G) m; d- N; |. Z
model.add(layers.BatchNormalization()) 2 F* n% |1 U( f4 \5 D2 g- g
model.add(layers.LeakyReLU())
/ y# L/ y& g' R0 i9 r+ p6 _- Q - ~( y( z1 [* m: f( E
model.add(layers.Reshape((7, 7, 256))) ( @+ Q. n+ v0 g" j% C) J' ]4 C
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
- x( g8 J% c3 [1 N- h6 D 5 M: _. u6 t) B. y% _
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) : Y* r! p6 z6 U8 e8 t
assert model.output_shape == (None, 7, 7, 128)
1 k1 w1 C, @; U5 H/ O0 `, p model.add(layers.BatchNormalization()) / U: @' Q0 M) O' t7 e% K& H
model.add(layers.LeakyReLU())
7 G' Z, s& @/ M; T& B* Z
1 n( Q9 ~$ z0 {! U3 ~% ] model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) : C+ `8 k8 `# _* i
assert model.output_shape == (None, 14, 14, 64)
. m; B( N' N% a* C; Q8 J model.add(layers.BatchNormalization()) 3 }2 _% d o$ W' }6 ?
model.add(layers.LeakyReLU()) 9 [/ ~) L- x% V' @
8 [2 x8 J' [/ T, w2 H% T; m model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) % n: `' V& T3 ^$ D0 j8 d1 n
assert model.output_shape == (None, 28, 28, 1) R0 y. n% g" @0 ^" N6 E8 V
7 E1 q4 P( W! A return model % N" d6 F" d- S
. ^ d# a- e- k) k# \+ ~3 F # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
9 p* A+ @* D" c generator = make_generator_model() O: y l0 [- ]+ Z% }; l; I
, j, d: E; j/ w, R" C
noise = tf.random.normal([1, 100]) " Q$ l5 b; l' y& U" E4 N0 z
generated_image = generator(noise, training=False) : Z* }& C0 m. ?* {
7 }2 C! Q8 ~1 L+ l/ ` plt.imshow(generated_image[0, :, :, 0], cmap='gray') , I4 j7 M$ C4 i& h5 o% G+ B
tf.keras.utils.plot_model(generator) % ~1 L' m5 J% ~6 M$ x
) n& z8 u @7 C& \3 f # 判别器
# T" R, W+ n8 J9 y" v) P7 s def make_discriminator_model():
) ~9 d" F: r/ m model = tf.keras.Sequential() ' O- U. }5 J% \, ^
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', ( h& T$ h1 b0 a$ J3 i
input_shape=[28, 28, 1]))
& S1 V& k& b& ?+ Y# n- F model.add(layers.LeakyReLU()) 1 E6 g# h2 i! U* {
model.add(layers.Dropout(0.3)) " O( h7 Y5 K4 l" K" h
g2 x7 N5 ]! u0 D2 y# Q model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
/ o6 I0 d, f2 O model.add(layers.LeakyReLU()) 3 H& w+ D/ D% m% o) c$ k0 ?
model.add(layers.Dropout(0.3))
, J) H3 a2 X- y5 Y2 h: S, |2 v: T * ^) e1 T+ S$ d1 g4 L% d
model.add(layers.Flatten()) & C+ {) i* q8 T1 w0 w
model.add(layers.Dense(1))
^1 P$ }; n% H5 b- b# z
@- V, g% g% P return model
" k8 i. X2 q( b4 |6 Z- B 7 ^# Z5 v- f; p. y ?$ l& X/ @
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
" ]+ o6 J3 D- [$ Q5 M: _ discriminator = make_discriminator_model()
* V% z- n: j6 P( N% X. v7 E decision = discriminator(generated_image)
9 x6 A# }" m0 W2 x" \8 Q2 Z M! d% S print (decision) ! X0 C/ [; o8 X9 @% O9 d
' z1 b Z2 M3 Z+ q/ x c& o
# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。 , T5 `. m7 q6 i. M7 I% Y
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) $ _9 c) z d' }) Y- M4 t$ n6 i" ?- k
2 k4 J' e3 _& y$ p0 g
# 生成器的损失和优化器
# |0 y. y( [) w2 r" ], [ j- e( R L def generator_loss(fake_output): 0 \# X! i. b5 i- [; e
return cross_entropy(tf.ones_like(fake_output), fake_output) - g! g1 i1 B8 | `* z- i
generator_optimizer = tf.keras.optimizers.Adam(1e-4) - D) h* x& o: \, j
0 R; T# J& ]/ z # 判别器的损失和优化器 0 c7 V0 g, h: Z; e9 n
def discriminator_loss(real_output, fake_output): 8 v7 z( f! u# P2 h6 C0 J
real_loss = cross_entropy(tf.ones_like(real_output), real_output) 5 i* M0 |/ q. N9 e
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
. p2 K9 o2 G2 K' v( b total_loss = real_loss + fake_loss 3 Y& a9 s+ T- K8 C; k* u6 g' p! R: I$ p
return total_loss
1 u: c5 r) g% H! h discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) - i8 g6 T; z9 ?+ l: K5 r/ Y
# F6 |) n& X! ?! f9 U # 保存检查点 ! P v3 `( ?6 w5 k) g, @
checkpoint_dir = './training_checkpoints' - h6 z* Z" y4 d! E& {
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
2 p: a# k( r7 V+ Q8 [ [ checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 6 H9 y9 M& Q* a. g2 ]- R
discriminator_optimizer=discriminator_optimizer, & c# j5 X/ X" F0 q: Y$ z$ _2 n
generator=generator,
9 }6 S' \2 ?$ w: N/ E discriminator=discriminator) * Q+ c' V1 R# ~
* k' i8 |3 e% q7 I' E # 定义训练过程 ' I9 V* T& l/ _) Q" D$ w6 o& |
EPOCHS = 50
2 j; H4 d8 _$ N! S! k, B noise_dim = 100
- c5 h. H/ |3 K: c# j9 m' Q num_examples_to_generate = 16
& [( {( p0 X4 w4 ^0 E7 R/ s: _! Y : m6 c! H6 k, x8 K R
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
! N0 S; ^/ X( J. a7 U seed = tf.random.normal([num_examples_to_generate, noise_dim])
" g$ k% E _% I8 R6 G s
/ _8 [" _" G; r0 k( e+ \0 C # 注意 `tf.function` 的使用
: p7 Y7 X" P* z* \' k7 }* u # 该注解使函数被“编译” $ x; C6 T* g) ^6 M" \' S- H/ J! S
@tf.function 0 u: x" c0 W( h
def train_step(images): 6 N, k$ V) @! r' B9 t `4 [0 z) p
noise = tf.random.normal([BATCH_SIZE, noise_dim]) - N8 \) C7 ~4 v1 q0 c( X
' l# A7 o: @9 M; t" f# E- r) u with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: % u/ p/ n. n* T0 G9 R, \' u' T$ z
generated_images = generator(noise, training=True) , z# w/ w! B3 U0 k5 ^7 E9 v% U( F
# O/ [6 D6 u$ K real_output = discriminator(images, training=True)
3 v! E- t& [+ ]- A) { fake_output = discriminator(generated_images, training=True) # r* t9 t$ |$ ], q8 ^( ]! q, T' ]
3 l# K0 q; q- }* w4 f0 u5 a5 w
gen_loss = generator_loss(fake_output)
0 v0 v6 E4 B/ ?7 f) F; ?9 j8 q disc_loss = discriminator_loss(real_output, fake_output) 8 j. a8 m! _6 S# r2 }% } f
G. P: y' h1 U: D* @ gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) & r& T; M( y5 {$ G K1 g
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
! }7 Z* A- D$ I4 L
* p9 M5 V! G, H! I9 m4 @ generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) + O) q* u" F$ R( C, A2 X( y
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) # X8 ~" S2 h/ \1 | t- L3 M
2 D, f3 `: \4 O. v* S; y
def train(dataset, epochs): 8 g. f; j$ k( Q5 {0 ~- b
for epoch in range(epochs): : w- E* l. s/ Y& p" p
start = time.time()
4 e% [, v7 t4 n z
) O) T. b6 j! l# h/ c for image_batch in dataset:
+ | \2 B) v6 N; _8 N# r train_step(image_batch) 5 O( z2 U2 b8 l0 |( q3 [
# d/ }- m- k1 Q9 d) Z2 t. ~ # 继续进行时为 GIF 生成图像
3 t2 d- g5 P ^ M, T N' ~( ^ display.clear_output(wait=True) - T) B; ~* O( X5 B
generate_and_save_images(generator,
) n- g" L r( h# x. M& c8 f epoch + 1,
5 J* A: D4 j0 x2 O; s* h: K1 |3 J seed) # V4 g0 w g; m- w+ K2 U
- G$ v0 A1 A7 k; n; D4 t # 每 15 个 epoch 保存一次模型
5 L) k3 r0 w4 N0 e1 |( N" r if (epoch + 1) % 15 == 0: 5 X' r2 b' C1 q2 |/ a- k/ `5 v( U
checkpoint.save(file_prefix = checkpoint_prefix) : H8 o5 `2 s0 s1 d* Z0 D
W5 ^" g m$ E2 W8 p! G print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
6 `# ], J/ \# V2 R: b - Z: w% K9 V# J4 i3 c# ^
# 最后一个 epoch 结束后生成图片
& L0 b2 y2 N8 l. L1 t% Y( K display.clear_output(wait=True) ) \6 R! c) p; F S6 v# G8 k9 Z, M
generate_and_save_images(generator,
3 Z! O1 ]. B" D epochs, + K" k. q1 |4 m7 o1 F
seed) C g/ I+ G: d S2 |
& @4 Y, J2 m3 c! S& e0 t' G; T # 生成与保存图片 / M' F8 M+ Z: X- Q8 d
def generate_and_save_images(model, epoch, test_input): 0 y E9 }, }8 z0 K4 A5 i8 P
# 注意 training` 设定为 False / z# z2 A2 v9 h+ ?
# 因此,所有层都在推理模式下运行(batchnorm)。 $ u" w. n. e; z! A. w* E* W* U
predictions = model(test_input, training=False) . @6 a4 O- W3 J3 H* U
% I9 c+ |* O3 f ?8 ?
fig = plt.figure(figsize=(4,4)) % F2 I5 y$ O/ m
$ O" ^0 I3 T7 a k8 @. Q# k8 v* p
for i in range(predictions.shape[0]):
) X* @; j; j8 n% z9 H3 J0 m3 N- } plt.subplot(4, 4, i+1)
& u7 V/ J5 D B; N) @ plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') # E) s W; u/ [- w
plt.axis('off') # u5 d Y; \% v0 f) _! W, j
' t+ w: N5 \6 e+ o! w- X7 } plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) ) _$ X; T: d/ \6 ]6 r- _$ r
plt.show()
% O8 P. U; x! w; R9 i$ ~ M# A$ q N , @$ c! X! `; G7 b
# 训练模型 1 H9 G( G8 ^8 z% G# M" f
train(train_dataset, EPOCHS) 0 s0 k- h0 M" `# q6 T! Y
2 s% ]% ~" e- ]) H4 p$ O, R( S: w. E$ D # 恢复最新的检查点
5 I! m+ Y Z: H( x checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
Z' q2 g( A" q* N [ ' @7 A' f: x) l# W8 Y
# 评估模型 2 l5 r: s4 o+ @, m& `6 n
# 使用 epoch 数生成单张图片
W1 I6 ^( I6 I; D. B" t; T, r' R" ~ def display_image(epoch_no):
- v. p5 w. {3 z$ c return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
V& t6 p' ~+ B- z+ ]. Q4 k( ^3 Q1 ]
- r/ ] k$ d6 c5 z display_image(EPOCHS)
9 u! o% c5 M2 m
4 M& g8 T2 ^( G' w- E9 ?( z- D anim_file = 'dcgan.gif' % y& p$ h2 N* u
( d" U5 g9 b$ f% W4 I/ k4 ?
with imageio.get_writer(anim_file, mode='I') as writer:
( a2 V" Q, d2 \; m# k* g' [ filenames = glob.glob('image*.png')
/ D0 E5 {+ Z* N5 v$ ^ filenames = sorted(filenames)
! t& Q7 g. u" {6 K- N5 b last = -1 : \1 Y3 J, c# B3 u9 d8 ^% F
for i,filename in enumerate(filenames):
$ w2 }/ S6 D& { frame = 2*(i**0.5) 0 i M/ n5 h- ]; D. m9 Q
if round(frame) > round(last):
4 b* v. S; P7 ^' u$ u8 { last = frame ) z2 t7 w$ I4 I; J
else:
+ s5 Y$ `! X9 u8 w' K" v7 E. Z continue 1 i- \: ~% w% f; N4 W9 g
image = imageio.imread(filename)
, c B& h2 ^; i2 z) ?- n4 ?! Y& K writer.append_data(image) " M/ `9 t8 z" @# K" p
image = imageio.imread(filename)
* x! ]- l1 p! U! k. F$ q- A writer.append_data(image)
# @( f3 o" b$ m 5 `& s/ ]0 i6 E$ R( C9 P* `' G* d
import IPython ) S2 b, m' ?' ~: p
if IPython.version_info > (6,2,0,''):
. E1 E' B* D+ ?, q' D9 M display.Image(filename=anim_file) 9 J! ?. |, @# B& N- g
参考:https://www.tensorflow.org/tutorials/generative/dcgan & G: H8 q; I0 d5 u
———————————————— 3 M* ^- M3 W3 g+ r2 j* P
版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
+ }& G$ E. m- g 原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
- T% j. |% s8 {% x0 {
& R, |- n5 A' P% q
& b/ L) G4 j- D! S2 M* ?
zan