数学建模社区-数学中国
标题:
随机梯度下降算法SGD(Stochastic gradient descent)
[打印本页]
作者:
2744557306
时间:
2023-11-28 14:57
标题:
随机梯度下降算法SGD(Stochastic gradient descent)
SGD是什么
# g1 h2 D/ @9 y! ?6 X
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
7 y% L t# M" m' M6 r" Z: Y/ E
怎么理解梯度?
3 u# J% X# T- N& `
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
4 ~# A2 N! P( O% ?% I5 U: q
! D' _5 }" Q/ t7 }' G3 F1 U
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
$ {: M& U. {! M
" t' P& c' W/ z. \# f
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
& D# W7 L, e" B* h$ Z* r4 z
: d- I; h. M8 K, _" v
为什么引入SGD
% ?( Q* b" b( V- D* d
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
9 k) G, ?+ i/ e- y
L& u9 g6 p9 H' l, K9 S
怎么用SGD
import torch
E3 n! G8 p" D) c) |1 S
% v/ O$ K* J6 U7 c5 D8 k# i; J- R& I
from torch import nn
; F8 b$ p; L' k& U+ C5 J' p; ]
+ u3 b! r* c2 d8 U X
from torch import optim
; i9 A5 d7 O9 D% E5 M1 h- `
2 j2 o% B: r6 _9 h" w2 t& Q
9 ^. z- I1 Y' m9 L! R8 M
- H& N* n3 r( y5 Y1 {6 [1 H
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
' t5 q1 |5 B& |+ Z" c
* k4 z% _5 h" i z* s
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
3 m5 s7 z+ J- e! f# t
" W: u+ V8 G7 w! }! ]; K
4 f* K% R) N# u- U
( T! U7 p# S3 S2 t* _$ V
model = nn.Linear(2, 1)
% R; R- O& Z3 y7 H* T: q) s& V
+ }" _1 O. w( \
/ v9 c7 A- h- X) G) m. x. l- R
+ w: \2 }9 r3 R1 B/ Z
def train():
9 |2 ?( |8 ~% o' O) q
" h. D7 \: h' x6 W' o1 d
opt = optim.SGD(params=model.parameters(), lr=0.1)
# K0 r; o- n R# m% y; j
% g$ D& _. ~4 m6 G3 j
for iter in range(20):
$ d7 K t; ?; J7 E
# H- L( @* K6 z; Y2 F
# 1) 消除之前的梯度(如果存在)
1 g, Q% ?! R( p0 N7 W4 u4 ]6 b
& j! T0 Q3 B/ u7 C
opt.zero_grad()
( ~- b% d: M a( h
+ ?2 k) w& l8 h0 x( q0 Y
+ k8 ]2 a' o4 X9 A' I
1 S# d% i, k: j/ h* L" B3 S
# 2) 预测
# B) M a: N# c/ m. m
6 _+ p @5 @7 K+ Q
pred = model(data)
( m2 _. D6 A0 m0 C2 s
/ y a& R4 X$ j
' T5 K3 E2 O! a/ x! \8 r
7 t, ]2 u+ p0 F9 L0 q1 l+ L
# 3) 计算损失
, X6 o+ A7 q, w) |
- A3 M0 x) T$ `; T( e
loss = ((pred - target)**2).sum()
. G0 f- I$ P$ ~! `1 ]1 a
x8 p/ v3 v1 x( Q
; v* d, @, g5 u- q( ]
! Y3 }: Z& U- f; a+ }" }
# 4) 指出那些导致损失的参数(损失回传)
$ ~/ C& r6 ]/ f2 h! }9 X5 G
" y$ W4 @, A4 l R+ t% Z, E$ l8 u
loss.backward()
3 Z3 V; \% v% i! F) q% \
7 H5 F+ x Q( h: L
for name, param in model.named_parameters():
`: L* T- z, q3 W) x! T
, p) y; J# x! o5 e# E
print(name, param.data, param.grad)
7 B. ?6 p0 Y4 D( @/ X7 o
/ l' ]# n- j4 ~3 o9 [) u
# 5) 更新参数
; N) O- W' Z- n9 a9 Y, i- R) P
# n0 W- Z) j7 M( r
opt.step()
! t9 ]7 C3 T& @1 P5 E( X
1 z/ d! n; b8 I% j; B
# E* x; [3 }# ]
+ y! G$ ^2 F0 H# p
# 6) 打印进程
6 e' S' r9 r/ A; b! \3 L
; |6 d: z# I- \- Z/ K/ B
print(loss.data)
' T0 `. A* j5 n( o; r& R: A
2 m$ N j$ l4 Z/ {4 S+ N
) a( M! K, W/ C" C2 F& M
4 c9 P3 m0 `/ ]3 _+ y( V
if __name__ == "__main__":
! o& U8 p E# n1 R& [
* E& @7 S7 M* u4 j
train()
. X1 t$ n6 b+ i. x* h2 W# R" k, d
! g6 \4 g$ a6 j, e& B8 r& \, U7 k% [4 w
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
4 B5 T8 b1 F3 x
& A& V* s! e5 v0 @0 e- u/ N" Z
计算结果:
weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
4 k1 z) H; {* ~2 v( k. X9 E
* O9 o/ `* n5 ?* q
bias tensor([-0.2108]) tensor([-2.6971])
0 |; j3 r% X9 J5 @
! [/ y- S( R3 p) h
tensor(0.8531)
: s+ Y/ `9 B# O2 H$ @3 ^- N
; b' @' V! [! j# w( E
weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466, 1.1232]])
* i. c, L$ q. d+ X8 i: B
& i6 d* d, V; j3 A7 r
bias tensor([0.0589]) tensor([0.7416])
+ S a9 Y" n* R+ c
8 r. L1 R+ Y5 K. U$ \3 }, @4 j/ x
tensor(0.2712)
# n" X. j1 H+ _
! v z" y. x. y
weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692, 0.4266]])
3 g) S9 E& F( L0 j" Z
1 A1 k- G7 n! n2 P1 c4 C
bias tensor([-0.0152]) tensor([-0.2023])
+ @, K7 \' F9 O; q+ ?
" C9 l, q8 X8 z/ Y/ ~5 L
tensor(0.1529)
+ ?3 J3 L# J* S }& N
8 ]4 L; g& P6 e- Q
weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059, 0.4707]])
) d* U( W6 G, m Z+ y
' E& L: z2 Y. ^! A3 x2 ~
bias tensor([0.0050]) tensor([0.0566])
" u) {, e" x% v7 E& y) `
, d+ q# ?$ `( M
tensor(0.0963)
$ X) x @ D- [& r/ H1 }; k) _. o
, c" \! a) ?- {. g- f
weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603, 0.3410]])
' o4 E% v& X/ U. Z1 P
: G- s7 A- ]& S% c9 B$ U
bias tensor([-0.0006]) tensor([-0.0146])
8 e' g' I: l6 t
. u1 X. r0 n& T, m' _* y- r1 o7 I
tensor(0.0615)
2 }1 R$ Z0 ~: R6 F- Z+ l+ f
1 R* r/ P% _7 w/ p
weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786, 0.2825]])
% w! n% b$ P) q- O4 `1 [0 k0 @: m G9 G
- W+ R* l* G9 J& c
bias tensor([0.0008]) tensor([0.0048])
{& G& F) O$ ?
. u5 o" m' Y# l( ?. r! V/ x
tensor(0.0394)
2 N7 }& X& Y9 \" _
5 n a+ Z* W! V, w0 x v: a% u
weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256, 0.2233]])
x% p b$ s, m4 y* [( f1 m
* y/ i, K/ Z7 G
bias tensor([0.0003]) tensor([-0.0006])
4 G; t1 l/ W1 `6 O. }
2 `% z( y6 X4 i& v+ t# v
tensor(0.0252)
/ {: G+ l( c+ ]( s7 h
8 O# t6 `% H6 C$ g9 E
weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797, 0.1793]])
2 H. {0 f' z) h4 u" r
4 g/ ?4 K, _( z* [0 f6 A% B
bias tensor([0.0004]) tensor([0.0008])
8 o7 N+ R( N5 }
( o& m/ k. x: j9 b% S8 Z% w
tensor(0.0161)
: g, Z0 @* X: `, K9 j
/ }4 w4 y! g/ e' T, X
weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440, 0.1432]])
% x; C3 r4 T) x; G) \
( a: x z2 N; S/ U! m" ]) b& S
bias tensor([0.0003]) tensor([0.0003])
( l* t L: w, j$ ~: c/ Q/ Q
5 g: P1 _0 y! L% b( d& m H. U# {
tensor(0.0103)
$ V5 V3 }* e r6 r# ]% H6 N
; t3 D3 S) N/ r. T
weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152, 0.1146]])
: Z% h. L, c- [" j. c
( I# V' p* k5 J# z' \
bias tensor([0.0003]) tensor([0.0004])
; ?- L9 L+ A' o; R+ c- N
9 J5 Q: o% h: M9 n) A3 K
tensor(0.0066)
/ [5 Z" X9 i3 Y& f5 ?0 O
. N+ p3 S3 W! z9 @' c. M
weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922, 0.0917]])
0 l: [9 @" t/ I/ r# d/ q* i% o
1 `" c- A- [' t4 L
bias tensor([0.0003]) tensor([0.0003])
; N5 p% P0 [$ o
- E g' e% }1 ~0 V* G+ K2 H
tensor(0.0042)
3 o9 P# H a* B: F
1 L. h! c2 `- e% Q! H* ~) C
weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738, 0.0733]])
3 A+ W+ G: _. o+ |
9 w) r4 z1 N) c: i
bias tensor([0.0002]) tensor([0.0003])
' J; a& I% q% x: v1 z
8 C5 u% s, j6 y4 o# V: S
tensor(0.0027)
8 d8 J! f# J x
5 e( M! h# x. t Z5 ?" g+ O9 t
weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590, 0.0586]])
# i9 O& o3 g" k, [4 Q$ O4 V( D5 [
" S1 s. f" |: q+ @. e& i+ E" x
bias tensor([0.0002]) tensor([0.0002])
$ p& v a' y# _. }; d8 o
3 V3 x3 N) Q7 q' _; e
tensor(0.0017)
" ~1 T; F4 H5 O7 W; f
6 [0 @' H: e7 |" R) j0 M
weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472, 0.0469]])
, ?' ]! ^: I M: J" ?4 p- X
+ K8 ]! m. t: t" X7 m
bias tensor([0.0002]) tensor([0.0002])
0 S6 E) z, U4 r1 r
1 X) s3 i. L3 Q3 @# T$ s0 C
tensor(0.0011)
% r& U% |- a3 P8 L3 G% i
) p' R' I6 A7 r9 v6 g
weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378, 0.0375]])
) k+ L6 @# ?, U G l
`; D" w9 s8 P% U3 u
bias tensor([0.0001]) tensor([0.0002])
/ i. D. r9 W$ n; K% y+ m) n
% ^6 Z% g8 ~8 |' r" r& Z
tensor(0.0007)
1 _/ n, i8 U/ B/ l
* k7 S/ g' W/ V% h/ T P$ r0 \
weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303, 0.0300]])
' V n; p& ]8 E* A
+ }( [& N# K* v: J
bias tensor([0.0001]) tensor([0.0002])
% y. q4 N/ D1 h3 g( B/ V$ ?" W
3 r) {% @8 E/ T: ?9 [9 V) A8 X
tensor(0.0005)
! o4 L+ r8 `; v, \) S! g$ n+ G* j% S
% y1 _+ B) c2 z* R
weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242, 0.0240]])
5 I1 O, a9 z0 ^2 Q
: X5 g6 X8 _2 X U- }9 ^
bias tensor([0.0001]) tensor([0.0001])
! B# v0 K6 }3 G1 R
# ?; O" Z. ]$ z- W
tensor(0.0003)
1 H q( x5 E" z5 P; l
5 w0 @9 d+ \2 r- _
weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194, 0.0192]])
/ f: i* p0 k, [- A
, [* p" b3 w5 |
bias tensor([9.7973e-05]) tensor([0.0001])
4 w) P" Q8 z2 J
/ |" D' b I: P0 o$ ?4 s" N
tensor(0.0002)
- @7 R0 o6 L' _5 t4 S3 \& g. I) n
$ v% b# [0 ]0 `" V N5 d) V
weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155, 0.0153]])
/ H' b8 z0 y4 R' G4 b$ W
: i/ {! Z" L, X6 A" ^1 U) T
bias tensor([8.5674e-05]) tensor([0.0001])
: h2 T% K" {1 s+ u) X
# V9 L! Z) J5 ~8 {4 Y/ y
tensor(0.0001)
$ g) f9 `2 w& g P' ]: x
+ ~. j2 ~2 C$ k
weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124, 0.0123]])
2 K+ Y3 s1 N: t l
}6 o# b/ j: l% k Y
bias tensor([7.4933e-05]) tensor([9.4233e-05])
. G2 L" Z, |3 L
- }. a- j0 g+ h1 C8 [
tensor(7.6120e-05)
复制代码
" j- ^0 ?, E: b/ [2 c
欢迎光临 数学建模社区-数学中国 (http://www.madio.net/)
Powered by Discuz! X2.5