QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2668|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1189

主题

4

听众

2934

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么! j1 q) r$ V6 ~3 T1 }, v4 t
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。$ \  F4 t$ [/ k$ Q" v
怎么理解梯度?2 I* C5 ~4 |; Z. ~6 \
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
5 P8 w  p- L: p# ?7 [/ a2 }' {7 N6 ^: ?
; u1 Z* X, _6 j/ J6 N在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。. l2 H8 L/ N$ |5 M5 Q% F3 a/ G

  O1 H, y6 X/ I! X9 b0 d每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
; c7 y  U' o7 Y+ y' K$ s' }/ x, o0 {% p
为什么引入SGD
  T9 k6 |, S$ P, B3 P" }- }深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
3 e- E) W' Z8 Q  _! }* e* r  I  R  Q/ F
怎么用SGD
  1. import torch
    8 ]1 c* F5 L% a
  2. 2 G3 U8 Q) \. y\" A2 X! j( B8 [
  3. from torch import nn
    : x! H% w, ]/ X2 X4 V: n

  4. : v; W! ]' O+ ]* O
  5. from torch import optim4 t- q' w5 W) x2 b3 k
  6. 2 L\" @- W% M+ H) [

  7. - ^* r5 b5 S1 t$ @9 J\" W
  8. 6 v; v8 k% O  M
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    % q: {$ [* o/ n
  10. % Z& X- n8 s2 S. Z4 L$ o
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)% Y* ?7 J7 I5 C5 d1 |% ?& e

  12. 9 i\" T8 \/ ^  i2 O4 @
  13. - m' _1 e* A1 h\" x4 Y( T

  14.   \/ U$ {* j! o% n  h5 H
  15. model = nn.Linear(2, 1)+ A* o/ X# [$ c9 [! t. D8 `
  16. \" b6 n  V1 q8 W( i$ d7 J

  17. 3 E( M2 s. U8 d& H, s& Q
  18. 8 e! K) R5 F9 J! V
  19. def train():. H% P1 F9 v# ~0 k7 \- J

  20. 0 u8 M7 R9 ?# ~& H% S. i/ G. c4 Q
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)& R' y. L  H$ k& R

  22. 9 ^8 W5 D6 f3 J' `0 Q
  23.     for iter in range(20):
    ' P+ i, R$ d+ t: D( ]1 v, W# p( H& d
  24. * x, ~7 ]; z) ]( @8 u. L# w
  25.         # 1) 消除之前的梯度(如果存在)' _4 \9 Y' [: T* p) R: Y( |
  26. - q/ p8 D0 O/ |' {4 a' d\" r+ \7 x
  27.         opt.zero_grad()1 G8 a7 h( o) Q8 |

  28. / w\" U2 m$ N\" u+ Q

  29. * _. w; s, ?- K6 j
  30. ( Y! {* B9 d& b/ O( a7 |5 a
  31.         # 2) 预测
    $ C\" \\" X$ j( M! m4 G$ [

  32. 2 k# r* R! S5 O
  33.         pred = model(data)( m( S3 n* Q! I\" x- o( L2 g
  34. : H$ d! {8 U/ u' ^# n8 w  V; E, z
  35. 3 [6 ?9 l  a( f) }! S
  36. . @+ F9 s) \4 o8 X/ h1 u
  37.         # 3) 计算损失; N+ T1 O! o\" c$ I2 }5 u  g8 k
  38. 8 {7 z' |\" E0 i# w( Z
  39.         loss = ((pred - target)**2).sum()
    $ _' C0 C1 r+ y8 Y# w3 D! _
  40. % P& E- E\" T, d4 c( ~- n

  41. 3 c  g  A# T$ j7 x% E) X* ?

  42. 0 t. k1 Q& M6 N& k
  43.         # 4) 指出那些导致损失的参数(损失回传)
    0 x7 a( U2 D7 z/ Z) N+ ^1 ?

  44. # u' r) U+ L# i
  45.         loss.backward()! ^2 V! M- C% Y3 m, h\" t
  46. % H% [5 x; `# N* M
  47.     for name, param in model.named_parameters():
    ( r3 x7 e% \( b  c  e; D3 z% z9 f\" B
  48. # |% k/ @\" r: q1 t5 ~* ?% v* p/ }
  49.             print(name, param.data, param.grad)) F1 o\" Q1 \7 t  ~
  50. - D3 Q# H4 A3 t! V$ \/ y/ c
  51.         # 5) 更新参数3 P0 N1 R$ m6 y& V, t

  52. ! K3 S. f. u\" N
  53.         opt.step()8 @3 \9 L* _7 x

  54. 4 J  n7 f5 C+ f8 T

  55. % c; n$ N. h\" h9 E9 D' q$ o) P

  56. 7 E% V$ b/ P, D1 ]& d
  57.         # 6) 打印进程
    8 {\" |4 \( Y5 F+ @0 v
  58. . Y( N! Z; \$ ^: F! ?
  59.         print(loss.data)
    1 V# y- S, s) v\" z! L( X5 S! B; j; s

  60. 9 M: `& L\" l% i
  61. 9 }7 z# Z: T4 Q8 `

  62. $ C% U$ m, g% Z
  63. if __name__ == "__main__":  Z( s  E1 }# l' L
  64. ( \* Z- y% N- A! k3 Q9 p
  65.     train()' {- D\" o; _\" {6 y

  66. 3 s  b- P  o, b7 S& |/ D+ B
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
) O5 F6 W) }( u* K7 J8 ]4 J% [1 W6 C) P. e
计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])) C: m: M- L6 |
  2. ) J3 p7 y, w9 Q5 u4 d
  3. bias tensor([-0.2108]) tensor([-2.6971])7 u6 ^; p7 m( T3 O( K- p

  4. 7 U* E: D# {/ s+ y& b2 N. X! E1 d8 d, _
  5. tensor(0.8531)- C3 [' h$ Q% d7 \
  6. . `5 [; p/ D8 F. ~0 X- x
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])% n1 x; L- C; z- ]9 x( m1 `
  8. $ g5 I7 I7 n$ _% v! s, \' @+ O; U$ m
  9. bias tensor([0.0589]) tensor([0.7416])/ B* g, j8 N: x3 d! s# T

  10. * {2 F' o6 d% N$ A
  11. tensor(0.2712)2 N\" [6 P3 k\" F2 k& K& s/ x: A) q

  12. 3 N3 H  N' |! u; D& h; z# \5 D
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])3 Q3 y) ~. w\" G) y( z9 Y4 m9 ?- ?9 G

  14. ; }- e7 O3 r: j9 }: q* z
  15. bias tensor([-0.0152]) tensor([-0.2023])
    ) F; J1 k1 |- U: k0 F8 ^
  16. 5 D/ C, q  g7 u3 l; x/ X
  17. tensor(0.1529). b1 C# e/ v5 G8 k\" K! w

  18. ) N9 B5 o% w# z2 L( e8 Q\" b2 Y
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    0 D2 g/ `6 o& ]; E: ^8 ?
  20. - U( T7 P9 G2 g4 v8 u1 {4 `( X9 z
  21. bias tensor([0.0050]) tensor([0.0566])
    & _4 r0 v% h- X) Z1 I
  22. : ~7 {7 c7 J' p6 u' |
  23. tensor(0.0963)  G* V7 m3 D/ `& u3 b1 y1 R- u6 ]
  24. , J8 H( L# n; \6 L- W0 W
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    ! V' M( J- U, ?% {5 {5 s- F. C# V( W: e
  26. 3 u6 I+ b' S0 b3 K
  27. bias tensor([-0.0006]) tensor([-0.0146])7 m# z5 s4 M: t7 z\" C

  28. 4 J5 F9 R, O6 F% }! P% P
  29. tensor(0.0615)3 e6 v9 F, R6 k/ {* l! R; A; i
  30. 4 R# I3 L: N  E; R
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])! w\" B, s9 p& r% g% C
  32. . n, S8 L) z\" V0 y
  33. bias tensor([0.0008]) tensor([0.0048])
    7 ^$ e0 l# {) `

  34. 4 u4 s! j+ G: M, }8 O7 o
  35. tensor(0.0394)! ~  E' z4 w6 b\" E

  36. 5 `; {( W* e/ ]) L\" y; f3 w( G
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])% M! d: v  T; m( s$ T9 e

  38. 1 ^0 R, B( Y7 k* L& v9 C
  39. bias tensor([0.0003]) tensor([-0.0006])
    ! D) i$ o* S# y; `5 a; m/ W5 ~
  40. . Y/ T6 ?! V( d# g
  41. tensor(0.0252)
    9 e\" W: U  o5 m$ Y/ J' f4 D3 w
  42. 9 P- b; O1 N2 p, T
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    9 |' D: ^( G8 o3 _4 o0 ~& D' e

  44. 1 Y2 O: t; P& R1 \- P; [/ a3 r9 M
  45. bias tensor([0.0004]) tensor([0.0008])
    ; q' M, [* Z3 w! Y( g

  46. 4 p! U' ]4 W5 I7 b* s
  47. tensor(0.0161); a, T0 f6 m9 _$ L
  48. - K! U% I. ]+ G' k! F9 m
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    - F' }5 B% f. g5 x2 Q

  50. & u; a* Y1 K+ s( M8 ~5 N! O4 w
  51. bias tensor([0.0003]) tensor([0.0003]): B  |5 u0 D& Y0 f' S7 `
  52. 3 Z' j3 M/ K. h: q. V6 p
  53. tensor(0.0103)& G& o: m3 F' Q' c& t) f

  54. 1 m9 a# ~* F5 z$ i. ~
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    + @/ l7 q! p) v. C, \* S

  56. 8 ~& [, ]/ M2 K+ R1 X- b3 h
  57. bias tensor([0.0003]) tensor([0.0004])
    / M- @  a% U5 u/ f0 g
  58. 9 E8 t& D; f% s9 g. _* h* H! E! k
  59. tensor(0.0066)
    8 \) C( U0 y) ?& L6 O, m& m/ r6 n0 C

  60. 4 R/ N* _6 |  r1 ~) _$ Q: v
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])
      a; r( n5 f; C: {/ I

  62. , k1 z5 D2 R% m2 P( q2 e9 V
  63. bias tensor([0.0003]) tensor([0.0003])0 P& Q3 U6 N# R0 M0 y$ u$ r
  64. + V! \! Y9 ~$ \3 e$ p\" O% P
  65. tensor(0.0042)
    + A3 ?6 G/ e5 D9 _

  66. % c$ [( b5 S4 f/ E
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    2 j( H! z( C+ A1 w
  68. \" s! Z2 t3 [8 a1 U+ w( G
  69. bias tensor([0.0002]) tensor([0.0003])' L) O& m$ f3 J  X6 i  ], m' K\" |

  70. * a) D\" c1 ~; y% e) f
  71. tensor(0.0027)
      G/ U4 U6 P3 D3 O) W$ Z

  72. 8 w3 C/ r1 H3 Z7 A  s$ h! @5 u- W
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])3 b3 _3 K( X) ]+ o7 x- k: G# P

  74. / ~* B  B/ ~% J6 m) d\" x, a2 _
  75. bias tensor([0.0002]) tensor([0.0002]). b, T( W$ R. r5 i0 K. S
  76. 4 V+ |( p- _9 |5 G9 p3 X9 l% v1 L
  77. tensor(0.0017)
    5 v0 [$ S+ u6 i0 `1 m0 g) e% ~7 D

  78. # ^* [\" V% [+ M% K+ D, }7 C
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])7 G\" _3 j, }, [
  80. 1 z( [, ^6 d! @2 b2 c4 g
  81. bias tensor([0.0002]) tensor([0.0002])
    9 z4 c+ }: L! m! P! n: H. n3 t9 x5 j
  82. + H. U9 F% u; U. D0 ^; S  Z
  83. tensor(0.0011)
    % @9 o' M9 h; c7 ~
  84. - g! u  s  m9 f+ _% u! d
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    3 G) T2 `, J; R2 c' L5 U9 a
  86. 3 I, t$ d\" T; ?) f' Q% ^, D
  87. bias tensor([0.0001]) tensor([0.0002])
    : Q3 Q  I1 V1 ^9 J9 R7 v  [; l\" U

  88. 1 J( S* x. v; }3 H
  89. tensor(0.0007)8 ~) {5 ]0 F) d7 J1 s

  90. 0 B0 U7 {$ ^$ f; z3 y% v9 f/ m
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])6 Q1 ?* M+ f( p, m

  92. * h. m) u. [& b9 X
  93. bias tensor([0.0001]) tensor([0.0002])
    * D8 m, O, g- U% L; W
  94. \" |4 V$ K( C, X! R
  95. tensor(0.0005)
      f1 e% O7 |: e- L3 d6 O# [

  96. $ y\" q  I, v/ u% O
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])4 t1 M! }. R' @/ b  c
  98.   X; [  f% z0 u3 Q6 K! [! t1 i
  99. bias tensor([0.0001]) tensor([0.0001])
      Y) `0 j: M- z2 @+ h1 P! M

  100. 6 Q5 {% D# m6 ^$ s4 X. z4 Q6 A
  101. tensor(0.0003)8 A5 R- r) J\" K: X6 ^6 J

  102. ; v4 v$ [7 G2 a/ q& E$ x
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])
    9 i6 N9 v' |. M6 }6 [7 C
  104. ' k6 a9 @( \/ {3 H1 @% R
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    0 M- ~7 F* \  d# G\" L1 `7 N; N
  106. % b1 O/ x1 L, w7 l
  107. tensor(0.0002), _4 h6 \! E, X
  108. 1 n6 g2 q% H5 ^7 G; n2 U
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]]): z! x1 N2 l& V

  110. 2 @; C2 c8 l7 q
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    ' q* S- [1 C. ]; L+ E\" v

  112. 8 D8 s% y% K1 W* `0 Y' R+ o0 E+ _
  113. tensor(0.0001)8 x# F4 a! a8 i  B) |\" \7 u
  114. ( A$ _7 i/ h& _  w( U0 |1 t
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])' O) B5 A  |9 O# B. o9 s5 |
  116. , D4 T: m- ~  j% B8 l7 K& R( ]. Z
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])5 ]) {8 K/ k6 s% F$ S
  118. 9 y! t6 d( ~2 e\" ~, ^3 X' ]1 _/ Z
  119. tensor(7.6120e-05)
复制代码
  E7 e" X8 C7 m; m5 K! o
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2026-6-7 14:26 , Processed in 0.430201 second(s), 51 queries .

回顶部