QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2623|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
" p0 A8 V0 {$ xSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。; t0 E! ]* r0 j2 D) Z
怎么理解梯度?
8 j! c9 K6 G1 [- p假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。, {: s. ^$ t& c" T3 O

; R( x3 K& M+ ]% O在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
+ F2 [0 h6 H0 L; o8 o0 T& O/ {6 F: [$ D0 s
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。6 S( D4 {* J* m  I5 W4 p
, g+ v' _7 T2 Z/ b' f
为什么引入SGD+ y5 ]6 z) ]* ^, f6 V1 O. Q
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。9 w6 O* c+ a: M( B9 p

8 G. b2 r/ ^$ @) B& J! x怎么用SGD
  1. import torch  Z! U\" ]8 ~9 o$ P# `

  2. 2 o+ P' p. N\" G3 g+ Q$ E
  3. from torch import nn
    1 S2 z( }1 T! j9 Z7 A5 v* w! W
  4. % k3 v; L( K. n2 s/ D! U% ^* J$ n
  5. from torch import optim
    & W, `# r7 n7 [/ r4 b, F5 |9 _

  6. & {- ^/ p2 X) |/ w
  7. 2 O) s! R3 W5 m
  8.   i, I  @' m; }1 V. B0 R+ H1 r
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    , f/ B+ ?' ]9 d8 ~! H
  10. 4 I3 l! q- ^1 u1 H8 `- a
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)+ J' k$ B7 O$ Z$ X/ }

  12. 4 n/ O0 J- m/ _  Y4 _: ?
  13. - Z! Z. z$ I\" R6 O2 e( d+ \7 `/ H

  14. 0 D) X, x6 ^  z& S& `6 ?
  15. model = nn.Linear(2, 1)0 }2 A1 J5 S% S
  16. / n. A1 E5 L! l6 m1 A/ _
  17. + v1 @0 J9 c/ h. e

  18. 3 z2 c; I( Y  u7 j! u4 U5 E0 h$ G3 ]
  19. def train():5 W7 c5 o3 g* H% k' _\" n7 T: l

  20. & Q+ V% l& c8 N
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)8 ], @5 K9 p) q9 w5 G  w
  22. 1 J% Z6 G& d: A8 g
  23.     for iter in range(20):% L8 B% ]( M; Z+ V2 k

  24. % `* j6 x, ?- v/ a
  25.         # 1) 消除之前的梯度(如果存在)  Z' ]' Y; d5 I. X1 `9 d
  26. 7 `2 f. W+ E# c\" z. x  ~2 x
  27.         opt.zero_grad()
    8 \6 E$ c. r9 Y0 M) L
  28. $ J% c6 Z. {1 ^( {6 n# [, X

  29. $ w0 H' N9 Q# e9 w8 j/ x
  30. ! X4 P- [, u3 @\" M0 L; e! `0 s. E
  31.         # 2) 预测4 c9 o) `/ C( V2 K. c5 s) G

  32. 5 p, y9 q' Q3 J. B
  33.         pred = model(data)6 f+ H% ^1 j) @. v) @0 ?( t
  34. , `( \$ l+ z4 m5 t6 d

  35. ' P1 @, {3 ~\" ?5 U* G6 q9 g
  36. * ~9 j7 f2 D2 k, J. V9 ^2 t# C
  37.         # 3) 计算损失( u5 n4 [\" x& m\" f6 N
  38. 1 O! \+ o' M9 u9 y0 V7 c6 g
  39.         loss = ((pred - target)**2).sum()# {2 j% V8 O  Z3 t1 U
  40. ( e6 S5 b# S; j

  41. ; V. S8 o& |. d* p! l, z9 f# Q+ Z
  42. - W# N4 h- w6 ?0 y9 i/ Z
  43.         # 4) 指出那些导致损失的参数(损失回传)
    $ |* y  R; e- a% ~8 V  @: s
  44. 1 }* C, h; T\" ~  b% n) ^% O
  45.         loss.backward()
    5 A% d: i) m+ q
  46. \" z, c6 x$ s0 {! z
  47.     for name, param in model.named_parameters():  I; @# l0 l1 `
  48. 0 o' l: T) k% U$ ?4 c
  49.             print(name, param.data, param.grad)0 P- N7 b4 u4 Y
  50. ) q& y$ c9 f! [1 P5 ?1 v
  51.         # 5) 更新参数
    $ z% Z4 E) ?- ?5 Q  ?8 n, c5 w. z

  52.   P/ C1 b) r2 H6 m; H$ O6 o
  53.         opt.step()) n  M, g; z! R$ c( Z

  54. * C+ q; m# f  T% a- ]5 p

  55. ! m, V  S# A. N7 S* ^  L
  56. ! [5 ^- N9 N/ O; |9 o
  57.         # 6) 打印进程\" x: R& E+ \; ~8 v3 @

  58. 4 m7 @4 l# F4 ~- e( O5 G
  59.         print(loss.data)
    6 t' P$ v* B& O7 u1 G\" [% K

  60. 7 O6 \3 {1 o0 O! l, p* u  ]

  61. \" X4 f+ C4 O6 n5 |3 v$ p. N/ K/ ]

  62. / r) |& b. q$ ]' b
  63. if __name__ == "__main__":! _, @3 M% V: M5 s  }

  64.   q* h\" Z6 ^* ?; \
  65.     train()
    & B& ]0 B- q0 {/ J- D7 ]- J
  66. 8 j# ~6 W7 b4 F5 d0 q! ]3 ~  D8 j& v
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
7 T2 S* C! t  T* U* }9 i& e  |9 C( `6 F+ n4 T
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])9 b- G0 g! }5 z) H
  2. 0 Y; L! }7 J4 z* D% d1 i8 J8 ^% p6 p/ D
  3. bias tensor([-0.2108]) tensor([-2.6971])
      n+ o- \; P9 W/ O( F
  4. ( S/ B$ G& s3 l- Z- w
  5. tensor(0.8531)* G( [$ J& x8 ~' S; a! Q, g

  6. : W1 d\" \  P. Q' i% [- a9 o
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    % e6 p\" u4 B( }' C% z
  8. $ o) C6 m( A; O7 Q6 I4 p* D
  9. bias tensor([0.0589]) tensor([0.7416])
    1 c! N. F4 h6 ?

  10. # l- V& Z. f/ t# _  a4 R, l
  11. tensor(0.2712); w' v1 ^) v. x* ]) n7 l# A

  12. ; c% k' |6 C$ m* `* w, n5 a, M
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])/ ^+ t, [; J8 w: E0 N
  14. 8 U. a; f1 I4 g) [9 ^& J  H
  15. bias tensor([-0.0152]) tensor([-0.2023])
    - @3 l& f8 n/ D  }! [& g! O
  16. # G0 X0 w8 |( C$ ]  r3 Z
  17. tensor(0.1529)
    $ E' h, D# a; g' k
  18.   [! M# W% g9 T: D- ~; O3 n
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    ) K& ~/ }+ a% U$ R. p

  20. 5 `9 Z: Y: D7 r, J7 V) o
  21. bias tensor([0.0050]) tensor([0.0566])
    , P2 \% W8 C1 U) o9 j+ l( {) H

  22. 4 v( J1 C) D& ~1 w) L
  23. tensor(0.0963)( V- H  h/ A+ I, D/ ~3 q- o

  24. 9 c/ Q6 Y+ z- \. u8 i8 g3 Y2 I
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])* Z' M/ n( w5 {\" N# n. F8 b
  26. % \+ S: A% ~( K/ C4 j
  27. bias tensor([-0.0006]) tensor([-0.0146])) P' x( Q0 i9 S; i: A4 m

  28. 6 B/ Y# @' @  `8 m. H
  29. tensor(0.0615)* h) k& w- o) H. Y4 U6 e  K* a! N
  30. % d( h- i& x  s
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])3 ?$ A% I( k7 r

  32. 9 ?2 ^3 A' W( d
  33. bias tensor([0.0008]) tensor([0.0048])$ `* j& `  [% j7 m- G
  34. ( a8 F3 v8 S: y8 a, r- H
  35. tensor(0.0394)
    & v6 S. \' t$ S9 `

  36. 5 C2 D' z2 \' J% Y
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    : [$ M% l. l' y8 i

  38. % n% }6 @* t% ]: `5 n5 E+ u0 U8 _
  39. bias tensor([0.0003]) tensor([-0.0006])+ \) e( N+ u* F, P$ G# n# Z
  40. . {5 H# g) Q, J! B9 F# n; X3 Z
  41. tensor(0.0252)
    ) e\" W3 \# k: O+ J8 }; K# y

  42. ! C9 C8 p4 G% l
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]]): [2 k$ \' x/ ?\" ~8 a! h

  44. 7 \4 ?7 l; O% B7 Q8 t
  45. bias tensor([0.0004]) tensor([0.0008])
    1 |: c\" F& y* R/ ~3 C

  46. : K1 [6 y: g- G9 ~
  47. tensor(0.0161)9 N7 ^8 O. d! E7 c) g+ f

  48. ; X$ R0 G\" _8 s
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]]), i% a; W5 G3 G
  50. & J5 `: e3 h% w7 R  Y# W! Z6 q
  51. bias tensor([0.0003]) tensor([0.0003])
    . k( }1 z. F5 v  j0 @* D) Q7 ^

  52. . o4 {7 f2 c: {
  53. tensor(0.0103)4 x0 g# Z3 |' c. E3 D' E/ R: P: ?
  54. ! G. {2 D& k* S- u+ T8 u
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])8 Q' Z& {6 j3 C& e5 a
  56. 6 u1 \- H7 N1 S4 K\" |# f7 @$ M
  57. bias tensor([0.0003]) tensor([0.0004])% l* g, l  Z- o* }
  58. : s* E% a$ [( m  f
  59. tensor(0.0066)% Y* w$ {+ [3 B- J
  60. \" a! a! W# J$ o3 j( b
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    . @% h& M% h/ j

  62. * w0 b. G4 @. e
  63. bias tensor([0.0003]) tensor([0.0003])
    : [; @0 Q9 O. r9 w8 u2 N; m
  64. , W& o8 F- a* C# h: S5 }- `3 o1 Z
  65. tensor(0.0042)& y  [; b6 z8 Z! E\" W

  66. 0 g6 y9 x1 ^/ i+ i0 E  n2 v\" d
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])! ?: y\" G+ j- Z( ~

  68. ! e$ J8 b- Y1 Y( a5 l
  69. bias tensor([0.0002]) tensor([0.0003])
    . D' P& w# e1 D# n
  70. ' B& p0 A4 V$ K8 q$ z
  71. tensor(0.0027)3 \\" m) t+ u4 a$ I! B% ~/ Q4 ]( R( I) Z

  72. + S2 G9 u4 W& B2 u: O& y  I+ Y
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])7 C: |) }$ P' M* z5 ^! a

  74. ! b& k* p. o5 t! l  {
  75. bias tensor([0.0002]) tensor([0.0002])7 s8 s: J- V\" ^8 g# B8 \/ E* I
  76. . O6 V3 U9 d* X4 n
  77. tensor(0.0017)
    5 D4 I: X6 L; w, Y) \; j' k( Z
  78. $ d. `; D- V9 `& w  U
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])) i+ r& Z0 }( W) n/ h2 P

  80. 1 h5 I3 [* P\" C2 @3 Y' z
  81. bias tensor([0.0002]) tensor([0.0002])
    3 B: G4 ^. Q& H7 [4 g% M' t

  82. , r1 o. {\" N0 [0 T0 K\" n
  83. tensor(0.0011)
    ' J% I/ y; O! Y' _9 v! [
  84. ; A- [* C2 E7 j/ L6 F
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])$ O' `# q# k3 Y2 w) l3 P5 n# x

  86. / f0 E; u% W8 r/ q
  87. bias tensor([0.0001]) tensor([0.0002])  A- h  F( k1 t$ e

  88. % E: H. f  I, h
  89. tensor(0.0007)8 q6 z  z4 h$ E
  90. 9 p\" q3 g9 g: ]2 m! ?6 t
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    ) Y& N8 V) M5 `% p: t/ ^

  92. / x. a3 [& j, j$ O( U\" _, ?
  93. bias tensor([0.0001]) tensor([0.0002])
    ' V. r, j\" t. k$ W  B9 G1 c7 X
  94.   R$ n+ c8 C0 N: c
  95. tensor(0.0005)
    6 ^& v2 P4 N0 T* y' b0 R8 }1 G. ^( f
  96. . I* b1 j8 h' S* a
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])4 k# l$ K2 m# E) m

  98. $ D$ q$ j/ L* w& E' T
  99. bias tensor([0.0001]) tensor([0.0001]), b; H; d+ b( ^3 F& h

  100. \" n7 F9 b7 e. w: P( F\" x6 t\" w4 H( ^
  101. tensor(0.0003): @. }9 [8 d5 t% q$ E\" u/ g

  102. : i2 {2 L. r- g/ s6 E
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])5 N5 X/ m8 P, t1 W

  104. ( r3 B\" U: I  V; r% }
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    : z. K! O' S5 A- f* m

  106. * N8 T6 R( H7 f. J: B
  107. tensor(0.0002)
    ' v7 \# F\" Z, d) I0 W
  108. & e7 E; ~/ d$ A4 C) r/ z, k2 U( B
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    9 a4 ?0 H7 Y\" `- q1 H- h
  110. & E5 D1 g8 \' w' H, y9 I! w
  111. bias tensor([8.5674e-05]) tensor([0.0001])& {\" P+ W8 u8 c. Y# V

  112. 2 P% @- l0 g0 @9 a) n
  113. tensor(0.0001)2 n+ G# a1 b$ p( u8 D$ X
  114. $ }- ]5 n2 P  d1 s6 l) V# L
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]]); _+ H+ K1 R) X8 }  Z  {& H0 Y* `

  116.   o+ Y, b0 d6 l, E+ M
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])2 _+ H8 q% d  H# j# m

  118. + Z1 L# v% O) L7 v( ~0 y* t; m4 _
  119. tensor(7.6120e-05)
复制代码
0 m' M  }6 G6 p# |, w& J" O
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-12 02:19 , Processed in 0.407827 second(s), 51 queries .

回顶部