From 2e669dbae2fccce6c65dac33db9acf6a5ec564ac Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Wed, 18 Dec 2024 11:36:07 -0800 Subject: [PATCH] Added setters for various Quantity properties to enable setting of Quantity metadata and data properties. --- ndsl/comm/communicator.py | 30 +++++++++++++----- ndsl/quantity.py | 30 ++++++++++++++++++ tests/mpi/test_mpi_all_reduce_sum.py | 47 ++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 8 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 5f19b2e..55efa96 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -109,14 +109,28 @@ def _create_all_reduce_quantity( ) return all_reduce_quantity - def all_reduce_sum(self, quantity: Quantity): - reduced_quantity_data = self.comm.allreduce(quantity.data, MPI.SUM) - all_reduce_quantity = self._create_all_reduce_quantity( - quantity.metadata, reduced_quantity_data - ) - return all_reduce_quantity - # quantity.data = reduced_quantity_data - # return quantity + def all_reduce_sum( + self, input_quantity: Quantity, output_quantity: Quantity = None + ): + reduced_quantity_data = self.comm.allreduce(input_quantity.data, MPI.SUM) + if output_quantity is None: + all_reduce_quantity = self._create_all_reduce_quantity( + input_quantity.metadata, reduced_quantity_data + ) + return all_reduce_quantity + else: + if output_quantity.data.shape != input_quantity.data.shape: + raise TypeError("Shapes not matching") + + output_quantity.metadata.dims = input_quantity.metadata.dims + output_quantity.metadata.units = input_quantity.metadata.units + output_quantity.metadata.origin = input_quantity.metadata.origin + output_quantity.metadata.extent = input_quantity.metadata.extent + output_quantity.metadata.gt4py_backend = ( + input_quantity.metadata.gt4py_backend + ) + + output_quantity.data = reduced_quantity_data def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( diff --git a/ndsl/quantity.py b/ndsl/quantity.py index b95a9aa..80bb4d0 100644 --- a/ndsl/quantity.py +++ b/ndsl/quantity.py @@ -458,10 +458,20 @@ def units(self) -> str: """units of the quantity""" return self.metadata.units + @units.setter + def units(self, newUnits): + if type(newUnits) is str: + self.metadata.units = newUnits + @property def gt4py_backend(self) -> Union[str, None]: return self.metadata.gt4py_backend + @gt4py_backend.setter + def gt4py_backend(self, newBackend): + if type(newBackend) is Union[str, None]: + self.metadata.gt4py_backend = newBackend + @property def attrs(self) -> dict: return dict(**self._attrs, units=self._metadata.units) @@ -471,6 +481,11 @@ def dims(self) -> Tuple[str, ...]: """names of each dimension""" return self.metadata.dims + @dims.setter + def dims(self, newDims): + if type(newDims) is Tuple: + self.metadata.dims = newDims + @property def values(self) -> np.ndarray: warnings.warn( @@ -492,16 +507,31 @@ def data(self) -> Union[np.ndarray, cupy.ndarray]: """the underlying array of data""" return self._data + @data.setter + def data(self, inputData): + if type(inputData) in [np.ndarray, cupy.ndarray]: + self._data = inputData + @property def origin(self) -> Tuple[int, ...]: """the start of the computational domain""" return self.metadata.origin + @origin.setter + def origin(self, newOrigin): + if type(newOrigin) is Tuple: + self.metadata.origin = newOrigin + @property def extent(self) -> Tuple[int, ...]: """the shape of the computational domain""" return self.metadata.extent + @extent.setter + def extent(self, newExtent): + if type(newExtent) is Tuple: + self.metadata.extent = newExtent + @property def data_array(self) -> xr.DataArray: return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 9ba01e0..9c2b3a3 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -95,3 +95,50 @@ def test_all_reduce_sum( global_sum_q = communicator.all_reduce_sum(testQuantity_3D) assert global_sum_q.metadata == testQuantity_3D.metadata assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() + + base_array = np.array([i for i in range(5)], dtype=Float) + testQuantity_1D_out = Quantity( + data=base_array, + dims=["K"], + units="New 1D unit", + gt4py_backend=backend, + origin=(8,), + extent=(7,), + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) + + testQuantity_2D_out = Quantity( + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) + + testQuantity_3D_out = Quantity( + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) + communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out) + assert testQuantity_1D_out.metadata == testQuantity_1D_out.metadata + assert ( + testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size) + ).all() + + communicator.all_reduce_sum(testQuantity_2D, testQuantity_2D_out) + assert testQuantity_2D_out.metadata == testQuantity_2D.metadata + assert ( + testQuantity_2D_out.data == (testQuantity_2D.data * communicator.size) + ).all() + + communicator.all_reduce_sum(testQuantity_3D, testQuantity_3D_out) + assert testQuantity_3D_out.metadata == testQuantity_3D.metadata + assert ( + testQuantity_3D_out.data == (testQuantity_3D.data * communicator.size) + ).all()