Skip to content

Commit

Permalink
[Refactor] Use empty_like in storage construction
Browse files Browse the repository at this point in the history
ghstack-source-id: 28cd569bd4abf472991b82b3eba9fe333b5cd68f
Pull Request resolved: #2455
  • Loading branch information
vmoens committed Sep 26, 2024
1 parent ca3a595 commit 9409cc8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 49 deletions.
95 changes: 49 additions & 46 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,22 +982,20 @@ def _find_start_stop_traj(

# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
if not at_capacity:
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
length = trajectory.shape[0]
else:
# TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True

# We presume that not done at the end means that the traj spans across end and beginning of storage
length = end.shape[0]
if not at_capacity:
end = end.clone()
end[length - 1] = True
ndim = end.ndim

if not at_capacity:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
else:
if at_capacity:
# we must have at least one end by traj to individuate trajectories
# so if no end can be found we set it manually
if cursor is not None:
Expand All @@ -1019,7 +1017,6 @@ def _find_start_stop_traj(
mask = ~end.any(0, True)
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
end = torch.masked_fill(mask, end, 1)
ndim = end.ndim
if ndim == 0:
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
Expand Down Expand Up @@ -1109,57 +1106,63 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length):
result[:, 0] = result[:, 0] % storage_length
return result

@torch.no_grad()
def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")

current_storage = storage[:]
if self._fetch_traj:
# We first try with the traj_key

if isinstance(storage, TensorStorage):
key = self._used_traj_key
else:
key = self.traj_key
try:
if isinstance(storage, TensorStorage):
trajectory = storage[:][self._used_traj_key]
else:
try:
trajectory = storage[:][self.traj_key]
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
trajectory=trajectory,
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
trajectory = current_storage.get(key, default=None)
except Exception:
# eg, ListStorage
raise RuntimeError(
"Could not get a tensordict out of the storage, "
"which is required for SliceSampler to compute the trajectories."
)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if trajectory is None:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise

raise KeyError(f"Coulnd't find key={key} in storage.")
vals = self._find_start_stop_traj(
trajectory=trajectory,
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
else:
try:
try:
done = storage[:][self.end_key]
except Exception:
raise RuntimeError(
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
end=done.squeeze()[: len(storage)],
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
done = current_storage.get(self.end_key, None)
except Exception:
# eg, ListStorage
raise RuntimeError(
"Could not get a tensordict out of the storage, "
"which is required for SliceSampler to compute the trajectories."
)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if done is None:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise
raise KeyError(f"Couldn't find key={self.end_key} in storage.")

vals = self._find_start_stop_traj(
end=done.squeeze(),
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals

def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
Expand Down
4 changes: 1 addition & 3 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,9 +901,7 @@ def max_size_along_dim0(data_shape):

if is_tensor_collection(data):
out = data.to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.clone()
out = out.zero_()
out = torch.empty_like(out.expand(max_size_along_dim0(data.shape)))
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
Expand Down

0 comments on commit 9409cc8

Please sign in to comment.