QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2040|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1175

主题

4

听众

2823

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么, P* }- F& V0 M: c
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。% b0 ~" b7 N* [" D* |2 O
怎么理解梯度?
/ r9 b& Y/ j" a9 K假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。" C. x' l4 a4 N' g
! S5 Q4 m" v5 |9 p- B% `+ z
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
3 j+ p- k8 i! W9 m7 Z
5 g2 U: u0 K3 \3 N* U5 q2 z% j每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。+ x- e. Z4 k$ l

4 V1 H8 k2 m7 U; G5 {为什么引入SGD" {/ L0 B4 Z9 k. r3 k; l# a
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
6 N6 S/ n2 E! f4 s& d; q( j# X- q- y+ E! N3 k3 Z
怎么用SGD
  1. import torch) _% d3 R6 f- \- l( F2 ]+ C3 R2 d8 ?

  2. # A9 T5 t$ ^7 M( o# v) i
  3. from torch import nn
    : m, J7 o. w7 I, J# w3 D0 b) Q% P

  4. 5 E8 y5 t. ~# O( R1 S
  5. from torch import optim
    7 V. v0 T2 d8 W. ^1 Z) y

  6. 6 `\" U/ C* T1 t- @; l- _, U/ L8 r
  7. ' q4 K( d5 k+ l' ~% ~
  8. - ?3 n7 b3 ^\" n4 N% s
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True); u1 e( p' j) W0 x
  10. + B, ^0 a9 G& k9 N
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)9 j& V8 m4 [/ S' F6 O9 E
  12. # E. U/ B6 _; u. @+ t2 p

  13. 9 Z) {+ {. W$ K( x) Q

  14. - C. }5 W\" W* Q- T+ W4 y/ O( S
  15. model = nn.Linear(2, 1)) y! f& A) S( x7 b' r
  16. : \) }9 j6 K% S- [: d\" c8 O4 J- k3 b
  17. 8 K6 s# I# ^% `* [1 z7 T4 U
  18.   V9 K6 o4 f3 s+ ~; G
  19. def train():
    , Z* W$ M& U\" u2 |2 K

  20. 0 w/ K8 t* K, u
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)$ g1 L; _8 @9 k1 Z

  22. % N( P\" W$ @) l  z3 X3 J: D8 r
  23.     for iter in range(20):
    * ]5 T% W  Q# D8 \$ s: x  p
  24. 2 s$ N1 m\" K. m0 m
  25.         # 1) 消除之前的梯度(如果存在)
    ( M, l7 F2 f# H& |. a# {

  26. 9 U+ ?' P1 i# ]( O5 q' m
  27.         opt.zero_grad()
    - n& S8 y0 W; c1 ^. q8 `

  28. & g( J' L$ \( N) _

  29. . {* X. |\" u\" q, y( L. A/ |
  30. 7 C7 S  [$ r+ O
  31.         # 2) 预测
    , W2 {% I- Y, w( I; K  f4 E$ \) d
  32. 5 S' O# T- j- F6 [& v) d4 O\" a! L& `
  33.         pred = model(data)
    * e' k\" ?4 q' A1 k
  34. + o* X- p6 C0 d0 g) v% k, {  u- x
  35. 1 T( w- \/ D% x1 p

  36. & t, Z% n# ]! V/ g- p5 K
  37.         # 3) 计算损失
    8 N4 Y\" e: n9 L8 m1 g. O
  38. % b1 l' M. \' C0 H9 c4 f
  39.         loss = ((pred - target)**2).sum()- H: x: D9 |3 g. |

  40. 5 f' E% x2 [* p, b& [, \

  41. 0 L: K0 {& e  T% o3 x

  42. 4 {7 f2 m. M1 m
  43.         # 4) 指出那些导致损失的参数(损失回传)
    ( a! k& ~/ k6 k# k$ F

  44. * K0 k/ S: P\" u# H/ A
  45.         loss.backward()
    ) M$ k* l( F- E' i, C) u

  46. 4 _6 \\" z: {. q\" v
  47.     for name, param in model.named_parameters():  _6 a- q. \8 @; x  I& G. O

  48. % g% C! L2 n4 w: I7 B. A8 d
  49.             print(name, param.data, param.grad)+ r: ^1 t% e3 E; l1 @

  50. 9 b% r0 x- A2 ^9 E* m+ D' j
  51.         # 5) 更新参数/ |; `# s& [& s* U

  52. 5 b# [' [9 t/ \) Y& W8 p) [, A- v
  53.         opt.step()
    7 i; {/ C/ _  b$ e2 \3 W% T

  54. * c* y4 c1 c* W4 x

  55. 0 K% L* ?8 E7 W& X
  56. 3 R7 q0 m5 K1 N2 ^  |+ a1 }2 }
  57.         # 6) 打印进程* t\" B) i6 n/ C' K; K* d

  58. % d- y' ]' z% J( |8 a
  59.         print(loss.data)
    5 Z! d% N4 z8 k3 r& b% C1 `
  60. & i$ g: p\" g- ?\" o
  61. * @) F$ Y- A/ x+ q
  62. 9 O( e( D, x$ Y* n( I; F/ S- h
  63. if __name__ == "__main__":
    5 k. i\" i5 `9 N, ?* k/ Q

  64. 1 I% ]' N& y7 G( I2 |
  65.     train()
    ; a0 K* |\" I3 q0 h' @8 X: p% v6 g! W

  66. # f0 Y5 w0 p, l! L: J! W
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
4 ^% Z; S1 M) J: T5 C
' u2 j0 ~! [  D& g' t8 g9 N计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])  f8 `7 h5 t6 V& W

  2. ; H/ U; m\" S3 X) V7 C% K
  3. bias tensor([-0.2108]) tensor([-2.6971])1 p& x, }6 E' ?1 d/ B: H2 n1 R6 A

  4. 1 v3 T1 {/ r\" I- H$ A( ~) m
  5. tensor(0.8531)$ |% p\" j5 d$ a& _1 Y' ?+ }' X4 r
  6. 0 l& \- F# w( V( c
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])) g! ^& Z; |; {4 }0 R' N9 Z2 @
  8. \" h2 G9 w: v  H7 \  r
  9. bias tensor([0.0589]) tensor([0.7416])
    3 y9 e. o\" C) l, y

  10. . @% G4 N: Y/ [* R; Z* A
  11. tensor(0.2712)
    : [% d) U- j, y\" Y: D* w
  12. ( Y: h8 @  @0 n+ T! c4 _' ?5 s
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    ' h' i\" }3 G5 T7 I, \4 ?1 g
  14. ) i7 q# q( P7 g
  15. bias tensor([-0.0152]) tensor([-0.2023])
    : l* F: L- [: S4 z! ~

  16. 5 U4 H7 F# v8 H5 W2 d3 c* O
  17. tensor(0.1529)
    6 Y\" i: j7 {# z
  18.   _, k& l( b& B
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    1 ~3 U5 ]3 m8 A9 X- J% Q; U
  20. & v' |( p: a% ?5 b- F- S  s
  21. bias tensor([0.0050]) tensor([0.0566]). w: O* H5 e( {0 z! r% P
  22. / W\" u* v3 v$ @1 d; |
  23. tensor(0.0963)\" K$ r6 i2 ?; Y: G
  24. 3 l4 J, o8 }/ e6 E  c
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    ) R) ?+ w: K- C

  26. - l6 l; {: Q; k( ]% W' I
  27. bias tensor([-0.0006]) tensor([-0.0146])! F, w+ W- ?) q8 J2 @0 M4 J
  28. ' a: r; |5 x/ {1 [\" H! @* e
  29. tensor(0.0615)
    5 E; @. y- h  [$ l4 I7 n, w\" F
  30. \" ?' s2 r8 ?1 z. ]. A% d
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])2 y8 F0 E- A1 ~) [
  32. 8 ?; \1 ~- f' [; ?* x- [2 ^, I$ L7 \$ w
  33. bias tensor([0.0008]) tensor([0.0048]); K! r7 L' e; h( y( J\" S& b

  34. ; W- ^* K0 x! d# Z/ m4 d' F: S
  35. tensor(0.0394)- D! b\" v/ y7 f  Z# a4 s7 l8 v
  36. # [' g, q% c, t! G, L( M
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]]); p* V  {: X$ i% k+ G( f9 E1 y5 j* ~

  38. * _' s4 [\" f* f6 _
  39. bias tensor([0.0003]) tensor([-0.0006])
    ) G# T* Z% m; W0 M

  40. 7 O1 b1 [' [6 n$ L# Z4 N: E6 U
  41. tensor(0.0252)
    + H+ M; a' c9 e* t
  42. $ r# I/ T/ T: L4 q0 G
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])& X. t! ]1 i- s- x6 q0 x% a

  44. ; S0 l6 Z. h9 U$ B( G. }
  45. bias tensor([0.0004]) tensor([0.0008]). S: ^2 B# }6 D. J4 g$ M

  46. 8 F* E6 J: d9 _, V
  47. tensor(0.0161)
    % W5 |+ d: k' g) K: M/ i
  48. 9 ]* ~0 o+ d! k+ {$ k4 P* m- ^
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])' m) l$ M( w- x: X! T0 u# A
  50. 4 j6 ]$ I& W9 c3 T
  51. bias tensor([0.0003]) tensor([0.0003])
    * Q: _6 b# e: K\" O/ M7 N: i2 O: E0 G7 }

  52. # q4 L8 l( V2 @; N
  53. tensor(0.0103)
    - ^8 V2 W# J( b# S4 w, t1 h0 u

  54. ) ]( M+ Y6 n3 N( `/ ^' x- a
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    3 M, @+ X# _0 ?

  56. 0 u: B+ A2 Z/ O5 v; M; U
  57. bias tensor([0.0003]) tensor([0.0004])
    2 X- x$ U) r\" R1 r( O2 `/ u
  58. * l* @& J  ]1 s9 F% |
  59. tensor(0.0066)
    : o, F; B/ k: ^. P/ g7 S. C4 J
  60. * Q5 t9 U/ l( ~! `
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    1 f' [4 S3 t  ?0 u
  62. 9 W  Q# A$ H+ @3 `0 v
  63. bias tensor([0.0003]) tensor([0.0003]): b1 i* g/ |/ k+ }% ~
  64. 1 C% k\" |$ l+ E! L6 O- u. o3 i# p
  65. tensor(0.0042)
    & ^2 A& Q5 w+ `: t! A

  66. - g- k9 J, S5 y  P\" y/ C
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    ; L5 q; i, H3 G% i

  68. ' z$ a7 C. g- |! e( o) v3 s/ k1 U
  69. bias tensor([0.0002]) tensor([0.0003])
    $ S, N4 E* a$ V5 `5 y7 ]1 U# n
  70. 7 B) f0 x* j6 Y* p
  71. tensor(0.0027)
    ( A0 X) Y9 \* ]- ~' M
  72. \" V4 `+ `! k7 t5 X; o# _
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]]): ]+ _/ W; Q$ u& I8 S$ n

  74. 3 ^' ]3 R! I, O: J- |% m\" }6 W1 c
  75. bias tensor([0.0002]) tensor([0.0002])6 _4 X7 l  W3 |2 H% m+ }/ c

  76. ! ~\" W) W  w, x+ Q4 g
  77. tensor(0.0017)
    2 o. N3 d7 E; l1 E3 y; {\" v7 ], p

  78. # y6 o/ t2 D. \3 }
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])8 f4 D! d2 _7 o. S
  80. * C7 c4 d6 E6 z- E& e: d6 J: [
  81. bias tensor([0.0002]) tensor([0.0002])7 p6 F1 Q0 }% M9 I
  82. ( Q; r' Z7 z+ s  L) C( R! {
  83. tensor(0.0011)
    - ^6 Y: N9 p) N' _: N! t5 g

  84. ) b# ]/ I\" y4 s0 ]2 e) {
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    ! G( \9 h! N0 f

  86. ' N0 m2 I( O7 {; N4 f+ x
  87. bias tensor([0.0001]) tensor([0.0002])
    + E: p  [0 D) \3 {& E
  88. - a1 @+ o; J* q% S/ t
  89. tensor(0.0007)
    % k* o$ x) G' C8 s, h# g7 r

  90. 2 H0 s3 A# Y9 o( v% ]6 U) r* T
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    # N! U\" ~7 S, P4 @# @% j8 P
  92. . z+ y/ e0 ?* c8 k7 o0 T2 a
  93. bias tensor([0.0001]) tensor([0.0002])
    2 K3 {2 s  o6 S! \\" Z

  94. . g4 J6 {1 g\" K0 ^
  95. tensor(0.0005)) I9 J8 H2 l4 H* ^4 [3 {3 m

  96. : y5 d7 a0 c0 T3 Q6 x! `8 i
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    + Q: _: R8 [, `! k3 g2 s
  98. % G\" ?8 h- L+ c5 z* L( n! x
  99. bias tensor([0.0001]) tensor([0.0001])
    ; ~* ]/ w, N7 e0 v

  100. 6 x# @4 ^1 j1 q* j, ?7 p; {
  101. tensor(0.0003)
    % ~4 J- Q8 z7 `' p; }
  102. 3 Q6 P& b! D2 ^' p- F% K4 G: L
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    , B\" a) @7 h, D- D4 P. p/ B6 {% i$ V
  104. \" ~& y: f  L3 r; i5 S. n. d+ }
  105. bias tensor([9.7973e-05]) tensor([0.0001])6 B; ]/ D, B3 Z4 q9 x. y
  106. 3 \' S* L; }  t+ o, p- g% M\" Y
  107. tensor(0.0002)
    5 R  h% G. m9 ^, g! d1 g9 a9 I
  108. % B: S: q5 p5 X& R7 \: e
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])  G& y1 A/ g$ J( O- |

  110. 7 N8 L; [8 J) v' H) L
  111. bias tensor([8.5674e-05]) tensor([0.0001])9 Q) K8 @3 D3 i4 g3 e

  112. \" A& ^2 u% }8 P: m7 N2 n1 @
  113. tensor(0.0001)5 h% ^% \0 R+ Q, ^; [
  114. # c' o& W3 \7 |8 [4 h\" s
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    & W: y8 I% {, P, e

  116. # J# r! k; f4 c+ _
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    + n1 K7 Y7 @  u\" _7 I+ q! i

  118. , K  F! U: f- l3 @5 E2 f/ A' w* Z
  119. tensor(7.6120e-05)
复制代码
; Q# o' d6 l& q4 q7 t( O
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-7-23 14:28 , Processed in 0.450658 second(s), 51 queries .

回顶部