Skip to content

Commit

Permalink
[Refactor] Use empty_like in storage construction
Browse files Browse the repository at this point in the history
ghstack-source-id: 28cd569bd4abf472991b82b3eba9fe333b5cd68f
Pull Request resolved: #2455
  • Loading branch information
vmoens committed Sep 26, 2024
1 parent 8542d2e commit b4d543e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
23 changes: 10 additions & 13 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,22 +982,20 @@ def _find_start_stop_traj(

# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
if not at_capacity:
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
length = trajectory.shape[0]
else:
# TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True

# We presume that not done at the end means that the traj spans across end and beginning of storage
length = end.shape[0]
if not at_capacity:
end = end.clone()
end[length - 1] = True
ndim = end.ndim

if not at_capacity:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
else:
if at_capacity:
# we must have at least one end by traj to individuate trajectories
# so if no end can be found we set it manually
if cursor is not None:
Expand All @@ -1019,7 +1017,6 @@ def _find_start_stop_traj(
mask = ~end.any(0, True)
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
end = torch.masked_fill(mask, end, 1)
ndim = end.ndim
if ndim == 0:
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
Expand Down Expand Up @@ -1126,7 +1123,7 @@ def _get_stop_and_length(self, storage, fallback=True):
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
trajectory=trajectory,
trajectory=trajectory.clone(),
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
Expand Down
4 changes: 1 addition & 3 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,9 +902,7 @@ def max_size_along_dim0(data_shape):

if is_tensor_collection(data):
out = data.to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.clone()
out = out.zero_()
out = torch.empty_like(out.expand(max_size_along_dim0(data.shape)))
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
Expand Down

1 comment on commit b4d543e

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: b4d543e Previous: 8542d2e Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 36.302446727976736 iter/sec (stddev: 0.15686418142162384) 185.54331334899317 iter/sec (stddev: 0.0008683457050601657) 5.11

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.