diff --git a/torchx/components/test/dist_test.py b/torchx/components/test/dist_test.py index 90574f6de..22f108b85 100644 --- a/torchx/components/test/dist_test.py +++ b/torchx/components/test/dist_test.py @@ -38,6 +38,13 @@ def test_ddp_debug(self) -> None: for k, v in _TORCH_DEBUG_FLAGS.items(): self.assertEqual(env[k], v) + def test_ddp_rdzv_backend_static(self) -> None: + app = ddp(script="foo.py", rdzv_backend="static") + cmd = app.roles[0].args[1] + self.assertTrue("--rdzv_backend static" in cmd) + self.assertTrue("--node_rank" in cmd) + + class SpmdTest(ComponentTestCase): def test_validate_spmd(self) -> None: