Skip to content

Commit

Permalink
Fix comm object in serial utest
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Dec 22, 2024
1 parent 7ad271f commit 07cd0f3
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions tests/dsl/test_compilation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CubedSpherePartitioner,
RunMode,
TilePartitioner,
NullComm,
)


Expand All @@ -33,8 +34,7 @@ def test_check_communicator_valid(
partitioner = CubedSpherePartitioner(
TilePartitioner((int(sqrt(size / 6)), int((sqrt(size / 6)))))
)
comm = unittest.mock.MagicMock()
comm.Get_size.return_value = size
comm = NullComm(rank=0, total_ranks=size)
cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner)
config = CompilationConfig(
run_mode=run_mode, use_minimal_caching=use_minimal_caching
Expand All @@ -52,8 +52,7 @@ def test_check_communicator_invalid(
nx: int, ny: int, use_minimal_caching: bool, run_mode: RunMode
):
partitioner = CubedSpherePartitioner(TilePartitioner((nx, ny)))
comm = unittest.mock.MagicMock()
comm.Get_size.return_value = nx * ny * 6
comm = NullComm(rank=0, total_ranks=nx * ny * 6)
cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner)
config = CompilationConfig(
run_mode=run_mode, use_minimal_caching=use_minimal_caching
Expand Down Expand Up @@ -91,9 +90,7 @@ def test_get_decomposition_info_from_comm(
partitioner = CubedSpherePartitioner(
TilePartitioner((int(sqrt(size / 6)), int(sqrt(size / 6))))
)
comm = unittest.mock.MagicMock()
comm.Get_rank.return_value = rank
comm.Get_size.return_value = size
comm = NullComm(rank=rank, total_ranks=size)
cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner)
config = CompilationConfig(use_minimal_caching=True, run_mode=RunMode.Run)
(
Expand Down Expand Up @@ -133,8 +130,7 @@ def test_determine_compiling_equivalent(
TilePartitioner((sqrt(size / 6), sqrt(size / 6)))
)
comm = unittest.mock.MagicMock()
comm.Get_rank.return_value = rank
comm.Get_size.return_value = size
comm = NullComm(rank=rank, total_ranks=size)
cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner)
assert (
config.determine_compiling_equivalent(rank, cubed_sphere_comm.partitioner)
Expand Down

0 comments on commit 07cd0f3

Please sign in to comment.