QQ登录

只需要一步,快速开始

 注册地址  找回密码
查看: 1698|回复: 0
打印 上一主题 下一主题

随机梯度下降算法SGD(Stochastic gradient descent)

[复制链接]
字体大小: 正常 放大

1171

主题

4

听众

2749

积分

该用户从未签到

跳转到指定楼层
1#
发表于 2023-11-28 14:57 |只看该作者 |倒序浏览
|招呼Ta 关注Ta
SGD是什么) c" |! a; E# w+ K& ~% b
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。, r% S, _2 o% i' v- Y7 G& I2 j
怎么理解梯度?
8 k9 N1 O; s1 ~1 n% y( H  b( x# `: }假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
1 V' F5 m$ f7 A4 R2 V8 H# A! H( c+ A7 \9 N
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
& D7 _( C( ?, ~) t+ k/ r+ y7 v$ ?, [! w( h1 c7 W7 U
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。. ~, Y0 T- b8 ^; W3 g
9 @/ a$ s* `* p: K9 i
为什么引入SGD1 Y* D; y3 m9 w- `9 a
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
' n& T) p  v7 l" G; f: l- m7 T' z  Z3 d2 P& Z8 W8 N  S
怎么用SGD
  1. import torch
    \" I2 `6 Z2 l4 w\" p

  2. - P4 p+ e! B7 `
  3. from torch import nn
    % d! L' L8 M% z+ ~8 Z) F+ r' j
  4. 4 L* `% ~# r; t* v) ]
  5. from torch import optim& I7 K4 ^5 g4 e$ `$ q8 t: M
  6. & W  w3 @. l: M+ T( J7 Z

  7. $ r! f. ~) P\" Y  G( k
  8. . o# u' |% b  {4 T4 I- W1 @
  9. data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
    / V6 V7 K$ p6 f1 s7 k' M

  10. ; A5 U, e( q( u$ A0 Z: l
  11. target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
    7 h+ O* O8 b- G
  12. - ^2 @+ s# J% K
  13. . A* Y( g+ a& d3 C+ i, I/ b
  14. / p, k# m8 u. p) }
  15. model = nn.Linear(2, 1)4 a! e/ G4 x5 o( V) `

  16. ) {* ~! ^; ~\" D3 }1 o
  17. ; }& V& s# s# L: ^7 o; i. u
  18. . X; Y1 e% [0 u) G0 _
  19. def train():
    / R6 z4 d- ~! ^0 w

  20. ; S) u; I4 m4 L1 h1 F) L) a
  21.     opt = optim.SGD(params=model.parameters(), lr=0.1)
    # W+ ~2 ?+ y8 J# m# E0 ~: i

  22. + a) ]* w9 Y% m( b# K( _! r
  23.     for iter in range(20):' C$ I$ q/ _# p$ v& Z1 w* b( W: n
  24. - C  l' I. r8 e& M8 m+ C+ C
  25.         # 1) 消除之前的梯度(如果存在)/ |2 D3 s, o) h/ G

  26. 6 U1 t) t1 C\" n3 v' k7 z
  27.         opt.zero_grad()5 D: f1 E6 M6 L; b5 h. [' |* i
  28. + Q. Y5 ^# i1 `; M0 _& J, l
  29. : L  c+ G2 L' k3 l6 r' \

  30. 5 o6 d3 k$ H  h% N\" R0 d. w
  31.         # 2) 预测
    \" x7 i3 P1 ~+ l1 m- }- \
  32. ( Q. P; D* H3 C0 w\" [  c
  33.         pred = model(data)
    + s0 g6 g% K' _3 B
  34. . L7 X9 ^7 X* x* X) R) Y6 ^9 p
  35. 0 G- F' ]2 h' C; Z5 `% v/ a
  36. # h( r) M# K6 {! k6 l2 _- ]& p# `! Q
  37.         # 3) 计算损失. g; Y% j9 t7 }/ ~9 _0 i) A\" D

  38. 8 s: v3 x) `  T1 t$ w+ C
  39.         loss = ((pred - target)**2).sum()- U0 h) W/ m5 m, q7 @- V1 j9 n  `5 m

  40. # z/ q% K2 i. |

  41. 4 {# D. J7 [3 B! ^! I& i2 T  h2 W: ?7 b
  42. % b, }5 ~$ G; s
  43.         # 4) 指出那些导致损失的参数(损失回传)
    7 M0 u\" N# J! g2 }9 d/ f

  44. ( Z9 _  w0 l1 U7 |/ K
  45.         loss.backward()( }3 X  K; X' X( p4 J

  46. / X% T4 \0 V: [& A/ |
  47.     for name, param in model.named_parameters():% Q% D( C/ U! S
  48. $ `9 Y: T\" Q0 R' v  B( e* \
  49.             print(name, param.data, param.grad)
    # |. j) _8 Z, i( j. {7 x\" {

  50. 5 [$ d' |  x  r\" L' J( d8 v
  51.         # 5) 更新参数
    * ~- q: _- c: f4 a2 u
  52. 0 {% s  m  i/ z
  53.         opt.step(), s- E' Y3 Z* n1 ]1 l( [. E$ F

  54. : p\" D' b4 f/ o5 o5 E( f' L
  55. # [7 u& X/ T) s0 D$ O

  56. # U* z! K8 z# u, @
  57.         # 6) 打印进程
    \" z# q3 u  E# s( t! a) {- W; `& `

  58. & C8 ^5 S7 K4 {9 J
  59.         print(loss.data)2 ^* _- T9 Y* x, o& l+ s) y
  60. 9 Y  C\" K! r* L# r5 V5 m

  61. ( i+ S0 ?. E  s% X6 E3 M
  62. 1 E' t( g  N' }
  63. if __name__ == "__main__":
    2 |' m* N! Q) v7 y

  64. 8 X; s% ?# s' u5 d+ b6 ?6 ~
  65.     train()
    3 p! o; G3 H2 A* e' m. i; ]. x  Z$ S
  66. + R* a& N6 P, D$ `9 W2 X7 T5 n
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
" W( S) D6 g; ~8 j8 b! Z% O- f
2 n# a0 e& \' u3 X! z2 X8 C; ~计算结果:
  1. weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
    # a# G5 A% S. \. J

  2. , |5 J5 Q$ u8 v; G; x: p
  3. bias tensor([-0.2108]) tensor([-2.6971])
    / j- e2 E9 o- g# e

  4. 1 v+ ]! v+ z6 E& m5 b
  5. tensor(0.8531)
    ) q+ `% M3 P9 s
  6. 2 P& b8 V0 c, W: l% o
  7. weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466,  1.1232]])
    & W1 Y) f) k8 s0 P9 p) d1 t
  8. 8 v, F6 l8 H3 k\" W# o
  9. bias tensor([0.0589]) tensor([0.7416])
    9 O* q8 j4 e. A2 j' @* I. b/ `
  10. 6 X$ M; L' s( S- H7 v
  11. tensor(0.2712), j  v% Z\" p: n7 Q) X
  12. 3 Z  d9 b8 M; n
  13. weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692,  0.4266]])
    . X1 Q1 \* m; D' ]- j; H2 R

  14.   V  X' G$ P5 n; K+ g
  15. bias tensor([-0.0152]) tensor([-0.2023]); G8 `1 O; z. H
  16. ) K0 v. @  M$ B
  17. tensor(0.1529)
      L2 U% D0 Y* P0 b4 V4 F% R$ V

  18. 0 \1 R  Q& X4 W% L8 }
  19. weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059,  0.4707]])% H3 I3 U+ E# @- E

  20. 3 ]# y$ w+ V% l& p, Q! V/ k  o
  21. bias tensor([0.0050]) tensor([0.0566])
      N2 j* S  A3 A4 n8 H8 A

  22. 3 O- H1 `. K& L\" A- l3 A
  23. tensor(0.0963)7 Z; ~0 o7 R7 C3 U) C* i( u

  24. 5 u\" p0 J7 ^0 G  h$ w
  25. weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603,  0.3410]])  f* Q0 F$ `- ]( u
  26. / e+ ?8 o) M& ^  N5 Z
  27. bias tensor([-0.0006]) tensor([-0.0146])
    , \3 H! [) _1 w( `* y4 F( n

  28. + Z( e: V6 b2 ]$ i2 y
  29. tensor(0.0615)$ [0 o. _8 o1 n\" x5 H
  30. $ D( R; e- [' B8 l: u9 O
  31. weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786,  0.2825]])
    ) L; L- p6 v6 d, a4 B
  32. ; a/ O) X, i6 O- Q4 |! f% j$ O. N; `
  33. bias tensor([0.0008]) tensor([0.0048])
    , y, Q3 r' y+ j4 N
  34. # L: S6 O: b% \& J' X; H+ m# ~0 c
  35. tensor(0.0394)0 [8 u) m) ?+ u0 F1 t\" ?

  36. + t; }! _' [! f2 Y$ l! m
  37. weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256,  0.2233]])0 Z8 X4 U! L6 i6 Y

  38. : z( p7 Q# R3 v' `+ r3 C( Y5 ^
  39. bias tensor([0.0003]) tensor([-0.0006])3 V; N8 Y% i* I\" M, P6 V2 J

  40. 3 Q3 p! f( `0 y! C/ z! \
  41. tensor(0.0252)
    - p7 L7 t0 |  G: v9 w3 X4 H
  42. 1 E5 u) Q0 y( g% t) m0 Y! |) ~$ X/ ?
  43. weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797,  0.1793]])
    $ d$ c+ e4 R/ `7 T# G

  44. 7 g, c; n! c* {8 Y
  45. bias tensor([0.0004]) tensor([0.0008])/ K4 H7 t& A9 M4 Z; T/ D4 @
  46. / }( t6 r# h* {) Q* s
  47. tensor(0.0161)9 O$ a6 h2 h; k, s

  48. 1 Y( Z- T) M- f6 k2 P
  49. weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440,  0.1432]])6 Y9 o) W7 l' H% {
  50. & Q: O8 _1 r) W; |
  51. bias tensor([0.0003]) tensor([0.0003]); V# o% b; `8 B  u) a! b( v

  52. . P+ G; \- w5 n% d9 J% f. i
  53. tensor(0.0103)& {+ N; Z8 l& j  G) V5 r( C

  54. & A, T+ V' O4 _' V; V
  55. weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152,  0.1146]])! C, U; R9 u) S' \& f* g

  56. 3 ?! K% ~/ Z3 K7 o3 q- `6 ]; v2 z
  57. bias tensor([0.0003]) tensor([0.0004])
    & n+ S  d0 U- u+ w& l) [( W
  58. # ]9 |# A0 x! v7 M
  59. tensor(0.0066)
    ; K5 z6 Q* o# |! T
  60. \" R# S3 K1 V\" [, M# L$ f/ M
  61. weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922,  0.0917]]), X1 I9 l! {  H4 |1 |
  62. ) U* C( f6 T5 m8 O9 p
  63. bias tensor([0.0003]) tensor([0.0003])
    5 K: [0 o; ], Y8 X9 c5 R  ?
  64. 8 m- ^% ]6 ~6 Y4 a
  65. tensor(0.0042)
      K' s& g8 @( J9 r8 E

  66. 9 g4 ]% V) z& H, `; k2 r/ s' A
  67. weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738,  0.0733]])) r6 g6 w% P9 h; N7 v# J/ r

  68. 0 k; a0 _4 |9 S\" _) l% `/ ^6 c  A
  69. bias tensor([0.0002]) tensor([0.0003])* X  h9 S6 V# L
  70. & L( E3 k3 ]2 T9 q
  71. tensor(0.0027)& ~8 o1 i* n* X3 [2 r
  72. 3 i- a& A5 i( F, |
  73. weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590,  0.0586]])
    2 [. L- a. N0 _; K( ]  R: B6 f

  74. , t9 e9 Z, I( p/ C1 M
  75. bias tensor([0.0002]) tensor([0.0002])* O( e$ r+ j( Q# T5 h
  76. & q: b2 d) T% s' r% r
  77. tensor(0.0017)
    ' o! g$ X3 M: z

  78. $ S% l: X; ]& X8 ]  o
  79. weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472,  0.0469]])
    ' [8 C* B  Q9 L% R5 D& P
  80. % K$ A7 _7 z/ B8 Y, d2 l/ R. U
  81. bias tensor([0.0002]) tensor([0.0002])
    : j) \8 u. ?8 r9 J5 J( J  O  v' E
  82. 9 h# R# D( E! A( u  _, g( W8 I\" v- Z
  83. tensor(0.0011)
    / |9 n8 K) K0 _, c# F

  84. # n/ @0 H* d+ F6 u2 ?! U
  85. weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378,  0.0375]])
    \" s( ^- L' f9 g' [# z) b! J/ H

  86. 1 ]' @; q- D( z$ @
  87. bias tensor([0.0001]) tensor([0.0002])6 G/ w- ^# |7 R6 D( i

  88. ! B  K  ?  {1 x6 j: _5 f
  89. tensor(0.0007)
    4 e: z$ l\" {/ [# T
  90. $ D\" G/ E4 @9 r' o
  91. weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303,  0.0300]])  p# l- X+ L; v

  92. \" g* Y# ]& p' _' [8 r
  93. bias tensor([0.0001]) tensor([0.0002])  s- v. i) R# H  ?* A  N! [% k

  94. ' [\" n  a( S9 H
  95. tensor(0.0005)
    & }0 q6 }( e  }8 S; b

  96.   C3 F/ m4 Z' |* p! s
  97. weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242,  0.0240]])
    \" Y, K0 X: P* r( Y3 {; c1 r' o+ t2 c

  98. , ^! H5 Q: L0 f+ j/ o/ A# t6 K
  99. bias tensor([0.0001]) tensor([0.0001])0 p9 B\" b3 [) P9 j7 r& g
  100. ) D: F% _9 c. }2 p; e: @# V  Z4 N' k
  101. tensor(0.0003)
    - f( W6 X% J# B4 N( W: R  K
  102. 1 d1 ~  E* }1 g2 ?
  103. weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194,  0.0192]])8 G, u( M7 y5 u  V  Y* b

  104. ) \; R& e& e5 x! ?8 w: G
  105. bias tensor([9.7973e-05]) tensor([0.0001])% s! P\" l\" f1 M0 w

  106. : x; D: A) f) i% O* J, T
  107. tensor(0.0002)) S; B' ]+ v$ U+ K0 ~8 s

  108. / b$ n) n: e\" i
  109. weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155,  0.0153]])7 K/ d: _% _! b% S4 v: |- n3 O  F
  110. ' x9 Q' h* v) ^8 [. C( Z% J: P
  111. bias tensor([8.5674e-05]) tensor([0.0001]). [9 J1 l\" k! e  j! a9 `9 I7 T
  112. ! Z( i9 e  @1 U5 m
  113. tensor(0.0001)
    4 \6 a8 c7 J% {; {6 `

  114. , b5 [- E  {; t( J
  115. weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124,  0.0123]])
    $ @/ }* V' h8 a9 f6 l\" A2 S6 ?0 S

  116. 3 e. u+ z; u, N\" H0 R8 ]0 _, ?
  117. bias tensor([7.4933e-05]) tensor([9.4233e-05])  E* ]$ e- w/ a7 V+ O, t

  118. # Z2 \# S: s+ F+ p$ k+ z* l
  119. tensor(7.6120e-05)
复制代码

8 K+ t. ~0 P! |
zan
转播转播0 分享淘帖0 分享分享0 收藏收藏0 支持支持0 反对反对0 微信微信
您需要登录后才可以回帖 登录 | 注册地址

qq
收缩
  • 电话咨询

  • 04714969085
fastpost

关于我们| 联系我们| 诚征英才| 对外合作| 产品服务| QQ

手机版|Archiver| |繁體中文 手机客户端  

蒙公网安备 15010502000194号

Powered by Discuz! X2.5   © 2001-2013 数学建模网-数学中国 ( 蒙ICP备14002410号-3 蒙BBS备-0002号 )     论坛法律顾问:王兆丰

GMT+8, 2025-5-11 19:46 , Processed in 0.495187 second(s), 50 queries .

回顶部