QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2310|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1176

主题

4

听众

2887

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |正序浏览
|招呼Ta 关注Ta
SGD是什么3 x) w7 r: ^: \5 G
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。# O, a3 |* C$ N+ ~
怎么理解梯度?
, D* \- l5 z  p% D  g# T( I; `假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
/ w6 g& s1 x7 L; Q+ B& N0 {) K( n+ p0 f
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。' w( P1 ]  n4 ~. I
. P% P. B* M  R, Y. f
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
0 L: m/ W, B5 ?7 n" s. E4 P+ A9 b' ?0 u) b: V
为什么引入SGD
0 C/ p2 K. |7 _9 c5 ^5 L$ ^$ y+ e+ S; F深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
) k' F3 K% y! F% q" {9 x2 |5 Y9 K; a
' m1 o% K- l1 W) o; X! f$ _" Q1 @怎么用SGD
  1. import torch
    * J+ o4 J$ M. U3 _. o

  2. + j: Y% @$ F  g/ v* Z
  3. from torch import nn
    , s( D$ s: f2 R

  4.   t; e; Y5 [  T1 r8 E2 R
  5. from torch import optim% @, N' k2 z( _  i' p% p: v
  6. ) S# c4 C% a5 R- H9 p+ F8 ]; T

  7. 0 F9 {6 C$ R$ S3 e8 N  v; A3 o) L, T
  8. : i7 A, `4 J6 o8 I. Q5 `
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)5 _& c! J  `/ [, ]( l

  10. ! T6 v5 K8 q7 t  k9 v& \( B
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)$ _6 P. L% u# E0 W. U& \9 \/ Y
  12. 9 S! W5 n, X  Z
  13. & i2 R  `, @' G* \9 H1 [

  14. : @& U4 m+ V\" c4 d! n6 {
  15. model = nn.Linear(2, 1)9 Y\" K0 ?! j' @! d6 Q
  16. : p9 m; h. B! O0 t

  17. , y( k9 U4 X' F$ v: G4 M3 i4 k

  18. 4 J: j; Y1 u& [( |2 e% h
  19. def train():
    , I4 d. @5 E7 B) ~; ~6 m5 h
  20. 8 p+ y- u& D7 v+ [! l
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    4 ]2 \, @- G( D+ q: s3 v; `& A

  22. 5 y6 K7 S: P1 P
  23.     for iter in range(20):
    ) l' o% A7 V) k6 |5 p
  24.   C: w4 {; r\" b; v! `
  25.         # 1) 消除之前的梯度(如果存在)
    $ n3 h3 e( _  F+ g$ b. L2 A
  26. 3 O7 t+ `( {, _) \- B
  27.         opt.zero_grad()2 q) z1 U% W3 z+ v6 s; E1 `

  28. # n3 P( O* y& J9 m: j( C7 d

  29. \" V- q4 |& e% f* J

  30. # d9 x& o0 L9 F' D6 H
  31.         # 2) 预测
    $ _2 j* e6 R/ h5 l1 e% Z  u

  32. - V; s( i0 P\" t) ~( R$ {! j
  33.         pred = model(data)
    , ]$ m. v& q8 V( r

  34. ) s\" p7 D# J. F8 y

  35. ! o( j% R% e4 E; T- {6 Y

  36. ( Y+ F* N, \9 z; U
  37.         # 3) 计算损失( `8 A$ q8 j  q7 n) N; p
  38. ) j4 b- c% O  D
  39.         loss = ((pred - target)**2).sum()
    & E/ N  I( ]: B' O% d( u' b
  40. 6 }1 K' h9 F0 Q8 |* j

  41. ' K' k% H4 u% {\" L

  42. ; y# k4 |0 `/ y/ G6 H! X7 t
  43.         # 4) 指出那些导致损失的参数(损失回传)
    5 ?$ l6 D; S6 }. [7 f! c
  44. ; ?9 n; A6 c& A7 P# `7 e- l\" O
  45.         loss.backward()0 ~, |4 {+ |% _7 m! t. P( T9 ?

  46. ( w/ H' I% j8 o; W2 X! a7 ^! y
  47.     for name, param in model.named_parameters():
    - C, j\" ~7 r$ x5 J$ K4 I# r  L' j
  48. 5 {# G$ _: U. Q, J3 X- N' X8 \! J
  49.             print(name, param.data, param.grad)
    % \5 Y1 S) J! c# \7 p) y
  50. * b) i1 B, d4 f. }3 w. l
  51.         # 5) 更新参数
    & w, h! c; _& z  L: a
  52. : C6 g  r4 P' T: J- \
  53.         opt.step()0 \: `: _$ `4 \/ w% k
  54. 0 j6 c  c6 ]# ?. I

  55. + w) v% }) ?$ P/ h6 e/ ?( y0 T6 e  r
  56. 1 N( v, x) o) ~  m; i
  57.         # 6) 打印进程
    ' Z' U1 S9 d6 X) ^! d

  58. 7 \\" t4 u\" s2 z  d7 P7 c5 d' x
  59.         print(loss.data)
    1 e\" S3 w8 y) w& y5 l9 `  U1 p
  60.   ^( V9 G1 j% b/ t) t3 j' y- U

  61. # n; g) q: @( r\" }7 }+ _3 l

  62. ) h\" p4 A% V( T9 r5 L1 d' w  i
  63. if __name__ == "__main__":! g* `; H/ b$ V$ G- k

  64. - K: Y  T; r* C( |- V# S
  65.     train()
    7 [9 ]' K3 A' b& e

  66. 7 L. @6 b6 z# [$ L! [* x: N$ o& O0 p
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。3 a: y/ [* I$ J" y
6 r3 J6 t5 \! K7 K5 n
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    ( F- b, V2 M! G$ L
  2. + f: t+ B' e, f$ e, E
  3. bias tensor([-0.2108]) tensor([-2.6971])
    $ |! p3 A# J. q# i4 w

  4. 1 j# W6 @  ?/ r\" t# t5 A
  5. tensor(0.8531)' z: ^+ G3 t) Y) `2 ]

  6.   ]) E2 ?6 E1 h
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]]): ]5 }! E8 E$ O- C! g. s

  8. ) [0 w9 o, x* P6 d1 b
  9. bias tensor([0.0589]) tensor([0.7416])# x3 l' ~* z0 x, X* Z; _
  10. ) h' Q$ u- u/ @6 K) i4 q\" m
  11. tensor(0.2712)+ q. B) u  A' o+ o% R

  12. # ^7 C- t2 \/ [9 ]
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    3 p8 A  |$ ?5 V  |7 Y4 o

  14. $ k& ~7 j4 |; Q
  15. bias tensor([-0.0152]) tensor([-0.2023])
    % G9 i\" t( a\" B

  16. 8 M; T3 ?  ^2 S2 M  }0 y7 k
  17. tensor(0.1529), z5 g( d$ P1 l3 D( ?
  18. / \% j' ~0 m& b: J
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    1 r4 x0 a' M+ N' ]

  20. ' E\" }3 N\" p& ~& D
  21. bias tensor([0.0050]) tensor([0.0566])4 n! r7 m7 g# a& \( n

  22. ! |3 E. D: d7 m& E6 h: k
  23. tensor(0.0963)5 D( L- b; N5 v0 k9 ?0 e: I9 b4 t8 M# n
  24. ' t- A0 z  u6 O* T2 }' t
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    / l* Y6 D! H; m7 g\" W
  26. 1 ?, O' {5 c6 d9 |5 b. t5 l
  27. bias tensor([-0.0006]) tensor([-0.0146])6 q- `2 ?* V4 L; w
  28. + E* A2 @. R+ g7 b: D\" @7 f
  29. tensor(0.0615)  x7 r( A2 _6 N

  30. + ]* S4 [( L: w5 l0 ~* |; w8 v1 P
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    5 L9 m$ `$ g# J/ W0 c6 N8 W

  32. % v4 U6 K  M$ U0 t1 H; W% g
  33. bias tensor([0.0008]) tensor([0.0048])/ ~6 b\" @4 {7 D% R: {
  34.   R1 X: n& p8 \
  35. tensor(0.0394)
    - t+ V( p* R6 W# ^9 _' ]' I, k

  36. ) m# }! i  K) v6 o: F- F) Z) ]
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    6 e! S3 j! d: R# V# \

  38. : b. W+ b+ I: `
  39. bias tensor([0.0003]) tensor([-0.0006])5 E! c  s. s$ _+ Z) s2 s
  40. : x8 @* l; i: ?
  41. tensor(0.0252)  b' w: {7 R- M) Y: E4 P
  42. / @6 P5 C' f3 W* y
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])* Z1 c  L) S; `; J

  44. - M) _5 X' r' k$ r5 S
  45. bias tensor([0.0004]) tensor([0.0008])
    * n4 K  v+ g) i( X* {( U

  46. / V2 g. d3 a7 ~\" _; Z8 |- _
  47. tensor(0.0161)
    . U$ U, z: s: I
  48. 6 P' u* C* g7 e+ _' w
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    6 v% |- u\" d/ H! J/ C

  50. 2 r! K: j. f\" {6 D
  51. bias tensor([0.0003]) tensor([0.0003])
      `* T8 S& S* F' @/ @7 v* u1 I

  52. $ I$ v( m4 G3 H- b- I
  53. tensor(0.0103)/ T/ [9 m  g# p0 o0 z. b/ i; x- x
  54. 6 J5 p+ R7 o6 c
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])7 N# n) f  }\" H# B5 E% _
  56.   U/ T1 S8 {# x
  57. bias tensor([0.0003]) tensor([0.0004])8 K\" b( r+ v/ V8 S5 y: P4 u

  58.   E! b! c# Q, c8 j
  59. tensor(0.0066)% h2 J- y/ |, V0 V+ {' J
  60. ) q1 \  D\" w) o5 N; O
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])\" Q( ?' A: V* D, V9 G% {, M5 d7 d

  62. ) p\" {4 T, h- I: ?
  63. bias tensor([0.0003]) tensor([0.0003])
    % C- l2 C- q; ~0 }$ G

  64. + d# Y$ {# j7 G* b4 l
  65. tensor(0.0042)1 j. h# q; ~! Q( }
  66. 2 V- p3 o6 b7 l( L+ }! i& W
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    5 X, N1 z' c3 x2 M+ d' k% z\" W
  68. * X0 f4 C) @4 J
  69. bias tensor([0.0002]) tensor([0.0003])3 ?+ J7 Q  H+ S5 G
  70. ! D) {! j6 F* E6 L. n
  71. tensor(0.0027)
    9 Y3 {' G& P, i3 Q* S' N6 ~' o

  72. - W; O; Q# p4 |\" [
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    3 N3 e+ s+ s4 p1 q
  74. 8 e7 V9 P* V' K+ f
  75. bias tensor([0.0002]) tensor([0.0002])
    1 }$ ^% q\" R* I
  76. + w: h3 a/ M. a! }! \9 U' o
  77. tensor(0.0017); x% k4 ?4 ]1 C/ S3 ^$ ?\" _+ G- u  Y
  78. % Z& V3 e$ C- V* N\" w6 _: j
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])& Z- a* k# U/ s5 Z\" x9 J4 T
  80. 4 i$ M4 U+ p5 G& w  m: J
  81. bias tensor([0.0002]) tensor([0.0002])) k; ]' y( `' z

  82. 3 `3 g& ]5 k% L/ g, J
  83. tensor(0.0011)- m- @$ M( r' @9 r- b3 w

  84. 5 Y4 f5 _( S0 h1 e4 \. @  p
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])! e/ J1 p) \5 F3 f* Y

  86. ( G8 p; O9 X: t3 o5 l  x( P: _
  87. bias tensor([0.0001]) tensor([0.0002])
    6 y3 E) G/ o3 I- a2 `7 J' S! c9 @
  88. 4 G  ?' k1 Y  p. {) X6 m9 i- }: `
  89. tensor(0.0007)8 p& H9 o7 O. a, k( m. @4 I+ u
  90. . \4 x0 N) H% }! N* ?% e  G  x
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])' `+ Z/ l. [. _# ~6 W

  92. $ J6 i$ m) j$ i4 d- h
  93. bias tensor([0.0001]) tensor([0.0002])
    . {! O; ~& A% ^\" K8 y2 L
  94. - c$ E: J8 @0 B. n7 \
  95. tensor(0.0005)
    . n; M- N6 y5 Y
  96. 0 u' M2 }\" _8 @* [& R1 \& U
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]]), W' U! b+ u+ }! R' ^2 s1 S$ E

  98. 4 f% _- y0 v1 g4 R) I0 D% a
  99. bias tensor([0.0001]) tensor([0.0001])
      O4 Z6 _\" F3 o0 E  P2 P- }
  100. - T; V/ g' |* W, q1 q4 e
  101. tensor(0.0003)
    - N7 P9 L! W& N  z

  102. ; }5 [$ e4 o# {: x! G- [
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])* y9 }) @' i' x: V+ i
  104. : W\" g8 R, H9 \% q' @
  105. bias tensor([9.7973e-05]) tensor([0.0001])4 d4 W8 c& D/ M4 |$ `9 R2 H1 x8 R

  106. 3 U\" X# f, G4 ^2 I% K$ f
  107. tensor(0.0002)/ k4 d) T' s9 O( ]0 b0 |
  108. 3 p' e+ A' Q$ w. n! Y2 J
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    2 w& G. p1 K$ |% H

  110.   D: v7 u) H2 g/ L( |
  111. bias tensor([8.5674e-05]) tensor([0.0001])0 U) O# q  s. H+ c+ n

  112. & F- N( _! i- j/ K1 p9 b4 R9 k
  113. tensor(0.0001)
    & e! T9 C4 {\" j: ?0 x( F& K

  114. & f. h0 \5 ^, G- v/ |: r( z+ G
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    / O& G. D: H; V: U6 y5 r* t\" y/ @. i

  116. 8 J6 R3 x. m2 _; x7 ~
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    % m9 G* ^9 ~* V; u6 V
  118. # r8 A/ `& a0 K9 M$ Q
  119. tensor(7.6120e-05)
复制代码
7 D3 `( Z, X$ P* i
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-10-31 04:51 , Processed in 0.598056 second(s), 51 queries .

回顶部