diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index e5aaa873870..500f457ad20 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1434,8 +1434,13 @@ def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, _data in enumerate(tensordict.unbind(0)): - self.parent_channels[i].send(("step_and_maybe_reset", _data)) + td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) + for i in range(td.shape[0]): + # We send the same td multiple times as it is in shared mem and we just need to index it + # in each process. + # If we don't do this, we need to unbind it but then the custom pickler will require + # some extra metadata to be collected. + self.parent_channels[i].send(("step_and_maybe_reset", (td, i))) results = [None] * self.num_workers @@ -1556,8 +1561,11 @@ def step_and_maybe_reset( def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - for i, data in enumerate(tensordict.unbind(0)): - self.parent_channels[i].send(("step", data)) + data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) + for i, local_data in enumerate(data.unbind(0)): + self.parent_channels[i].send(("step", local_data)) + # for i in range(data.shape[0]): + # self.parent_channels[i].send(("step", (data, i))) out_tds = [] for i, channel in enumerate(self.parent_channels): self._events[i].wait() @@ -1663,17 +1671,24 @@ def _reset_no_buffers( reset_kwargs_list, needs_resetting, ) -> Tuple[TensorDictBase, TensorDictBase]: - tdunbound = ( - tensordict.unbind(0) - if is_tensor_collection(tensordict) - else [None] * self.num_workers - ) + if is_tensor_collection(tensordict): + # tensordict = tensordict.consolidate(share_memory=True, num_threads=1) + tensordict = tensordict.consolidate( + share_memory=True, num_threads=1 + ).unbind(0) + else: + tensordict = [None] * self.num_workers out_tds = [None] * self.num_workers - for i, (data, reset_kwargs) in enumerate(zip(tdunbound, reset_kwargs_list)): + for i, (local_data, reset_kwargs) in enumerate( + zip(tensordict, reset_kwargs_list) + ): if not needs_resetting[i]: - out_tds[i] = tdunbound[i].exclude(*self.reset_keys) + localtd = local_data + if localtd is not None: + localtd = localtd.exclude(*self.reset_keys) + out_tds[i] = localtd continue - self.parent_channels[i].send(("reset", (data, reset_kwargs))) + self.parent_channels[i].send(("reset", (local_data, reset_kwargs))) for i, channel in enumerate(self.parent_channels): if not needs_resetting[i]: @@ -1995,10 +2010,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda): i = 0 next_shared_tensordict = shared_tensordict.get("next") root_shared_tensordict = shared_tensordict.exclude("next") - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): - raise RuntimeError( - "tensordict must be placed in shared memory (share_memory_() or memmap_())" - ) + # TODO: restore this + # if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): + # raise RuntimeError( + # "tensordict must be placed in shared memory (share_memory_() or memmap_())" + # ) shared_tensordict = shared_tensordict.clone(False).unlock_() initialized = True @@ -2243,6 +2259,8 @@ def _run_worker_pipe_direct( raise RuntimeError("call 'init' before resetting") # we use 'data' to pass the keys that we need to pass to reset, # because passing the entire buffer may have unwanted consequences + # data, idx, reset_kwargs = data + # data = data[idx] data, reset_kwargs = data if data is not None: data._fast_apply( @@ -2256,25 +2274,33 @@ def _run_worker_pipe_direct( event.record() event.synchronize() mp_event.set() - child_pipe.send(cur_td) + child_pipe.send( + cur_td.consolidate(share_memory=True, inplace=True, num_threads=1) + ) del cur_td elif cmd == "step": if not initialized: raise RuntimeError("called 'init' before step") i += 1 + # data, idx = data + # data = data[idx] next_td = env._step(data) if event is not None: event.record() event.synchronize() mp_event.set() - child_pipe.send(next_td) + child_pipe.send( + next_td.consolidate(share_memory=True, inplace=True, num_threads=1) + ) del next_td elif cmd == "step_and_maybe_reset": if not initialized: raise RuntimeError("called 'init' before step") i += 1 + # data, idx = data + # data = data[idx] data._fast_apply( lambda x: x.clone() if x.device.type == "cuda" else x, out=data )