数学建模社区-数学中国

标题: 随机梯度下降算法SGD(Stochastic gradient descent) [打印本页]

作者: 2744557306    时间: 2023-11-28 14:57
标题: 随机梯度下降算法SGD(Stochastic gradient descent)
SGD是什么
/ p/ `- g8 z- c) k- NSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
4 o# T5 G9 D4 e' \- B+ z怎么理解梯度?% |% T2 F# E$ R: D
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。8 ], ?$ U: N: u3 i2 @6 J

, u( N. L; R0 r: E& ~在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
4 R9 ?6 M# [9 @" _% ^  i5 i  d' h+ Y4 a3 c0 w# f  h
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。5 q# E3 \$ @- V
+ B6 x/ ]( h( j4 I; x- w5 q
为什么引入SGD
3 o0 [- a7 X8 s  {深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
4 @7 ~7 ^# P& F' c
3 f8 o7 g9 |$ M' L$ D# F怎么用SGD
  1. import torch: B. ?7 e6 H4 }; [9 X& l& _4 n
  2. . b0 Y9 z5 \1 v# a
  3. from torch import nn7 u9 }( @& L" [3 l+ i; v# W

  4. 8 D: h% W, [* l' }
  5. from torch import optim
    / C' e* {9 O2 W) d
  6. $ r2 |& `+ u1 J" U* s0 _* K3 Q

  7. 7 T" R1 l" ~8 \8 t

  8. 0 k2 s$ B' t  v" q
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)* Y4 L$ L$ y4 s& x
  10. . R9 }2 V. G5 p2 o3 b7 q8 w
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    7 |7 b7 _0 L7 R: m5 y
  12. 1 Y9 v: {! w- h( K

  13. / T) ~0 c9 v4 D6 I0 E
  14. ( n' z8 Q* {+ P# `) E" q4 _
  15. model = nn.Linear(2, 1)
    * o' [( H7 a( ~% K  O

  16. : G5 Z, c/ x9 @' F5 ]
  17. - \6 s( k( R$ D$ N, ?6 ?/ }

  18. 7 X6 l4 Z- @+ N( W; y5 u# B
  19. def train():: G  L" t( L9 l/ h+ p

  20. 6 t! S- w* k) D- i3 E
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    0 t; Q6 V9 C/ G" `" K' p9 B/ h

  22. 3 l: j# l* S; S+ Z* u; T, L
  23.     for iter in range(20):) }  V2 w2 Y# f, e' C4 P" p
  24. 8 W" A9 |) J3 k: W* X" [1 c
  25.         # 1) 消除之前的梯度(如果存在): y3 C; X6 o% M! c! A0 e6 x$ ^3 x
  26. ( {, d' e/ F  t$ I* O! @
  27.         opt.zero_grad()( i& ^. O# @- f& o6 M5 g
  28. ! Q! B, o- ], c
  29. " [5 Y2 X7 [4 z* X7 F7 m
  30. # ~+ L# h. d: c3 a4 p2 t
  31.         # 2) 预测: y. ^5 v  o7 r7 T9 u2 V
  32. ! ]6 ?2 p* s. w& j: p$ A9 P
  33.         pred = model(data)
    , a0 f% R( R8 ^- C
  34. # Q0 f8 V# }) {, d  d8 [

  35. ( I. [& {2 s2 w; i$ c9 C* [

  36. ' C- ~( a' c" i: J3 W8 g2 {
  37.         # 3) 计算损失
    0 v- c8 L+ V2 I

  38. ! O  P+ k; T3 c5 E# @( V8 S
  39.         loss = ((pred - target)**2).sum()
    / ]7 `' |7 B6 o* t1 u3 M8 y

  40. - b+ N6 a0 l/ K# b- T9 B8 E

  41. : s7 ?/ Z3 @8 |9 ?
  42. ( @/ f- {3 T% J( Y9 f! d, ?. H# p. K# A( Y
  43.         # 4) 指出那些导致损失的参数(损失回传). x: n. K, o2 B5 k: m6 e
  44. % A: Y* b4 x4 F
  45.         loss.backward()
    1 t1 ^4 Y* {2 ]1 ?5 V6 o1 X# U# [" t

  46. 7 C# ~$ F7 D0 }. x* i, a
  47.     for name, param in model.named_parameters():- X8 d; P# C- O  Z+ z! \* t% s+ ~7 k
  48. 2 D+ y( h( i8 ]/ D' `% S
  49.             print(name, param.data, param.grad)
    : b7 s4 C% ^" [! d' d/ N! H
  50.   U+ g3 |( _$ E+ ?3 n1 B% h2 A1 a3 }
  51.         # 5) 更新参数
    ' r# [; I4 I% P, C. ]0 x7 o

  52. ! l7 a) c$ r2 l3 t. f
  53.         opt.step()
    / r2 p  a4 i0 N, \

  54. : b  s0 [: j6 L; Z6 Y4 F7 G

  55. - {, _$ {  _3 T& [

  56. " U+ x2 F) b1 f  t
  57.         # 6) 打印进程8 N2 `, ?4 K4 V- ^
  58. 3 J9 Q$ W. `2 t1 h
  59.         print(loss.data)
    5 {6 u7 ]8 i7 F8 Q  y8 j; W) Q
  60. , h/ \) @& u" n4 g/ J/ {7 [1 [! \
  61. . l' P4 u; g9 i: Z, E- [
  62. / @# j% P5 t& l! X! y9 u0 a! C1 Q
  63. if __name__ == "__main__":
    2 S* Z% @7 H, k! _( A" L
  64. # l' Q7 `' e$ i( P" b
  65.     train()3 T) K3 W& M, R) B. A% Z5 Q
  66. 6 w" o0 I+ d3 M1 f& m) w" W
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
: X* `: p2 Y: ~7 H2 o! N5 l1 r' `$ C  s8 {
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]]); v% L/ R8 z8 z6 W0 F  \0 V2 |" f
  2. $ o# F/ Y& s' T  o5 j! Y2 b' {8 Y
  3. bias tensor([-0.2108]) tensor([-2.6971])
    * k( f  L8 s9 I* S/ E
  4. ( C. z1 a/ B0 _9 J" J2 o" G
  5. tensor(0.8531)" m! e# n$ H( |

  6. ) A0 j  J" H: G0 b2 `, Y
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])# v- m% f- [# A7 s5 X* N2 h

  8. ' Q9 I1 W$ P# \1 e5 {' U" q
  9. bias tensor([0.0589]) tensor([0.7416])# u2 r% D( x# E' d/ H
  10. 4 S- w5 G" A& d" S
  11. tensor(0.2712); n; V* {& e3 S

  12. 4 n- ~8 V4 R1 S$ X- B; ^) Y  m6 N
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    + W3 c. o; ^4 V4 q. }3 L2 g

  14. , {0 d+ g6 J" J4 S) p* C  {
  15. bias tensor([-0.0152]) tensor([-0.2023])( @5 K4 C3 G9 P
  16. 7 r1 i8 s5 U+ n/ E1 w9 k
  17. tensor(0.1529)4 c% o3 F( r6 F$ J
  18. " D% g! J- M$ v  D; |
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])5 Z$ y; s1 S: Z( X9 V% \3 u
  20. + c& M+ E4 v1 ~
  21. bias tensor([0.0050]) tensor([0.0566])
    1 g( d) }4 g1 K( U& J
  22. $ X4 h* {4 t3 O$ ?# L$ D! t2 \5 T: Y
  23. tensor(0.0963)  m% Z( U+ p4 `9 ]

  24. % Y% B% c2 ~6 w
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    $ v2 j# l# i6 T9 i

  26. % c8 P/ m9 B; b( w% h6 D
  27. bias tensor([-0.0006]) tensor([-0.0146])* b/ I% `6 }  E

  28. 1 p0 A7 u, o# o8 B/ W) ~' {
  29. tensor(0.0615)
    ! p4 P2 L5 i. ?; m6 k( [& `, @
  30. ) R6 l; ^* S5 U% b2 ?
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    5 U0 O6 a# n9 a/ F& J

  32. : H$ }2 m; f3 \) x; k1 {
  33. bias tensor([0.0008]) tensor([0.0048])2 K. a5 S7 S' D  s: |

  34. $ A5 r+ f* m" T+ B9 {
  35. tensor(0.0394)/ Y3 Q8 V1 T8 Q4 B* E
  36. 9 e$ G3 {$ Q7 {
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    4 O1 K8 Y8 J4 y

  38.   S6 `/ c2 a9 H
  39. bias tensor([0.0003]) tensor([-0.0006])
    9 I* d& ]/ G9 k- W) Q$ f

  40. 9 h( Y+ c7 b1 ]3 L+ _
  41. tensor(0.0252)& [, `0 Y5 A9 V& r

  42. + B8 |7 l# |; a6 ~
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])6 s3 S2 p! Q( n" P. c6 r/ _2 m

  44. ) x( ]4 J, U8 J  t$ J
  45. bias tensor([0.0004]) tensor([0.0008])$ B% e7 o. {! @: ?
  46. . s  s. m/ s! F2 Y8 ~' c$ o* L
  47. tensor(0.0161)
    % L5 \+ Y* I5 _/ i
  48. / w, R1 l5 l5 g
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    + }2 G( D, V6 \$ `
  50. $ j- V( L. y% q4 s; _" U
  51. bias tensor([0.0003]) tensor([0.0003])
    ) S( U% O; M8 ?3 |0 L! p( u, G

  52. 1 j5 |# F! [8 O  J
  53. tensor(0.0103)( g1 A4 [; b  H3 f2 x8 }6 {7 _
  54. 5 Y* R! Q# _1 A% y9 B0 @8 E3 N
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    6 H8 q: s: r4 l% ~$ B

  56. : P8 P+ |9 f) A* [
  57. bias tensor([0.0003]) tensor([0.0004])  E; Z  o7 Q& b9 e! m, v
  58. ' H! u# g1 t4 N" B/ d
  59. tensor(0.0066)
    & J5 `. h9 }* G. r( L

  60. + v3 T. K( e* C9 n( c" b0 Z
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])& @5 h; R" h- {6 N7 h
  62.   v% O2 H! K4 [2 n
  63. bias tensor([0.0003]) tensor([0.0003])8 Y6 ?7 B, l0 k# Q" z/ Y! i0 W: h
  64. 9 I( D. t7 B' o' G1 k
  65. tensor(0.0042)
    ) y2 `/ h2 Q' l2 g

  66. : r* T& Y" _% g1 N% m
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])6 r" J5 G" F' [1 G; \
  68. * `6 z  s, m  M' I8 g  i
  69. bias tensor([0.0002]) tensor([0.0003])9 I1 S& B! A+ f0 k1 W: C1 `

  70. , @+ X" S5 B+ K; {0 l7 _
  71. tensor(0.0027)7 e/ t6 A. S0 L3 i8 Y  U1 E
  72.   o9 v0 i0 f5 t4 ^* O1 ?
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])9 w& q$ s+ `3 x$ T
  74. 8 L: O$ C6 x/ g+ j& m+ C$ Z  s2 V
  75. bias tensor([0.0002]) tensor([0.0002])1 L+ V1 h5 @9 q! ^7 g

  76. 8 N( B0 d! |( A: ^" D( t# ~8 Z
  77. tensor(0.0017)
    / L) U# J+ t" z* j: E3 s
  78. % V; |. {8 i1 r9 q
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])! a& a4 ]* i6 j0 g4 w" P; Q( d# c
  80. 1 M, ?. G0 G* _0 |1 @% E* X$ J& ~
  81. bias tensor([0.0002]) tensor([0.0002])
    / Z/ F" t5 C3 B( B+ F3 U: C
  82. " ]$ t( `- L& S8 [9 Z$ k
  83. tensor(0.0011)
    ' T6 x1 @4 T# V: L

  84. 3 _8 }* K. c" x
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])7 K2 d! ?  w  x6 \: C

  86. . B; v) E& }* O' Z
  87. bias tensor([0.0001]) tensor([0.0002])4 m% Z' v) x8 D+ |  s3 X0 C
  88. 7 T( z6 a9 p1 a5 {8 t0 M
  89. tensor(0.0007)
    3 s# D+ g+ m5 C1 H3 k9 Y( u) U+ e* a

  90. 6 z8 i) c' q  ^3 Y5 Z
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])0 d) k( P% `& s/ G1 o$ f

  92. ' Z# W4 r) M% T6 r8 R
  93. bias tensor([0.0001]) tensor([0.0002])
    . |5 n7 Y3 G% E" m0 D

  94. 4 N4 M/ ~' Q) S) j/ E+ r6 ^
  95. tensor(0.0005)
    9 |  u& z, C1 V: U( x7 ~: i, `

  96. + F8 a& T$ S( n  I& U3 H: |
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    4 K, {- h& M- `. x' o$ w

  98. 0 B0 T- G2 c- `' b+ t
  99. bias tensor([0.0001]) tensor([0.0001])* ^& |3 M* e! s8 P) R3 z

  100. ) f* V6 N1 c$ O
  101. tensor(0.0003)
    - F4 @+ U" M! B' Z& f) p/ h

  102. & x  P$ k0 b2 d6 J, O
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])4 q' \4 Z8 w3 Q' L2 B
  104. # K: {7 Z) U8 m& H3 f& {1 j6 j6 T. W
  105. bias tensor([9.7973e-05]) tensor([0.0001]). ]9 B. ]/ _: J+ r
  106. 5 V0 d% m1 |/ I! R, I8 {' y5 i
  107. tensor(0.0002)9 ~% M& r: A# i0 n4 a& f

  108. * g8 p. e6 J" ~9 u$ _& z
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])5 A5 E: o$ M+ \# j- U  P

  110. , R( M. E7 W; @/ s
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    $ P+ [7 W: i/ Y+ W7 f+ Q$ Z5 p
  112. ) E7 l) c$ c3 D
  113. tensor(0.0001)
    # J" m6 g4 Q5 Q/ R
  114. 1 X3 M8 K2 k, u) O. F
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])2 J' D2 F, O$ B( I. W, k' B7 W, |, n

  116. " T  J! f2 N9 Y6 n" ]* a
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    ; M9 Q4 j& T8 r' k) a* Y

  118. % _1 y8 V  Q0 Z8 q) B. ^. ]  U
  119. tensor(7.6120e-05)
复制代码

1 ~/ F  T/ ^& Z+ K2 I3 M




欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5