QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 5685|回复: 0
打印 上一主题 下一主题

深度卷积生成对抗网络DCGAN——生成手写数字图片

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2021-6-28 11:54 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta
    & e- ~: W0 [+ Y9 r+ {! ?2 z
    深度卷积生成对抗网络DCGAN——生成手写数字图片  [' }5 Z$ r* Y, W) l
    前言
    $ d- B% r( T' h9 z" {本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。! X! A( }; d% l

    4 e/ I/ ^( A! ?" [

    6 R, R6 f$ j0 \9 }5 J/ _, u 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
    5 v. i' E. V6 |+ e2 c( u4 y7 d- w0 p2 n0 z2 O- m+ H
    $ o( }, A% N! j$ U
    # 用于生成 GIF 图片
    $ m1 N3 g- @$ V4 [) t( \: @9 Wpip install -q imageio
    4 s+ v: k# _: G* `目录+ r9 ^( ^$ l: ~

    4 w2 J# p# k: s+ z: U& y
    , }, ?% T. I7 x+ |" u
    前言
    # ^6 f$ O. J* D) m6 R6 u8 `: B) b( U8 o' [( `2 ?* L9 t. P6 w7 d

      z' |( k; F3 M; K; b一、什么是生成对抗网络?
    & H1 t' l" D7 L
    ; R1 E) i  r3 p) X

    ) I# }2 D% N$ \0 u! R& z二、加载数据集
    & t! X) E+ N# o1 w: g; |/ L9 D: x" N# o; d5 X

      p3 y; e3 \; [* E' D三、创建模型3 m' \3 ^# E5 r( r. ]6 r. q' F
    , h" S3 W9 \& Y1 E) i! t( G' ^6 p

    ! r% `# f4 V, _. |- k3.1 生成器
    0 ]* O& U& L6 o7 q# p/ d+ N3 q8 ?* Z% L9 t2 R8 T: H
    2 c# i9 r4 M% o
    3.1 判别器1 v/ x* h7 d$ @9 x
    5 s3 U' W5 t* X' m- |2 A* ~

    ) D6 u3 a! X- n! {9 P# t. o* R四、定义损失函数和优化器
    1 Y3 {7 L) b9 I9 ~5 G1 e$ [9 x% _% O+ A* h
    9 g7 K4 W0 a2 W. |" p, s; k" X
    4.1 生成器的损失和优化器1 a, x, I1 ^5 W

    & f: s( ~$ _5 y& ^- E

    ) H9 s6 b" Q1 U: e4.2 判别器的损失和优化器8 }. N, V8 a+ l4 `

    # x6 v" ?+ P9 q' H' P( D# g9 d. @

    + p  s" m& ^) R( i1 {, K" O* i, _4 @五、训练模型9 ^+ W& b+ s/ s5 y6 H( k

    ) N; d3 `( Y* |8 g$ u, O
    - d- e7 d; \- r
    5.1 保存检查点
    7 _3 \/ D/ I, {$ d1 J5 f" j$ A* D$ I- ~" ~% v8 B
    . g4 G: R- \, z
    5.2 定义训练过程% |5 _8 |: f1 ^9 `* w: u8 O$ O7 I
    & ^* G7 X+ n! F2 h9 U& _& O/ S1 y- @/ Z0 u, ~
    / P/ E: v; L6 j& ]4 q( C( Z
    5.3 训练模型( i+ ?" K$ `2 i# d6 Z: j4 U2 E4 X
    3 K1 x2 Q$ i& V! c' `- w
    & r9 U7 ?& q( w. o8 z4 ]" c) y
    六、评估模型+ b: Y: C% F3 O$ t; N

    ; O  L2 a; ]$ I
    ' B1 n7 y2 O  u1 M4 ~6 Y
    一、什么是生成对抗网络?8 F) E* s( i6 Z4 y% k
    生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。) D' M1 {0 j0 k1 l
    " }; ^+ ~% a% j; {+ W
    1 S8 s4 a8 K) W8 m+ d6 C
    生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。
    5 e  f7 s( t. E& W& B- }3 p( i
    & `& i9 ~+ H3 p7 v# V9 g

    0 U7 }$ N. g- H. R2 S. m; x2 Q' y判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
    ( X- d, l! H- e6 g
    ! _1 q3 N9 p2 ?. s

    5 T% r% S5 e4 I$ ?% ^1 W' w训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。9 w, l9 g* k% F/ _' l3 J

    % j4 W1 V$ m4 ?" T: K
    & m9 m6 a, H- F6 T
    当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
    , S. A; t' {% P- ~. v: j8 y
    : S3 P# f) j2 E
      e/ G8 w% p/ c7 ^" `
    本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
    7 Q& q2 `$ w2 c$ j
    7 s+ c9 I& k  h+ [. V) w5 H
    + |7 l: O( I: {( K" \
    二、加载数据集. _- m* g) S8 u+ v) U- {' X* T
    使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。. m, X+ p- I$ t

    / D2 U8 A. a& t- `
    / Y% B3 u0 L2 b2 P5 E! Q) W
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()2 o2 ?1 C& n  f! d0 e; U# [' b

      r, r% ]& |0 B3 W5 y+ Itrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')6 K1 L7 |; f7 S  r& T9 i6 U9 l
    train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
    . V  F- n9 t. G% c1 w   n; ]; h% Y' m* B: \
    BUFFER_SIZE = 60000
    + q9 _4 ~  }3 P7 }$ OBATCH_SIZE = 256# S! o* T' @3 c

    " O+ g( X" @) X4 V) M# 批量化和打乱数据; t5 b2 O8 x" ^/ `9 K0 d) Y: D  y. H
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)4 K* D  H0 `! t. i" [
    三、创建模型
    * S8 `( n6 u) w  d! D主要创建两个模型,一个是生成器,另一个是判别器。. b! E* {$ G7 V% V8 o4 O. Y
    3 K+ i- @4 U0 S8 K8 U

    - I8 e: G, N$ W  ]' y( H6 w3.1 生成器
    # G7 P+ ]& y; H! N/ V& E( [7 S生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。+ U) A7 ~  M( [+ q9 z7 V) V
    ) e, N- U3 ]) S

    0 f% d, j0 J  {1 C& ~然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
    * Z% [9 U* j" S8 H+ o& U
    ! r0 |: I4 ^5 y( r
    3 O, y3 w) _  x7 \
    后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。
    * r3 T) b! D) G0 o* j# Z( l# ?( |4 F5 a* M6 t$ v" [$ z- o" l* b7 L  Y

    / ?8 M! |; T, T- K( l) S3 gdef make_generator_model():
    ; n; j7 t! i1 O" G    model = tf.keras.Sequential()/ z) c/ M& g0 K/ d2 f, Q
        model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    7 L, h3 s) u& r2 e/ W    model.add(layers.BatchNormalization())
    * @, ]& a$ ~" v8 h    model.add(layers.LeakyReLU())
    # y0 x% Q' P1 p0 f6 V& j6 p 7 [# V6 ?4 f& f4 m! L6 Y4 }
        model.add(layers.Reshape((7, 7, 256)))
    , c2 n7 Y3 w: S( Z4 o' _) ~    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制/ h+ h0 ?8 f' x% y+ h. u8 L: ~

    , [; X: U3 k' J" E    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    : O$ r4 i# w' d+ j( a) H# p5 W6 r    assert model.output_shape == (None, 7, 7, 128)
    $ u0 K6 }  j. _/ [" H+ x' T    model.add(layers.BatchNormalization())
    4 \, Y. }; f% V9 I2 K    model.add(layers.LeakyReLU())) ^! ?2 I; j; \+ O# j/ s
    4 D, j) W# K; \: z; B
        model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    ; K! D" }0 V: O    assert model.output_shape == (None, 14, 14, 64)% `2 D! L$ u1 o6 \5 |: B
        model.add(layers.BatchNormalization())" R# G- e/ G  D# u, k. e) E
        model.add(layers.LeakyReLU())4 W3 _4 i1 t  D6 X/ o1 n

    0 m$ v7 @7 }6 z( t1 H  b# h    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    $ v$ e5 u! {2 O0 Z    assert model.output_shape == (None, 28, 28, 1)
    * v0 |) R+ c4 `7 ^ 0 J. G% Y7 g# [* K5 _
        return model: d  a9 I, j3 G* `' s$ p+ X3 E% @
    用tf.keras.utils.plot_model( ),看一下模型结构
    $ ^( {# l$ ~# X4 q$ {2 o* b, A: B4 D: d8 n1 a( a6 d; \( C/ U
    1 {* L9 C' G" e9 O- X
    + r& A( m# Z& x; U5 J4 Z9 P8 L
    & i3 R6 s- E( W+ e
    5 C+ ]9 @( w  Q2 _
    用summary(),看一下模型结构和参数9 w- q4 q+ Y) s3 B7 J

    % v, ?+ I  j. h* _% R) `
    % H% h! u! y* {' Q2 M+ s
    ; V# P1 a0 @2 }# c7 U' ]& y6 d
      G/ w' n! B9 e- m$ Q0 J

    8 {/ P2 T5 |0 s* L1 l+ I( s/ x
    9 z) ^( g9 i. }4 S3 F- n, |& S. d+ _
    使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。4 v- h4 `, u* F3 i) Q0 A+ Q
      [3 ]0 C( A. ]' m
    . [% r& {/ p; b3 r* F( T" L( ]
    generator = make_generator_model()% }# y1 ^7 K8 e( H; }% m  T: t

    * |  D* W, s& T1 T" ynoise = tf.random.normal([1, 100])
    # x* j. d3 C+ i* Q) V+ ugenerated_image = generator(noise, training=False)
    + W4 s7 {; G  L: p8 K
    5 p/ y) q( M2 ~. E, O  Vplt.imshow(generated_image[0, :, :, 0], cmap='gray')
    5 Q4 _3 u" H5 o4 C0 L* ~: f5 V% A& n2 ~4 P" i
    * m/ Q3 k9 L$ m9 z1 ~3 N, [

    ( J/ W& C. d: A6 M

    6 y/ R6 u. X, l3.1 判别器. H& x# Q: D" }" ~' ?* }
    判别器是基于 CNN卷积神经网络 的图片分类器。
    , Z7 D* d. A6 C: U% E8 w6 ~* C$ {/ I( k$ Y
    ) r% w- r) o! [+ B) o
    def make_discriminator_model():
    , x# T0 r; _  k. b# Y7 o    model = tf.keras.Sequential()1 T1 A3 U: c3 ?+ k2 A
        model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',4 F+ Z* E" t  C) {+ g2 l" Y
                                         input_shape=[28, 28, 1]))/ N! {8 T% E4 p* _
        model.add(layers.LeakyReLU()), L" N! @- d9 v& k8 E
        model.add(layers.Dropout(0.3))1 [5 {1 v9 T( M% I' x3 P/ g0 ^
    ' V3 u% s8 U5 W# R: _2 ^
        model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    . O; e1 W! @3 {7 J: l" g( h/ C    model.add(layers.LeakyReLU())
    , q2 U6 f9 ^" t" ?. U7 Q    model.add(layers.Dropout(0.3))
    + g( a7 c  E; Q' H8 R3 x7 G 2 m, Q8 y& o% |4 S* z7 {
        model.add(layers.Flatten())
    , n/ o+ `" b7 O! O8 O    model.add(layers.Dense(1))
    + N' y" S; f" f1 _$ K
    * j) C+ S# P0 B    return model
    * v0 W1 Z, A7 A  C( B用tf.keras.utils.plot_model( ),看一下模型结构+ _  y4 k4 [/ W; I2 b

    , S! X- e& n5 P  i8 ^( \
    5 v8 y. \4 K* D( a
    1 `9 q: I: Y# N. ?+ x! [  R& j
    ' Q! N( d8 g3 M/ s
    0 s& A" G% D) O5 h4 G7 e! [. i( F8 ?
    ; f9 f0 T4 @1 T- N
    用summary(),看一下模型结构和参数
    8 D6 l/ r  p' Y9 g: k) q6 s# ^5 a4 D0 c

    7 e  e( W2 M1 @$ l* M
    7 c7 Y9 G2 N8 [5 B

    + E* W' N6 W$ M8 `% V& s8 |! d) `9 P/ J1 }! Z
    - `- r# j5 R/ r1 @
    四、定义损失函数和优化器) L3 v' y% d! R0 ]8 f
    由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。" Q# Y; C7 ~0 b9 j2 @  d
    0 }) W. y" o; |/ }/ }9 @3 w
    4 T( N- Q( D" T
    首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
    + l% O& Z  f0 U; O& w$ v
    ( h, N& t: P/ P# b; V
    9 H; O# B' \$ ]) V' O  }/ |
    # 该方法返回计算交叉熵损失的辅助函数
    4 I& u, t* P. n8 Dcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)0 u8 e. B' ~& d0 b; X
    4.1 生成器的损失和优化器
    2 Q  s. Y" K2 f0 E' g1)生成器损失& M7 m' ^& }, c: \8 r8 u; C

    : b8 D- I9 a+ o& w- E! w

    , ~" Z4 S7 h& a4 K生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。: V4 T7 _- }$ I# p# Z, L, ^2 o
    6 s- h# ?  n3 C: Q7 F) _7 @5 A1 I# i

    + M( a# U1 U2 L& {+ f这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。
    " [; D: T/ I' }( R
    ; M/ `7 M! ?9 M! _
    2 |& ]$ x+ M8 e
    def generator_loss(fake_output):' v* M' k3 m6 G( E) q+ ~+ d8 o. a
        return cross_entropy(tf.ones_like(fake_output), fake_output)/ G# V, t5 A: U! k* K- h0 W
    2)生成器优化器
    0 d1 _. Y  w2 }8 H- `
    / T8 ~5 @( B' M2 ?

    : i* {8 \; H, y; m: Q+ y: Fgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)
    * {5 Q: k2 N0 b  z6 W4.2 判别器的损失和优化器
    6 D0 \, ^' O, ]- `) B- u( U1)判别器损失1 M$ s' @) r" L* f( c# W& u
    0 ^; o6 `7 Y0 C! E  Y- P$ A/ P

    1 s/ ?1 }& J6 `5 o; B6 b判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
    : _9 K( p; N$ a5 N7 F. y, P
    8 `! E: d4 s, ]+ E
    9 J5 S( H! O$ C% ?% Q. u2 o8 S$ z* M
    def discriminator_loss(real_output, fake_output):! ~9 d( m, |( B6 d4 v' |( s: \
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)' `6 v; F  n3 N9 p& w; J( ^
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    . m- \7 a8 F% C% }4 S+ i- a    total_loss = real_loss + fake_loss% i4 N1 E" C8 @
        return total_loss1 {* ?7 X, ~' x& ^; @2 R2 _. b" ?) _
    2)判别器优化器9 O$ x  C: z% O2 M$ ]; O: G: K! k

    ; }0 f9 \" j! i; U. m2 X8 S" n# E- t; M

    * t0 z7 \+ H9 T8 ?" [9 v! Idiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
    0 n9 V/ O9 D. E* ^& m: q, `  ?# }. j五、训练模型  e: U' M5 e1 S3 e2 g' x, k. E6 V
    5.1 保存检查点& J6 {% \6 Y: S9 M
    保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。5 }" m2 K& j5 s9 r5 d
    - W# k  G+ F! y- \: I+ z, b

    ' _5 o: `, X" y# pcheckpoint_dir = './training_checkpoints'
    / R/ |0 b7 r/ M6 r4 B; Bcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    0 [4 q' Q" K$ Dcheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,4 i3 D6 M5 H5 z6 O/ b  v* Q3 G+ ~
                                     discriminator_optimizer=discriminator_optimizer,8 `; g2 |% P4 Q  C1 q" ?6 i
                                     generator=generator,
    + T8 o9 T9 j, a3 h; r% @5 k8 X3 l6 M                                 discriminator=discriminator)# ^6 h$ ~2 {; `/ m( X6 v& g
    5.2 定义训练过程
    - B7 ]$ k/ {! Q/ zEPOCHS = 50& g* \" j- M5 p8 t+ C
    noise_dim = 100& U  q/ y8 h: E; z# ~
    num_examples_to_generate = 16* P$ Y; W1 I6 ^/ A

    / N2 U4 ]# ~8 M- [9 _+ ?+ z + {9 W% X$ x) n9 r+ w; D
    # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)3 k8 f( w9 K$ N- d  L* q
    seed = tf.random.normal([num_examples_to_generate, noise_dim])- @9 ~8 C$ k' Y+ Y) S% L2 ]
    训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。
    ! q' U8 l( D8 m* e% ]% m8 K- y2 U; `; D$ |# _& _
    6 `5 [7 x3 a$ \/ k; W4 |# w" ^. R% X
    判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
    0 X. }; M2 s2 @& p$ N  n# y, u& b( P8 M; H
    ( K) x6 Z) a; K; ]: X7 |$ B
    两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
    , a9 S' [0 ?( m" V# ~7 I- Q/ ]( S
    $ ^; o. F: P9 h# e! Z) L
      t& S, p! G) y- c' S: ]
    # 注意 `tf.function` 的使用
    8 k9 i( U8 Y9 `' c- {) p# 该注解使函数被“编译”
    5 V  V2 i/ Q2 z, J. F& D@tf.function! O# ~! `) l' W) \, Y" D* w% a$ \
    def train_step(images):
    + a3 `# x# L$ z, K1 G8 z, l3 R    noise = tf.random.normal([BATCH_SIZE, noise_dim])8 O0 f' [: L( e) J( G
    / J0 ?: f% B$ T6 Q7 Y2 G9 v
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:* Q- b9 U1 K, u) l( b
          generated_images = generator(noise, training=True)# d3 N) ^0 [& [; w# }- T

    . ?- l' V* H9 M( d      real_output = discriminator(images, training=True)& @: t0 w# ~' w* x6 ]
          fake_output = discriminator(generated_images, training=True)
    7 j2 z6 A; N* T- Y% ?9 X+ ~0 |
    . f' u# o6 ]8 P      gen_loss = generator_loss(fake_output)- }1 q" Q1 k1 t' p$ Y
          disc_loss = discriminator_loss(real_output, fake_output): [8 ?9 m7 m) A0 J

    - C9 D$ ~" A8 [% b4 E    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables): t: z% l9 a+ \
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    4 _! z7 u. \2 ^; H" d  H" w
    6 Z9 ~- x% r0 B/ o5 n" {    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))! @2 n& ~' U8 y2 }+ ~$ V$ I
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    ; I8 z5 z" O; q: f6 F  [
    ! k& p/ `# S/ }3 E  S; L& rdef train(dataset, epochs):; U, E  X# d: @3 A7 X  G1 s: a
      for epoch in range(epochs):- F! T" L4 H% b; _% ]1 ^: _
        start = time.time()9 v/ U5 p6 H/ u5 I6 M+ R& s
    ' c6 g, B) h( B0 _; G+ Q  a3 n
        for image_batch in dataset:8 g7 E2 H+ w* |+ L
          train_step(image_batch); F9 K  X# v$ R+ {8 l2 h% d

    8 _8 T. L# v; l$ N' r* k    # 继续进行时为 GIF 生成图像
    ; t' O. z; Q* b! C* t, K    display.clear_output(wait=True)
    ; v% x5 q& M6 k& T+ f0 a" @% z    generate_and_save_images(generator,/ i8 o$ F( v% `. @. x2 s  m
                                 epoch + 1,/ N# O" y4 e, \8 k4 x, i) n
                                 seed)
    2 r7 a; f+ d, v% A5 p
    % j% H! t$ A( F6 Q7 s9 i    # 每 15 个 epoch 保存一次模型
    # s1 x8 G0 a6 R( H* m    if (epoch + 1) % 15 == 0:3 e) r% y; r4 L, y
          checkpoint.save(file_prefix = checkpoint_prefix)
    ! P" s3 S5 ^+ V# h
    * I* T- K. c" A) ~; f    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    ( m  r0 Y5 k' q5 u$ B # d- |# S; ^/ _1 I0 r
      # 最后一个 epoch 结束后生成图片
    $ C+ c- {) l9 H4 c' P" v  display.clear_output(wait=True)* {) ]+ Q# k& V' p
      generate_and_save_images(generator,
    4 |: f. O: _, h7 O4 S/ C                           epochs,
    ( ?5 j0 K* @9 g& }/ \- C! l& s                           seed)( n! m1 P+ R: d  {% l5 Q# b
    + P% `/ [) g$ N( _. ?/ k
    # 生成与保存图片6 t5 ]" W$ ?1 c# I  n* s4 v% C
    def generate_and_save_images(model, epoch, test_input):( o& o1 s6 k% E: @( ~! F
      # 注意 training` 设定为 False/ n7 y& ]# e' K
      # 因此,所有层都在推理模式下运行(batchnorm)。
    1 h9 _. s% S( W1 X1 C  predictions = model(test_input, training=False)
    , C! i) s$ c( ~' O& @ & a  y* S/ n" \& D) C
      fig = plt.figure(figsize=(4,4))0 F; e  k8 D; i% f

    ) U% g+ v5 y/ ?. ~" z0 K  for i in range(predictions.shape[0]):
    + I7 x1 t4 f6 H3 b+ |      plt.subplot(4, 4, i+1)2 ]. c- y" k0 D
          plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray'); W5 G* \+ h: q; d' e) g% _
          plt.axis('off')
    3 n# Z! y- O* H) ` / K! |5 m% Y/ M) Z! ]! e( q. X# `
      plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    # s. N$ z: q& x4 ?5 I7 q3 m  plt.show()9 p, m; v( B) e0 d$ }1 H
    5.3 训练模型0 l4 S9 `- e, x% B( V  V: P7 j- P' B) C
    调用上面定义的train()函数,来同时训练生成器和判别器。9 f4 q, K) o3 {0 y
      `: C. _2 ]2 o7 V: P  j

    % X" y' ~8 Z5 K# f2 O  r7 `注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
    . e. P$ {7 G. y; L1 q9 m
    - X- ~1 s/ j0 I  |
    / _+ W: f9 Q; b6 u3 B& |
    %%time" g; x1 ~& B5 M6 t" R1 }8 E0 u
    train(train_dataset, EPOCHS)" C5 Q3 C& M' @5 u& N
    在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。
    0 ]7 A* ?) p9 N9 h( @; |. X/ _7 J7 _" U

    $ w4 _9 q' H, u( a1 Z7 v训练了15轮的效果:, x8 ~/ L  c( e+ t' U9 k

    6 T1 h3 ~5 f9 W4 ]( P: O$ s0 U' {
    ( H/ X( @5 J) J& h8 @# ]1 A. J
    ) d6 g# H3 ]: \% F) j
    : Q! u. Y( M) {- Y

    , b3 q& i5 ]9 P6 k+ @. [

    ! W8 {. ^- i& u4 l7 h  k$ Y训练了30轮的效果:
    % S$ r9 n$ z8 B. E3 J+ ?2 o' z# M/ @/ m. m7 Y) O- @2 D1 G

    , |9 ~( e# ^* c2 f- w
    ' A/ Q0 B) Z; v

    / ]& H( @2 c6 e) f  X/ V0 l" n3 e) U8 b

    4 _/ Q9 b" h* C+ ]& _训练过程:* V! _2 l, v. f/ G! p% o" ?

    3 q% T/ K! u1 f8 L4 W

    4 D! Y. Q  H) r' J% C
    - T" G, w$ r6 J9 v) M* L
    # v3 T$ m+ t& ]4 _; z% c
    4 y1 l& m6 U+ W' |" P/ W5 m/ {: |

    % H* J" |3 {8 J( p( W% G+ _/ N0 H恢复最新的检查点
    - W) d  [' u+ s; Q$ k9 I# g+ b
    : Q. |% C& I; E& m1 ]) X0 M8 L
    . S4 b2 b0 a7 w5 c
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
    & F% e" I: I/ D9 T六、评估模型
    ! V: K8 @' f# M  ~7 X" C2 u( c这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。9 z/ L7 X5 r- r( E/ N: K

    6 \5 ^& A! i% e% U( ^4 a

    ! a3 S4 l$ S: B# R# 使用 epoch 数生成单张图片1 m# d1 R9 k7 M: z$ r" m
    def display_image(epoch_no):" ]; N/ T( D3 T) h% ?! ^+ c+ `
      return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
    7 j- I  G8 W8 U* Y3 b4 E& z8 k 3 n5 u4 G* Y2 ]+ {( x4 [
    display_image(EPOCHS)2 h. ^' g7 O- l+ b6 o  f$ ]5 t
    anim_file = 'dcgan.gif'
    1 w7 B  C6 I: Z8 ?! e+ d: m1 I# m
    5 a/ H$ s! J$ Z7 `3 {& ~with imageio.get_writer(anim_file, mode='I') as writer:
    % Q* T& T1 V: w; ?: A7 O' c& q  filenames = glob.glob('image*.png')
    + P  k2 |( D/ a% _) _  filenames = sorted(filenames)
    6 l5 C3 g% K" E3 d8 `8 V# o  last = -1" \7 F! }5 s  E. ~" S5 q+ d
      for i,filename in enumerate(filenames):- q4 `, |! B' h) |1 M6 c
        frame = 2*(i**0.5)( H* }. D% b8 P3 o1 x  n3 I
        if round(frame) > round(last):/ I" v0 r; @! a  P( O  i7 P# F
          last = frame* g7 g, H% M" ]" H  V& D( Z
        else:, m! @% z3 ?0 S! I6 e1 d% d+ _
          continue
    * ~4 k' h; E9 @% H& n" o9 t    image = imageio.imread(filename)
    ; e, p# `) M8 W9 z    writer.append_data(image): G5 l& \: r% X+ `( G- l4 O
      image = imageio.imread(filename)& O8 P( T6 e, g4 |  Q
      writer.append_data(image)' M) _% l9 k8 a6 c* l1 }+ c4 i
    " _* s5 p  }4 ^2 T* {! x2 ?' T
    import IPython, y2 F' z: q& w2 Z
    if IPython.version_info > (6,2,0,''):
    : H  i5 Y# x& e  display.Image(filename=anim_file)9 Z( E  j1 ^1 q+ U

    ( t" g* k% B" V) ^" |- q8 y

    # W- S1 A% E! i0 a, n  h
      N! s2 ]: r0 J- C1 M- U
    # S- U5 U: G' N7 P
    完整代码:0 K/ C( q4 |* X7 e5 p% H) N6 K) v
    3 r: N! R3 f+ q" m- E
    # \5 y. i8 |  r( Q
    import tensorflow as tf8 m) U& v& z5 W: _( l$ Y
    import glob2 r- S0 ^, N7 A( D, N
    import imageio/ y4 ~& e, j4 j. Q9 U+ W& q
    import matplotlib.pyplot as plt. T) D- f9 T2 {  K' K
    import numpy as np+ Z5 I5 N4 z% q
    import os  j% L0 O0 [! w; U3 ]) V" h: Z
    import PIL
    4 H& T$ W0 F3 A7 j* y5 ]7 |& lfrom tensorflow.keras import layers
    # _+ O" k$ r6 C9 h- E. jimport time' P1 f9 n" Q* X) m

    9 ^3 Q9 F* Q) r. h7 X" Wfrom IPython import display' B" [3 B2 `. [$ v$ |/ P
    ; [2 I  q0 m8 H0 N* D% ]6 U& W
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()/ C! D0 d$ `' f" E6 w9 F
    * a8 a  ?; m5 }0 c# }. m. _
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
    # O, y; K7 [, s" Ltrain_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
      G) X6 O7 V3 P: W9 A0 t& o
    : K: P+ O7 \5 f, \/ _BUFFER_SIZE = 60000
    . e0 ^! t/ r( d) }BATCH_SIZE = 256
    , E: z/ E8 w3 t' B/ t( X" c0 G + ~' `) H) r; ^- J4 }( q& w( j
    # 批量化和打乱数据
    1 ^! V. b/ D% T- Mtrain_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)) c( D: U0 r$ f, }1 @

    9 L8 K, f# t: y% _! J6 t5 f, H# 创建模型--生成器
    / ^, h6 C3 G& e& s2 W: B% k; Ldef make_generator_model():! e5 i# u/ I2 [3 y+ K- j
        model = tf.keras.Sequential(). S, U$ _# T, b; l
        model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))/ i2 j6 n# C" f- [2 J
        model.add(layers.BatchNormalization())
    2 [5 Y& q; p2 M$ r    model.add(layers.LeakyReLU())
    3 f+ B. Q4 f+ b, f) W  ^
    * z7 E1 O- Q1 n, T    model.add(layers.Reshape((7, 7, 256)))
    ' [" n$ X% S7 Y    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
    $ v5 d) B) [& N# u- `: q* I2 f' | : _' b, e, [+ \# O; z, `
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))9 x9 f0 C1 r. M1 X3 E
        assert model.output_shape == (None, 7, 7, 128)$ O6 m& l+ P' o% R3 w
        model.add(layers.BatchNormalization()). t0 z; D& |& ?( V. a0 M
        model.add(layers.LeakyReLU())2 K# m# U& n* {0 J5 n, K
    ( g$ u4 p& U# g' p% j* b
        model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))" s: s3 ]) ]1 j1 U* R! U, O( X7 E
        assert model.output_shape == (None, 14, 14, 64)/ L. ~+ i# o- k' Z$ ^+ M- ]2 i
        model.add(layers.BatchNormalization())
    0 D9 s* P9 h: f6 B% _    model.add(layers.LeakyReLU())
    * |8 F! {1 i& F) Z0 H. L
    / w( e) X- W: `- _$ Z0 N4 b    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))3 }: f$ k% u3 {9 x# s& T3 @
        assert model.output_shape == (None, 28, 28, 1)
    , {# x& G: `) D! c. Y ) V/ [/ m7 {, b5 `6 K
        return model- I0 I; D! w2 {0 E( e( k1 _4 l
    3 e& t& B/ J9 ~  e1 r* N
    # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。7 H  B9 Y7 `- C$ H
    generator = make_generator_model()
      v" x4 E" _5 X! {! P2 ]
    * S8 {4 P0 R. T# B9 _noise = tf.random.normal([1, 100])1 b0 N: q8 m" p- B, |) `0 Z
    generated_image = generator(noise, training=False)- Z. l# V+ x2 ^) ^; G

    " A+ g8 h* X6 I$ Rplt.imshow(generated_image[0, :, :, 0], cmap='gray')
    . D! W0 u! z8 H7 z4 F9 z# y- u  Mtf.keras.utils.plot_model(generator)  e2 h! n1 s5 F9 I( J8 ]8 x- ]
    / M" v6 X( W- [: [3 X0 t! m/ z
    # 判别器
    / O$ h% {% t+ o$ e: O* I5 c7 ndef make_discriminator_model():$ s! K* g7 h6 C1 t8 _& H1 s) L# d+ I
        model = tf.keras.Sequential()
    # N+ p/ w& P" [0 b/ `) y( N* y. P    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
    & w- T( E# b( J6 D                                     input_shape=[28, 28, 1]))
    " F$ }5 }4 U. g& s% U! D    model.add(layers.LeakyReLU())& Y+ h8 d* M+ h! h5 w+ n: {
        model.add(layers.Dropout(0.3))& Z# u  f* Y8 q) }7 m
    ( [8 C7 s$ J. q' B
        model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    - m. N! m* J8 G" d    model.add(layers.LeakyReLU())
    2 A( z+ F$ B$ Y    model.add(layers.Dropout(0.3))( z3 R3 @( ]* K# I& J) R: j* H0 y
    ' v8 t) Q  m7 u
        model.add(layers.Flatten()): l4 C$ ^, D* V6 N; D6 F4 \8 f
        model.add(layers.Dense(1))
    7 j. V  d# {+ B: W& ]7 w
    . u5 a5 e, J! q4 E  `4 o9 P& O    return model6 S% L5 _  j% g5 X9 ~
    3 r+ b" n' q: G) e8 I/ n3 w' C; C& ?
    # 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
    5 K  b* A5 q% F( ~( P1 Jdiscriminator = make_discriminator_model()
    5 S+ t2 _5 @; r& ]9 u3 Sdecision = discriminator(generated_image)# j$ z, Y" M" p- G
    print (decision); f& L5 l# l" g& k

    ! m% q0 V1 A3 k- U# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。* R/ u+ a8 n0 h4 [! \9 M! ]9 k% z4 ^  K
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)* y. S0 g' b. m8 Y; h
    * {: W# N$ ]# I# t: Q$ L
    # 生成器的损失和优化器4 z8 c! W1 i" Y8 r+ A
    def generator_loss(fake_output):0 d, Q  m$ F' p- K2 F
        return cross_entropy(tf.ones_like(fake_output), fake_output)
    ( q6 Y/ c, d- o7 mgenerator_optimizer = tf.keras.optimizers.Adam(1e-4)2 |, F, D5 i9 Y/ ?. Z
    9 \% ?! W1 K4 {% G
    # 判别器的损失和优化器
    2 `; X4 Q: h) n- rdef discriminator_loss(real_output, fake_output):' F5 q  I$ ?) @7 i9 g9 d, J/ P8 w
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)9 C1 M! Z# Q  N
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)- B2 j+ s; X2 S' k  R
        total_loss = real_loss + fake_loss
    ( U* q$ c9 ]* N$ p% |! \* q. h    return total_loss0 Q; Y# @4 B7 r$ w9 @- d7 o; c
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
    . N* G  l' x" K! X7 y# y 1 U1 N, k7 H$ P: i3 p: y
    # 保存检查点
    ! _% u" C9 ^" ncheckpoint_dir = './training_checkpoints'
    9 T# b9 V. p+ B: N0 U, ^( B1 V; L) vcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")2 P6 f, p( g9 U% r
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
    / o& b3 E- g8 u& s% d, d& @                                 discriminator_optimizer=discriminator_optimizer,6 q% g% v; j/ |- e& ]
                                     generator=generator,
    + i/ t) e) U' T7 j/ P$ T1 K                                 discriminator=discriminator)  ~5 j; j, |3 u% k
    ; J) V$ \! \5 `8 Y" @% N5 N; L/ J
    # 定义训练过程) {5 n8 Y/ ~. g
    EPOCHS = 50! t2 W" Z8 Y) b" ?" M3 p- T' Q, `
    noise_dim = 100: ?( @0 S# E1 S+ t
    num_examples_to_generate = 163 q( ^  E6 f/ V1 g
    0 T7 y) H% S, a- I9 _' v; _
    # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)3 ~4 @" C- n7 g: v; U2 j" I8 }: t
    seed = tf.random.normal([num_examples_to_generate, noise_dim])
    4 ^' i; E" P, @5 l1 x2 o# }$ c. b( a
    . V$ c) N5 W4 t* f0 _0 k# 注意 `tf.function` 的使用9 n; S& c$ @( |0 ]$ F- z& j3 K3 |: @7 Q
    # 该注解使函数被“编译”
    8 ]/ G9 z7 j8 v! u* Q@tf.function
    + P. S- R& k: P. Rdef train_step(images):% S! N5 R+ M7 H% y" C2 ~' v
        noise = tf.random.normal([BATCH_SIZE, noise_dim])
    / S) o8 O( T2 [: L8 u+ I$ j 0 m& ~% `8 e; O! e9 S
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:7 H% }; q  ^% f9 _  k6 q
          generated_images = generator(noise, training=True)
    ; g# [. s& i* O! Z' u" b # G( f' S' z3 }/ f0 ]
          real_output = discriminator(images, training=True)2 j* z/ d+ m  x. f* K+ @
          fake_output = discriminator(generated_images, training=True)
    . P: ]4 L/ V5 J$ c1 B' _6 k, B: O
    - ^! k9 {1 E( h' u7 E6 P4 [& J0 l      gen_loss = generator_loss(fake_output)+ o+ ~2 l/ w1 ]. i+ K. e
          disc_loss = discriminator_loss(real_output, fake_output)/ ~; d" J  p" v' V% D; s) T; Z
    $ z, I! @  s4 O& A+ m! c
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    9 a- g7 V  a2 D    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)9 d% E2 ]! q" `

    5 U% c. ^1 B  K1 O, d    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))! F6 K6 S  D# V. b2 ]- f
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    ( _& d# T* h6 c . G: |' f! I2 E5 z
    def train(dataset, epochs):
    ( E: j6 T& m" W. c  for epoch in range(epochs):
    " x9 H: N+ Z& p6 D$ j    start = time.time()( t& P$ H# M. t, i

    " V, u  G7 k+ Y4 H! f    for image_batch in dataset:
    * b( z& u4 m0 a# u5 Q      train_step(image_batch)
    . \3 y1 s- k/ M( `1 S6 Y" X* n
    ' m! Y( z; s3 s    # 继续进行时为 GIF 生成图像
    * D- G7 |  |" p0 ?" T4 x' M    display.clear_output(wait=True)
    0 [+ z4 q9 M' i$ f" Z+ I    generate_and_save_images(generator,
    7 P/ a* h: V( X, G  i                             epoch + 1,
    1 n, e, M' b7 M9 p( ?3 l& v                             seed)$ @7 p$ Q2 y; \+ n
    + P6 |5 T" E( a! Y' n0 }. x) g- v
        # 每 15 个 epoch 保存一次模型3 v! E# L# t$ Q! _# s
        if (epoch + 1) % 15 == 0:
    5 Y# T$ s7 F2 D# o: E      checkpoint.save(file_prefix = checkpoint_prefix)
    + g- ^+ `' x7 B& W5 a, K& A3 t
    % c4 L+ {! v- E2 c5 F: z    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    # S9 R* D, A3 G2 @& Y
    6 j* o8 w9 t5 R$ B  # 最后一个 epoch 结束后生成图片0 n6 U' C5 b. P& E2 `2 r& v
      display.clear_output(wait=True), x5 j. o9 v! `0 m0 Q; Q
      generate_and_save_images(generator,
    - Q, n$ b& @4 P; k7 |: m- v                           epochs,' g7 ~1 ~8 E( A0 `
                               seed)
    ! N! s/ [6 M' G$ u
    9 N) l3 F+ a: C3 `6 z) T- k# 生成与保存图片
    7 Y0 O' l7 Z3 ?( xdef generate_and_save_images(model, epoch, test_input):
    9 L% v/ h6 m* Y! m  # 注意 training` 设定为 False' C6 b, {* c" @5 B5 M; B- b3 Z; ?) A
      # 因此,所有层都在推理模式下运行(batchnorm)。
    ) Z- L. t) t2 N( \! w4 w  predictions = model(test_input, training=False)
    % E5 G2 R" e0 ]0 b
    ! q4 E, R- R9 K. c6 B6 j, \  fig = plt.figure(figsize=(4,4))
    # g# B. p, I( Z- P0 {4 W" j
      X" |* [/ q; ]! L, F8 y% W) V  for i in range(predictions.shape[0]):
    ( J! ?2 L+ |: ]( d8 L. L7 z, g      plt.subplot(4, 4, i+1)6 U( a& @0 d# n2 ~
          plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    # Q" u8 b0 c1 N4 X" |/ t  C      plt.axis('off')
    - y8 G* ?8 M3 p 5 I! f/ T7 ?5 G' Q- m  D
      plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    * i7 V$ }9 g- q/ i9 }0 V$ m  plt.show()6 h: H+ G: O% t

    0 _5 b8 y7 g" r2 P3 x9 y  A1 G# 训练模型% W4 V8 b# j& U4 P( X8 p, B+ y8 y
    train(train_dataset, EPOCHS)/ N- O0 o: \; e+ n0 D: b1 a, m* m

    & u* e: w9 G: r8 U: ?2 C! q8 e# 恢复最新的检查点6 Q* ?: R9 ]5 ~  F1 M) G1 I! W
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)); q2 r8 L# `/ X3 F3 n- s/ M. ~" k

    6 `2 G% ^$ N, V2 i# 评估模型' w" j& w2 l5 ~; D! r2 c
    # 使用 epoch 数生成单张图片
    8 F* V! D) |7 n: ~2 Gdef display_image(epoch_no):6 f" u! A1 s$ M8 [) `* `& G
      return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
    ) z- B: l( Z) @7 i5 `+ T0 G/ K
    0 n- }$ j) D  R9 A5 h4 tdisplay_image(EPOCHS)
    ; [& l" m! P/ V  h# ]- l0 }
    # I  Z; w* b7 Z* O1 l5 Z" P% uanim_file = 'dcgan.gif'
    & m$ H8 c3 [8 f9 p% B 1 m0 N6 }: r, _1 j: s5 k# v
    with imageio.get_writer(anim_file, mode='I') as writer:
    ' d7 w6 c, `  ^4 I- K  filenames = glob.glob('image*.png')
    4 r" z1 X6 l2 _  t$ T" Q4 y  filenames = sorted(filenames)
    8 h! M' B* P* {+ B8 p  last = -1" L7 F' ~2 k/ o. ?& H0 n; _7 X1 }
      for i,filename in enumerate(filenames):
    8 M3 h  z# H  ?  t$ m, c+ q    frame = 2*(i**0.5)
    ! s  A3 n% `3 |9 j& s2 p    if round(frame) > round(last):9 P- p8 w+ o3 g# L( s& K/ d4 e
          last = frame6 `9 [, k, b9 _& ~' T3 b1 r
        else:
    7 p2 g1 x; @: [; i. Z- F      continue6 }3 C  p* U+ z# @
        image = imageio.imread(filename)* r# ~) w# N" d* u
        writer.append_data(image)
    7 n3 K3 V* Z. B' F  |& p) x# M  image = imageio.imread(filename)
    : M! l/ d1 Z( f' f1 _  writer.append_data(image)) |5 T, w% i* |3 `, K" p
    8 Z. X! R. N! q
    import IPython
    ( t' Z" U$ r  P; \if IPython.version_info > (6,2,0,''):
    6 M5 O9 U% Q! w; O7 k  display.Image(filename=anim_file). h% Q  [) h! ~, R+ s3 t
    参考:https://www.tensorflow.org/tutorials/generative/dcgan# u7 ~; k1 g; ]; n, ^
    ————————————————
    6 x# J2 q8 F% J" F& r' e/ h版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。* d- G$ a3 Q$ s9 K2 T7 ?- {. l. W+ M
    原文链接:https://blog.csdn.net/qq_41204464/article/details/1182791119 F9 f% ~: |/ J1 W9 |

    & H6 g; Q2 @. X  h. ]+ q+ k4 Z8 E/ q3 q9 r+ f4 S
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2026-4-14 12:37 , Processed in 0.547256 second(s), 51 queries .

    回顶部