数学建模社区-数学中国

标题: 随机梯度下降算法SGD(Stochastic gradient descent) [打印本页]

作者: 2744557306    时间: 2023-11-28 14:57
标题: 随机梯度下降算法SGD(Stochastic gradient descent)
SGD是什么# g1 h2 D/ @9 y! ?6 X
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。7 y% L  t# M" m' M6 r" Z: Y/ E
怎么理解梯度?
3 u# J% X# T- N& `假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。4 ~# A2 N! P( O% ?% I5 U: q
! D' _5 }" Q/ t7 }' G3 F1 U
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
$ {: M& U. {! M
" t' P& c' W/ z. \# f每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。& D# W7 L, e" B* h$ Z* r4 z
: d- I; h. M8 K, _" v
为什么引入SGD% ?( Q* b" b( V- D* d
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
9 k) G, ?+ i/ e- y
  L& u9 g6 p9 H' l, K9 S怎么用SGD
  1. import torch  E3 n! G8 p" D) c) |1 S

  2. % v/ O$ K* J6 U7 c5 D8 k# i; J- R& I
  3. from torch import nn
    ; F8 b$ p; L' k& U+ C5 J' p; ]
  4. + u3 b! r* c2 d8 U  X
  5. from torch import optim; i9 A5 d7 O9 D% E5 M1 h- `

  6. 2 j2 o% B: r6 _9 h" w2 t& Q
  7. 9 ^. z- I1 Y' m9 L! R8 M

  8. - H& N* n3 r( y5 Y1 {6 [1 H
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)' t5 q1 |5 B& |+ Z" c

  10. * k4 z% _5 h" i  z* s
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)3 m5 s7 z+ J- e! f# t
  12. " W: u+ V8 G7 w! }! ]; K
  13. 4 f* K% R) N# u- U
  14. ( T! U7 p# S3 S2 t* _$ V
  15. model = nn.Linear(2, 1)% R; R- O& Z3 y7 H* T: q) s& V
  16. + }" _1 O. w( \

  17. / v9 c7 A- h- X) G) m. x. l- R

  18. + w: \2 }9 r3 R1 B/ Z
  19. def train():
    9 |2 ?( |8 ~% o' O) q
  20. " h. D7 \: h' x6 W' o1 d
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    # K0 r; o- n  R# m% y; j
  22. % g$ D& _. ~4 m6 G3 j
  23.     for iter in range(20):$ d7 K  t; ?; J7 E

  24. # H- L( @* K6 z; Y2 F
  25.         # 1) 消除之前的梯度(如果存在)
    1 g, Q% ?! R( p0 N7 W4 u4 ]6 b

  26. & j! T0 Q3 B/ u7 C
  27.         opt.zero_grad()
    ( ~- b% d: M  a( h
  28. + ?2 k) w& l8 h0 x( q0 Y

  29. + k8 ]2 a' o4 X9 A' I

  30. 1 S# d% i, k: j/ h* L" B3 S
  31.         # 2) 预测
    # B) M  a: N# c/ m. m

  32. 6 _+ p  @5 @7 K+ Q
  33.         pred = model(data)( m2 _. D6 A0 m0 C2 s
  34. / y  a& R4 X$ j
  35. ' T5 K3 E2 O! a/ x! \8 r
  36. 7 t, ]2 u+ p0 F9 L0 q1 l+ L
  37.         # 3) 计算损失
    , X6 o+ A7 q, w) |

  38. - A3 M0 x) T$ `; T( e
  39.         loss = ((pred - target)**2).sum()
    . G0 f- I$ P$ ~! `1 ]1 a

  40.   x8 p/ v3 v1 x( Q

  41. ; v* d, @, g5 u- q( ]
  42. ! Y3 }: Z& U- f; a+ }" }
  43.         # 4) 指出那些导致损失的参数(损失回传)$ ~/ C& r6 ]/ f2 h! }9 X5 G
  44. " y$ W4 @, A4 l  R+ t% Z, E$ l8 u
  45.         loss.backward()3 Z3 V; \% v% i! F) q% \

  46. 7 H5 F+ x  Q( h: L
  47.     for name, param in model.named_parameters():
      `: L* T- z, q3 W) x! T

  48. , p) y; J# x! o5 e# E
  49.             print(name, param.data, param.grad)7 B. ?6 p0 Y4 D( @/ X7 o
  50. / l' ]# n- j4 ~3 o9 [) u
  51.         # 5) 更新参数; N) O- W' Z- n9 a9 Y, i- R) P
  52. # n0 W- Z) j7 M( r
  53.         opt.step()
    ! t9 ]7 C3 T& @1 P5 E( X
  54. 1 z/ d! n; b8 I% j; B

  55. # E* x; [3 }# ]

  56. + y! G$ ^2 F0 H# p
  57.         # 6) 打印进程6 e' S' r9 r/ A; b! \3 L
  58. ; |6 d: z# I- \- Z/ K/ B
  59.         print(loss.data)' T0 `. A* j5 n( o; r& R: A
  60. 2 m$ N  j$ l4 Z/ {4 S+ N
  61. ) a( M! K, W/ C" C2 F& M
  62. 4 c9 P3 m0 `/ ]3 _+ y( V
  63. if __name__ == "__main__":
    ! o& U8 p  E# n1 R& [
  64. * E& @7 S7 M* u4 j
  65.     train(). X1 t$ n6 b+ i. x* h2 W# R" k, d
  66. ! g6 \4 g$ a6 j, e& B8 r& \, U7 k% [4 w
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。4 B5 T8 b1 F3 x
& A& V* s! e5 v0 @0 e- u/ N" Z
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    4 k1 z) H; {* ~2 v( k. X9 E
  2. * O9 o/ `* n5 ?* q
  3. bias tensor([-0.2108]) tensor([-2.6971])
    0 |; j3 r% X9 J5 @

  4. ! [/ y- S( R3 p) h
  5. tensor(0.8531)
    : s+ Y/ `9 B# O2 H$ @3 ^- N
  6. ; b' @' V! [! j# w( E
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])* i. c, L$ q. d+ X8 i: B

  8. & i6 d* d, V; j3 A7 r
  9. bias tensor([0.0589]) tensor([0.7416])+ S  a9 Y" n* R+ c
  10. 8 r. L1 R+ Y5 K. U$ \3 }, @4 j/ x
  11. tensor(0.2712)
    # n" X. j1 H+ _
  12. ! v  z" y. x. y
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    3 g) S9 E& F( L0 j" Z
  14. 1 A1 k- G7 n! n2 P1 c4 C
  15. bias tensor([-0.0152]) tensor([-0.2023])
    + @, K7 \' F9 O; q+ ?
  16. " C9 l, q8 X8 z/ Y/ ~5 L
  17. tensor(0.1529)
    + ?3 J3 L# J* S  }& N
  18. 8 ]4 L; g& P6 e- Q
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])) d* U( W6 G, m  Z+ y

  20. ' E& L: z2 Y. ^! A3 x2 ~
  21. bias tensor([0.0050]) tensor([0.0566])
    " u) {, e" x% v7 E& y) `

  22. , d+ q# ?$ `( M
  23. tensor(0.0963)
    $ X) x  @  D- [& r/ H1 }; k) _. o
  24. , c" \! a) ?- {. g- f
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])' o4 E% v& X/ U. Z1 P

  26. : G- s7 A- ]& S% c9 B$ U
  27. bias tensor([-0.0006]) tensor([-0.0146])
    8 e' g' I: l6 t
  28. . u1 X. r0 n& T, m' _* y- r1 o7 I
  29. tensor(0.0615)2 }1 R$ Z0 ~: R6 F- Z+ l+ f
  30. 1 R* r/ P% _7 w/ p
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])% w! n% b$ P) q- O4 `1 [0 k0 @: m  G9 G

  32. - W+ R* l* G9 J& c
  33. bias tensor([0.0008]) tensor([0.0048])  {& G& F) O$ ?
  34. . u5 o" m' Y# l( ?. r! V/ x
  35. tensor(0.0394)2 N7 }& X& Y9 \" _

  36. 5 n  a+ Z* W! V, w0 x  v: a% u
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
      x% p  b$ s, m4 y* [( f1 m
  38. * y/ i, K/ Z7 G
  39. bias tensor([0.0003]) tensor([-0.0006])4 G; t1 l/ W1 `6 O. }
  40. 2 `% z( y6 X4 i& v+ t# v
  41. tensor(0.0252)
    / {: G+ l( c+ ]( s7 h

  42. 8 O# t6 `% H6 C$ g9 E
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])2 H. {0 f' z) h4 u" r

  44. 4 g/ ?4 K, _( z* [0 f6 A% B
  45. bias tensor([0.0004]) tensor([0.0008])
    8 o7 N+ R( N5 }

  46. ( o& m/ k. x: j9 b% S8 Z% w
  47. tensor(0.0161)
    : g, Z0 @* X: `, K9 j

  48. / }4 w4 y! g/ e' T, X
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])% x; C3 r4 T) x; G) \

  50. ( a: x  z2 N; S/ U! m" ]) b& S
  51. bias tensor([0.0003]) tensor([0.0003])( l* t  L: w, j$ ~: c/ Q/ Q

  52. 5 g: P1 _0 y! L% b( d& m  H. U# {
  53. tensor(0.0103)$ V5 V3 }* e  r6 r# ]% H6 N

  54. ; t3 D3 S) N/ r. T
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]]): Z% h. L, c- [" j. c
  56. ( I# V' p* k5 J# z' \
  57. bias tensor([0.0003]) tensor([0.0004]); ?- L9 L+ A' o; R+ c- N

  58. 9 J5 Q: o% h: M9 n) A3 K
  59. tensor(0.0066)
    / [5 Z" X9 i3 Y& f5 ?0 O
  60. . N+ p3 S3 W! z9 @' c. M
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    0 l: [9 @" t/ I/ r# d/ q* i% o

  62. 1 `" c- A- [' t4 L
  63. bias tensor([0.0003]) tensor([0.0003]); N5 p% P0 [$ o
  64. - E  g' e% }1 ~0 V* G+ K2 H
  65. tensor(0.0042)3 o9 P# H  a* B: F

  66. 1 L. h! c2 `- e% Q! H* ~) C
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    3 A+ W+ G: _. o+ |
  68. 9 w) r4 z1 N) c: i
  69. bias tensor([0.0002]) tensor([0.0003])' J; a& I% q% x: v1 z

  70. 8 C5 u% s, j6 y4 o# V: S
  71. tensor(0.0027)8 d8 J! f# J  x
  72. 5 e( M! h# x. t  Z5 ?" g+ O9 t
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])# i9 O& o3 g" k, [4 Q$ O4 V( D5 [
  74. " S1 s. f" |: q+ @. e& i+ E" x
  75. bias tensor([0.0002]) tensor([0.0002])
    $ p& v  a' y# _. }; d8 o
  76. 3 V3 x3 N) Q7 q' _; e
  77. tensor(0.0017)
    " ~1 T; F4 H5 O7 W; f
  78. 6 [0 @' H: e7 |" R) j0 M
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    , ?' ]! ^: I  M: J" ?4 p- X

  80. + K8 ]! m. t: t" X7 m
  81. bias tensor([0.0002]) tensor([0.0002])0 S6 E) z, U4 r1 r
  82. 1 X) s3 i. L3 Q3 @# T$ s0 C
  83. tensor(0.0011)
    % r& U% |- a3 P8 L3 G% i
  84. ) p' R' I6 A7 r9 v6 g
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    ) k+ L6 @# ?, U  G  l
  86.   `; D" w9 s8 P% U3 u
  87. bias tensor([0.0001]) tensor([0.0002])
    / i. D. r9 W$ n; K% y+ m) n

  88. % ^6 Z% g8 ~8 |' r" r& Z
  89. tensor(0.0007)
    1 _/ n, i8 U/ B/ l
  90. * k7 S/ g' W/ V% h/ T  P$ r0 \
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    ' V  n; p& ]8 E* A

  92. + }( [& N# K* v: J
  93. bias tensor([0.0001]) tensor([0.0002])% y. q4 N/ D1 h3 g( B/ V$ ?" W
  94. 3 r) {% @8 E/ T: ?9 [9 V) A8 X
  95. tensor(0.0005)
    ! o4 L+ r8 `; v, \) S! g$ n+ G* j% S

  96. % y1 _+ B) c2 z* R
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])5 I1 O, a9 z0 ^2 Q
  98. : X5 g6 X8 _2 X  U- }9 ^
  99. bias tensor([0.0001]) tensor([0.0001])
    ! B# v0 K6 }3 G1 R
  100. # ?; O" Z. ]$ z- W
  101. tensor(0.0003)1 H  q( x5 E" z5 P; l
  102. 5 w0 @9 d+ \2 r- _
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    / f: i* p0 k, [- A
  104. , [* p" b3 w5 |
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    4 w) P" Q8 z2 J

  106. / |" D' b  I: P0 o$ ?4 s" N
  107. tensor(0.0002)- @7 R0 o6 L' _5 t4 S3 \& g. I) n

  108. $ v% b# [0 ]0 `" V  N5 d) V
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    / H' b8 z0 y4 R' G4 b$ W
  110. : i/ {! Z" L, X6 A" ^1 U) T
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    : h2 T% K" {1 s+ u) X
  112. # V9 L! Z) J5 ~8 {4 Y/ y
  113. tensor(0.0001)
    $ g) f9 `2 w& g  P' ]: x
  114. + ~. j2 ~2 C$ k
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    2 K+ Y3 s1 N: t  l
  116.   }6 o# b/ j: l% k  Y
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    . G2 L" Z, |3 L
  118. - }. a- j0 g+ h1 C8 [
  119. tensor(7.6120e-05)
复制代码
" j- ^0 ?, E: b/ [2 c





欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5