Skip to content

Commit

Permalink
[BugFix] Fix mp_start_method for ParallelEnv with single_for_serial (#…
Browse files Browse the repository at this point in the history
…2007)

(cherry picked from commit 2b8450c)
  • Loading branch information
vmoens committed Apr 7, 2024
1 parent a4795aa commit 11512c9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 14 additions & 5 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,13 +446,22 @@ def test_parallel_devices(
env.shared_tensordict_parent.device.type == torch.device(edevice).type
)

def test_serial_for_single(self, maybe_fork_ParallelEnv):
env = ParallelEnv(1, ContinuousActionVecMockEnv, serial_for_single=True)
@pytest.mark.parametrize("start_method", [None, "fork"])
def test_serial_for_single(self, maybe_fork_ParallelEnv, start_method):
env = ParallelEnv(
1,
ContinuousActionVecMockEnv,
serial_for_single=True,
mp_start_method=start_method,
)
assert isinstance(env, SerialEnv)
env = maybe_fork_ParallelEnv(1, ContinuousActionVecMockEnv)
env = ParallelEnv(1, ContinuousActionVecMockEnv, mp_start_method=start_method)
assert isinstance(env, ParallelEnv)
env = maybe_fork_ParallelEnv(
2, ContinuousActionVecMockEnv, serial_for_single=True
env = ParallelEnv(
2,
ContinuousActionVecMockEnv,
serial_for_single=True,
mp_start_method=start_method,
)
assert isinstance(env, ParallelEnv)

Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __call__(cls, *args, **kwargs):
serial_for_single = kwargs.pop("serial_for_single", False)
if serial_for_single:
num_workers = kwargs.get("num_workers", None)
# Remove start method from kwargs
kwargs.pop("mp_start_method", None)
if num_workers is None:
num_workers = args[0]
if num_workers == 1:
Expand Down

0 comments on commit 11512c9

Please sign in to comment.