数学建模社区-数学中国

标题: 随机梯度下降算法SGD(Stochastic gradient descent) [打印本页]

作者: 2744557306    时间: 2023-11-28 14:57
标题: 随机梯度下降算法SGD(Stochastic gradient descent)
SGD是什么  N4 V0 y" \0 |& E3 c; _* T% `% d
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。$ f0 G' o- {/ C' u  E( e; [2 I
怎么理解梯度?0 m6 b' L8 I/ c& B* B
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
8 i- t% V+ W# S  L6 P; o/ l6 A6 X- B$ A1 ]" g
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。, w9 s3 C+ [- X( h; @7 m+ u

. s: _3 H" _% V" e. z0 g每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。4 @+ m0 D$ ~4 k0 o# ^8 ~/ T

8 p6 f" d5 f. S) ^( ^1 K$ N% I为什么引入SGD
% w4 o; U. ~% F& u* L. [! c8 {深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。5 p0 S. |) o8 F+ W

6 B5 n3 E( [. D8 ~! D- {# Y+ W怎么用SGD
  1. import torch
    ! t( W+ f% _! l- `! H9 G" P

  2. 5 @# i  p3 m3 D" ]" z6 h- J& V5 M% W
  3. from torch import nn
    % s0 ~+ ]0 g$ o: C) n

  4. 8 Z  ~5 F& ]) [# q7 B; [6 \1 |
  5. from torch import optim
    . H; U5 V) z  z/ ?
  6. 5 T+ z8 e3 N& a. A, @
  7. 5 x) a+ d# ~5 A7 {4 n( }; y3 n
  8. 8 K% J+ j7 o6 M$ |: q$ X/ X
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    & F% y& w* l) x8 Z

  10. 9 D' T+ [- }) u& v
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True), y( \, w4 _, R) ?8 H

  12. / M, E1 ?: W, E
  13. , y2 N- ~! R7 \: G

  14. " x: }0 z% X. r- E# W6 B' o
  15. model = nn.Linear(2, 1); g) B: p; o. c
  16. 1 Y7 ~. T$ }; D
  17. 5 E& ?8 l3 J; P1 L8 M2 \

  18. $ R) J3 b3 p5 C5 ^8 Z: j8 J( h
  19. def train():- @! D. z6 x1 @/ E# j0 ?: k  n
  20. * s  T4 F$ a5 w# [' T# u
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    , n, T& [/ ?) b# I: \: w) ~% _

  22. ) R# R- ^  Z' J5 F" Y! S$ ^# b  `
  23.     for iter in range(20):
    , f2 e5 d8 a* Y
  24. , t" l% d0 W4 A# M2 p; B, i: j$ d  f
  25.         # 1) 消除之前的梯度(如果存在)
    5 Q5 |3 J: v! z3 I7 V
  26. 6 ^4 x5 F# I3 ~' A& ]
  27.         opt.zero_grad()
      ]$ ]3 A5 Q. U# J) |( I8 D2 {
  28. ! F7 ]  I! U; [8 j8 i2 V/ ^, y

  29. 1 n+ U8 k+ y: R; [6 b+ U

  30. 6 e$ N1 D" @5 P# m
  31.         # 2) 预测
    ; K  h* h) \( ~2 n# S/ b
  32. % e& C6 d( v; n" k& A
  33.         pred = model(data)
    5 _4 N4 Y1 S- v% Z! S

  34. : i2 d# n, \% L: c% x/ E

  35. 5 M7 x0 I+ y/ H% J, i
  36. 9 o% B; x6 o( T% Y- ?' ~* m1 f
  37.         # 3) 计算损失
    0 f1 L% ]! Q1 y* P: |/ D
  38. - e" k1 Z8 I7 T6 |3 o4 k# G$ h! ^2 n
  39.         loss = ((pred - target)**2).sum()# l8 K7 m$ p5 J( i% B: p

  40. 6 ]( j# Y  f0 |% G
  41. 7 `) T- C, t1 o

  42. % u7 O  l% A/ }
  43.         # 4) 指出那些导致损失的参数(损失回传)
    0 N# I7 U# G: c! L0 ?* n

  44. # r; v9 k3 d2 p3 s' ?
  45.         loss.backward()1 s  b* w! W/ I: }- J
  46. 0 ~. O! u8 t& }; u3 i2 F) k
  47.     for name, param in model.named_parameters():
    $ v$ h2 [% y* A) G/ s, _! E
  48. 7 n8 k, F2 I5 P, r+ @8 x! e6 x6 `
  49.             print(name, param.data, param.grad)7 {4 B+ Q6 w2 N! N3 V4 q
  50. & h/ {" y$ e# o1 X
  51.         # 5) 更新参数
    ! B8 c2 Q6 _& |# U* {8 O3 J7 T

  52. . x$ I  Z. R% I! J  ]0 M+ S, h
  53.         opt.step()
    ' A; P/ h* P  [$ h+ v1 u! h) `# u

  54. - c9 U' p, G( ~( I' H! e
  55. 0 w! O; V/ G( J

  56. - G# I$ L6 o1 A$ K7 w. f; s
  57.         # 6) 打印进程
    / l! p  j% ~7 X. F
  58. * N. E; `' N, q' G3 C3 B- S
  59.         print(loss.data)
    / A% o" `( C* t2 C' ?" L; W3 t
  60. # x- e+ S8 O( k, \: ?" R& ?* j

  61. 5 s  _) `) H: ^! L2 g: U8 V% P2 ?
  62. 8 {; m/ W/ {2 m1 U; }
  63. if __name__ == "__main__":) |; c$ s' Z. l* \
  64. 0 r/ l; q. O9 m' }; R4 {
  65.     train()
    " h' I1 `6 p7 S& b; J
  66. 0 z- g) \  ^& d- [5 `
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。) O; f5 K( A8 y/ \/ y1 G5 P

' k+ Z( k" s& b" e# n计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]]): m; V) @9 N& e( G* H
  2. 8 |0 J% v% S( P2 q
  3. bias tensor([-0.2108]) tensor([-2.6971])# k; e* {9 L, m0 d. [' d6 F. \1 A& |

  4. # S* t; ~9 M) I! m* K* E2 r
  5. tensor(0.8531)
    ) o( K$ t% _$ i5 }

  6. 4 {' P! |3 J" e3 y
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    9 h7 R5 f: h' q$ I( O2 v1 w

  8. 2 l& U6 L' ]# b
  9. bias tensor([0.0589]) tensor([0.7416])  ?' F) K8 O+ a& `- a9 y5 i) ?5 V) C! C  p
  10. 4 z0 Y5 _  O/ E" c
  11. tensor(0.2712)! R: Y! G# `: S6 _( p
  12. % S9 Z) e% f. r% K% h$ B
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    - }# ^& C. G' {3 H

  14. 7 i* B  }; }( G: ]) D+ F
  15. bias tensor([-0.0152]) tensor([-0.2023])
    ; T1 ?: S' U/ O/ o9 Z% C/ O

  16. , v+ w2 F2 c: f  J0 \# f# j
  17. tensor(0.1529)9 x( Z% T( i1 O9 |* t1 n

  18. , E6 s) Y3 e' v3 S# ^( E# w# p
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])
    / ]/ }; O8 H. l7 x; P2 L# }

  20. % j- [( r. Q: N" ~/ z& o, H, Z: Z
  21. bias tensor([0.0050]) tensor([0.0566])
    % f* J4 z2 @5 ]; [) o7 N

  22. 7 G; t4 B2 B( v" b; p8 A
  23. tensor(0.0963)! V% X( ^9 K( y% B

  24.   p( a- u, }0 G1 b
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])
    ; k: g9 z9 m" g6 E4 B! A

  26. , B# |2 Q4 @. I
  27. bias tensor([-0.0006]) tensor([-0.0146]), G  z' I2 Z# Y1 y9 T

  28. % f& T3 w5 [$ `1 P3 z1 a
  29. tensor(0.0615)
    / j. @" J3 a' A

  30. ' Y, _4 D& p" [- L
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])# m. n3 Q' ~1 j- \$ e
  32. & d$ x5 l2 G2 W( J8 c9 Z0 W" r. ~
  33. bias tensor([0.0008]) tensor([0.0048])! r" O% q" S) J8 Y# X; o3 k: I- l

  34. ) b$ J4 M, H3 A8 z1 R
  35. tensor(0.0394)
    8 t( Y. d# M  G6 G) t$ Z$ f

  36. 3 ]) g2 C; B& k1 x+ O" d
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])- G% J( c) e2 ~1 b

  38. ; l/ L- ]# w2 [4 k8 e$ X4 d2 A5 q& x
  39. bias tensor([0.0003]) tensor([-0.0006])
    - I4 J  \2 K# B2 q3 _/ s+ N3 {

  40. ) \  b, v. n1 h' ^+ b, u
  41. tensor(0.0252)0 y' V- B$ ?/ f- B- H4 o% c

  42. . U" \9 D( q9 t& K
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    ( F4 C% x$ @- B
  44. 1 E0 t$ F$ q: J) D8 x1 F6 O  T+ v% g
  45. bias tensor([0.0004]) tensor([0.0008])
    4 t6 O) ^5 H& c1 {1 L/ h

  46. $ a8 |; E! r$ A( E7 a" }& f7 J; _0 i
  47. tensor(0.0161)
    5 J. P- ?" L/ w& y% v

  48. * n3 Z6 N+ ?! l; r
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])& f( `3 |* w9 _8 k
  50.   r: \* P# h: V; `
  51. bias tensor([0.0003]) tensor([0.0003])0 b: ]1 v0 t6 K' a

  52. : O  J& Z8 S( t# T9 p2 M
  53. tensor(0.0103)) ?3 q/ d3 K5 Z: ?

  54. * k% l% v- Q, N3 l7 M( w
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])
    9 ]6 J" P" o, o+ z2 \

  56. 4 ]0 e1 ~6 u9 D
  57. bias tensor([0.0003]) tensor([0.0004])
    6 S0 N+ U3 @( P& \* @+ C% W7 ]* }

  58. 7 b1 |3 Q, I  @  M( n# \
  59. tensor(0.0066)& X( _2 C* R& l+ Y1 \* `

  60.   i: `1 d& m  S0 X8 ?! n3 K8 y
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]])# l! M8 L$ L3 z1 A% N: l0 J
  62. , b- h/ _3 p7 ]9 L# @7 c% ?
  63. bias tensor([0.0003]) tensor([0.0003])
    5 F  m1 i! _+ w* X
  64. 1 L& H  g4 {, o; J% W& Q  E
  65. tensor(0.0042)
    ( @9 k' v4 W7 ?2 Y7 F! d* S" ?
  66. * G! U, X& M, S# r4 [4 L
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])
    ) N9 }+ \- x6 r" u& N! H# D
  68. ; a, m$ L7 L5 D; L4 ~0 |
  69. bias tensor([0.0002]) tensor([0.0003])
    / S3 A0 |. G% w% c$ P# [* D, h

  70. ) X3 w0 O+ k, z. Q0 s! V
  71. tensor(0.0027)
    $ x! S  a7 Z$ V! Y6 U& H

  72. + d* p( X' \0 C1 v. ]  z! o/ d
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    ! i7 \! e4 a- i5 j  R/ [

  74. ' z5 ~2 N# a6 O9 K9 v+ R
  75. bias tensor([0.0002]) tensor([0.0002])4 p9 w- }1 b1 S
  76. ) U/ G- M" U4 ?
  77. tensor(0.0017)# F9 _8 H4 R5 Z1 F: R1 o3 P

  78. - D/ o2 d  A* c" n1 X
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    9 I' ]& H0 ~0 H. R0 Q* C1 X( w$ `+ C

  80. $ V7 p4 g: ?1 y+ k% G7 V* Z8 \- {) [
  81. bias tensor([0.0002]) tensor([0.0002]). B3 g4 G2 i  O" n9 E& v0 R- I
  82. 4 @6 z. V/ f4 _& p9 S
  83. tensor(0.0011)6 g0 R; p- m1 e/ p$ C# [+ A# B
  84.   o3 b8 ^7 g' J, V0 I2 g
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    3 f, v9 Q* O/ e, j
  86. ; o9 |6 i: {, x2 i4 B
  87. bias tensor([0.0001]) tensor([0.0002])) @7 V1 U; s) `5 l$ o3 f
  88. 8 u' d9 m4 @+ f: D& {* l
  89. tensor(0.0007)( C7 l' J, ^' i. b8 R2 m
  90. " h) C2 h" F2 l* u  e3 F7 l
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])
    % h9 K; y. R4 r/ q( F" |

  92. 4 {& ~5 t+ h/ i; T
  93. bias tensor([0.0001]) tensor([0.0002])  g4 ^1 `- ]2 w; b7 s; h* Y& J# g6 p

  94. 6 k7 c! Q% I9 [# j5 G# i
  95. tensor(0.0005)* ]3 C- O1 x; l% @  `$ T

  96. # p  `# }/ K. ?5 p7 E) F" E
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])) Z- u, }3 d3 E  m

  98. 3 j  S; ?% i4 U* D0 c6 V, s* g
  99. bias tensor([0.0001]) tensor([0.0001])
    ! f& ?5 ~9 c2 X. g* T
  100. - S& Q' o0 ]$ @7 Y* p+ U
  101. tensor(0.0003)
    7 D8 s/ r& U% p3 z
  102. ! v2 F' K# c( \& X/ a' I
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])+ S% \; h/ }% j& H2 W

  104. : M3 P* K( @- f0 W" r5 {/ K: [6 f9 T
  105. bias tensor([9.7973e-05]) tensor([0.0001])
    7 N' y5 }. n8 \: ^7 w7 M
  106. 9 t. O: k% ]) p; f" D
  107. tensor(0.0002)8 j0 J+ J/ n, c9 k. e$ S! G0 i

  108. ; |# [8 G4 t1 k! Z2 f; V$ h
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])% S4 p1 u# H7 u2 T) q6 X9 Z' t
  110. 8 t4 q, [& i9 x
  111. bias tensor([8.5674e-05]) tensor([0.0001])
    1 G8 [; Y+ U8 H9 T. d9 W# U

  112.   c5 `( ^9 E) N3 t. F
  113. tensor(0.0001)6 l" f: E' y& @7 w; u

  114. ) b) U' Q  `3 @6 A* P# z& X
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])5 G$ S9 p# U1 Z3 ^1 _& O6 ?

  116. % d9 {/ M3 G% p* @5 j9 J. S7 A& M# ^
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])# G6 h/ [. ]1 |' d, ?( O1 p

  118. 8 K/ F7 L# I' I) x
  119. tensor(7.6120e-05)
复制代码
1 n) \' d$ c- ?4 Y+ q





欢迎光临 数学建模社区-数学中国 (http://www.madio.net/) Powered by Discuz! X2.5