forked from gy910210/rnn-from-scratch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
layer.py
24 lines (21 loc) · 869 Bytes
/
layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from activation import Tanh
from gate import AddGate, MultiplyGate
mulGate = MultiplyGate()
addGate = AddGate()
activation = Tanh()
class RNNLayer:
def forward(self, x, prev_s, U, W, V):
self.mulu = mulGate.forward(U, x)
self.mulw = mulGate.forward(W, prev_s)
self.add = addGate.forward(self.mulw, self.mulu)
self.s = activation.forward(self.add)
self.mulv = mulGate.forward(V, self.s)
def backward(self, x, prev_s, U, W, V, diff_s, dmulv):
self.forward(x, prev_s, U, W, V)
dV, dsv = mulGate.backward(V, self.s, dmulv)
ds = dsv + diff_s
dadd = activation.backward(self.add, ds)
dmulw, dmulu = addGate.backward(self.mulw, self.mulu, dadd)
dW, dprev_s = mulGate.backward(W, prev_s, dmulw)
dU, dx = mulGate.backward(U, x, dmulu)
return (dprev_s, dU, dW, dV)