Skip to content

Commit

Permalink
Linted
Browse files Browse the repository at this point in the history
  • Loading branch information
gmao-ckung committed Dec 13, 2024
1 parent 8c5b5d5 commit fb4e740
Showing 1 changed file with 31 additions and 33 deletions.
64 changes: 31 additions & 33 deletions tests/mpi/test_mpi_all_reduce_sum.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import pytest
import numpy as np
from tests.mpi.mpi_comm import MPI
import pytest

from ndsl import (
CubedSphereCommunicator,
CubedSpherePartitioner,
Quantity,
TilePartitioner,
)

from ndsl.quantity import Quantity
from ndsl.comm.partitioner import Partitioner

from ndsl.dsl.typing import Float
from tests.mpi.mpi_comm import MPI


@pytest.fixture
def layout():
Expand All @@ -24,62 +21,63 @@ def layout():
else:
return (1, 1)


@pytest.fixture(params=[0.1, 1.0])
def edge_interior_ratio(request):
return request.param


@pytest.fixture
def tile_partitioner(layout, edge_interior_ratio: float):
return TilePartitioner(layout, edge_interior_ratio=edge_interior_ratio)


@pytest.fixture
def cube_partitioner(tile_partitioner):
return CubedSpherePartitioner(tile_partitioner)


@pytest.fixture()
def communicator(cube_partitioner):
return CubedSphereCommunicator(
comm=MPI.COMM_WORLD,
partitioner=cube_partitioner,
)


def test_all_reduce_sum(
communicator,
communicator,
):

backend = "numpy"
base_array = np.array([i for i in range(5)], dtype=Float)

testQuantity_1D = Quantity(
data=base_array,
dims=["K"],
units="Some 1D unit",
gt4py_backend=backend,
)
base_array = np.array([i for i in range(5*5)], dtype=Float)
base_array = base_array.reshape(5,5)
data=base_array,
dims=["K"],
units="Some 1D unit",
gt4py_backend=backend,
)

base_array = np.array([i for i in range(5 * 5)], dtype=Float)
base_array = base_array.reshape(5, 5)

testQuantity_2D = Quantity(
data=base_array,
dims=["I","J"],
units="Some 2D unit",
gt4py_backend=backend,
)
data=base_array,
dims=["I", "J"],
units="Some 2D unit",
gt4py_backend=backend,
)

base_array = np.array([i for i in range(5*5*5)], dtype=Float)
base_array = base_array.reshape(5,5,5)
base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float)
base_array = base_array.reshape(5, 5, 5)

testQuantity_3D = Quantity(
data=base_array,
dims=["I","J","K"],
units="Some 3D unit",
gt4py_backend=backend,
)

# print("Communicator rank = ", communicator.rank)
# print("Communicator size = ", communicator.size)
# print("nsize = ", nsize)
data=base_array,
dims=["I", "J", "K"],
units="Some 3D unit",
gt4py_backend=backend,
)

global_sum_q = communicator.all_reduce_sum(testQuantity_1D)
assert global_sum_q.metadata == testQuantity_1D.metadata
Expand All @@ -91,4 +89,4 @@ def test_all_reduce_sum(

global_sum_q = communicator.all_reduce_sum(testQuantity_3D)
assert global_sum_q.metadata == testQuantity_3D.metadata
assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all()
assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all()

0 comments on commit fb4e740

Please sign in to comment.