diff --git a/ndsl/dsl/typing.py b/ndsl/dsl/typing.py index 3b6ba44c..b3fa72d8 100644 --- a/ndsl/dsl/typing.py +++ b/ndsl/dsl/typing.py @@ -89,3 +89,14 @@ def cast_to_index3d(val: Tuple[int, ...]) -> Index3D: if len(val) != 3: raise ValueError(f"expected 3d index, received {val}") return cast(Index3D, val) + + +def is_float(dtype: type): + """Expected floating point type""" + return dtype in [ + Float, + float, + np.float16, + np.float32, + np.float64, + ] diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py new file mode 100644 index 00000000..fce0cf63 --- /dev/null +++ b/ndsl/quantity/__init__.py @@ -0,0 +1,9 @@ +from ndsl.quantity.metadata import QuantityHaloSpec, QuantityMetadata +from ndsl.quantity.quantity import Quantity + + +__all__ = [ + "Quantity", + "QuantityMetadata", + "QuantityHaloSpec", +] diff --git a/ndsl/quantity/bounds.py b/ndsl/quantity/bounds.py new file mode 100644 index 00000000..419cff22 --- /dev/null +++ b/ndsl/quantity/bounds.py @@ -0,0 +1,190 @@ +from typing import Sequence, Tuple, Union + +import numpy as np + +import ndsl.constants as constants +from ndsl.comm._boundary_utils import bound_default_slice, shift_boundary_slice_tuple + + +class BoundaryArrayView: + def __init__(self, data, boundary_type, dims, origin, extent): + self._data = data + self._boundary_type = boundary_type + self._dims = dims + self._origin = origin + self._extent = extent + + def __getitem__(self, index): + if len(self._origin) == 0: + if isinstance(index, tuple) and len(index) > 0: + raise IndexError("more than one index given for a zero-dimension array") + elif isinstance(index, slice) and index != slice(None, None, None): + raise IndexError("cannot slice a zero-dimension array") + else: + return self._data # array[()] does not return an ndarray + else: + return self._data[self._get_array_index(index)] + + def __setitem__(self, index, value): + self._data[self._get_array_index(index)] = value + + def _get_array_index(self, index): + if isinstance(index, list): + index = tuple(index) + if not isinstance(index, tuple): + index = (index,) + if len(index) > len(self._dims): + raise IndexError( + f"{len(index)} is too many indices for a " + f"{len(self._dims)}-dimensional quantity" + ) + if len(index) < len(self._dims): + index = index + (slice(None, None),) * (len(self._dims) - len(index)) + return shift_boundary_slice_tuple( + self._dims, self._origin, self._extent, self._boundary_type, index + ) + + def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: + """Convenience method to perform indexing using dimension names + without knowing dimension order. + + Args: + **kwargs: slice/index to retrieve for a given dimension name + + Returns: + view_selection: an ndarray-like selection of the given indices + on `self.view` + """ + return self[tuple(kwargs.get(dim, slice(None, None)) for dim in self._dims)] + + +class BoundedArrayView: + """ + A container of objects which provide indexing relative to corners and edges + of the computational domain for convenience. + + Default start and end indices for all dimensions are modified to be the + start and end of the compute domain. When using edge and corner attributes, it is + recommended to explicitly write start and end offsets to avoid confusion. + + Indexing on the object itself (view[:]) is offset by the origin, and default + start and end indices are modified to be the start and end of the compute domain. + + For corner attributes e.g. `northwest`, modified indexing is done for the two + axes according to the edges which make up the corner. In other words, indexing + is offset relative to the intersection of the two edges which make the corner. + + For `interior`, start indices of the horizontal dimensions are relative to the + origin, and end indices are relative to the origin + extent. For example, + view.interior[0:0, 0:0, :] would retrieve the entire compute domain for an x/y/z + array, while view.interior[-1:1, -1:1, :] would also include one halo point. + """ + + def __init__( + self, array, dims: Sequence[str], origin: Sequence[int], extent: Sequence[int] + ): + self._data = array + self._dims = tuple(dims) + self._origin = tuple(origin) + self._extent = tuple(extent) + self._northwest = BoundaryArrayView( + array, constants.NORTHWEST, dims, origin, extent + ) + self._northeast = BoundaryArrayView( + array, constants.NORTHEAST, dims, origin, extent + ) + self._southwest = BoundaryArrayView( + array, constants.SOUTHWEST, dims, origin, extent + ) + self._southeast = BoundaryArrayView( + array, constants.SOUTHEAST, dims, origin, extent + ) + self._interior = BoundaryArrayView( + array, constants.INTERIOR, dims, origin, extent + ) + + @property + def origin(self) -> Tuple[int, ...]: + """the start of the computational domain""" + return self._origin + + @property + def extent(self) -> Tuple[int, ...]: + """the shape of the computational domain""" + return self._extent + + def __getitem__(self, index): + if len(self.origin) == 0: + if isinstance(index, tuple) and len(index) > 0: + raise IndexError("more than one index given for a zero-dimension array") + elif isinstance(index, slice) and index != slice(None, None, None): + raise IndexError("cannot slice a zero-dimension array") + else: + return self._data # array[()] does not return an ndarray + else: + return self._data[self._get_compute_index(index)] + + def __setitem__(self, index, value): + self._data[self._get_compute_index(index)] = value + + def _get_compute_index(self, index): + if not isinstance(index, (tuple, list)): + index = (index,) + if len(index) > len(self._dims): + raise IndexError( + f"{len(index)} is too many indices for a " + f"{len(self._dims)}-dimensional quantity" + ) + index = _fill_index(index, len(self._data.shape)) + shifted_index = [] + for entry, origin, extent in zip(index, self.origin, self.extent): + if isinstance(entry, slice): + shifted_slice = _shift_slice(entry, origin, extent) + shifted_index.append( + bound_default_slice(shifted_slice, origin, origin + extent) + ) + elif entry is None: + shifted_index.append(entry) + else: + shifted_index.append(entry + origin) + return tuple(shifted_index) + + @property + def northwest(self) -> BoundaryArrayView: + return self._northwest + + @property + def northeast(self) -> BoundaryArrayView: + return self._northeast + + @property + def southwest(self) -> BoundaryArrayView: + return self._southwest + + @property + def southeast(self) -> BoundaryArrayView: + return self._southeast + + @property + def interior(self) -> BoundaryArrayView: + return self._interior + + +def _fill_index(index, length): + return tuple(index) + (slice(None, None, None),) * (length - len(index)) + + +def _shift_slice(slice_in, shift, extent): + start = _shift_index(slice_in.start, shift, extent) + stop = _shift_index(slice_in.stop, shift, extent) + return slice(start, stop, slice_in.step) + + +def _shift_index(current_value, shift, extent): + if current_value is None: + new_value = None + else: + new_value = current_value + shift + if new_value < 0: + new_value = extent + new_value + return new_value diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py new file mode 100644 index 00000000..d7ddba0f --- /dev/null +++ b/ndsl/quantity/metadata.py @@ -0,0 +1,70 @@ +import dataclasses +from typing import Any, Dict, Tuple, Union + +import numpy as np + +from ndsl.optional_imports import cupy +from ndsl.types import NumpyModule + + +if cupy is None: + import numpy as cupy + + +@dataclasses.dataclass +class QuantityMetadata: + origin: Tuple[int, ...] + "the start of the computational domain" + extent: Tuple[int, ...] + "the shape of the computational domain" + dims: Tuple[str, ...] + "names of each dimension" + units: str + "units of the quantity" + data_type: type + "ndarray-like type used to store the data" + dtype: type + "dtype of the data in the ndarray-like object" + gt4py_backend: Union[str, None] = None + "backend to use for gt4py storages" + + @property + def dim_lengths(self) -> Dict[str, int]: + """mapping of dimension names to their lengths""" + return dict(zip(self.dims, self.extent)) + + @property + def np(self) -> NumpyModule: + """numpy-like module used to interact with the data""" + if issubclass(self.data_type, cupy.ndarray): + return cupy + elif issubclass(self.data_type, np.ndarray): + return np + else: + raise TypeError( + f"quantity underlying data is of unexpected type {self.data_type}" + ) + + def duplicate_metadata(self, metadata_copy): + metadata_copy.origin = self.origin + metadata_copy.extent = self.extent + metadata_copy.dims = self.dims + metadata_copy.units = self.units + metadata_copy.data_type = self.data_type + metadata_copy.dtype = self.dtype + metadata_copy.gt4py_backend = self.gt4py_backend + + +@dataclasses.dataclass +class QuantityHaloSpec: + """Describe the memory to be exchanged, including size of the halo.""" + + n_points: int + strides: Tuple[int] + itemsize: int + shape: Tuple[int] + origin: Tuple[int, ...] + extent: Tuple[int, ...] + dims: Tuple[str, ...] + numpy_module: NumpyModule + dtype: Any diff --git a/ndsl/quantity.py b/ndsl/quantity/quantity.py similarity index 59% rename from ndsl/quantity.py rename to ndsl/quantity/quantity.py index a38a7a5d..c88ba140 100644 --- a/ndsl/quantity.py +++ b/ndsl/quantity/quantity.py @@ -1,283 +1,21 @@ -import dataclasses import warnings -from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, cast +from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast import matplotlib.pyplot as plt import numpy as np import ndsl.constants as constants -from ndsl.comm._boundary_utils import bound_default_slice, shift_boundary_slice_tuple -from ndsl.dsl.typing import Float +from ndsl.dsl.typing import Float, is_float from ndsl.optional_imports import cupy, dace, gt4py from ndsl.optional_imports import xarray as xr +from ndsl.quantity.bounds import BoundedArrayView +from ndsl.quantity.metadata import QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule if cupy is None: import numpy as cupy -__all__ = ["Quantity", "QuantityMetadata"] - - -@dataclasses.dataclass -class QuantityMetadata: - origin: Tuple[int, ...] - "the start of the computational domain" - extent: Tuple[int, ...] - "the shape of the computational domain" - dims: Tuple[str, ...] - "names of each dimension" - units: str - "units of the quantity" - data_type: type - "ndarray-like type used to store the data" - dtype: type - "dtype of the data in the ndarray-like object" - gt4py_backend: Union[str, None] = None - "backend to use for gt4py storages" - - @property - def dim_lengths(self) -> Dict[str, int]: - """mapping of dimension names to their lengths""" - return dict(zip(self.dims, self.extent)) - - @property - def np(self) -> NumpyModule: - """numpy-like module used to interact with the data""" - if issubclass(self.data_type, cupy.ndarray): - return cupy - elif issubclass(self.data_type, np.ndarray): - return np - else: - raise TypeError( - f"quantity underlying data is of unexpected type {self.data_type}" - ) - - def duplicate_metadata(self, metadata_copy): - metadata_copy.origin = self.origin - metadata_copy.extent = self.extent - metadata_copy.dims = self.dims - metadata_copy.units = self.units - metadata_copy.data_type = self.data_type - metadata_copy.dtype = self.dtype - metadata_copy.gt4py_backend = self.gt4py_backend - - -@dataclasses.dataclass -class QuantityHaloSpec: - """Describe the memory to be exchanged, including size of the halo.""" - - n_points: int - strides: Tuple[int] - itemsize: int - shape: Tuple[int] - origin: Tuple[int, ...] - extent: Tuple[int, ...] - dims: Tuple[str, ...] - numpy_module: NumpyModule - dtype: Any - - -class BoundaryArrayView: - def __init__(self, data, boundary_type, dims, origin, extent): - self._data = data - self._boundary_type = boundary_type - self._dims = dims - self._origin = origin - self._extent = extent - - def __getitem__(self, index): - if len(self._origin) == 0: - if isinstance(index, tuple) and len(index) > 0: - raise IndexError("more than one index given for a zero-dimension array") - elif isinstance(index, slice) and index != slice(None, None, None): - raise IndexError("cannot slice a zero-dimension array") - else: - return self._data # array[()] does not return an ndarray - else: - return self._data[self._get_array_index(index)] - - def __setitem__(self, index, value): - self._data[self._get_array_index(index)] = value - - def _get_array_index(self, index): - if isinstance(index, list): - index = tuple(index) - if not isinstance(index, tuple): - index = (index,) - if len(index) > len(self._dims): - raise IndexError( - f"{len(index)} is too many indices for a " - f"{len(self._dims)}-dimensional quantity" - ) - if len(index) < len(self._dims): - index = index + (slice(None, None),) * (len(self._dims) - len(index)) - return shift_boundary_slice_tuple( - self._dims, self._origin, self._extent, self._boundary_type, index - ) - - def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: - """Convenience method to perform indexing using dimension names - without knowing dimension order. - - Args: - **kwargs: slice/index to retrieve for a given dimension name - - Returns: - view_selection: an ndarray-like selection of the given indices - on `self.view` - """ - return self[tuple(kwargs.get(dim, slice(None, None)) for dim in self._dims)] - - -class BoundedArrayView: - """ - A container of objects which provide indexing relative to corners and edges - of the computational domain for convenience. - - Default start and end indices for all dimensions are modified to be the - start and end of the compute domain. When using edge and corner attributes, it is - recommended to explicitly write start and end offsets to avoid confusion. - - Indexing on the object itself (view[:]) is offset by the origin, and default - start and end indices are modified to be the start and end of the compute domain. - - For corner attributes e.g. `northwest`, modified indexing is done for the two - axes according to the edges which make up the corner. In other words, indexing - is offset relative to the intersection of the two edges which make the corner. - - For `interior`, start indices of the horizontal dimensions are relative to the - origin, and end indices are relative to the origin + extent. For example, - view.interior[0:0, 0:0, :] would retrieve the entire compute domain for an x/y/z - array, while view.interior[-1:1, -1:1, :] would also include one halo point. - """ - - def __init__( - self, array, dims: Sequence[str], origin: Sequence[int], extent: Sequence[int] - ): - self._data = array - self._dims = tuple(dims) - self._origin = tuple(origin) - self._extent = tuple(extent) - self._northwest = BoundaryArrayView( - array, constants.NORTHWEST, dims, origin, extent - ) - self._northeast = BoundaryArrayView( - array, constants.NORTHEAST, dims, origin, extent - ) - self._southwest = BoundaryArrayView( - array, constants.SOUTHWEST, dims, origin, extent - ) - self._southeast = BoundaryArrayView( - array, constants.SOUTHEAST, dims, origin, extent - ) - self._interior = BoundaryArrayView( - array, constants.INTERIOR, dims, origin, extent - ) - - @property - def origin(self) -> Tuple[int, ...]: - """the start of the computational domain""" - return self._origin - - @property - def extent(self) -> Tuple[int, ...]: - """the shape of the computational domain""" - return self._extent - - def __getitem__(self, index): - if len(self.origin) == 0: - if isinstance(index, tuple) and len(index) > 0: - raise IndexError("more than one index given for a zero-dimension array") - elif isinstance(index, slice) and index != slice(None, None, None): - raise IndexError("cannot slice a zero-dimension array") - else: - return self._data # array[()] does not return an ndarray - else: - return self._data[self._get_compute_index(index)] - - def __setitem__(self, index, value): - self._data[self._get_compute_index(index)] = value - - def _get_compute_index(self, index): - if not isinstance(index, (tuple, list)): - index = (index,) - if len(index) > len(self._dims): - raise IndexError( - f"{len(index)} is too many indices for a " - f"{len(self._dims)}-dimensional quantity" - ) - index = fill_index(index, len(self._data.shape)) - shifted_index = [] - for entry, origin, extent in zip(index, self.origin, self.extent): - if isinstance(entry, slice): - shifted_slice = shift_slice(entry, origin, extent) - shifted_index.append( - bound_default_slice(shifted_slice, origin, origin + extent) - ) - elif entry is None: - shifted_index.append(entry) - else: - shifted_index.append(entry + origin) - return tuple(shifted_index) - - @property - def northwest(self) -> BoundaryArrayView: - return self._northwest - - @property - def northeast(self) -> BoundaryArrayView: - return self._northeast - - @property - def southwest(self) -> BoundaryArrayView: - return self._southwest - - @property - def southeast(self) -> BoundaryArrayView: - return self._southeast - - @property - def interior(self) -> BoundaryArrayView: - return self._interior - - -def ensure_int_tuple(arg, arg_name): - return_list = [] - for item in arg: - try: - return_list.append(int(item)) - except ValueError: - raise TypeError( - f"tuple arg {arg_name}={arg} contains item {item} of " - f"unexpected type {type(item)}" - ) - return tuple(return_list) - - -def _validate_quantity_property_lengths(shape, dims, origin, extent): - n_dims = len(shape) - for var, desc in ( - (dims, "dimension names"), - (origin, "origins"), - (extent, "extents"), - ): - if len(var) != n_dims: - raise ValueError( - f"received {len(var)} {desc} for {n_dims} dimensions: {var}" - ) - - -def _is_float(dtype): - """Expected floating point type for Pace""" - return ( - dtype == Float - or dtype == float - or dtype == np.float32 - or dtype == np.float64 - or dtype == np.float16 - ) - class Quantity: """ @@ -311,7 +49,7 @@ def __init__( if ( not allow_mismatch_float_precision - and _is_float(data.dtype) + and is_float(data.dtype) and data.dtype != Float ): raise ValueError( @@ -371,8 +109,8 @@ def __init__( _validate_quantity_property_lengths(data.shape, dims, origin, extent) self._metadata = QuantityMetadata( - origin=ensure_int_tuple(origin, "origin"), - extent=ensure_int_tuple(extent, "extent"), + origin=_ensure_int_tuple(origin, "origin"), + extent=_ensure_int_tuple(extent, "extent"), dims=tuple(dims), units=units, data_type=type(self._data), @@ -598,10 +336,10 @@ def transpose( transpose_order = [self.dims.index(dim) for dim in target_dims] transposed = Quantity( self.np.transpose(self.data, transpose_order), # type: ignore[attr-defined] - dims=transpose_sequence(self.dims, transpose_order), + dims=_transpose_sequence(self.dims, transpose_order), units=self.units, - origin=transpose_sequence(self.origin, transpose_order), - extent=transpose_sequence(self.extent, transpose_order), + origin=_transpose_sequence(self.origin, transpose_order), + extent=_transpose_sequence(self.extent, transpose_order), gt4py_backend=self.gt4py_backend, allow_mismatch_float_precision=allow_mismatch_float_precision, ) @@ -625,7 +363,7 @@ def plot_k_level(self, k_index=0): plt.show() -def transpose_sequence(sequence, order): +def _transpose_sequence(sequence, order): return sequence.__class__(sequence[i] for i in order) @@ -655,21 +393,27 @@ def _collapse_dims(target_dims, dims): return return_list -def fill_index(index, length): - return tuple(index) + (slice(None, None, None),) * (length - len(index)) - - -def shift_slice(slice_in, shift, extent): - start = shift_index(slice_in.start, shift, extent) - stop = shift_index(slice_in.stop, shift, extent) - return slice(start, stop, slice_in.step) +def _validate_quantity_property_lengths(shape, dims, origin, extent): + n_dims = len(shape) + for var, desc in ( + (dims, "dimension names"), + (origin, "origins"), + (extent, "extents"), + ): + if len(var) != n_dims: + raise ValueError( + f"received {len(var)} {desc} for {n_dims} dimensions: {var}" + ) -def shift_index(current_value, shift, extent): - if current_value is None: - new_value = None - else: - new_value = current_value + shift - if new_value < 0: - new_value = extent + new_value - return new_value +def _ensure_int_tuple(arg, arg_name): + return_list = [] + for item in arg: + try: + return_list.append(int(item)) + except ValueError: + raise TypeError( + f"tuple arg {arg_name}={arg} contains item {item} of " + f"unexpected type {type(item)}" + ) + return tuple(return_list) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index a6de628b..61e92025 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -1,8 +1,8 @@ import numpy as np import pytest -import ndsl.quantity as qty from ndsl import Quantity +from ndsl.quantity.bounds import _shift_slice try: @@ -229,7 +229,7 @@ def test_compute_view_edit_all_domain(quantity, n_halo, n_dims, extent_1d): ], ) def test_shift_slice(slice_in, shift, extent, slice_out): - result = qty.shift_slice(slice_in, shift, extent) + result = _shift_slice(slice_in, shift, extent) assert result == slice_out