Skip to content

Commit

Permalink
[Performance] consolidate TDs in ParallelEnv without buffers (#2231)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 28, 2024
1 parent eb6c85d commit 443620f
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,8 +1434,13 @@ def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:

for i, _data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step_and_maybe_reset", _data))
td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
for i in range(td.shape[0]):
# We send the same td multiple times as it is in shared mem and we just need to index it
# in each process.
# If we don't do this, we need to unbind it but then the custom pickler will require
# some extra metadata to be collected.
self.parent_channels[i].send(("step_and_maybe_reset", (td, i)))

results = [None] * self.num_workers

Expand Down Expand Up @@ -1556,8 +1561,11 @@ def step_and_maybe_reset(
def _step_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
for i, data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step", data))
data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
for i, local_data in enumerate(data.unbind(0)):
self.parent_channels[i].send(("step", local_data))
# for i in range(data.shape[0]):
# self.parent_channels[i].send(("step", (data, i)))
out_tds = []
for i, channel in enumerate(self.parent_channels):
self._events[i].wait()
Expand Down Expand Up @@ -1663,17 +1671,24 @@ def _reset_no_buffers(
reset_kwargs_list,
needs_resetting,
) -> Tuple[TensorDictBase, TensorDictBase]:
tdunbound = (
tensordict.unbind(0)
if is_tensor_collection(tensordict)
else [None] * self.num_workers
)
if is_tensor_collection(tensordict):
# tensordict = tensordict.consolidate(share_memory=True, num_threads=1)
tensordict = tensordict.consolidate(
share_memory=True, num_threads=1
).unbind(0)
else:
tensordict = [None] * self.num_workers
out_tds = [None] * self.num_workers
for i, (data, reset_kwargs) in enumerate(zip(tdunbound, reset_kwargs_list)):
for i, (local_data, reset_kwargs) in enumerate(
zip(tensordict, reset_kwargs_list)
):
if not needs_resetting[i]:
out_tds[i] = tdunbound[i].exclude(*self.reset_keys)
localtd = local_data
if localtd is not None:
localtd = localtd.exclude(*self.reset_keys)
out_tds[i] = localtd
continue
self.parent_channels[i].send(("reset", (data, reset_kwargs)))
self.parent_channels[i].send(("reset", (local_data, reset_kwargs)))

for i, channel in enumerate(self.parent_channels):
if not needs_resetting[i]:
Expand Down Expand Up @@ -1995,10 +2010,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
i = 0
next_shared_tensordict = shared_tensordict.get("next")
root_shared_tensordict = shared_tensordict.exclude("next")
if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
raise RuntimeError(
"tensordict must be placed in shared memory (share_memory_() or memmap_())"
)
# TODO: restore this
# if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
# raise RuntimeError(
# "tensordict must be placed in shared memory (share_memory_() or memmap_())"
# )
shared_tensordict = shared_tensordict.clone(False).unlock_()

initialized = True
Expand Down Expand Up @@ -2243,6 +2259,8 @@ def _run_worker_pipe_direct(
raise RuntimeError("call 'init' before resetting")
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
# data, idx, reset_kwargs = data
# data = data[idx]
data, reset_kwargs = data
if data is not None:
data._fast_apply(
Expand All @@ -2256,25 +2274,33 @@ def _run_worker_pipe_direct(
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(cur_td)
child_pipe.send(
cur_td.consolidate(share_memory=True, inplace=True, num_threads=1)
)
del cur_td

elif cmd == "step":
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# data, idx = data
# data = data[idx]
next_td = env._step(data)
if event is not None:
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(next_td)
child_pipe.send(
next_td.consolidate(share_memory=True, inplace=True, num_threads=1)
)
del next_td

elif cmd == "step_and_maybe_reset":
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# data, idx = data
# data = data[idx]
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
Expand Down

0 comments on commit 443620f

Please sign in to comment.