QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2633|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么3 V+ |# Q5 y0 r7 W6 G, {" w8 |
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
$ P* ~; B' S4 k3 G9 ?7 C怎么理解梯度?
4 B4 E8 ^+ f0 e3 P& @2 E; b. U* V, m假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。  E2 I2 g# ?8 U, T- {

! H/ x+ {" L) g% h/ j在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。; g9 m  B# f5 P+ u1 [5 x
) ]2 y8 Z7 z9 w5 k1 ?
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。; m( J9 l/ V+ Y9 T: Q

6 l8 I! n- Q6 G) V6 Q$ ^' L为什么引入SGD
- @8 s% M  F$ L1 B; O6 K深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
: X6 b( p+ Q8 r' x0 s( X4 T5 N& ^# H8 U1 |
怎么用SGD
  1. import torch
    ' P0 t$ n\" D. c, h

  2. ; k1 r2 G) @& }( ~) x' H. ~& d
  3. from torch import nn8 F; M8 F. k5 X
  4. ) d* m/ C4 V1 ~, `0 P# X( f
  5. from torch import optim' M+ f3 Z5 K) P, D* t1 ]3 r/ n

  6. ; @# Z9 M) F2 e- ^5 I
  7. ) h5 a8 Q5 G# ~8 Z% U5 @; ^

  8. \" v( q$ a! B! \6 D4 k. s
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    & q( ]4 _: S8 G4 j9 ~. Z

  10. 2 X* y: j- o& ]- e  S$ g: |
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)\" v# O2 ?) F4 Q$ q' n4 r0 A, @
  12. - C  l* v, V, V3 Q2 T, D& K
  13. 3 M. {5 M1 o2 d8 R3 H8 f& ?

  14. 6 \# d3 m' P! `# K1 Z+ J8 P
  15. model = nn.Linear(2, 1)
      {' \; y' m; V! c! R5 X5 Q2 i
  16. * w2 Q, X- m1 |6 X+ P8 X' o
  17. & l6 M1 [& |/ {4 p  N
  18. + \\" ]9 {6 o2 Z6 }. t0 _
  19. def train():) {4 P: f4 q/ n/ k2 [! b

  20. ) F0 G+ p2 }+ H- x$ Q; ]* u/ U
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    7 N' A2 ~& M+ P( N! _2 b: n
  22. % |3 n0 g! q6 q
  23.     for iter in range(20):1 d+ l9 Q# Y$ L+ Z' j+ ~, t+ O

  24. 1 W! J% [: Y1 K4 s9 z+ p$ T4 z
  25.         # 1) 消除之前的梯度(如果存在)
    . O  `\" m( r6 i
  26. + F0 f% {\" }; s6 e; ?5 S+ t
  27.         opt.zero_grad()$ V# D\" f1 D$ q* n7 T+ }

  28. ' Z! a2 J$ G2 m# m) b

  29. 6 e+ F! x4 l* g1 ~

  30. 8 e2 Q9 o4 V( T# u
  31.         # 2) 预测* G+ @- O0 x; P' P/ U8 L
  32. . e/ }& j1 e4 ^' S# Q4 P, ]
  33.         pred = model(data)
    5 B\" M  z; a& y( K1 [' o

  34. ! R5 N* X2 u6 P3 h: I% [
  35.   R; f$ P2 p8 I! e& t+ h  z3 o# a\" _
  36. 0 H! r4 L5 d8 a( D  g1 [: R% J
  37.         # 3) 计算损失
    3 _% ?* s: F0 e( H- `1 \
  38.   s7 X, s1 O  e
  39.         loss = ((pred - target)**2).sum()4 n  \! d, k5 X- n6 G# W

  40. 8 j: N2 n8 `, ^; m
  41. ! ?( V* z4 v7 l9 p0 u5 A\" n
  42.   Z8 T& u' {* L, g
  43.         # 4) 指出那些导致损失的参数(损失回传); R( G& b, D1 ^9 r8 Z

  44. 2 d( j, b1 g9 I1 d
  45.         loss.backward()
    * T7 U) _- f! n0 L9 @) B0 M
  46. + M0 S$ l& A0 z
  47.     for name, param in model.named_parameters():
    $ k\" i8 A! Z# j  c9 w# G2 S

  48. * F  C+ R. N6 p0 i1 g# R  Q
  49.             print(name, param.data, param.grad)7 m6 ?, G8 G; v# J6 e
  50. 8 }* K) [! F; `% \$ q7 @
  51.         # 5) 更新参数
    , K: }! z  N  m/ P3 _' W. }$ Z
  52. ! s3 Z& G! j; X+ F$ v) I% {, h. v
  53.         opt.step()
    . Z\" U! X7 v# [/ h& c( b8 v
  54. ( j4 G5 r& W# s4 \\" c. Z5 C
  55. 9 a: A- U2 ^/ d3 a% F
  56. 7 T7 H& _; z+ j. y! B
  57.         # 6) 打印进程
    2 e$ a4 @, Y2 l& v% K' B4 E
  58. / G0 X$ e0 Y: ]
  59.         print(loss.data)
    + Y1 v7 R6 A! @, d

  60.   f8 q- H! n! t8 ?& L\" O

  61. 5 D4 J4 O\" a9 P) |
  62. / F2 ?. |/ ?7 U; V# b. C
  63. if __name__ == "__main__":
    . k, h/ Z# O! m, P
  64. 1 m* p5 V2 u# D  e. U, s
  65.     train()4 |- W( B+ A  J3 e3 D
  66. ' a4 `) s5 O; D1 G9 f  J# U. x
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
' u: W, `' V2 J5 P% |
1 V3 }( x3 [  j) x计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    + J9 }) N' ?, ], U\" P: Z. e\" Z+ S
  2. 1 ?9 C7 `\" Y5 e% Z$ P# [
  3. bias tensor([-0.2108]) tensor([-2.6971])+ I( T' j: Y1 `8 e7 h2 u
  4. 0 \2 S. ?6 F\" ]7 y/ R1 x! h
  5. tensor(0.8531)
    ! d4 ~+ I8 `/ [6 l# p' O
  6. ' R- X3 c3 k- x: N
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])\" @* H( V) h! c

  8. 7 X6 ^9 z% t/ P0 h
  9. bias tensor([0.0589]) tensor([0.7416])
    & R7 u# a9 U' {8 g$ \1 }
  10. ) r; f) b\" r\" M' s
  11. tensor(0.2712)
    & O9 R; h  l; S3 Z8 W% [\" x- j& }

  12. # g5 D, K9 G# j& q; N& B) E3 W( z
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    : q$ m! S\" O: }8 V
  14. * u. R3 S. @! l) b  f6 V
  15. bias tensor([-0.0152]) tensor([-0.2023])) Z3 g1 t! C# Y
  16. : I& x- L+ K) c8 K+ F: K- D
  17. tensor(0.1529)0 T! i4 ~) f1 C

  18. , O* h& S\" ^) U3 I, f+ Z0 {
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    1 \% k6 @! l5 z8 Y- D! T$ T! J
  20. # g& H3 `% \) e+ f6 g( \( i( M9 N: M
  21. bias tensor([0.0050]) tensor([0.0566])' L% [9 V& h0 h9 N

  22. ; r; g( J) z8 n1 }4 }
  23. tensor(0.0963)6 z. g& N. w) @' m7 E7 ]* Z1 E: [# w
  24. ) A- c- x5 m2 ]. E: m& a6 l
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    . A) x( O& c8 c: S, D* ^

  26. 4 H4 m\" r1 S) p. p3 e
  27. bias tensor([-0.0006]) tensor([-0.0146])
    ! D9 z6 P( k: O

  28. / U- v5 S2 {& ~' X) ?
  29. tensor(0.0615)
    8 `8 j+ n# l. g) @  _
  30. ' I9 [/ J) v% N9 L9 R7 E
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])( o& [\" R5 n& `- A9 B0 y) X0 Q
  32. 4 ~9 k) I' B3 a1 _) R
  33. bias tensor([0.0008]) tensor([0.0048])
    ) q+ L( H\" [1 r5 ]* H
  34. 9 O$ }4 m! z1 q# A( h  b
  35. tensor(0.0394)( e; h. b# h# Q3 z- p

  36. ; H1 A' e8 r+ z: _% {& f* P# {/ |5 t
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])# U7 i0 [, j1 F# Y8 G6 i+ v0 f
  38. : Y  A# n0 E\" B& u' c
  39. bias tensor([0.0003]) tensor([-0.0006]), |' c: h. _  _3 B3 N6 F: x
  40. ' S8 a, g% ?, _/ y
  41. tensor(0.0252)0 Q0 N/ |: Z3 V) x) W
  42. - Q; @( P) {( j0 s% q3 x
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])* ]4 o1 ]/ e\" B
  44. 0 ?4 ?3 V! _) a4 R- D
  45. bias tensor([0.0004]) tensor([0.0008])
    \" W* Q4 A6 h  K) w% r
  46. & i3 R' S$ Q& a& o
  47. tensor(0.0161)7 S; X& `\" T( |' ?3 |
  48. # o, C' }6 m\" y1 t+ ~3 |. v; L
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    9 ]- v4 Q- q5 B\" g\" Y
  50. # X( Q\" t/ F. e0 [% T% P
  51. bias tensor([0.0003]) tensor([0.0003])1 |0 g- s. o) L$ K* T
  52. 7 W+ Z& @/ I3 t7 Y, b
  53. tensor(0.0103); ^* B0 \& R5 W5 `) Q  t
  54. ! X9 a- t! r( p% R
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])3 `4 G8 T, K* X% E; s

  56. 8 y0 e- U* v* K9 ^( a% t+ s
  57. bias tensor([0.0003]) tensor([0.0004])! ^, j4 g+ r$ E( J4 \
  58. & {7 j9 A/ I3 @
  59. tensor(0.0066)
    7 F% H/ e# M0 r+ R5 R( x1 e; r

  60. . ]\" s! ^0 I( b9 I# f: ^5 y$ l4 r
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])2 C' J' K' X7 I* |! o5 S/ u4 }

  62. # `1 B: a2 j; z1 _
  63. bias tensor([0.0003]) tensor([0.0003])4 G/ m; N) [* X, p% f( Y- J

  64. - S\" X6 M, X2 A4 f, h3 f: {
  65. tensor(0.0042)$ Q' z# J8 ^% D
  66. 3 m9 F( a  Y' \
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])/ U3 t4 @\" f; d8 s1 G0 q: }
  68. 9 ^0 D# h% E4 t/ G/ H5 x7 ~
  69. bias tensor([0.0002]) tensor([0.0003])
    : M& {\" Q: V- `2 o+ C
  70. ! j' Z. y0 j5 |/ [8 _5 F& j
  71. tensor(0.0027)8 w* N4 G) m* O* p+ e
  72. \" k! {# f6 X8 C5 F% n1 k% e
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    2 h' C; t1 {0 K* w) }
  74. 7 O2 Y. i! b8 o+ b( i) T
  75. bias tensor([0.0002]) tensor([0.0002]): g4 p1 c: ]/ h
  76. ( z, B5 u/ B/ b# L
  77. tensor(0.0017)
    ; Q* v8 z5 P\" W# S, f

  78. 0 ?- w\" K, l' N# C
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])/ b% L; L+ [, @, E4 ?3 E

  80. 8 [0 s4 X+ s- I1 d+ Q+ p
  81. bias tensor([0.0002]) tensor([0.0002])
    / v5 H7 _* s1 t5 J* t
  82. * n\" k3 ~6 M) @/ X* V$ h$ C+ W) H
  83. tensor(0.0011)\" _  p' L! g+ e$ z% N: a
  84. , W2 _/ g( |# b; p- ^
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    ( h- b/ `- t# P7 Z) L1 T
  86. 2 I\" t( u0 Q0 {1 Q7 _
  87. bias tensor([0.0001]) tensor([0.0002])
    / K  l( ?. k( y- A$ B& Z

  88. 5 @) x4 G7 h% ~$ ~# g
  89. tensor(0.0007)
    5 `  d\" ]( k! d, E! y! g9 h( B# k
  90. 6 Q0 Q8 ~+ R6 v: L
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    , K  r8 Q+ z+ H, m  t
  92. , G# z0 a7 [: W1 Q
  93. bias tensor([0.0001]) tensor([0.0002])
    # A1 B, o5 y+ g; x9 q4 W9 A
  94. 3 }7 R5 R: `% w$ U: |, Q0 }; c\" I
  95. tensor(0.0005)# \; Z4 ?! y( |8 R, f! k' M
  96. % U+ |, G, d- o8 A
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])3 a* b2 [# E( ?) k% W# d
  98. ' A2 V+ X\" o. B1 @. E  P+ g
  99. bias tensor([0.0001]) tensor([0.0001])
    - {0 Q. V: ?) h6 ?, S1 ^9 W
  100. ' d4 }5 B& ~' W: O7 \3 Y- L
  101. tensor(0.0003)/ q7 K& X. F' S  p; X! {
  102. % F8 |9 [* W6 b; R. [! o
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])8 v- C; N* n6 M3 r$ a5 m: F! A, U

  104. ) d0 y% d$ T: r- q
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    1 w' w2 ^* n# m' s

  106. 2 G1 I. K6 b# O' y0 z7 C. b: f
  107. tensor(0.0002)$ x$ d5 W6 W# [% z: n! M4 j$ ?9 t$ x& t

  108.   Q/ s+ U& a  Q/ D( O6 t
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])2 C0 C! P, x/ Q% a: z
  110. ! s) `4 P& p8 N
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    4 k* U6 [$ ^! \6 W: ?% j, w: d

  112. ( K- ?3 a* |: |0 h' ?
  113. tensor(0.0001)
    - A2 y5 P! K0 s% ^/ x

  114. ! m0 S+ q* U, `
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])8 |6 M- C; x! K, z- ~- `2 p
  116. 9 N  @$ m4 }5 c\" t5 ~  p
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    3 E. I5 J+ U\" e- {

  118. * @7 D* N  C4 |& [. ~4 s3 l, f
  119. tensor(7.6120e-05)
复制代码
5 S( d% r. K% \. z3 |9 x
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-14 13:06 , Processed in 0.429286 second(s), 51 queries .

回顶部