9 L8 K, f# t: y% _! J6 t5 f, H# 创建模型--生成器 / ^, h6 C3 G& e& s2 W: B% k; Ldef make_generator_model():! e5 i# u/ I2 [3 y+ K- j
model = tf.keras.Sequential(). S, U$ _# T, b; l
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))/ i2 j6 n# C" f- [2 J
model.add(layers.BatchNormalization()) 2 [5 Y& q; p2 M$ r model.add(layers.LeakyReLU()) 3 f+ B. Q4 f+ b, f) W ^ * z7 E1 O- Q1 n, T model.add(layers.Reshape((7, 7, 256))) ' [" n$ X% S7 Y assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 $ v5 d) B) [& N# u- `: q* I2 f' | : _' b, e, [+ \# O; z, `
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))9 x9 f0 C1 r. M1 X3 E
assert model.output_shape == (None, 7, 7, 128)$ O6 m& l+ P' o% R3 w
model.add(layers.BatchNormalization()). t0 z; D& |& ?( V. a0 M
model.add(layers.LeakyReLU())2 K# m# U& n* {0 J5 n, K
( g$ u4 p& U# g' p% j* b
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))" s: s3 ]) ]1 j1 U* R! U, O( X7 E
assert model.output_shape == (None, 14, 14, 64)/ L. ~+ i# o- k' Z$ ^+ M- ]2 i
model.add(layers.BatchNormalization()) 0 D9 s* P9 h: f6 B% _ model.add(layers.LeakyReLU()) * |8 F! {1 i& F) Z0 H. L / w( e) X- W: `- _$ Z0 N4 b model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))3 }: f$ k% u3 {9 x# s& T3 @
assert model.output_shape == (None, 28, 28, 1) , {# x& G: `) D! c. Y ) V/ [/ m7 {, b5 `6 K
return model- I0 I; D! w2 {0 E( e( k1 _4 l
3 e& t& B/ J9 ~ e1 r* N
# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。7 H B9 Y7 `- C$ H
generator = make_generator_model() v" x4 E" _5 X! {! P2 ] * S8 {4 P0 R. T# B9 _noise = tf.random.normal([1, 100])1 b0 N: q8 m" p- B, |) `0 Z
generated_image = generator(noise, training=False)- Z. l# V+ x2 ^) ^; G
" A+ g8 h* X6 I$ Rplt.imshow(generated_image[0, :, :, 0], cmap='gray') . D! W0 u! z8 H7 z4 F9 z# y- u Mtf.keras.utils.plot_model(generator) e2 h! n1 s5 F9 I( J8 ]8 x- ]
/ M" v6 X( W- [: [3 X0 t! m/ z
# 判别器 / O$ h% {% t+ o$ e: O* I5 c7 ndef make_discriminator_model():$ s! K* g7 h6 C1 t8 _& H1 s) L# d+ I
model = tf.keras.Sequential() # N+ p/ w& P" [0 b/ `) y( N* y. P model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', & w- T( E# b( J6 D input_shape=[28, 28, 1])) " F$ }5 }4 U. g& s% U! D model.add(layers.LeakyReLU())& Y+ h8 d* M+ h! h5 w+ n: {
model.add(layers.Dropout(0.3))& Z# u f* Y8 q) }7 m
( [8 C7 s$ J. q' B
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) - m. N! m* J8 G" d model.add(layers.LeakyReLU()) 2 A( z+ F$ B$ Y model.add(layers.Dropout(0.3))( z3 R3 @( ]* K# I& J) R: j* H0 y
' v8 t) Q m7 u
model.add(layers.Flatten()): l4 C$ ^, D* V6 N; D6 F4 \8 f
model.add(layers.Dense(1)) 7 j. V d# {+ B: W& ]7 w . u5 a5 e, J! q4 E `4 o9 P& O return model6 S% L5 _ j% g5 X9 ~
3 r+ b" n' q: G) e8 I/ n3 w' C; C& ?
# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。 5 K b* A5 q% F( ~( P1 Jdiscriminator = make_discriminator_model() 5 S+ t2 _5 @; r& ]9 u3 Sdecision = discriminator(generated_image)# j$ z, Y" M" p- G
print (decision); f& L5 l# l" g& k
! m% q0 V1 A3 k- U# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。* R/ u+ a8 n0 h4 [! \9 M! ]9 k% z4 ^ K
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)* y. S0 g' b. m8 Y; h
* {: W# N$ ]# I# t: Q$ L
# 生成器的损失和优化器4 z8 c! W1 i" Y8 r+ A
def generator_loss(fake_output):0 d, Q m$ F' p- K2 F
return cross_entropy(tf.ones_like(fake_output), fake_output) ( q6 Y/ c, d- o7 mgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)2 |, F, D5 i9 Y/ ?. Z
9 \% ?! W1 K4 {% G
# 判别器的损失和优化器 2 `; X4 Q: h) n- rdef discriminator_loss(real_output, fake_output):' F5 q I$ ?) @7 i9 g9 d, J/ P8 w
real_loss = cross_entropy(tf.ones_like(real_output), real_output)9 C1 M! Z# Q N
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)- B2 j+ s; X2 S' k R
total_loss = real_loss + fake_loss ( U* q$ c9 ]* N$ p% |! \* q. h return total_loss0 Q; Y# @4 B7 r$ w9 @- d7 o; c
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) . N* G l' x" K! X7 y# y 1 U1 N, k7 H$ P: i3 p: y
# 保存检查点 ! _% u" C9 ^" ncheckpoint_dir = './training_checkpoints' 9 T# b9 V. p+ B: N0 U, ^( B1 V; L) vcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")2 P6 f, p( g9 U% r
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, / o& b3 E- g8 u& s% d, d& @ discriminator_optimizer=discriminator_optimizer,6 q% g% v; j/ |- e& ]
generator=generator, + i/ t) e) U' T7 j/ P$ T1 K discriminator=discriminator) ~5 j; j, |3 u% k
; J) V$ \! \5 `8 Y" @% N5 N; L/ J
# 定义训练过程) {5 n8 Y/ ~. g
EPOCHS = 50! t2 W" Z8 Y) b" ?" M3 p- T' Q, `
noise_dim = 100: ?( @0 S# E1 S+ t
num_examples_to_generate = 163 q( ^ E6 f/ V1 g
0 T7 y) H% S, a- I9 _' v; _
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)3 ~4 @" C- n7 g: v; U2 j" I8 }: t
seed = tf.random.normal([num_examples_to_generate, noise_dim]) 4 ^' i; E" P, @5 l1 x2 o# }$ c. b( a . V$ c) N5 W4 t* f0 _0 k# 注意 `tf.function` 的使用9 n; S& c$ @( |0 ]$ F- z& j3 K3 |: @7 Q
# 该注解使函数被“编译” 8 ]/ G9 z7 j8 v! u* Q@tf.function + P. S- R& k: P. Rdef train_step(images):% S! N5 R+ M7 H% y" C2 ~' v
noise = tf.random.normal([BATCH_SIZE, noise_dim]) / S) o8 O( T2 [: L8 u+ I$ j 0 m& ~% `8 e; O! e9 S
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:7 H% }; q ^% f9 _ k6 q
generated_images = generator(noise, training=True) ; g# [. s& i* O! Z' u" b # G( f' S' z3 }/ f0 ]
real_output = discriminator(images, training=True)2 j* z/ d+ m x. f* K+ @
fake_output = discriminator(generated_images, training=True) . P: ]4 L/ V5 J$ c1 B' _6 k, B: O - ^! k9 {1 E( h' u7 E6 P4 [& J0 l gen_loss = generator_loss(fake_output)+ o+ ~2 l/ w1 ]. i+ K. e
disc_loss = discriminator_loss(real_output, fake_output)/ ~; d" J p" v' V% D; s) T; Z
$ z, I! @ s4 O& A+ m! c
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 9 a- g7 V a2 D gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)9 d% E2 ]! q" `
5 U% c. ^1 B K1 O, d generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))! F6 K6 S D# V. b2 ]- f
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) ( _& d# T* h6 c . G: |' f! I2 E5 z
def train(dataset, epochs): ( E: j6 T& m" W. c for epoch in range(epochs): " x9 H: N+ Z& p6 D$ j start = time.time()( t& P$ H# M. t, i
" V, u G7 k+ Y4 H! f for image_batch in dataset: * b( z& u4 m0 a# u5 Q train_step(image_batch) . \3 y1 s- k/ M( `1 S6 Y" X* n ' m! Y( z; s3 s # 继续进行时为 GIF 生成图像 * D- G7 | |" p0 ?" T4 x' M display.clear_output(wait=True) 0 [+ z4 q9 M' i$ f" Z+ I generate_and_save_images(generator, 7 P/ a* h: V( X, G i epoch + 1, 1 n, e, M' b7 M9 p( ?3 l& v seed)$ @7 p$ Q2 y; \+ n
+ P6 |5 T" E( a! Y' n0 }. x) g- v
# 每 15 个 epoch 保存一次模型3 v! E# L# t$ Q! _# s
if (epoch + 1) % 15 == 0: 5 Y# T$ s7 F2 D# o: E checkpoint.save(file_prefix = checkpoint_prefix) + g- ^+ `' x7 B& W5 a, K& A3 t % c4 L+ {! v- E2 c5 F: z print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) # S9 R* D, A3 G2 @& Y 6 j* o8 w9 t5 R$ B # 最后一个 epoch 结束后生成图片0 n6 U' C5 b. P& E2 `2 r& v
display.clear_output(wait=True), x5 j. o9 v! `0 m0 Q; Q
generate_and_save_images(generator, - Q, n$ b& @4 P; k7 |: m- v epochs,' g7 ~1 ~8 E( A0 `
seed) ! N! s/ [6 M' G$ u 9 N) l3 F+ a: C3 `6 z) T- k# 生成与保存图片 7 Y0 O' l7 Z3 ?( xdef generate_and_save_images(model, epoch, test_input): 9 L% v/ h6 m* Y! m # 注意 training` 设定为 False' C6 b, {* c" @5 B5 M; B- b3 Z; ?) A
# 因此,所有层都在推理模式下运行(batchnorm)。 ) Z- L. t) t2 N( \! w4 w predictions = model(test_input, training=False) % E5 G2 R" e0 ]0 b ! q4 E, R- R9 K. c6 B6 j, \ fig = plt.figure(figsize=(4,4)) # g# B. p, I( Z- P0 {4 W" j X" |* [/ q; ]! L, F8 y% W) V for i in range(predictions.shape[0]): ( J! ?2 L+ |: ]( d8 L. L7 z, g plt.subplot(4, 4, i+1)6 U( a& @0 d# n2 ~
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') # Q" u8 b0 c1 N4 X" |/ t C plt.axis('off') - y8 G* ?8 M3 p 5 I! f/ T7 ?5 G' Q- m D
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) * i7 V$ }9 g- q/ i9 }0 V$ m plt.show()6 h: H+ G: O% t
0 _5 b8 y7 g" r2 P3 x9 y A1 G# 训练模型% W4 V8 b# j& U4 P( X8 p, B+ y8 y
train(train_dataset, EPOCHS)/ N- O0 o: \; e+ n0 D: b1 a, m* m
& u* e: w9 G: r8 U: ?2 C! q8 e# 恢复最新的检查点6 Q* ?: R9 ]5 ~ F1 M) G1 I! W
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)); q2 r8 L# `/ X3 F3 n- s/ M. ~" k
6 `2 G% ^$ N, V2 i# 评估模型' w" j& w2 l5 ~; D! r2 c
# 使用 epoch 数生成单张图片 8 F* V! D) |7 n: ~2 Gdef display_image(epoch_no):6 f" u! A1 s$ M8 [) `* `& G
return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no)) ) z- B: l( Z) @7 i5 `+ T0 G/ K 0 n- }$ j) D R9 A5 h4 tdisplay_image(EPOCHS) ; [& l" m! P/ V h# ]- l0 } # I Z; w* b7 Z* O1 l5 Z" P% uanim_file = 'dcgan.gif' & m$ H8 c3 [8 f9 p% B 1 m0 N6 }: r, _1 j: s5 k# v
with imageio.get_writer(anim_file, mode='I') as writer: ' d7 w6 c, ` ^4 I- K filenames = glob.glob('image*.png') 4 r" z1 X6 l2 _ t$ T" Q4 y filenames = sorted(filenames) 8 h! M' B* P* {+ B8 p last = -1" L7 F' ~2 k/ o. ?& H0 n; _7 X1 }
for i,filename in enumerate(filenames): 8 M3 h z# H ? t$ m, c+ q frame = 2*(i**0.5) ! s A3 n% `3 |9 j& s2 p if round(frame) > round(last):9 P- p8 w+ o3 g# L( s& K/ d4 e
last = frame6 `9 [, k, b9 _& ~' T3 b1 r
else: 7 p2 g1 x; @: [; i. Z- F continue6 }3 C p* U+ z# @
image = imageio.imread(filename)* r# ~) w# N" d* u
writer.append_data(image) 7 n3 K3 V* Z. B' F |& p) x# M image = imageio.imread(filename) : M! l/ d1 Z( f' f1 _ writer.append_data(image)) |5 T, w% i* |3 `, K" p
8 Z. X! R. N! q
import IPython ( t' Z" U$ r P; \if IPython.version_info > (6,2,0,''): 6 M5 O9 U% Q! w; O7 k display.Image(filename=anim_file). h% Q [) h! ~, R+ s3 t
参考:https://www.tensorflow.org/tutorials/generative/dcgan# u7 ~; k1 g; ]; n, ^
———————————————— 6 x# J2 q8 F% J" F& r' e/ h版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。* d- G$ a3 Q$ s9 K2 T7 ?- {. l. W+ M
原文链接:https://blog.csdn.net/qq_41204464/article/details/1182791119 F9 f% ~: |/ J1 W9 |