diff --git a/treeqn/nstep_learn.py b/treeqn/nstep_learn.py index a07515f..57cd39f 100644 --- a/treeqn/nstep_learn.py +++ b/treeqn/nstep_learn.py @@ -84,7 +84,7 @@ def train(self, obs, next_obs, returns, rewards, masks, actions, values): # compute the sequences we need to get back reward predictions action_sequences, reward_sequences, sequence_mask = build_sequences( - [torch.from_numpy(actions), torch.from_numpy(rewards)], self.nenvs, self.nsteps, self.tree_depth, return_mask=True) + [torch.from_numpy(actions), torch.from_numpy(rewards)], masks, self.nenvs, self.nsteps, self.tree_depth, return_mask=True) action_sequences = cudify(action_sequences.long().squeeze(-1)) reward_sequences = make_variable(reward_sequences.squeeze(-1)) sequence_mask = make_variable(sequence_mask.squeeze(-1)) @@ -124,7 +124,7 @@ def train(self, obs, next_obs, returns, rewards, masks, actions, values): if self.use_st_loss: st_embeddings = tree_result["embeddings"][0] - st_targets, st_mask = build_sequences([st_embeddings.data], self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1) + st_targets, st_mask = build_sequences([st_embeddings.data], masks, self.nenvs, self.nsteps, self.tree_depth, return_mask=True, offset=1) st_targets = make_variable(st_targets.view(self.batch_size, -1)) st_mask = make_variable(st_mask.view(self.batch_size, -1)) diff --git a/treeqn/utils/treeqn_utils.py b/treeqn/utils/treeqn_utils.py index f45fe3b..4a8166d 100644 --- a/treeqn/utils/treeqn_utils.py +++ b/treeqn/utils/treeqn_utils.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn.functional as F from treeqn.utils.pytorch_utils import cudify @@ -13,12 +14,23 @@ def discount_with_dones(rewards, dones, gamma): discounted.append(r) return discounted[::-1] +def make_seq_mask(mask): + mask = mask.numpy().astype(np.bool) + max_i = np.argmax(mask, axis=0) + if mask[max_i] == True: + mask[max_i:] = True + mask = ~np.expand_dims(mask, axis=-1) # tilde flips true and falses + return torch.from_numpy(mask.astype(np.float)) + # some utilities for interpreting the trees we return -def build_sequences(sequences, nenvs, nsteps, depth, return_mask=False, offset=0): +def build_sequences(sequences, masks, nenvs, nsteps, depth, return_mask=False, offset=0): # sequences are bs x size, containing e.g. rewards, actions, state reps # returns bs x depth x size processed sequences with a sliding window set by 'depth', padded with 0's # if return_mask=True also returns a mask showing where the sequences were padded # This can be used to produce targets for tree outputs, from the true observed sequences + tmp_masks = torch.from_numpy(masks.reshape(nenvs, nsteps).astype(np.int)) + tmp_masks = F.pad(tmp_masks, (0, 0, 0, depth+offset), mode="constant", value=0).data + sequences = [s.view(nenvs, nsteps, -1) for s in sequences] if return_mask: mask = torch.ones_like(sequences[0]).float() @@ -29,7 +41,8 @@ def build_sequences(sequences, nenvs, nsteps, depth, return_mask=False, offset=0 proc_seq = [] for env in range(seq.shape[0]): for t in range(nsteps): - proc_seq.append(seq[env, t+offset:t+offset+depth, :]) + seq_done_mask = make_seq_mask(tmp_masks[env, t+offset:t+offset+depth]) + proc_seq.append(seq[env, t+offset:t+offset+depth, :].float() * seq_done_mask.float()) proc_sequences.append(torch.stack(proc_seq)) return proc_sequences @@ -101,4 +114,4 @@ def append_list(run, key, val): if key in run.info: run.info[key].extend(val) else: - run.info[key] = [val] \ No newline at end of file + run.info[key] = [val]