QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2620|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么: O/ |, d! |" R6 [8 i
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
- O& y  |! p: x0 x怎么理解梯度?! ]5 j! T0 \  v8 Q) G3 D  g
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。: t7 j7 d& n3 w1 Q. I; d
, o  C3 }/ j( ^2 I1 P' {0 M% S
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
6 J, y. U4 X) n+ {  x! F: b$ a! L" I3 o8 k4 N7 L
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。9 n! i* P* k3 I8 B8 c
/ r. k. O# a' n2 q( g
为什么引入SGD, t7 v% z( ]0 o* O8 ?1 R4 P2 J
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。0 ]5 s% ~( ]5 M3 j8 ?4 j$ t3 u3 P
. x; E9 i) H  L! A
怎么用SGD
  1. import torch! W+ ?6 M  o2 P2 U

  2. 9 p9 I& G1 R7 M0 W6 ~3 y7 [2 z
  3. from torch import nn
    5 w+ k$ j$ n! K& A+ K5 H. ~) d. q

  4. # T8 d& G9 q7 G4 A
  5. from torch import optim0 @' U1 |, z7 Q. |
  6. 2 F- I7 J+ O% H3 `/ Z, T9 S7 @2 k

  7. 0 b1 G, M3 z6 p+ ]. F# j
  8. 6 N\" T& w. U1 o; K/ f+ W
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)6 z6 V. D3 v; {- `
  10.   |* k9 R\" V, C+ `
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    ! y; @5 M$ J' C, s: K/ `$ H

  12. 9 o  `/ F- q! \
  13. - @( G% Y; X+ N# s( m, F; t
  14. , f* d1 [9 ]\" ~4 a5 f3 z, Q7 _
  15. model = nn.Linear(2, 1)) J3 e\" }( }+ a, O6 Z2 C7 f  q# E3 x
  16. - O8 N  p% v( m3 _, i9 C0 Z5 P

  17. $ H/ S  t/ P6 W5 P
  18. 8 F: A: ~& A  I$ l* k9 ?, e
  19. def train():
    3 ^5 ?+ F$ L) N. o$ [
  20.   s( ?+ K2 u- J% T# v+ d/ m$ Z- u. d( U
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    ' b\" c# B* W- y7 a$ t$ \

  22. 9 Q: _/ W) E! {
  23.     for iter in range(20):
    7 ~) z* j. P6 z! K+ i; @% \

  24. 7 Y6 I\" G/ i& ~: j8 q\" e
  25.         # 1) 消除之前的梯度(如果存在)
    1 }( \% c! n) a
  26. 8 m7 {+ i; E$ A& t3 \$ n
  27.         opt.zero_grad(): ?1 S/ y1 S4 G1 ]+ z: a, |& R) i

  28. 9 D8 A5 }5 R; F; b

  29. ; E, J' J\" ^( E6 \3 @7 {. q% {

  30. 3 J$ i* `9 {. ?
  31.         # 2) 预测
    / ~+ ]4 z2 L; t1 M2 b
  32. / d1 n+ g* ^9 B9 J3 X5 D
  33.         pred = model(data)\" `; K1 P: w6 o$ U4 V, E8 I

  34. & g, N9 Y: V: j& _3 e  R/ B/ `

  35. 4 y, E' [2 z* W. j& X- v! R$ |
  36. 1 o- T7 y% U, P. j6 G
  37.         # 3) 计算损失  c2 g+ a* w5 v\" w1 ^
  38. 2 p\" i( |- p8 a( Y+ t6 u2 f
  39.         loss = ((pred - target)**2).sum()# E, J  E7 l\" C, m; ]/ N2 S  Z# ?

  40. - p4 K% m% E' W$ {
  41. 9 J8 r7 A: T9 _8 t) _
  42. # R- ~+ O$ I1 M/ C) Q
  43.         # 4) 指出那些导致损失的参数(损失回传)
    / w* |4 w: h; {4 A
  44. ( J3 O) v! H3 Z0 ?
  45.         loss.backward()! {3 w1 P0 Y- F$ @
  46. \" P, a8 L' A! `* ~
  47.     for name, param in model.named_parameters():# W8 g  T2 l4 Q5 ]
  48. ! s3 G8 ]: L& f' w+ j
  49.             print(name, param.data, param.grad)' n# g/ Y4 b) D! Q, Y
  50. : m# d& Y; F. G6 H
  51.         # 5) 更新参数
    3 Z/ e\" W5 B! ]; D

  52. % R( A+ m8 ^5 c9 Q
  53.         opt.step()
    6 i1 h2 a9 H6 f9 F4 K* E. m& q  o: I
  54. + ^+ u* D& E: I# e& k6 q

  55. - }5 T2 P( i. I5 }3 b0 v* [
  56. + D- D* O6 R1 p\" G: H! @6 R
  57.         # 6) 打印进程
    . \! C# Y% ]4 h3 P5 W$ B

  58. ! q, z5 m4 J5 \# T# M! r' G
  59.         print(loss.data)* d% h, E7 j9 K

  60. 3 m) N: @& i$ O* s- P1 k5 L% }
  61. 5 s8 ]( Z\" W. t9 n1 c
  62. ; I& m! N7 n' J
  63. if __name__ == "__main__":3 a2 j, H' R: m8 E
  64. & n; C. ~5 ^9 j\" Q7 O
  65.     train()9 j\" j. ^6 Z7 e/ D0 `) ^9 i; Z
  66. \" ^! a$ q8 \; ^2 }. u1 I3 V\" H
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
5 I* a, A. I$ \
6 F. h" q3 t8 @3 E3 X* h计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    ) T1 m' }\" x\" V/ C& E( P/ t

  2. ) j* S. z$ m- }/ {2 W
  3. bias tensor([-0.2108]) tensor([-2.6971])9 e! Z. ]. }+ Y5 j1 A

  4. 2 I: L8 N4 \! g
  5. tensor(0.8531)6 [: S, y7 i! b4 F3 C+ a
  6. 6 N9 ~\" r& O- `# D) [7 r
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])5 Z* T, p+ y2 K+ `0 q. \3 R' H
  8. 9 b0 v8 y\" |7 h# U8 k% @1 J# J
  9. bias tensor([0.0589]) tensor([0.7416])- q5 b$ v  k; |% C: K
  10. ) V1 W' D# B  {8 r
  11. tensor(0.2712)& E3 b, n, A1 \; t8 X# X
  12. + e; Y9 q. k' U5 u7 C3 W$ W
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    7 S9 i. G) I' E' ]

  14. % R% y9 C  x& C( W3 U, W
  15. bias tensor([-0.0152]) tensor([-0.2023])
    0 P0 x' }+ M% R3 i1 U

  16. 4 Q\" p9 F* A( w. {4 q$ C
  17. tensor(0.1529)* z( Q/ y! o7 o& h! r
  18. 7 o\" f. t5 S: E% Z, ]' E  k
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    % B, k$ s8 B2 @, I5 D. e2 s& S$ u4 U

  20.   o' f% ^/ r# x% H\" k. _) k- D$ `; a
  21. bias tensor([0.0050]) tensor([0.0566])) N\" F3 z1 V4 y6 S5 i( T- F. F2 T
  22. 8 f9 m\" ^7 x% J, m9 ?* X9 A. V
  23. tensor(0.0963)
    , V, x0 |0 `1 i7 T) I. w

  24. , u2 D$ F$ W% v0 H* m, ~9 G, n& R5 q
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])# j; L0 h: P1 O( B  y

  26. # n, G! [0 u; y9 X7 K0 G
  27. bias tensor([-0.0006]) tensor([-0.0146])) N: Q+ Z& f% R% H\" l. R) M- Q
  28. : w0 m; S3 p# n' B! H+ T+ J' u
  29. tensor(0.0615)/ U; x  ~9 A, n) S\" T8 V9 M4 V

  30. 3 U  ?8 O) o! F/ n9 y2 a
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])  `/ T, g) J7 Z! M+ R

  32. * S2 `2 X% Q+ \
  33. bias tensor([0.0008]) tensor([0.0048])9 v: q\" A6 N+ R

  34. 4 g7 i; e2 t& f  S2 v5 h, v
  35. tensor(0.0394)2 D/ F, f8 t- ?$ _- B' j
  36. 6 t\" K+ m# j! Q  c0 ~+ v
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])8 I, d( K1 p, X4 T
  38. . j( r8 Q5 H$ _3 A; @
  39. bias tensor([0.0003]) tensor([-0.0006])9 ~  }' g1 b- |% Y6 c# I
  40. * W4 X0 @+ q\" V& b  k$ m4 ~
  41. tensor(0.0252)
    : W! x' T. c! e4 m8 S) W\" p+ Y

  42. 2 u2 Y5 @9 J: Y* N. S( B4 I% Q
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])4 X3 Z! m2 U7 u- b4 f. y8 \7 j
  44. ! B7 T\" ?8 e1 f/ X; ]
  45. bias tensor([0.0004]) tensor([0.0008])# L% ]; a4 i: N

  46. \" P9 ~\" l, z* x. t. V
  47. tensor(0.0161)! R2 Z) T, B0 ~9 p7 w

  48.   Y5 y% e9 S( B3 ]3 l, S; `: n
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    % Q2 r! q$ N$ U, {1 L* o4 Z( N
  50. 9 T4 B+ y3 K+ R3 L: G
  51. bias tensor([0.0003]) tensor([0.0003])
    ; r) N  f* v- E( `# S
  52. : S* E) ~' q* O
  53. tensor(0.0103)2 N) Z6 B  @3 U! c) l5 Y
  54. 4 y% c8 T8 g9 W+ `6 G  P
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])( G/ a4 B, R\" w$ J
  56. ) f# k8 f+ Q8 U& T3 ?7 \9 n! M7 W
  57. bias tensor([0.0003]) tensor([0.0004])  m8 x0 D, C7 l# W' @9 j. z
  58. ! E/ y, P  {- }- }$ F
  59. tensor(0.0066)* U\" Y) J* Y5 ?

  60. 6 }; n2 `! }8 x0 e2 g% O: a! l
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    # ?% n( d! e0 M, D- e+ e9 G: @

  62. ; e& j+ v  Z& f1 k) |
  63. bias tensor([0.0003]) tensor([0.0003])
    & ?. G: N9 j4 ?3 c5 ]

  64. : j4 R. Q4 g\" f4 D, _
  65. tensor(0.0042)* C( B/ r0 z- I+ C5 ~, u, u

  66. 2 D\" `# B) m9 s4 u\" ~8 |
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    % U0 o5 O2 Z\" m6 [- D
  68. 1 P5 V3 e8 X8 s1 }
  69. bias tensor([0.0002]) tensor([0.0003])
    # `( O8 k! o2 P5 A* E( ^9 H
  70. + q, n1 D( {7 c$ t  K  f
  71. tensor(0.0027)& k) c  v) s% Q
  72. % V9 z4 ?' h& u- o: A
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]]); ~1 I2 Z  w) g0 e5 Y* Y8 }* B
  74. / ?\" Y. r( `/ D3 _
  75. bias tensor([0.0002]) tensor([0.0002])
    ; g3 k: v. [2 u0 K\" V3 Q

  76. 6 F' b! u1 b2 B9 j8 l0 ?
  77. tensor(0.0017)2 l1 ^8 |% X6 u; g\" o

  78. 7 f: z6 I; ?# L1 n
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]]), j6 C1 o6 y# Q+ x7 z

  80. ' X* @. n\" K# [  J( Z\" v
  81. bias tensor([0.0002]) tensor([0.0002])  ~1 Y7 T/ ^, Q' k8 t

  82. & N# B* l* \# R  N( @8 r4 o* Y
  83. tensor(0.0011)1 S+ i3 l* E3 b& D

  84. \" l5 l3 }  Y$ _; ]  c0 ?& k6 t
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    5 o% [8 z! k$ D! n& `
  86. + k! M$ B; y\" L
  87. bias tensor([0.0001]) tensor([0.0002])9 j( T7 c% B: z
  88. ) a& ?3 J' T( z  y& l9 n1 c; @! Q
  89. tensor(0.0007)+ {' x* A7 @6 Z7 e, u

  90. & D1 [2 M8 A\" k) Y5 D. u
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    ' f\" h4 x9 r7 X
  92. ( J, P( Z7 E\" a  v7 o2 V) Z
  93. bias tensor([0.0001]) tensor([0.0002])
      _\" R. G4 a2 c0 K) m

  94. $ V& k: V\" a) D( w1 f! r; F
  95. tensor(0.0005)
    4 Z& h7 H; S# q! R
  96. : t1 e# a0 J! v; m$ O5 z7 }. G0 @
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])# R! Z# V8 V8 z$ _- E0 y( g, U
  98. . q/ C: x' Z6 L
  99. bias tensor([0.0001]) tensor([0.0001])1 S! {) S' Y! O) E\" K; r( Z
  100. 9 |* A0 J  X  P1 Q
  101. tensor(0.0003)
    2 S+ }2 s  Q* K7 [: _

  102. 1 o; b! y$ i$ F% S. b& h9 X
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    . R) B$ o$ |) `' A: }

  104. 5 B, m* h( j! k3 k- z* ^4 n9 Q: {
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    8 [% m! \4 [1 y0 e+ N' p# }* o

  106. 8 t  A8 @) H7 M' Y2 c
  107. tensor(0.0002): k: F) @6 Q. L; ^# G7 l4 D+ _* Y
  108. ! F. e1 k- ~1 @! l0 D, I0 Q$ V% O' c
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])5 }% G3 v- J. D& E% [* [; q
  110. % }/ L! [$ Y& }
  111. bias tensor([8.5674e-05]) tensor([0.0001])& H4 x8 Z9 v% @2 \4 R
  112. $ Y3 H' S* a2 y& Y; Y( m
  113. tensor(0.0001)
    ) u' s! Z, e  R' d

  114. ) B- i. [' @! ~: F5 U  n8 S
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])' P  w0 U+ l; a7 w4 Q( e& ~, Y

  116. # l/ s  F9 L# U1 l
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    - X9 t  V7 f' o+ n+ Q

  118. # l+ i0 g: M0 E; X; w4 o
  119. tensor(7.6120e-05)
复制代码
& f  A+ U) D2 g9 n$ f; k
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-10 13:42 , Processed in 0.344390 second(s), 51 queries .

回顶部