Skip to content

Commit

Permalink
0 out sequence values when termination occurs before nsteps is over
Browse files Browse the repository at this point in the history
  • Loading branch information
zacwellmer authored and Ubuntu committed Sep 9, 2018
1 parent 95b9466 commit 2f08595
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
4 changes: 2 additions & 2 deletions treeqn/nstep_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def train(self, obs, next_obs, returns, rewards, masks, actions, values):

# compute the sequences we need to get back reward predictions
action_sequences, reward_sequences, sequence_mask = build_sequences(
[torch.from_numpy(actions), torch.from_numpy(rewards)], self.nenvs, self.nsteps, self.tree_depth, return_mask=True)
[torch.from_numpy(actions), torch.from_numpy(rewards)], masks, self.nenvs, self.nsteps, self.tree_depth, return_mask=True)
action_sequences = cudify(action_sequences.long().squeeze(-1))
reward_sequences = make_variable(reward_sequences.squeeze(-1))
sequence_mask = make_variable(sequence_mask.squeeze(-1))
Expand Down Expand Up @@ -124,7 +124,7 @@ def train(self, obs, next_obs, returns, rewards, masks, actions, values):

if self.use_st_loss:
st_embeddings = tree_result["embeddings"][0]
st_targets, st_mask = build_sequences([st_embeddings.data], self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1)
st_targets, st_mask = build_sequences([st_embeddings.data], masks, self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1)
st_targets = make_variable(st_targets.view(self.batch_size, -1))
st_mask = make_variable(st_mask.view(self.batch_size, -1))

Expand Down
19 changes: 16 additions & 3 deletions treeqn/utils/treeqn_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn.functional as F
from treeqn.utils.pytorch_utils import cudify
Expand All @@ -13,12 +14,23 @@ def discount_with_dones(rewards, dones, gamma):
discounted.append(r)
return discounted[::-1]

def make_seq_mask(mask):
mask = mask.numpy().astype(np.bool)
max_i = np.argmax(mask, axis=0)
if mask[max_i] == True:
mask[max_i:] = True
mask = ~np.expand_dims(mask, axis=-1) # tilde flips true and falses
return torch.from_numpy(mask.astype(np.float))

# some utilities for interpreting the trees we return
def build_sequences(sequences, nenvs, nsteps, depth, return_mask=False, offset=0):
def build_sequences(sequences, masks, nenvs, nsteps, depth, return_mask=False, offset=0):
# sequences are bs x size, containing e.g. rewards, actions, state reps
# returns bs x depth x size processed sequences with a sliding window set by 'depth', padded with 0's
# if return_mask=True also returns a mask showing where the sequences were padded
# This can be used to produce targets for tree outputs, from the true observed sequences
tmp_masks = torch.from_numpy(masks.reshape(nenvs, nsteps).astype(np.int))
tmp_masks = F.pad(tmp_masks, (0, 0, 0, depth+offset), mode="constant", value=0).data

sequences = [s.view(nenvs, nsteps, -1) for s in sequences]
if return_mask:
mask = torch.ones_like(sequences[0]).float()
Expand All @@ -29,7 +41,8 @@ def build_sequences(sequences, nenvs, nsteps, depth, return_mask=False, offset=0
proc_seq = []
for env in range(seq.shape[0]):
for t in range(nsteps):
proc_seq.append(seq[env, t+offset:t+offset+depth, :])
seq_done_mask = make_seq_mask(tmp_masks[env, t+offset:t+offset+depth])
proc_seq.append(seq[env, t+offset:t+offset+depth, :].float() * seq_done_mask.float())
proc_sequences.append(torch.stack(proc_seq))
return proc_sequences

Expand Down Expand Up @@ -101,4 +114,4 @@ def append_list(run, key, val):
if key in run.info:
run.info[key].extend(val)
else:
run.info[key] = [val]
run.info[key] = [val]

0 comments on commit 2f08595

Please sign in to comment.