QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 5682|回复: 0
打印 上一主题 下一主题

深度卷积生成对抗网络DCGAN——生成手写数字图片

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2021-6-28 11:54 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta

    0 j" R% @1 U! M2 x深度卷积生成对抗网络DCGAN——生成手写数字图片. f0 r3 C3 _- ^) ]" Y- a$ l, J- T
    前言- t9 g0 D# n/ K
    本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。" t; T6 t) F5 w6 F& q) g

    - v( {8 K8 R  Z8 b0 b7 Y9 ~
    ; |% r8 @3 P/ @( @/ Q% K
    本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:8 g6 y4 u& V7 D+ i6 u% a5 t, G

    1 P& u" R* |; u9 k
      Q: x" v! y  Z; @) G
    # 用于生成 GIF 图片0 _" o: F; k  P0 Y( [
    pip install -q imageio
    4 B# R3 S: B& V目录
    " S; N9 F1 I' P  T1 R8 K. X: N3 Q

    / R6 b) T5 Y' J. e4 s前言
    ; o) n/ S/ @% l: `' h' }/ L0 p9 S% \* z$ P# \: @, c
    " m4 F, R5 {& F6 T
    一、什么是生成对抗网络?( U1 F4 \' g$ d! m6 ~
    ( B, a1 c. @# D# k% z: w
    % r" N% w2 G; H, P# l
    二、加载数据集% I6 e' K$ A# |- H0 W( z

    3 O8 ?2 \* Q2 ]) L3 `& w  `

    : D! L2 B% Z& v" b' D2 w4 u5 u三、创建模型
    1 Y4 \% c# P3 R2 o. n! Z
    ! b/ A; U! ?! E0 \' |8 _# V* }* }

    6 g  m; I  g) d3.1 生成器
    ( C( S- e# g1 v. ~3 W: ~2 m8 S; u8 J7 }9 G
    ' |: g* K! W& }; ]- D" N0 r) R
    3.1 判别器
    4 Y4 I0 I0 b2 e! F( F5 Q
    6 o. N$ M% V, [: X
    : J5 c2 [) `8 b) E& `
    四、定义损失函数和优化器' `* a& b8 m6 @8 a6 S, b
    7 @3 T4 ]$ A7 |3 h$ {7 \

    / `7 @3 q3 {! S1 _) e4.1 生成器的损失和优化器
    - h# ^- G& ^. V) T; c2 r' u- K. ]- ~  D

    6 b) t. F) ^$ b7 {4.2 判别器的损失和优化器
    1 `9 J3 P7 O& y: V5 G6 Y9 ]  T
    4 s8 o9 F% s! s) {4 ~, a4 ]4 U: }

    + E/ |( Z8 Q# p6 ^' m! A/ t五、训练模型
    * d6 A- m$ ?  l6 T8 F
    4 i& B' o8 Z$ h7 o) Q" r& T( N+ i# O
    + U1 k! ?' D8 E7 X( t; W
    5.1 保存检查点
    % ^! D: r+ I. ^2 b4 b& `) P9 d8 v6 i1 B! n8 k$ c1 A0 i* l

    6 W0 [1 I8 W8 \4 |# T  V- A8 H5.2 定义训练过程  A" ]+ E: @% p5 l: U! m6 Y  x( N( {  f

    ) g  }$ z6 z9 \" N
    8 |0 {1 T6 I  _1 w* R! [4 S+ c4 L
    5.3 训练模型1 q, z% K  j9 T  v9 t/ H9 l( I
    3 H* Q) j5 y$ a+ v/ }
    ! j' d% a3 Q/ v  ?6 y
    六、评估模型
    8 Z; g) O1 z0 B. ]* I
    0 l1 R. |7 d( p3 Z* m" G" Q7 I

    0 ?$ p& V9 W! j一、什么是生成对抗网络?
    ; v* \% Y- S! S4 D生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。* i- K; x4 _" P) _

    & @* f( a! T. Z. A% k8 c: ~+ C
    ( ]' c8 _$ z! p: M* F* I
    生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
    7 \! P; v* O7 ~) |, W! l" B& R; ^# b
    6 B) y- |! M" t" R3 D* V. R

    # b$ r4 u  ~7 _判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
    + X- h9 M  g- r3 _+ F  Z2 S, z- t4 n8 C1 I" c

    4 U% @( ^1 p8 h- q训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
    6 H9 w( u2 x8 Y( _. m4 J! G  N: ?4 u# v

    0 O0 N: h8 x0 O* `8 _) |- J当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
    ; ^5 g! c' b- i2 L6 @; F* {& \/ B; k6 f

    2 Q/ Q$ R9 w5 U$ o本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。. |6 o- w; Z# K" b$ `1 O

    3 K4 Y; n+ ?# ]8 p! z# Z* E" [
    & o' V; v; e3 s  C! [* M
    二、加载数据集
    & O  k* Q# G$ n使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。
    ) F; S1 j9 G( j& f) p* D
    ; ?5 b- E4 d3 V* c9 X8 W2 L# O

    6 P: R  V6 U# m(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
    , ^% e; \2 Z2 z8 R* o- ~ 0 W" E( n4 k; f8 O
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')- h6 A9 ~. x& e* l1 f
    train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内- ]9 L& L; w/ Q5 {

    1 G4 N4 i* n3 h1 X2 i9 }BUFFER_SIZE = 60000
    2 _% F: Z6 f$ S9 s& k$ f: u7 @/ oBATCH_SIZE = 256
      q0 c7 q5 v" P2 a. N" A! \1 e$ J. m
    % r. F1 A, L# v- o4 B# 批量化和打乱数据
    . ?& C& m5 f9 m/ htrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
    + P# A( B5 T! {% w三、创建模型
    3 Z: D+ ?0 N' S5 R$ w) w- P7 R主要创建两个模型,一个是生成器,另一个是判别器。
    " o/ d* X- v, K4 B: [9 ]* T
    % e& u+ \. J6 t5 T; z4 W  O: n& ^

    & h+ B& C) k- s  E. y3.1 生成器
    1 x( s& }/ e9 S% h$ v( E* S生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。  V3 K9 D9 s! f/ m* \
    ! N) l. p! n* I8 q: H3 b( X! h9 x: _
    " }6 j7 _$ j1 k7 n* n4 r
    然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。0 E3 n4 g. v; Y0 d$ o( S5 w
    ! o* C/ J. H0 K1 K, i

    " |. Y5 c- ]4 I. T后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。; B( f3 P$ Q! M6 U; d$ A3 R# p
    " D6 O) S+ w2 L% M4 u* d. z; l

    # E' U5 N) O) E6 `- V8 I: wdef make_generator_model():- L% l) \1 s; G
        model = tf.keras.Sequential()
    ) V) u) W, _8 W& G1 _' M, [- N    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))0 \2 y! E0 u% A4 L$ M8 ^$ `
        model.add(layers.BatchNormalization())0 }' l0 m; g) m) q* m' |5 h3 ]
        model.add(layers.LeakyReLU())
    4 i* U# n/ }: L. r* p4 G ; {. o5 m/ d$ r. M' j
        model.add(layers.Reshape((7, 7, 256)))
    $ {6 Z. A4 [& E. q# b8 ^5 z! k5 X    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
    4 ?# ^  ?* [( q0 G+ }
    0 b2 {- p& p% x9 k/ C    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))8 {! E( |; a2 M0 E
        assert model.output_shape == (None, 7, 7, 128)/ w& x# b# ?/ T$ F
        model.add(layers.BatchNormalization()): Z( ]0 O$ m- \0 Q- q
        model.add(layers.LeakyReLU())
    3 B2 r# N# ]' _0 c6 g 0 F4 l' K. e. j
        model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))/ g9 \3 J5 a1 C4 D) F' X
        assert model.output_shape == (None, 14, 14, 64)
    9 D, c: U  `2 B5 s! _* O    model.add(layers.BatchNormalization())
    4 n8 `, G# k: F! @6 P    model.add(layers.LeakyReLU())
    $ S% I; i" {3 o+ ^5 b5 m- ? - ^2 h; D% V% ^$ Z7 O  q
        model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    ; z3 x* J4 y2 U0 v8 o* }* f    assert model.output_shape == (None, 28, 28, 1)4 k0 j- h# J% g
    9 r2 F2 g9 F* W& A( c+ Y
        return model
    # R/ ?: ^$ v7 r3 L8 B用tf.keras.utils.plot_model( ),看一下模型结构: ^5 y, g( }$ a  `7 i4 k: q
    ( Q  H& L! m- {3 T4 C, B/ x
    ! V, D4 ^6 J: ^/ I
    " I2 I4 Q  d1 C

    9 i0 }3 N. v" L  r  [( N: F; I
    # s. Z1 M* g) G4 i6 K
    用summary(),看一下模型结构和参数8 k' T0 C% g  I, {$ `
    ) i8 @1 ?# b4 c  l7 Z( r
    . Y+ J3 ^. M! B, F/ O- `& P, H% g) }

    ! b9 m. ^4 K" y  _* {% F

    0 S/ {) B, R9 [; X* W' R# m+ S3 w  \0 |2 h+ D' S

    . |1 S. A$ t+ p使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
    % h! ]/ \  y1 ~  }
    + [7 G  A6 i  k6 R

    ( @3 O2 w4 ~% _* Mgenerator = make_generator_model()
    . u" H" p% ^3 {4 b
    - X9 R" S0 S# c, U8 v! I$ l' enoise = tf.random.normal([1, 100])# R6 e  G! e, q! G# a; Y: C
    generated_image = generator(noise, training=False)8 _. G. [* \3 q8 l% [6 C; z4 x

    4 i2 E, n" l8 H; _, q$ U; n; Bplt.imshow(generated_image[0, :, :, 0], cmap='gray'). A+ k: I2 L" y: N+ w
    - j4 N7 K. m( W& g  O

    . t5 v$ j1 J2 B- F; e( w5 Z
    * p0 n" P) d$ h" ~! c
    " r, A6 ?7 K" G" x( k8 {
    3.1 判别器) g" @9 _% _1 }" a- ?4 J
    判别器是基于 CNN卷积神经网络 的图片分类器。
    / s2 i  O% _7 D, c6 {' }" F
      z" Z' G9 l- [9 ^
    " |2 t( v1 n1 ]3 ?7 A) X
    def make_discriminator_model():# z  E7 B( y; Y. y% T" m& c
        model = tf.keras.Sequential()2 \2 I$ D8 O. U' T4 K$ L
        model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',+ A1 K9 ~2 L, z/ h8 j; ?
                                         input_shape=[28, 28, 1]))
      K7 Z+ v! r7 m3 c    model.add(layers.LeakyReLU())
    & q) @+ M: l% a  `( G    model.add(layers.Dropout(0.3))$ V( N2 k9 M+ ~% X8 }
    1 ^6 q1 B0 l. K; ]
        model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))* J$ \+ u# t! d3 Q2 h. P
        model.add(layers.LeakyReLU())- a1 h/ l7 g$ j/ |- @2 O
        model.add(layers.Dropout(0.3))# P- W. `$ e+ r5 t9 K* _
    6 F8 u! t5 }4 O3 ~1 ^& h5 [6 g, Q
        model.add(layers.Flatten())
    3 u- z7 s1 e2 B( l1 u# `6 i2 ~6 k    model.add(layers.Dense(1))
    % Y8 i$ z5 G4 c) H; D
    4 m  Y( ~& K5 @  @% ]/ F  s    return model6 T4 o" A) A- P
    用tf.keras.utils.plot_model( ),看一下模型结构0 j) x) l& i; p" U
      _% z# {8 S3 o% q% [6 s
    ' P7 n7 h# ?/ {& M8 P, x
    : u- s* ~: w# o/ M8 c) z

    8 d8 Q2 t1 K3 A4 m- y7 R/ Z% Z0 K% Q/ B* J: J' d1 Q2 j+ J* s7 E% j5 W4 k

    7 I7 c# V- u4 `9 J+ @! i- u用summary(),看一下模型结构和参数" R; c$ i% B* K- H4 ?, S; @
    1 t! E; u) F9 Z$ v# t/ J3 I

    2 G6 C, d- S$ c$ [$ J: O' k- W) m7 {" ^* B8 e( G* h
    $ i" n" s! I7 o) |9 \! l. ^

    " q* U/ ]2 J$ h# \/ B

    7 F# {7 E% e: _5 _: n* H- q四、定义损失函数和优化器% M% C- M( m  K: }
    由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。% ^7 {; P* J7 c, V, E2 }
    5 G0 g& j( ~9 F0 I8 x4 R
    5 s6 q  h$ [4 a8 H2 ^. a% j& b( J
    首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
    $ q" S2 U. @1 ^0 V/ V' T9 V+ M: i3 ?- N/ i4 H" }
    5 \8 o( x) i, q! ?% M4 g9 Y! p6 z
    # 该方法返回计算交叉熵损失的辅助函数
    2 r3 F, G. j6 B$ V# Y' Dcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    ( z4 C- \4 P# w. x6 s4.1 生成器的损失和优化器, d1 J" a1 \; |/ t( s4 f
    1)生成器损失
    * v: b7 t/ H& @% u
    1 Q2 |" u. A3 T. l
    . d+ \, t9 J) G
    生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
    1 m: ?+ X6 ]* b) o+ [. `: c# n6 A5 w
    / n8 J' \/ _7 \1 ?9 \0 ~3 w

    8 h1 L- E% a% e2 Q# X) N$ i这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
    ' n8 D9 X6 D& V! \8 g- V$ x% r- |$ R% S2 N2 n! i

    : v. U3 f& v5 g! V: t1 B- ?5 V7 m* Gdef generator_loss(fake_output):
    ) n9 q% `% v2 j0 k0 i    return cross_entropy(tf.ones_like(fake_output), fake_output)$ c) p7 M) K/ K6 J$ B
    2)生成器优化器
    4 f2 o' A6 P7 o4 D% K2 x
    ; N( n9 G$ x' ]1 v: R* {( \

    9 W. n; h  k  M# G4 H6 Ngenerator_optimizer = tf.keras.optimizers.Adam(1e-4)
    # x- D4 d3 X0 @: A4 @4.2 判别器的损失和优化器' m' ?, j3 S# v' @
    1)判别器损失: Q& `9 N3 |$ m( _8 e# F& G% r! u

    * r& M+ G( {* w; A8 a

    ; K6 m' i. N: j5 I5 M1 N% c判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
    ! w, C( X- [$ a* N1 [
    ! `& h! J% o' O7 b: B' h
    $ B+ X6 B+ e+ X& z* ]
    def discriminator_loss(real_output, fake_output):
    7 n* ?7 s9 L- N) a/ R    real_loss = cross_entropy(tf.ones_like(real_output), real_output), y; P1 G' Z3 n5 H% x7 h! a/ N
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output), O0 W; \7 ?. h7 v; Z% O* y
        total_loss = real_loss + fake_loss' q/ Q# c# `3 b; I% F
        return total_loss9 ]0 `* r. ?) k% H; p  U
    2)判别器优化器6 |" a% ~, q9 Z3 d6 u3 W
    6 p/ X. K/ y( A4 D0 H1 J% W
    4 `" D2 B8 ?, ^6 w- m  H- E# ]
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
    + o( R% ^0 E! C" C3 X五、训练模型: q: k& m# s0 ], V9 j  q; {
    5.1 保存检查点1 |0 f& x5 @2 \' _- n! B
    保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
    2 X8 T$ ]& c0 ?$ f  e! r" I" Y8 c) d
      U+ d4 F2 X2 Y5 m2 F/ c- i  L
    checkpoint_dir = './training_checkpoints'# e3 v; }3 h8 x, E
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")$ w1 S( x; b+ G* o  L( z
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,5 I3 r* G1 ^0 Y% \- C) A( X( c
                                     discriminator_optimizer=discriminator_optimizer,) o7 x; l0 h  N  l; x+ p. c
                                     generator=generator,
      O. z" E9 p$ @9 ~6 k6 r# f                                 discriminator=discriminator)
    ( A1 ^" q" z9 n% M! [. V: W5.2 定义训练过程0 b; B; j: O% W8 ~7 C" m8 `
    EPOCHS = 502 d9 `; Y( ^# i1 o: z
    noise_dim = 100
    , Z/ d' ]4 D  g8 `# Y! E4 j2 wnum_examples_to_generate = 16
    ! D% `3 C% L' ?" _
    + \* n7 {) [7 c3 P ; w  h9 {. ~- g2 E7 ?5 e( t; p
    # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)  g! L! R2 S! A
    seed = tf.random.normal([num_examples_to_generate, noise_dim])2 E4 W' m; r, p( B" K9 g
    训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。1 o! m4 a3 F! j; S' H: \
    # c* t5 `) \" {# O6 B& e

    1 p1 Y- B0 K) Z8 K% o' l判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
    / F' H0 |3 y8 m* G
    2 ]4 {+ x" k# d' \$ Z

    + l! u1 y' M, ]6 z4 t4 H5 ^两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
    5 t( A6 }$ z$ y; U
    ! h0 t' A6 T% |. b: Y4 l! i
    # z# N1 ]' {4 P9 {* h( `
    # 注意 `tf.function` 的使用
    + B/ w: U4 c8 ~: r; l% Z# 该注解使函数被“编译”7 r; k: c0 A- _; C
    @tf.function
    3 {9 u3 b1 _" L: Adef train_step(images):
    ( f4 I7 M* ~/ X2 i    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    1 a! _0 d- c$ W+ y 5 t" P7 M' @- S6 w% b
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:  I1 _3 t. c, S0 R
          generated_images = generator(noise, training=True)
    8 X% M3 y; o- R- U; @; x- w$ L ( ]2 \& e  E7 \# a) s8 Y
          real_output = discriminator(images, training=True); l4 d3 Y+ i: k6 L$ X  ?; b- \
          fake_output = discriminator(generated_images, training=True)$ r; y5 r; ]# ?

    " z- p4 J7 p& {      gen_loss = generator_loss(fake_output)
    * @( d0 `$ W4 K: w      disc_loss = discriminator_loss(real_output, fake_output)
    " y) P- _% v8 `0 Y8 o  n+ \& W- o 8 B' I& ]1 A! \7 y! X3 g7 {) C- Z+ x. Y
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)- {( a% r# h  N$ J6 @* b
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    - W$ G1 W+ b* y/ X8 L
      O: p/ `# C' e1 I    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    ( F; d; F0 Q1 Q5 q    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    3 T3 C/ O7 ^: D3 D7 s, z, s ' I% Q9 J' z' l* @8 {8 e+ U# j
    def train(dataset, epochs):9 R+ K$ h) G# w. _3 l7 d
      for epoch in range(epochs):* e( @' L3 J6 A: L
        start = time.time()
    ' Q1 r, g* c  m1 c* O9 D3 s
    ( P" t) o4 y6 [+ q; M1 S' @) J  n    for image_batch in dataset:* g& I+ I% x' [" X; O9 R# {
          train_step(image_batch)
    " z* _/ W3 F$ `7 }9 S; [9 ` 0 K' T! j- E1 N, C
        # 继续进行时为 GIF 生成图像
    + F, P( e0 M; }/ B% ^    display.clear_output(wait=True)
    4 }) f/ u$ ~  N0 P0 y) A8 g    generate_and_save_images(generator,( j* _) |. n( ~1 z% @3 u% B/ u) W& c
                                 epoch + 1,
    ) ^9 `' `  j/ t6 M7 \' F                             seed)  _; x6 F! m# H/ u# l. j* j
    9 R& M5 o: \; p! H
        # 每 15 个 epoch 保存一次模型4 G! L! F; E- y0 Q
        if (epoch + 1) % 15 == 0:
    8 A" @. k( ~9 s7 |& G# W      checkpoint.save(file_prefix = checkpoint_prefix)
    2 N  v7 J+ a2 D+ W1 i% q 6 y  R+ [1 B# d  B
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))6 y# [; g" H/ ]5 s) T& l$ E# S) E) j

    ( O$ K8 [8 A  B/ D  # 最后一个 epoch 结束后生成图片
    ! N9 {& n8 D/ Z  display.clear_output(wait=True)
    . q! H! \) i+ D( q- b6 N- O  generate_and_save_images(generator,
    2 c( v4 R9 {  k% W                           epochs,
    3 e  Y3 }0 [# i; R: Q; M                           seed); {* _6 x; X  R5 k& q; W

    . Z; r7 {- X2 l2 `# 生成与保存图片' L0 J5 ~0 o2 ~
    def generate_and_save_images(model, epoch, test_input):
    - U$ i' u2 \9 t  # 注意 training` 设定为 False* D' H7 \  v# |; h: g5 o# R; X/ r
      # 因此,所有层都在推理模式下运行(batchnorm)。" D2 L! K& s. k3 P" m" G
      predictions = model(test_input, training=False)
    & ?! K4 N; ^- v # y( H. \& o8 t" C& R8 A$ `
      fig = plt.figure(figsize=(4,4))
    4 T5 U5 k! W7 `- W/ ~. M. ?+ d9 M3 D
    5 S. U5 G; o/ e$ A' S  for i in range(predictions.shape[0]):! ?$ e7 z/ P  n/ ]8 r6 [( D
          plt.subplot(4, 4, i+1)
    5 B/ d" U* O; l      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    4 `+ W! }8 i; {# c: n2 `      plt.axis('off'). _2 h) o# \, @$ U; q* Y
    * \; T) X) o& M  K* s2 ?: ?: ~# g5 H
      plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    ; v7 m. \+ e* E5 ~  plt.show()
    % @/ J+ Q# H* X+ K5.3 训练模型
    $ O4 {5 {$ |3 {调用上面定义的train()函数,来同时训练生成器和判别器。1 |. x+ X- u2 j  _. b3 e1 F3 p
    0 \2 s' x: F5 T, j

    $ u! U5 W* c4 n( n; t- n注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。& e0 M% k7 q/ _* h7 E% l2 M
    0 z* R' y# t6 [4 h2 ~
    $ X% F7 R: m, A! e, @, D- W
    %%time5 J1 t/ A' D. l8 h9 T
    train(train_dataset, EPOCHS)
    & f+ T% l) i% K- C/ m% Y在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
    + V' B( v, {3 Q# f) x, q7 q! I
    9 y- B% q3 ~5 _8 T9 ~: s

    " ^6 g9 K) W! e' Q  H4 h% u训练了15轮的效果:# @1 |. g  h* q) z* T

    - C4 n) C- X$ `* h
    % s% M0 i. S5 l" E' {1 t* u
    : c. n6 E! y4 ?
    7 L- F8 z* c1 ?7 H

    0 _* \, Y3 K: }  h5 @

    ; [) R5 y3 ]' _, p5 s8 b9 z训练了30轮的效果:
    ! Y( T; _& K2 s4 E! H9 |# d( H5 P# x) C0 ]6 F

    0 m( \, j  y' O; e4 x. u7 d/ a! N7 B) ~- U7 s2 _/ w
    ; I) O' T: Y4 o9 n2 v, ?4 D

    - ?$ r+ S+ S8 C1 E! m1 q4 V

    1 F" ]! B/ z  J- f# ~# k训练过程:' }6 r- H1 a  v
    ; s7 w! I5 W6 B( p9 W6 q( T
    * z) U; ]  h7 v3 p
    4 K% o5 Q8 ^; `2 }( X: b8 K+ g

    ' V* i9 r$ [$ x0 G$ ~- w; m
    5 `8 Z( R: W$ O

    3 w1 C3 d& Y# Z0 N* J6 J' S恢复最新的检查点
    , W$ l; N/ q6 G1 O/ T* x, l7 b
    1 L) j1 r& u+ z, x1 U

    . G7 n" M6 V7 n) ^7 F) U" C! ucheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))8 T6 ~2 f, H) I5 d, _: B2 V
    六、评估模型8 L  ~+ _; C8 }- n) q
    这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
    2 T: y$ Z" @; D  Y4 a6 z, C7 t# h: [2 B
    6 l! T# ?& n7 j: V- W: w6 n4 _
    8 ?4 I- Y$ u# e/ f3 y& Y/ i
    # 使用 epoch 数生成单张图片' g5 z% m. V0 a1 ^9 V
    def display_image(epoch_no):
    ( S# e) X0 Y" O/ M( @; w! q! W; [# b  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
    & W- T  ~' Z1 w3 c# J8 n
    ! z% n$ T  U. M" |* ?+ H7 m! odisplay_image(EPOCHS)
      `$ m- P4 g" w3 b! nanim_file = 'dcgan.gif'. p. k! z+ h: f" M4 T/ u+ \4 O% T

    2 P) ^( L( y: Zwith imageio.get_writer(anim_file, mode='I') as writer:! q1 A8 h3 r5 I6 M
      filenames = glob.glob('image*.png')
    ! Y2 S) z5 P$ U) y  filenames = sorted(filenames)
    3 [/ b& K. q" i% u$ f  last = -1) N7 W! F+ u/ p! F  U; x: r# l, L. ^) t
      for i,filename in enumerate(filenames):( X  j$ L. c  Y" W/ E. k1 o' N* b
        frame = 2*(i**0.5)' k( O3 j1 H; k
        if round(frame) > round(last):
    / o1 g) m4 `; Z$ @2 U      last = frame
    ! ?7 }- A2 L6 C3 L0 n* d8 F$ ^    else:
    & ^* o) Z/ l( r1 ^2 N      continue, c& o7 ?& T# j9 O6 I7 |
        image = imageio.imread(filename)% n! t- n3 _. m+ e
        writer.append_data(image)
    0 q! u+ `& R3 F4 q" a( O3 ]6 R* Y7 }  image = imageio.imread(filename)6 x$ ^* S( T7 ~% S8 d
      writer.append_data(image)
    ! n6 L( n3 c- j. v8 N0 C& F6 w; W 0 Q6 Q) V4 ^/ \* P. Q8 @
    import IPython
    ) `. H8 y' P5 p& A  \1 @if IPython.version_info > (6,2,0,''):
    5 z. }4 ~. c- `# z" n% D  display.Image(filename=anim_file)
    3 Y. A7 ^) L7 H4 [3 \4 T$ n0 u3 u/ K. s/ \
    9 q. G- y1 G1 Z  n- V% d9 P

    * f& I; h( P. k$ s
    6 p2 J& [* v0 f8 l+ e& S0 D9 U
    完整代码:/ y) {+ A4 |( w2 j; f

    8 L6 [* I# X% {2 I

    9 _7 \1 R; ^6 ]: {! N/ R! Dimport tensorflow as tf
    4 l5 A$ ?) @( r  gimport glob3 Z/ p1 Z, \& r
    import imageio
    ; `4 @/ [: L- @# Z' M& Kimport matplotlib.pyplot as plt
    ) d: |; @) @1 c& W( rimport numpy as np  C% z) y+ Q! p7 Q: t+ @# C# w
    import os
    ; s  M9 B  g* D" Pimport PIL6 n% ^  N) j6 j2 K  s
    from tensorflow.keras import layers
    5 D- t4 e( x1 }. @import time  H6 K# ~) p4 o; F% ?

    # f2 m- s" H# C9 w6 s9 \; Kfrom IPython import display+ y$ _9 i0 V* j2 m* P2 A- B

    * Y; {  p3 M, M5 N" {(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
    6 M, J1 B0 u  l( s+ v* p' y ( F# ?: Q; l: M
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
    7 `" W+ Q; x5 m, J% Q4 C! ntrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
    + h& ~  O7 h& q$ B4 l   L# Z6 a: {+ S/ I4 U0 x7 u/ T
    BUFFER_SIZE = 600009 B7 `. S7 S) t4 Z3 l) A+ [6 A" b8 m
    BATCH_SIZE = 256
    : q  W9 w( L9 ~7 v+ n" H7 E! t   J6 K- w! @( ~; L+ G$ o9 i8 Q/ D  ~, a! a; f
    # 批量化和打乱数据
    & K3 }, J6 Z, S) x. k( k$ N1 atrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)8 d7 M" S3 X! k
    $ ^% p: Q, e1 D0 A% e/ }
    # 创建模型--生成器
    ! O9 h4 R  J9 A, P6 }2 _; gdef make_generator_model():
    . b+ n5 y& [  F    model = tf.keras.Sequential()
    . ~: g) O/ I7 |' x/ F    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))8 J. i3 A; H7 ~% G) m; d- N; |. Z
        model.add(layers.BatchNormalization())2 F* n% |1 U( f4 \5 D2 g- g
        model.add(layers.LeakyReLU())
    / y# L/ y& g' R0 i9 r+ p6 _- Q - ~( y( z1 [* m: f( E
        model.add(layers.Reshape((7, 7, 256)))( @+ Q. n+ v0 g" j% C) J' ]4 C
        assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
    - x( g8 J% c3 [1 N- h6 D 5 M: _. u6 t) B. y% _
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)): Y* r! p6 z6 U8 e8 t
        assert model.output_shape == (None, 7, 7, 128)
    1 k1 w1 C, @; U5 H/ O0 `, p    model.add(layers.BatchNormalization())/ U: @' Q0 M) O' t7 e% K& H
        model.add(layers.LeakyReLU())
    7 G' Z, s& @/ M; T& B* Z
    1 n( Q9 ~$ z0 {! U3 ~% ]    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)): C+ `8 k8 `# _* i
        assert model.output_shape == (None, 14, 14, 64)
    . m; B( N' N% a* C; Q8 J    model.add(layers.BatchNormalization())3 }2 _% d  o$ W' }6 ?
        model.add(layers.LeakyReLU())9 [/ ~) L- x% V' @

    8 [2 x8 J' [/ T, w2 H% T; m    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))% n: `' V& T3 ^$ D0 j8 d1 n
        assert model.output_shape == (None, 28, 28, 1)  R0 y. n% g" @0 ^" N6 E8 V

    7 E1 q4 P( W! A    return model% N" d6 F" d- S

    . ^  d# a- e- k) k# \+ ~3 F# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
    9 p* A+ @* D" cgenerator = make_generator_model()  O: y  l0 [- ]+ Z% }; l; I
    , j, d: E; j/ w, R" C
    noise = tf.random.normal([1, 100])" Q$ l5 b; l' y& U" E4 N0 z
    generated_image = generator(noise, training=False): Z* }& C0 m. ?* {

    7 }2 C! Q8 ~1 L+ l/ `plt.imshow(generated_image[0, :, :, 0], cmap='gray'), I4 j7 M$ C4 i& h5 o% G+ B
    tf.keras.utils.plot_model(generator)% ~1 L' m5 J% ~6 M$ x

    ) n& z8 u  @7 C& \3 f# 判别器
    # T" R, W+ n8 J9 y" v) P7 sdef make_discriminator_model():
    ) ~9 d" F: r/ m    model = tf.keras.Sequential()' O- U. }5 J% \, ^
        model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',( h& T$ h1 b0 a$ J3 i
                                         input_shape=[28, 28, 1]))
    & S1 V& k& b& ?+ Y# n- F    model.add(layers.LeakyReLU())1 E6 g# h2 i! U* {
        model.add(layers.Dropout(0.3))" O( h7 Y5 K4 l" K" h

      g2 x7 N5 ]! u0 D2 y# Q    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    / o6 I0 d, f2 O    model.add(layers.LeakyReLU())3 H& w+ D/ D% m% o) c$ k0 ?
        model.add(layers.Dropout(0.3))
    , J) H3 a2 X- y5 Y2 h: S, |2 v: T * ^) e1 T+ S$ d1 g4 L% d
        model.add(layers.Flatten())& C+ {) i* q8 T1 w0 w
        model.add(layers.Dense(1))
      ^1 P$ }; n% H5 b- b# z
      @- V, g% g% P    return model
    " k8 i. X2 q( b4 |6 Z- B 7 ^# Z5 v- f; p. y  ?$ l& X/ @
    # 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
    " ]+ o6 J3 D- [$ Q5 M: _discriminator = make_discriminator_model()
    * V% z- n: j6 P( N% X. v7 Edecision = discriminator(generated_image)
    9 x6 A# }" m0 W2 x" \8 Q2 Z  M! d% Sprint (decision)! X0 C/ [; o8 X9 @% O9 d
    ' z1 b  Z2 M3 Z+ q/ x  c& o
    # 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。, T5 `. m7 q6 i. M7 I% Y
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)$ _9 c) z  d' }) Y- M4 t$ n6 i" ?- k
    2 k4 J' e3 _& y$ p0 g
    # 生成器的损失和优化器
    # |0 y. y( [) w2 r" ], [  j- e( R  Ldef generator_loss(fake_output):0 \# X! i. b5 i- [; e
        return cross_entropy(tf.ones_like(fake_output), fake_output)- g! g1 i1 B8 |  `* z- i
    generator_optimizer = tf.keras.optimizers.Adam(1e-4)- D) h* x& o: \, j

    0 R; T# J& ]/ z# 判别器的损失和优化器0 c7 V0 g, h: Z; e9 n
    def discriminator_loss(real_output, fake_output):8 v7 z( f! u# P2 h6 C0 J
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)5 i* M0 |/ q. N9 e
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    . p2 K9 o2 G2 K' v( b    total_loss = real_loss + fake_loss3 Y& a9 s+ T- K8 C; k* u6 g' p! R: I$ p
        return total_loss
    1 u: c5 r) g% H! hdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)- i8 g6 T; z9 ?+ l: K5 r/ Y

    # F6 |) n& X! ?! f9 U# 保存检查点! P  v3 `( ?6 w5 k) g, @
    checkpoint_dir = './training_checkpoints'- h6 z* Z" y4 d! E& {
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    2 p: a# k( r7 V+ Q8 [  [checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,6 H9 y9 M& Q* a. g2 ]- R
                                     discriminator_optimizer=discriminator_optimizer,& c# j5 X/ X" F0 q: Y$ z$ _2 n
                                     generator=generator,
    9 }6 S' \2 ?$ w: N/ E                                 discriminator=discriminator)* Q+ c' V1 R# ~

    * k' i8 |3 e% q7 I' E# 定义训练过程' I9 V* T& l/ _) Q" D$ w6 o& |
    EPOCHS = 50
    2 j; H4 d8 _$ N! S! k, Bnoise_dim = 100
    - c5 h. H/ |3 K: c# j9 m' Qnum_examples_to_generate = 16
    & [( {( p0 X4 w4 ^0 E7 R/ s: _! Y : m6 c! H6 k, x8 K  R
    # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
    ! N0 S; ^/ X( J. a7 Useed = tf.random.normal([num_examples_to_generate, noise_dim])
    " g$ k% E  _% I8 R6 G  s
    / _8 [" _" G; r0 k( e+ \0 C# 注意 `tf.function` 的使用
    : p7 Y7 X" P* z* \' k7 }* u# 该注解使函数被“编译”$ x; C6 T* g) ^6 M" \' S- H/ J! S
    @tf.function0 u: x" c0 W( h
    def train_step(images):6 N, k$ V) @! r' B9 t  `4 [0 z) p
        noise = tf.random.normal([BATCH_SIZE, noise_dim])- N8 \) C7 ~4 v1 q0 c( X

    ' l# A7 o: @9 M; t" f# E- r) u    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:% u/ p/ n. n* T0 G9 R, \' u' T$ z
          generated_images = generator(noise, training=True), z# w/ w! B3 U0 k5 ^7 E9 v% U( F

    # O/ [6 D6 u$ K      real_output = discriminator(images, training=True)
    3 v! E- t& [+ ]- A) {      fake_output = discriminator(generated_images, training=True)# r* t9 t$ |$ ], q8 ^( ]! q, T' ]
    3 l# K0 q; q- }* w4 f0 u5 a5 w
          gen_loss = generator_loss(fake_output)
    0 v0 v6 E4 B/ ?7 f) F; ?9 j8 q      disc_loss = discriminator_loss(real_output, fake_output)8 j. a8 m! _6 S# r2 }% }  f

      G. P: y' h1 U: D* @    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)& r& T; M( y5 {$ G  K1 g
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    ! }7 Z* A- D$ I4 L
    * p9 M5 V! G, H! I9 m4 @    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))+ O) q* u" F$ R( C, A2 X( y
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))# X8 ~" S2 h/ \1 |  t- L3 M
    2 D, f3 `: \4 O. v* S; y
    def train(dataset, epochs):8 g. f; j$ k( Q5 {0 ~- b
      for epoch in range(epochs):: w- E* l. s/ Y& p" p
        start = time.time()
    4 e% [, v7 t4 n  z
    ) O) T. b6 j! l# h/ c    for image_batch in dataset:
    + |  \2 B) v6 N; _8 N# r      train_step(image_batch)5 O( z2 U2 b8 l0 |( q3 [

    # d/ }- m- k1 Q9 d) Z2 t. ~    # 继续进行时为 GIF 生成图像
    3 t2 d- g5 P  ^  M, T  N' ~( ^    display.clear_output(wait=True)- T) B; ~* O( X5 B
        generate_and_save_images(generator,
    ) n- g" L  r( h# x. M& c8 f                             epoch + 1,
    5 J* A: D4 j0 x2 O; s* h: K1 |3 J                             seed)# V4 g0 w  g; m- w+ K2 U

    - G$ v0 A1 A7 k; n; D4 t    # 每 15 个 epoch 保存一次模型
    5 L) k3 r0 w4 N0 e1 |( N" r    if (epoch + 1) % 15 == 0:5 X' r2 b' C1 q2 |/ a- k/ `5 v( U
          checkpoint.save(file_prefix = checkpoint_prefix): H8 o5 `2 s0 s1 d* Z0 D

      W5 ^" g  m$ E2 W8 p! G    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    6 `# ], J/ \# V2 R: b - Z: w% K9 V# J4 i3 c# ^
      # 最后一个 epoch 结束后生成图片
    & L0 b2 y2 N8 l. L1 t% Y( K  display.clear_output(wait=True)) \6 R! c) p; F  S6 v# G8 k9 Z, M
      generate_and_save_images(generator,
    3 Z! O1 ]. B" D                           epochs,+ K" k. q1 |4 m7 o1 F
                               seed)  C  g/ I+ G: d  S2 |

    & @4 Y, J2 m3 c! S& e0 t' G; T# 生成与保存图片/ M' F8 M+ Z: X- Q8 d
    def generate_and_save_images(model, epoch, test_input):0 y  E9 }, }8 z0 K4 A5 i8 P
      # 注意 training` 设定为 False/ z# z2 A2 v9 h+ ?
      # 因此,所有层都在推理模式下运行(batchnorm)。$ u" w. n. e; z! A. w* E* W* U
      predictions = model(test_input, training=False). @6 a4 O- W3 J3 H* U
    % I9 c+ |* O3 f  ?8 ?
      fig = plt.figure(figsize=(4,4))% F2 I5 y$ O/ m
    $ O" ^0 I3 T7 a  k8 @. Q# k8 v* p
      for i in range(predictions.shape[0]):
    ) X* @; j; j8 n% z9 H3 J0 m3 N- }      plt.subplot(4, 4, i+1)
    & u7 V/ J5 D  B; N) @      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')# E) s  W; u/ [- w
          plt.axis('off')# u5 d  Y; \% v0 f) _! W, j

    ' t+ w: N5 \6 e+ o! w- X7 }  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))) _$ X; T: d/ \6 ]6 r- _$ r
      plt.show()
    % O8 P. U; x! w; R9 i$ ~  M# A$ q  N , @$ c! X! `; G7 b
    # 训练模型1 H9 G( G8 ^8 z% G# M" f
    train(train_dataset, EPOCHS)0 s0 k- h0 M" `# q6 T! Y

    2 s% ]% ~" e- ]) H4 p$ O, R( S: w. E$ D# 恢复最新的检查点
    5 I! m+ Y  Z: H( xcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
      Z' q2 g( A" q* N  [ ' @7 A' f: x) l# W8 Y
    # 评估模型2 l5 r: s4 o+ @, m& `6 n
    # 使用 epoch 数生成单张图片
      W1 I6 ^( I6 I; D. B" t; T, r' R" ~def display_image(epoch_no):
    - v. p5 w. {3 z$ c  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
      V& t6 p' ~+ B- z+ ]. Q4 k( ^3 Q1 ]
    - r/ ]  k$ d6 c5 zdisplay_image(EPOCHS)
    9 u! o% c5 M2 m
    4 M& g8 T2 ^( G' w- E9 ?( z- Danim_file = 'dcgan.gif'% y& p$ h2 N* u
    ( d" U5 g9 b$ f% W4 I/ k4 ?
    with imageio.get_writer(anim_file, mode='I') as writer:
    ( a2 V" Q, d2 \; m# k* g' [  filenames = glob.glob('image*.png')
    / D0 E5 {+ Z* N5 v$ ^  filenames = sorted(filenames)
    ! t& Q7 g. u" {6 K- N5 b  last = -1: \1 Y3 J, c# B3 u9 d8 ^% F
      for i,filename in enumerate(filenames):
    $ w2 }/ S6 D& {    frame = 2*(i**0.5)0 i  M/ n5 h- ]; D. m9 Q
        if round(frame) > round(last):
    4 b* v. S; P7 ^' u$ u8 {      last = frame) z2 t7 w$ I4 I; J
        else:
    + s5 Y$ `! X9 u8 w' K" v7 E. Z      continue1 i- \: ~% w% f; N4 W9 g
        image = imageio.imread(filename)
    , c  B& h2 ^; i2 z) ?- n4 ?! Y& K    writer.append_data(image)" M/ `9 t8 z" @# K" p
      image = imageio.imread(filename)
    * x! ]- l1 p! U! k. F$ q- A  writer.append_data(image)
    # @( f3 o" b$ m 5 `& s/ ]0 i6 E$ R( C9 P* `' G* d
    import IPython) S2 b, m' ?' ~: p
    if IPython.version_info > (6,2,0,''):
    . E1 E' B* D+ ?, q' D9 M  display.Image(filename=anim_file)9 J! ?. |, @# B& N- g
    参考:https://www.tensorflow.org/tutorials/generative/dcgan& G: H8 q; I0 d5 u
    ————————————————3 M* ^- M3 W3 g+ r2 j* P
    版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    + }& G$ E. m- g原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
    - T% j. |% s8 {% x0 {
    & R, |- n5 A' P% q
    & b/ L) G4 j- D! S2 M* ?
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-12 12:23 , Processed in 0.474628 second(s), 51 queries .

    回顶部