QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 2541|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1184

主题

4

听众

2916

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么
  d  |: H4 E- R# [# e) v2 }SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。3 O8 i- w& B  _* g; }- w
怎么理解梯度?
& r; l+ }' u  N# e- }+ B假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
0 _' o6 N0 O+ }0 t" e1 t
6 S6 `6 u0 n* H& D2 ~/ V4 d在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。7 {4 ~2 x& W+ ?7 B: v

/ G1 d% M: D$ y9 ~( x  p- p9 t每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。) X& |7 b3 Z' _% i4 T6 S
( `$ ^) l$ z7 N9 Z8 w5 ~" l( V/ J+ q, |
为什么引入SGD
7 F- j' x0 M! |% {深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。1 a  N+ {- p( J7 g4 n7 J6 i' C( u

, X& m' p$ D, V" E怎么用SGD
  1. import torch
    ' N, z9 s7 ]2 u8 O' n5 _+ R
  2. ! U4 Z% g) V% V# w/ v
  3. from torch import nn
    7 e0 m\" n. p. S% h/ U9 Y\" O
  4. 6 G! p- l8 O. z\" ^0 m% G/ `- |
  5. from torch import optim' v$ W; M+ G2 [- N: x

  6. * I. O* j$ ~9 O

  7. - Y& Y  k% V& U

  8. ) G- W% W: H2 g
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
      j9 ~8 I# G9 t5 d; @/ D+ G2 M5 W

  10. / d7 R: ^/ k& m2 U; Z) F% B0 @
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)3 E5 J/ Z6 D& K6 d* L7 q; A

  12. % S  @3 `, x' u

  13. $ T+ @' ~+ q2 ]. p1 d! ]; U

  14. + \0 ~# c9 x. `; S
  15. model = nn.Linear(2, 1)
    + I) g4 q: t) A5 n) Q' B
  16. & g7 `8 p! J% `- R, ?/ S  @
  17.   u9 n! W% h2 G7 e\" n$ w

  18.   c! A: e/ f- b. s  R' o
  19. def train():% X6 _  a/ K1 m- D7 G$ G
  20. 7 F! O! W- j& l7 y\" ~9 J
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    8 f/ d/ p( f6 @+ ^9 d

  22. . p! p* |% W/ ~3 ?2 Z4 q3 f5 T/ m3 J
  23.     for iter in range(20):9 E3 e( m. x5 H. F' S$ x1 V) n1 x
  24. : e+ r4 l! p9 m2 m6 ^3 ~
  25.         # 1) 消除之前的梯度(如果存在)
    : c% j0 v; p& a7 }

  26. . j0 x: K, B* ]- E
  27.         opt.zero_grad()
      u  ^+ V( h+ i  |: i

  28. 3 [1 \1 {1 A& _( h

  29. 0 U  X% m3 X; y8 i2 [

  30. . s: F3 g: {2 E2 B8 s
  31.         # 2) 预测
    9 J' H5 p9 ~/ j8 n- G9 _+ S

  32. * j& t, Y' P$ l: I
  33.         pred = model(data): ^! r  C# _# \# e+ D

  34. , f; Z5 J) o% g0 q3 U( L

  35. / h; p( f0 d! p! ^. R

  36. & r8 ^) C$ h: L2 I% y9 f
  37.         # 3) 计算损失$ D0 X5 R! r/ U/ i* k1 D' J! J
  38. 6 B3 B& k\" [( N$ @
  39.         loss = ((pred - target)**2).sum()
    - h; E4 ^0 g2 S( y' o9 E6 q# u
  40. + Y5 p6 C) M. d# S* }& ^$ O, t
  41. 2 |9 Q! S5 ?) }4 E& v) x0 Y
  42. 1 k* S: I8 o& U; h6 W, g
  43.         # 4) 指出那些导致损失的参数(损失回传)) b7 c7 a\" D; J, T\" r6 W! R
  44. % [' N7 k6 T/ R  p, m
  45.         loss.backward()
    $ O' W\" X! `% t% b# C6 Y. B. x
  46. $ o$ T; g0 F7 o. i# _- G
  47.     for name, param in model.named_parameters():
    - s* h9 m5 O7 f3 l2 ~
  48. 1 |\" p7 J: L$ V# T% g
  49.             print(name, param.data, param.grad)2 e, i3 d+ o7 Y  d- ~

  50. ' O* g7 U$ u, c/ A
  51.         # 5) 更新参数
    & }3 ]2 ~8 L! J* h& a4 }  N: ^' }
  52. # P' @1 }0 x5 z  r
  53.         opt.step()
    $ l1 a/ f1 O) y# P8 k8 G3 w
  54. - A+ e; |/ z/ c  k! Y* |

  55. ; @4 L, t) Y6 ]0 V

  56. 9 \: I( D$ t# x7 O9 T
  57.         # 6) 打印进程( _8 Y; \3 Z* b
  58. $ ^  `: V  N' V1 n\" z
  59.         print(loss.data)
    \" V0 \6 f$ ]! ^. I4 X
  60. 7 ?: L0 g7 z5 h! a, ?* J
  61. 0 v2 ]3 M2 D/ ]8 Z1 y; V

  62. 2 a) j3 w' l9 r. l
  63. if __name__ == "__main__":
    ) v$ w9 ?9 r6 D( x% I

  64. \" s% I! q+ g9 t& y\" O
  65.     train()
    & C' r  B' ?* @* d

  66. 9 Y% {2 Q, C2 @& x2 L\" q7 K
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
) P1 U7 O  y5 Y( u( U  q# O2 I) S
9 s6 V6 t% Y2 }$ {计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])1 H: Q% [, ?; N2 d: A- A6 N+ Y; l7 [

  2. * x1 `1 H$ P7 i* \$ x2 ]/ H) N
  3. bias tensor([-0.2108]) tensor([-2.6971])
    3 k$ O( y  m  X

  4. ! g6 i! `+ q3 j8 \( z# d
  5. tensor(0.8531)
    . ]1 x8 C8 t\" W\" }5 i
  6. . C% l. ?! B) `) V) m, x
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    , g* \9 b  X' W
  8. ' v+ \# s) H9 M: l4 F
  9. bias tensor([0.0589]) tensor([0.7416]): l+ S5 @  U: Y% M, S. e9 n

  10. , {3 n  i; m% v$ _
  11. tensor(0.2712)
    7 L\" c. X# d/ K0 ^. E1 ^$ g
  12.   ~2 i+ B4 p2 J; U
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    4 R1 j- ~# r# `9 ~, P; ^  m+ X
  14. % M2 C! x7 @/ s8 B+ D8 z  m\" G
  15. bias tensor([-0.0152]) tensor([-0.2023])2 f& G4 G* X, A7 n. l  `' C& E; H

  16. ) C+ e# v' w! F. O
  17. tensor(0.1529)2 k* V% M* R8 N# `* X/ c- m9 k' _
  18. 5 ?9 f( c; T* d- ^' W: n4 R+ w2 @
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])3 M( H( E. Z, N, R( `
  20. : G9 q) [1 g3 [) m* R
  21. bias tensor([0.0050]) tensor([0.0566])
    ) s. l3 M. b/ T/ k
  22. ) D4 U* z* W0 K
  23. tensor(0.0963)) b' a4 m# ?. q, V! Y4 R8 G1 d

  24. 8 N- T3 [  P; A) v/ G& K\" T
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    9 T  K# g\" S\" V0 n8 M
  26. - q$ z' e& |  w
  27. bias tensor([-0.0006]) tensor([-0.0146])
      ~6 y, H4 L2 y: j# m& o* Y
  28. & G0 p2 ~  ~4 S6 @
  29. tensor(0.0615)
    / _# t  M\" E6 [# i1 a( F

  30. - k, y; @) Q+ A\" B
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])- N! g  x2 m0 I% Z3 t! C1 I
  32. / \( V. a6 @/ E3 I) Y5 d1 K. I
  33. bias tensor([0.0008]) tensor([0.0048])
    * K7 `4 Y) Y: m
  34. ' c+ c0 c( x% X/ R- g
  35. tensor(0.0394)( T8 g7 Z. G6 B' g/ m1 d
  36. + p, n. O- [: g2 p; y
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])
      l8 B2 k) u4 p  S
  38. 7 ~2 }\" g0 ^6 ~
  39. bias tensor([0.0003]) tensor([-0.0006])
    3 ~& z1 n1 w( n; K\" `) T

  40. 2 q% q- x) O* K5 H
  41. tensor(0.0252)7 |1 D, {$ u2 c! U+ ], W# ?8 r) h
  42. 1 V3 Y; K/ [, Y1 m0 o, D
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]]). {' ^9 \% A1 K) Z/ w

  44. ; x& [1 N* ?' n9 _& w% o  {# y
  45. bias tensor([0.0004]) tensor([0.0008])$ n' S  D6 f8 c2 }7 A0 u6 i

  46. \" O7 b; s1 N; J\" f
  47. tensor(0.0161); @1 o! A) n) J3 N

  48. & x+ Z4 j) i! f8 d: g: ~3 |
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])
    $ P0 a  T) l) h+ L2 e% X$ V  C
  50. . @( ]7 w1 |; @$ Y: m- B- h
  51. bias tensor([0.0003]) tensor([0.0003]). y6 B( z+ h5 x# G' w2 H4 y2 q

  52. 0 H3 \8 X0 T; U: s7 L\" M* S5 W5 {
  53. tensor(0.0103)( i. m% E* z2 W

  54. / d  O\" j& v/ ^5 \) X* h+ q
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])6 z; c& c4 G5 a; s( \

  56. + a1 d# v2 l+ h$ h0 v' U\" U
  57. bias tensor([0.0003]) tensor([0.0004])
    % ]$ l& {2 P\" u& \, w) s- h8 g

  58. . O4 z' R% n+ d# {
  59. tensor(0.0066)% [8 c- Q  L6 K( o- t0 V) A$ g9 b
  60. ! i0 w4 h, @# g: C\" z. j
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])+ A9 E, \3 ~3 N, j- B/ r\" e$ o

  62. 8 R; U  w0 m4 o; \; X' }2 Z
  63. bias tensor([0.0003]) tensor([0.0003])\" v  @% M/ a  Y! H* }8 o

  64. 3 X$ T0 N2 U& k\" G  I  g' U# y
  65. tensor(0.0042)
    1 F; Q/ q( w+ J  h$ w6 p

  66. 9 z' n3 U+ U% S5 M
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    : F* W5 G* H. {

  68. # r  v4 C' M9 s9 n8 _2 Q% r% S5 s6 L
  69. bias tensor([0.0002]) tensor([0.0003])
    & n! J9 S2 I+ S7 m& L- O% Z

  70. & _) m6 z  v- x5 \6 o( m5 z
  71. tensor(0.0027)
    4 j\" n/ Y/ X5 G9 u4 Z4 _; W# h
  72.   u9 E+ g' B% C# K* v8 u' |
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]]): {* b$ F6 t! S5 ~\" v/ ?4 Z; z
  74. 3 s- q+ l- ?% ]
  75. bias tensor([0.0002]) tensor([0.0002])
    7 g- U* {1 Y% f: n8 V5 _- ^

  76. & Z$ H; I\" X# s* z
  77. tensor(0.0017)2 f8 G6 t7 W$ r) B: d# \

  78. & g$ h( }1 e& G; a\" a. H3 i
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    + Y. a7 T* ~, S2 V* }
  80. - t$ u4 }7 O8 I  l6 {- x/ J
  81. bias tensor([0.0002]) tensor([0.0002])/ L) @6 Y  W- Q# [
  82. / x2 D' I+ |* H& F
  83. tensor(0.0011)
    : s  X\" P% Y' g: b' d! n, M
  84. 2 j/ p/ P: I3 n, l5 b5 q6 U: D
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    0 A0 k7 W% Y2 Z6 `& K! _! C
  86. 7 d\" d\" m9 J3 f2 }* G+ j* w
  87. bias tensor([0.0001]) tensor([0.0002])2 z  t/ n- K5 C9 I
  88. ' S, |' d, D- s# B% ?. v1 D
  89. tensor(0.0007)! e* p% ~4 a; L3 q
  90. $ J  p8 H+ M; Y# c3 f
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    / ]0 k4 \. o9 X1 Y

  92. * I0 l/ S7 y3 P* l3 X# h
  93. bias tensor([0.0001]) tensor([0.0002]); }& z8 i4 A3 p' ~! @. B! n2 t

  94. + N+ D/ m4 r/ Q4 ]+ z
  95. tensor(0.0005)
    + f# X, g* E5 F% {
  96. ! K9 H; W, k$ S( N6 t
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])4 y4 G  O% U, v7 [: H

  98. ) S  r7 b$ l8 H( ~- k
  99. bias tensor([0.0001]) tensor([0.0001])
    2 i# j, N6 k' T2 t0 f& A0 c/ T

  100. , N1 M% K( }% Y$ U0 s% n! g
  101. tensor(0.0003)
    3 H. d9 L% ~$ }1 u2 w3 l
  102. 4 ?) \( S6 N( i4 g
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])\" N8 |$ i+ I9 D- _) w3 m$ N) K0 x

  104. . C5 b8 {5 d1 k* `
  105. bias tensor([9.7973e-05]) tensor([0.0001])3 @2 E2 }; i# O- X+ d1 w. b
  106. 0 b' F% s2 }7 f7 A9 I' D
  107. tensor(0.0002)
    7 F, u- x\" ^8 C* k9 \

  108. ' M: I9 j9 @' f4 r. T
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])/ s/ ?0 E# v; H* T& u* R  P0 T

  110. 4 n, B2 h( N9 S2 K2 ]' n& ?9 }8 f. G
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    8 i0 o& L% |6 O3 l7 b

  112. - C2 U( J1 ~# n- v4 @* F! d8 U- ~
  113. tensor(0.0001): D- m2 f- A+ {4 @
  114. - r0 g+ S8 u! L8 `  ]/ Z
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])' ?! h( s: W  A- P+ p7 Q* k

  116. : h$ Q# U$ s6 r) e
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])
    / e' T& [8 Y6 H

  118. 0 Y5 X$ B7 Z: d$ Q& @4 {
  119. tensor(7.6120e-05)
复制代码

9 E" ]& B' c; q+ m1 {; G
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-12-28 16:58 , Processed in 0.753640 second(s), 50 queries .

回顶部