QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2667|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1189

主题

4

听众

2934

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |正序浏览
|招呼Ta 关注Ta
SGD是什么
! ~$ C7 Q8 [- `3 l2 y* q# u- q2 ~SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
; L% \; E' D- L( \! u怎么理解梯度?
% A# s5 m2 s% u9 i  y4 H假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
& |* v& e3 T$ _# P& S: x- {2 b9 Q! d, @/ S; d- V0 t
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
0 s9 r: r  F/ o& t: m$ a* e- f1 A- i8 ~; l5 n0 m$ v  A! I
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。2 p$ B) S; R& |' D( q! k: G: @; a
% ]) h1 {& }+ I0 j
为什么引入SGD& y6 g! q# h+ @- a
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。1 t4 j8 q+ Q" g; L( V7 R! T

& f. `, h8 F. a6 R! y怎么用SGD
  1. import torch
    0 E# f* \$ n8 K4 H0 I

  2. 5 P7 X8 P0 q+ j+ H: V
  3. from torch import nn, e* ]8 r1 f( y- x: m$ e% p. |- G

  4. - R) Y, n) N0 Y( S; s1 l\" F
  5. from torch import optim
    ! b: ]5 c# N# X1 U% I- |
  6. % w4 \+ s* V8 F, a; x1 j' Z& Q$ I

  7. ; k9 R$ T% z6 p! V( b\" f

  8. 1 S! V5 @; V9 Q/ A+ {9 G5 ~
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    / I% O; W3 N$ h: ]5 e
  10. 4 r, m* _) n/ z\" i, ]
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True), c: g* ^. m  L( o3 z9 ^6 X

  12.   n8 u3 t\" S\" e- J9 }

  13. & n' \3 f% S# [9 p& r, b& x\" G

  14. , |: b9 {( o# C3 P* w8 Y) O# Q& z
  15. model = nn.Linear(2, 1)3 [6 U  L8 r5 e6 |0 i  T
  16. ! p- n- X1 @+ l* m/ _: `: K

  17. 1 u# V2 [+ s\" p7 I, ^, T2 M( r) B4 a' o

  18. ) T7 f( s! D. A, i
  19. def train():
    1 r$ R1 _3 C; b1 M

  20. 7 k9 _$ f( }( ?1 D
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)2 J. r/ b8 o0 a* Y; t4 y
  22. 4 s$ {8 a6 Z0 o3 b' ^
  23.     for iter in range(20):' j/ T, j. n+ e+ y

  24. : o! ~1 ]& P  h% T
  25.         # 1) 消除之前的梯度(如果存在)
    / [0 \/ m  f$ l

  26. ( |  W# K. r  E, K
  27.         opt.zero_grad()
    * ~9 e, K2 P$ F
  28. 5 A8 d; P, t3 W/ j

  29. , Y0 i! x; ^+ s' }7 P* z  L7 x& ~
  30. , }+ m6 [/ I* j/ c
  31.         # 2) 预测
    0 f9 ]7 T' t3 ]( S\" ~* H% m2 M

  32. & w, Z2 M( v# W7 B+ k* `' M
  33.         pred = model(data)
    + k$ t. y. `7 C2 |& T

  34. : b' a# a$ W1 Z8 s  \2 c4 Y

  35. : W5 [5 t+ l\" Q1 c, c

  36. % E: C7 t1 J$ e! X% t0 y, C6 W
  37.         # 3) 计算损失* I4 ~\" Z% O9 m9 P

  38. 7 Q3 U9 y( b% Q+ s
  39.         loss = ((pred - target)**2).sum()
    ) g- @2 A' S1 q; H

  40. 8 B! |. J- y- \6 _& Z
  41.   l: n6 p/ v% L0 @: I
  42. 5 e9 w3 H+ l6 \6 J7 \! \
  43.         # 4) 指出那些导致损失的参数(损失回传)- J2 c/ Q2 F* j9 H! H$ c) N* s
  44. $ _& p. h7 m/ K  O: t. ]
  45.         loss.backward()
    7 |) \3 w/ B1 G$ w8 A% B3 [: Z( V
  46. ! H% @' h( K3 r$ p. O) @
  47.     for name, param in model.named_parameters():
    1 c6 t- W8 I' e4 \( ^- s$ G

  48. \" X' o( d# K9 Y: }- M/ s2 H
  49.             print(name, param.data, param.grad)
    ; K  B  w0 z9 v' _0 l
  50. 0 H( m% n8 L! K8 Q+ H( c. y6 j
  51.         # 5) 更新参数. m- M; L1 \4 M6 O  Z# @

  52. 5 p3 i) ]5 J* S( I\" X1 z# x3 I
  53.         opt.step()  p\" n: ~( U' o6 Z% l

  54. * M\" N\" r+ F, n; C8 P$ x
  55. 6 t# ^8 }( n* T7 ]6 y

  56.   T1 u) K+ [6 h/ k& M% a. m3 h9 |
  57.         # 6) 打印进程1 O: J0 a! W  b
  58. * V( N& }% `- p9 R: O
  59.         print(loss.data)
    9 [5 R\" D0 _* a6 Y/ K

  60. ; _1 S. A: X; [* \* A) K  d( k2 q

  61. ; W8 W( y, j/ B$ i) l& q, t

  62. ) V* `  J8 X1 i% x$ k' u
  63. if __name__ == "__main__":
    7 g% G+ y$ ^, ?# O

  64. ( A4 \1 s# Q$ _
  65.     train()
    ) Q8 j# X  A2 a1 o
  66. % \) v. B4 Q+ J, T2 P; p) |. \& u
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
0 a( y+ r* u4 [% B  Z* i' S  l2 f0 v" O% G; J
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    ) b# e2 S) ^2 \! l+ W4 Q9 v4 X
  2. ; O: R' I, j! r: i- `. N6 f& _( h, B
  3. bias tensor([-0.2108]) tensor([-2.6971])\" X6 c$ k6 b% @% l$ q
  4. / H! t$ o' u, I) D0 P! Y
  5. tensor(0.8531)7 _, [2 D; r9 |% d
  6. : ]& Z\" I0 B/ N
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    \" H; `\" I2 C) B% U) u! b$ n7 E1 ]

  8. ! S- O) S8 s) O
  9. bias tensor([0.0589]) tensor([0.7416])/ ^4 S% ~7 b! K8 i; u. N8 _, e9 U

  10. 3 y9 s2 J+ x\" P6 Q; o+ B
  11. tensor(0.2712)
    # J  U. J; v1 h- @4 A0 d# f; v5 O

  12. 1 X5 {3 ]. x1 P( m9 P
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    $ ^3 j( l2 f% P8 r( Y+ O
  14. : F# i9 [. [: |. ?$ l! m
  15. bias tensor([-0.0152]) tensor([-0.2023])
    9 M% P  C( T0 C. O, t- u
  16. % C2 I7 W* X8 X% P! |4 p; s\" O! u
  17. tensor(0.1529)
    / ]* P7 T3 R1 c# Y
  18. & L: @3 d6 y0 d1 }' }
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    * @& G/ P9 y% _  _

  20. 6 r1 {/ V  [: M0 ]/ K( X/ N
  21. bias tensor([0.0050]) tensor([0.0566])3 E: |! d) }6 Y3 i7 U

  22. ; n6 O3 M3 h\" p% @) t- c\" m
  23. tensor(0.0963)- J- a: C' _& l2 c  [

  24. \" D0 M( N- {) i8 p7 B\" `% W
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])$ o! U3 g- w# z  Z2 h6 ]
  26. + d. }9 Q, J+ f# v1 x5 Z$ a
  27. bias tensor([-0.0006]) tensor([-0.0146])7 ^( _) K1 X! \7 r
  28. 2 _\" i! k3 H7 ?  a  C
  29. tensor(0.0615)\" e' s2 S% a! A0 u1 \3 h9 r+ `

  30. 8 G$ ?# ?- S2 B. @# Q( n
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])9 O# u- k7 H# j% E; ], T; D- B

  32. * s; z1 ]4 l: f$ \
  33. bias tensor([0.0008]) tensor([0.0048])- \( P) c! `/ t
  34. ! i  b) s- z' o! c: d2 _$ {5 ]
  35. tensor(0.0394)5 a7 }5 [. G6 @
  36. & N# ]5 f' b7 W
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])' o  Q) k  [( a2 E! y: ^) m
  38. ) y, }6 s. m6 D5 ~% h% t# [& H
  39. bias tensor([0.0003]) tensor([-0.0006])
    # R. T; \+ D5 V. w
  40. ! O: x* i3 b* b4 r
  41. tensor(0.0252). y3 [+ a3 H4 E! L# ?# |8 R\" }

  42. 7 W4 O+ ^1 {- X, N% w( _# A
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])- t/ l* }4 [; Y

  44. ' ]5 x4 y$ e  E
  45. bias tensor([0.0004]) tensor([0.0008])! C. f9 d$ j7 x
  46. 2 z: Y4 f\" @! M* R; n8 E
  47. tensor(0.0161)  G8 H8 X: u\" `2 D! @, I/ g

  48. ) L0 c/ B! ^! E  s
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])! s8 d% U, h. ]/ k+ X/ A
  50. $ \' ~5 U9 d- P0 f0 j5 W6 u# W
  51. bias tensor([0.0003]) tensor([0.0003])
    % ^4 C; i! y* `: V
  52. 4 F! F7 y9 D9 C% g\" U8 X
  53. tensor(0.0103)
    ( Y7 Z4 _/ @6 _- q: o4 F$ s
  54. ! i+ d* s. p7 D$ v/ e
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    5 U4 S& U/ C  e; l( R
  56. ! I: \3 H( D, x3 i+ P
  57. bias tensor([0.0003]) tensor([0.0004])
    5 t9 Z  R  x$ G' {2 e; Z* _
  58. / U; F* a7 K5 _- r! ?- L
  59. tensor(0.0066)/ R- j% M2 t! J/ T: n) w) z
  60. + q6 e( D7 U' S$ j* V' e
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    3 M( E3 Y+ E4 j$ V4 W
  62.   D0 n+ ?9 t* k5 v/ M
  63. bias tensor([0.0003]) tensor([0.0003])/ C7 A  {6 b) C* C# }3 {# {, A2 W

  64. & G4 d: m+ G8 ~) W* r
  65. tensor(0.0042): r& w) k4 a, _+ b

  66. $ a\" a( j  G7 r/ w6 T- r( L  Z
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])! s7 _9 c/ V( y4 q
  68. 5 \* J  s' |: W7 A8 `
  69. bias tensor([0.0002]) tensor([0.0003])
    8 z! h: t# Z4 ]4 r; e; h) @

  70. \" l% M' l' Z4 ~* M3 o& F4 R! L% J
  71. tensor(0.0027)
    , m' a. [6 b# M- Y5 t- L. F* s
  72. 6 a; E6 Z. F% |, l8 h% v
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])3 ~6 @$ i( P2 @4 N% Y  U& ?0 s

  74. 9 S8 D  a% C9 Z8 f- U) r! a
  75. bias tensor([0.0002]) tensor([0.0002])$ w0 z- }& `  w+ }) ]

  76. + k; R' u! n) m! p+ A0 u
  77. tensor(0.0017)
    $ m# O2 A+ g4 m( O
  78. $ U) w' w9 D3 N* t$ z\" C* r3 x- O
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])7 q$ X7 h: m. M) @+ D

  80. 8 y% L+ ?2 v# A- B% P% m2 N4 ~1 x
  81. bias tensor([0.0002]) tensor([0.0002])
    9 A  k7 C/ K5 p

  82. 3 B' {9 {, n1 n8 m9 q/ ^3 N
  83. tensor(0.0011)) v/ A% t; V: l' C1 c3 Z' m9 Q

  84. 6 M/ T\" p) @' I7 p
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])- L7 l: {& d6 l5 R( u; {
  86. 2 O- v' g- d8 w+ j% X8 B6 y/ @+ O3 v
  87. bias tensor([0.0001]) tensor([0.0002])
    1 _' d2 x2 H, i2 k. L\" s\" m
  88. 3 V6 e8 J( @( z: N
  89. tensor(0.0007)
    # f1 w5 s\" i' m7 d+ ^
  90. 7 T, f, P7 i4 L' Q1 d
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])* Q/ M3 A9 E\" S# f

  92. 8 b2 U* m7 g. X) u6 g
  93. bias tensor([0.0001]) tensor([0.0002])
    8 H; a/ @4 x( f* B
  94. 3 c# N) l- y& ]& u1 F! J$ B+ r; B
  95. tensor(0.0005)9 r\" }9 ~  n\" o/ V  j

  96. 4 G# ?. C# T0 {1 j
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]]), n1 l- l2 i. W0 y3 z

  98. 9 J0 }( U0 @# n3 I\" q) w( j
  99. bias tensor([0.0001]) tensor([0.0001])$ p9 j8 n  y4 ^: z0 d; l

  100. / e$ E/ M7 d  z: X
  101. tensor(0.0003)8 I$ y* q5 d! G, U- }

  102. 4 G7 [5 a& j! ^$ r! [6 v
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])# b2 I0 f' G; ^  m9 _

  104. / ?5 r\" z! _6 B  a' }
  105. bias tensor([9.7973e-05]) tensor([0.0001])* u; O/ S/ b5 X2 ^/ z

  106. $ u2 Z# h* }. Q8 R
  107. tensor(0.0002)
    + r& D& ~3 U- M# t1 o8 {+ Q3 ?
  108. # @) H0 P  G/ o2 F) C+ u3 T7 g
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])\" {5 {5 T7 o/ e1 l; S, ^

  110. * f+ n; Z, X  f; j3 L( ~& D
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    0 B- j7 M6 Y. \. M( w
  112. 2 S* A# j6 q0 q& S
  113. tensor(0.0001)$ A. i! Z9 i9 v' w5 T/ A

  114. : d& e1 P; ~4 w
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]]): h- Y, K) N5 N% |& o
  116. 4 F# J2 E, M# O; _2 z& ^
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])0 {3 X3 R8 G( ^* `5 }9 M: U$ ^

  118.   x1 U3 x+ g% ?! B8 z* h( R6 S
  119. tensor(7.6120e-05)
复制代码

* k; z9 l/ r! m
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-6-7 11:25 , Processed in 0.433308 second(s), 51 queries .

回顶部