QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2666|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1189

主题

4

听众

2934

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
8 d% Y( K! r) W5 P: l+ VSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
4 L" c- P, @! F& p$ M) L+ N怎么理解梯度?
  x+ [, w$ N9 q+ A假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
$ C( g- G, i' k8 l$ b* n- @! g+ {! O. \
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。( o2 U2 i$ t4 P- @" C8 Q9 e" Z5 M, ?
4 s5 h; r7 j- P
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。) g6 T9 H( {% T; f

, ^, w2 s* h7 l/ u' [为什么引入SGD7 Z! K5 M8 i! j) y. `
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。  z+ ?2 c8 F  m9 R- q3 L

- i5 H/ p" ]! t7 T* z" Y+ S- g) O怎么用SGD
  1. import torch
    9 q2 L, A9 m* b- W& h6 b1 L+ O
  2. * p- A+ r6 W, S1 d
  3. from torch import nn  A4 ~% Y. |; ]\" L6 v# ]. ^  ]( \& R! D

  4. & z; D, Z! W# q4 M  }\" x
  5. from torch import optim0 Z, ^: g( k; G& q6 d0 t

  6. / c, ~. G0 m/ q) Q% U+ A
  7. & `2 u+ n; F# b* B
  8. $ t+ N* C% D0 z' j
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)/ h+ z/ {, j, z! h
  10. 8 I1 O# F; U1 C
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)% X+ q0 k% h  a

  12. 3 [+ s# Q) _  p2 ~

  13. : E, J- H8 V; `( `9 Q3 t
  14. + W  A5 j+ D$ H0 h$ L0 L
  15. model = nn.Linear(2, 1)
    , T% _( b* f/ l\" @- {5 {4 C
  16. # f* m' D& @% G8 P6 N3 ~- c

  17. \" e! c' p$ L6 f* ?6 v2 _
  18. % H  e. p( j( y# d# g$ P8 f! A
  19. def train():
    ' d8 ^# L8 c9 l  C, k: n
  20. ! ^& S/ ~: P/ H, a$ ~: D' p
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1). B# D) G3 @$ x4 \& P
  22. 2 h  z' [8 L: P1 S: g& O  [/ {
  23.     for iter in range(20):5 s2 }! L. W$ D1 s

  24. ; j  ^5 S/ O1 Q, B6 H3 @) M
  25.         # 1) 消除之前的梯度(如果存在)3 t  J3 U. x$ Y2 A) w# g

  26. 4 N; _3 _& G. ~$ `
  27.         opt.zero_grad(). p& r+ W! P2 K+ }# d4 e

  28. . z! s& T& \# k) S/ o
  29. 6 L6 a8 }- v) V8 J& @3 @
  30. 6 v! [4 Q3 R& V
  31.         # 2) 预测, L0 ^0 q/ ~- r1 e/ M* R1 `
  32. : [' Z, X- {, W
  33.         pred = model(data)
    1 j  A* k. }* C- R
  34. % U! C6 R  V+ v! L

  35. 9 p8 |/ E- B' \, R- A: N3 M
  36. ( n) z% Q# ^5 m6 k0 h
  37.         # 3) 计算损失, Q' G% ]: E5 S/ |9 `
  38. 2 j- R& ^1 s% v! o8 h
  39.         loss = ((pred - target)**2).sum()
    , ~8 U1 a3 u2 O# v; ]8 t

  40.   \. y  O6 j6 L1 {/ |/ W( t) o
  41. 7 q, X% j# B+ d3 o- i8 O$ n& E, d
  42. # O( S7 I0 Y6 b$ [. w2 L1 A& m, V
  43.         # 4) 指出那些导致损失的参数(损失回传)0 t- Q% c' Q5 f, R

  44. * E\" w8 ~  E  ?5 b0 A$ Z
  45.         loss.backward()5 e' }% D: P  A/ m
  46. 0 W% w! U9 q\" f! t1 W
  47.     for name, param in model.named_parameters():6 d2 N/ h& \. e6 Q
  48. 7 m- \6 g' G8 ~3 |. C; h$ I3 Y% Y
  49.             print(name, param.data, param.grad)% [: u) y, j( f4 r6 W+ d
  50. 4 |+ R3 w8 ~8 ?\" H
  51.         # 5) 更新参数' Y9 `9 X& C\" [
  52. 1 v. v% X5 W0 o$ x) n1 ?
  53.         opt.step()8 n# ?. H7 h: P

  54. & O% X1 ^  u7 B
  55. , P& U8 [1 V% p\" g  }1 ]* F4 T

  56. . v3 Y% Y5 l\" H$ n' Q2 ]4 E, m
  57.         # 6) 打印进程+ z+ q# Z3 }. E+ _

  58. 7 `: z$ M8 b% y\" ?3 R5 Y
  59.         print(loss.data)  {\" P- ^& @( n* [4 i5 \' E

  60. : V' O, ^' ?% `2 _9 Q/ Q
  61. \" V  N# A' K; g: Q

  62. : w7 t, e7 |* T4 c# s3 q. P
  63. if __name__ == "__main__":$ S3 {$ w7 m+ m
  64. ( T/ d\" \$ z, ~: L7 w  @5 M6 ?
  65.     train()
    5 @9 l% m% P$ f\" F% X1 G/ s; Y

  66. ; j8 r0 r% j; D6 W4 J. H/ |
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。$ H( f; u& D' \0 J5 K, d) Y8 \

% v' v- L* G; h4 O$ ?- w% Z计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    * `8 H- `: q$ X4 J& t, z
  2. ; T& Y% Q8 u+ H8 v# ]( J2 T
  3. bias tensor([-0.2108]) tensor([-2.6971])% i\" N2 z. |) A+ z

  4. % }4 O# Y* z* F5 u6 s( c
  5. tensor(0.8531)% o% Z6 ^7 k6 A

  6. * `3 x1 ~7 S9 N) @9 z
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])% \; ~( W- A- q% W; O

  8. + C. z# o\" g. k$ T, V
  9. bias tensor([0.0589]) tensor([0.7416])
    4 ~. y; E% U+ n) ^
  10. ' @# D4 K% o7 f$ n9 a
  11. tensor(0.2712)
    4 F; [+ k5 `! S9 S0 c! F8 v( {
  12. , j7 m' S$ _; v8 e3 O  K, L: L
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    . h4 k; y( c; r9 w
  14. ! q9 W+ S/ G, [
  15. bias tensor([-0.0152]) tensor([-0.2023])3 G: |' ]! Z1 b+ Z

  16. ! W. J) W  c& C1 k$ c( J3 O: n5 c
  17. tensor(0.1529)5 b# ?! M! b7 m/ G- X* {
  18. & K/ _7 H$ i3 B9 g4 V( o- j0 j
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])! a5 k; r* u% @! W

  20. ) `3 l( p: [! ]* r0 o9 d  \/ @
  21. bias tensor([0.0050]) tensor([0.0566])
    / v  F+ S8 H/ W5 G: Y

  22. 9 c& b8 u# Y1 A) F2 |& a
  23. tensor(0.0963)
    / X; F\" s: a( g. j
  24. 2 W6 R\" t. u+ w. g* u1 `
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    3 g; k* P+ ^- b2 q5 X* Q

  26.   c: D3 b6 O% `. O& i2 ?, b
  27. bias tensor([-0.0006]) tensor([-0.0146])1 `6 u; E# r1 D& ^& [  L% b6 \
  28. # y, |* b% v3 q) j
  29. tensor(0.0615)3 G# C2 w7 D& x  U, H

  30. , M; p6 o9 R  _8 `
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    # r9 P' y+ C1 ^0 I  D

  32. # \5 w8 k2 [* F$ ?
  33. bias tensor([0.0008]) tensor([0.0048])( g8 k/ Y& O' a. m: t\" k2 v4 Z# ^
  34. ! R5 N8 ?) N& ~! N# |
  35. tensor(0.0394)
    ! V) o5 w+ U$ `# u
  36. ( f7 m5 s/ X1 h! q
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]]), Y0 P/ g) |7 s7 m! Z! N
  38. & [( y' v* A7 G1 t8 K
  39. bias tensor([0.0003]) tensor([-0.0006])+ y1 X; F' f0 Z  ?* `) W0 U* d
  40. % Z2 e+ N6 E, T\" j( j! N
  41. tensor(0.0252)
    , Y9 b3 }- b- k8 O1 v4 \3 x) n

  42. + n1 f5 W- M3 i
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])' i: ?\" R7 p( l& g( V6 n% l( ]
  44. # [1 \1 R2 e0 }7 @8 N9 b% F4 X
  45. bias tensor([0.0004]) tensor([0.0008])
    # u$ O8 f0 L$ R6 |

  46. 5 S/ _* H0 k; P/ S7 q
  47. tensor(0.0161)5 y& |8 P) M) r7 C' k! l
  48. 4 k; c& `, V* H8 K$ u; g7 H
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])) Y0 a7 @) z8 ~7 H
  50. \" l5 m0 s/ f$ r
  51. bias tensor([0.0003]) tensor([0.0003])
    $ p3 e- v' X' a

  52. - g8 ], c/ r0 q6 D, f9 n0 r5 _! _6 Q
  53. tensor(0.0103)
    ) g\" l, ^+ V6 q( `

  54. 4 t# U. ~0 r* h* a8 x! a+ L/ R
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])$ \# z+ I4 b% g% @  V4 o
  56. 4 b: h) I7 F3 L7 C1 F, ?
  57. bias tensor([0.0003]) tensor([0.0004])
    $ t: }+ h+ E: K% m  a5 g% Y7 C4 l
  58. - t$ _! _4 T\" }. m5 m& l  W* ?
  59. tensor(0.0066)5 L6 o1 y2 t* M0 v, `

  60. 3 t- f9 i  v\" F\" H( H6 \+ e
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])  C1 \% ~% p; F9 D* T( g2 A
  62. 1 D% g) m4 Y' w+ g\" E
  63. bias tensor([0.0003]) tensor([0.0003])0 G- L# K6 @# K* s  w
  64. 0 n+ P& I% [6 ?! w$ t0 [* @! Q
  65. tensor(0.0042). ^; ^. o3 Q& @! n9 L
  66. 6 ~/ M' h\" f1 o* `' c% G, t: J3 k
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    - f\" o7 y2 U4 {2 p3 p  g. `( r
  68. % c, ~+ S. G! f4 x+ _+ v
  69. bias tensor([0.0002]) tensor([0.0003]). T6 Y: }* f+ W: M4 n
  70. % ]9 f# L3 l2 v. y5 F2 R4 h
  71. tensor(0.0027)4 `% I2 a! Z/ W- \5 D- f0 n; a% E

  72. 2 O+ h# t  P0 X8 U4 x' x2 M
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    9 Q3 @\" ~. ^8 M7 l

  74. ; k  S2 C8 r- @  ^  b
  75. bias tensor([0.0002]) tensor([0.0002])
    , q& ^+ O& q0 Z\" X
  76. 3 K- ?& {4 M1 A' l0 t
  77. tensor(0.0017); X7 x% S8 G& I7 `' l3 p

  78. ; W: W+ E. q, L+ a8 ~
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    \" w# O5 E/ E\" k/ A0 ~
  80. 0 @0 C4 v# q- q* f/ A  |\" E( u
  81. bias tensor([0.0002]) tensor([0.0002])
    5 z+ r% L- ]  L4 Z& ]5 w
  82. * q: \) O; s, a( D1 g
  83. tensor(0.0011)
    + a$ |$ i9 b- ]3 T1 u4 c2 w9 h
  84. ' E* W2 @- o/ Z3 i( _, V
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]]); |+ P4 f( z0 K9 a

  86. ! c$ X- @( T  N2 m) E$ A; @
  87. bias tensor([0.0001]) tensor([0.0002])
    5 h4 n) [4 `7 O. u5 p
  88. $ ^+ \8 Q/ r5 F5 L1 n( b( R
  89. tensor(0.0007)* K) ]( [0 L+ D4 Y& x% t5 \

  90.   q: [; l5 M& V
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    6 x6 m- C\" O& ^, [+ E& I. N

  92. ; r3 s; e, Q) |
  93. bias tensor([0.0001]) tensor([0.0002])
    6 v0 {. V+ x; O) U! C

  94. - X! s4 u  W7 C2 ?& |& }3 r
  95. tensor(0.0005)9 W+ N# Y* f\" z0 q  |
  96. 2 ]4 P6 [5 V  j& I# K$ a/ h7 \
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])& ]& j\" v' F  X2 Y$ W; I
  98. 5 ~' l0 |6 ^8 @; i  Y+ W* ]4 H\" B3 \
  99. bias tensor([0.0001]) tensor([0.0001])
    4 `1 u3 H. [3 d5 }. B6 V$ \! p

  100. 2 v3 U7 ?: f5 t. N2 n
  101. tensor(0.0003)
    / m& f9 {' ?6 m5 m1 T; k3 G
  102. & O: O; A\" K: T8 Y
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    ' H\" t- s$ E) [2 v% ]* Q0 S9 t
  104. ' F+ h% ?$ ~1 Y/ c8 Q1 r
  105. bias tensor([9.7973e-05]) tensor([0.0001])
      I1 ^* g; M4 n* C
  106. ! y9 l: j, [) w6 _1 I( Z
  107. tensor(0.0002)5 X7 Y- k8 ^$ G; s$ a\" B
  108. 4 ^& _5 m) P7 A( C( ?( J7 L7 M
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    $ h% N! R' @0 Y
  110. ! D+ G1 d1 X5 M. W) T; h# O
  111. bias tensor([8.5674e-05]) tensor([0.0001])% j; O% T' V# k- Y( q

  112. + i( B' z' L' T\" [- E) U
  113. tensor(0.0001)) G, Y2 F4 Z\" V& {' O  [) F$ t6 p

  114. % ^) f& P% o/ I; O: Q
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])  d, |5 z+ u  V% l2 I* j8 c- c
  116. ( {, y3 v! B4 {: C( q
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    6 r- I; q3 b1 j\" J3 u
  118. & S2 V/ q0 y- }4 Z1 Z( w
  119. tensor(7.6120e-05)
复制代码
, l1 N( w6 D6 v8 o* q/ ~2 w
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-6-7 06:43 , Processed in 0.873953 second(s), 50 queries .

回顶部