-
Notifications
You must be signed in to change notification settings - Fork 0
/
optim.go
54 lines (41 loc) · 1.1 KB
/
optim.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package whale
import "github.com/hidetatz/whale/tensor"
type Optimizer interface {
Optimize(v *Variable) error
}
type SGD struct {
learnRate *tensor.Tensor
}
func NewSGD(learnRate float32) *SGD {
return &SGD{learnRate: tensor.Scalar(learnRate)}
}
func (s *SGD) Optimize(v *Variable) error {
delta := v.GetGrad().GetData().Mul(s.learnRate)
newData := v.GetData().Sub(delta)
v.SetData(newData)
return nil
}
type MomentumSGD struct {
learnRate *tensor.Tensor
momentum *tensor.Tensor
velocities map[*Variable]*tensor.Tensor
}
func NewMomentumSGD(learnRate, momentum float32) *MomentumSGD {
return &MomentumSGD{
learnRate: tensor.Scalar(learnRate),
momentum: tensor.Scalar(momentum),
velocities: make(map[*Variable]*tensor.Tensor),
}
}
func (s *MomentumSGD) Optimize(v *Variable) error {
if _, ok := s.velocities[v]; !ok {
s.velocities[v] = tensor.ZerosLike(v.GetData())
}
velocity := s.velocities[v]
velocity = velocity.Mul(s.momentum)
delta := s.learnRate.Mul(v.GetGrad().GetData())
velocity = velocity.Sub(delta)
newv := v.GetData().Add(velocity)
v.SetData(newv)
return nil
}