Skip to content

Commit

Permalink
control ucx presence in install_info more carefully (#882)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryevdv authored Nov 6, 2023
1 parent a4b5430 commit 8f6d882
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions tests/unit/legate/driver/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,10 +1046,13 @@ def test_default_single_rank(self, genobjs: GenObjs) -> None:

assert result == ()

def test_utility_1_single_rank(self, genobjs: GenObjs) -> None:
def test_utility_1_single_rank_no_ucx(self, genobjs: GenObjs) -> None:
config, system, launcher = genobjs(["--utility", "1"])

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ()

Expand All @@ -1064,12 +1067,15 @@ def test_utility_1_single_rank_and_ucx(self, genobjs: GenObjs) -> None:
assert result == ()

@pytest.mark.parametrize("value", ("2", "3", "10"))
def test_utiltity_n_single_rank(
def test_utiltity_n_single_rank_no_ucx(
self, genobjs: GenObjs, value: str
) -> None:
config, system, launcher = genobjs(["--utility", value])

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ()

Expand All @@ -1088,14 +1094,17 @@ def test_utiltity_n_single_rank_and_ucx(

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
def test_default_multi_rank(
def test_default_multi_rank_no_ucx(
self, genobjs: GenObjs, rank: str, rank_var: dict[str, str]
) -> None:
config, system, launcher = genobjs(
[], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2")

Expand All @@ -1117,14 +1126,17 @@ def test_default_multi_rank_and_ucx(

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
def test_utility_1_multi_rank_no_launcher(
def test_utility_1_multi_rank_no_launcher_no_ucx(
self, genobjs: GenObjs, rank: str, rank_var: dict[str, str]
) -> None:
config, system, launcher = genobjs(
["--utility", "1"], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2")

Expand All @@ -1145,14 +1157,17 @@ def test_utility_1_multi_rank_no_launcher_and_ucx(
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
def test_utility_1_multi_rank_with_launcher(
def test_utility_1_multi_rank_with_launcher_no_ucx(
self, genobjs: GenObjs, launch: str
) -> None:
config, system, launcher = genobjs(
["--utility", "1", "--launcher", launch], multi_rank=(2, 2)
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", "2")

Expand All @@ -1174,14 +1189,17 @@ def test_utility_1_multi_rank_with_launcher_and_ucx(
@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
@pytest.mark.parametrize("value", ("2", "3", "10"))
def test_utility_n_multi_rank_no_launcher(
def test_utility_n_multi_rank_no_launcher_no_ucx(
self, genobjs: GenObjs, value: str, rank: str, rank_var: dict[str, str]
) -> None:
config, system, launcher = genobjs(
["--utility", value], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", value)

Expand All @@ -1204,14 +1222,17 @@ def test_utility_n_multi_rank_no_launcher_and_ucx(

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
@pytest.mark.parametrize("value", ("2", "3", "10"))
def test_utility_n_multi_rank_with_launcher(
def test_utility_n_multi_rank_with_launcher_no_ucx(
self, genobjs: GenObjs, value: str, launch: str
) -> None:
config, system, launcher = genobjs(
["--utility", value, "--launcher", launch], multi_rank=(2, 2)
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ("-ll:bgwork", value)

Expand Down

0 comments on commit 8f6d882

Please sign in to comment.