QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2659|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1188

主题

4

听众

2931

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么4 P8 N: g' s2 ]8 L4 z4 {" V
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。* D6 c7 @. S: g3 ]* C6 O- Z' O1 B3 n
怎么理解梯度?
% e$ W- w  x4 @/ ~假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
# V0 g8 d  T" ?7 {, T1 H. U4 T) y1 m7 F) s+ b
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。5 w8 B, f. b% ]

! X! ^$ ~6 K8 C, M每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
# S. A4 ]% R* ]0 }5 i+ u! f- s
: i* N. s, W( I% ?$ T, D0 p5 d- j5 L为什么引入SGD
2 E- g! Y" b4 G- C7 O- a1 I  T深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
: r! q+ Y4 I9 _/ H  s8 i) @; u+ D# o- O* [) t7 v0 U" H
怎么用SGD
  1. import torch. e% ~7 g4 E/ B$ Y- q

  2. : F\" [7 F4 j; k; e4 g
  3. from torch import nn7 `+ c4 |& b  G: a4 `3 p
  4. : \  Q: C: J6 H9 c6 n1 [7 a
  5. from torch import optim
    \" T6 a\" H; `/ J9 t

  6. 4 k  ]3 H. K: J; [* D: N

  7. & g: m4 g( L& D6 Z; t& b) A* o

  8. . S6 X8 S* v) B: L. P
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)0 H. H, O& J2 Q. k) _* h: [- E, n1 L. R
  10. ! _: y/ U5 L9 I8 [' e  ]
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)/ d* d5 ^* o  c: e0 Q8 l

  12. 2 v' ]$ J* h( q$ T\" E
  13. + \0 g, p; p: A3 t! _! K
  14. / h8 D6 I0 y  s$ [& M5 A- S
  15. model = nn.Linear(2, 1)
    ' @4 W5 l, C# G
  16. - F5 X$ O( [/ q, m3 B( i
  17. ; [\" Y2 U6 G! m  y
  18. - j! }7 R% a\" M1 ?* D+ k
  19. def train():& f) O3 Y\" `0 [4 p; ^

  20. ; E+ k5 q+ v& ?( ?% k
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    , I# q6 D+ R% n: o( W
  22. 5 Q- s# H$ O& p- B
  23.     for iter in range(20):
    8 i: l7 R+ Y4 q( G4 e+ ~
  24. 4 d- J2 e! {7 n
  25.         # 1) 消除之前的梯度(如果存在)
    ' [8 a7 `' M/ Q+ l

  26. \" Z4 N3 g7 B4 a( {* ?# {
  27.         opt.zero_grad()# X* y+ Z4 C( O& i

  28. 1 d9 @# B; N5 X2 z9 k
  29. 9 ?! q6 @% b2 |& D# u9 E0 j
  30. 3 E\" k3 ?0 |1 ]: d, B! b
  31.         # 2) 预测
    7 v& w: R9 k. h# [; }3 c# d( [
  32. 1 A8 C\" a( U0 L9 Z' K- I
  33.         pred = model(data)
    \" B) x' Z; `) p) H
  34. 8 U  t% U- K! o7 `
  35. ' g: T- p+ r; e& v6 `& P

  36. 6 D\" P$ J; K2 K' W
  37.         # 3) 计算损失
      A% c% e5 o6 ]/ _1 d, f) x1 u

  38. ' i5 h3 J9 l' P
  39.         loss = ((pred - target)**2).sum()
    + F1 |0 ?: A8 ]+ V0 I

  40. 8 m' D: R9 }% v\" o6 |

  41. 2 X# P+ j$ z# [+ `/ z: v4 O

  42. 4 ^6 R7 Q' O1 Y) n
  43.         # 4) 指出那些导致损失的参数(损失回传)7 j+ l' B0 z& X- ~
  44. % w9 B6 v! p- r% T: ~
  45.         loss.backward()
    ! x- @( \, Y% P4 M$ l

  46. 1 B4 B1 a2 N' ^1 a0 P. s1 s
  47.     for name, param in model.named_parameters():
    5 g5 Y0 W( P! n' R4 b* V

  48. , z. U% P% H# p
  49.             print(name, param.data, param.grad)& S, w- c5 V6 t7 e. X/ O
  50. . q' G; B) t, O: Q# w) X3 @
  51.         # 5) 更新参数
    ; V2 t4 q* U! x4 ?

  52. 2 U$ M$ E& R5 m$ ^* o+ V- ]; t4 V/ `5 C
  53.         opt.step()& }' Z4 {$ _+ X9 ~; q, c
  54. 7 r; ^0 o  q: @

  55. : q8 B! O3 u# ^% w3 A4 [9 x

  56. 6 i( ~# P* P  c4 F
  57.         # 6) 打印进程9 R' w# [  [( @

  58. ' a1 M8 z- i2 V. G6 E# x
  59.         print(loss.data)
    $ r\" d/ b, o# z) F! X, T3 ?

  60. & c4 F* p; }3 D3 X! ~1 h5 }* {
  61. # E1 y1 Q8 \( `$ I3 t6 [' Z8 m$ [

  62. $ s2 L1 e+ f; c% _6 `2 a( |5 V3 O+ [
  63. if __name__ == "__main__":2 E5 l, A+ G/ W! O- I
  64. & K. n2 @* o+ a
  65.     train()
    $ J; G* I) v6 Q* f

  66. ) u: C9 r& F* X4 B& o/ \% f& C) K
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。0 Z0 K5 i) F# E9 D% V- r( z
& J: e5 Y2 U* N# o
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    # H% m0 i0 Y; l, v, A\" X

  2. 5 ~- T, f6 y( M
  3. bias tensor([-0.2108]) tensor([-2.6971])* v: O& }  U' h: k\" X/ A
  4. 4 l6 k. f- z- g1 l( ^7 }\" n
  5. tensor(0.8531)\" w2 q/ k/ W; Z1 @2 }0 B
  6. # y: _% H# O3 C- S7 k1 \
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
      A/ ^( N1 f/ l( m5 ?! q/ {7 t
  8. 6 m, ?8 _* U' h
  9. bias tensor([0.0589]) tensor([0.7416])
    , @2 ~2 l  r0 Q; H9 z; ]

  10. 6 R2 A* z7 @1 `9 S
  11. tensor(0.2712)
    ; m+ q; f% t1 h/ [) o

  12. $ p  g  A) x% ^
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    \" a2 y; N4 r& X' m\" Y

  14. , b* f9 S. Q. Y  Q& n0 X
  15. bias tensor([-0.0152]) tensor([-0.2023])
    - e# e$ d7 p+ ?9 {. s2 W
  16. * Z0 r/ b5 ^( p) m: g- q7 [
  17. tensor(0.1529)' K, |: u7 H. {- c

  18. $ ^; H4 p' l0 N
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    - n3 k5 s. ~& S3 c6 I2 C- t8 y
  20.   F9 m! [: {& p+ _- l. {
  21. bias tensor([0.0050]) tensor([0.0566])
    - _  n5 s7 h& W1 y5 a( ?, q

  22.   f5 I% D' s\" E0 E8 k  `# ?$ o
  23. tensor(0.0963)
    $ ]- [; @9 U) g0 Q. }8 H  f

  24. 3 R8 M! _9 g, E' d* X
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    6 p. A  E# N9 w6 c0 q

  26. % J+ B& |5 i1 t4 ?
  27. bias tensor([-0.0006]) tensor([-0.0146])
    : Q  M0 O! P4 N

  28. 7 f1 ~2 h$ k, T0 u5 C
  29. tensor(0.0615)0 i, `$ C$ g% M. N' O

  30. ' b' X2 x& J* P1 a, [4 K/ }
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])8 P  J% i- ^# y# \- X

  32. ) K+ Q* L, f0 `1 @
  33. bias tensor([0.0008]) tensor([0.0048])' x8 C5 Y& h\" U0 P2 K. x% V6 Q

  34. 5 _2 b+ w( Q- A  H5 c# L
  35. tensor(0.0394)) I, q$ J0 D4 {) X

  36. * {% X% c9 U6 Q9 k* @; [, p
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])5 L1 j9 e& z9 K; d9 [

  38. ! y) m; j) y9 D( t, \5 |9 k
  39. bias tensor([0.0003]) tensor([-0.0006])  o% _4 Y5 E; C; M7 ~/ t

  40. ( `& H0 k( N1 {9 j
  41. tensor(0.0252)
    7 V8 F\" K  L( o4 o- |! u
  42. 8 K3 _: |7 W& p) ~. g
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    + @5 [8 A6 E% T6 R: ]$ }1 J
  44. 6 F5 k# W0 j6 ^1 x8 X; |
  45. bias tensor([0.0004]) tensor([0.0008]). G9 L& T/ D1 P. H9 _

  46. 1 q9 U7 t# P5 ~4 z
  47. tensor(0.0161)3 s3 z. {6 j: _9 ~  X2 ]
  48. 2 g& m\" J  ^6 E4 ]+ |$ v4 F
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    : u- H' R* S$ E& |
  50. 7 B5 \, R\" g3 D% E
  51. bias tensor([0.0003]) tensor([0.0003])
    0 ~$ k8 i+ p! G7 _
  52. ) p& p4 c* |* U\" }$ j% f
  53. tensor(0.0103)! s) D$ P! A- h/ ]

  54. % l$ O8 h% v( e6 i
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])- I7 K( F% b' j0 _

  56. 0 e6 m3 v9 k# h2 Q
  57. bias tensor([0.0003]) tensor([0.0004])0 a3 Y) e) }: U! Z7 R: z  {
  58. 2 K% r4 v1 A* C* c7 Y- |
  59. tensor(0.0066)
    ) N4 s3 }& F- @; M  N3 }' K6 x

  60. - ]7 f5 P! r9 N
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
      ~7 m% F2 _; }4 C% N
  62. 0 l& I9 h% s/ A! e5 {  h7 |
  63. bias tensor([0.0003]) tensor([0.0003])
    # C4 w+ H* L0 x$ m! e3 y1 S

  64. * X3 Q4 ?1 [( Z9 i1 p. U( ~' J
  65. tensor(0.0042)+ a+ P. @/ y# Z, v: I
  66. , @0 S. C) L% t4 U- \2 G
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    ' s# K- l# \0 s2 i: i\" S% H
  68. 0 X4 D, R: N  h+ @) G+ F
  69. bias tensor([0.0002]) tensor([0.0003])
    1 S* g6 z, ?% G/ d3 O- L) A/ h9 r
  70. 1 b& h( k* {6 |+ {% }
  71. tensor(0.0027)2 s/ W. `. Z4 H! v% U

  72. / \- ~: Y( T9 k0 ^, s
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    9 ^- a  u( R* L\" A2 z
  74. ; e( p/ ^$ e9 d' L
  75. bias tensor([0.0002]) tensor([0.0002])
    9 V6 a1 a3 ~& e3 V
  76. ) m4 ^% q, S# A0 T
  77. tensor(0.0017)
    + O: K2 N\" o+ l) |5 K

  78. 9 C0 Q! n- h! y. d5 m- l
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    $ C' F+ ]3 V, L( s
  80. ( e/ R. W8 _. Q3 J! x/ z
  81. bias tensor([0.0002]) tensor([0.0002]): b5 C+ V  \7 E7 Y2 X2 T1 I+ b

  82. 9 O0 {1 x4 Z6 ?! m  \
  83. tensor(0.0011)1 E, w9 s* d\" n- L* \3 G
  84. 2 d8 l* j8 C\" Q\" d8 O: h# [
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    ; r! L% Z  X6 H4 |; {  r7 H1 n

  86. 7 W, r) U# j) l, z. W$ w) _
  87. bias tensor([0.0001]) tensor([0.0002])
    + c! W( j; x7 s$ j3 I% @8 E5 q

  88. 6 K& p9 U' @2 n# G  U/ n
  89. tensor(0.0007)
    9 b% t% D2 f! Q6 c

  90. 8 v2 w! k+ e/ k& m, S% o2 O. W8 C
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    # x9 O: {/ p' r# G0 j2 c3 s' y
  92. ! w3 k5 X3 W\" v% L8 C5 V
  93. bias tensor([0.0001]) tensor([0.0002])5 s+ d  _: A: ?4 |- p
  94. , g4 e% l  M$ n
  95. tensor(0.0005)
    , f% b8 Y1 m) O( s

  96. ( C0 F& ^  S4 q
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])0 j- V* p% q! H. X* P7 O

  98. 9 \4 }- n6 Y6 M6 ?; b$ T9 ~
  99. bias tensor([0.0001]) tensor([0.0001])4 I5 m2 Q# E2 W  {

  100. / [! X  I/ y6 i( o. B  e9 J; A; R
  101. tensor(0.0003)\" F6 g) W: v* B5 \5 o. ?: K

  102. # f' d. r' k\" F3 e! A/ p
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    ; C) K9 R5 C1 V8 o

  104. + v$ q8 O, ^  k$ G+ D
  105. bias tensor([9.7973e-05]) tensor([0.0001])/ \4 i+ }4 V, D

  106. 8 m, }\" Z  t% ], y: O) w
  107. tensor(0.0002)
    1 E/ U; h& T% S+ r2 @) ?5 J

  108. / N+ L; o# d$ \
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    \" a% u, {! ^' z! v7 E\" {/ u$ E
  110. ' i7 z2 I  u; ?3 m5 @7 I5 I
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    0 w. J/ i+ m1 A6 w' h0 I
  112. # T9 T2 l- u\" Y& O
  113. tensor(0.0001), x* S2 j! }$ F( C7 m  l
  114. 1 i/ s& ~& Y! Q+ a
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])3 N; m- b; ]! s6 H; Z

  116. . t* L; A0 O7 M
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    + u/ L6 {# I' S0 V9 ?3 g1 {
  118. 4 Q( R) ^/ ~, ]5 J
  119. tensor(7.6120e-05)
复制代码
, B$ Z8 }+ }8 D1 Z7 {8 D$ i
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-5-26 10:37 , Processed in 0.434010 second(s), 51 queries .

回顶部