QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2631|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
8 }  }3 h$ U6 X. pSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。5 Q& \/ h- v7 Q6 n1 X* Q  z3 v& ^
怎么理解梯度?9 B; H4 i8 K2 A$ H7 s# X
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
, h1 {2 @4 y/ f' x$ w8 p4 b$ h3 ^' _/ @
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
+ Q. [8 e; w, |2 g5 X# S
6 X: h0 ~+ v- N/ |每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。' l/ N, R& ^' ^
5 @/ _3 i# e+ U9 |
为什么引入SGD
' P' _1 i! M0 ]9 N; E/ y  O- H+ K深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。5 [1 w" u& P7 {/ x% ^/ ~

$ W  C+ Z( a0 D怎么用SGD
  1. import torch\" ~\" E. B6 _; C; u& s
  2. \" f) `* L: ?; J5 u2 C; Q
  3. from torch import nn& N9 L/ C' W$ x+ s
  4. $ y3 h5 v6 U2 m\" @
  5. from torch import optim
    + [9 F. o& B( Y7 J
  6. 7 q5 Y+ d$ M# Z  S  X

  7. . a6 E/ K/ y5 y. g5 v\" |5 w( {
  8. # D( _7 G6 a/ ]1 r
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    ( Q5 S# h% }0 O4 ^- B5 r( C
  10. 4 D' ?9 _# S* L& {; c1 s- ~8 e
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    8 g4 k. Q, S6 A+ D

  12. 4 I# f& b0 p+ c# H/ y\" J

  13. + j& J, ?; I: _1 j6 z$ a
  14. 7 `; B( N3 e! o: f( b4 l
  15. model = nn.Linear(2, 1)
    \" H- Y; A$ @$ F4 \& k9 Q# m% e; q

  16. / i& E! ~+ \% o  z\" H+ f$ R' A

  17. 5 H\" _0 z& J. B! ^
  18. 9 q9 D7 k2 b, @2 A6 b8 u
  19. def train():
    9 y. U6 L8 ^, S1 w# r

  20. 1 v! l7 `1 i+ p. q  Y
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)1 K' d6 y  F6 {: G+ f* u
  22. / c+ @8 F$ |! C% Y4 U3 m, Y
  23.     for iter in range(20):
    ; l6 }, |, b3 R6 ~
  24. 6 s5 X\" e0 i+ V
  25.         # 1) 消除之前的梯度(如果存在)5 V% U% L9 K3 P2 h* j: h5 d
  26. 4 F; ^& Q' D) K7 X/ E1 `\" B
  27.         opt.zero_grad()* ~' h2 f1 I/ G2 C
  28. % ]- _- G2 ~1 D; x
  29. ' c  K1 x/ }* u& C: L5 l! Z1 R2 D
  30. 8 b4 y, y6 l: b' A& h- [& q7 x
  31.         # 2) 预测/ O- \; ]8 r& A\" _

  32. 0 [/ m1 C- L. s9 u* y( [/ {# N
  33.         pred = model(data)0 O& }+ `# o: v3 z9 N

  34. : ?/ |8 \% H+ @\" y
  35. 8 b  f; P3 L: }/ \7 z

  36.   G: ^+ C2 l/ X5 T- D0 y& \
  37.         # 3) 计算损失
    $ ~! L! c. T& Y( v

  38. / J) ~4 K6 h) I, s. ?
  39.         loss = ((pred - target)**2).sum()
    3 {4 B$ M$ ^0 ~# G  z

  40. ; }- H0 o5 ~\" T$ ~, u\" A

  41. 2 Z3 a$ |1 c( {
  42. 2 z6 x! I, v) u& l- Z+ l3 P
  43.         # 4) 指出那些导致损失的参数(损失回传)
    # |2 \- ^- _\" j8 o, X- c  Y
  44. * G8 I$ _\" z\" M0 _' E
  45.         loss.backward()
    $ t\" N4 P% r( h, f1 Z: r
  46. 1 Q; n\" q/ y9 G& Q
  47.     for name, param in model.named_parameters():
    6 F  x) h) V2 n/ t

  48. ) t4 O/ E3 Y6 y4 r$ C; B
  49.             print(name, param.data, param.grad)) V8 r\" L3 H  `3 W$ W& n0 Z. M7 K

  50. / n1 H- n  K/ l8 }
  51.         # 5) 更新参数
    . [+ z* g. R+ t* t# b5 P
  52. + k% v1 n5 x2 t+ z5 Z4 U+ `# V3 \
  53.         opt.step()
    ! m# |$ n* x6 ?- g* m6 C
  54. 2 A) A/ n$ K& `

  55. 3 ^- Z, ^' J- f# S+ a! F
  56. + U! D  K. z3 K# ?\" z
  57.         # 6) 打印进程
    4 c6 q+ @7 X% U) m+ c9 l3 C
  58. 1 o8 ^5 Q. c+ K
  59.         print(loss.data)
    9 j: S( z& A% ^  Y: ]1 R
  60. * @, D. q% M/ H

  61. ( \1 n/ ^, ?* w/ X% E* p) i

  62. 2 {7 X7 X) d3 X1 I/ [: g
  63. if __name__ == "__main__":+ U$ E0 ~) g; k; [# G

  64. : h9 l; |: r; K' q# b\" C( U
  65.     train()8 y. b: D2 Z9 h/ H9 N- F% b

  66. 6 r; c+ @1 Y/ a\" n: [2 h
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
5 X7 c, ^8 m7 c: L3 c. s  t& S& I0 a( F9 u/ b7 l1 B. J
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    7 C; M5 E( n; Z
  2. 2 K- W! v+ b, v! {) D% H4 O\" I0 x) L
  3. bias tensor([-0.2108]) tensor([-2.6971])
    , |8 q, ]0 ^5 w7 p
  4. 6 |. K% ?+ S4 U
  5. tensor(0.8531)
    6 j; z% k: C- s; v7 U& e) L- _  X; F

  6. 5 [3 p7 U% O: }0 H# Y! U% A  }1 S
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    3 {! |2 {- y6 D5 V% n# o9 p5 g

  8. 2 h! ~: w7 X! S\" \& z3 E; e
  9. bias tensor([0.0589]) tensor([0.7416])
    4 \* E# \\" {' e6 O
  10. ! y% D6 Z\" q- I* M, z2 m+ w
  11. tensor(0.2712)
    3 Y' A4 y# s0 L1 d# d) d8 j
  12. 7 o5 J! K0 j! t1 m
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])* \5 I) X  M. n\" w* o* t

  14. : d7 P: a+ c. ]. R
  15. bias tensor([-0.0152]) tensor([-0.2023])
    $ Z/ |: h5 S$ l: C6 |4 G( e
  16. & K\" e3 ]9 f* |: g0 G: _
  17. tensor(0.1529)9 {, S9 F' {. c! p0 F8 \  ~
  18. 0 @) q3 _- G6 p7 k
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    3 `' l4 u4 z$ b, u( W# {  a\" v/ K/ q
  20. 8 k2 ?0 O' a6 T& q0 X( \
  21. bias tensor([0.0050]) tensor([0.0566])7 [7 p1 i3 D$ z8 c/ D

  22. & W; Z9 B, P: v- [\" X6 E
  23. tensor(0.0963)( C8 C/ [4 F\" X$ x* e# e

  24. # e4 z0 J4 K! L- O& U  a
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])0 h. u. b9 u% c$ i0 X: o5 A- f

  26. $ h- H! E6 I# u\" y
  27. bias tensor([-0.0006]) tensor([-0.0146])4 Z* [, _6 v9 z9 \
  28. + y0 y6 v9 O' j+ v/ W. P0 d\" y
  29. tensor(0.0615)
    1 p6 z7 J. y) v4 X; V& U

  30. 8 F; l, t- h6 G9 o\" t\" @8 `
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    % v( [& n\" z' i* _- G' h- h\" K7 `- j
  32. 1 @+ e; ^6 L( q' x
  33. bias tensor([0.0008]) tensor([0.0048])3 t- F$ d\" c* X. v( A
  34. / Z4 n- q7 ^' b0 l! I1 Z
  35. tensor(0.0394), `+ _/ \8 E! F3 O; F
  36. 1 Q  X, L* C& o( V! G
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    ' Y\" f\" i% G& i8 y
  38. , }; j6 X- O% o8 g# C% v* G
  39. bias tensor([0.0003]) tensor([-0.0006]). ]& r3 {0 I: A, s2 \2 Q1 w1 E! _
  40. 1 J  H' v3 ^  d: S# Y0 b; v' u
  41. tensor(0.0252)
    & p& _5 K% d) {/ w: ?1 C

  42. $ _3 W# _4 \, [6 x
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    + K  U& g\" a% _1 C
  44. ' a\" g0 L\" r, ~/ b% ^
  45. bias tensor([0.0004]) tensor([0.0008])7 p! {, c1 f8 M6 j

  46. 6 _; v2 ~8 {: r6 z  F* b( A
  47. tensor(0.0161)  g6 ^0 D0 {, F0 \# M

  48. 3 _\" Y. }( H' |2 y2 N# ?
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    - w) q/ L* M( Z\" Q$ w4 F& J; B
  50. 5 ~4 e- t' p( h/ W- B4 w
  51. bias tensor([0.0003]) tensor([0.0003])
    % C3 i\" T( Y! u# A\" }
  52. 1 [5 u) S9 X- I\" J  t1 c\" K( M
  53. tensor(0.0103)
    5 C) Z\" s4 H0 p0 e8 c' t0 X
  54. # \1 ~+ E% \9 j
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])$ H1 j* _! m( F
  56. 7 h( b4 P' Z2 E+ q8 B/ E\" B# B1 u
  57. bias tensor([0.0003]) tensor([0.0004])
    - b4 ?7 X0 p/ k

  58. - O2 X) q7 L8 b5 B3 D
  59. tensor(0.0066)
    4 {' `3 Z( d  m! H

  60. + Y# B9 c/ z+ {0 p) \; \( Z
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])( ~4 z2 J2 o: W2 e, y: |6 R

  62. \" @, I7 w- h1 h# X  v. m$ m\" @
  63. bias tensor([0.0003]) tensor([0.0003]); y$ o9 F+ ~$ Q( k; i* \
  64. ) g  A/ a8 p3 Q/ U3 U
  65. tensor(0.0042)
    ( _7 k, B+ x* N  }+ z& s

  66. 3 \3 e% [! |1 t* {# I
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])0 X, H7 C0 f% f, g, W, c
  68. ; U9 Q1 M3 ?5 Y+ K
  69. bias tensor([0.0002]) tensor([0.0003])* h$ Z% ~7 V4 V+ Y9 W( i\" C
  70. & L5 P6 ^6 }) Y; |+ i: h
  71. tensor(0.0027)\" s+ D9 `4 |5 R: g
  72. ; @& c0 T8 _2 ?' ?8 I- B
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])\" _& g. u8 E0 q# z! \

  74. + P  O0 A- k* M: Y\" |+ A
  75. bias tensor([0.0002]) tensor([0.0002])
    3 T: }! G. [6 r8 }& g. E
  76. * v3 E6 d  d8 o# l, h* m
  77. tensor(0.0017)
    # I% x5 u+ B6 N: C5 }3 Z

  78. ! C! T* I- g0 X0 E, {
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]]). y# K+ p+ {/ y+ K  u
  80. 0 @, T: N3 n+ J5 M/ w' V4 d* D$ L
  81. bias tensor([0.0002]) tensor([0.0002])
    . U# k& b% a: k: y6 d0 A; P/ v

  82. ' A  a* U. `7 z
  83. tensor(0.0011)8 ~( V8 l) C3 C- i6 p- p8 Y* x1 `
  84. + d  l: I/ ?) L! e4 _. P
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])$ h7 z& v: y, S
  86. % W8 T# F/ ^: H) k
  87. bias tensor([0.0001]) tensor([0.0002])) q* v; e+ {, X. H0 R
  88. % I\" g  {' u3 Q8 h& W; N
  89. tensor(0.0007)& m0 y4 U4 m1 ]( H# }
  90. 3 f3 n, A& D/ E8 O; }: i8 f3 e
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    ( D9 \- a8 F/ h, P. Q3 _1 E
  92. & b* M8 e4 S; Q, Z
  93. bias tensor([0.0001]) tensor([0.0002])0 `# O. q4 H' R8 v6 `1 J
  94. 3 w7 h) |& V1 Z- `
  95. tensor(0.0005)
    9 k2 H4 n% Z+ b# T2 i9 T1 K& B

  96. 5 L7 ?- {5 l5 B# A; g
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])2 |* ~6 r) N) T) K/ h  U( }% w' X6 p9 q
  98. . h) k1 E. E7 t/ z, M, @
  99. bias tensor([0.0001]) tensor([0.0001])) Q7 c; A; ?3 g% h( \  u
  100. . ^% o! e; }( V0 d% F
  101. tensor(0.0003)
    % w9 p) }9 H5 F' I

  102. : ]' s3 w) h+ c6 c( \
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    / Q* t2 `# ?) R* c4 Q9 ~+ q

  104. 1 _- M0 H/ B3 p7 M
  105. bias tensor([9.7973e-05]) tensor([0.0001]); p\" ?* G; [) A2 p* L0 S
  106. . _& X  ]* e$ \: G
  107. tensor(0.0002)6 q- H4 X* J& |, w, d8 o+ R\" t

  108. ' U! B# c# ^+ p5 z; r
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])5 U- i! K4 ]4 J8 E% Z$ K

  110. ! n; H% e3 r$ F
  111. bias tensor([8.5674e-05]) tensor([0.0001])  F$ o5 Z& Q4 ~; b+ j% v, N$ z
  112. ; Q4 [' ^' J0 C& q9 j
  113. tensor(0.0001)8 P. F7 x, J9 U( Z; q1 r

  114. - j- M  b* d+ ~) O4 z; h
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]]). x+ d6 P) A9 q, O3 u

  116. & S9 }0 _8 p; F6 s* q' e
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])( S7 j\" C# a/ f% F- H6 ?: f6 Y
  118. & E\" |2 I6 y) V. }0 V; M6 t
  119. tensor(7.6120e-05)
复制代码

! S" M2 Z$ ]$ M, f! d5 ?
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-13 16:27 , Processed in 0.425113 second(s), 51 queries .

回顶部