QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2622|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
/ j7 s% j9 S* uSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
& E  Y: b  k+ g! N0 {怎么理解梯度?
, z# [4 A# n: G$ S假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
0 w5 v+ i8 ^8 `  e, D6 ]4 G; Z9 z1 ~
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
4 x! |, Y3 t' u; J
$ t% }* M+ N5 V: @" S, _每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
9 `! k: \* ?, |8 d, f
  M- q9 N4 L+ d5 }% J为什么引入SGD
& p" A* P3 d0 x. p4 l深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
  r% a" J' e0 R2 u) L7 V
' e- s' R4 h* n6 h# s( X. k# m* U4 b怎么用SGD
  1. import torch
    9 t+ ~. W( ^' U* z
  2. ; P\" v5 l! ~7 G# s) l
  3. from torch import nn. T6 ^4 r. D3 S/ s( k( ?' g

  4. 5 I+ G, P1 o6 s/ k/ M  ^$ L! B# v7 i
  5. from torch import optim$ B* E4 [# z% ]$ t: T- [
  6. $ A- g. w7 G) k6 _' V\" j% z  f; }, P
  7. : ]3 x% M4 X( r3 B- b

  8. . |+ T' W: N, ?
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    . T+ k7 C; x4 c* D! `# f8 U
  10. 9 H( z7 [. t  p) d! q  a
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    ) ]% E2 k: o2 g  @$ v

  12. / F8 _7 K* I, w3 ?2 t4 C

  13. \" g\" c+ V- j% D& v( L+ n$ Z* J
  14. 1 {0 C\" K' p- t7 o
  15. model = nn.Linear(2, 1)- V$ @2 J7 t' W2 a' V\" H: u\" v

  16. + K% x8 z0 @7 V5 i
  17. : Y) A2 [% k9 Z
  18. 9 ~/ ~- m5 l- }0 V$ d
  19. def train():
    - a/ @( C/ t% p6 e# j. N. {& X\" q. k$ P

  20. ! `# \9 u3 _\" J
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
      }1 a. I7 C% a9 L

  22. & S( p7 y# m  c5 R) q- Z
  23.     for iter in range(20):9 G5 n- O4 g2 L! `5 |
  24. # D- b6 c5 n/ S9 S\" V7 ^* H
  25.         # 1) 消除之前的梯度(如果存在)0 A( n* d+ m2 c
  26. 0 U1 j' D3 c5 p3 n3 \* D
  27.         opt.zero_grad()
    6 z) G# [/ O+ e9 p6 s% s' g  b

  28. $ {$ q- |- o4 R- I$ r& q

  29. * U+ \& J( C5 t8 c3 I5 A7 i2 Z9 P; O
  30. 1 H0 G( U5 U\" `3 A/ G, o
  31.         # 2) 预测
    : B8 _2 M6 \, p1 g8 e6 r

  32. ! p7 F+ y% H) b- i+ V
  33.         pred = model(data)
    4 ?# ]) X# R( L% J, U& R

  34. 3 L! c- y0 r) r
  35. # P$ P- e% B/ f

  36. % _$ \9 @/ [2 ]; {: X
  37.         # 3) 计算损失
    , L& [; \3 v  l9 d# W: [. k& x
  38. \" d. v* \- H- Z9 ]1 I# R0 O+ k
  39.         loss = ((pred - target)**2).sum()1 X( q- q, {% _  |

  40. ! q9 I3 M\" N1 M
  41. - ^; u/ J; E2 n8 A/ @3 j

  42. 0 N! v% ?! S/ l# E
  43.         # 4) 指出那些导致损失的参数(损失回传)
    ; Y. Z8 t8 E  o8 g
  44. * Y3 Y& l, d7 @1 T
  45.         loss.backward(), L/ Z$ U, t: T) k6 p( N0 i

  46. % C7 ?0 Y1 K6 p\" \; v
  47.     for name, param in model.named_parameters():: O; E3 ?: C0 {* [& }' d! D1 a/ F! C
  48. 8 C8 p, I4 L$ b, r$ Z
  49.             print(name, param.data, param.grad)
    ' W8 ]- H7 r; }/ q. l* ^

  50. + T- ^4 L# q! X\" b: S
  51.         # 5) 更新参数9 I' P; Y; x* B\" g/ t1 V
  52. 5 L* N! m# V5 a6 _1 P% |1 l
  53.         opt.step()
      J+ ]2 o/ {. c2 O8 I

  54. $ [7 |9 w& }: @

  55. : D: x! e1 W0 X! V. w/ K

  56. + ?+ A9 d5 l; d9 n% @' ]
  57.         # 6) 打印进程
    ' j: s9 I8 H& T) P
  58. 2 k; \8 F' ~/ Y  A6 l6 ?# `
  59.         print(loss.data)
    2 d  {8 N/ {3 R0 h
  60. \" K+ f\" t; }& r7 I' C

  61. $ W. \* ~( ?6 F+ x+ J  \5 L1 R\" B
  62. , Y, @6 M% o9 Z
  63. if __name__ == "__main__":4 W3 `0 w8 ~' y6 c4 |/ H# y

  64. 0 D$ M: B9 ~$ C) J) r
  65.     train()  L/ I. F0 \2 |& C1 r
  66. 8 X! e* J: w% p8 @* s0 V- i$ r
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
) H% N4 n9 x0 O" P0 ~
0 k- T0 x. ?& H; S" n计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])6 A9 t' V7 r- ]& t
  2. . W/ y& U/ Y- p- o$ b8 {\" o. T
  3. bias tensor([-0.2108]) tensor([-2.6971])
    8 l( l7 ~  F  s6 O6 _' H. d

  4. % V1 I% g  O3 \' T2 G$ N* c4 \\" V
  5. tensor(0.8531)
    , P4 A8 Z! x$ e  B2 M  i+ [$ q0 `

  6. 5 h$ T\" h  G: y, b7 d- L' V5 n\" A4 I
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    $ `\" p' b/ t1 _5 `

  8. 0 @9 s1 h; U8 T. x& |, i
  9. bias tensor([0.0589]) tensor([0.7416])
    + M- t: I\" T5 Y+ ^, k5 g
  10. 1 ?8 `0 L+ i1 G\" j
  11. tensor(0.2712)
    2 o+ V5 H! e; C+ M3 x- f- X
  12. ' ^, i3 A) f. h
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])- i4 M. E1 |  l\" P
  14. 1 G$ U# n3 I# f! V
  15. bias tensor([-0.0152]) tensor([-0.2023])
    , c; ~* W: D! o2 M9 ^4 D2 T5 x
  16. % B5 |  [# i4 J9 K% o! v/ x
  17. tensor(0.1529)6 ?' n, x( c2 e# ^4 C, a

  18. ! p. C, P$ s: v+ P6 ^
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])3 \6 A/ ]6 r3 J0 i, F+ G: g
  20. 8 r: u0 M6 x\" d  d* Q3 N& W
  21. bias tensor([0.0050]) tensor([0.0566])
    7 J, N  u5 ?% p# c
  22. ( y( \2 _% v7 O' n3 C2 ?\" p
  23. tensor(0.0963)
    . O) Q' n! @6 r/ P9 b6 M- o* z
  24. 5 F* |- t8 @, i- x0 L
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])1 L3 v' N! j: n, M5 e7 C

  26. 2 K0 I0 P. g  O# w; i! M$ }
  27. bias tensor([-0.0006]) tensor([-0.0146])
    - G7 Y8 `/ ]8 |: o

  28. 6 P# M# |! C\" i# I
  29. tensor(0.0615)
    3 r% s3 Y+ {  n5 O\" G* F% Y) S
  30. - Z/ x$ N! I: ~* ^/ H
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    : c, w% y# e& ?6 v! j! u; F

  32. 9 A9 F6 K+ R. n0 T9 q$ B, n
  33. bias tensor([0.0008]) tensor([0.0048])8 {( r7 L1 K9 r

  34. - E. @  {6 X7 z1 }
  35. tensor(0.0394)
    ' e7 t\" x\" S8 M7 G

  36. & q) }4 G( h8 d; {
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    ; g- P! v. u& g6 L

  38. 7 k# {. c+ [' a2 S
  39. bias tensor([0.0003]) tensor([-0.0006])
    2 J4 a\" }) d3 w. A% k

  40. \" z& W2 ]1 y- q6 z4 B. U1 Q
  41. tensor(0.0252)
    6 N( o+ D. [0 |5 j

  42. 4 o4 ~\" j$ f! k6 @  q! H: W! `\" \
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    - E5 u1 d, k- i0 V1 E

  44. 0 c1 F1 S1 U: P  Q1 n' {  a
  45. bias tensor([0.0004]) tensor([0.0008])\" E1 u7 I: E) e3 E3 v$ I0 r! J

  46. ! `8 U8 c\" @) v: f) d
  47. tensor(0.0161)8 u4 _4 k8 l4 V' _

  48. $ R6 _5 y+ ~\" i0 P  m# o
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])0 r5 u6 h2 e- u5 A8 u
  50. . |2 j4 F4 E+ E\" _/ O
  51. bias tensor([0.0003]) tensor([0.0003])5 i( l0 y, ?# ^

  52. + L( ]. v& n5 S0 d/ Q
  53. tensor(0.0103)3 q) [% Q/ ^1 _6 \. I

  54. 5 D4 W4 L- A0 U! T, b2 D! d5 s
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    , v0 q8 b9 m) M3 o' b, T

  56. 7 k, `3 t$ B7 F) X
  57. bias tensor([0.0003]) tensor([0.0004])7 p7 x7 @  I$ n
  58. 5 c1 B& O4 f! g5 h! o  P. a1 C
  59. tensor(0.0066)) [+ d# Y0 v$ M( T, h
  60. ! E) N* L8 X1 X; R8 y) V) H2 U; t
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])5 F; w/ c) M) a+ K9 d
  62. - u& B\" @0 j+ _, N; j& y9 G9 c
  63. bias tensor([0.0003]) tensor([0.0003])9 y- {% g/ t6 Q3 x+ H& d

  64. # g% r  [. I' v3 v5 r2 x8 t
  65. tensor(0.0042)& e5 R7 D( l\" h! B, x$ E( s# J

  66.   d& Q0 b! g6 k8 \! z4 y  D) X0 }9 C
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    1 z9 f+ Q! ?- q3 k
  68. # u. i( F( S* V4 u: N  z
  69. bias tensor([0.0002]) tensor([0.0003]). r' p  g! k1 l* |2 S+ j/ P/ W
  70. $ y( L- J6 {8 ]2 d
  71. tensor(0.0027)1 _( X# w* I. B' M' F* g4 w

  72. \" e3 e$ H+ \7 m- e; y\" Q
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])( S# f2 K/ C1 K8 b: J\" D9 P- f3 S
  74. * F* T' @5 P! Y8 \+ b
  75. bias tensor([0.0002]) tensor([0.0002])
    3 s/ K, T3 Z  ]( d
  76. ; J1 o0 `0 @, d6 F/ I/ t, j
  77. tensor(0.0017)
    ) z! e+ y8 x. X0 D9 u2 Y0 z
  78. ' v\" E, _! e5 ?7 z' m
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    ; b7 O6 p* c6 h: a

  80. , N# }3 p' k6 _7 d
  81. bias tensor([0.0002]) tensor([0.0002])# g6 m* q9 t- D\" B  O6 ^

  82. , j1 s& n, w: Z& R! m
  83. tensor(0.0011)9 B9 K+ r5 K; a# ?\" L

  84. + T  ]\" Z1 _; z
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])( k/ G5 V, K9 i: o% s\" ?
  86. - ~4 o* r$ t1 Y# ]5 F4 x
  87. bias tensor([0.0001]) tensor([0.0002])
    8 T& w  \# B/ E

  88. $ ]( z8 u- Y6 U+ |! j; h) W8 p
  89. tensor(0.0007)# z  I: h5 O: a, _  C' I

  90. / Z$ i- c4 u- x# e. c! _7 ^  s0 _9 a4 T4 x
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]]), N( F! v6 |4 I& [0 w* W1 }\" }( i

  92. + H) G& i1 W# D+ `1 D
  93. bias tensor([0.0001]) tensor([0.0002])2 s6 O/ O( g0 e( G

  94. 7 R. j1 n8 h+ n; [
  95. tensor(0.0005)
    4 G/ e8 s5 q- j1 S, G8 q
  96. 3 [- B9 ?7 _' e7 x- r2 \
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    ' Q\" D3 k. O\" \! o* ~. ?

  98. 7 D0 h4 U/ {# Z
  99. bias tensor([0.0001]) tensor([0.0001])  j: u: s\" p4 n) O5 @

  100. ! ^; I4 M4 N* p
  101. tensor(0.0003)
    : e9 q& o) c) m. V6 v

  102. 3 |+ ~; J) o$ H* {0 \5 d1 m
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    2 T# B: [3 i. m9 X

  104. % {* v* U5 a& V  I
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    & f2 E# i$ X, G8 ~6 Q% r0 Y$ G
  106. ' {/ u' O! d$ N
  107. tensor(0.0002)
    0 d; Q+ U( g+ n

  108.   i0 Y* u% \2 Z) E, \
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])\" V6 A1 I( r! e) S' o
  110. ! I6 _+ a  M8 s7 [. \2 q
  111. bias tensor([8.5674e-05]) tensor([0.0001])\" W' H0 U9 c+ g7 f- ?# o. T
  112. ; w\" x/ F9 F( q# r* n( c
  113. tensor(0.0001)
    ! e8 H/ I# m' C$ {. a4 N* |# u

  114. 2 g! f0 t5 _# q* u  W/ U
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    1 F# I* A- N\" p! ?' J8 g2 Z
  116. 8 o; c5 G; o& J# b  o4 b+ H+ l
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    ) a& T& N\" H( N  Z. W  X# h% B5 A
  118. 4 L  v$ B' j% W0 i
  119. tensor(7.6120e-05)
复制代码

1 D7 B9 [$ N5 z! t
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-11 01:46 , Processed in 0.296512 second(s), 51 queries .

回顶部