diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 55efa96..ead2489 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -122,13 +122,7 @@ def all_reduce_sum( if output_quantity.data.shape != input_quantity.data.shape: raise TypeError("Shapes not matching") - output_quantity.metadata.dims = input_quantity.metadata.dims - output_quantity.metadata.units = input_quantity.metadata.units - output_quantity.metadata.origin = input_quantity.metadata.origin - output_quantity.metadata.extent = input_quantity.metadata.extent - output_quantity.metadata.gt4py_backend = ( - input_quantity.metadata.gt4py_backend - ) + input_quantity.metadata.duplicate_metadata(output_quantity.metadata) output_quantity.data = reduced_quantity_data diff --git a/ndsl/quantity.py b/ndsl/quantity.py index 80bb4d0..a38a7a5 100644 --- a/ndsl/quantity.py +++ b/ndsl/quantity.py @@ -53,6 +53,15 @@ def np(self) -> NumpyModule: f"quantity underlying data is of unexpected type {self.data_type}" ) + def duplicate_metadata(self, metadata_copy): + metadata_copy.origin = self.origin + metadata_copy.extent = self.extent + metadata_copy.dims = self.dims + metadata_copy.units = self.units + metadata_copy.data_type = self.data_type + metadata_copy.dtype = self.dtype + metadata_copy.gt4py_backend = self.gt4py_backend + @dataclasses.dataclass class QuantityHaloSpec: @@ -458,20 +467,10 @@ def units(self) -> str: """units of the quantity""" return self.metadata.units - @units.setter - def units(self, newUnits): - if type(newUnits) is str: - self.metadata.units = newUnits - @property def gt4py_backend(self) -> Union[str, None]: return self.metadata.gt4py_backend - @gt4py_backend.setter - def gt4py_backend(self, newBackend): - if type(newBackend) is Union[str, None]: - self.metadata.gt4py_backend = newBackend - @property def attrs(self) -> dict: return dict(**self._attrs, units=self._metadata.units) @@ -481,11 +480,6 @@ def dims(self) -> Tuple[str, ...]: """names of each dimension""" return self.metadata.dims - @dims.setter - def dims(self, newDims): - if type(newDims) is Tuple: - self.metadata.dims = newDims - @property def values(self) -> np.ndarray: warnings.warn( @@ -517,21 +511,11 @@ def origin(self) -> Tuple[int, ...]: """the start of the computational domain""" return self.metadata.origin - @origin.setter - def origin(self, newOrigin): - if type(newOrigin) is Tuple: - self.metadata.origin = newOrigin - @property def extent(self) -> Tuple[int, ...]: """the shape of the computational domain""" return self.metadata.extent - @extent.setter - def extent(self, newExtent): - if type(newExtent) is Tuple: - self.metadata.extent = newExtent - @property def data_array(self) -> xr.DataArray: return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 9c2b3a3..858a7f9 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -126,7 +126,7 @@ def test_all_reduce_sum( gt4py_backend=backend, ) communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out) - assert testQuantity_1D_out.metadata == testQuantity_1D_out.metadata + assert testQuantity_1D_out.metadata == testQuantity_1D.metadata assert ( testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size) ).all()