Skip to content

Commit

Permalink
Correct cell state for LSTM. && Small changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
xfffrank committed Feb 24, 2020
1 parent 476708e commit 5d4100a
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 24 deletions.
8 changes: 5 additions & 3 deletions early_stop_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def set_seed(seed):
# the number of training epoches
num_of_epoch = 10
# the number of batch size for gradient descent when training
batch_sz = 50
batch_sz = 64

# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
Expand Down Expand Up @@ -161,7 +161,8 @@ def main():
text = train.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
curr_step = 0
# set up the initial input for lstm
h_0 = torch.zeros([1,1,128]).to(device)
h_0 = torch.zeros([1,1,128]).to(device)
c_0 = torch.zeros([1,1,128]).to(device)
saved_log_probs = []
baseline_value_ep = []
while (curr_step < 20):
Expand All @@ -172,8 +173,9 @@ def main():
# read a chunk
text_input = text[curr_step]
# hidden state
ht = clstm(text_input, h_0) # 1 * 128
ht, ct = clstm(text_input, h_0, c_0) # 1 * 128
h_0 = ht.unsqueeze(0).to(device) # 1 * 1 * 128, next input of lstm
c_0 = ct
# compute a baseline value for the value network
ht_ = ht.clone().detach().requires_grad_(True).to(device)
bi = value_net(ht_)
Expand Down
6 changes: 4 additions & 2 deletions early_stop_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def set_seed(seed):
# the number of training epoches
num_of_epoch = 10
# the number of batch size for gradient descent when training
batch_sz = 50
batch_sz = 64

# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
Expand Down Expand Up @@ -157,6 +157,7 @@ def main():
curr_step = 0
# set up the initial input for lstm
h_0 = torch.zeros([1,1,128]).to(device)
c_0 = torch.zeros([1,1,128]).to(device)
saved_log_probs = []
baseline_value_ep = []
cost_ep = [] # collect the computational costs for every time step
Expand All @@ -168,8 +169,9 @@ def main():
# read a chunk
text_input = text[curr_step]
# hidden state
ht = clstm(text_input, h_0) # 1 * 128
ht, ct = clstm(text_input, h_0, c_0) # 1 * 128
h_0 = ht.unsqueeze(0).to(device) # 1 * 1 * 128, next input of lstm
c_0 = ct
# compute a baseline value for the value network
ht_ = ht.clone().detach().requires_grad_(True).to(device)
bi = value_net(ht_)
Expand Down
15 changes: 7 additions & 8 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def __init__(self, input_dim, embedding_dim, ker_size, n_filters, hidden_dim):
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.conv = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(ker_size, embedding_dim))
self.lstm = nn.LSTM(input_size=n_filters, hidden_size=hidden_dim)
self.dropout = nn.Dropout(p=0.1)
self.dropout = nn.Dropout(p=0.2)
self.relu = nn.ReLU()

