数学建模社区-数学中国
标题:
随机梯度下降算法SGD(Stochastic gradient descent)
[打印本页]
作者:
2744557306
时间:
2023-11-28 14:57
标题:
随机梯度下降算法SGD(Stochastic gradient descent)
SGD是什么
/ p/ `- g8 z- c) k- N
SGD是Stochastic Gradient Descent(随机梯度下降)的缩写,是深度学习中常用的优化算法之一。SGD是一种基于梯度的优化算法,用于更新深度神经网络的参数。它的基本思想是,在每一次迭代中,随机选择一个小批量的样本来计算损失函数的梯度,并用梯度来更新参数。这种随机性使得算法更具鲁棒性,能够避免陷入局部极小值,并且训练速度也会更快。
4 o# T5 G9 D4 e' \- B+ z
怎么理解梯度?
% |% T2 F# E$ R: D
假设你在爬一座山,山顶是你的目标。你知道自己的位置和海拔高度,但是不知道山顶的具体位置和高度。你可以通过观察周围的地形来判断自己应该往哪个方向前进,并且你可以根据海拔高度的变化来判断自己是否接近山顶。
8 ], ?$ U: N: u3 i2 @6 J
, u( N. L; R0 r: E& ~
在这个例子中,你就可以把自己看作是一个模型,而目标就是最小化海拔高度(损失函数)。你可以根据周围的地形(梯度)来判断自己应该往哪个方向前进,这就相当于使用梯度下降法来更新模型的参数(你的位置和海拔高度)。
4 R9 ?6 M# [9 @" _% ^ i5 i d
' h+ Y4 a3 c0 w# f h
每次你前进一步,就相当于模型更新一次参数,然后重新计算海拔高度。如果你发现海拔高度变小了,就说明你走对了方向,可以继续往这个方向前进;如果海拔高度变大了,就说明你走错了方向,需要回到上一个位置重新计算梯度并选择一个新的方向前进。通过不断重复这个过程,最终你会到达山顶,也就是找到了最小化损失函数的参数。
5 q# E3 \$ @- V
+ B6 x/ ]( h( j4 I; x- w5 q
为什么引入SGD
3 o0 [- a7 X8 s {
深度神经网络通常有大量的参数需要学习,因此优化算法的效率和精度非常重要。传统的梯度下降算法需要计算全部样本的梯度,非常耗时,并且容易受到噪声的影响。随机梯度下降算法则可以使用一小部分样本来计算梯度,从而大大提高了训练速度和鲁棒性。此外,SGD还可以避免陷入局部极小值,使得训练结果更加准确。
4 @7 ~7 ^# P& F' c
3 f8 o7 g9 |$ M' L$ D# F
怎么用SGD
import torch
: B. ?7 e6 H4 }; [9 X& l& _4 n
. b0 Y9 z5 \1 v# a
from torch import nn
7 u9 }( @& L" [3 l+ i; v# W
8 D: h% W, [* l' }
from torch import optim
/ C' e* {9 O2 W) d
$ r2 |& `+ u1 J" U* s0 _* K3 Q
7 T" R1 l" ~8 \8 t
0 k2 s$ B' t v" q
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
* Y4 L$ L$ y4 s& x
. R9 }2 V. G5 p2 o3 b7 q8 w
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
7 |7 b7 _0 L7 R: m5 y
1 Y9 v: {! w- h( K
/ T) ~0 c9 v4 D6 I0 E
( n' z8 Q* {+ P# `) E" q4 _
model = nn.Linear(2, 1)
* o' [( H7 a( ~% K O
: G5 Z, c/ x9 @' F5 ]
- \6 s( k( R$ D$ N, ?6 ?/ }
7 X6 l4 Z- @+ N( W; y5 u# B
def train():
: G L" t( L9 l/ h+ p
6 t! S- w* k) D- i3 E
opt = optim.SGD(params=model.parameters(), lr=0.1)
0 t; Q6 V9 C/ G" `" K' p9 B/ h
3 l: j# l* S; S+ Z* u; T, L
for iter in range(20):
) } V2 w2 Y# f, e' C4 P" p
8 W" A9 |) J3 k: W* X" [1 c
# 1) 消除之前的梯度(如果存在)
: y3 C; X6 o% M! c! A0 e6 x$ ^3 x
( {, d' e/ F t$ I* O! @
opt.zero_grad()
( i& ^. O# @- f& o6 M5 g
! Q! B, o- ], c
" [5 Y2 X7 [4 z* X7 F7 m
# ~+ L# h. d: c3 a4 p2 t
# 2) 预测
: y. ^5 v o7 r7 T9 u2 V
! ]6 ?2 p* s. w& j: p$ A9 P
pred = model(data)
, a0 f% R( R8 ^- C
# Q0 f8 V# }) {, d d8 [
( I. [& {2 s2 w; i$ c9 C* [
' C- ~( a' c" i: J3 W8 g2 {
# 3) 计算损失
0 v- c8 L+ V2 I
! O P+ k; T3 c5 E# @( V8 S
loss = ((pred - target)**2).sum()
/ ]7 `' |7 B6 o* t1 u3 M8 y
- b+ N6 a0 l/ K# b- T9 B8 E
: s7 ?/ Z3 @8 |9 ?
( @/ f- {3 T% J( Y9 f! d, ?. H# p. K# A( Y
# 4) 指出那些导致损失的参数(损失回传)
. x: n. K, o2 B5 k: m6 e
% A: Y* b4 x4 F
loss.backward()
1 t1 ^4 Y* {2 ]1 ?5 V6 o1 X# U# [" t
7 C# ~$ F7 D0 }. x* i, a
for name, param in model.named_parameters():
- X8 d; P# C- O Z+ z! \* t% s+ ~7 k
2 D+ y( h( i8 ]/ D' `% S
print(name, param.data, param.grad)
: b7 s4 C% ^" [! d' d/ N! H
U+ g3 |( _$ E+ ?3 n1 B% h2 A1 a3 }
# 5) 更新参数
' r# [; I4 I% P, C. ]0 x7 o
! l7 a) c$ r2 l3 t. f
opt.step()
/ r2 p a4 i0 N, \
: b s0 [: j6 L; Z6 Y4 F7 G
- {, _$ { _3 T& [
" U+ x2 F) b1 f t
# 6) 打印进程
8 N2 `, ?4 K4 V- ^
3 J9 Q$ W. `2 t1 h
print(loss.data)
5 {6 u7 ]8 i7 F8 Q y8 j; W) Q
, h/ \) @& u" n4 g/ J/ {7 [1 [! \
. l' P4 u; g9 i: Z, E- [
/ @# j% P5 t& l! X! y9 u0 a! C1 Q
if __name__ == "__main__":
2 S* Z% @7 H, k! _( A" L
# l' Q7 `' e$ i( P" b
train()
3 T) K3 W& M, R) B. A% Z5 Q
6 w" o0 I+ d3 M1 f& m) w" W
复制代码
param.data是参数的当前值,而param.grad是参数的梯度值。在进行反向传播计算时,每个参数都会被记录其梯度信息,以便在更新参数时使用。通过访问param.data和param.grad,可以查看参数当前的值和梯度信息。值得注意的是,param.grad在每次调用backward()后都会自动清空,因此如果需要保存梯度信息,应该在计算完梯度之后及时将其提取并保存到其他地方。
: X* `: p2 Y: ~7 H2 o
! N5 l1 r' `$ C s8 {
计算结果:
weight tensor([[0.4456, 0.3017]]) tensor([[-2.4574, -0.7452]])
; v% L/ R8 z8 z6 W0 F \0 V2 |" f
$ o# F/ Y& s' T o5 j! Y2 b' {8 Y
bias tensor([-0.2108]) tensor([-2.6971])
* k( f L8 s9 I* S/ E
( C. z1 a/ B0 _9 J" J2 o" G
tensor(0.8531)
" m! e# n$ H( |
) A0 j J" H: G0 b2 `, Y
weight tensor([[0.6913, 0.3762]]) tensor([[-0.2466, 1.1232]])
# v- m% f- [# A7 s5 X* N2 h
' Q9 I1 W$ P# \1 e5 {' U" q
bias tensor([0.0589]) tensor([0.7416])
# u2 r% D( x# E' d/ H
4 S- w5 G" A& d" S
tensor(0.2712)
; n; V* {& e3 S
4 n- ~8 V4 R1 S$ X- B; ^) Y m6 N
weight tensor([[0.7160, 0.2639]]) tensor([[-0.6692, 0.4266]])
+ W3 c. o; ^4 V4 q. }3 L2 g
, {0 d+ g6 J" J4 S) p* C {
bias tensor([-0.0152]) tensor([-0.2023])
( @5 K4 C3 G9 P
7 r1 i8 s5 U+ n/ E1 w9 k
tensor(0.1529)
4 c% o3 F( r6 F$ J
" D% g! J- M$ v D; |
weight tensor([[0.7829, 0.2212]]) tensor([[-0.4059, 0.4707]])
5 Z$ y; s1 S: Z( X9 V% \3 u
+ c& M+ E4 v1 ~
bias tensor([0.0050]) tensor([0.0566])
1 g( d) }4 g1 K( U& J
$ X4 h* {4 t3 O$ ?# L$ D! t2 \5 T: Y
tensor(0.0963)
m% Z( U+ p4 `9 ]
% Y% B% c2 ~6 w
weight tensor([[0.8235, 0.1741]]) tensor([[-0.3603, 0.3410]])
$ v2 j# l# i6 T9 i
% c8 P/ m9 B; b( w% h6 D
bias tensor([-0.0006]) tensor([-0.0146])
* b/ I% `6 } E
1 p0 A7 u, o# o8 B/ W) ~' {
tensor(0.0615)
! p4 P2 L5 i. ?; m6 k( [& `, @
) R6 l; ^* S5 U% b2 ?
weight tensor([[0.8595, 0.1400]]) tensor([[-0.2786, 0.2825]])
5 U0 O6 a# n9 a/ F& J
: H$ }2 m; f3 \) x; k1 {
bias tensor([0.0008]) tensor([0.0048])
2 K. a5 S7 S' D s: |
$ A5 r+ f* m" T+ B9 {
tensor(0.0394)
/ Y3 Q8 V1 T8 Q4 B* E
9 e$ G3 {$ Q7 {
weight tensor([[0.8874, 0.1118]]) tensor([[-0.2256, 0.2233]])
4 O1 K8 Y8 J4 y
S6 `/ c2 a9 H
bias tensor([0.0003]) tensor([-0.0006])
9 I* d& ]/ G9 k- W) Q$ f
9 h( Y+ c7 b1 ]3 L+ _
tensor(0.0252)
& [, `0 Y5 A9 V& r
+ B8 |7 l# |; a6 ~
weight tensor([[0.9099, 0.0895]]) tensor([[-0.1797, 0.1793]])
6 s3 S2 p! Q( n" P. c6 r/ _2 m
) x( ]4 J, U8 J t$ J
bias tensor([0.0004]) tensor([0.0008])
$ B% e7 o. {! @: ?
. s s. m/ s! F2 Y8 ~' c$ o* L
tensor(0.0161)
% L5 \+ Y* I5 _/ i
/ w, R1 l5 l5 g
weight tensor([[0.9279, 0.0715]]) tensor([[-0.1440, 0.1432]])
+ }2 G( D, V6 \$ `
$ j- V( L. y% q4 s; _" U
bias tensor([0.0003]) tensor([0.0003])
) S( U% O; M8 ?3 |0 L! p( u, G
1 j5 |# F! [8 O J
tensor(0.0103)
( g1 A4 [; b H3 f2 x8 }6 {7 _
5 Y* R! Q# _1 A% y9 B0 @8 E3 N
weight tensor([[0.9423, 0.0572]]) tensor([[-0.1152, 0.1146]])
6 H8 q: s: r4 l% ~$ B
: P8 P+ |9 f) A* [
bias tensor([0.0003]) tensor([0.0004])
E; Z o7 Q& b9 e! m, v
' H! u# g1 t4 N" B/ d
tensor(0.0066)
& J5 `. h9 }* G. r( L
+ v3 T. K( e* C9 n( c" b0 Z
weight tensor([[0.9538, 0.0458]]) tensor([[-0.0922, 0.0917]])
& @5 h; R" h- {6 N7 h
v% O2 H! K4 [2 n
bias tensor([0.0003]) tensor([0.0003])
8 Y6 ?7 B, l0 k# Q" z/ Y! i0 W: h
9 I( D. t7 B' o' G1 k
tensor(0.0042)
) y2 `/ h2 Q' l2 g
: r* T& Y" _% g1 N% m
weight tensor([[0.9630, 0.0366]]) tensor([[-0.0738, 0.0733]])
6 r" J5 G" F' [1 G; \
* `6 z s, m M' I8 g i
bias tensor([0.0002]) tensor([0.0003])
9 I1 S& B! A+ f0 k1 W: C1 `
, @+ X" S5 B+ K; {0 l7 _
tensor(0.0027)
7 e/ t6 A. S0 L3 i8 Y U1 E
o9 v0 i0 f5 t4 ^* O1 ?
weight tensor([[0.9704, 0.0293]]) tensor([[-0.0590, 0.0586]])
9 w& q$ s+ `3 x$ T
8 L: O$ C6 x/ g+ j& m+ C$ Z s2 V
bias tensor([0.0002]) tensor([0.0002])
1 L+ V1 h5 @9 q! ^7 g
8 N( B0 d! |( A: ^" D( t# ~8 Z
tensor(0.0017)
/ L) U# J+ t" z* j: E3 s
% V; |. {8 i1 r9 q
weight tensor([[0.9763, 0.0234]]) tensor([[-0.0472, 0.0469]])
! a& a4 ]* i6 j0 g4 w" P; Q( d# c
1 M, ?. G0 G* _0 |1 @% E* X$ J& ~
bias tensor([0.0002]) tensor([0.0002])
/ Z/ F" t5 C3 B( B+ F3 U: C
" ]$ t( `- L& S8 [9 Z$ k
tensor(0.0011)
' T6 x1 @4 T# V: L
3 _8 }* K. c" x
weight tensor([[0.9811, 0.0187]]) tensor([[-0.0378, 0.0375]])
7 K2 d! ? w x6 \: C
. B; v) E& }* O' Z
bias tensor([0.0001]) tensor([0.0002])
4 m% Z' v) x8 D+ | s3 X0 C
7 T( z6 a9 p1 a5 {8 t0 M
tensor(0.0007)
3 s# D+ g+ m5 C1 H3 k9 Y( u) U+ e* a
6 z8 i) c' q ^3 Y5 Z
weight tensor([[0.9848, 0.0150]]) tensor([[-0.0303, 0.0300]])
0 d) k( P% `& s/ G1 o$ f
' Z# W4 r) M% T6 r8 R
bias tensor([0.0001]) tensor([0.0002])
. |5 n7 Y3 G% E" m0 D
4 N4 M/ ~' Q) S) j/ E+ r6 ^
tensor(0.0005)
9 | u& z, C1 V: U( x7 ~: i, `
+ F8 a& T$ S( n I& U3 H: |
weight tensor([[0.9879, 0.0120]]) tensor([[-0.0242, 0.0240]])
4 K, {- h& M- `. x' o$ w
0 B0 T- G2 c- `' b+ t
bias tensor([0.0001]) tensor([0.0001])
* ^& |3 M* e! s8 P) R3 z
) f* V6 N1 c$ O
tensor(0.0003)
- F4 @+ U" M! B' Z& f) p/ h
& x P$ k0 b2 d6 J, O
weight tensor([[0.9903, 0.0096]]) tensor([[-0.0194, 0.0192]])
4 q' \4 Z8 w3 Q' L2 B
# K: {7 Z) U8 m& H3 f& {1 j6 j6 T. W
bias tensor([9.7973e-05]) tensor([0.0001])
. ]9 B. ]/ _: J+ r
5 V0 d% m1 |/ I! R, I8 {' y5 i
tensor(0.0002)
9 ~% M& r: A# i0 n4 a& f
* g8 p. e6 J" ~9 u$ _& z
weight tensor([[0.9922, 0.0076]]) tensor([[-0.0155, 0.0153]])
5 A5 E: o$ M+ \# j- U P
, R( M. E7 W; @/ s
bias tensor([8.5674e-05]) tensor([0.0001])
$ P+ [7 W: i/ Y+ W7 f+ Q$ Z5 p
) E7 l) c$ c3 D
tensor(0.0001)
# J" m6 g4 Q5 Q/ R
1 X3 M8 K2 k, u) O. F
weight tensor([[0.9938, 0.0061]]) tensor([[-0.0124, 0.0123]])
2 J' D2 F, O$ B( I. W, k' B7 W, |, n
" T J! f2 N9 Y6 n" ]* a
bias tensor([7.4933e-05]) tensor([9.4233e-05])
; M9 Q4 j& T8 r' k) a* Y
% _1 y8 V Q0 Z8 q) B. ^. ] U
tensor(7.6120e-05)
复制代码
1 ~/ F T/ ^& Z+ K2 I3 M
欢迎光临 数学建模社区-数学中国 (http://www.madio.net/)
Powered by Discuz! X2.5