QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2658|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1188

主题

4

听众

2931

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么- g! r8 [& G6 t+ W
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。; H9 c8 Y& c6 f4 G1 C
怎么理解梯度?0 S5 v" \  N, t; ]6 ~% `
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。! f1 x/ @" ]+ n, E" H0 G

  ~3 i5 |/ r4 i; G在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
0 M2 m0 }; _$ a" {: N* s
& T7 E* o  V+ a  i" t每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。* P; {( v! u' O4 t1 ?: I  [

4 n( T8 N' [4 ^" O- [$ F4 ]- s为什么引入SGD" j+ l3 d2 u0 }
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
  C3 E! l/ E8 P& P/ _/ B, J, e' q5 [/ p+ T. O) r7 h' P- j
怎么用SGD
  1. import torch
    : D  t$ G, ^2 w
  2. 3 J4 R; I, Q% ^9 g
  3. from torch import nn! s5 h8 {' h; N: D5 u: ?
  4. / G5 @0 b( x0 a: A
  5. from torch import optim
    . R- b. S7 j0 q% B' V

  6. - u8 y- ]& Q& Y; y& a' k5 I
  7. 4 u3 A3 o9 c9 H  S, @
  8. # h% x' Y9 i\" z* E' L! e3 C: _+ H
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)7 q9 G: H: C% k! |, w9 \( S

  10. 4 Y\" w1 u/ O5 m/ U
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    ) d% k) L! z& }9 y0 p

  12. , w% h\" k) G  y/ j4 O' m. y! l

  13. \" `% B\" o  f% R2 c
  14. ( B$ C& q# O! g9 }8 Y
  15. model = nn.Linear(2, 1)
    ) O; b0 [; a; N2 O
  16. 3 ^8 I; G- m* p/ Q9 M1 d
  17. % j) k* J* v# i$ i- k/ T: z

  18. 3 t\" R6 T% |& D5 E
  19. def train():
    \" c  o$ z( e- F9 D( e( o; q
  20. * i4 w. k+ p' P: r8 B5 s& N
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)& y: }. B, E, H- ^+ y, T: N8 g9 R6 i/ g

  22. 9 N9 Y' w8 P' S! M( s( K* P  F
  23.     for iter in range(20):
    5 y  }7 Y, f' }\" `; t6 @

  24. & ^/ Q- P$ o/ S2 ?
  25.         # 1) 消除之前的梯度(如果存在)
    . x2 I9 E\" [$ U

  26. % ?1 q- d0 t7 R4 G* G: K
  27.         opt.zero_grad()
    ) ~\" }  W3 h) z; o. ]1 G4 L
  28. ( I3 |( O* K' ?$ f. s3 y
  29. 2 e: T, D5 f% K+ `. C- m5 V
  30. 2 `2 ]0 u$ `0 q8 ?# }& h
  31.         # 2) 预测+ U2 \4 ]# N$ \5 k0 I
  32.   l' v' \- ?2 {% l2 j1 i9 s! t
  33.         pred = model(data)5 J. C7 d; j  [, a, D
  34. 4 d2 g8 a' b0 `9 W: a
  35. \" s/ Y' q6 o  p+ c# X& ^6 _: u  o% g

  36. ; E! u- Y# \' V# V7 q- i. h  m, l\" |
  37.         # 3) 计算损失$ M/ z. G5 R8 V# T
  38.   W; b2 y' y, Q  L4 Y
  39.         loss = ((pred - target)**2).sum()
    ; i8 J- d\" A& L& E3 J5 k7 v

  40. / d9 k5 ^+ S& D  j6 V
  41. ; \+ K# ?# K3 T0 L! E

  42. % R6 f6 Y5 d3 X% a4 v% m0 O1 Y( m; P
  43.         # 4) 指出那些导致损失的参数(损失回传)6 q7 n5 A8 D9 F2 C

  44. 3 Y! C\" s/ ^$ n' h5 a- [
  45.         loss.backward()
      |0 d9 `3 ^- ]% x& v& g+ l

  46. 0 v+ T7 `' }7 ?2 d) ?1 e* [; R) U) q
  47.     for name, param in model.named_parameters():
    / @  o2 j  @% W# i' E1 z; V6 C2 W

  48. . R; S8 |1 e' w3 Y: M4 s0 u0 Z# t
  49.             print(name, param.data, param.grad)  o9 ?; {- U1 _\" j! u& G( O

  50. 6 h; K( k/ E3 r7 ]3 d3 E% V/ f
  51.         # 5) 更新参数- ]+ |3 r( x- j8 |
  52. \" j1 Q; s% x- |4 Z; a) C
  53.         opt.step()
    ( q4 P& s! H; m  h- U; V' f6 F
  54. 0 P$ y. b( f1 @$ f( l
  55. : e. K- w6 T3 m: u
  56. : s4 d4 W3 ^  g# _
  57.         # 6) 打印进程
    ( w# }9 l, V: i# M% u6 E

  58. / {% G7 v/ N2 A- ]  {- r* c2 A7 s\" b
  59.         print(loss.data)
    2 e5 g$ W2 Q3 x

  60. 9 ]# X/ g- c) x6 V& n! |; S

  61. # Y, G) d3 A* y8 c

  62. 6 }( W( D1 j- |' D/ }
  63. if __name__ == "__main__":
    ) y( h# r8 L+ F( J# Q' A\" z: i

  64. % w  \; E/ _* m8 A# E' Z$ p# r
  65.     train()& [3 X1 G. D1 M! z( w  F& T

  66. ( K5 D0 L: y8 G0 ^
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
/ E0 A8 \7 s: M* u" q- f' A
5 W. ~# w9 z( P* B8 M9 J% \计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    9 o7 j% W+ ?: E

  2. , A6 P( m. ~$ x$ g+ M0 C$ h
  3. bias tensor([-0.2108]) tensor([-2.6971])8 R2 q3 {( U/ u! l# a

  4. ; O9 e9 K( |$ j1 Y% I, V
  5. tensor(0.8531)
    6 ]# m1 m( t7 w- A+ b0 x
  6. 4 h2 v* Z. i- X1 i! e
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
      ]' Q) V; J2 {! v; ?
  8. 6 o. @' J/ G& p% `1 k& h+ n
  9. bias tensor([0.0589]) tensor([0.7416])3 Z) l0 s( l) g2 O4 J0 x2 q\" ?
  10. / x\" |7 N$ z' z7 M5 f8 J7 _
  11. tensor(0.2712)1 m; i\" [. o1 n% X4 P; C' L% ?\" M9 |

  12. : {  l9 a  K; a/ q, v
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])! a% n, q, L( k7 X, c0 T\" Z

  14. ) z( ]  Q; k+ J$ J! t
  15. bias tensor([-0.0152]) tensor([-0.2023])
    7 ]+ e. Z2 F+ t( A/ h  t6 i

  16. ! ~9 G) Z5 Y0 M
  17. tensor(0.1529)
    ! M+ h+ G9 h9 P7 d! W, ?* v

  18. ) V: s5 \' D6 y\" L7 u4 q
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])* y5 P0 P. L5 ~4 A9 {( n& f

  20. + e3 b; a  O* r0 \# \7 U# k/ W1 p* ~
  21. bias tensor([0.0050]) tensor([0.0566]); O. X/ L# c) Z8 {6 Z9 ]  }; f

  22. : D5 H/ Q5 c$ N8 ]
  23. tensor(0.0963)/ i3 S/ q8 b9 R\" u\" u7 p  X- ?
  24. 2 a7 ~9 n$ h3 B6 r
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
      I8 ^\" a9 a  ^+ }# q6 K9 [& l& ~\" @

  26. ! P\" u7 |* T) F5 p
  27. bias tensor([-0.0006]) tensor([-0.0146]), L' S# \' F0 ~: G2 `; f5 i

  28. 1 p2 P1 X+ t+ F& O
  29. tensor(0.0615)
    # |9 D! z2 O1 ?% F0 F
  30. , A! X# {5 J4 @# M  a( w
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    ; M. I( p9 C  u5 S6 S\" m, B

  32. & H' U0 w& [2 d, u6 i0 S- l% ]4 y
  33. bias tensor([0.0008]) tensor([0.0048])4 K$ z% Q% y* b. J0 Q' t* z+ j! T
  34. 6 p  ]! k# x/ E  N) }+ G
  35. tensor(0.0394)
    . w6 O  V\" l# g9 `+ h; f
  36. ( |' s# z( l! h+ O& a
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])' u% w/ o+ Q6 e; h( x/ Y

  38. 2 ~% X8 a- n) B8 f1 T4 a
  39. bias tensor([0.0003]) tensor([-0.0006])
    9 w/ X, r+ ]. g\" }! X3 s, U\" f! y
  40. : M' G% |: Y( W3 a
  41. tensor(0.0252)7 ~* D) b# C8 V; M* ]
  42. ; H* R# W4 p/ D4 ~. b8 p
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    / d! q, o$ H\" z! q: I/ O
  44. # f, ^/ k0 }1 r\" R2 y$ g
  45. bias tensor([0.0004]) tensor([0.0008])6 Y4 ^% e4 b\" A$ f* H
  46. * Z2 ~5 F0 H' h( G
  47. tensor(0.0161)2 N! P+ w/ ^4 r, x
  48. ) x5 e) T0 f9 g/ U; h: _
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    ( H' Z1 _3 e4 }! V7 V; H9 i

  50. # B2 L% ~5 `! I7 {
  51. bias tensor([0.0003]) tensor([0.0003])
      G! V( J) Q- Z- q/ ?
  52. + Y- z9 G' R+ B, c' D
  53. tensor(0.0103)  ]: e; I; v8 H! I* f9 y\" L6 d

  54.   i* y6 ?\" ~5 p  s
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])9 E3 v; N0 H6 V( x$ x$ A\" Z

  56. 3 d* _) Z1 q) t. M- |- A0 k\" v
  57. bias tensor([0.0003]) tensor([0.0004])+ i, F/ Y8 M5 E& b
  58. $ Y2 q\" k& p/ |: I& t; {% J& i
  59. tensor(0.0066)
    : @/ h. V5 [\" N0 C0 `! a$ ^
  60. 6 t2 m0 b5 h) l( ?
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    \" p- r\" }$ V6 Z5 i  ^

  62. ) W, |2 ^8 _' i3 q
  63. bias tensor([0.0003]) tensor([0.0003])
    ( [% n7 K; y; k& D, M
  64. 2 }5 x; p5 x\" f% h3 }3 C
  65. tensor(0.0042)
    ; i: y5 B2 U6 n3 I: S

  66. 8 e7 D2 F' e+ @  @% J! v& `  }( F
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])# `7 E6 e# p6 f& y
  68.   h1 a+ m1 V4 ~% |3 N/ a
  69. bias tensor([0.0002]) tensor([0.0003])+ |/ {- Q& S/ N- w. D
  70. , W; c1 p5 u3 o4 Z
  71. tensor(0.0027)
    - a( I7 C  v, J/ }\" H( A; G: _
  72. 9 o6 W: m# [* t4 Q1 D$ x5 |& a
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    / S* g, k  s( ~8 ^2 k' \
  74. 3 K! J* q3 C3 F& v
  75. bias tensor([0.0002]) tensor([0.0002])
    ) q1 a2 g9 \2 q( L4 \( F# k! ^9 C
  76. % s9 B' y9 e; V
  77. tensor(0.0017)& N  @$ k1 f8 Y- ]2 z

  78. # [, ]) l/ |  b\" v  B7 O
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    0 T8 M  Z1 q( x  K9 w# P9 o

  80. 2 o. }/ W3 L6 E; ~' p5 u
  81. bias tensor([0.0002]) tensor([0.0002])
    - G\" L; E( q2 N3 `/ ?& G1 R
  82. 9 n( D! I5 ]0 ~0 J
  83. tensor(0.0011)- \) L$ u$ Y$ _. B  J, H4 L
  84. % H  b0 ?! s1 G; l
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])  B0 Z6 q. u- e: x
  86. 7 c. B% R0 p) t# N; j4 O2 g
  87. bias tensor([0.0001]) tensor([0.0002])) ?+ s: {1 g- D' i4 x8 Y) [; V4 x  D

  88. 7 N& S: K\" @5 S$ }
  89. tensor(0.0007)2 M/ i& J2 _6 H9 K
  90. - D3 q* g+ p$ E- \% }
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])# q. h1 u: W% D4 @0 C

  92.   B+ y3 i5 o+ w* D- I; W
  93. bias tensor([0.0001]) tensor([0.0002])! ?& S; a$ G1 W' `; K' V
  94. , [1 ^# V2 P+ n\" V% R% `
  95. tensor(0.0005)
    - l+ e4 p\" m9 ~5 G# h
  96. # w6 y# S+ \$ k! ~0 y% \
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    8 K$ H! E6 }% D& \
  98. - N% Q7 z) [1 z; [8 y
  99. bias tensor([0.0001]) tensor([0.0001])
    % h0 w6 s2 |* Y

  100. 0 C* F5 P1 d+ `\" w- x7 |
  101. tensor(0.0003)4 j& r; ~' v* b3 z
  102. & t% Q4 G$ I. j6 h
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])% h1 o  y! T& M: a+ i( D

  104. / N8 ^2 u9 R; ]- m
  105. bias tensor([9.7973e-05]) tensor([0.0001])- L: K0 [' V+ Q0 R- O7 v. g
  106. & e& ]\" A' g7 @3 f4 F9 `- \- j
  107. tensor(0.0002)
    , V; u5 X\" Z3 u2 K3 a

  108. - y$ F3 t$ |6 K$ M0 J: x, Q
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    2 Q8 s0 ]- ^9 u( w
  110. ' k9 s6 P4 f& j0 b  a
  111. bias tensor([8.5674e-05]) tensor([0.0001])
      `0 b\" j- z' U# X. M! E7 e
  112. + [: Q! g\" w8 B
  113. tensor(0.0001)+ M% v  T+ d+ P; Z  Y) e- U
  114. ' U8 y5 `. P. A0 }  Y
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]]): O$ I! w- W: V2 S' A# x! k$ }

  116. / x/ l4 v7 \% {, |5 f
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    ! x1 C, S5 e# A\" x# k! c5 y3 L5 T& H, |
  118. $ H' X8 d; D4 s$ E: f( [
  119. tensor(7.6120e-05)
复制代码
3 D9 n$ f; n9 J4 i3 H$ U9 ^
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-5-26 05:13 , Processed in 0.416003 second(s), 50 queries .

回顶部