QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2309|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1176

主题

4

听众

2887

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么* L, b  w, ~, P- L& B4 c% H
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
7 g" Y( x( a1 K, ]% \/ ?怎么理解梯度?
$ _4 r7 {, e8 ?6 i$ }$ f' K* T: x假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。% u) s0 J5 G- I5 v  w5 B$ P% v1 b

8 ~5 x2 |! d/ @! W在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。7 u9 _% T4 ^- J' c, u
5 }  K! F  P$ W5 }4 h
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。+ \$ M7 k; U" |$ Y9 r7 B, U0 u# Y
4 }+ ]0 i6 q; t, o9 N5 D8 y
为什么引入SGD
3 F1 N4 c1 w& M5 g深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。; S3 \# `3 f: L7 e  h3 }* l! n

' e. i9 N% b# t; [. c$ |怎么用SGD
  1. import torch) Q& b% w/ |/ o* w! l  r\" f

  2. 7 y& U; c% e\" z& _
  3. from torch import nn4 T: @6 m5 Q3 D$ ^

  4. * {1 U1 U# u9 ?
  5. from torch import optim
    , G2 U6 c& `; f& |4 g8 n# Q
  6. ( `! @  a$ L/ e! x
  7. 7 C1 E( A- n, B: r* `  {% s
  8. / H# \! h  c4 b- A2 k* n( Q8 T
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)2 v3 s& T# b; e/ T- n$ [

  10. % N, S9 e8 X. `, O0 a! L
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    6 i% b4 b1 D\" g4 B+ L( m

  12. 8 o0 V& Z( r! l6 `$ s( a1 _
  13. * x1 @, a: x( a* G/ U# R2 I% o' H

  14. : g7 C7 D; D' k& v
  15. model = nn.Linear(2, 1)
    , H5 w0 H' F: f: _4 t- c

  16. : u4 ~( u/ G0 Y* _
  17. % ^) b- H! h; W# j9 k9 k' n% M. \

  18. 2 K+ z( i5 F' y- o2 g( J) _+ y' T
  19. def train():( |7 k( F% I; ~8 F3 J

  20. 3 O8 I  I9 V; y: n, [! s* [  O) G' E
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    3 q/ z* W% K8 t7 B1 v1 }' f
  22. 5 y\" v8 J9 [, n, }' p; Z
  23.     for iter in range(20):+ ~( d* A$ [6 u2 G  p

  24. * u& f1 s+ Q7 K- j% T
  25.         # 1) 消除之前的梯度(如果存在)% e% C9 {* O( I& z\" A\" d7 \8 O3 H
  26. / F, X3 K! r1 x( ]4 U7 s
  27.         opt.zero_grad()) V+ |2 C/ u* a
  28. 7 S/ L; \  }2 v1 Q; }+ I
  29. 5 V7 ^- l6 Y\" E. n

  30. 9 [( m) {. q6 }& z& \
  31.         # 2) 预测
    2 j/ n% o5 g$ T! J9 j

  32. - X6 i: f* ~, M1 f% }$ N; I
  33.         pred = model(data)
    / `6 v% f9 }1 j& r

  34. . H; e1 f1 N( ]4 H1 g
  35. 4 X2 j5 t2 _6 n: Q# C
  36. 3 \2 }, H$ R$ z
  37.         # 3) 计算损失
    1 G4 u* _0 i3 r6 d' G

  38. # d* [/ A3 n* _
  39.         loss = ((pred - target)**2).sum()+ t( U2 }+ m2 f
  40. ) W! o  m, o/ L# Z# Z0 [% w

  41. : a. @' j. ]2 }1 b$ I

  42. 4 w) L: u0 a3 b; ]
  43.         # 4) 指出那些导致损失的参数(损失回传)
    # c. D) S- c2 A& M
  44. - Z' J* o7 z1 k
  45.         loss.backward()
    2 ~0 q$ E+ O5 _% j/ T/ B7 ?% e6 _

  46. - t/ e& _' H8 U2 b/ R- e' G
  47.     for name, param in model.named_parameters():
    # W# N0 ~8 _) N9 h& V\" t; H
  48. ( a- A! M) s  \# P2 t
  49.             print(name, param.data, param.grad)& j. ^+ u2 r- K! Z
  50.   y. H0 Z5 B\" @; v& p
  51.         # 5) 更新参数, ~\" h  E9 k& ~1 d5 I& l

  52. 3 g; N$ R, J! O) m, Y: G
  53.         opt.step()
    , u* w/ e: W- H; Y

  54. % K5 ^8 V3 a: q; Y' \\" q

  55. + {( |5 ~6 Z: D9 c

  56. 7 e$ ?1 ~1 C0 E. u\" z+ W) \6 B5 o
  57.         # 6) 打印进程9 r) X6 J8 R0 g. J  B. A* y

  58. ' h2 M5 J3 b7 e! ^; j
  59.         print(loss.data)) ^% ?! _7 T) \& p! v* l

  60. 3 B5 g\" {. n5 i8 W2 F# p+ o

  61. # G: L# H0 Q1 H\" T* b, f1 ?

  62. ( N7 r5 u* o& G( d- o6 P: H% x1 Y
  63. if __name__ == "__main__":5 p/ y  s3 u' B5 t. ^5 n
  64. 3 o1 i! f/ r- q8 H& r% W* y+ K
  65.     train()
    5 i: L1 l; `9 _
  66. 0 f7 n; d' t* a  e
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。& ?/ ]0 X+ l2 H# F

& K% c3 M7 B6 \- _4 G6 W计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])% _5 e+ l) G8 M4 `! E2 J

  2. 4 ]2 s6 I' G( L& q
  3. bias tensor([-0.2108]) tensor([-2.6971])
    # @  M0 p\" @0 m! \4 M; l/ F

  4. 6 z/ Y) s4 l\" W# s+ [
  5. tensor(0.8531)
    # B: _\" a& `\" p6 M/ k3 n+ o

  6. * x9 D' ~) Q+ v# {3 A0 j; J
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])' O' q- t! h; v* ~% Y% _3 v
  8. : P1 |& T: X6 c# P# U( j
  9. bias tensor([0.0589]) tensor([0.7416])
    1 D* Y* Z8 k2 Z9 _  X% _% S- Z0 ?0 k
  10. ! n1 ~; X+ g$ I9 i$ ?) J$ q\" O
  11. tensor(0.2712)
    9 m% X/ S0 z7 k5 @0 g2 Q! w

  12.   V. j\" p+ s- ]( n) o. M  f
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    2 v$ Q. r! G8 _. `8 t* b5 _1 R

  14. 9 l1 s% }: \5 h) ]& f
  15. bias tensor([-0.0152]) tensor([-0.2023])$ _. Z* S. T: W: [9 I2 z1 I

  16. 9 Y+ l6 `- Q+ q: ^6 P
  17. tensor(0.1529)1 ^* I9 [/ ?7 {6 J# o6 `% E

  18. ' i3 T0 v& k1 x4 B7 s: B
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])* j+ D- l$ x/ T. E4 O
  20. 6 P0 m0 h3 h2 ?
  21. bias tensor([0.0050]) tensor([0.0566])
    8 O, u5 |) i& p2 J$ E

  22. # w% D% h  E! ]: E
  23. tensor(0.0963)
    ! P/ q' Q  ~' w  R$ n  N- q
  24. 8 X# _# ?8 P, J8 a$ U6 Y# C6 J
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])& T6 _, C7 C. N- w7 v
  26. ; ]1 C& z8 A2 A1 O, K
  27. bias tensor([-0.0006]) tensor([-0.0146])
    ( o5 r$ l6 T# S3 ~- i# }& x+ L- ~
  28. / ~2 g0 A* x7 {# u
  29. tensor(0.0615)
    3 C& B( X4 I4 E\" p# D; C

  30. ! `3 _1 a6 V# ]9 j4 F
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    3 e# I! c9 Q; [& P

  32. ( A/ r6 g5 v5 w+ V% U\" F, Z- Z
  33. bias tensor([0.0008]) tensor([0.0048])7 [2 _; h7 O, ?( |7 Y
  34. 3 p! s) e: `& s1 j\" G4 E! k
  35. tensor(0.0394)& ~8 x+ G) f- N' \) X- O, d) A

  36.   ?1 f( M( t) e1 _; b- C, h
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    ! m% {; }! g- f. S$ m7 ?
  38. 8 N. P; Q4 o1 L4 i\" B0 M# }7 e  [
  39. bias tensor([0.0003]) tensor([-0.0006])
    4 ?% X6 W6 o. q6 m7 `) d
  40. , u5 V8 |  S9 T; K. _9 `' F
  41. tensor(0.0252)
    \" \$ t8 s5 Q. k/ S# d. K+ W\" u2 V1 @
  42. 7 v) j  v- m0 u# e
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    . O1 ?$ p6 C& \+ ?, P$ e% R
  44. / f+ }% O, n1 ]9 \/ t! Z0 W6 T
  45. bias tensor([0.0004]) tensor([0.0008])
    % H* ~8 M- S# `5 h+ f* ?

  46. 8 |' I  S$ K( r7 P
  47. tensor(0.0161)
      U. r: G5 v; t8 q1 E. P5 H

  48. / R8 L' N  t* Q$ @- K; S
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    0 i3 p6 R2 W2 n+ y; |! z5 v9 ^
  50. 0 N$ ~/ Z) ]\" v& s) K# {
  51. bias tensor([0.0003]) tensor([0.0003])' j5 j* _- n! N8 f! C
  52. 3 A3 d( ?, K; V6 N  d2 e. W  i$ a
  53. tensor(0.0103)
    / y' N7 L/ \7 H* P. @- h
  54. 3 j6 f\" ^\" O! f2 L5 Y) I6 H
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])5 d, D  f: K: L( l2 J. G
  56. . I- A& @6 ]% S; i6 w
  57. bias tensor([0.0003]) tensor([0.0004])6 d* F0 \: X) Y+ ^9 L4 M

  58. 7 s2 W1 Z2 U3 `& h6 A) r) v! v
  59. tensor(0.0066)
    ; y# l5 G4 m( R% E5 X' k8 j4 s7 e+ W

  60. $ P3 P) e/ n4 G% J3 c
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])  l, q! V0 `; b/ }* [0 B
  62. 2 ?/ ~1 @  r3 o8 t- m
  63. bias tensor([0.0003]) tensor([0.0003])
    ; p7 H* @# i\" C+ E7 n+ A
  64. % D4 K. j) K& n
  65. tensor(0.0042)\" U7 P0 G% g1 t# o8 O
  66. 0 q% i! |# ?6 U! }
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    ! F9 t% M) @# u* J, W\" O
  68. * e  [& z4 {( r# N6 t+ F
  69. bias tensor([0.0002]) tensor([0.0003])
    ( z, `( t4 f. z* ]0 ^

  70. \" L( X' O& n9 l3 R* F+ }/ s$ ~\" W: o$ _
  71. tensor(0.0027)
    6 A7 H  Z: C5 C: z! w

  72. 3 E. ^3 O\" V) n- n7 k. S0 ?
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])+ t2 U& k5 X0 V7 A/ ^

  74. ! R$ l- h6 c7 R5 {* y' m9 Z
  75. bias tensor([0.0002]) tensor([0.0002])
    6 _4 q# [% T% O: w, n

  76. & h* o; X! B2 u! X# y' g
  77. tensor(0.0017)* Q3 A! ?5 H0 g$ D9 W- c

  78. 3 F+ `- U4 |0 s5 x; D9 A
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    3 j3 p# V7 N( l\" O8 F
  80. % ?( v& \- U/ \
  81. bias tensor([0.0002]) tensor([0.0002])
    ! O9 R. Z, J3 V2 h* i9 a$ g0 i
  82. + x1 Q' e$ s( r; K6 t\" \/ P
  83. tensor(0.0011)  T1 N2 u( F& F6 |\" g/ B1 [

  84. / g0 u8 K- Z& f1 V6 x
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    2 y, o. p4 C# J5 F
  86. 7 _# }\" l& A$ \' Y! l* }; a
  87. bias tensor([0.0001]) tensor([0.0002])
    + i0 |/ t/ K2 E# Z+ t0 p
  88. ( _& d, }0 E! E
  89. tensor(0.0007)
    : H0 G+ n- ?! l3 D* }
  90. ! v8 _3 t! P- Q- U: x5 v
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    : ], I* ^4 H  J- K* A* k' N\" {$ p! y

  92. ; d2 N\" ^$ y' V, b, `) a
  93. bias tensor([0.0001]) tensor([0.0002])3 s* s\" }- T' k7 p

  94. - \# J6 ]* \% v! Y2 b0 F1 z5 i9 l! N
  95. tensor(0.0005)
    % B$ f# X\" M3 Q2 s; `
  96. ' t7 D! H\" t* Q\" e
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    \" P+ j% t) N& q; \# Y& y& |! O
  98. 9 q, K/ D+ t/ e  P& {
  99. bias tensor([0.0001]) tensor([0.0001])\" F5 \( C6 x, Q( @9 G0 l\" d
  100. 6 R2 d, w- [: c. o
  101. tensor(0.0003)( m0 f0 a7 K& `$ |, a! j
  102. / x$ @/ g! G2 t! t
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])9 i% ~5 m: I1 f2 ]( v9 O6 ?% v

  104. 3 s( v+ J6 t- u3 u: p+ [* M
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    ) N0 z8 j' F2 f1 D8 \8 c! ~
  106. ) J/ Y5 |2 S' D/ w* |7 d8 s
  107. tensor(0.0002)* W/ S6 G, h  z, Q2 r6 X* L

  108. 5 o3 Z3 Q# c! p% A. e
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    0 ]; z6 }$ y+ d. k

  110. 7 g2 l: D3 c3 Z4 V( c9 Z; h
  111. bias tensor([8.5674e-05]) tensor([0.0001])7 }- j( F# ~, p( r- A' Z

  112. 0 [, Q9 Z5 U+ ]$ a
  113. tensor(0.0001)! E$ y6 D& J5 s5 i
  114. , R4 C# o  t) x8 ?
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])1 E! E' n/ h- o% q& G3 X\" U+ J

  116. 4 W2 Q3 J% j8 D5 K- [  C9 t
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])% H. S; x/ p1 l) ~; }, q

  118. # u( O! N3 h) M- ^# X
  119. tensor(7.6120e-05)
复制代码

# V) v2 g) s3 g; U# E: K3 U3 ~
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-10-30 17:03 , Processed in 0.627216 second(s), 51 queries .

回顶部