QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2635|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么+ n2 w* R9 P3 P% a% t% u
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
% I+ `& f6 H# u怎么理解梯度?! V9 a: k: J' g$ [" v
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。. @( P$ \! L, m* D( J8 l

" O3 g. {/ g9 ?* H在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。3 D8 u; Z4 v% Y" w* K% D

1 [  C1 a# i* H* o8 W每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
* d" C! L/ z+ i7 z5 d  _  M, h" u8 Y7 X" F
为什么引入SGD6 [: Y  M% J/ h
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。0 t+ |! Y  `- W  V3 {0 y5 F

8 W0 ]% h' G/ k; R$ z怎么用SGD
  1. import torch
    # A: w\" V8 Y: Z1 K5 R
  2. # {  g0 r\" ?$ O2 R2 B
  3. from torch import nn
    ! A! C  ~1 l! Y0 t; n0 h1 ]

  4. # v! ^6 F: C$ K
  5. from torch import optim
    5 D3 `7 j8 h- W2 A9 y( L9 g7 e

  6. % F' K# E* i6 G0 F+ K5 C

  7. 6 \2 R( z! o/ H3 ?
  8. + H) f) y# e. `' |# p, x
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)3 s0 C  |' ~\" }- p

  10. 4 h2 u; W+ X: p! v$ C4 L: r1 j/ b1 u$ T
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True); L8 f$ F4 V& b4 D% d

  12. * k7 b) w5 ~1 ], D5 c% A  H! E

  13. ( r0 |# X) t% a; [  a

  14. / h6 h\" c7 L0 c7 U
  15. model = nn.Linear(2, 1)
    2 E7 |, t- v8 Y* U7 b6 f5 @  r
  16. 1 i& ^* o4 A/ q& W2 H$ n, W* z

  17. * T: W; ?. [: s* z

  18. ) y+ B0 M; X' O2 M
  19. def train():
    ) u* v6 \* m: a
  20. ; \/ n% G$ M: E\" U
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    % c, _9 a8 N\" q5 h, b6 N7 _+ r
  22. 6 }1 U& b% a: ^  c* y
  23.     for iter in range(20):; `( h  S9 |. \* t$ P

  24. - ?# R* {\" e* B$ m) m
  25.         # 1) 消除之前的梯度(如果存在)# ]& F5 p) M% q
  26. \" F3 P: D3 P4 K4 X
  27.         opt.zero_grad()
    8 L, b& _6 e# G9 z: v& ~
  28. ' |! B# a1 P5 q% n9 Z

  29.   e: i( V) L8 l9 T) L
  30. 0 O2 }7 t9 I8 J* _8 u
  31.         # 2) 预测) B, Q6 ?) T. F* O! ?

  32. % [8 ?- G9 ]+ v- S; }- b: ^' V3 j3 d
  33.         pred = model(data)
    , Y* a% N8 v! D\" s1 I

  34. # M# r3 Y$ s8 e/ T* R. o4 c

  35. 8 j& D& N) ~0 `' w3 V
  36. + o* i3 a7 Z# _
  37.         # 3) 计算损失
    ' H. R3 X& K) V8 Y1 n2 b

  38. - @* ^4 {\" R5 p1 Y
  39.         loss = ((pred - target)**2).sum()
    7 k% W  ~2 Z' g

  40. ! R$ y0 a& Q, s3 L2 |- W
  41. : G, F$ A5 j; ~* c

  42. ' `. D. r- ~' D$ r
  43.         # 4) 指出那些导致损失的参数(损失回传)\" w0 Q6 V2 o  V0 |  l
  44. 6 t1 o# G! {; [/ [% L
  45.         loss.backward(); w  C+ Z& C# Q1 F2 e% ]7 L, K

  46. 3 R0 A5 O0 w: }) `  K- M; q
  47.     for name, param in model.named_parameters():
    / y! S: r* `) n2 ^: S
  48. 7 p9 s! p8 L$ x& V
  49.             print(name, param.data, param.grad)/ j9 p' ?. U1 V5 A; ^0 r6 c* Z

  50. % N0 v  d- g: {2 U* ^8 Y7 `
  51.         # 5) 更新参数
    # U% J( @& i, _* c+ q
  52. & @- b7 r+ T5 ?  _5 [/ O/ ]( R- O
  53.         opt.step()
    * i8 A% N' i) ?& l7 e

  54. 1 w. x3 O! M9 x  w; s$ F/ G- X

  55.   I6 @4 l% J( X4 v6 f

  56. 9 u; E4 P7 F\" c
  57.         # 6) 打印进程4 Q# Z5 U4 i- a& g' _
  58. 2 f8 @- `% W0 B+ w! a
  59.         print(loss.data)8 M; k# E2 T& F
  60. ) U2 U6 ]\" w4 F) f7 L0 x, o6 Y
  61. \" p8 a( t\" u* d; R\" |3 r  e7 e! @

  62. 2 N% O& p9 C# c% B% ?
  63. if __name__ == "__main__":
    % Q2 w+ p$ [* A' ?2 g! s

  64. \" ]1 w$ l\" d( S- l) W
  65.     train()8 _0 I\" [9 K, X6 [\" U; R

  66. * x/ V2 Z. r6 E! w; X8 O$ K6 p
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。0 ?0 C# ]. h0 ~5 F* L, O
$ l" J* f1 ^- U6 k9 _; {0 ]
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])2 J2 H% d' H4 F- T0 O# ]
  2. 1 v+ Z+ D5 h' d% T* t4 W+ `
  3. bias tensor([-0.2108]) tensor([-2.6971]), F9 Q* w* }* \, E4 y
  4. 0 A1 a$ |* d* B
  5. tensor(0.8531)
    $ d- T% M. B( t$ X/ T9 P0 E. g
  6. ( q6 \. D( M* n6 L
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])7 O% O* I5 X4 ~# ~& F2 l; W
  8. * C( l; l0 a4 H6 r
  9. bias tensor([0.0589]) tensor([0.7416])
    ) F* e5 a' \5 Q5 K& [- g

  10. / h) C) c: q$ M( p
  11. tensor(0.2712)
    % U0 M1 }  U& _# b, [

  12. 1 t# Z+ x7 h1 d# @, W  W
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    + N- E) K7 _' ~- i/ n9 U- N
  14. 9 o0 Z, d1 q6 A1 W: \4 T
  15. bias tensor([-0.0152]) tensor([-0.2023])3 s' A/ b0 e$ ^# t5 m
  16. # [+ V. m- O\" }8 B: G
  17. tensor(0.1529)
    7 W! c, }! t5 x\" }% Z7 J$ T
  18. 3 W- e: a$ M7 _, H: l/ Y
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])- \+ M\" k: i: _% Q; c8 k3 s4 Q
  20. # b* K/ j( K$ N& @
  21. bias tensor([0.0050]) tensor([0.0566])4 y; J4 d3 q6 z# C+ s- m( D0 P+ W
  22. 4 c, A* q# ~\" X$ M5 m2 R
  23. tensor(0.0963)
    $ C, k1 n' j0 l3 x6 l* p1 |
  24. # g6 f2 p, U8 V! C9 `4 t5 J
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    3 _1 _1 A1 U+ k3 o; V/ H$ x  `2 q: j
  26. * I% y/ Q, N  ^& O, q. w, y/ |
  27. bias tensor([-0.0006]) tensor([-0.0146])
    7 m9 c* }. ^\" o9 p$ y- _

  28. : V1 K9 H% `$ a0 Y: E. A  F* E
  29. tensor(0.0615)- N9 U+ F' l0 A: T\" E

  30. ' n9 X0 \. G, y3 s  r, Z  O5 R\" C
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    4 u$ E! J7 y' t\" j3 x5 Q) q) P) Y1 T

  32. - ]* y$ |% B* y; L
  33. bias tensor([0.0008]) tensor([0.0048])
    : D7 M- l( m; o8 n/ \0 l, ?

  34. 3 o' J7 D  |3 x2 d) r3 x2 m5 [) v* t7 r
  35. tensor(0.0394)
    # y7 ?$ V# r/ N' O
  36. ( v6 T; ?7 l7 c% m2 a* n
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])$ w7 U. d8 B) J\" q) a. Z
  38. + G& N8 Y! r' m  {% l; E
  39. bias tensor([0.0003]) tensor([-0.0006])
    - O- n  a& z6 B) [9 h

  40. ( w: e% R4 u. y\" Z* u) X3 }
  41. tensor(0.0252)
    $ h% H& T; t! o' n$ M, r

  42. ! S' \  s) e6 E* [4 l& j- x9 P
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])+ U( f6 t: E- g: \+ x7 @\" A
  44. ( H, a. P, x5 V
  45. bias tensor([0.0004]) tensor([0.0008])
    ( `; N1 N5 c( L6 @
  46. \" A7 W+ D+ J' s; w! V
  47. tensor(0.0161): T* e# C4 c. i0 c* x, f4 [
  48. ( P2 P7 N5 B5 V
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    ' @6 z$ b\" \7 u0 r' u

  50. * B/ Q/ a2 v2 r. Z/ @' [
  51. bias tensor([0.0003]) tensor([0.0003])
    * Q% _# c. D1 F4 w, i, c
  52. : X# B\" F8 C( n! U
  53. tensor(0.0103)( Q- W& P, c) D4 j! V9 M! Q
  54. 4 t4 X. W( h1 h! Z
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])* _1 q1 x, K0 w5 z\" O4 y  e6 J

  56. & v3 u  q5 j& o$ H\" A& M
  57. bias tensor([0.0003]) tensor([0.0004])
    ! `, p% x\" c! T5 b& Q& E

  58. 4 Q8 X: G* x5 k7 K. D% o
  59. tensor(0.0066)
    % e6 C( q8 @7 W7 R

  60. - |: l6 g$ D% S4 p/ D
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])6 ~; L\" c0 M) _2 U( G
  62. 3 q9 d9 m1 H5 V8 P
  63. bias tensor([0.0003]) tensor([0.0003])% Q! N4 |) A$ i0 g3 ^3 w3 R
  64. 1 _: t# T3 C8 _
  65. tensor(0.0042)
    ( v1 `* r( ^% i5 z$ a
  66. \" U7 n0 |+ S  L( S% Y
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    % S4 j6 d4 ^% q6 |# {8 b: a0 R

  68. % U& i# u' \; s  d! [9 p1 x1 x
  69. bias tensor([0.0002]) tensor([0.0003])) G9 L* g# R* C$ F
  70. , T% u2 q7 d# p# Y
  71. tensor(0.0027)0 _( }9 h2 [0 o' U
  72. 9 J% M3 H4 u& o7 G* V
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    & H- L7 r$ {\" @* k8 c
  74. + r  H1 Q- P2 q. Q) a3 x, \
  75. bias tensor([0.0002]) tensor([0.0002])
    , T1 q! j) I' A5 M+ h
  76. # z5 `\" N- c, b4 C
  77. tensor(0.0017)
    8 J, v. m) g! _- {8 k2 ^' i* y/ r

  78. : G% p% m! V% [8 q' v
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    2 y% a' z: v7 n+ t4 V1 \( r

  80. % a$ T) j8 @; ]
  81. bias tensor([0.0002]) tensor([0.0002])  k- M2 ^3 `3 k0 r9 i; c/ W
  82. \" Q8 ~1 E1 M& k8 p2 `6 l
  83. tensor(0.0011)
    7 B- |' ^) }7 t! O3 ~. d
  84. 7 ?. S8 K& h\" o) T1 H2 C  N! ^) G
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    6 Y! Q7 f. }' a/ W
  86. 0 u- y' n9 M4 P$ {
  87. bias tensor([0.0001]) tensor([0.0002])
    . U, g* l! d) ^/ x$ T

  88. * }$ y, i: ], w. v! e. i\" g
  89. tensor(0.0007)
    $ G( Q: H  T/ e+ _8 L

  90. % u9 Y' B) h: L4 K: p( L
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    ) q1 D8 l\" N; o9 G# H9 i7 [

  92. 3 v4 Y$ j6 v$ E1 a3 g
  93. bias tensor([0.0001]) tensor([0.0002])
    / e) I+ s7 p/ E. ^, N

  94. , x+ C3 S& i/ J\" `. ]/ B/ Z6 I
  95. tensor(0.0005)
    + X/ i  M: m  A2 S+ e

  96. 3 Q8 p- H2 _+ G1 e5 U$ `
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])- f5 W4 N- W- Y9 m2 P6 K
  98. 6 C- V8 T/ A0 h& u
  99. bias tensor([0.0001]) tensor([0.0001])  O\" A; B9 `! e6 y  P$ B: j7 E

  100. ) L3 l5 R, f( r, `, P# F
  101. tensor(0.0003)
    9 T3 L: T- Z6 x( d7 s

  102. 6 d9 |  `, ^& C7 w& C% V& Y
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])* E3 l5 G( C# a& L  \7 ~- C

  104. 1 w2 E5 K- y6 |
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    . U1 F$ h0 u3 a' W2 H, \8 J/ [- s
  106. # K* W0 T. p: H8 X! h9 y( u
  107. tensor(0.0002)
    0 U) O- o% Z+ L! B' s

  108. 8 [. ?% f. L, y  z7 P- l% N) K
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])/ U# Z& d5 t: S! P8 a! J/ m

  110. # u% V: }( M, Q2 J& E) O
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    \" J  y/ u3 u; Q. \5 M- g

  112. * O) [' o% ?) ~- }
  113. tensor(0.0001)
    2 h. H1 r9 O, q9 o

  114. # `, ]2 d\" K0 j: `( i
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    # V8 c' U/ H\" P2 ^
  116. ' Z/ m% ~. Y3 [
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])2 h0 E% F0 G- M, c
  118. 7 \( v2 ^' g5 e2 x& p( n
  119. tensor(7.6120e-05)
复制代码

" L) U# U- S' M6 u4 W, l" a' o5 E
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-16 08:06 , Processed in 0.399219 second(s), 51 queries .

回顶部