Skip to content

Commit

Permalink
add node_rank to torchrun cmd if rdzv_backend is 'static' (pytorch#761)
Browse files Browse the repository at this point in the history
* add node_rank to torchrun cmd if rdzv_backend is 'static'

* add unit test for static backend option
  • Loading branch information
MichaelClifford authored and KPostOffice committed Sep 7, 2023
1 parent b2e3075 commit 9882f12
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchx/components/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def ddp(
"--role",
"",
]
# TODO 'node_rank' is made optional as it currently does not work with the AWS Batch scheduler.
# node_rank is only used when rdzv_backend is 'static'
if rdzv_backend == "static":
cmd += ["--node_rank", f"{macros.replica_id}"]
if script is not None:
cmd += [script]
elif m is not None:
Expand Down
6 changes: 6 additions & 0 deletions torchx/components/test/dist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def test_ddp_debug(self) -> None:
for k, v in _TORCH_DEBUG_FLAGS.items():
self.assertEqual(env[k], v)

def test_ddp_rdzv_backend_static(self) -> None:
app = ddp(script="foo.py", rdzv_backend="static")
cmd = app.roles[0].args[1]
self.assertTrue("--rdzv_backend static" in cmd)
self.assertTrue("--node_rank" in cmd)


class SpmdTest(ComponentTestCase):
def test_validate_spmd(self) -> None:
Expand Down

0 comments on commit 9882f12

Please sign in to comment.