QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2628|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1186

主题

4

听众

2922

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |正序浏览
|招呼Ta 关注Ta
SGD是什么
* A# D# y& s8 F* r* mSGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
# o3 _  w4 P; {; o6 s; y0 y怎么理解梯度?& v8 F9 o! o3 A: B  B
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
; ?* @1 d& E8 g! O3 k
; f* O% q" W( h# N在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
, e9 t% H$ g& g1 [
+ d1 |$ Z9 u" H( {/ Q每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
0 a, t6 t$ _( U& `4 S0 F7 A+ m: u
9 o% ^! c' r+ ^为什么引入SGD
6 W# Y4 N  }- G1 h6 z+ h" c深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
5 _* F+ b0 b1 X# g1 c
6 k' h' Q: a2 r( Z8 Q5 D7 B怎么用SGD
  1. import torch, g/ Y: f; l\" `\" b\" m' b2 K

  2. \" L) _1 c$ O8 l! J
  3. from torch import nn
    - @6 W8 [9 Z: X1 P  K
  4. 2 ]9 ?6 v8 q( C' ~2 u% T. z
  5. from torch import optim$ @# Z/ Z1 o3 Y\" k, u

  6. 9 ^* N- j) r7 ~5 O, ]9 O
  7. 5 E8 Q. B% i\" [# c9 T! i

  8. - Z\" t\" w, q5 X
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)$ }( G9 {- a, J& G
  10. 7 k1 Z0 f7 j7 U( A( q
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    2 p* F5 \- n# E( y& d
  12. % \4 [5 T% O- u6 `3 m\" k. s
  13. 8 u* O8 h3 Y( t/ K; `! H

  14. ; B& \; a8 }& X6 A* ]7 L
  15. model = nn.Linear(2, 1)
    ( d/ M; B( P! g- z& N
  16. $ |& x; V- v& O6 a. @1 N
  17. 9 ]6 i  A, ~. r. t\" g

  18. 7 J! M3 X2 |  }/ [, B( O- U
  19. def train():& K2 ^- l# H; T

  20. , Z2 m/ s1 S* @7 b( F, T9 `3 c, u, ~
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    : h' \; t& w. I/ G2 n9 \8 ~; s2 Q

  22. 4 j( C$ n7 B% K  g! O. f2 ^6 h
  23.     for iter in range(20):& k6 V  D( o9 _  C5 w* q# l
  24. 1 j/ ^( ]1 H0 e# T
  25.         # 1) 消除之前的梯度(如果存在)
    5 y, }3 {/ q/ _& n' u3 \5 {! k
  26. + _8 k5 N$ q\" L# G8 F  E  K
  27.         opt.zero_grad()
    ( Y- {/ {0 `( Y; o- v
  28. 4 }8 {0 t8 s) b4 f

  29. $ N, F4 F/ a: l5 u$ V
  30. 4 m' U  l  G' Q# D
  31.         # 2) 预测& ^6 C7 W; w/ z# a) x) C

  32. . E4 ?4 h/ J+ d: ?+ i: Q8 X( \. Y: k! T
  33.         pred = model(data)8 F7 O' q1 M0 I+ e

  34. & E$ G# _3 m1 k\" W

  35. 9 h, u+ P# \) l$ c

  36. - z$ r  K5 l, ?1 \7 u
  37.         # 3) 计算损失
    : X) S% Z' c+ ]% N
  38. 0 k: W9 M# \6 E1 Q1 w+ R
  39.         loss = ((pred - target)**2).sum()
    . ?# ]- ?4 ~% w9 w
  40. 2 L0 N1 j% A+ \8 D4 r8 Q

  41. $ I0 o  w( |* ^; v2 K0 O

  42. 2 Q2 y* J. D\" D5 j; y! {
  43.         # 4) 指出那些导致损失的参数(损失回传)
    ( h; a6 S+ @1 b* b- g
  44. 6 L9 m\" U) l5 e% t! V
  45.         loss.backward()
    + C# a' y; D8 o' x! ^( s

  46. 7 x\" V5 n+ y  {
  47.     for name, param in model.named_parameters():
    6 a- [/ h8 M& O' R

  48. 9 q( `9 N1 X6 k8 t
  49.             print(name, param.data, param.grad)
    ' G) D/ [( ]; ^9 A  P5 x1 ]

  50. \" R% m) Q6 q8 }/ }0 O: s: l2 }\" l
  51.         # 5) 更新参数
    4 P9 f: E- @\" l
  52. + j6 C8 i! {6 G7 V& \
  53.         opt.step()
    3 i% ~& k! K) l9 ~

  54. 7 ~9 T* m# E# {8 l4 W
  55. 4 s2 i1 ~- e( Z: S
  56. - V$ d* o6 B. c1 `( H4 Y) Q; `
  57.         # 6) 打印进程% a+ B' h! ~0 Y0 Q
  58. 2 Z# x4 }5 U' V7 K* _1 m7 B
  59.         print(loss.data)2 ~( O2 @) k/ _$ c; R) V9 R2 `

  60. ( ]1 ?- ]. V. u  Z, u7 N
  61. 7 f: ?! g. [: b, z) x% \
  62. & k3 J\" I( k& U0 f4 I
  63. if __name__ == "__main__":
    5 _1 d  H( t( F- l

  64. 7 e: s# F\" X  t. m9 z& K4 I* Y
  65.     train()
    9 r0 h# P% N; \4 P, d
  66. 4 I' l& Z( \2 {) p
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
3 K6 ~# Z* c2 m- P; E& i
* |4 U2 P% }; i* [+ B计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])) G6 v$ [& E% T3 K3 k0 V9 j
  2. 7 w! I, `: u# i- s* H( \
  3. bias tensor([-0.2108]) tensor([-2.6971])# Z7 m2 ]3 C\" a, s$ r

  4. - G- S# ?  n, B+ [0 L$ i- I
  5. tensor(0.8531)+ r6 j: T\" n9 j& ]' D4 q0 ?
  6. , ~5 ?) ^/ P4 y+ F. x
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    2 c  o% f* u) ]- [. T+ t
  8. . b) n3 @8 g) q) G  N$ B1 d, {
  9. bias tensor([0.0589]) tensor([0.7416])+ u+ q9 I* E( d0 {! N5 q
  10. 7 d/ p! k, f& v$ g
  11. tensor(0.2712)
    0 V5 K- N: f  y# L2 o* n# a1 s/ ?, P
  12. # P: N1 p; M) s1 Z# H6 i) i4 y
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])9 O2 s$ r4 D1 _2 f  H
  14. 6 C: u6 C+ V. P0 A6 s6 e* e
  15. bias tensor([-0.0152]) tensor([-0.2023])
    4 C. @5 c$ D, u& N0 F' t

  16. + ]( u& `( M- U5 Z# J& {' T
  17. tensor(0.1529)
    8 r* Z\" j, u  ~

  18. 2 n9 m& O7 s4 S7 |( o
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    1 z% a. Y6 {5 Y* a* l/ z+ @$ X

  20. $ E- }3 a, E4 H\" @$ H; f
  21. bias tensor([0.0050]) tensor([0.0566])4 _9 o. n4 I! x$ m
  22. : j1 m8 c\" @. C8 Z
  23. tensor(0.0963)
    6 ^3 f0 |/ C% D% O' D1 y7 L! b* k

  24. , H; |5 t5 K+ ^$ D! L9 y# ~( E
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]]); E\" n\" X( u+ ~% I! u  i; c, ^% V

  26. $ ~, D9 o3 V% }* B: B/ w
  27. bias tensor([-0.0006]) tensor([-0.0146])
    * L$ C1 `( _6 `& d9 t/ n/ v! K6 z

  28. \" `  Z+ g& q* X' l
  29. tensor(0.0615), H# n+ e! {: x! K0 z/ N+ b5 Y

  30. 5 h5 R3 f& D\" l/ Q( u. `
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    ; ]/ m: p+ \' Q\" R. ?
  32. ) K- M3 M. S  P& J4 x1 h
  33. bias tensor([0.0008]) tensor([0.0048])& r+ P; O: Y# N4 D4 ^3 ]9 M9 q8 p4 C

  34. + T5 \8 R5 @+ [
  35. tensor(0.0394)
    ' l' j7 k5 C) s  j% Q; U  H6 P( t
  36. . Z8 j  x& k) x  d: |- q
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
    $ f, `6 {+ b6 z& c' k; z  `& P
  38. - c% M. }5 Z5 w/ x! {0 n
  39. bias tensor([0.0003]) tensor([-0.0006])
    + F8 ]0 b$ M7 \: ~6 ?2 _7 s\" ?

  40. 4 n4 t  a; G& o3 j+ Z2 g
  41. tensor(0.0252)$ A7 ]: j, ^; s  `# K3 s

  42. 0 T# {/ o& _8 o. F6 J# s7 W+ w
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])1 z, }\" h8 @( Q' c# R+ P( Z: y
  44. ; x\" Q8 ]2 `& w% o  p1 @8 ]
  45. bias tensor([0.0004]) tensor([0.0008]), z, c+ X5 |& \

  46.   R+ M, ]  _) F5 {
  47. tensor(0.0161)
    $ }8 b$ ^, a4 P) ^

  48. ( }. s) c1 o' Y9 Z! v
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])& A1 j( t; {3 H7 h

  50. 0 e& |5 j6 |# @8 ?5 X
  51. bias tensor([0.0003]) tensor([0.0003])
    $ h* @9 Q6 R; r
  52. . b1 ~0 D% T  o# ]
  53. tensor(0.0103)/ r# l2 D' a$ V, [& g% _7 }$ x
  54. & f9 _6 [. ?5 c0 f\" h( }
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])* q1 e5 V, ^( g) C# ^0 H1 n
  56. 8 R* Y3 h* j* i1 m1 o6 s- u\" Q
  57. bias tensor([0.0003]) tensor([0.0004])
    4 R9 F& a+ |1 T# T

  58. 8 j\" ^1 J/ ?5 c; S2 E8 T6 R( V- q
  59. tensor(0.0066)
    , [) O3 r6 k6 [5 j* Z
  60. - G! x6 T8 |/ M1 S, {+ F8 ~5 w
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
    ) S) g- [$ m) e& c! Q

  62. ) p5 ~* p$ f% T3 ]; m5 i$ @
  63. bias tensor([0.0003]) tensor([0.0003])
    # ?& s8 g0 j5 v

  64. # h9 y! R! N. E% T' f0 I\" ^
  65. tensor(0.0042)
    5 E\" @! E6 ~$ }/ u* n3 L

  66. 1 l  s% @1 |, {/ l, m% _8 D2 p; z; Y
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    / v$ k7 e  O. Y  F4 H

  68. % _! w4 e# m2 k# H
  69. bias tensor([0.0002]) tensor([0.0003])
    4 K& X1 L$ b$ h
  70.   Y2 F5 y# }( ?, k% y
  71. tensor(0.0027)
      Z  x: G' L! h7 u$ S

  72. & {8 R2 N' d! L& }  j( Q9 h
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    : j  X/ \$ n* W0 C
  74. 4 s6 l+ M$ P  Z5 b8 D
  75. bias tensor([0.0002]) tensor([0.0002]): E( x- m# f. C& G6 @! _7 c- X

  76. + O* j/ Z+ e5 D! _
  77. tensor(0.0017)  z2 K) H! k0 C$ l+ T
  78. , O  j# A- d0 K7 N
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])! w/ U3 P\" }' p2 N: `- u# k1 w

  80.   Y$ L5 w0 w/ m
  81. bias tensor([0.0002]) tensor([0.0002])8 N  T8 L2 a! K/ Z3 S/ K& d% Q5 P

  82. 6 \; c' t$ W' C8 r0 M9 ~. ~
  83. tensor(0.0011)
    ) ^) |% Q7 c# f; K8 O0 J0 W

  84. 5 M+ T# h; k2 n, l, b
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    ) I! y$ |. E4 N+ K* T. f+ c
  86. / D9 N4 ^9 X) g2 L
  87. bias tensor([0.0001]) tensor([0.0002])
    & q6 r6 r9 ?( f+ K8 d
  88. 9 f+ _0 ]& j8 {7 \5 b
  89. tensor(0.0007)4 W, n/ i- D) y% Q( w% @2 L  e

  90. 5 U. X) G1 m$ G2 Z$ ~
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])5 E. F5 `* U; m, n\" |
  92. \" {: _5 v4 X  a- |; P; @
  93. bias tensor([0.0001]) tensor([0.0002])* I; s, [! H  ?\" i. x  w' e

  94. / }( a2 d$ E3 k. k% H+ M1 ^
  95. tensor(0.0005)
    6 J' ]$ b, a+ O2 p\" @

  96. 8 P5 c* K  q/ c; P+ G, ~$ f
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    / q. V9 |$ @1 \; H, v( [2 r& r( C

  98. ! j6 g8 D6 I\" C
  99. bias tensor([0.0001]) tensor([0.0001])
    * Y% {# T$ Z' Z2 G$ E* t/ C
  100. ; j' Y+ l1 m3 R2 f$ ?
  101. tensor(0.0003)
    $ K0 V3 f4 @/ _2 ^. n- H$ ]

  102. + S5 _8 w2 p! Z( s2 j
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])4 H' z( p% r' W- N! C/ V
  104. , t3 L( M/ s2 w3 X
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    3 O$ U0 a  g' S! _, ?! m8 x/ f2 z
  106. 6 D8 T9 B5 u$ i) O
  107. tensor(0.0002)
    ! q8 c4 U+ H) O/ N
  108. 5 j% a5 P. B$ G\" L8 D& ^# ?
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])
    ' U  R) _( m% u1 w) s3 M+ V# `

  110. 3 Q- G( d% _4 X( k* p
  111. bias tensor([8.5674e-05]) tensor([0.0001])5 J9 A) c# ~5 O5 O

  112. 4 l% }0 J* S% ~/ z. |( Z
  113. tensor(0.0001)
    3 X# x  i4 c% o6 [3 Z4 K
  114. 5 u) C4 {4 \# x  K  y9 t& X
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    6 c) [  j' \3 [3 r) W8 w: w
  116.   b0 i, v$ L1 W- o& Y2 B( c
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    / d, @: `6 x# j) z. o9 j

  118. ' [9 p* b8 p0 U7 n' l  |' @
  119. tensor(7.6120e-05)
复制代码
0 G; U3 Y) Z/ R( n9 q& g0 z
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-4-13 02:22 , Processed in 0.429470 second(s), 51 queries .

回顶部