QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2670|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1189

主题

4

听众

2934

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么: p2 U2 `; C( V0 B, T6 T# Z. v3 p' h3 F
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
( q+ L$ j- p  i  z) I怎么理解梯度?6 e- ^$ r$ B: j9 |2 E7 `' ^% e
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
- B/ d! G0 i3 u7 ^8 p) s
8 ~3 a- d# [* k/ C$ _7 b在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。& U, d1 d( ]! K* d9 M
3 m4 }' G+ b" f  N' P, L. E, R
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
$ e: J* U; ?) ~) Z! Y
& _2 i2 c5 }  J# M$ I& u# ^为什么引入SGD
$ l, h7 T( u: }  Z1 [0 Z深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
: p( k4 T6 h8 m8 q& G
4 Y5 k# a$ g. I4 R8 ~' [* W$ f7 g怎么用SGD
  1. import torch
    # i7 q8 R0 `, V  \+ |' @
  2. / K4 W; K$ [! Y; |& s
  3. from torch import nn5 x: d3 n7 b8 Q1 B; H% |' c2 }
  4. - T/ b1 g' S) I8 i- o1 B4 r9 \
  5. from torch import optim% j( k7 I' P9 s4 ^) Q2 e* {4 L. b

  6. & I0 I; ?6 w7 G/ U2 C

  7. , N: E) s( I1 p' X. h

  8. ! w* ~. X, X$ d% L
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    # a1 L3 `, a3 {  m+ \

  10. # F\" P/ x\" q: }\" x+ V
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    - j\" g) F% F8 j4 D! @0 L/ Y' [

  12. # W% P4 Y\" c2 z; @# i* ^- _5 |$ i

  13. . z2 u. i$ g# \; x5 N8 k1 q2 ^6 c% Z

  14. , o3 r' Y# |) O8 m- \8 c* m
  15. model = nn.Linear(2, 1)
    5 P5 ^3 z2 \( x. K
  16. 8 G& }# ^  a$ k5 n; u0 J$ P

  17. - ?1 }# j0 G2 ?9 d
  18. * e1 O) A8 S7 |- O9 Y, W- F4 G6 x: M
  19. def train():9 \) a$ D6 T\" U0 L/ X\" N

  20. ; t4 }3 V7 P1 b. I
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    + ^5 ]' }/ i% A
  22. ! g# D. o' \8 L; o
  23.     for iter in range(20):# S6 P* Q5 X7 a% H
  24. 9 Z6 `4 `  I# t% w& c+ |$ a' a+ m
  25.         # 1) 消除之前的梯度(如果存在)0 m7 h9 Q0 n5 u9 A; v' H

  26. ! N  R- A  e, L4 U+ w* S2 z; r
  27.         opt.zero_grad()
    2 S( I- F  J/ B  E, s; k5 @

  28. 9 T6 T1 {1 i2 x' i
  29. 5 R* w) e! e) Q

  30. / K4 f! S0 K/ g& H
  31.         # 2) 预测% J7 R- y5 K- q3 R3 j
  32. ) E. t& P: ~& ~5 O
  33.         pred = model(data)' e2 \0 h1 i' d

  34. \" F, l( g) S, B1 a2 J\" I; B, F
  35. 5 B4 H. \* F2 g$ l& U- Z( G
  36. 9 B$ u+ ^% s4 h
  37.         # 3) 计算损失3 D3 t- k# }6 v- I, i
  38. \" Y! K1 U# f$ e9 u9 P) s. @
  39.         loss = ((pred - target)**2).sum()
    + @- d  ^( @! b\" w

  40. # d6 G3 }& [: |. K' ~* y- d/ m) W
  41. ! S2 B. Y, @, @- @3 i- i

  42. - b) @( ~/ z7 L4 w# A- E
  43.         # 4) 指出那些导致损失的参数(损失回传): x% C; \* K) m) r\" T

  44. ! ]; Y$ W! m8 U6 B# D
  45.         loss.backward()
    $ N. z% h* o0 [* W* g
  46. / h8 U$ p, d9 ]\" J0 U- N- l) {2 [+ q, t
  47.     for name, param in model.named_parameters():
    - U) g6 i4 O& W3 R$ x  _

  48. / ?3 o( N2 _- p! a# j9 M
  49.             print(name, param.data, param.grad)
    9 i1 L$ \7 ]! m: G/ }
  50. 7 D. j8 h. N7 [1 J
  51.         # 5) 更新参数) y7 Z- f0 u6 _\" L

  52. 8 Y( y7 a* g# J7 ?5 @$ [6 v. M
  53.         opt.step()6 C- p: Y1 i+ }- k5 Z
  54. ; w5 N  N- V% L; r# E8 p. d: R

  55. ) Q1 L& V; J5 J) {) S; t% Y
  56. $ b3 l' n4 ]( C. ~7 I3 `# \0 X
  57.         # 6) 打印进程\" p0 {. e  t) H7 R/ E4 e& E
  58. ( ]+ R. c7 L* X. C
  59.         print(loss.data)
    , [+ A9 |+ J7 [7 {, M

  60. * }. x4 y* X\" u) v/ m1 {  a
  61. & I8 a+ D5 L0 z; E6 P
  62. + h* g# _* X5 k1 h9 y  t
  63. if __name__ == "__main__":
    : Y  r  A. ~- Q% g7 U' t$ Q

  64. ) d0 G) q/ U. A* o. T- y* l$ T* O
  65.     train()0 C& n! V% n# n' S6 F3 p

  66. - q; [2 K8 |9 t1 a4 f
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
/ U0 r$ E' t# g3 I, B
$ b' N+ A4 D' S% `计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])  U- M2 N$ x# }! A
  2. 6 @( X; @3 }8 l/ Q8 Q# U$ o
  3. bias tensor([-0.2108]) tensor([-2.6971])1 n& H) u: C1 w# u8 Y( v
  4. 1 [, \( |* |- l2 O7 K- I
  5. tensor(0.8531)
    \" w/ D$ j' O, Q0 I
  6. 7 T\" L4 j' ]4 e
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]]); J: `: u; A7 o% S  l! W

  8. \" R1 Z( M$ ~\" ^. C8 H: m
  9. bias tensor([0.0589]) tensor([0.7416])
    5 P' S# H8 c7 ~, v. _- P1 D

  10. 3 M  s\" l0 ?% |2 h/ K& q
  11. tensor(0.2712)
    \" H* J% T- c' ~
  12. ! U, C- w) w  I( N$ Q
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]]); m/ g\" z, F! Y! C
  14. : w  l% H% i% J* B& ?
  15. bias tensor([-0.0152]) tensor([-0.2023]). S' t3 q! [- ?2 h8 a
  16. + f* \# Q: ?9 n9 z
  17. tensor(0.1529)( I6 @\" u4 E, U
  18. $ f- \3 O4 Z: u$ n+ x- \1 U
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    # i4 d4 O! g. j
  20. ' p/ Y% v4 y4 M# m9 `/ x
  21. bias tensor([0.0050]) tensor([0.0566])
    \" Y1 J& r/ a+ K+ b6 h6 C  G( ^

  22. : X7 @: Y* D/ J! u9 K
  23. tensor(0.0963)( w# m\" P/ P/ c

  24. & g  Z4 Q  l8 I& f  A) w( O
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])# s8 X5 ~5 \+ D9 e9 |. C& Y

  26. : Z' ~( v: T0 c\" Z% x; U+ B: _
  27. bias tensor([-0.0006]) tensor([-0.0146])
    7 o# A5 t, b8 A4 s9 j' b: n8 ^- n( F
  28. / Q( e' n7 V7 D' p' M5 r
  29. tensor(0.0615)$ i. n  J, ^/ P' A/ X! s

  30. : h& S, J8 I$ y: K9 t& d: v& q4 n
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    4 }4 V2 @' P! X  [
  32. 6 a* h  g6 |6 ^% ^9 s' [
  33. bias tensor([0.0008]) tensor([0.0048])
    1 m; X) q4 Y! i! A/ _& A( a' M
  34. + U. f8 W  E\" O- k8 f9 l  Z/ ?
  35. tensor(0.0394)' z6 ]5 t/ S: [% `0 }* v5 }\" P

  36.   r5 w: p% g( ?; u4 K$ v1 u
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    \" A% X; y& j- K6 R, l/ k  q

  38. & v9 S  e6 \3 n' M# y
  39. bias tensor([0.0003]) tensor([-0.0006])4 H; L! i1 k! j2 ~2 R. A5 A

  40. 8 n\" u* Q- I- z$ Y! T$ h
  41. tensor(0.0252)\" P. ^/ }5 P+ c; |- M! {* {/ h: w
  42. 2 c\" E8 J( K4 o! E, {& k9 k2 b9 A
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    5 ?6 S, [7 Y7 R  k

  44. ; f, H3 H1 p# j) D0 D! [6 B
  45. bias tensor([0.0004]) tensor([0.0008])3 P9 T$ J7 {0 W) y% U

  46. 3 s& B( [\" R% }$ B
  47. tensor(0.0161)
    & ^( @( H8 n( ~

  48. ! B8 C8 K8 L\" k7 q
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])4 O3 x4 C8 R! N  m5 K
  50. / ^( c4 H6 n0 L# R0 b
  51. bias tensor([0.0003]) tensor([0.0003])
    \" \. i* D- w1 m# F\" @5 Z\" R
  52. + Z% v- G' S. ]( ?* [. D) m
  53. tensor(0.0103)0 h4 }$ ^# Y; F; `- d

  54. : m4 Z0 @* S, n2 l3 r4 L0 t; d1 m
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])( H\" ^+ l- @8 j
  56. ! J# e2 a! F3 {! I! s) S
  57. bias tensor([0.0003]) tensor([0.0004])
    6 e6 y0 y' |: c  a
  58. $ S- E: m0 ?5 Y- H' M( x, U5 x! v
  59. tensor(0.0066)
      K7 R& Y) x! ?. z- ^! o& _2 t, G

  60. 0 j+ E7 Q1 A2 E
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    % J$ B6 Z! [# y$ ]' C$ b  I

  62. , P9 O& \: X/ a9 C8 t( u- L6 }' s0 W
  63. bias tensor([0.0003]) tensor([0.0003])6 k2 q+ p- \1 A3 e0 a; k2 p
  64. 2 U: a) N. m0 D( c! b
  65. tensor(0.0042)
      R: [1 x! g- K& t7 O' z' [
  66. : u4 ^4 F# a( X1 b. y. N* n9 d
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    2 e0 N  V# c8 G; s; S
  68. 9 j  u' R: W1 `1 c+ c$ _  m, t
  69. bias tensor([0.0002]) tensor([0.0003])
    # i8 u- Z1 B& b# W# a  s

  70. # y4 K- C; i  C& F9 B
  71. tensor(0.0027)
    9 H/ w& \. y& C& \  c) ^: a8 b

  72. 2 Y' L3 q. T) V$ ]- r0 K0 }\" k
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    8 z! t! a0 W9 K+ R& v# i

  74. ; Q* p& m; n3 ^
  75. bias tensor([0.0002]) tensor([0.0002])6 R2 l* E, g% t3 C& ?& z! t
  76. 1 E0 T) D: z+ D4 v0 V9 u
  77. tensor(0.0017): T) X' ]' J% B! I

  78. / |' ~# x# c3 L5 d\" a
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])/ ]9 y+ P/ G' H' q
  80. 4 n! `( J9 A; u8 z
  81. bias tensor([0.0002]) tensor([0.0002])
    : l* n5 ?/ Q6 J  ~0 K, [& A

  82. 2 u) H/ x$ R3 _. }5 X/ j4 R
  83. tensor(0.0011)$ H( N* e. J* |

  84. 3 y. i/ X1 a% S) \% r- g
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])0 e# k% y1 q- S& L  L
  86. ! e9 n6 j/ J( O5 M/ [
  87. bias tensor([0.0001]) tensor([0.0002])* z7 P8 D0 g/ H\" {& Z
  88. ' f: R8 I$ u- F2 [
  89. tensor(0.0007)
    0 R$ C5 i5 o$ j
  90. \" f, F/ s! n1 u, o5 V8 B
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    1 U7 m% m( E& O6 O3 a

  92. ! r+ @/ Z& V! e  h% O6 C' T/ k
  93. bias tensor([0.0001]) tensor([0.0002])# \( F8 P8 S, [3 x

  94. 7 |# s- N' o6 \\" [6 P7 r& p# ~# X
  95. tensor(0.0005)
    + D\" a  F: G. ~- g6 M3 m5 A/ ~
  96. - y, O7 g$ x6 `8 {) L& M
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    \" [2 `\" O+ ~9 J$ V6 n

  98. : q8 g\" m/ m0 Q5 d& E8 m- @
  99. bias tensor([0.0001]) tensor([0.0001])$ B5 y, ~4 g- q& Z6 I
  100. 8 N( [1 E1 P+ {6 C
  101. tensor(0.0003)
    ' t6 H4 B3 ]1 a% h  @2 d
  102. . s; j+ I1 T6 r# j: W  u
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])3 r0 Y6 s7 f0 ?: Y, m
  104. 3 H) [, v\" Q1 b# o7 o: o9 N
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    ! D5 [; h. E7 {: t

  106. 1 |2 N8 L) H( [% g1 @
  107. tensor(0.0002): }' f: I( \  m6 h# b1 q
  108. 2 |& l\" r; L  d- G7 [+ t
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])* g3 E, o5 @, U

  110. 6 O. ~) K9 J7 I7 Y
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    ( V9 s! ~4 `' h- C1 W3 P
  112. 3 n7 \' u* D2 Q3 D+ P0 m\" X
  113. tensor(0.0001)
    1 e1 i9 G% ~\" U9 ~5 T- J

  114. 0 f* {% e2 _% v: k- l$ F
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    ( ~+ q/ F! P+ p

  116. : b% l3 f6 _9 _5 h  r) `
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    % b\" a& Y& h\" z3 t5 b

  118. 2 a# y\" `) h( T% D2 r6 p: X% P
  119. tensor(7.6120e-05)
复制代码

4 ~2 x2 Z: L  [+ R1 v1 R
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-6-8 09:22 , Processed in 0.423558 second(s), 51 queries .

回顶部