Skip to content

Commit

Permalink
Added function in QuantityMetadata class that allows copying of Metad…
Browse files Browse the repository at this point in the history
…ata properties from one class to another. Subsequent Quantity setters that performed the copying of QuantityMetadata properties were removed
  • Loading branch information
gmao-ckung committed Dec 19, 2024
1 parent 2e669db commit fd2fa97
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 33 deletions.
8 changes: 1 addition & 7 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,7 @@ def all_reduce_sum(
if output_quantity.data.shape != input_quantity.data.shape:
raise TypeError("Shapes not matching")

output_quantity.metadata.dims = input_quantity.metadata.dims
output_quantity.metadata.units = input_quantity.metadata.units
output_quantity.metadata.origin = input_quantity.metadata.origin
output_quantity.metadata.extent = input_quantity.metadata.extent
output_quantity.metadata.gt4py_backend = (
input_quantity.metadata.gt4py_backend
)
input_quantity.metadata.duplicate_metadata(output_quantity.metadata)

output_quantity.data = reduced_quantity_data

Expand Down
34 changes: 9 additions & 25 deletions ndsl/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def np(self) -> NumpyModule:
f"quantity underlying data is of unexpected type {self.data_type}"
)

def duplicate_metadata(self, metadata_copy):
metadata_copy.origin = self.origin
metadata_copy.extent = self.extent
metadata_copy.dims = self.dims
metadata_copy.units = self.units
metadata_copy.data_type = self.data_type
metadata_copy.dtype = self.dtype
metadata_copy.gt4py_backend = self.gt4py_backend


@dataclasses.dataclass
class QuantityHaloSpec:
Expand Down Expand Up @@ -458,20 +467,10 @@ def units(self) -> str:
"""units of the quantity"""
return self.metadata.units

@units.setter
def units(self, newUnits):
if type(newUnits) is str:
self.metadata.units = newUnits

@property
def gt4py_backend(self) -> Union[str, None]:
return self.metadata.gt4py_backend

@gt4py_backend.setter
def gt4py_backend(self, newBackend):
if type(newBackend) is Union[str, None]:
self.metadata.gt4py_backend = newBackend

@property
def attrs(self) -> dict:
return dict(**self._attrs, units=self._metadata.units)
Expand All @@ -481,11 +480,6 @@ def dims(self) -> Tuple[str, ...]:
"""names of each dimension"""
return self.metadata.dims

@dims.setter
def dims(self, newDims):
if type(newDims) is Tuple:
self.metadata.dims = newDims

@property
def values(self) -> np.ndarray:
warnings.warn(
Expand Down Expand Up @@ -517,21 +511,11 @@ def origin(self) -> Tuple[int, ...]:
"""the start of the computational domain"""
return self.metadata.origin

@origin.setter
def origin(self, newOrigin):
if type(newOrigin) is Tuple:
self.metadata.origin = newOrigin

@property
def extent(self) -> Tuple[int, ...]:
"""the shape of the computational domain"""
return self.metadata.extent

@extent.setter
def extent(self, newExtent):
if type(newExtent) is Tuple:
self.metadata.extent = newExtent

@property
def data_array(self) -> xr.DataArray:
return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs)
Expand Down
2 changes: 1 addition & 1 deletion tests/mpi/test_mpi_all_reduce_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_all_reduce_sum(
gt4py_backend=backend,
)
communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out)
assert testQuantity_1D_out.metadata == testQuantity_1D_out.metadata
assert testQuantity_1D_out.metadata == testQuantity_1D.metadata
assert (
testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size)
).all()
Expand Down

0 comments on commit fd2fa97

Please sign in to comment.