Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 26, 2024
1 parent e8774c3 commit cc96914
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
23 changes: 10 additions & 13 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,22 +982,20 @@ def _find_start_stop_traj(

# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
if not at_capacity:
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
length = trajectory.shape[0]
else:
# TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True

# We presume that not done at the end means that the traj spans across end and beginning of storage
length = end.shape[0]
if not at_capacity:
end = end.clone()
end[length - 1] = True
ndim = end.ndim

if not at_capacity:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
else:
if at_capacity:
# we must have at least one end by traj to individuate trajectories
# so if no end can be found we set it manually
if cursor is not None:
Expand All @@ -1019,7 +1017,6 @@ def _find_start_stop_traj(
mask = ~end.any(0, True)
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
end = torch.masked_fill(mask, end, 1)
ndim = end.ndim
if ndim == 0:
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
Expand Down Expand Up @@ -1126,7 +1123,7 @@ def _get_stop_and_length(self, storage, fallback=True):
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
trajectory=trajectory,
trajectory=trajectory.clone(),
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
Expand Down
4 changes: 1 addition & 3 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,9 +901,7 @@ def max_size_along_dim0(data_shape):

if is_tensor_collection(data):
out = data.to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.clone()
out = out.zero_()
out = torch.empty_like(out.expand(max_size_along_dim0(data.shape)))
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
Expand Down

0 comments on commit cc96914

Please sign in to comment.