QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2250|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1176

主题

4

听众

2884

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么; x4 u5 k3 q- \7 ~! W4 a
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
) s4 k, p4 i2 E: N) c% d% b怎么理解梯度?6 S% I& ]5 k# ]( m
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。. R0 z0 w3 `! s- k2 W- K$ t

! v# R$ x+ b8 W' W, q6 M7 w- q! Q在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。/ u; d2 r/ @8 @# u

' R% F# S! x+ T* h每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
# F1 k$ e0 |3 b( A8 \5 H! d8 R' w  z
为什么引入SGD6 k' t/ T; |" l2 `/ V" W, Q0 j3 s8 _
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。$ e5 |) S9 U: \( b: e; k
7 O/ d3 ~& l  d/ E
怎么用SGD
  1. import torch0 S& [2 W; H. ~; Q1 }+ v; B) C
  2. : F: U. d% l9 }' K3 K
  3. from torch import nn0 T+ S\" Y+ Y! t# q; v1 v% w0 j

  4. / Z) F\" E- f0 W# j$ O& e2 S8 T
  5. from torch import optim- {# \1 ?- W. B! r6 M, r

  6. + a  _\" i; M+ r8 L3 ^  c
  7. 7 f- W7 W) ]9 G
  8. 9 M- ^; b4 g% h+ s+ Q4 B/ H5 ?& N1 g
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    ) V# y0 _5 E\" i) v\" p, c% \- a
  10. * V/ j+ R6 e' Z
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True), p0 l4 G2 m7 n2 E
  12. 5 Q0 R: D4 ~1 E& a# s& _
  13. / H8 ^; i0 h& E. S5 c0 Q

  14. $ F( [, u+ I6 F% I, p- d( E
  15. model = nn.Linear(2, 1)1 E\" y; p$ N3 m

  16. / Y9 ?) Y: c8 g# L( m6 h; l- _6 W
  17. $ V* S) j% Z' g6 c) z

  18. & n) p' d% p7 h9 s4 X\" b+ O/ Y
  19. def train():' Q$ H; @9 Q, s( p# `+ w( C- J
  20. 9 G. t5 x' s% v) a5 H/ [
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    , ?6 D& h5 x4 i9 b

  22. 5 |( A5 g  u, G1 n; N( \+ I* a1 U; @2 }
  23.     for iter in range(20):\" x/ `2 O6 o# O( F8 a8 ]
  24. $ b9 x4 B. ?; t1 Y1 F+ s1 T: {' b
  25.         # 1) 消除之前的梯度(如果存在)$ Z1 P# g! v, J3 g6 v
  26. / e7 l4 ^$ q9 a$ f
  27.         opt.zero_grad()/ U, a! N; J2 y
  28. 1 j1 H, d\" Y\" ]9 \7 ]0 c+ E
  29. ( ]$ P! j\" V( Y% C! X
  30. 2 t0 T) O, R' u# q
  31.         # 2) 预测. f# D! E& \1 w: M
  32. % y7 D( r- W2 l8 B; B% N* l% |
  33.         pred = model(data): d: }& l; J( S. X) b

  34. & O4 z& |% w! ^$ M\" |- I2 Y
  35. 7 c# ?7 R& I, m6 h# i
  36. * D$ m2 N  L  D0 m) [9 n3 I% ^& t& [) k
  37.         # 3) 计算损失
    ( o( V! J# G. X\" n5 i$ U$ V. _7 q
  38. . ]; f! w5 }! z/ ^6 O* q
  39.         loss = ((pred - target)**2).sum()' v! J$ P5 o$ e
  40. & c3 |4 \\" ~( F5 t1 q- A
  41. & H' t  b* N! N8 }4 b) A: ]4 |& |6 N

  42. 8 s. j4 E& e( t* r
  43.         # 4) 指出那些导致损失的参数(损失回传)- J8 u: O. h8 f( V! D

  44. 0 U! d$ `& y; K
  45.         loss.backward(); R+ ^+ k$ H3 R+ l
  46. ' l9 D+ p& P  j6 D
  47.     for name, param in model.named_parameters():
      q! j9 c# ?8 Q) W, N6 c

  48. \" o, f/ v% U! I) |5 w
  49.             print(name, param.data, param.grad)
    * n) T, B; P, @- z  X

  50. 8 C: s/ D/ C1 S  w  i
  51.         # 5) 更新参数
    ( ?\" c1 a* h6 O7 p( Z/ M6 U8 r

  52. 2 h8 M: s& r3 J
  53.         opt.step()
      `5 _' L6 Z: z! ~; M# H\" Z
  54. \" t, Y! N7 H4 b  m. N/ a$ ]- p! V
  55. ) A- A$ f3 X. ]( p& ]/ t

  56. 3 h0 ?3 c4 d- U6 @
  57.         # 6) 打印进程9 b! L5 u6 X; [9 p

  58. - V2 @3 f) i% M
  59.         print(loss.data)  B, e7 ]$ Y1 {7 `9 t
  60. & h# \! R5 F' U8 Z
  61. 0 L0 G6 X  Z: M7 Z4 H3 t

  62. 1 z2 v3 ~7 R) y7 M9 \\" u
  63. if __name__ == "__main__":. j  \, _$ k  h6 G: z* H8 f

  64. + f# L; ^* D% Q\" m' ?\" r, g5 z: y8 B% B
  65.     train()7 R\" l% t7 N9 m0 |

  66. \" c1 ~0 x. d  J
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。9 z' ?; j# R" U% [& H6 h

% M, J" F2 q3 |  X计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])7 n1 o/ R, ?; u8 \

  2. : [4 J3 C6 s  a* r\" U; C
  3. bias tensor([-0.2108]) tensor([-2.6971])
    ) h# I( o3 X# P+ S, C1 W
  4. 5 j$ g1 X7 j, B
  5. tensor(0.8531)+ ~: p  A6 H+ L. d2 w
  6.   O, z; P# |2 O5 s' c9 R
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
      s) e% M( v+ }( H; u. M
  8. ! C0 Z' i/ [& ^/ p9 q' p% Y0 N% N
  9. bias tensor([0.0589]) tensor([0.7416])
    7 ~7 N* v9 T6 z

  10. + v% D/ [/ m6 Q3 F
  11. tensor(0.2712)0 \4 ^0 X0 k\" m\" a) b8 M  d/ S
  12. / |! ]7 l6 I$ t+ r* R( i5 o+ I
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])- A1 A( y* y) F' Z1 F

  14. ! \, y- G# p\" u& `' J2 m# O! q
  15. bias tensor([-0.0152]) tensor([-0.2023]). M  B: n( A$ E
  16. 3 [' E, B1 v! d  U( o
  17. tensor(0.1529)
    + \/ ?! v! `% m

  18. * c2 q, y' X5 w: k! [5 H) O
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])+ p  a- }* S6 j4 c3 _

  20. 1 X\" @' L3 V& ?& H8 d
  21. bias tensor([0.0050]) tensor([0.0566]): F& E5 R% x* z5 B
  22. 5 b6 |- d- }2 S
  23. tensor(0.0963)
    / J# K, a4 t$ m7 L$ G8 g; j! d

  24. 0 i! u% [# E* b1 v2 c
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]]). D) ^* |) H0 U

  26. ! U& V0 c+ g. l# O( O4 ~
  27. bias tensor([-0.0006]) tensor([-0.0146])
    7 D5 _2 g# C6 F. Z

  28. 3 t; d2 h1 k( ]& a5 E. O9 t; Y
  29. tensor(0.0615)
    . @% L0 g# f7 h6 R2 E: n
  30. 9 w- g$ |5 J' Q0 F& Y2 z) g
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    7 ^  K4 k  }7 d8 ~) L, ?# W
  32. 6 E9 T8 Q# w* U- J9 Y
  33. bias tensor([0.0008]) tensor([0.0048])\" K! u$ A/ X* b9 |* \

  34. , b; \3 ]9 W1 H
  35. tensor(0.0394)
    5 m) k3 l5 C/ L
  36. , L0 V, R5 r& g/ e$ @
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    0 R# e: m3 Q, ~! F1 J. ]% C3 z
  38. + V2 k/ W* |6 g  W' M. q: }  Z+ P: g
  39. bias tensor([0.0003]) tensor([-0.0006]), L  A; c$ I. L& n6 w; M, Z

  40. \" f\" a/ \9 N- ^( b5 p\" b
  41. tensor(0.0252)
    8 Z, d; H) {; n% d5 l

  42. # J& b# X9 k- k+ R3 [# v; q
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    5 k+ ~  B9 e7 Z0 X. M( ^

  44. \" U4 E! Y1 r7 J' t& y5 x' v, S
  45. bias tensor([0.0004]) tensor([0.0008])
    , }\" L0 i9 U% }7 A

  46. ' r( Q\" Z. ?2 B9 m' h' _, V& R
  47. tensor(0.0161)\" y* G6 |7 B4 G, g- Q% S

  48. $ `- \8 a2 F2 b3 K9 \7 R
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    ) {4 F\" d! }$ h8 G8 q8 z  H

  50. 4 |! J2 b. Q9 S1 ~
  51. bias tensor([0.0003]) tensor([0.0003])4 V: z: M. y4 Y2 O! D8 v8 t  q  r

  52. ! p8 u9 L- g6 L1 v
  53. tensor(0.0103)' w; e- _* p2 Y
  54. 3 [# Q7 p8 J. k4 E, H
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])( R9 L4 G& m* \+ F\" q\" N9 ]
  56. # Q: N7 t7 H  a! p9 d' d, k/ T
  57. bias tensor([0.0003]) tensor([0.0004])
    ! X( @: I7 s+ T
  58. + ?  t* s. Y2 n1 ]1 i4 u1 C
  59. tensor(0.0066)
    0 h: u6 K3 r1 S! v
  60. ( p0 N% F4 y6 t) |/ z
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])) l2 {& S0 j1 H) c) ~1 Z% A
  62. ) T5 f$ X. j% p% a! J  i5 q
  63. bias tensor([0.0003]) tensor([0.0003])
    # N/ x* H. w* g3 d& J8 _* z

  64. 3 B3 D5 ~: w8 ~2 @
  65. tensor(0.0042)
    ' A: v9 s8 K8 D( N3 a# v

  66. * V4 [1 v$ ~0 o. m* V6 |. U
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    4 ]2 i& H) M& h# L
  68. % Q- ?$ O4 ]2 x7 G! o. u
  69. bias tensor([0.0002]) tensor([0.0003])
    \" w4 k( o4 E3 Y  p3 s1 G: E% X

  70. 6 C5 y- D, V( g
  71. tensor(0.0027), n2 r8 k! k7 F9 a9 W) _6 L1 ~; u$ L0 y

  72. 5 K& O/ A' S, k- }
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
      x% y\" K7 h( K

  74. $ _6 I0 j# @' t  N7 D( t4 M
  75. bias tensor([0.0002]) tensor([0.0002])& h# y& C! r5 ?8 J. X# D2 _% {

  76. / i% c; s; r2 H! o
  77. tensor(0.0017)8 A: r\" G/ `1 [- Z$ t# g' w: k
  78. 3 j4 W. t% k2 Y. p% d; a& R) H& ^* u8 d( H# @
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])  o6 p* G; t7 U# P4 ?  q/ S8 [+ B, Z\" m
  80. - H- m5 [* ?! Q3 I0 s5 t
  81. bias tensor([0.0002]) tensor([0.0002])- n+ e7 V( I4 Y- N. |  l7 P& ]

  82. 7 i5 w& R% z/ ^' [1 {$ B
  83. tensor(0.0011)
    # @. x& o; m  |* L2 {
  84. \" \2 e9 Q# ?. N  d
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])7 ]7 y7 T  Q' d% M4 u5 o\" r

  86. : B9 o+ g3 o3 J  v+ z
  87. bias tensor([0.0001]) tensor([0.0002])
    / T& w7 @- z1 ?& z* q

  88. ' _7 t\" x1 H; R$ v) c& p- a
  89. tensor(0.0007)
    $ u( ~+ X/ _3 _* ~# V
  90. / b- o8 Y7 O: |- d/ k2 z\" b
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    ! A( Z# \* |+ v6 d

  92. * T# T0 J1 l: k3 L2 P0 h0 }
  93. bias tensor([0.0001]) tensor([0.0002])
      i) h4 H3 p! T: v: g2 T+ H$ ~$ B

  94. . z6 ^& s8 D\" d
  95. tensor(0.0005)0 Q. L( K# U7 H$ \. S8 i

  96. # q# ~7 M. _* A# ]# o. M# c\" L3 _
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])\" `/ ]\" Y' q. e$ Q9 J
  98. 0 O6 ~) f/ @. s. J0 c: V; H& D1 e* a
  99. bias tensor([0.0001]) tensor([0.0001])# f8 ?' m% j( `\" N3 D$ B
  100. : K/ Y- u# O+ @1 P. W
  101. tensor(0.0003)+ L( G/ U6 p6 i% Y

  102. ) p6 }! D3 W7 Z, j& g) I+ w' f
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])/ x$ |8 ]\" s% x4 `9 V- [1 P; S

  104. 4 D+ g- K/ G6 o% K( O
  105. bias tensor([9.7973e-05]) tensor([0.0001]); ~1 z5 j\" D5 q, m
  106. 5 K( f6 b: E! q, n. }  x, A, {
  107. tensor(0.0002)6 P0 m/ i& d8 f6 b! F

  108. 9 a2 h$ _0 w5 M
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    , O7 v5 D% w5 C1 G\" ?' d

  110. ! d4 X/ r  x- y$ H: T
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    6 q- f' x) [7 [6 _  \

  112. 5 j% J, D: O, O% a- O! y5 ~
  113. tensor(0.0001), m& b9 ]* q) M6 t

  114. * Z\" ^  a1 y* t\" I
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]]), ]4 R  p# t# x$ W- t
  116. , c- C7 z( u$ o' p5 q
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    ; Q& D  ]: T& {! N) @$ j: Y
  118. / g\" N  t, M# ^4 O% P\" a1 |0 }! @6 g
  119. tensor(7.6120e-05)
复制代码

' v: ]2 a# L/ b' k! g: ]6 e
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-9-17 20:34 , Processed in 1.081605 second(s), 50 queries .

回顶部