-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgrefenstette.py
81 lines (69 loc) · 2.44 KB
/
grefenstette.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import typing
import torch
from torch_extras.layer import Layer
from .common import StackRNNBase
class GrefenstetteRNN(StackRNNBase):
def __init__(self,
input_size: int,
stack_embedding_size: int,
controller: typing.Callable
):
super().__init__(input_size, stack_embedding_size, controller)
self.action_layer = Layer(
self.controller.output_size(),
2,
torch.nn.Sigmoid()
)
self.push_value_layer = Layer(
self.controller.output_size(),
stack_embedding_size,
torch.nn.Tanh()
)
def initial_stack(self, batch_size, reading_size):
return GrefenstetteStack(
elements=[],
bottom=torch.zeros((batch_size, reading_size), device=self.device)
)
class State(StackRNNBase.State):
def compute_stack(self, hidden_state, stack):
actions = self.rnn.action_layer(hidden_state)
push_value = self.rnn.push_value_layer(hidden_state)
return stack.next(actions, push_value), actions
class GrefenstetteStack:
def __init__(self, elements, bottom):
self.elements = elements
self.bottom = bottom
def reading(self):
device = self.bottom.device
batch_size = self.bottom.size(0)
result = self.bottom
strength_left = torch.ones((batch_size,), device=device)
for value, strength in reversed(self.elements):
result = result + value * torch.min(
strength,
torch.nn.functional.relu(strength_left)
)[:, None]
strength_left = strength_left - strength
return result
def next(self, actions, push_value):
return GrefenstetteStack(
self.next_elements(actions, push_value),
self.bottom
)
def next_elements(self, actions, push_value):
push_strength = actions[:, 0]
pop_strength = actions[:, 1]
result = []
strength_left = pop_strength
for value, strength in reversed(self.elements):
result.append((
value,
torch.nn.functional.relu(
strength -
torch.nn.functional.relu(strength_left)
)
))
strength_left = strength_left - strength
result.reverse()
result.append((push_value, push_strength))
return result