QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2249|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1176

主题

4

听众

2884

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
2 B9 `% Y# B+ ]$ M3 y  QSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
  {, {* Z, i: A怎么理解梯度?  b+ i1 h2 g! ~) j- u% G
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。3 Z  T/ s8 d- R; }) p- {- n8 `
: I9 y7 n& f( u, d6 v6 S
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。; A- P+ \) P0 ?: p  p* j5 [. Y' e
# N8 N4 H! W+ W
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。" t, l0 E+ M3 I6 R7 L, \

3 e  b( ~4 \/ G4 C$ z$ L& o为什么引入SGD
& G' ^9 I& y4 ?2 H/ Z& v深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
/ M* }  K$ g; }+ C( a3 Q) k7 b, o9 ?% `+ I- E, w0 q
怎么用SGD
  1. import torch- J3 z9 A7 a1 v: K  O5 w

  2. . A! X4 e# N, @2 N# R! \
  3. from torch import nn
    + J. }  F) c! i, R. ]( u3 r

  4. ' W' X1 e( g* H9 c1 O; e4 i; U0 @
  5. from torch import optim
    ' F* {\" |2 z* S; @9 l% q
  6. 4 t4 Q- P8 n5 M' R
  7. 8 ?( g. H$ J6 x, H2 \+ `, w
  8. . E+ \5 G* g) h# Q* `1 H, X
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)# n5 b3 p5 i. Z7 Y$ q
  10. 9 v8 _8 c3 v6 j/ Z: n\" e$ \4 D3 c
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)6 h# L5 ?( k6 m, K0 u. N

  12. ! c3 `' E0 I& n
  13. ; H% v# o( h5 Z8 N0 z9 @+ z
  14. % p6 w, y( z; v' U2 \; ]
  15. model = nn.Linear(2, 1)
    ( E! Q! i5 J# l\" A

  16. 2 F; s6 v0 f# x3 \6 k

  17. % V  |; k. H$ _; E: e( A
  18. 4 ^9 e! _1 |, `+ `& ?7 a# H\" e
  19. def train():
    7 }  V7 G. F/ p1 \, J+ A) j1 X
  20. 6 ~8 `' `2 C  f; \& ^
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)\" J6 r0 n. ]2 N5 R

  22. ; Z5 @& R8 d$ k; ]: ^# q
  23.     for iter in range(20):
    ; J5 ~; P% K, }0 G+ f\" x; A2 z; e6 V9 X
  24. % G% k8 f- D9 |4 R
  25.         # 1) 消除之前的梯度(如果存在)4 L3 T7 p0 @2 V1 o+ @

  26. % S  r5 E1 u+ T/ t
  27.         opt.zero_grad()
    % d! Y$ M$ E& A0 ^, [+ g

  28. 8 }% @* X\" j  w9 Y, Q* N8 J
  29. 6 s, q( v- i9 e' Z; {% j

  30. 5 {3 v* ?% ?3 N+ u& h) s1 u
  31.         # 2) 预测
    # V  ]/ S, H, Z9 j3 Q: [) D# \
  32. 1 Z8 y% r. Z* Z
  33.         pred = model(data)- s! j  L6 S3 d5 R. T# j4 Q' a0 ?\" E; R) N
  34. / i, V( D( ~: q4 h2 ~

  35. 4 |' T% r$ a% e7 ~1 I. o
  36. , R2 K3 T+ o0 |  k4 q: S\" R6 O$ B
  37.         # 3) 计算损失
    : @1 F* |& R, p- |. u

  38. \" f\" B% w\" h7 z) x8 a/ J
  39.         loss = ((pred - target)**2).sum(). V5 a1 t$ q# e- k

  40. / s' }  |3 m# O& O

  41. 7 U\" y) ~' M  A$ _1 W

  42. 6 m& i( \) D4 }5 }1 O2 h
  43.         # 4) 指出那些导致损失的参数(损失回传)
    3 D* I8 {6 C) e\" o$ Y

  44. 0 @5 D\" v8 N2 `; s* J) v
  45.         loss.backward()
    & q: Q& w% q4 b% g' E/ R
  46. , X0 q2 H; |/ k! u  m2 i4 W! Q
  47.     for name, param in model.named_parameters():
    ( q' w1 J2 e& y1 `: f

  48. & F6 U9 h4 I, o
  49.             print(name, param.data, param.grad)) ?1 K) Z$ a! [. ^( t0 l

  50. ( \, u0 H. y1 G$ n; H6 G
  51.         # 5) 更新参数4 z* s' L* `, u
  52. 1 w* T) S0 e+ g. q! s
  53.         opt.step()
    0 u/ G8 a2 K) T
  54. 1 G7 @+ x, a' I9 [7 ]# {* ]3 i1 B

  55. 9 h% v1 |( ^' r: y2 r
  56. ' L8 w  _3 k\" D7 P( \
  57.         # 6) 打印进程
    9 i# M\" e+ z& o' Z

  58. ! L  ^$ \- F: U0 g( b6 h
  59.         print(loss.data)
    + q% d\" W4 K/ g8 M6 d4 C

  60. 7 i% r7 }% S, \$ y1 C* X
  61.   w( T, f3 A$ p2 Y0 f; v3 R

  62. : [3 a, q; ?! V4 T% K( h/ x% D
  63. if __name__ == "__main__":3 w. z- a\" t, v2 |# I/ k

  64. . c  s  b+ ?8 |' |$ L' G\" [
  65.     train()
    * M& y% a7 @0 [9 F

  66. 8 ?# B: \) S5 y) c* S8 ^
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。: a7 [$ ?# P1 `4 B7 r) A
: ]- B5 k, S1 S
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    . X! T; P' `' L0 g
  2. : ?) P3 Z' v8 H2 \3 ]/ b/ M) b7 }
  3. bias tensor([-0.2108]) tensor([-2.6971])
    ( Z* i3 V  i% l$ X0 a

  4. ( U; B  |% `' w% g
  5. tensor(0.8531)7 N\" Z8 L* X7 G* H. @, |; K

  6. 8 n9 r5 W\" c' A\" J' W/ G
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
      K, y; a$ i$ }- ^

  8. . I7 d8 D6 C8 N. b: B2 k  t
  9. bias tensor([0.0589]) tensor([0.7416])3 x; m% f, C2 y+ R/ m6 I. V

  10. ) ]9 h# `+ m9 b) k9 ^, \; t7 \
  11. tensor(0.2712)
    3 C' v1 `& k! s1 `
  12. 4 x3 p. A2 M\" e3 [\" e8 m
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])5 j% }  P. Z8 k* t* `5 e
  14. 1 V) \, s& h& U\" T$ g+ k
  15. bias tensor([-0.0152]) tensor([-0.2023])
    0 `. ]7 A8 W- L% L; w
  16. - I\" y6 Z( k6 E; M3 `
  17. tensor(0.1529)
    : S- r& w+ L1 W2 e/ F9 n
  18. ) J+ v4 X# N. b& D) J, X
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    7 I- I: \% `1 e, v/ M+ }) G
  20. 4 R4 h5 u# r: ~0 Q+ N1 Q
  21. bias tensor([0.0050]) tensor([0.0566])) B7 m+ p+ T/ P( q

  22. - j( V) ~0 t7 _  _  c9 p7 [0 \
  23. tensor(0.0963)
    : \9 R# v/ L8 t, ]0 A

  24. 3 Z3 W. f/ ?, d0 i% X% H
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    * I( P/ T. d, J8 D9 B. U
  26. 4 {3 a* c4 }\" O2 J9 q; o
  27. bias tensor([-0.0006]) tensor([-0.0146]); R2 |& y0 x. W. `% J\" w

  28. . ?3 x9 H  V' u7 w6 s+ N. ?2 y
  29. tensor(0.0615)
    ) A' A4 Z' u% _: r
  30. ( P' W% j8 q; n- u7 k
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    1 w% n( e  k  s9 d5 b+ B

  32.   F1 L\" v7 u! I# O
  33. bias tensor([0.0008]) tensor([0.0048])
    # P/ Z+ `2 N+ B. [# q
  34. 3 ?3 c) s: ]) Z' h* N
  35. tensor(0.0394)
    2 d& N+ O/ m, H9 R8 u& K

  36. 0 n- \8 ]9 t% S4 J
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]]), f+ o9 Q! ]! R- C' q\" |* n\" W& Y8 V- V

  38. * @( P' H* B- o  C
  39. bias tensor([0.0003]) tensor([-0.0006])
    $ X7 e& _: q! G1 Z7 N, @
  40. 4 P$ K, x# j! r5 m3 Q& F
  41. tensor(0.0252)
    . C. `9 c4 M8 T/ u6 ?$ i- b
  42. 5 e& B! V0 h, W0 l
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]]): o! v7 N$ @3 y  P
  44. . G& U' B( W9 t6 V, d0 n1 ~
  45. bias tensor([0.0004]) tensor([0.0008])
    6 @1 S6 Z6 [: T# K8 m/ Z% S
  46. $ S- j0 \- C# U6 E2 J
  47. tensor(0.0161)
    ; G3 `3 r4 E+ ^9 U/ W9 h+ N

  48. - B- o) p/ p6 N
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    + M* \\" m9 }% U5 k+ L& X

  50. 8 l9 b# ~/ E& y- r4 A2 U
  51. bias tensor([0.0003]) tensor([0.0003])& s& p! ^% O6 [5 W: {
  52. / ~$ e! a/ r. Z' P( C
  53. tensor(0.0103)
    ! g( d' ^# p# {2 L- c4 v

  54. 3 i: T8 P$ W6 c5 O0 n/ N5 ^. i
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    0 r, e1 R  ~& v7 S: m8 c

  56. % d\" j) k0 K% [8 G\" i& e. P4 j
  57. bias tensor([0.0003]) tensor([0.0004])4 Q* K% P7 y3 s\" j$ Z8 H4 b

  58. 4 o( X+ W  B! Y# n3 `! v9 b+ y
  59. tensor(0.0066)' c' ]& ^9 [  J$ F7 h
  60. & A7 \+ C  L2 a6 p
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    # _\" A0 V! d* q# F( _

  62. & T: Q! {# d9 L$ Z) F
  63. bias tensor([0.0003]) tensor([0.0003])
      H9 S- X$ x; \( d4 v

  64. ) \, J' Y; q9 \8 b3 M( u' ]9 h
  65. tensor(0.0042)- K4 q+ O( J; I0 z; T) z% i
  66. , L1 D\" r# \! A
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    / ]9 V\" ]8 ]) R4 o
  68. 3 h$ f$ @1 f: y. ~1 N, L
  69. bias tensor([0.0002]) tensor([0.0003])1 _6 v6 Z\" N. }* B, M
  70. - G0 T4 C. M\" O2 T2 K+ ^
  71. tensor(0.0027)1 A9 b2 ~$ @) H6 h5 k
  72. 0 \* C% y: c: f5 P' @# E1 y* l
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    ) h. l# K3 h( U

  74. , Z4 M6 _$ ~! P( @' F+ Q0 K
  75. bias tensor([0.0002]) tensor([0.0002])% g# f( c0 V$ D6 z
  76. : E. r' x: R8 \1 x\" ~  O3 y  O4 ~5 W
  77. tensor(0.0017)
    . G' x2 o  P* T/ P6 x$ y

  78. % x# w) p% f2 |! X. Z# b
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])\" f) b8 `. y! R# _
  80. 1 I0 k: K0 h3 H
  81. bias tensor([0.0002]) tensor([0.0002])) g: u$ C% u/ j! b5 e* `

  82. 9 E8 a\" V6 q) t/ i
  83. tensor(0.0011)
    2 o$ C9 w* m# K  r3 A( ]

  84.   @( r; l, j\" a# x1 |9 |: t
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])+ v, m6 N% ~9 e, A2 N8 E, b3 b
  86. & [9 n4 B/ O\" K
  87. bias tensor([0.0001]) tensor([0.0002])- B. C5 n5 T8 A\" {, ]5 z

  88. % J  e7 y; w: E! `- `
  89. tensor(0.0007)
    \" v% z: l) y/ ?' L
  90. 8 c4 e; @- j% M5 B2 E8 l
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    . X\" l4 \/ _, t% H' t

  92. 4 B7 w8 {0 r5 p# V; d
  93. bias tensor([0.0001]) tensor([0.0002])
    3 g: L5 s\" a9 Q7 H
  94. $ K$ C% N2 B4 S
  95. tensor(0.0005). E7 @# E6 N1 K, E5 D4 [
  96. . A  t1 R& b2 ?7 j$ l9 D' a# o* B
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    0 ?3 j! R; {' p\" q% |

  98. 4 }; J4 D0 {; m; D, w1 J
  99. bias tensor([0.0001]) tensor([0.0001])
    0 w\" e\" O2 |  m* @\" J
  100.   O  e) x8 T* C5 J' U( I: F0 y
  101. tensor(0.0003)
    ' r5 B6 d' m5 }! N) }
  102. , r, j+ O: O\" C8 }: ^. m
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    ! z. N7 C5 T: |7 f+ |
  104. + a$ @/ |4 h! N8 |$ [/ `& e! [
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    # }2 W3 P( o  Q1 ]7 s6 T

  106. . m2 D8 o' }, f$ b0 b
  107. tensor(0.0002)
    # m5 a- r4 t, v0 t: P+ s5 x* B8 Q
  108. 1 R\" A! K  [) P* w( R
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])8 I( E2 C( V1 V

  110. ' J9 u7 M/ p5 T7 p2 ]
  111. bias tensor([8.5674e-05]) tensor([0.0001])/ a' T0 ?! @* M! X
  112. % [' R+ E: W( V3 |
  113. tensor(0.0001)- {' [3 a6 {% r$ _7 ]

  114. \" @' q+ e% H! D- b& @; w% s5 H
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    ( {0 Z1 M6 B0 F! n% C, L

  116. 7 s! u+ }' N! _6 b( ?. i
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])( z6 c4 Z9 ]% n# h  v9 u9 [7 n

  118. 9 M! \3 e# B' \0 |; v, ~+ t1 ^
  119. tensor(7.6120e-05)
复制代码

% L+ W7 L! f  j3 p* i  q4 b' B- C, r6 H
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-9-17 15:08 , Processed in 0.498650 second(s), 50 queries .

回顶部