-
Notifications
You must be signed in to change notification settings - Fork 0
/
lstm_input.py
112 lines (93 loc) · 3.21 KB
/
lstm_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# This script presents how batching in LSTM work
# how to prepare input for lstm,
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
def init_inputs(num=1, seq_len=4, shape=(1, 2)):
inputs = list()
for i in range(num):
# i*seq_len+(k+1)
inputs.append([(i+1)*torch.ones(shape) for k in range(seq_len)])
return inputs
def init_hidden(batch_size=1, hidden_size=3):
return (2*torch.ones(1, batch_size, hidden_size), 2*torch.ones(1, batch_size, hidden_size))
lstm = nn.LSTM(2, 3) # Input dim is 2, output dim is 3
linear = nn.Linear(3, 2)
# fill linear weight with 1.0
linear.weight.data.fill_(1.0)
# gen some input, one sequence with 4 vectors
inputs = init_inputs(num=1, shape=(1, 2))
# initialize the hidden state.
hidden = init_hidden(hidden_size=3)
for seq in inputs:
for i in seq:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
print(out, hidden)
lin_out = linear(out)
print(lin_out)
# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument to the lstm at a later time
# Add the extra 2nd dimension
print(inputs)
inputs = inputs[0]
print(inputs)
inputs2 = torch.cat(inputs)
print(inputs2)
inputs2 = inputs2.view(len(inputs2), 1, -1)
print(inputs2)
hidden = init_hidden(hidden_size=3)
out, hidden = lstm(inputs2, hidden)
print(out)
print(hidden)
lin_out = linear(out)
print(lin_out)
# cretea 2D list of list of tensors of default shape
# each list symbolize a sequence(time series) of items,
# each item is tensor of shape (1,dim) eg it could be one word embeding
inputs_batch = init_inputs(2)
print(inputs_batch)
# create 3D tensor, this is placeholder for transformed inputs
# this tensor has shapes [batch, seq_len, seq_item_size]
# tensor([
# # 1 batch
# [
# 1 seq item [ 0., 0.],
# 2 seq item [ 0., 0.],
# 3 seq item [ 0., 0.],
# 4 seq item [ 0., 0.]
# ],
# # 2 batch
# [
# 1 seq item [ 0., 0.],
# 2 seq item [ 0., 0.],
# 3 seq item [ 0., 0.],
# 4 seq item [ 0., 0.]
# ],
# ])
input_tensor = torch.zeros(2, 4, 2)
# go through the list of sequences and concat all sequence items into 2D tensor
for i, seq in enumerate(inputs_batch):
print(seq, i)
# put each 2D sequence of shape [seq_len, seq_item_size]
# into 3D tensor, first dim is a batch
input_tensor[i] = torch.cat(inputs_batch[i])
# we have 3D tensor
# permutate dimensions in order to have [seq_len, batch, seq_item_size]
input_tensor = input_tensor.permute([1, 0, 2])
print(input_tensor)
hidden = init_hidden(batch_size=2, hidden_size=3)
out, hidden = lstm(input_tensor, hidden)
print(out)
print(hidden)
lin_out = linear(out)
print(lin_out)