def forward(self, text, h_0):
def forward(self, text, h_0, c_0):
# CNN and LSTM network
'''
At every time step, the model reads one chunk which has a size of 20 words.
Expand Down Expand Up @@ -56,10 +56,9 @@ def forward(self, text, h_0):
conved = conved.squeeze(3) # 1 * 128 * 16
conved = torch.transpose(conved, 1, 2) # 1 * 16 * 128
conved = torch.transpose(conved, 1, 0) # 16 * 1 * 128
c_0 = torch.zeros([1, batch, 128]).to(device)
output, (hidden, cell) = self.lstm(conved, (h_0, c_0))
ht = hidden.squeeze(0) # 1 * 128
return ht
return ht, cell


class Policy_S(nn.Module):
Expand All @@ -73,7 +72,7 @@ def __init__(self, input_dim, hidden_dim, output_dim):
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(p=0.1)
self.dropout = nn.Dropout(p=0.5)
self.relu = nn.ReLU()

def forward(self, ht):
Expand Down Expand Up @@ -104,7 +103,7 @@ def __init__(self, input_dim, hidden_dim, output_dim):
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(p=0.1)
self.dropout = nn.Dropout(p=0.5)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)

Expand Down Expand Up @@ -135,7 +134,7 @@ def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
#self.fc = nn.Linear(input_dim, output_dim)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.dropout = nn.Dropout(p=0.1)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()

Expand All @@ -158,7 +157,7 @@ def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
#self.fc = nn.Linear(input_dim, output_dim)
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.dropout = nn.Dropout(p=0.1)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()

Expand Down
10 changes: 6 additions & 4 deletions skim_reread_es_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def set_seed(seed):
learning_rate = 0.001

# the number of training epoches
num_of_epoch = 10
num_of_epoch = 8
# the number of batch size for gradient descent when training
batch_sz = 50
batch_sz = 64

# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
Expand Down Expand Up @@ -169,6 +169,7 @@ def main():
text = train.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
curr_step = 0 # the position of the current chunk
h_0 = torch.zeros([1,1,128]).to(device) # run on GPU
c_0 = torch.zeros([1,1,128]).to(device) # run on GPU
count = 0 # maximum skim/reread time: 5
baseline_value_ep = []
saved_log_probs = [] # for the use of policy gradient update
Expand All @@ -178,13 +179,14 @@ def main():
count += 1
# pass the input through cnn-lstm and policy s
text_input = text[curr_step] # text_input 1*20
ht = clstm(text_input, h_0) # 1 * 128
ht, ct = clstm(text_input, h_0, c_0) # 1 * 128
# separate the value which is the input of value net
ht_ = ht.clone().detach().requires_grad_(True)
# compute a baseline value for the value network
bi = value_net(ht_)
# 1 * 1 * 128, next input of lstm
h_0 = ht.unsqueeze(0)
c_0 = ct
# draw a stop decision
stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
stop_decision = stop_decision.item()
Expand Down Expand Up @@ -233,7 +235,7 @@ def main():
policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
baseline_value_batch.append(torch.cat(value_losses).sum())
# update gradients
if (index + 1) % batch_sz == 0: # take the average of 50 samples
if (index + 1) % batch_sz == 0: # take the average of samples, backprop
finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]

Expand Down
10 changes: 6 additions & 4 deletions skim_reread_es_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def set_seed(seed):
learning_rate = 0.001

# the number of training epoches
num_of_epoch = 10
num_of_epoch = 8
# the number of batch size for gradient descent when training
batch_sz = 50
batch_sz = 64

# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
Expand Down Expand Up @@ -163,6 +163,7 @@ def main():
text = train.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
curr_step = 0 # the position of the current chunk
h_0 = torch.zeros([1,1,128]).to(device) # run on GPU
c_0 = torch.zeros([1,1,128]).to(device)
count = 0 # maximum skim/reread time: 5
baseline_value_ep = []
saved_log_probs = [] # for the use of policy gradient update
Expand All @@ -174,13 +175,14 @@ def main():
count += 1
# pass the input through cnn-lstm and policy s
text_input = text[curr_step] # text_input 1*20
ht = clstm(text_input, h_0) # 1 * 128
ht, ct = clstm(text_input, h_0, c_0) # 1 * 128
# separate the value which is the input of value net
ht_ = ht.clone().detach().requires_grad_(True)
# compute a baseline value for the value network
bi = value_net(ht_)
# 1 * 1 * 128, next input of lstm
h_0 = ht.unsqueeze(0)
c_0 = ct
# draw a stop decision
stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
stop_decision = stop_decision.item()
Expand Down Expand Up @@ -223,7 +225,7 @@ def main():
policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
baseline_value_batch.append(torch.cat(value_losses).sum())
# update gradients
if (index + 1) % batch_sz == 0: # take the average of 50 samples
if (index + 1) % batch_sz == 0: # take the average of samples, backprop
finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]

Expand Down
10 changes: 7 additions & 3 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ def evaluate(clstm, policy_s, policy_n, policy_c, iterator):
text = batch.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
curr_step = 0
h_0 = torch.zeros([1,1,128]).to(device)
c_0 = torch.zeros([1,1,128]).to(device)
count = 0
while curr_step < 20 and count < 5: # loop until a text can be classified or currstep is up to 20
count += 1
# pass the input through cnn-lstm and policy s
text_input = text[curr_step] # text_input 1*20
ht = clstm(text_input, h_0) # 1 * 128
ht, ct = clstm(text_input, h_0, c_0) # 1 * 128
h_0 = ht.unsqueeze(0) # 1 * 1 * 128, next input of lstm
c_0 = ct
# draw a stop decision
stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
flops_sum += clstm_cost + s_cost
Expand Down Expand Up @@ -152,7 +154,8 @@ def evaluate_earlystop(clstm, policy_s, policy_c, iterator):
text = batch.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
curr_step = 0
# set up the initial input for lstm
h_0 = torch.zeros([1,1,128]).to(device)
h_0 = torch.zeros([1,1,128]).to(device)
c_0 = torch.zeros([1,1,128]).to(device)
saved_log_probs = []
while (curr_step < 20):
'''
Expand All @@ -162,8 +165,9 @@ def evaluate_earlystop(clstm, policy_s, policy_c, iterator):
# read a chunk
text_input = text[curr_step]
# hidden state
ht = clstm(text_input, h_0) # 1 * 128
ht, ct = clstm(text_input, h_0, c_0) # 1 * 128
h_0 = ht.unsqueeze(0).cuda() # 1 * 1 * 128, next input of lstm
c_0 = ct
# draw a stop decision
stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
stop_decision = stop_decision.item()
Expand Down

0 comments on commit 5d4100a

Please sign in to comment.