QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2089|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1175

主题

4

听众

2848

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么% J5 I: N8 ~" B5 ~$ F3 ~7 l1 `$ I
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。# v/ D5 Q  a0 R* N3 |9 @  U
怎么理解梯度?
: W/ Q8 C% x/ H: G( ?假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。0 Z& P! @! L( K1 L
) Z) b) O8 U/ A9 a4 Z6 K9 U. h
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。" i/ o4 Q5 v3 ^- W, t+ f

$ a- L' a& O* C4 d( N每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。/ ~5 g1 u  K' M, ?7 I# E: z8 n  g/ d
) V  Y! Y) R2 O7 _
为什么引入SGD! l, J  _: Q$ J0 f: K/ r! p, Q
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。, r% @; F2 ~" D% z

. ~0 @0 ^3 u& p5 N9 P  L2 i7 b, g怎么用SGD
  1. import torch  J0 c% o) Q2 L+ [0 w3 v( I8 Y7 y# D
  2. 6 F% G! K% m1 `5 Z$ N
  3. from torch import nn
    5 R# x' v! }2 X) X/ f$ S

  4. 6 r+ J4 `9 b9 [, g0 B; E* b& @) }
  5. from torch import optim
    ) i3 ^/ z7 c' B  J/ u1 i- `: I  P

  6. 1 m( a3 G3 U/ W/ X' O. |

  7. 9 i5 z6 m7 A0 ~; S
  8. 1 \; F% s9 v, O* Y- z* m. m  _
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    * ~8 j! {: z2 O7 Z- l
  10. / r  f1 m0 t% x, z& J9 _
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    $ m* i\" K\" f7 |\" a! z
  12. - Z; t5 [; T2 l
  13. . M! A9 w: C* G! G6 x7 w. a6 h
  14. / n5 x2 i  s\" J, B
  15. model = nn.Linear(2, 1). L  j9 O& z1 J5 A1 M; W, D
  16. ( D) T) o2 B2 t# z

  17. / s% t3 p5 e$ \+ B0 ~( h
  18. 7 m8 H3 z( [: N7 e0 l# D' v0 z
  19. def train():7 x+ t# H0 o/ y

  20. ( }\" z4 z% Q' D\" D
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)5 `# K3 @  o6 m1 M7 g0 F1 |( d

  22. 9 u8 }+ A$ T$ a
  23.     for iter in range(20):( R. l9 v0 w& P; g\" r
  24. * L, j0 j/ _& j\" C& H# v
  25.         # 1) 消除之前的梯度(如果存在)' e4 c/ Z; U) V  v5 d+ R

  26. 4 a2 @7 E, w' H3 s
  27.         opt.zero_grad()
    ; f+ z! U8 i1 e0 P5 z5 E* B. M; m+ [

  28. # ^- i$ k* `8 k$ V1 ?7 `& x; J3 [
  29. 3 Q( ?% Q8 n. v; M  W
  30. \" {: ]1 n5 n* Y4 x- g
  31.         # 2) 预测
    \" S5 v0 r$ s! d; @1 I
  32. ; V& Q( |+ `! C$ C4 e! @- }% n
  33.         pred = model(data)
    & a4 r9 w1 q! d4 f: u2 ]8 F
  34. 1 `8 J+ _% u& U/ ?* x. o\" D  j

  35. ; u& J7 v\" m0 f

  36. ' O. z* a3 I; E: Y$ l
  37.         # 3) 计算损失
    ! ~) Y\" p! `7 n7 }

  38. + C2 ^* Z& S1 ~; _, |
  39.         loss = ((pred - target)**2).sum()
      Y, R. Y3 q$ E) V+ s
  40. , l9 O6 H# t6 D+ z
  41. ) v+ q/ w% s' I* K) J$ e
  42. \" w3 W! _# W( Q
  43.         # 4) 指出那些导致损失的参数(损失回传)
    1 e* F7 d+ q3 Q7 ]
  44. 6 n5 ^0 M+ ^6 V
  45.         loss.backward(), Y- d5 n* p4 A& u- v
  46. % m& r& ]& U: i) [$ `' C3 o
  47.     for name, param in model.named_parameters():) L0 w; E2 }$ J3 S, |

  48. / ]8 z7 T! }0 h$ D$ f- P: q
  49.             print(name, param.data, param.grad)
    / [  L) P4 l3 U/ \4 M2 B
  50. ( b& D- I( u* V& `9 @8 T8 \/ [4 V
  51.         # 5) 更新参数
    : a! m( e; S, H

  52. ; p3 N3 h& U; H* K
  53.         opt.step()
    + M$ J: h9 S* ]( E$ B- R
  54. / A: y) ?3 O' K! @0 \. X- u
  55. ; \* ?7 [; l( E' k* U

  56. 1 }$ {6 r# ]: N) t( D, ^( o6 q
  57.         # 6) 打印进程
    & @7 w) N- z3 W1 ]

  58. 2 r, e6 \6 `1 c+ r) ~
  59.         print(loss.data)
    . ]! T/ o$ o6 z
  60. / i  ~+ I1 j. V1 ^
  61. ( k9 f6 r+ W4 @! n3 s5 W

  62. % M+ J7 Y0 t: \+ l) h, [2 r
  63. if __name__ == "__main__":
    ' D\" v- s. y2 t: l$ w

  64. 0 r, l1 ^! S( G
  65.     train()
    9 |+ q0 u/ v' b
  66. - N/ Q+ }  ]2 W. S6 N
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。) L  F* n# {  @  u
' T% [  G+ d- D# O* A7 d" d
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    0 C5 A- C+ v  d( g& V3 v
  2. 0 P+ w2 k' k! h$ _# f# Z4 T
  3. bias tensor([-0.2108]) tensor([-2.6971])
    5 d, p3 C' O3 _4 i1 C
  4. # f  C( ]1 G: u4 }
  5. tensor(0.8531)
    $ }0 P  l! _* ?3 R$ {. e9 u; ?' {5 ~9 F

  6. + @  K# U1 K! _
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])/ p9 v\" p' t2 r! B3 D

  8. 2 ~7 u3 G9 U# B& m
  9. bias tensor([0.0589]) tensor([0.7416])
    4 [& J, o# m% q: ^8 A
  10.   ^; I8 P0 u9 l% b9 Z6 v- v: c, f
  11. tensor(0.2712)
    1 ^! I7 K4 M  v7 ^2 e\" D; P
  12. - }/ z8 B! a& Q3 L$ O4 G$ X( o, a
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    0 }. g. L* F: ]6 z2 ]
  14. / I, p5 G9 q% }6 `6 W
  15. bias tensor([-0.0152]) tensor([-0.2023])
    7 j1 l, i7 K% q# O# {' z* a
  16. 6 Y1 Q8 T+ P1 m; W. d
  17. tensor(0.1529)
    9 {0 a  H# ?+ F+ `  H

  18. 5 H4 J' m0 a& Z, y\" u$ Q\" l& j/ {
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    0 n; Z8 E8 n& _% [\" }3 e3 t

  20. # p# j% i0 H1 L: i
  21. bias tensor([0.0050]) tensor([0.0566])4 l0 X9 a# }5 S( {. W

  22. 1 G2 W! H) W, b. i
  23. tensor(0.0963)
    - z4 g( N# o\" ?. c( u, x7 s( s. y
  24. % l% p\" P- ^3 o$ E/ G6 e) q
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])# U: G' E0 ^+ r9 y

  26.   ^: D- @4 R% r0 i$ G; D1 R. I
  27. bias tensor([-0.0006]) tensor([-0.0146])* K6 k% |1 {6 V3 K2 D
  28. % [, U0 S' u6 X
  29. tensor(0.0615)
    ) X: h6 a0 m- n. R/ c. R5 Q( n$ m
  30. 6 W# V& p- O2 D! e6 w0 N- L
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    ' Z* Z3 |1 I6 l
  32. \" d9 ?5 h$ q8 n+ o* Q' e9 a. x1 e
  33. bias tensor([0.0008]) tensor([0.0048])
    # H$ ^1 u* C: Q
  34. , c& g2 ^; A9 _, ?4 I: U2 Q
  35. tensor(0.0394)
    ) p& w0 `2 p1 j  Q$ ]
  36. ! h: \& M( R# }3 b
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    % ~$ A, Z( c2 X$ u: r9 m) ]; |* J
  38. ) {! f& o- D. f) {' {! I2 S! N& t) F
  39. bias tensor([0.0003]) tensor([-0.0006])
    9 E/ X; B& p2 U0 o4 \+ x
  40. 9 B7 r  m\" n1 v: [
  41. tensor(0.0252)
    ! P5 Y\" q% V4 S% b5 t

  42. - m: F6 P2 `7 u( e) j' `- z- T; O
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]]). Q7 h4 z: y/ c2 W9 H  `* v% E0 {

  44. 3 X: a  K/ S8 _, K& k* S/ p\" c
  45. bias tensor([0.0004]) tensor([0.0008])2 j7 f# O. w5 X

  46. ) z0 b: @& Y# ]8 {) p. ^
  47. tensor(0.0161)
    ! G\" |% E( I  q* k8 J+ Q+ }

  48. , R' N$ I6 _% J: J% l
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]]); ~7 b) B. J+ A! n/ i

  50. * d$ X% F\" G9 S+ z
  51. bias tensor([0.0003]) tensor([0.0003])
    3 D( ^6 M- H& |% S& I% E7 b1 b

  52. 2 W7 I* C% }( K
  53. tensor(0.0103)/ T7 g/ w. b/ k' o! |

  54. 9 P9 `* C% }  F
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])  Z  |3 a- O8 P, U6 Q/ A- J
  56. 7 E/ s7 D( s- S0 s
  57. bias tensor([0.0003]) tensor([0.0004])! n0 x' Z2 y8 C! {1 ~
  58. : H( X, o8 ~) _. |' T' m
  59. tensor(0.0066)$ m3 [7 ^% Y3 X
  60. ! y8 A8 y. j; b
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])/ l9 K: ~9 X; W

  62. 8 N8 t+ y/ S% B; V( u
  63. bias tensor([0.0003]) tensor([0.0003])
    4 b) E$ T7 ?5 |4 N, j0 |8 z6 O, j
  64. 2 ?4 C1 K' ^; Y- f# `7 @( U
  65. tensor(0.0042)
    & E, X% [* `. U. R0 S
  66. : ?\" n- l\" e3 h* k+ J3 u3 }
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])2 |+ E; r; y/ Q- D, T$ b' U

  68. 7 P' K+ `& L# k* R
  69. bias tensor([0.0002]) tensor([0.0003])7 n( z9 M! }. Z4 W1 J/ B

  70. ! C7 ?1 {4 v+ r+ P% ~. T
  71. tensor(0.0027). x\" a8 r4 H, |, R\" P

  72. 2 ~' ]- P, T0 i; R. L6 V' K
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])5 u  g5 P8 B2 A- p) q0 A
  74. * f\" j& G* f/ I1 J/ V7 _# q6 v8 r2 S
  75. bias tensor([0.0002]) tensor([0.0002])
    - l8 _6 S- i. f* U9 b
  76.   C, D# N; J6 w7 A5 D8 Q2 k- p
  77. tensor(0.0017)( s+ U# O+ w& Y& v! y9 O

  78. 5 o: ]% s/ n# c3 w\" }
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    & b2 w5 a& Q- H& D% R! g4 n

  80. / k4 S3 A* [- x6 V* ?
  81. bias tensor([0.0002]) tensor([0.0002])
    \" S* [+ ?+ ?# U$ ^) {, h% f% s

  82. ; x8 ~7 v2 Y$ t3 C
  83. tensor(0.0011)
    : T! y. S\" |1 m

  84. 0 @# {; c/ Z0 P4 P' E. t
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])- Z\" k* j. q7 K3 f! L% Y/ ?
  86. * ]3 z: g6 o: \
  87. bias tensor([0.0001]) tensor([0.0002])\" Y4 @. L; e$ A/ S# [% n1 w
  88. + j/ Z0 {( o6 v1 z/ A
  89. tensor(0.0007)
    7 g8 r3 t) s( B& B

  90. . T/ ?; x# U  u! S0 W  _, a
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])! N, s/ t\" r' a
  92. - Q8 w\" t; s7 u
  93. bias tensor([0.0001]) tensor([0.0002]), O* h/ o! J+ i
  94. $ u, [* Q/ r- i) t5 ?
  95. tensor(0.0005)% J9 k, H  A+ @* [2 B* J( t+ v5 C

  96. , R6 i1 ?: B1 C& c\" Q: [+ G' T2 t4 `$ `
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])) A\" Y0 `2 A6 B9 S; ^

  98. % Z3 ~2 |3 x- ?1 g9 g. F6 ^
  99. bias tensor([0.0001]) tensor([0.0001])\" t5 Y. U6 @8 |! ~/ [
  100.   B% r1 J) L+ l) p) \2 T% T# S) z
  101. tensor(0.0003)3 o6 M/ X\" @/ Y7 |+ x6 @2 d6 J
  102. 4 x! q8 f0 r& R. B2 S
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    1 k) w# O9 l2 m* y: H\" P* o1 I

  104. / D7 u. z5 [3 l% a, Q! o; R+ K
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    5 Y3 K7 B: S1 ^

  106. / E$ j/ O- j2 n
  107. tensor(0.0002)
    9 @! k( g  L1 Z# X; }\" o% }1 Z
  108. ( t3 r' Y+ }- `
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    5 b, ~' l# V7 Z2 [! |
  110.   A  n0 `2 q2 E( y* T; j( }& H$ W
  111. bias tensor([8.5674e-05]) tensor([0.0001])# Z5 D6 h# W$ K. ~

  112. - p* h9 Z' z0 {; ]
  113. tensor(0.0001)
    $ Y4 t3 L& [\" p
  114. ' B! b6 U* C( z, A/ o6 ]( F8 K
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])+ e) w6 E5 F4 b; P( S0 f# @' ]

  116. + m\" v. }. ^; ?4 t2 _
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])! s( Z( w' H# S: L9 X\" w8 J\" W
  118. 0 v9 b- l2 W8 N) R5 g0 p2 O
  119. tensor(7.6120e-05)
复制代码

  u2 o4 `7 i/ d
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-8-2 09:26 , Processed in 0.592614 second(s), 50 queries .

回顶部