From fb4e74010615f15f962a8a0572ed48f1267b5581 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Fri, 13 Dec 2024 09:48:09 -0800 Subject: [PATCH] Linted --- tests/mpi/test_mpi_all_reduce_sum.py | 64 ++++++++++++++-------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index f03787e..728ec4f 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -1,6 +1,5 @@ -import pytest import numpy as np -from tests.mpi.mpi_comm import MPI +import pytest from ndsl import ( CubedSphereCommunicator, @@ -8,11 +7,9 @@ Quantity, TilePartitioner, ) - -from ndsl.quantity import Quantity -from ndsl.comm.partitioner import Partitioner - from ndsl.dsl.typing import Float +from tests.mpi.mpi_comm import MPI + @pytest.fixture def layout(): @@ -24,18 +21,22 @@ def layout(): else: return (1, 1) + @pytest.fixture(params=[0.1, 1.0]) def edge_interior_ratio(request): return request.param + @pytest.fixture def tile_partitioner(layout, edge_interior_ratio: float): return TilePartitioner(layout, edge_interior_ratio=edge_interior_ratio) + @pytest.fixture def cube_partitioner(tile_partitioner): return CubedSpherePartitioner(tile_partitioner) + @pytest.fixture() def communicator(cube_partitioner): return CubedSphereCommunicator( @@ -43,43 +44,40 @@ def communicator(cube_partitioner): partitioner=cube_partitioner, ) + def test_all_reduce_sum( - communicator, + communicator, ): - + backend = "numpy" base_array = np.array([i for i in range(5)], dtype=Float) testQuantity_1D = Quantity( - data=base_array, - dims=["K"], - units="Some 1D unit", - gt4py_backend=backend, - ) - - base_array = np.array([i for i in range(5*5)], dtype=Float) - base_array = base_array.reshape(5,5) + data=base_array, + dims=["K"], + units="Some 1D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) testQuantity_2D = Quantity( - data=base_array, - dims=["I","J"], - units="Some 2D unit", - gt4py_backend=backend, - ) + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) - base_array = np.array([i for i in range(5*5*5)], dtype=Float) - base_array = base_array.reshape(5,5,5) + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) testQuantity_3D = Quantity( - data=base_array, - dims=["I","J","K"], - units="Some 3D unit", - gt4py_backend=backend, - ) - - # print("Communicator rank = ", communicator.rank) - # print("Communicator size = ", communicator.size) - # print("nsize = ", nsize) + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) global_sum_q = communicator.all_reduce_sum(testQuantity_1D) assert global_sum_q.metadata == testQuantity_1D.metadata @@ -91,4 +89,4 @@ def test_all_reduce_sum( global_sum_q = communicator.all_reduce_sum(testQuantity_3D) assert global_sum_q.metadata == testQuantity_3D.metadata - assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() \ No newline at end of file + assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all()