Skip to content

Commit

Permalink
Added setters for various Quantity properties to enable setting of Qu…
Browse files Browse the repository at this point in the history
…antity metadata and data properties.
  • Loading branch information
gmao-ckung committed Dec 18, 2024
1 parent 2e41349 commit 2e669db
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 8 deletions.
30 changes: 22 additions & 8 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,28 @@ def _create_all_reduce_quantity(
)
return all_reduce_quantity

def all_reduce_sum(self, quantity: Quantity):
reduced_quantity_data = self.comm.allreduce(quantity.data, MPI.SUM)
all_reduce_quantity = self._create_all_reduce_quantity(
quantity.metadata, reduced_quantity_data
)
return all_reduce_quantity
# quantity.data = reduced_quantity_data
# return quantity
def all_reduce_sum(
self, input_quantity: Quantity, output_quantity: Quantity = None
):
reduced_quantity_data = self.comm.allreduce(input_quantity.data, MPI.SUM)
if output_quantity is None:
all_reduce_quantity = self._create_all_reduce_quantity(
input_quantity.metadata, reduced_quantity_data
)
return all_reduce_quantity
else:
if output_quantity.data.shape != input_quantity.data.shape:
raise TypeError("Shapes not matching")

output_quantity.metadata.dims = input_quantity.metadata.dims
output_quantity.metadata.units = input_quantity.metadata.units
output_quantity.metadata.origin = input_quantity.metadata.origin
output_quantity.metadata.extent = input_quantity.metadata.extent
output_quantity.metadata.gt4py_backend = (
input_quantity.metadata.gt4py_backend
)

output_quantity.data = reduced_quantity_data

def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer(
Expand Down
30 changes: 30 additions & 0 deletions ndsl/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,20 @@ def units(self) -> str:
"""units of the quantity"""
return self.metadata.units

@units.setter
def units(self, newUnits):
if type(newUnits) is str:
self.metadata.units = newUnits

@property
def gt4py_backend(self) -> Union[str, None]:
return self.metadata.gt4py_backend

@gt4py_backend.setter
def gt4py_backend(self, newBackend):
if type(newBackend) is Union[str, None]:
self.metadata.gt4py_backend = newBackend

@property
def attrs(self) -> dict:
return dict(**self._attrs, units=self._metadata.units)
Expand All @@ -471,6 +481,11 @@ def dims(self) -> Tuple[str, ...]:
"""names of each dimension"""
return self.metadata.dims

@dims.setter
def dims(self, newDims):
if type(newDims) is Tuple:
self.metadata.dims = newDims

@property
def values(self) -> np.ndarray:
warnings.warn(
Expand All @@ -492,16 +507,31 @@ def data(self) -> Union[np.ndarray, cupy.ndarray]:
"""the underlying array of data"""
return self._data

@data.setter
def data(self, inputData):
if type(inputData) in [np.ndarray, cupy.ndarray]:
self._data = inputData

@property
def origin(self) -> Tuple[int, ...]:
"""the start of the computational domain"""
return self.metadata.origin

@origin.setter
def origin(self, newOrigin):
if type(newOrigin) is Tuple:
self.metadata.origin = newOrigin

@property
def extent(self) -> Tuple[int, ...]:
"""the shape of the computational domain"""
return self.metadata.extent

@extent.setter
def extent(self, newExtent):
if type(newExtent) is Tuple:
self.metadata.extent = newExtent

@property
def data_array(self) -> xr.DataArray:
return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs)
Expand Down
47 changes: 47 additions & 0 deletions tests/mpi/test_mpi_all_reduce_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,50 @@ def test_all_reduce_sum(
global_sum_q = communicator.all_reduce_sum(testQuantity_3D)
assert global_sum_q.metadata == testQuantity_3D.metadata
assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all()

base_array = np.array([i for i in range(5)], dtype=Float)
testQuantity_1D_out = Quantity(
data=base_array,
dims=["K"],
units="New 1D unit",
gt4py_backend=backend,
origin=(8,),
extent=(7,),
)

base_array = np.array([i for i in range(5 * 5)], dtype=Float)
base_array = base_array.reshape(5, 5)

testQuantity_2D_out = Quantity(
data=base_array,
dims=["I", "J"],
units="Some 2D unit",
gt4py_backend=backend,
)

base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float)
base_array = base_array.reshape(5, 5, 5)

testQuantity_3D_out = Quantity(
data=base_array,
dims=["I", "J", "K"],
units="Some 3D unit",
gt4py_backend=backend,
)
communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out)
assert testQuantity_1D_out.metadata == testQuantity_1D_out.metadata
assert (
testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size)
).all()

communicator.all_reduce_sum(testQuantity_2D, testQuantity_2D_out)
assert testQuantity_2D_out.metadata == testQuantity_2D.metadata
assert (
testQuantity_2D_out.data == (testQuantity_2D.data * communicator.size)
).all()

communicator.all_reduce_sum(testQuantity_3D, testQuantity_3D_out)
assert testQuantity_3D_out.metadata == testQuantity_3D.metadata
assert (
testQuantity_3D_out.data == (testQuantity_3D.data * communicator.size)
).all()

0 comments on commit 2e669db

Please sign in to comment.