Skip to content

Commit

Permalink
Update test_command.py
Browse files Browse the repository at this point in the history
check with original test
  • Loading branch information
mag1cp1n authored Nov 16, 2023
1 parent 1d42ffc commit c55182f
Showing 1 changed file with 12 additions and 29 deletions.
41 changes: 12 additions & 29 deletions tests/unit/legate/driver/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,10 +1106,7 @@ def test_default_multi_rank_no_ucx(
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2")

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1124,10 +1121,8 @@ def test_default_multi_rank_and_ucx(
install_info.networks.append("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]
if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", value, "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", value)

assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1143,10 +1138,7 @@ def test_utility_1_multi_rank_no_launcher_no_ucx(
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2")

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1161,10 +1153,8 @@ def test_utility_1_multi_rank_no_launcher_and_ucx(
install_info.networks.append("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]
if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", "2")

assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
def test_utility_1_multi_rank_with_launcher_no_ucx(
Expand All @@ -1179,10 +1169,7 @@ def test_utility_1_multi_rank_with_launcher_no_ucx(
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2")

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
def test_utility_1_multi_rank_with_launcher_and_ucx(
Expand Down Expand Up @@ -1213,10 +1200,8 @@ def test_utility_n_multi_rank_no_launcher_no_ucx(
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]
if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", value, "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", value)

assert result == ("-ll:bgwork", value)

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand Down Expand Up @@ -1248,10 +1233,8 @@ def test_utility_n_multi_rank_with_launcher_no_ucx(
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]
if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", value, "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", value)

assert result == ("-ll:bgwork", value)

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
@pytest.mark.parametrize("value", ("2", "3", "10"))
Expand Down Expand Up @@ -1630,4 +1613,4 @@ def test_with_legate_opts(self, genobjs: GenObjs, opts: list[str]) -> None:


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
sys.exit(pytest.main(sys.argv))

0 comments on commit c55182f

Please sign in to comment.