QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 5220|回复: 0
打印 上一主题 下一主题

深度卷积生成对抗网络DCGAN——生成手写数字图片

[复制链接]
字体大小: 正常 放大
杨利霞        

5273

主题

82

听众

17万

积分

  • TA的每日心情
    开心
    2021-8-11 17:59
  • 签到天数: 17 天

    [LV.4]偶尔看看III

    网络挑战赛参赛者

    网络挑战赛参赛者

    自我介绍
    本人女,毕业于内蒙古科技大学,担任文职专业,毕业专业英语。

    群组2018美赛大象算法课程

    群组2018美赛护航培训课程

    群组2019年 数学中国站长建

    群组2019年数据分析师课程

    群组2018年大象老师国赛优

    跳转到指定楼层
    1#
    发表于 2021-6-28 11:54 |只看该作者 |倒序浏览
    |招呼Ta 关注Ta

    ' v. R' d' O/ Y) h深度卷积生成对抗网络DCGAN——生成手写数字图片
    2 f, D* ]+ P2 I前言
    9 ?7 i+ ^; N1 S2 Z7 e7 f本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。
    , j* N* F. A/ J$ r; W9 y, }: u5 f- Z- E+ \$ |: I8 M; ~8 b

    " `/ S2 e' s: _( r! s. ?  l  D0 E 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:
    4 }4 m+ }& z% B' w% z2 L" L( N1 T% b6 ^! h$ d$ ]

      i# s$ \$ |& k) b# 用于生成 GIF 图片$ W9 B, Q& Q7 I% y, w
    pip install -q imageio
    + u3 Z' ^& ~/ r' I6 r目录: B, ?7 V2 j* Y- g
    * n% [% h# P+ y+ @+ ?" y

    . R% w. J6 S7 R9 H1 r: x前言
    / o; X0 Z. d3 ?# X7 G2 }1 b* ~
    . [8 m, F9 o5 h/ s2 Y

    . j4 }- Q6 A9 x  v. Y% d一、什么是生成对抗网络?
      A1 {5 ?$ e* B6 k" E
    : ?8 {5 m" c! m" ~

    6 ]; r8 ]" R# s/ e; W5 X: }二、加载数据集
    8 ]: ]% ?, d1 }! m' ?$ D! P+ M* }; T% L" r- g
    6 {$ [% C$ [, V3 x" Z& \: L8 c3 v
    三、创建模型
    7 b" a) `6 J6 n4 r% ^3 [$ M4 B" _  A
    1 m- k  N3 y: r
    3.1 生成器! Z+ t4 |6 D* a  b# E. X
    . \9 j- I  ^& e# W/ w0 k
    ! |: v- @. U! ^
    3.1 判别器  H8 x' i/ y2 g( C: J
    ' n9 ^7 N' ^, M! \' V4 x
    & t! f; D- |* ]: ~# t3 x* O
    四、定义损失函数和优化器% \- e+ G$ |% ?6 I( |$ d
    ) F; a4 w+ p3 J5 ]
    . u8 d- }1 l; N% M+ o* V
    4.1 生成器的损失和优化器
    . R% P( z( I) s8 I8 a+ I  U1 }4 b; H6 l9 S9 \
    " ^6 n( p4 R$ ^* o! U
    4.2 判别器的损失和优化器/ e# N7 V7 P" r: @9 _3 f" {8 D9 `5 x
    8 W' P  |' r5 V2 c0 Q' x' U5 b1 o

    5 ]( L# s* P# I- ~9 O五、训练模型
    & g+ I  _2 `) l  R# N" |: n# @2 R1 y& k7 t

    ' }% O) _+ U& x5 s5.1 保存检查点6 A! r8 b* q& N& u% r

    ( a$ Z( Q/ |2 ?7 v/ L

    " p: k& C6 v1 b0 A- R8 N0 Q5.2 定义训练过程. E% `6 z' e( `( ^( {2 I% t6 \

    ; e9 a# j6 N+ b; I# t9 _

    ; p. k4 @7 r4 ~( m5.3 训练模型& V- b5 Y# A0 o+ C7 g9 W! T
    ) Z% k1 i: ~- Q: Y& C9 l, |

    2 C- ?6 U0 V2 |六、评估模型5 x4 t0 @; L# |7 y; o
    7 ?! u1 u2 B+ L+ f6 C' ^
    / Z( o. _4 V3 t- Z* [' O
    一、什么是生成对抗网络?# N! a: C9 M* q8 K. b, {& \$ V
    生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。
    5 R4 u! [7 |/ W
      ?* R/ \* t5 D0 {4 w, C

    ! Q6 g9 I: |+ s$ r$ \! ~, I生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。1 C) p4 k# _( e- Z. d- s) t
    5 z; e( @0 @8 j2 ?3 q
    1 F* k0 U1 T" D% V  T
    判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。
    6 C, l$ G3 b) ?" F; F5 q& T
    9 E5 W1 B) @; D# w
    $ j1 }' e7 t% S8 U! U
    训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。
    & e# l; ]& f  W: @7 ?$ [1 z6 u2 w/ M" D; G  y. j* C

    0 H- n  Y6 `6 e- F/ r! c! |; r当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。
    3 Z( a: C. ^8 o% }) V! h. z$ A: V1 V. R3 a) r2 j' a: {9 T, l

    5 J" Z: ]0 A5 ]  U* b7 U' @本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。
    ; r# G. \, F4 y2 M% {# e1 E# a( M; U9 c; l
    ( I  w! W5 J5 }/ P" U
    二、加载数据集. p0 ?% ?7 q2 R  a
    使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。; T+ V  I+ z* q5 o( ]; n

    ) e7 z/ P9 ]& I# P4 O; E/ G2 d  V

    5 V- u9 m# J6 |& V2 Q1 S(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()2 V- s/ i7 `* H) ]0 x$ X& Y

    - ^6 U4 J" {* o: _( v. @$ Otrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
    3 C6 H6 q7 W8 c/ i; w* v2 s) @train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内0 S9 ~3 A$ I+ h0 |+ R

    & ]/ e+ k$ |0 K4 b& I% y0 @* ]& N4 YBUFFER_SIZE = 60000
    9 {0 U: o* ?2 R0 F) ZBATCH_SIZE = 256/ b  J+ u9 t$ V" E2 l! t  q

    % k; R9 r* Q4 F# 批量化和打乱数据: o# Q& K* e6 ]* }1 V9 A
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)  d6 z/ q% z% S( I+ h9 U& t' Z
    三、创建模型+ E* h5 `# L- O6 |+ i9 p
    主要创建两个模型,一个是生成器,另一个是判别器。# r" h; z( u6 V1 r# p

    8 d( [# t/ l  d4 e* d: r% l  P: `

    % _& ^! x) E: o- T3.1 生成器+ s1 L4 E. _! q$ f# S
    生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。
    5 a2 j; ~9 K3 i* g4 H& K
    ! o# e4 P" X3 h# N% L. _

    / U& n% w0 b3 b& W! K7 m4 L" M然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。
    6 K/ _0 `7 N; _) Q1 w% T8 t( ~; m: ^" r) \: P, J# k

    3 k; n  s/ ]1 K! u后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。. l$ d; n  `. X

    0 M3 n8 b: y( u' G) @5 a; y
    ! T: {* \" F0 A+ B" z
    def make_generator_model():
    6 K+ d, M/ G4 `/ N7 X8 r  E    model = tf.keras.Sequential()
    4 m* B9 @; b1 Z1 T8 a    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))), {) z  }  V* E  n: Y
        model.add(layers.BatchNormalization())6 \. G( H/ q( |, @
        model.add(layers.LeakyReLU()). @" C& V3 s2 Q5 Q; H

    : d$ q! w8 I. Q$ j    model.add(layers.Reshape((7, 7, 256)))+ ?6 |8 y; P: w/ w. D. p3 Y
        assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
    . ]3 }( N- \* w% M9 p + F: a- l" A( l  `: m: X8 Q" q
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))# B" [% {2 a3 E  s* `( q7 K; w$ J
        assert model.output_shape == (None, 7, 7, 128)8 b( ]& |& O8 h6 I! |
        model.add(layers.BatchNormalization())8 U: Q" [6 a5 }
        model.add(layers.LeakyReLU())
    . Y; x" F( @6 w" _$ [, p ( E! ^' L7 Q5 H* s5 t4 D
        model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))0 s+ R7 I( w* R% c) O; U
        assert model.output_shape == (None, 14, 14, 64)
    / V) J9 [+ h. |! s    model.add(layers.BatchNormalization())
    0 _. m  U/ c6 L    model.add(layers.LeakyReLU())4 M# y( o* X1 {" ^3 F7 O* R, M
    - t- Y7 D8 z* _* X" Y
        model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))2 X1 j( n- R: g6 U; v1 `" {
        assert model.output_shape == (None, 28, 28, 1)* j. X- E/ W3 K/ g2 C# w/ H1 x

    ) o: E% d( w) M    return model3 ?4 M4 o5 k* C
    用tf.keras.utils.plot_model( ),看一下模型结构
    " X: k* [; Y) B) ?$ k& C, ]8 b# M
    " x# k4 z" I; J, x: [) ~9 C

    3 E" y# p3 O8 c3 @0 A6 {+ g
    1 J" K. O: I( p9 I9 |
    : t( A4 A: n$ Q% v; N: f
    3 D1 y& S6 [1 C- s( Y! w7 u9 Z  B
    用summary(),看一下模型结构和参数
    4 `( P! q: n4 g+ l, ?' J0 M
    * o* }2 A+ b  R  t
    / n8 A# o  @$ s% z; m4 D2 }

    ; F3 b1 g8 h1 c: h; x, Q
    # a/ c+ V4 @7 z0 n

    7 P9 P+ A8 g: }

    : ^- i& g! w0 F, \: I. m使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。/ E6 c7 u/ y" K* _5 y6 ^" i

    1 v- k9 y' x- M4 ?
    ! W+ k9 J  {/ k7 x
    generator = make_generator_model()
    7 d) w; ]/ c) B; Z8 _6 g + Q% O% \2 X. e2 a& S  v) p  @
    noise = tf.random.normal([1, 100])( P# A: V# y/ c' ^6 H
    generated_image = generator(noise, training=False). }+ Y) G( R; O/ ?# L8 |- I

    1 U3 D# ^1 ?0 P$ M) m2 Kplt.imshow(generated_image[0, :, :, 0], cmap='gray')
    ( }7 i( ^+ b# |* k$ b! T, Z8 @
    - J; a/ u3 B1 j! `; Q" k4 Z/ e+ @
    % T* H' Q5 O, _' N8 A; G

    3 [# u: @2 @1 x

    ( ^6 y: I) Y& x3.1 判别器- W& R% ]" ]7 t5 |! h. U  X
    判别器是基于 CNN卷积神经网络 的图片分类器。
    , j- v$ w* Y- N' K6 k# a7 v7 Z  a! h& v6 p4 q
    " ~* t2 W' Y  W! \( |$ @2 A
    def make_discriminator_model():6 r& Z: E! N9 B  r: Z# T
        model = tf.keras.Sequential()
    2 P- w1 K9 B# R- b+ d. s+ H! H1 `    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
    5 s" ^0 m' k! g- m  ~3 E# J                                     input_shape=[28, 28, 1]))0 I' z7 F7 {" L' V! Y
        model.add(layers.LeakyReLU())
    / \" t+ c0 H! G; J; i2 O    model.add(layers.Dropout(0.3))
    0 C- Y# o( h6 Y# n; T& B
    $ R5 N0 m1 S) {: D6 w7 B    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    - B1 H0 p. I  T0 S5 c2 G    model.add(layers.LeakyReLU())
    ' S1 l8 C/ a6 p4 [1 e: @5 A    model.add(layers.Dropout(0.3))) h: c/ F, V" h) Y+ |

    , C+ T: Z  {0 c+ ?% B    model.add(layers.Flatten()); s' {% K& V9 i7 q* V4 i1 ~# m. @
        model.add(layers.Dense(1))3 w& i" F4 h' n- ?; m9 v: T

    - z0 {5 P8 M. V    return model/ I8 ^: D, P3 r2 c) s
    用tf.keras.utils.plot_model( ),看一下模型结构
    4 p8 [6 c+ R# q& j: Y
    1 ]( T/ N. Y- C

    : S/ R9 ~- z" K4 l0 ~7 A8 B6 j" D# F
    4 n& V. u" b  K7 F: F% P/ G
    3 ^( g: U3 g5 b8 U: d7 D3 a
    , j" m/ a+ F5 w  l& F

    1 e( Y7 u" z/ N- B# E用summary(),看一下模型结构和参数6 j# _" p5 _( Q  g* R3 V
      @( b3 {5 g- X) i: R

    7 }7 s* i9 v! a( D$ d: Z3 x, {0 z, s1 T
    ; n; W8 n3 \; @! W5 e& J0 v( e

    & M  J! q+ `& [& D2 z4 n
    1 l7 u5 ^$ N  m  K1 q
    四、定义损失函数和优化器$ n7 V5 n1 A" i6 R( ^( _. d
    由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。2 I1 Y1 d; e0 L6 L# ~/ U5 a% z: \2 k

    # j6 T6 K9 Y9 d# R- k
    : t- k6 h# a- b7 O3 M
    首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
    * w$ b3 i0 {! e5 h$ g* y% x2 }- Q8 q& V2 n1 h( `/ ~
    . P  K/ B5 e" j0 Q: j
    # 该方法返回计算交叉熵损失的辅助函数
    # w' I: ]( i) k! q' bcross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)( R; G' j8 d, Q! w0 S( K
    4.1 生成器的损失和优化器4 q3 n4 L9 U6 v/ i3 W2 i' D/ x
    1)生成器损失
    8 D% D) T7 i/ ]3 i3 p6 s( p0 q  S
    % J) k2 l2 h6 D
    生成器损失,是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。
    3 Y2 B7 r' E. h. F# c
    7 ~8 e  D6 q. z! C3 \8 \
    $ ^! f; {6 I) P' k( `9 d
    这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。9 r9 v+ H+ {! ~& D1 [" a  Z2 F

    & r. e0 f. i1 T  {1 w

    , [! m" Y$ v; v  S* I! z+ Fdef generator_loss(fake_output):
    9 T, Q6 G0 p' k7 _0 j! O$ Q- L    return cross_entropy(tf.ones_like(fake_output), fake_output)
    3 j; i* U. x9 a0 l1 c; U- A1 i2)生成器优化器( P, W# q) ^3 e1 s0 }8 h7 [# r3 U( b

    4 Q6 [( H$ w& E, y; [$ i
    $ @* |7 B9 Q. Y
    generator_optimizer = tf.keras.optimizers.Adam(1e-4)0 ?8 E+ V- N) E; j$ S" g" i8 K
    4.2 判别器的损失和优化器$ k' p. `3 t. C, i  r( h
    1)判别器损失
    # p4 t9 }0 Q1 g. e0 f/ H9 Q  p) H6 i: _3 U; K7 l
    , R! k! E4 @* p/ Y1 C$ i
    判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。
    , b2 ?# B& Z# E: c+ Q& t. s( U/ j/ {- C; r
    3 A# V' c; Q9 E# C  z% N- r
    def discriminator_loss(real_output, fake_output):
    $ O1 c4 z8 N: @8 u5 T2 J    real_loss = cross_entropy(tf.ones_like(real_output), real_output). Z$ @0 I) s  o8 k8 B
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)( J, h1 Y$ Q8 R) ^3 I3 G4 k) u
        total_loss = real_loss + fake_loss  f& [6 r* F4 M+ ?* {
        return total_loss
    ' d+ D6 D. _" a" ?2)判别器优化器
    1 F5 b, |* P7 {7 ~
    + q# c5 j; b; Y8 p
    * f: F9 N; t1 q1 Z
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)6 ~# ^- U$ J( i) W) N" I1 }
    五、训练模型% _5 R% a, d, M8 m8 @7 L3 W
    5.1 保存检查点
    ) Y* Q+ M, R/ v/ v* v保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。
    0 ?1 ^# X* ~+ Y9 W" [- a
    5 \5 @9 G! n7 x- S' I! e
    2 U. |+ j0 Q. V, O  R
    checkpoint_dir = './training_checkpoints'
    8 R! J& y) H/ S8 {7 mcheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    9 Z& b& d8 f( scheckpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,# S# h# x+ B# v3 @4 }6 a
                                     discriminator_optimizer=discriminator_optimizer,+ R% ?( o& m$ ]/ M
                                     generator=generator,
    7 j. t7 y5 n3 V  o! X0 O# h                                 discriminator=discriminator)
    - [' S  w; ]% Z5.2 定义训练过程
    0 E9 P% ?: V; h7 DEPOCHS = 506 p4 U6 ]5 N+ k2 x
    noise_dim = 100
    - ~* p0 E7 U( q$ u; {/ \( {num_examples_to_generate = 16& X6 `) c, j4 k4 G/ ]8 @

    & O3 E- `( s. b
    + x1 Q4 \$ r  L+ O; Z1 T# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)# W! V- \2 V; q$ A8 e6 g# @: ?
    seed = tf.random.normal([num_examples_to_generate, noise_dim])
    3 h/ |1 C! x9 y5 w6 n/ W训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。2 b' M8 G! `9 T# L6 j

    . w2 R8 W$ _: d) Z

    + x3 J5 |- N# f3 G6 y* c$ n判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。
    5 p6 p* P3 L- D
    " N0 a% [% a$ G1 p0 {

    $ L# H& W: U: y% a( x5 ]8 S/ E两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。
    9 h! X5 x6 ^5 J: U6 H% `% v  L8 P9 R: b! m* r' g" R

    ( N/ S" m+ J$ d( U' S. C# 注意 `tf.function` 的使用* t0 W: H+ A4 c, f- z, c% S# m
    # 该注解使函数被“编译”
    % n* l8 q2 K. j/ v! x* z% `@tf.function
    & N: Z2 W' b' a# H# ?# [' r& ^def train_step(images):
    : T9 U7 n  a/ C* j$ q; H1 X    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    9 E1 c5 L* ^; x& p, U * s  k5 u9 P; u. g3 d6 i
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:, @2 D3 T* E  k3 B/ o
          generated_images = generator(noise, training=True)6 i9 |& @6 _: M- F9 f

    7 z8 a' b- u7 s: T      real_output = discriminator(images, training=True)
    0 d) H* h' Q: }      fake_output = discriminator(generated_images, training=True)6 Y- N, p, v! B( |; ]% Z; Z

    0 D9 f% |! W' @- R; y& Q      gen_loss = generator_loss(fake_output)7 U4 C$ }) b/ w" T* u; _( u
          disc_loss = discriminator_loss(real_output, fake_output)
    2 q8 F: a. G2 p5 ~8 a% M' }3 S
    ! P5 l* _: H! R; B% ?* D    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)0 z6 ], i7 U7 h4 t* w1 ~
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    . j7 T( `( k0 X1 ~  X. K 0 @# ^0 l" O+ `7 X
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    ( x9 F# X( e+ s% i9 Q2 v    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))# c: _, f  U1 R' q- H2 m. T9 a

    ; B3 s+ h' e$ V- q1 |+ g$ wdef train(dataset, epochs):, I' `' |! V8 Q: Q+ }' g, k
      for epoch in range(epochs):
    ( \% B5 p; M' P7 K& \& H" l    start = time.time()! N! \8 D3 i9 c( a+ k

    : v# J0 T+ p; S: s8 c    for image_batch in dataset:
    3 W- h+ v7 s- W/ d" h* b      train_step(image_batch)
    0 |, k1 Z  f, S5 m/ o/ ^2 D$ G 5 L9 Y( h! `/ m, v* Z  N5 g5 O$ Q6 i
        # 继续进行时为 GIF 生成图像
    4 o; ~; D  I/ A3 d6 Z    display.clear_output(wait=True)
    & O' R2 c; h8 L' n    generate_and_save_images(generator,
    ! h& Q  h( U2 W. y- f                             epoch + 1,) G( ^9 A1 Q3 d( K8 U* `7 T
                                 seed)# [# P7 t( e6 A$ O7 m0 V. P

    % u3 A/ q3 W2 u! x( F    # 每 15 个 epoch 保存一次模型3 B6 f. p5 r2 X+ ^
        if (epoch + 1) % 15 == 0:( i/ A7 B8 y7 w- [: t
          checkpoint.save(file_prefix = checkpoint_prefix)
    * I! ~) Q1 L8 T$ Z
    + {, G  T+ \& l" `    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))# v# f2 c. s$ a( P7 _3 Z9 u; d
    8 \; c2 H; X6 C! Q" i( _: r; g
      # 最后一个 epoch 结束后生成图片
    : L' N) I* L2 Q  display.clear_output(wait=True)) h% z  i' k7 u$ q/ ~
      generate_and_save_images(generator,  Y4 I* n' }. {" d5 p, U
                               epochs,
    & a- P0 ^' P6 c. R                           seed). q& D* ]" ~' `! C/ Q
    8 \! h0 S3 U! |! K; c8 ?, c8 a
    # 生成与保存图片+ @8 c' a: ^9 T
    def generate_and_save_images(model, epoch, test_input):
      U4 t* H5 \% a' z. X, x+ O  # 注意 training` 设定为 False
    ) i# }4 [" C. s. S4 z! g9 G% `  # 因此,所有层都在推理模式下运行(batchnorm)。
    % W0 I9 b, b, V, F- @  predictions = model(test_input, training=False)1 W9 s- s* F4 ?/ E7 y6 b5 i
    . k3 o. T9 \2 B$ g* e
      fig = plt.figure(figsize=(4,4))
    2 J0 d# D( d1 q* w: W: a- } ! d0 C# Q1 j: N- c# t# ^
      for i in range(predictions.shape[0]):
    5 o/ }; g; h+ U/ N$ C' G      plt.subplot(4, 4, i+1)4 ?+ t" \2 @# S7 q) P, i) V
          plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    / h7 [$ n% E5 j1 a, n      plt.axis('off')4 M$ v* K9 N% k, a4 _" A5 ]9 \
    - f: d" \, r# i" Z8 M+ O* g% Q
      plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))$ D; o$ n7 l- C  {! T% k1 a
      plt.show()* H5 _7 l1 _1 h! D0 P8 M
    5.3 训练模型0 V% k2 Z. J. l/ U7 \7 v) ~
    调用上面定义的train()函数,来同时训练生成器和判别器。3 @6 y3 b! @* Q' q

    8 U7 H9 d) T) o; E: c0 K- M1 K" W

    4 m4 v  u9 |, A% @: V/ L5 J* H2 W$ z* G注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。
    $ w  |, J& Z1 g$ I6 x" q$ ?0 B4 b: `! t  w- |- v9 `8 ~

    / ?( ?! _  d0 z0 U" b%%time
    , P& E9 E  `" }2 Q" n1 z1 n! Ctrain(train_dataset, EPOCHS)  u) H$ Q& g) l4 @8 G  a
    在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。8 Y" L6 N4 x  l# H; ~4 _7 z* s

    5 R7 Z2 m% J  G

    . p  q3 u5 S; c8 O0 N训练了15轮的效果:2 Y1 Z1 _$ a& L7 z7 W) a

    ' [3 Q6 \4 R* g; f! e3 J" ?2 x5 t& B

    / t$ |* P( |; _) C
    # O. ]3 r" o$ Y: H6 T8 C9 `
    # b- i6 [1 O! s4 P& h2 \

    ! Y9 U$ G" o6 R( u7 R" ]

    / D8 E% X* i* k1 o训练了30轮的效果:
    8 K5 i: e* b/ ?) s6 B$ T0 ]" z/ n  [
    % z7 ?1 y9 [  |& a' L/ ?
      ?: ?$ A/ G) j/ U, O, e; n3 \5 u) D
    2 f+ `6 z+ t8 r6 L

    # @- b/ T; L$ ^" V& @  m( y$ y
    ) R- P8 C' D) e2 s
    训练过程:
    ( X' y* ]$ y8 t- S3 Z& Y
    ; R* e0 [. I5 H4 `" ~2 D

    ! D; a( M2 _  D9 J8 A& j9 n- T$ ]4 V
    " h. J7 s$ Z* g
    : Y3 j% F9 U# Z, Q+ `0 A) R

    5 E8 U" |5 ^0 a3 O, z/ q) {
    % k* h+ q0 D+ c+ H  S6 N
    恢复最新的检查点/ d6 V; g: i( H" _" s* s
    ' B8 y7 Q; V- p0 m$ Y' `4 o
    8 N, c9 U" a  Q4 s  O; _: u8 C; i
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))8 f/ I/ f9 b4 g% N
    六、评估模型" k8 t8 Z# I# {, ?( g# `2 e
    这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。
    * W! q6 t$ ^' h: h2 y( c
    4 o6 A# P% X9 P$ i9 b* @1 v+ {% d

    + L( f2 t# f, P6 \  c' Q# 使用 epoch 数生成单张图片+ D& U3 ~( `- b
    def display_image(epoch_no):
    2 C, l6 e$ a; J8 f  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
    6 l2 d* s3 [" ^- m9 C
    ! y+ g" B: s! {+ y* E# O# \) G8 Hdisplay_image(EPOCHS)( l6 c1 a- q# i+ |+ K2 T& T/ Z* d
    anim_file = 'dcgan.gif'% e( a/ z2 J: U4 T

    , M, V, S; @& y/ N5 g' [* ^with imageio.get_writer(anim_file, mode='I') as writer:8 ~; w2 }8 I+ P& \2 Z5 Z) u
      filenames = glob.glob('image*.png'). z7 ~! a) w, t5 U* z
      filenames = sorted(filenames)" v9 w, ]. t) c) A; e5 f6 {
      last = -1* S0 m/ D% ?3 e1 T2 O4 y1 V9 j
      for i,filename in enumerate(filenames):1 X- \6 c8 c+ \
        frame = 2*(i**0.5), I1 E0 y8 P3 g/ i* V
        if round(frame) > round(last):
    - l0 @) a1 Q+ \/ `% s      last = frame' `2 f# z4 ]7 M- x' L4 [
        else:, I5 }! T& J- {" l7 K, c2 Y  d
          continue3 Z; Q9 m9 |3 Y
        image = imageio.imread(filename)% C9 P- a  j( }5 o9 [% Y$ m: t
        writer.append_data(image)
    & }1 Z" P5 Q/ g: C6 u6 o  image = imageio.imread(filename)
    3 u# r! C! j2 |  writer.append_data(image)5 k% E* k0 {* v1 D

    ) {  {+ e8 e3 j; himport IPython' z+ k, L; t" O' e4 |3 S$ z: j
    if IPython.version_info > (6,2,0,''):0 B! A9 r& N3 s5 B9 ]5 u
      display.Image(filename=anim_file)
    ; F# f4 x/ A5 c! X- e
    8 Y5 A8 K; }" ], O
    9 D# Q4 `5 |% |4 N

    ( s; |4 R! @: h5 _

    * h: i4 |6 E7 O* p完整代码:
    5 M7 B% a" U- H% w9 N9 k7 G3 p( X' I8 A+ ?: L
    ! `# Z$ T; h. [( \1 |( o5 C0 R
    import tensorflow as tf
    $ A" S9 V( E" @( P( p5 ^2 ]import glob5 b6 @+ o) F# ^3 M2 T" I4 d4 ^
    import imageio
      Q- O# z. H3 Z. d5 Y" O( [+ t. X0 o1 U' qimport matplotlib.pyplot as plt* j" {# o" ^) c6 m* h( ~
    import numpy as np  F8 v; X  Y3 [/ h
    import os
    * h6 u; {& B9 r+ ^+ iimport PIL0 |! o5 \3 i. ~; i9 X
    from tensorflow.keras import layers6 l) z8 U4 Q2 O: i1 |
    import time9 r& `) C2 K* p$ f% ~8 U
    / w' \1 F. Q$ o& }# C8 \
    from IPython import display
    $ }' X/ L7 @& Q 5 u* N  E7 o( L) s  K: {
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
    4 w$ i4 S3 N# ?$ D! q2 f
    # O1 [3 z' d* q3 f! gtrain_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32'); g" p8 u& Q9 |+ c8 h0 L# z% n/ I
    train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
    4 e; |# W4 L& j+ E$ O2 j 5 ~8 H; ^, f, @& ~9 ~
    BUFFER_SIZE = 600002 s% E% I5 z$ P) |; r( e6 |, z7 {
    BATCH_SIZE = 256* F2 m# g! n. E  |
    & T7 m2 v* k* V% `' z* e- P- c9 H) q
    # 批量化和打乱数据$ ^6 [& r) f8 k: Z! X
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
    ' J1 k# s& x" w
    & g" N: k4 V$ c& y# 创建模型--生成器  H/ O6 x* L* x) o
    def make_generator_model():1 c/ R. w' P' f& M6 l
        model = tf.keras.Sequential()* C9 d2 v1 K+ W5 [
        model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    ( O. m" `  |0 z/ n    model.add(layers.BatchNormalization())
    " ^, u1 c1 g' q5 k! E+ {& q    model.add(layers.LeakyReLU())$ F3 e& s. Y) @" a
    $ Q8 c/ }: M9 m4 H
        model.add(layers.Reshape((7, 7, 256)))
    % q" K# z2 @- u: f, a% k& z    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制; G, n% ~& a+ ^: y( K" E
    ) h9 z) Q% u0 b
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))' ~5 c- C# j7 w: z3 v/ ^
        assert model.output_shape == (None, 7, 7, 128): s) E. x  B" m2 ?/ j, t
        model.add(layers.BatchNormalization())# x' E: |1 }3 S4 g0 A
        model.add(layers.LeakyReLU())9 @* E1 q1 e# _' j, s. B: h% Y
    , j7 s0 e+ t2 @% P+ ]; D
        model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    % l7 _9 u+ u; x9 u! B( T7 H    assert model.output_shape == (None, 14, 14, 64)
    " i) p4 r! v; A$ T6 G2 i    model.add(layers.BatchNormalization())) h0 ]  V' Q+ ?5 Q& i% P: s- i. g
        model.add(layers.LeakyReLU())
    " o# f+ W( D$ C) r5 P6 A7 h1 S/ z ( J% j1 N( _( U; ]
        model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))) h( _; F) k& N2 ^6 i+ S
        assert model.output_shape == (None, 28, 28, 1)2 u6 k. w- |. R; v
    0 E" r" G2 v  N  P
        return model5 ?. z( x9 t8 l( e
    $ q8 z' E; m  X; v* Q2 V
    # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
      F' p8 n: b. Z, @+ sgenerator = make_generator_model()
    ) K  G% V6 \/ p# d
    0 m3 b. e* A$ x. V( u( ]noise = tf.random.normal([1, 100])
    ! K4 P7 j, b$ f( w2 ]' g1 m2 g; Hgenerated_image = generator(noise, training=False)  C6 y4 w: n1 Y, g5 W" ]

    . d- [9 ^5 g: \2 m: o, w9 fplt.imshow(generated_image[0, :, :, 0], cmap='gray')) V& X$ z0 U' }' L* W  g. Q& O0 ]
    tf.keras.utils.plot_model(generator)
    2 g' y, @+ [( V' W1 E
    % G, a' T) X) A: f" K* _# 判别器
    - E3 R' G. b) Ndef make_discriminator_model():
    8 @6 B2 u( C% U$ H5 F- @    model = tf.keras.Sequential()" \9 i4 @. r) t2 R
        model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
    # M2 I7 ~, t0 T3 S! ~8 h                                     input_shape=[28, 28, 1]))
    / h3 @1 v/ |* D+ x4 j# O1 l) a# S    model.add(layers.LeakyReLU())6 _0 \: ^) f0 z) L" d0 K! Y: Q2 T7 F
        model.add(layers.Dropout(0.3))
    . ~9 J. d2 D( B; |4 }0 z 4 j* ?6 W  b" r
        model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    & C$ @% @( {4 D0 }' j2 \- f    model.add(layers.LeakyReLU())
    & {. r% ~8 V( s/ A; a0 U! a    model.add(layers.Dropout(0.3))* J  m; {( `% b/ `2 X$ s
    ) F" E) _' U. G- Y7 X
        model.add(layers.Flatten())3 C% f6 S/ q; f
        model.add(layers.Dense(1))
    # Y5 `0 X: ^/ \1 ^* y4 g3 Y' q ' A1 L. P9 v& `4 e7 Q6 n& N8 }! ~* h
        return model
    ; H2 j4 |. H" g6 }8 c , ^2 A1 V, B5 F' U, D3 L. j
    # 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
    6 Z9 U# N- @: Zdiscriminator = make_discriminator_model()
    ' {0 t  J& v( `' G9 L  fdecision = discriminator(generated_image)
    6 ~; o$ H+ u1 @* {+ Aprint (decision), s0 p6 Z7 Y0 S1 F  n" L
    1 [7 Y$ F8 I4 V0 g9 v9 F
    # 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。: F& X: j* i# b/ t  v; i2 a
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    " T) X9 d# V# V4 y% [8 j* ~  z 6 J/ a$ U% @5 m* z
    # 生成器的损失和优化器
    3 L+ z2 }; ?5 v) E7 n0 j1 h/ M0 mdef generator_loss(fake_output):
    ' g" {, X8 T3 F; v0 d5 ?    return cross_entropy(tf.ones_like(fake_output), fake_output)  Q$ p3 u6 Z, Z8 O( g5 f' f
    generator_optimizer = tf.keras.optimizers.Adam(1e-4)9 j: b( l) @1 V( [* M, R

    " e8 d. B1 \9 d/ ^! ^# 判别器的损失和优化器% |% e6 W6 ]; A2 a; B
    def discriminator_loss(real_output, fake_output):. M5 ~8 W' O5 t
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    & d* V) ~( u0 @4 z8 j) C" h/ b    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output), c+ w& P& T; @: v" e
        total_loss = real_loss + fake_loss! O+ [/ w% r) T: @. n- e
        return total_loss
      F6 p  p! s3 i9 I0 |$ y* I4 pdiscriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
    : Z4 C- `2 w; |4 m' o ) e4 L  P7 `4 J  P9 R5 c+ P" K
    # 保存检查点
    5 s% _: T6 X0 b* |  [2 E! N. Ocheckpoint_dir = './training_checkpoints'2 ]1 O: W. F8 o8 F4 g& L
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt"), o$ i& [) A$ q1 j' v) @" E  L9 k
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,; h" J4 o' L" t/ R
                                     discriminator_optimizer=discriminator_optimizer,7 m2 K, o+ j7 E: H
                                     generator=generator,
    5 N8 l5 W0 ~. b& m                                 discriminator=discriminator)
    " S) H" _- Z8 z! l
    % U/ y% Y$ ^2 U+ y) C# 定义训练过程0 Y2 x( M/ V3 A* }+ s. |
    EPOCHS = 50
    : L# s& b. J5 a6 D5 k$ \noise_dim = 100
    2 y$ \7 @4 Q8 A* K# nnum_examples_to_generate = 16
    2 u$ F$ ?5 ?4 P. j; m2 s) ^" ~1 @' U: _ # g; m0 d# n; V! `( j- e* Y
    # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
    $ n( B# X7 R) P8 G5 w' sseed = tf.random.normal([num_examples_to_generate, noise_dim])0 h* n+ F2 ^9 b0 d: n; C

    9 }9 e# F, H/ u0 S& r! X# 注意 `tf.function` 的使用
    " W. W2 Z$ L3 r& w2 `# 该注解使函数被“编译”
    : v% b6 _% h1 p@tf.function* D, |4 Y- `& a  F' j8 V9 g
    def train_step(images):
    : `5 B- j7 ?8 W( W! I    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    " L- A( `: h6 ]7 Q2 _  _4 ]0 e $ V9 W: T2 k7 n: a- p, S$ ?
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    9 l) x- Z5 t2 A3 }/ `/ Y1 T/ G; O      generated_images = generator(noise, training=True). K) p5 M( j3 P1 ?* L& |

    + D! b" |6 X; Y4 _% B      real_output = discriminator(images, training=True)6 Z% r; n1 g# S4 t& J: G  U
          fake_output = discriminator(generated_images, training=True); h1 T5 ]- ]1 q

    " X4 }8 V4 f! g& Q4 x      gen_loss = generator_loss(fake_output)1 t# A/ h, C4 O: y0 k/ I3 Y
          disc_loss = discriminator_loss(real_output, fake_output)1 h. x5 V: B% ], F/ h- O" N7 c3 D
    8 E" I/ o$ V* c# `5 V% q
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    ; ^; D# z. E' {. E) N# y    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    9 p4 C. S- n7 T- L' w4 F% O7 H5 I
    ) Y5 D1 W5 Q1 r' n5 R! N/ e7 F    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))0 F& J2 X$ E; f- J2 T
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    3 r; s; Z. W6 P( I7 m3 f
    ! Q! g+ m0 h7 ]4 c. j( |def train(dataset, epochs):
      b$ y2 K' Z9 v& S" P" R& x  for epoch in range(epochs):  |4 q+ m0 k) x* a! ~/ S, Q
        start = time.time()9 v! r& n# C; _3 r1 K' ]" Y
    : Q% ]1 ~' V& G, b% S1 d
        for image_batch in dataset:! r; T+ o* W+ p$ C1 G8 n5 ~
          train_step(image_batch)! Q; k- A0 S8 X- o2 [3 h

    * a) W" Q$ M5 A% l3 ^( X    # 继续进行时为 GIF 生成图像
    - v* m; q3 Y# L; m+ t' L    display.clear_output(wait=True)
    - I5 v, v4 Q0 J& m3 t2 o) C; U  ]! ?    generate_and_save_images(generator,
    ) Q' U& P2 G; m) w$ Z6 n                             epoch + 1,
    " M- T' v3 n1 L                             seed)
    : a' ?' ^& J/ G" j6 t & n( y. N4 ]9 G. L( T/ g! f7 j
        # 每 15 个 epoch 保存一次模型
    + c6 W" L& K& _) o: f: _    if (epoch + 1) % 15 == 0:8 e7 x- C# a4 K( t
          checkpoint.save(file_prefix = checkpoint_prefix)% @; ?) O2 ~0 g) }/ c- J8 h7 y

    ! b( c! g6 N: }! ?/ l/ o  l* D    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    % M3 N; n$ m1 }1 a* }# T& t . c3 P# ^# O1 [! l. B7 a/ \, i
      # 最后一个 epoch 结束后生成图片8 q8 ], F/ x$ q1 d
      display.clear_output(wait=True)
    & Y# z& K) }7 @( o. Q  generate_and_save_images(generator,
    / D8 |7 A/ i+ ^! {' y- x  H                           epochs,, L5 |& o  v, Z& {4 P
                               seed)
    * {$ F( X7 }6 @
    ' B- _" f; W' z( @# 生成与保存图片( M6 o/ y; m1 M( `) G* n
    def generate_and_save_images(model, epoch, test_input):* h1 K8 D- V' _+ i! @7 I; y& N) \: J3 W
      # 注意 training` 设定为 False
    - R5 n% k! C7 `2 B  # 因此,所有层都在推理模式下运行(batchnorm)。) y) e# c* \4 Q) C0 ^# H: j7 R  }
      predictions = model(test_input, training=False)' l( b" w& C9 W- A. r; G

    * G+ h8 v) N5 q/ m. q% V% g  fig = plt.figure(figsize=(4,4))3 }6 c$ A+ n! \2 u: M- t7 p

    ' u( p+ ~- E, c* P  for i in range(predictions.shape[0]):) y# H' b4 }8 j. c4 c; [: C
          plt.subplot(4, 4, i+1)$ r( Y4 w. Q% w& \: Z9 i
          plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
    & V3 ?6 q: q* E6 A      plt.axis('off')
    & C: @" Q7 ]7 U! v( L& D2 \" j+ D
    ' j7 m: `$ y! ?% R- m9 C  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    & ]" n! |7 X7 V6 R2 y, A* x# {7 v6 p  plt.show()4 q3 u9 G# E. W% i; m

    " l: k+ A$ R! E* P$ `# U# 训练模型$ a! ?3 ]; M- w, H5 [
    train(train_dataset, EPOCHS)
    4 q' Q/ r# Q* r' o2 h, O$ E
    / [' }+ \* o. J! s& y# 恢复最新的检查点
    . Y( Z- R* _$ y6 x9 W$ s" Hcheckpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
    4 c! V0 n2 _/ P. C3 n. t5 C 0 l5 t2 q! Q% y8 J7 ]$ z% }
    # 评估模型
    5 L  c* n* q0 ^0 S' D# 使用 epoch 数生成单张图片
    5 v; J$ i6 d7 U6 G" `6 Xdef display_image(epoch_no):
    * F/ P1 t+ s& u( i& w  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))! {' K# J( Y: C
    + q! b" z# S& ]) O! y) m, M
    display_image(EPOCHS)
    : T' y# A% z, E0 }
    3 T/ `* J7 ]* Z+ B9 T7 N8 f) uanim_file = 'dcgan.gif'
    ' S# d4 y: [% K* M$ S% N; V( p
    ! s/ _; _0 L9 u3 X% z( r% K: Fwith imageio.get_writer(anim_file, mode='I') as writer:
    8 m* k$ \7 m% t8 k9 C( D8 y  filenames = glob.glob('image*.png')
    , T3 T' t* k9 I6 \! o! T" s! |- p  filenames = sorted(filenames)
    # U4 |& D) C7 @3 T9 M6 G4 }% y  last = -1. x5 M9 D# f( s0 X( F9 m2 p0 K& G+ n
      for i,filename in enumerate(filenames):0 X8 O- f, ^* V# K/ u; N% g
        frame = 2*(i**0.5)
    - g" ]1 B3 \, I9 y" c6 D& D    if round(frame) > round(last):4 ]9 P3 d, H' n. m* ]
          last = frame" G! Y* k" [4 R: z
        else:
    8 v  u6 f7 R/ W5 q+ u      continue9 p+ a$ T/ T8 U+ c+ s0 w; ^; C
        image = imageio.imread(filename)9 |: A5 p8 ^2 L! W2 f. W  G
        writer.append_data(image)6 u; a: I  [, M# |0 w
      image = imageio.imread(filename)
    2 Z: n: G9 x- n' }* u, R  writer.append_data(image)% f0 ^( r6 g2 d# t

    ; f4 b$ o' R/ Z3 ]3 ~import IPython* }* Z' ^# k- {
    if IPython.version_info > (6,2,0,''):* b- U4 Q0 l. @' `  a
      display.Image(filename=anim_file)8 R5 ^+ [. t( [' W; p' H! q
    参考:https://www.tensorflow.org/tutorials/generative/dcgan* i# c  N  g: _# q1 w% ]/ Z" b
    ————————————————+ ?2 L) o4 B* i- a
    版权声明:本文为CSDN博主「一颗小树x」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。* u5 `  b6 f9 ], `8 ^, i, y1 e
    原文链接:https://blog.csdn.net/qq_41204464/article/details/118279111
    : ?7 U$ d; C* K/ W2 i& n( _  A4 A: R7 U$ W0 R) |& j
    : c* t9 w3 w) _* K( I
    zan
    转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
    您需要登录后才可以回帖 登录 | 注册地址

    qq
    收缩
    • 电话咨询

    • 04714969085
    fastpost

    关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

    手机版|Archiver| |繁體中文 手机客户端  

    蒙公网安备 15010502000194号

    Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

    GMT+8, 2025-6-25 01:13 , Processed in 0.425916 second(s), 50 queries .

    回顶部