QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2632|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
" N% V+ ]! p8 i2 xSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
2 f. ^2 B# G' p怎么理解梯度?
+ f, K. `5 J. L( A假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
0 B. y! K/ ^/ K1 J
3 h# y5 ~- \: h在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。( d1 d8 C1 q0 b1 g0 c  S

5 G" w4 Y8 ~; b) |3 B' s每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
: g/ I; _: v6 s2 ~9 }9 Q& l/ E( W$ r9 p
为什么引入SGD  O1 e' H9 J$ B4 [4 ^
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
: x2 F4 x/ U  [( y; b$ R$ ]
8 y1 n- P0 T6 W) }8 g- i4 q! H怎么用SGD
  1. import torch
    6 m+ c  l6 u0 A* b+ O* D6 B; F. c

  2. 8 Q* s5 o4 }( H4 J5 V9 L5 B
  3. from torch import nn9 o5 n+ x. k, z. l$ U6 f
  4. / B) e/ M- E\" R9 ]! Z
  5. from torch import optim
    ( s4 N& ]1 A\" z& n/ V6 [

  6. % A  T\" b\" ]% V8 l
  7. + R6 V( K' ?8 S. z- h6 Q( Z

  8. - y8 Z8 J8 i3 T. R
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)( I3 a8 z* p0 y' [  K
  10. 0 B8 s( [+ k& a+ V# f* t
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    + O2 \3 ~3 T* J' n- V\" v( |\" p5 r
  12. 7 c  p8 ^' x1 ~: T9 P8 X

  13. 8 }0 P$ C# U. d! w: g  y
  14. + ]% _( u9 u( M8 D! p9 s
  15. model = nn.Linear(2, 1)- k5 r5 |\" p6 q1 N

  16. $ @% |; h* D& z\" c9 r# I

  17. # Q0 l\" a4 X$ H  O3 y* q5 c( i
  18. & S- @4 {! v/ h* k& `
  19. def train():3 {! d: r+ f) C5 N! J/ ^$ b

  20. \" v! i; O% }: ?( f
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)- [\" X  s: B3 \$ U8 r  Y) n
  22. * r1 H! J. U* K5 F- @$ b/ n
  23.     for iter in range(20):
    9 T3 C/ K& p' x; h& `

  24. / N' V( r9 |8 Z! M
  25.         # 1) 消除之前的梯度(如果存在)
    ' D! }5 _+ {/ m- T0 T7 z

  26. + A! q. ~, ^4 O  k! _0 o/ w
  27.         opt.zero_grad()8 ~4 O, P# h9 p* v1 K. ~+ S2 q. S
  28. + n4 C8 p' p- r% I' g& ?% N( e\" j* I
  29. ' Y. a\" c; b- c% x* A

  30. # U  H6 U7 U( q5 `8 c
  31.         # 2) 预测9 W& a1 r6 e! F9 x: `
  32. \" {* ?* t; {; M0 R4 h) }. E3 d
  33.         pred = model(data)& K: o6 o: T4 q( n7 N7 n; ^' [
  34. \" ?  V+ c/ V+ V# d
  35. 0 ?: V; _7 O% L9 `( n
  36. ! {: F* N* T1 b  v% s1 n
  37.         # 3) 计算损失
    # i& d9 `: C/ S' Z

  38. $ Q+ S) x- `# |! K' h
  39.         loss = ((pred - target)**2).sum()
    $ N: X& T0 i, O5 R2 Q) [! v) X  y
  40. ! K; v. u; F7 e* W; s4 L# s
  41. $ [4 r& Y\" o1 l1 U- d
  42. 6 G, g\" V* Q! z' h. y  @* r
  43.         # 4) 指出那些导致损失的参数(损失回传); B$ y( x+ J\" _, T/ d8 R* q

  44. 8 t) h0 ~% [) f' q4 C& }
  45.         loss.backward()
    & l) S, }' l* L\" `2 [- ]$ X. B\" b
  46. ) V, u; C- O5 w4 a, j& c2 T
  47.     for name, param in model.named_parameters():, `! v$ m5 B& ?6 K8 ~\" O) p8 B\" l
  48. ( Y6 ]3 I( ^7 H\" Z. l
  49.             print(name, param.data, param.grad)
    0 b+ Y1 v5 v, u0 @  H8 @
  50. \" B' Y4 _4 \* P5 P1 D2 z# \
  51.         # 5) 更新参数; [- |9 T9 L* M; R) b
  52. 6 y3 i% Q1 K9 N
  53.         opt.step()8 ^. B' J$ Q. Q

  54. 5 I4 r\" j& z7 I
  55. 4 T, g) K& O2 R2 U% |. E, \. v# x

  56. + j2 X7 C\" F+ n' I3 b, Y
  57.         # 6) 打印进程  @9 B  p2 G7 b# E

  58. 3 z- m8 N+ i& y, c5 o- y2 E
  59.         print(loss.data)% f* g! r+ X. r: ]9 R' r3 K

  60. 7 d& N1 O8 X3 c; h

  61. ' Z3 e4 w' q+ v+ @# V

  62. & S/ A\" n% I0 L
  63. if __name__ == "__main__":
    & l7 ]. k! N% n

  64. : I3 |& B5 g8 q$ w\" p
  65.     train()
      `! B% I4 o9 y; I. S! [2 I

  66. \" e8 v: O9 h: ~( l- j) K
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
8 H$ a6 c( c, D. K5 i" H' H5 }+ L$ W7 x- [& t8 _2 `3 g( [. Z
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    , b8 ^2 E/ x1 ~/ B9 z2 `5 U+ Q
  2. # i2 `* V. z\" \  b, U
  3. bias tensor([-0.2108]) tensor([-2.6971])\" q8 A2 L/ ^) c9 F' V

  4. / b  ?: m\" g' {) o5 q/ H/ z
  5. tensor(0.8531)3 o4 y% ]6 f: o' x9 Y

  6. ( J) t) I* B, M& o
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    0 v* Q7 p  k/ O, J# a
  8. : }  u  v) }7 X\" Y- c
  9. bias tensor([0.0589]) tensor([0.7416])$ K! M1 K6 ?. Z. M$ B2 K/ m
  10. ! i  ~, D% |( z2 {9 @& c
  11. tensor(0.2712)8 O9 D+ `\" h! ?

  12. & n$ H+ d% b1 O& I0 \$ y2 d( H; D
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])' V+ }) Y6 Z5 D. V: w
  14. / v- z+ q5 s& ]
  15. bias tensor([-0.0152]) tensor([-0.2023]); _1 q  h\" t, y6 V, X6 s

  16. 9 F. k7 X6 R. k* d. G\" @
  17. tensor(0.1529)
    ! g8 p( J& A! U  p
  18. 8 m% o9 \! a8 C
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    + }, |& F, o/ W- L( ?7 \7 U4 C% v/ V
  20. , B. _8 |: ?( y4 [( T4 S) Z0 q  X
  21. bias tensor([0.0050]) tensor([0.0566])
    ! Y* g\" N- g) Z8 q
  22. # S0 K\" T1 i$ r9 _% Y
  23. tensor(0.0963)) F& E* P& V0 J/ W% D* p
  24. * h1 I+ S) L7 I9 }; N& G
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    6 r4 l\" r; K9 C: {% [& V

  26. ( f4 L: y, s1 B4 i, {
  27. bias tensor([-0.0006]) tensor([-0.0146])
    * d\" G/ z7 V# W  J
  28. % d5 R. g# q# a/ [
  29. tensor(0.0615)* g) O% t, u, A1 H* J$ J
  30. ; D, S4 R6 Q9 m, v  T9 X4 f5 @) Z
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]]): }, s\" N8 Q& b+ K7 S) |

  32. / J0 }3 m+ u, r. H- Y9 Z) q4 L\" [# F
  33. bias tensor([0.0008]) tensor([0.0048])
    7 N* [3 e# F: l& b

  34. 7 B. ^- R7 j  k
  35. tensor(0.0394)4 q8 R\" h+ `: D. K

  36. : M/ B5 a4 A8 O# n' |( [
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    5 Q4 L5 R+ W# D

  38. 7 i) q\" M9 y6 r( q% c! |' |
  39. bias tensor([0.0003]) tensor([-0.0006])
    ) m; A2 ]3 B6 D7 O

  40. 3 a6 p' X* X) ]% Q* d! {8 x
  41. tensor(0.0252)
    4 K! N+ D) i0 ^! _4 ^1 G: P' u\" C9 p
  42. 1 a3 @) x\" ?1 g\" `
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]]); D* A( e) i2 }: j1 ]- a
  44. 9 C0 [) ^$ [# `; \' [5 K9 r# b4 d\" c
  45. bias tensor([0.0004]) tensor([0.0008])
    - u2 j7 t9 s- V, G5 h

  46. 6 R! @/ t7 r; m+ {2 A$ u) B5 K) `
  47. tensor(0.0161)5 G\" h' P( }  V8 O$ \

  48. , H3 A* y% Q% m0 L3 c
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    ( y& [1 H% c4 R+ }0 s

  50. ! C1 p- R0 ?: X4 k
  51. bias tensor([0.0003]) tensor([0.0003])8 g9 y8 b1 g% C6 G
  52. 5 o/ A' K' ?3 B. _$ H
  53. tensor(0.0103): N) X! H! h$ f) @# I' Z- v
  54. 2 v8 U! j+ c' k. {' e1 t
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    3 c& I$ S* M# b' \
  56. , G, }6 A) d; j+ g' M! B$ o. M. X
  57. bias tensor([0.0003]) tensor([0.0004])& `8 H) C1 h) I; l) ^; s* ~) W2 e

  58. * X/ N  L8 B+ A7 D
  59. tensor(0.0066), \6 r- S5 n4 V) P! C

  60. ; m- R5 _# W  ]: Y* C$ Z, T
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])  Y. U( T5 v# ~' @* q

  62. ! B/ w0 `- z' G3 f. E: _
  63. bias tensor([0.0003]) tensor([0.0003])( r% h\" H0 S) z

  64. % M8 x: G0 n# e$ u* q3 O
  65. tensor(0.0042)2 n4 N/ C  \  U: p1 _
  66. ' B\" I4 t# l- K* t1 H
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    % \$ k$ l1 D/ y4 Z/ A0 {0 S: h
  68. & P- B4 ^: ?\" b+ W  s+ \/ n3 L\" q
  69. bias tensor([0.0002]) tensor([0.0003])0 l+ j) E( B; D& Q2 N' i

  70. & y' N; m# T7 Z0 l! N) h3 V
  71. tensor(0.0027)7 w. u4 C; I1 Y) P4 m) C
  72. 0 p; V$ X! S, T/ k# Y9 L7 G- J# ]
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    9 [0 Q, M3 S, U( l3 U4 H$ j\" Y! H

  74. ; p% g8 A% K; Y  N7 a1 J
  75. bias tensor([0.0002]) tensor([0.0002])% F( [4 ^5 p; F6 C, p$ ~6 @- Y; i% E
  76. $ K2 _) I8 d4 f
  77. tensor(0.0017)\" j$ k5 L8 i$ j8 ~0 V) D
  78. 6 |7 B* j- N& d% P1 Z. v8 r
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]]): V4 B( l& d! G/ u; x. o
  80. : t! B3 s  e+ N* X: M- \9 {
  81. bias tensor([0.0002]) tensor([0.0002]); _) {: i, x4 v0 B

  82. 0 X  q3 n% `) |4 U$ L+ A. L
  83. tensor(0.0011)+ |2 a& @8 M1 T7 Z

  84. 1 H. c) k! c\" ?1 O) C; ]: Z! ^
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    % a( ^: u\" f- q+ a
  86. - v. [\" T- @7 L' H$ B
  87. bias tensor([0.0001]) tensor([0.0002])- `6 U/ W( {2 d\" q
  88. ) j, i, O! ?2 h6 n, t8 e
  89. tensor(0.0007)8 c# V, K* i  g) v/ s
  90. : [* R- Z- _2 {\" S3 v3 B  Y
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])\" S( H% Z4 T\" p2 d

  92. 7 U! R8 A( l% \7 _+ H$ |
  93. bias tensor([0.0001]) tensor([0.0002])
    % U, q8 Y1 z+ I) G7 ~
  94. ) c: C! F  ]4 V\" ]$ U
  95. tensor(0.0005)% }' d( F0 h7 {& E7 o- U
  96. : m, r9 H* P8 l  i+ K- i
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    ' h: m0 P/ s& n4 r2 }! z

  98. 7 ?) q# u- l. {7 ]5 @; k
  99. bias tensor([0.0001]) tensor([0.0001])3 o, P' k6 O/ T9 S0 I9 ]

  100. 7 n- _2 {\" A* ]! d3 f
  101. tensor(0.0003)
    7 v8 v7 `\" H  ?1 F  V8 M$ X
  102. $ e( R4 G/ w5 l; D\" K3 K
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    9 O& x\" V$ |' {% `( n- c

  104. 0 t$ y2 ~9 n8 K
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    ' Z' s! [/ r' Q' O1 g

  106. 9 b7 a1 d. t7 s7 r\" _
  107. tensor(0.0002)
    9 s  n( Y* C* d; V3 s1 t

  108. ! B1 H( m$ L7 F; T! g\" q0 Z
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])! D+ h# r. P$ z0 D3 t
  110. 0 r2 p9 N; r8 g5 g; a& h2 t2 r\" K
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    / Z% h2 P' [+ \
  112. # u# F+ ^- L1 w4 D/ o+ N\" _
  113. tensor(0.0001)
    9 x# P$ P0 V\" d! _2 T( Q

  114. $ Y( p3 R- [* n' I& K) ]8 c
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])8 l9 K; y$ u5 y: k

  116. ! |# y# M7 I1 n6 ^
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])  z9 X4 ^0 s3 f! h* n1 U: O

  118. - Q1 f3 N  D7 T; g* N* W
  119. tensor(7.6120e-05)
复制代码
; T8 ^, }% T' t/ x
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-14 12:53 , Processed in 0.421636 second(s), 51 queries .

回顶部