QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2669|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1189

主题

4

听众

2934

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么* `: z6 J/ E( B& b' e
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
/ J8 f& ]1 f" O" }) a7 q怎么理解梯度?
% \1 T/ {. H6 Z; Z! h5 {假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
; H5 P  A/ ~1 n7 h) B& U5 _* L) _' k1 c5 u3 D! r
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
4 t; [6 l+ z) U: I, o, k, b$ @
0 H, G) s" T9 l( {( b' B每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。- k1 A& |' u; H3 X8 s
+ C( M% j) }3 N+ i
为什么引入SGD
, b6 ^0 V0 R: J/ i9 _深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。3 M' Q+ G  _& J3 J. ^
6 w6 w4 h2 O0 I# C+ ^
怎么用SGD
  1. import torch' _2 ^, t9 ?6 m

  2. 6 {6 H  A6 o0 z+ ]/ t
  3. from torch import nn7 o7 @( N5 V; E$ p, s* d

  4. . `5 q! K% o3 G) m( v8 N
  5. from torch import optim% ?/ c. P# o( g$ z
  6. # R# x2 ]% ]* [8 Y/ d  G
  7. 9 U: L) |5 G& B( U+ _5 b8 T& D2 e0 o

  8. # a3 \' K' \/ ]2 I/ ]! `
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)1 W( V% z7 J! c% J2 h* l2 D7 E

  10. : m+ K) c' V, `! R! \' _! i
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    + d) w6 p- V7 r( x! Y0 j$ X( R, ~

  12. ) Z# r+ D1 o* X% E* c) h7 a5 @
  13. 9 y/ F0 c4 Q4 {9 L1 |6 k

  14. 3 M\" N! `2 ~7 s4 y. ?2 N2 N
  15. model = nn.Linear(2, 1)% j! h! E9 g5 Y8 f6 g; x- m\" v1 L

  16. # S, J, T# @  G; q
  17. : e, q. g* z6 I' I

  18. $ N, h6 k+ M7 j
  19. def train():
    / a2 H5 ~* J/ F+ Q) }( F2 G. _. j* U9 z
  20. # }( A  F% c9 k; B
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)& Q! Y) t, r& E5 k& _/ b
  22. 0 ^' S# k; g5 ^. j
  23.     for iter in range(20):+ @8 {/ r6 E2 `' U! l& o

  24. ! y. H! p0 b3 Y, ?\" F6 x
  25.         # 1) 消除之前的梯度(如果存在)
    \" f, X; m( ?7 n
  26. 2 C2 v; h8 k( d\" ~/ S  v
  27.         opt.zero_grad()
    $ k* n! e\" n9 e\" ?
  28. 1 u& j, M/ c% O' n* }$ ?
  29. + o1 j) U4 x/ P& b. b* O
  30. ; z1 v3 U9 d  c( Y  ]% K8 b* S, }
  31.         # 2) 预测
    / _% q, [! Q$ `  \) Q  B
  32. ; k9 [\" B% D9 G  e& u3 L
  33.         pred = model(data)3 w9 _7 }8 X: G6 Q# e. O- G, Z' E
  34. ! Q* X: D4 R, ]) y* S9 G: B' p% {+ j

  35. 4 z1 `8 N% w% }3 P% ^  X( d. O
  36. 2 x. E, W- P\" Q; b) k6 M! G: Y
  37.         # 3) 计算损失, E5 b8 G, N! E; ~7 [& Q
  38. ) X8 w; u0 ?) i\" G, e
  39.         loss = ((pred - target)**2).sum()
    ! t/ S' K0 u+ s
  40. ! z5 T( J, G5 M0 c* }# y7 X

  41. 4 W4 @5 n' w- U- v
  42. ! s7 G9 S+ P# b2 R& `1 u
  43.         # 4) 指出那些导致损失的参数(损失回传)
    ' O- @5 O$ M! z; ~
  44. \" f1 e- W1 J5 b6 u' X
  45.         loss.backward()0 c  w# |: q- P' J$ G

  46. \" ?* K. H9 g/ z, c. e5 A
  47.     for name, param in model.named_parameters():8 j- d6 g/ x8 h

  48. / B4 j% f+ a$ Q) D
  49.             print(name, param.data, param.grad)2 s2 l: K! y, m0 z- c\" k6 g
  50.   V0 d1 B; L* \, {+ F' k9 v
  51.         # 5) 更新参数
    \" K* h$ C! t+ n: f

  52. 9 t9 }( X6 M4 b* i; A
  53.         opt.step()
      q- d2 J5 s% ^. i7 r# {

  54. 2 A/ _* l$ o, ^9 L
  55. ' R0 I  Q. N# R% `- h9 p  [

  56. 9 W1 o! ^) {% m\" i% D3 E2 s$ N$ I
  57.         # 6) 打印进程9 V: k% @6 s) [: O; H: r$ Y: z
  58. 3 i6 H# J5 |  k8 y& e
  59.         print(loss.data)
    ; \% z, X: w4 A3 Z5 u

  60. ! O4 n8 r. b; B! |/ t5 J5 b
  61. ' K  j3 B5 z2 Q1 p

  62.   `2 c+ T  t\" u& ~1 n' J
  63. if __name__ == "__main__":  E$ h6 y* g$ h

  64. - c  j! f! E% X  Q4 Z& F$ B7 s
  65.     train(), U- E8 Q, k/ Y) _$ l$ x! w
  66. ( W2 G  E. f6 ^) v% i. C
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
* V" \. G- c- m# s5 R$ b5 h, D7 I2 c5 H
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    ; J3 P* S4 L* v% I) d' s7 L
  2. * e' n2 h% F. a5 c
  3. bias tensor([-0.2108]) tensor([-2.6971])# K4 d# C' @# O+ _* A
  4. ; E9 V, d$ o\" {4 W! U. D) v
  5. tensor(0.8531)
    # F/ ~\" V  L, y
  6. # R7 A* i+ N* q; n2 `
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])2 ^( P5 v# u$ x# x

  8. 0 O; X5 B' P+ ]/ r/ Z
  9. bias tensor([0.0589]) tensor([0.7416])
    3 t8 f+ U. ~- K. W. d+ j8 j3 s

  10. ' g/ P3 W9 z/ |* [3 M\" t
  11. tensor(0.2712)1 S! ?9 }5 ^1 E: V: X

  12. ) h* x9 U8 H5 l
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]]), ?: G# M5 @) s) ^

  14. ( C% ^% K1 ^( @* m\" N3 [+ D3 J
  15. bias tensor([-0.0152]) tensor([-0.2023])8 L: [2 M, Q' u1 ]0 N- h* G: x

  16. 3 d$ }; F. j9 b# y* F4 X) R
  17. tensor(0.1529)
    6 M* H2 t  ~4 X0 |

  18. $ A& D6 d4 I4 O1 V  |( m) g
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
      r  E8 E7 E- L' `) w- t
  20. ' P8 @; w3 i# r7 j* N- L& B% A% }
  21. bias tensor([0.0050]) tensor([0.0566])
    - o, p4 G0 f9 m$ u, B' o

  22. ; A% O# u, T9 E6 ^- e
  23. tensor(0.0963)
    : o8 x% L9 d+ [1 {6 l
  24. ; L+ Z) f* _9 P; _5 w: M\" A
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])  F$ M& x  x0 r: K9 j
  26. * c( _7 L8 N  Y. H7 t6 P+ k. z' g
  27. bias tensor([-0.0006]) tensor([-0.0146])
    # f+ C4 Q  t, J) i) q; i1 z

  28. ) }\" B7 t9 z# A8 T9 A; D
  29. tensor(0.0615)
    8 @\" a. E\" e1 S& A

  30. 6 @$ a* I# ~- O& l2 c& F
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])  d/ H5 s5 t* C) s3 J- N

  32. 2 q2 m8 e! R) f% X3 M( F3 ^
  33. bias tensor([0.0008]) tensor([0.0048])
      m8 i, ~) U7 N# Z# |, q

  34. 0 Y, U# V. e: f7 z* }
  35. tensor(0.0394)' P0 f5 O1 j! \& z

  36.   N: \  F2 ?; L. E/ l
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]]). W* u& J- j& f. u1 n/ A5 l' N! K
  38. % q* F, r/ [- d  a3 T
  39. bias tensor([0.0003]) tensor([-0.0006])
    0 Y# s. {( {& I( ]$ D

  40. 7 w( V  F6 e. p8 {& x; n- D
  41. tensor(0.0252)7 V8 Q# a/ q. v2 }
  42. $ d; ]5 B6 v8 m9 X\" X# @
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])! D7 a\" _, ^: v4 b, I% ^
  44. # A' t2 g. O! c5 D
  45. bias tensor([0.0004]) tensor([0.0008])7 G- L8 F7 \8 I5 T7 k- V9 K$ F  Q
  46. % t: r( {7 U; U- A* a6 }  {( E
  47. tensor(0.0161)
    + h4 X9 [* l/ V: ?6 M0 D7 V4 s. q
  48. : _9 U! R& f) @2 O# j' C
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    4 D6 P5 q4 B1 p

  50. 3 \& {% i- ~- z+ t7 r4 G1 \
  51. bias tensor([0.0003]) tensor([0.0003])
    ! g8 I/ P3 `0 X- w' R
  52. 0 s7 I) X1 @' V4 {/ S! w
  53. tensor(0.0103)  T+ _* Z6 j% x' l# s; K( W2 g

  54. + [/ V\" M3 A; L
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    ( k0 D2 A* ?6 ]) `, F- _
  56. \" _6 N6 n) e4 \1 O& F( @
  57. bias tensor([0.0003]) tensor([0.0004])' E( I- |- b3 t% d. b- \6 J
  58. \" x8 y' c& J+ S0 y& T4 l3 Z# U
  59. tensor(0.0066)
    ! m0 H, L5 v6 R7 |+ B  O' J
  60. / v+ [\" h& W# e. H) I, C
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    # F, F( j* n2 z  p$ H: L$ i: n1 [

  62. 8 R  G- Y0 M% Z' I: D
  63. bias tensor([0.0003]) tensor([0.0003])
    ( A% \* S) `$ i, O+ [
  64. / u* r6 E! D, O0 ]/ J/ L* m\" J
  65. tensor(0.0042)% L! i7 }9 T! y( G
  66. ( E+ A+ V8 i$ y) C\" B9 k8 |
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    ' j8 X4 s- h3 J: {+ w9 [* g
  68. # b: D- L9 F. e7 Q
  69. bias tensor([0.0002]) tensor([0.0003])
    / t3 ^! t0 e, r, K: j! {2 M
  70. ) C. q( K( ~4 h7 w8 c; \) v) ^$ `
  71. tensor(0.0027)) K( z% ]' K( d' s
  72. 0 B8 R6 t, D3 H7 V4 w+ a
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])1 F! S; c, c6 V2 e- z2 _1 v

  74. ' n\" W. m# b, J# i* Y2 c
  75. bias tensor([0.0002]) tensor([0.0002])
    2 o8 x1 j6 W) n9 V1 t; }3 K

  76. 4 u& o: O2 a0 u
  77. tensor(0.0017), Y# X2 c6 c2 l) o* _% [; Y

  78. ) a; x, S/ F% v3 h0 M) D
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])1 O. v; W* P4 C; O) f6 e
  80. 5 |\" a9 M0 G, b1 R9 e7 w# u
  81. bias tensor([0.0002]) tensor([0.0002])
    & [# ?7 E4 d! c! \1 o2 k9 {5 ~

  82. 9 I( u) r% G2 s( O1 m, P
  83. tensor(0.0011). `: V& P% Y7 q* E7 o8 M' Z

  84. ' g) z, H$ o% g: t
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    3 a\" W1 X2 l) U* W# y; p8 y

  86. 3 y. J+ M( D9 P3 V& |& W\" T
  87. bias tensor([0.0001]) tensor([0.0002])
    9 g! a2 q# Z5 k( m

  88. $ G% n3 D\" P1 `4 {\" C4 U& a1 k
  89. tensor(0.0007)7 D( }6 W& a. c1 V' l

  90.   _4 E6 Z% ?( h\" V\" O
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    / S6 @2 x5 [9 @* O8 r

  92. ! A+ _, h+ N# }$ b( y8 R) L/ s
  93. bias tensor([0.0001]) tensor([0.0002])
    7 }, a2 `9 {6 z' \+ h) h8 a

  94. 0 q% m. V* s  {0 `: ~1 S5 P/ ^
  95. tensor(0.0005)
    7 l/ K! E3 G; m; R

  96. \" N7 p$ Z. S+ `( X. k: Q
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])# w, \( ~8 w( q1 _

  98. ! @# v8 ]2 q- Q- L8 R( G( N
  99. bias tensor([0.0001]) tensor([0.0001])
    1 y+ ^  v' P1 K& w8 i
  100. 3 w% G/ x$ l- S$ t1 r
  101. tensor(0.0003)/ q: Y9 S6 v: w1 Q

  102. 4 j8 z4 J# f* b6 @  X
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    6 H% k' {3 v; L% w/ F
  104. ; a' b; k2 Q- W$ k- [6 g
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    1 Y5 [) [, G7 O8 M2 p  N

  106. / _8 U* m' @: s
  107. tensor(0.0002)
    . N0 ^8 [% o- p, i- z# G
  108. \" n( z' N) W8 b; s! ^3 I
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])9 h% P7 Q8 z6 c6 f/ m6 `\" |( f
  110. 5 `5 ]& r! ^, W+ u\" a* I7 }
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    \" t7 I+ c5 p! a/ j0 M
  112. ' L+ y+ ?5 M7 b% `, W% _7 q
  113. tensor(0.0001)\" [  K. ^: [7 O2 A+ n# _\" r
  114. 1 R; M  e  E. I0 G5 d& e/ P8 W
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])  j; t- V9 ?7 k4 b4 b  A) [( {
  116. ) ^: P. _$ a; C. \1 F: c
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])5 w+ c. I- O  k. c* O, h  y: s

  118. & J0 q6 M) z' Q2 U# |7 i
  119. tensor(7.6120e-05)
复制代码
; r0 u7 z1 P: J
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-6-7 19:29 , Processed in 3.093429 second(s), 51 queries .

回顶部