数学建模社区-数学中国
标题:
随机梯度下降算法SGD(Stochastic gradient descent)
[打印本页]
作者:
2744557306
时间:
2023-11-28 14:57
标题:
随机梯度下降算法SGD(Stochastic gradient descent)
SGD是什么
N4 V0 y" \0 |& E3 c; _* T% `% d
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
$ f0 G' o- {/ C' u E( e; [2 I
怎么理解梯度?
0 m6 b' L8 I/ c& B* B
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
8 i- t% V+ W# S L6 P; o
/ l6 A6 X- B$ A1 ]" g
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
, w9 s3 C+ [- X( h; @7 m+ u
. s: _3 H" _% V" e. z0 g
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
4 @+ m0 D$ ~4 k0 o# ^8 ~/ T
8 p6 f" d5 f. S) ^( ^1 K$ N% I
为什么引入SGD
% w4 o; U. ~% F& u* L. [! c8 {
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
5 p0 S. |) o8 F+ W
6 B5 n3 E( [. D8 ~! D- {# Y+ W
怎么用SGD
import torch
! t( W+ f% _! l- `! H9 G" P
5 @# i p3 m3 D" ]" z6 h- J& V5 M% W
from torch import nn
% s0 ~+ ]0 g$ o: C) n
8 Z ~5 F& ]) [# q7 B; [6 \1 |
from torch import optim
. H; U5 V) z z/ ?
5 T+ z8 e3 N& a. A, @
5 x) a+ d# ~5 A7 {4 n( }; y3 n
8 K% J+ j7 o6 M$ |: q$ X/ X
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
& F% y& w* l) x8 Z
9 D' T+ [- }) u& v
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
, y( \, w4 _, R) ?8 H
/ M, E1 ?: W, E
, y2 N- ~! R7 \: G
" x: }0 z% X. r- E# W6 B' o
model = nn.Linear(2, 1)
; g) B: p; o. c
1 Y7 ~. T$ }; D
5 E& ?8 l3 J; P1 L8 M2 \
$ R) J3 b3 p5 C5 ^8 Z: j8 J( h
def train():
- @! D. z6 x1 @/ E# j0 ?: k n
* s T4 F$ a5 w# [' T# u
opt = optim.SGD(params=model.parameters(), lr=0.1)
, n, T& [/ ?) b# I: \: w) ~% _
) R# R- ^ Z' J5 F" Y! S$ ^# b `
for iter in range(20):
, f2 e5 d8 a* Y
, t" l% d0 W4 A# M2 p; B, i: j$ d f
# 1) 消除之前的梯度(如果存在)
5 Q5 |3 J: v! z3 I7 V
6 ^4 x5 F# I3 ~' A& ]
opt.zero_grad()
]$ ]3 A5 Q. U# J) |( I8 D2 {
! F7 ] I! U; [8 j8 i2 V/ ^, y
1 n+ U8 k+ y: R; [6 b+ U
6 e$ N1 D" @5 P# m
# 2) 预测
; K h* h) \( ~2 n# S/ b
% e& C6 d( v; n" k& A
pred = model(data)
5 _4 N4 Y1 S- v% Z! S
: i2 d# n, \% L: c% x/ E
5 M7 x0 I+ y/ H% J, i
9 o% B; x6 o( T% Y- ?' ~* m1 f
# 3) 计算损失
0 f1 L% ]! Q1 y* P: |/ D
- e" k1 Z8 I7 T6 |3 o4 k# G$ h! ^2 n
loss = ((pred - target)**2).sum()
# l8 K7 m$ p5 J( i% B: p
6 ]( j# Y f0 |% G
7 `) T- C, t1 o
% u7 O l% A/ }
# 4) 指出那些导致损失的参数(损失回传)
0 N# I7 U# G: c! L0 ?* n
# r; v9 k3 d2 p3 s' ?
loss.backward()
1 s b* w! W/ I: }- J
0 ~. O! u8 t& }; u3 i2 F) k
for name, param in model.named_parameters():
$ v$ h2 [% y* A) G/ s, _! E
7 n8 k, F2 I5 P, r+ @8 x! e6 x6 `
print(name, param.data, param.grad)
7 {4 B+ Q6 w2 N! N3 V4 q
& h/ {" y$ e# o1 X
# 5) 更新参数
! B8 c2 Q6 _& |# U* {8 O3 J7 T
. x$ I Z. R% I! J ]0 M+ S, h
opt.step()
' A; P/ h* P [$ h+ v1 u! h) `# u
- c9 U' p, G( ~( I' H! e
0 w! O; V/ G( J
- G# I$ L6 o1 A$ K7 w. f; s
# 6) 打印进程
/ l! p j% ~7 X. F
* N. E; `' N, q' G3 C3 B- S
print(loss.data)
/ A% o" `( C* t2 C' ?" L; W3 t
# x- e+ S8 O( k, \: ?" R& ?* j
5 s _) `) H: ^! L2 g: U8 V% P2 ?
8 {; m/ W/ {2 m1 U; }
if __name__ == "__main__":
) |; c$ s' Z. l* \
0 r/ l; q. O9 m' }; R4 {
train()
" h' I1 `6 p7 S& b; J
0 z- g) \ ^& d- [5 `
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
) O; f5 K( A8 y/ \/ y1 G5 P
' k+ Z( k" s& b" e# n
计算结果:
weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
: m; V) @9 N& e( G* H
8 |0 J% v% S( P2 q
bias tensor([-0.2108]) tensor([-2.6971])
# k; e* {9 L, m0 d. [' d6 F. \1 A& |
# S* t; ~9 M) I! m* K* E2 r
tensor(0.8531)
) o( K$ t% _$ i5 }
4 {' P! |3 J" e3 y
weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466, 1.1232]])
9 h7 R5 f: h' q$ I( O2 v1 w
2 l& U6 L' ]# b
bias tensor([0.0589]) tensor([0.7416])
?' F) K8 O+ a& `- a9 y5 i) ?5 V) C! C p
4 z0 Y5 _ O/ E" c
tensor(0.2712)
! R: Y! G# `: S6 _( p
% S9 Z) e% f. r% K% h$ B
weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692, 0.4266]])
- }# ^& C. G' {3 H
7 i* B }; }( G: ]) D+ F
bias tensor([-0.0152]) tensor([-0.2023])
; T1 ?: S' U/ O/ o9 Z% C/ O
, v+ w2 F2 c: f J0 \# f# j
tensor(0.1529)
9 x( Z% T( i1 O9 |* t1 n
, E6 s) Y3 e' v3 S# ^( E# w# p
weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059, 0.4707]])
/ ]/ }; O8 H. l7 x; P2 L# }
% j- [( r. Q: N" ~/ z& o, H, Z: Z
bias tensor([0.0050]) tensor([0.0566])
% f* J4 z2 @5 ]; [) o7 N
7 G; t4 B2 B( v" b; p8 A
tensor(0.0963)
! V% X( ^9 K( y% B
p( a- u, }0 G1 b
weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603, 0.3410]])
; k: g9 z9 m" g6 E4 B! A
, B# |2 Q4 @. I
bias tensor([-0.0006]) tensor([-0.0146])
, G z' I2 Z# Y1 y9 T
% f& T3 w5 [$ `1 P3 z1 a
tensor(0.0615)
/ j. @" J3 a' A
' Y, _4 D& p" [- L
weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786, 0.2825]])
# m. n3 Q' ~1 j- \$ e
& d$ x5 l2 G2 W( J8 c9 Z0 W" r. ~
bias tensor([0.0008]) tensor([0.0048])
! r" O% q" S) J8 Y# X; o3 k: I- l
) b$ J4 M, H3 A8 z1 R
tensor(0.0394)
8 t( Y. d# M G6 G) t$ Z$ f
3 ]) g2 C; B& k1 x+ O" d
weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256, 0.2233]])
- G% J( c) e2 ~1 b
; l/ L- ]# w2 [4 k8 e$ X4 d2 A5 q& x
bias tensor([0.0003]) tensor([-0.0006])
- I4 J \2 K# B2 q3 _/ s+ N3 {
) \ b, v. n1 h' ^+ b, u
tensor(0.0252)
0 y' V- B$ ?/ f- B- H4 o% c
. U" \9 D( q9 t& K
weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797, 0.1793]])
( F4 C% x$ @- B
1 E0 t$ F$ q: J) D8 x1 F6 O T+ v% g
bias tensor([0.0004]) tensor([0.0008])
4 t6 O) ^5 H& c1 {1 L/ h
$ a8 |; E! r$ A( E7 a" }& f7 J; _0 i
tensor(0.0161)
5 J. P- ?" L/ w& y% v
* n3 Z6 N+ ?! l; r
weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440, 0.1432]])
& f( `3 |* w9 _8 k
r: \* P# h: V; `
bias tensor([0.0003]) tensor([0.0003])
0 b: ]1 v0 t6 K' a
: O J& Z8 S( t# T9 p2 M
tensor(0.0103)
) ?3 q/ d3 K5 Z: ?
* k% l% v- Q, N3 l7 M( w
weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152, 0.1146]])
9 ]6 J" P" o, o+ z2 \
4 ]0 e1 ~6 u9 D
bias tensor([0.0003]) tensor([0.0004])
6 S0 N+ U3 @( P& \* @+ C% W7 ]* }
7 b1 |3 Q, I @ M( n# \
tensor(0.0066)
& X( _2 C* R& l+ Y1 \* `
i: `1 d& m S0 X8 ?! n3 K8 y
weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922, 0.0917]])
# l! M8 L$ L3 z1 A% N: l0 J
, b- h/ _3 p7 ]9 L# @7 c% ?
bias tensor([0.0003]) tensor([0.0003])
5 F m1 i! _+ w* X
1 L& H g4 {, o; J% W& Q E
tensor(0.0042)
( @9 k' v4 W7 ?2 Y7 F! d* S" ?
* G! U, X& M, S# r4 [4 L
weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738, 0.0733]])
) N9 }+ \- x6 r" u& N! H# D
; a, m$ L7 L5 D; L4 ~0 |
bias tensor([0.0002]) tensor([0.0003])
/ S3 A0 |. G% w% c$ P# [* D, h
) X3 w0 O+ k, z. Q0 s! V
tensor(0.0027)
$ x! S a7 Z$ V! Y6 U& H
+ d* p( X' \0 C1 v. ] z! o/ d
weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590, 0.0586]])
! i7 \! e4 a- i5 j R/ [
' z5 ~2 N# a6 O9 K9 v+ R
bias tensor([0.0002]) tensor([0.0002])
4 p9 w- }1 b1 S
) U/ G- M" U4 ?
tensor(0.0017)
# F9 _8 H4 R5 Z1 F: R1 o3 P
- D/ o2 d A* c" n1 X
weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472, 0.0469]])
9 I' ]& H0 ~0 H. R0 Q* C1 X( w$ `+ C
$ V7 p4 g: ?1 y+ k% G7 V* Z8 \- {) [
bias tensor([0.0002]) tensor([0.0002])
. B3 g4 G2 i O" n9 E& v0 R- I
4 @6 z. V/ f4 _& p9 S
tensor(0.0011)
6 g0 R; p- m1 e/ p$ C# [+ A# B
o3 b8 ^7 g' J, V0 I2 g
weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378, 0.0375]])
3 f, v9 Q* O/ e, j
; o9 |6 i: {, x2 i4 B
bias tensor([0.0001]) tensor([0.0002])
) @7 V1 U; s) `5 l$ o3 f
8 u' d9 m4 @+ f: D& {* l
tensor(0.0007)
( C7 l' J, ^' i. b8 R2 m
" h) C2 h" F2 l* u e3 F7 l
weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303, 0.0300]])
% h9 K; y. R4 r/ q( F" |
4 {& ~5 t+ h/ i; T
bias tensor([0.0001]) tensor([0.0002])
g4 ^1 `- ]2 w; b7 s; h* Y& J# g6 p
6 k7 c! Q% I9 [# j5 G# i
tensor(0.0005)
* ]3 C- O1 x; l% @ `$ T
# p `# }/ K. ?5 p7 E) F" E
weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242, 0.0240]])
) Z- u, }3 d3 E m
3 j S; ?% i4 U* D0 c6 V, s* g
bias tensor([0.0001]) tensor([0.0001])
! f& ?5 ~9 c2 X. g* T
- S& Q' o0 ]$ @7 Y* p+ U
tensor(0.0003)
7 D8 s/ r& U% p3 z
! v2 F' K# c( \& X/ a' I
weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194, 0.0192]])
+ S% \; h/ }% j& H2 W
: M3 P* K( @- f0 W" r5 {/ K: [6 f9 T
bias tensor([9.7973e-05]) tensor([0.0001])
7 N' y5 }. n8 \: ^7 w7 M
9 t. O: k% ]) p; f" D
tensor(0.0002)
8 j0 J+ J/ n, c9 k. e$ S! G0 i
; |# [8 G4 t1 k! Z2 f; V$ h
weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155, 0.0153]])
% S4 p1 u# H7 u2 T) q6 X9 Z' t
8 t4 q, [& i9 x
bias tensor([8.5674e-05]) tensor([0.0001])
1 G8 [; Y+ U8 H9 T. d9 W# U
c5 `( ^9 E) N3 t. F
tensor(0.0001)
6 l" f: E' y& @7 w; u
) b) U' Q `3 @6 A* P# z& X
weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124, 0.0123]])
5 G$ S9 p# U1 Z3 ^1 _& O6 ?
% d9 {/ M3 G% p* @5 j9 J. S7 A& M# ^
bias tensor([7.4933e-05]) tensor([9.4233e-05])
# G6 h/ [. ]1 |' d, ?( O1 p
8 K/ F7 L# I' I) x
tensor(7.6120e-05)
复制代码
1 n) \' d$ c- ?4 Y+ q
欢迎光临 数学建模社区-数学中国 (http://www.madio.net/)
Powered by Discuz! X2.5