Skip to content

Commit

Permalink
install_info.networks.remove("ucx") --2
Browse files Browse the repository at this point in the history
  • Loading branch information
mag1cp1n authored Oct 18, 2023
1 parent 5368dff commit e187250
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions tests/unit/legate/driver/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,9 +1095,12 @@ def test_default_multi_rank(
[], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks.remove("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2") #PRGX

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1111,7 +1114,6 @@ def test_default_multi_rank_and_ucx(
networks_orig = list(install_info.networks)
install_info.networks.append("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks.remove("ucx")
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
Expand All @@ -1125,9 +1127,12 @@ def test_utility_1_multi_rank_no_launcher(
["--utility", "1"], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks.remove("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2") #PRGX

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1141,7 +1146,6 @@ def test_utility_1_multi_rank_no_launcher_and_ucx(
networks_orig = list(install_info.networks)
install_info.networks.append("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks.remove("ucx")
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
Expand All @@ -1154,9 +1158,12 @@ def test_utility_1_multi_rank_with_launcher(
["--utility", "1", "--launcher", launch], multi_rank=(2, 2)
)

networks_orig = list(install_info.networks)
install_info.networks.remove("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2") #PRGX

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
def test_utility_1_multi_rank_with_launcher_and_ucx(
Expand All @@ -1169,7 +1176,6 @@ def test_utility_1_multi_rank_with_launcher_and_ucx(
networks_orig = list(install_info.networks)
install_info.networks.append("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks.remove("ucx")
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
Expand All @@ -1184,9 +1190,12 @@ def test_utility_n_multi_rank_no_launcher(
["--utility", value], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks.remove("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", value)
assert result == ("-ll:bgwork", value) #PRGX

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1201,7 +1210,6 @@ def test_utility_n_multi_rank_no_launcher_and_ucx(
networks_orig = list(install_info.networks)
install_info.networks.append("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks.remove("ucx")
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", value, "-ll:bgworkpin", "1")
Expand All @@ -1215,9 +1223,12 @@ def test_utility_n_multi_rank_with_launcher(
["--utility", value, "--launcher", launch], multi_rank=(2, 2)
)

networks_orig = list(install_info.networks)
install_info.networks.remove("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", value)
assert result == ("-ll:bgwork", value) #PRGX

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
@pytest.mark.parametrize("value", ("2", "3", "10"))
Expand All @@ -1231,7 +1242,6 @@ def test_utility_n_multi_rank_with_launcher_and_ucx(
networks_orig = list(install_info.networks)
install_info.networks.append("ucx")
result = m.cmd_bgwork(config, system, launcher)
install_info.networks.remove("ucx")
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", value, "-ll:bgworkpin", "1")
Expand Down Expand Up @@ -1597,4 +1607,4 @@ def test_with_legate_opts(self, genobjs: GenObjs, opts: list[str]) -> None:


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
sys.exit(pytest.main(sys.argv))

0 comments on commit e187250

Please sign in to comment.