QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2621|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
! F% s8 B0 q' b3 V1 _SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
! g7 y& G- [; ?+ T7 g- H' b怎么理解梯度?
" E5 S1 V0 a5 L, B0 v& [假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
0 @9 [* }7 l6 G+ _8 @3 H
7 t7 Q4 Q, w" S" q在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。1 B% ~4 f+ V5 N" B7 r- B* _( j

' u$ w$ a* ^' X5 P每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
6 d+ }6 i3 C9 N! k
. ~1 T+ J2 v# y! N0 m! g/ [为什么引入SGD
$ y0 k. U) L  S4 x# u深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。3 f: U" B% s# s7 e# h1 Q7 l

2 p1 ^: W9 i' u怎么用SGD
  1. import torch
    6 n. L6 o4 x& ?. s
  2. # G/ e4 a8 p3 a8 K) N
  3. from torch import nn' B# p: _3 o6 R1 h* E/ k

  4. ' @4 g5 m, H; K/ x\" M
  5. from torch import optim, o& e, c+ i1 D; ?

  6. 2 c, z. i* a, t9 {' r, ?
  7. % }9 e/ I4 T6 `\" a$ ]

  8. ) i2 f. ]( r6 l7 T, F, x7 G! T7 p
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)' ~( J9 t\" h7 u' i! S: K
  10. ! a' ?1 V7 O) e- }: w1 o
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True): ?+ Z/ v1 S, s# H( U\" d& `

  12. 3 a# X4 C5 a- P3 {( g: w

  13. 6 U8 |% i* H. U8 f+ Z6 g
  14. 4 q0 f% E. F$ L/ }; q8 Z' C
  15. model = nn.Linear(2, 1)0 _' s7 @1 n4 k2 p+ p
  16. ; e. w6 q4 S. r! O0 e5 ~! Q

  17. % s7 i) P0 @0 [  W& r8 x

  18. $ w4 w+ v2 z+ `/ F
  19. def train():
    4 d7 G7 x( g9 j- Z, z
  20. 4 M) D5 m. G3 M% z* s) T; e
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    : p2 s4 K7 j+ ]4 u% f$ H% D\" S
  22. 9 H5 ~4 h+ [  O2 V+ T
  23.     for iter in range(20):, x0 y. t8 z2 l6 N0 q& ^
  24. ) f9 K) @\" x5 M* `. w3 R% r
  25.         # 1) 消除之前的梯度(如果存在)1 Z, F( q/ W\" y
  26. / L& S* B2 M; r5 f\" y6 W' x
  27.         opt.zero_grad()
    , y) I\" j) z' r8 t0 s. ?) |4 H; t0 q
  28. / H% D% [0 h. T9 y6 h* z\" c
  29. , n! U% ~! r0 p$ P

  30. \" d3 y- g( r  E, n
  31.         # 2) 预测
    % g- i9 O! q( \6 f
  32. ! ~( ?' E# {& c$ {
  33.         pred = model(data)+ j  {4 Z, W7 S  Z- r- I+ ?' z

  34. ! G& i* B2 @' \+ H% h\" B6 G# i
  35. , M5 e* t; \6 F0 H; A- n
  36. * E/ K8 v* I8 a  u
  37.         # 3) 计算损失
    9 \+ g3 b  `( z7 n! _7 M4 }# \
  38. 7 `% K4 a6 L9 _, P' G
  39.         loss = ((pred - target)**2).sum()
    \" [1 P! a8 S' J1 J, k
  40. 3 w0 ~1 B% ^! t1 n- O* G! w

  41. # I, i+ I% x# q8 C9 {7 J

  42. . i6 y% P9 x; h8 U
  43.         # 4) 指出那些导致损失的参数(损失回传)
    # M+ I& x# g' ^5 M% y  ^* g
  44. 3 t7 `# i) X3 Z\" q$ f' c
  45.         loss.backward()
    9 O) O. l0 j9 f# e

  46. : p/ l, {* |\" P& x. R- p
  47.     for name, param in model.named_parameters():
    - I: L. _& N# u1 A3 S

  48. % b! r( i4 A6 k& c
  49.             print(name, param.data, param.grad)
    , w8 l/ Z- m7 J# {

  50. : Q$ F+ ~. F6 w' Z5 z( M
  51.         # 5) 更新参数
    7 u* D, B5 n, L& l+ W( x

  52. 0 M0 F8 o* `6 h, p, E
  53.         opt.step()7 {1 _- R: Y7 E' _

  54. : c9 _) l% ]& d9 j( D
  55. + i: t* ~) E$ v' I

  56. 9 F- ^) |, M  L0 {+ o' c
  57.         # 6) 打印进程
    ' l' \, Q1 d5 P7 p1 l

  58. , I/ W0 s, X  I9 t0 s2 c% A
  59.         print(loss.data)
    0 V) a. \* z8 N

  60. 9 V! `1 _) _8 F

  61. ; @; l$ F* E, m8 I

  62. ( n1 t( N. W7 S* v+ i% V( F% H
  63. if __name__ == "__main__":
    ! T+ v/ Y6 ^* d, w

  64. * C7 w9 O' R, a\" a8 K
  65.     train()
    ' k* g0 P+ b) y1 T  ]! N4 r

  66. 0 l' `; `9 [4 o# e4 }( u, @- F8 e
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
8 N  p) }- i8 Q2 @* J7 d6 O  |
: v) ^; Z# t& C; n# @计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])% |/ y, }* {- [\" X2 h) Q
  2. / b1 p$ T1 E3 Z' t! `# c3 ~* S
  3. bias tensor([-0.2108]) tensor([-2.6971])( w* p2 f# U! f& \( r
  4. ' K7 ^7 b, R3 z' R: `7 s
  5. tensor(0.8531)  H  o: w7 b% b$ e
  6. 9 X/ C8 H2 e' N  H! J
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    & y' ]' \* e6 P  B# U7 B

  8. \" N% x3 a. I: }6 W* a
  9. bias tensor([0.0589]) tensor([0.7416])% S/ O) Y' d, F. U, J

  10. 9 {4 A% ^) _5 o
  11. tensor(0.2712)& Y; q\" X# K2 Z$ z5 B  n
  12. - `- j3 G% |1 S9 P
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    ; {2 w: n$ v. D

  14. # {9 a8 l1 q$ v( a0 a) \* f$ ]
  15. bias tensor([-0.0152]) tensor([-0.2023])
    4 p9 x! P) j3 e; k8 L7 Y* `4 Z
  16. & z* B+ z* T6 j, |
  17. tensor(0.1529)# U! j9 W\" f# Q! A& Y: v. Z- F' o

  18. ) Y. L\" k3 o# o
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])$ D5 I# H' K) R+ R
  20. 8 p/ ]* M$ N2 {4 `) s* x: |
  21. bias tensor([0.0050]) tensor([0.0566])
    ; o# r; p* r; \5 f0 T

  22. 3 y- }( {0 l$ I5 u
  23. tensor(0.0963)
    1 ^& o$ i/ V2 D

  24. 4 b9 z8 {: p  J
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])  H0 w4 I9 Y+ ~: R
  26. 6 ^0 Z+ E\" {1 `
  27. bias tensor([-0.0006]) tensor([-0.0146])
      f/ o6 |  V6 D1 _' [  i0 r

  28. ( k0 S; G3 U2 k7 }9 h) `* h) n, S0 @
  29. tensor(0.0615)  n& G' d* s+ G$ s
  30. - p\" i: R& ?5 F6 X# s6 A
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    % w, `% t7 m* D. v1 s

  32. 0 g: p1 r: _9 K4 @; |/ F- ?: R1 x7 @
  33. bias tensor([0.0008]) tensor([0.0048])
    9 @\" e- }# `4 c$ Q. O
  34. ! R2 r; @8 ?* ^. H% s3 t
  35. tensor(0.0394)
    3 |0 T7 ]; O- v5 I\" }1 S

  36. 4 L# j; J  V2 {/ h
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])\" k& K% r* j6 S; L2 u- T% t& I
  38. 3 p6 b, N2 o1 i& y. M
  39. bias tensor([0.0003]) tensor([-0.0006])
    ; `8 G3 \  n  {( |# O- S

  40. * X9 C0 ]( `0 t- n8 e  x
  41. tensor(0.0252); _$ n& P* N- A: S
  42. ' n$ [, O0 r( B7 I7 }) R2 }
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])  M\" ?) E& c2 n. h+ X; A3 A( r$ ]! T4 E
  44. # D1 N5 E\" {- p+ |, Q6 R, }; C
  45. bias tensor([0.0004]) tensor([0.0008])
    + B9 k# I7 X  G) `- ]# \

  46. , G7 Q/ }! g& k5 t+ S; ^' ^
  47. tensor(0.0161)
    \" o) M' `8 L, I8 L: x# d0 q5 Y; [
  48. \" P+ _5 W# M3 J5 [) P) v$ `
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])( @) q7 o. M  N' ~

  50. ! y: N& J: i' O$ E( H' I# N
  51. bias tensor([0.0003]) tensor([0.0003])
    2 e\" \$ v6 h% @( Z2 H1 x9 Z

  52. ; i3 V% Q! E3 [, Y( ^; q
  53. tensor(0.0103)
    - b& R& R7 A* K1 u1 O2 C/ A# Q8 F

  54. 9 ]/ p4 s, |5 o* o* _% r
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    ( W! ~* y6 a( h% {  D
  56. / y* y5 ~& W( c9 [! x/ A
  57. bias tensor([0.0003]) tensor([0.0004])
    4 b9 M; R2 e- u9 B

  58. : _- n( f- v1 E
  59. tensor(0.0066)( p\" ^, l: n- g7 B! L4 ]! `! Z6 Y5 N
  60. 5 }) c& Q\" ]3 g& p
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    3 L  j- R# x4 z( ^, R/ E2 W

  62. 9 B: b: R- G( Z- G( Y0 S6 `
  63. bias tensor([0.0003]) tensor([0.0003])/ N+ y8 z0 p5 M  c) y5 ?3 M) u6 j
  64. ! |% z1 l% |* E- T3 D: L8 r, l
  65. tensor(0.0042)
    $ c) a9 j) {) H& o! o
  66. # @% |# J. M/ S6 m
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])5 g3 j5 u3 d2 i9 P; d0 p4 I- p

  68. % o- d0 w, M2 O4 v8 b
  69. bias tensor([0.0002]) tensor([0.0003])
    % a8 W- V\" z% V) L( g  E
  70. $ j- S% Y2 L8 ?, a( _7 r
  71. tensor(0.0027)3 X6 w/ s0 ]2 T  S# \: T) R! [

  72. & ~  k/ k6 j\" {) |6 N
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    8 j, P/ O( Y) V2 a3 L, u

  74. + ^5 J% j1 A/ r4 P
  75. bias tensor([0.0002]) tensor([0.0002])! }4 h$ {' V4 l7 H, V- A

  76.   S5 i5 U# B* M* [* ]. y' M
  77. tensor(0.0017): o9 F& }7 {, R5 h

  78. ) b- \+ _4 s: \1 G  t/ O
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])& ]- \# \& ~3 U  l8 Q! @

  80. 8 ]9 ^% I7 ~2 [3 k$ e4 k) F\" L! J
  81. bias tensor([0.0002]) tensor([0.0002])5 {. ^: n6 p/ f/ v8 N3 s

  82. / w4 O) e0 {8 c2 a
  83. tensor(0.0011)5 Y( g/ B+ B# }

  84.   E: z6 @4 L2 j; s
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])- P0 ~1 |* r1 s! B- g

  86. : k* a! V7 j& L; z, H, \) d) R
  87. bias tensor([0.0001]) tensor([0.0002])  V7 t8 ~/ N, E% m# }

  88. . B( g  C# o9 @6 f5 T+ s
  89. tensor(0.0007)
    4 ^0 D. [3 s( K6 I) X  l

  90. 4 |) h- _/ T0 V
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])# T# P3 Q* {; M3 I3 ]( Y& {! L  S
  92. + }  Q# c( ^+ C
  93. bias tensor([0.0001]) tensor([0.0002])/ b8 }( V4 ~& S0 }
  94. - g8 A4 j/ g* b+ k# z
  95. tensor(0.0005)
    % K# n( G3 [1 ^6 i
  96. # S- Q) T% u! |( Y8 I
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])& K4 w' d5 N/ o9 W
  98. ) D7 e* j5 A( b/ D9 k' c
  99. bias tensor([0.0001]) tensor([0.0001])9 p0 a, r9 d. |2 D\" p
  100. , ~7 ~$ ?% _0 {1 @# s9 b; F- C# l
  101. tensor(0.0003)
    . n% f6 d8 c7 }( r

  102. ' {# {, {6 \\" B* j- V9 U( }6 z
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])1 ~  l\" c' t7 S\" h6 T. O4 P
  104. 9 K0 t2 X, {. y& O
  105. bias tensor([9.7973e-05]) tensor([0.0001])* \7 |4 o$ B& o

  106. $ ?8 _% |! v1 k
  107. tensor(0.0002)1 J4 N1 @3 e$ g( `! Z. @; n
  108. : \8 E- b! J! B( D
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])6 c* a+ e: D4 `$ R\" m- R\" l% T
  110. ) c( W8 ]  ~' ~+ ]( ?* H
  111. bias tensor([8.5674e-05]) tensor([0.0001])5 w. o( ?: L# D3 @$ s+ I
  112. * Q8 g/ h' Z7 c# V
  113. tensor(0.0001)
    & k& e! J9 `! p+ w8 G

  114. 7 o( A% R  g% w5 v* ~; j
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])) U\" q% s  R5 \) y+ N( N

  116. * A: t2 r: _7 N8 \$ U( [. K
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    $ \0 |4 Z2 _# M. ^

  118. 3 K9 E; X! P( ?% E
  119. tensor(7.6120e-05)
复制代码

( p0 h/ ~# Z6 |( b+ P
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-10 15:35 , Processed in 0.339380 second(s), 51 queries .

回顶部