From 502486f4847bc574bc84730d303097ece2f455d4 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 10 Dec 2024 09:37:44 -0500 Subject: [PATCH 1/8] Move `is_float` to `dsl.typing` --- ndsl/dsl/typing.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ndsl/dsl/typing.py b/ndsl/dsl/typing.py index 69b35c8..513ab1f 100644 --- a/ndsl/dsl/typing.py +++ b/ndsl/dsl/typing.py @@ -79,3 +79,14 @@ def cast_to_index3d(val: Tuple[int, ...]) -> Index3D: if len(val) != 3: raise ValueError(f"expected 3d index, received {val}") return cast(Index3D, val) + + +def is_float(dtype: type): + """Expected floating point type""" + return dtype in [ + Float, + float, + np.float16, + np.float32, + np.float64, + ] From a13776fa2d91f558be4d8ce914022e4e6ccc9c4f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 10 Dec 2024 09:38:16 -0500 Subject: [PATCH 2/8] Move Quantity to sub-directory + breakout the subcomponent --- ndsl/quantity/__init__.py | 9 + ndsl/quantity/bounds.py | 190 +++++++++++++++++++ ndsl/quantity/metadata.py | 56 ++++++ ndsl/{ => quantity}/quantity.py | 317 ++++---------------------------- 4 files changed, 288 insertions(+), 284 deletions(-) create mode 100644 ndsl/quantity/__init__.py create mode 100644 ndsl/quantity/bounds.py create mode 100644 ndsl/quantity/metadata.py rename ndsl/{ => quantity}/quantity.py (59%) diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py new file mode 100644 index 0000000..c8f68ea --- /dev/null +++ b/ndsl/quantity/__init__.py @@ -0,0 +1,9 @@ +from ndsl.quantity.quantity import Quantity +from ndsl.quantity.metadata import QuantityMetadata, QuantityHaloSpec + + +__all__ = [ + "Quantity", + "QuantityMetadata", + "QuantityHaloSpec", +] diff --git a/ndsl/quantity/bounds.py b/ndsl/quantity/bounds.py new file mode 100644 index 0000000..419cff2 --- /dev/null +++ b/ndsl/quantity/bounds.py @@ -0,0 +1,190 @@ +from typing import Sequence, Tuple, Union + +import numpy as np + +import ndsl.constants as constants +from ndsl.comm._boundary_utils import bound_default_slice, shift_boundary_slice_tuple + + +class BoundaryArrayView: + def __init__(self, data, boundary_type, dims, origin, extent): + self._data = data + self._boundary_type = boundary_type + self._dims = dims + self._origin = origin + self._extent = extent + + def __getitem__(self, index): + if len(self._origin) == 0: + if isinstance(index, tuple) and len(index) > 0: + raise IndexError("more than one index given for a zero-dimension array") + elif isinstance(index, slice) and index != slice(None, None, None): + raise IndexError("cannot slice a zero-dimension array") + else: + return self._data # array[()] does not return an ndarray + else: + return self._data[self._get_array_index(index)] + + def __setitem__(self, index, value): + self._data[self._get_array_index(index)] = value + + def _get_array_index(self, index): + if isinstance(index, list): + index = tuple(index) + if not isinstance(index, tuple): + index = (index,) + if len(index) > len(self._dims): + raise IndexError( + f"{len(index)} is too many indices for a " + f"{len(self._dims)}-dimensional quantity" + ) + if len(index) < len(self._dims): + index = index + (slice(None, None),) * (len(self._dims) - len(index)) + return shift_boundary_slice_tuple( + self._dims, self._origin, self._extent, self._boundary_type, index + ) + + def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: + """Convenience method to perform indexing using dimension names + without knowing dimension order. + + Args: + **kwargs: slice/index to retrieve for a given dimension name + + Returns: + view_selection: an ndarray-like selection of the given indices + on `self.view` + """ + return self[tuple(kwargs.get(dim, slice(None, None)) for dim in self._dims)] + + +class BoundedArrayView: + """ + A container of objects which provide indexing relative to corners and edges + of the computational domain for convenience. + + Default start and end indices for all dimensions are modified to be the + start and end of the compute domain. When using edge and corner attributes, it is + recommended to explicitly write start and end offsets to avoid confusion. + + Indexing on the object itself (view[:]) is offset by the origin, and default + start and end indices are modified to be the start and end of the compute domain. + + For corner attributes e.g. `northwest`, modified indexing is done for the two + axes according to the edges which make up the corner. In other words, indexing + is offset relative to the intersection of the two edges which make the corner. + + For `interior`, start indices of the horizontal dimensions are relative to the + origin, and end indices are relative to the origin + extent. For example, + view.interior[0:0, 0:0, :] would retrieve the entire compute domain for an x/y/z + array, while view.interior[-1:1, -1:1, :] would also include one halo point. + """ + + def __init__( + self, array, dims: Sequence[str], origin: Sequence[int], extent: Sequence[int] + ): + self._data = array + self._dims = tuple(dims) + self._origin = tuple(origin) + self._extent = tuple(extent) + self._northwest = BoundaryArrayView( + array, constants.NORTHWEST, dims, origin, extent + ) + self._northeast = BoundaryArrayView( + array, constants.NORTHEAST, dims, origin, extent + ) + self._southwest = BoundaryArrayView( + array, constants.SOUTHWEST, dims, origin, extent + ) + self._southeast = BoundaryArrayView( + array, constants.SOUTHEAST, dims, origin, extent + ) + self._interior = BoundaryArrayView( + array, constants.INTERIOR, dims, origin, extent + ) + + @property + def origin(self) -> Tuple[int, ...]: + """the start of the computational domain""" + return self._origin + + @property + def extent(self) -> Tuple[int, ...]: + """the shape of the computational domain""" + return self._extent + + def __getitem__(self, index): + if len(self.origin) == 0: + if isinstance(index, tuple) and len(index) > 0: + raise IndexError("more than one index given for a zero-dimension array") + elif isinstance(index, slice) and index != slice(None, None, None): + raise IndexError("cannot slice a zero-dimension array") + else: + return self._data # array[()] does not return an ndarray + else: + return self._data[self._get_compute_index(index)] + + def __setitem__(self, index, value): + self._data[self._get_compute_index(index)] = value + + def _get_compute_index(self, index): + if not isinstance(index, (tuple, list)): + index = (index,) + if len(index) > len(self._dims): + raise IndexError( + f"{len(index)} is too many indices for a " + f"{len(self._dims)}-dimensional quantity" + ) + index = _fill_index(index, len(self._data.shape)) + shifted_index = [] + for entry, origin, extent in zip(index, self.origin, self.extent): + if isinstance(entry, slice): + shifted_slice = _shift_slice(entry, origin, extent) + shifted_index.append( + bound_default_slice(shifted_slice, origin, origin + extent) + ) + elif entry is None: + shifted_index.append(entry) + else: + shifted_index.append(entry + origin) + return tuple(shifted_index) + + @property + def northwest(self) -> BoundaryArrayView: + return self._northwest + + @property + def northeast(self) -> BoundaryArrayView: + return self._northeast + + @property + def southwest(self) -> BoundaryArrayView: + return self._southwest + + @property + def southeast(self) -> BoundaryArrayView: + return self._southeast + + @property + def interior(self) -> BoundaryArrayView: + return self._interior + + +def _fill_index(index, length): + return tuple(index) + (slice(None, None, None),) * (length - len(index)) + + +def _shift_slice(slice_in, shift, extent): + start = _shift_index(slice_in.start, shift, extent) + stop = _shift_index(slice_in.stop, shift, extent) + return slice(start, stop, slice_in.step) + + +def _shift_index(current_value, shift, extent): + if current_value is None: + new_value = None + else: + new_value = current_value + shift + if new_value < 0: + new_value = extent + new_value + return new_value diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py new file mode 100644 index 0000000..0051f6b --- /dev/null +++ b/ndsl/quantity/metadata.py @@ -0,0 +1,56 @@ +import dataclasses +from typing import Any, Dict, Tuple, Union + +import numpy as np +from ndsl.optional_imports import cupy +from ndsl.types import NumpyModule + + +@dataclasses.dataclass +class QuantityMetadata: + origin: Tuple[int, ...] + "the start of the computational domain" + extent: Tuple[int, ...] + "the shape of the computational domain" + dims: Tuple[str, ...] + "names of each dimension" + units: str + "units of the quantity" + data_type: type + "ndarray-like type used to store the data" + dtype: type + "dtype of the data in the ndarray-like object" + gt4py_backend: Union[str, None] = None + "backend to use for gt4py storages" + + @property + def dim_lengths(self) -> Dict[str, int]: + """mapping of dimension names to their lengths""" + return dict(zip(self.dims, self.extent)) + + @property + def np(self) -> NumpyModule: + """numpy-like module used to interact with the data""" + if issubclass(self.data_type, cupy.ndarray): + return cupy + elif issubclass(self.data_type, np.ndarray): + return np + else: + raise TypeError( + f"quantity underlying data is of unexpected type {self.data_type}" + ) + + +@dataclasses.dataclass +class QuantityHaloSpec: + """Describe the memory to be exchanged, including size of the halo.""" + + n_points: int + strides: Tuple[int] + itemsize: int + shape: Tuple[int] + origin: Tuple[int, ...] + extent: Tuple[int, ...] + dims: Tuple[str, ...] + numpy_module: NumpyModule + dtype: Any diff --git a/ndsl/quantity.py b/ndsl/quantity/quantity.py similarity index 59% rename from ndsl/quantity.py rename to ndsl/quantity/quantity.py index b95a9aa..00e66f0 100644 --- a/ndsl/quantity.py +++ b/ndsl/quantity/quantity.py @@ -1,273 +1,16 @@ -import dataclasses import warnings -from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, cast +from typing import Any, Iterable, Optional, Sequence, Tuple, Union, cast import matplotlib.pyplot as plt import numpy as np import ndsl.constants as constants -from ndsl.comm._boundary_utils import bound_default_slice, shift_boundary_slice_tuple -from ndsl.dsl.typing import Float +from ndsl.dsl.typing import Float, is_float from ndsl.optional_imports import cupy, dace, gt4py from ndsl.optional_imports import xarray as xr from ndsl.types import NumpyModule - - -if cupy is None: - import numpy as cupy - -__all__ = ["Quantity", "QuantityMetadata"] - - -@dataclasses.dataclass -class QuantityMetadata: - origin: Tuple[int, ...] - "the start of the computational domain" - extent: Tuple[int, ...] - "the shape of the computational domain" - dims: Tuple[str, ...] - "names of each dimension" - units: str - "units of the quantity" - data_type: type - "ndarray-like type used to store the data" - dtype: type - "dtype of the data in the ndarray-like object" - gt4py_backend: Union[str, None] = None - "backend to use for gt4py storages" - - @property - def dim_lengths(self) -> Dict[str, int]: - """mapping of dimension names to their lengths""" - return dict(zip(self.dims, self.extent)) - - @property - def np(self) -> NumpyModule: - """numpy-like module used to interact with the data""" - if issubclass(self.data_type, cupy.ndarray): - return cupy - elif issubclass(self.data_type, np.ndarray): - return np - else: - raise TypeError( - f"quantity underlying data is of unexpected type {self.data_type}" - ) - - -@dataclasses.dataclass -class QuantityHaloSpec: - """Describe the memory to be exchanged, including size of the halo.""" - - n_points: int - strides: Tuple[int] - itemsize: int - shape: Tuple[int] - origin: Tuple[int, ...] - extent: Tuple[int, ...] - dims: Tuple[str, ...] - numpy_module: NumpyModule - dtype: Any - - -class BoundaryArrayView: - def __init__(self, data, boundary_type, dims, origin, extent): - self._data = data - self._boundary_type = boundary_type - self._dims = dims - self._origin = origin - self._extent = extent - - def __getitem__(self, index): - if len(self._origin) == 0: - if isinstance(index, tuple) and len(index) > 0: - raise IndexError("more than one index given for a zero-dimension array") - elif isinstance(index, slice) and index != slice(None, None, None): - raise IndexError("cannot slice a zero-dimension array") - else: - return self._data # array[()] does not return an ndarray - else: - return self._data[self._get_array_index(index)] - - def __setitem__(self, index, value): - self._data[self._get_array_index(index)] = value - - def _get_array_index(self, index): - if isinstance(index, list): - index = tuple(index) - if not isinstance(index, tuple): - index = (index,) - if len(index) > len(self._dims): - raise IndexError( - f"{len(index)} is too many indices for a " - f"{len(self._dims)}-dimensional quantity" - ) - if len(index) < len(self._dims): - index = index + (slice(None, None),) * (len(self._dims) - len(index)) - return shift_boundary_slice_tuple( - self._dims, self._origin, self._extent, self._boundary_type, index - ) - - def sel(self, **kwargs: Union[slice, int]) -> np.ndarray: - """Convenience method to perform indexing using dimension names - without knowing dimension order. - - Args: - **kwargs: slice/index to retrieve for a given dimension name - - Returns: - view_selection: an ndarray-like selection of the given indices - on `self.view` - """ - return self[tuple(kwargs.get(dim, slice(None, None)) for dim in self._dims)] - - -class BoundedArrayView: - """ - A container of objects which provide indexing relative to corners and edges - of the computational domain for convenience. - - Default start and end indices for all dimensions are modified to be the - start and end of the compute domain. When using edge and corner attributes, it is - recommended to explicitly write start and end offsets to avoid confusion. - - Indexing on the object itself (view[:]) is offset by the origin, and default - start and end indices are modified to be the start and end of the compute domain. - - For corner attributes e.g. `northwest`, modified indexing is done for the two - axes according to the edges which make up the corner. In other words, indexing - is offset relative to the intersection of the two edges which make the corner. - - For `interior`, start indices of the horizontal dimensions are relative to the - origin, and end indices are relative to the origin + extent. For example, - view.interior[0:0, 0:0, :] would retrieve the entire compute domain for an x/y/z - array, while view.interior[-1:1, -1:1, :] would also include one halo point. - """ - - def __init__( - self, array, dims: Sequence[str], origin: Sequence[int], extent: Sequence[int] - ): - self._data = array - self._dims = tuple(dims) - self._origin = tuple(origin) - self._extent = tuple(extent) - self._northwest = BoundaryArrayView( - array, constants.NORTHWEST, dims, origin, extent - ) - self._northeast = BoundaryArrayView( - array, constants.NORTHEAST, dims, origin, extent - ) - self._southwest = BoundaryArrayView( - array, constants.SOUTHWEST, dims, origin, extent - ) - self._southeast = BoundaryArrayView( - array, constants.SOUTHEAST, dims, origin, extent - ) - self._interior = BoundaryArrayView( - array, constants.INTERIOR, dims, origin, extent - ) - - @property - def origin(self) -> Tuple[int, ...]: - """the start of the computational domain""" - return self._origin - - @property - def extent(self) -> Tuple[int, ...]: - """the shape of the computational domain""" - return self._extent - - def __getitem__(self, index): - if len(self.origin) == 0: - if isinstance(index, tuple) and len(index) > 0: - raise IndexError("more than one index given for a zero-dimension array") - elif isinstance(index, slice) and index != slice(None, None, None): - raise IndexError("cannot slice a zero-dimension array") - else: - return self._data # array[()] does not return an ndarray - else: - return self._data[self._get_compute_index(index)] - - def __setitem__(self, index, value): - self._data[self._get_compute_index(index)] = value - - def _get_compute_index(self, index): - if not isinstance(index, (tuple, list)): - index = (index,) - if len(index) > len(self._dims): - raise IndexError( - f"{len(index)} is too many indices for a " - f"{len(self._dims)}-dimensional quantity" - ) - index = fill_index(index, len(self._data.shape)) - shifted_index = [] - for entry, origin, extent in zip(index, self.origin, self.extent): - if isinstance(entry, slice): - shifted_slice = shift_slice(entry, origin, extent) - shifted_index.append( - bound_default_slice(shifted_slice, origin, origin + extent) - ) - elif entry is None: - shifted_index.append(entry) - else: - shifted_index.append(entry + origin) - return tuple(shifted_index) - - @property - def northwest(self) -> BoundaryArrayView: - return self._northwest - - @property - def northeast(self) -> BoundaryArrayView: - return self._northeast - - @property - def southwest(self) -> BoundaryArrayView: - return self._southwest - - @property - def southeast(self) -> BoundaryArrayView: - return self._southeast - - @property - def interior(self) -> BoundaryArrayView: - return self._interior - - -def ensure_int_tuple(arg, arg_name): - return_list = [] - for item in arg: - try: - return_list.append(int(item)) - except ValueError: - raise TypeError( - f"tuple arg {arg_name}={arg} contains item {item} of " - f"unexpected type {type(item)}" - ) - return tuple(return_list) - - -def _validate_quantity_property_lengths(shape, dims, origin, extent): - n_dims = len(shape) - for var, desc in ( - (dims, "dimension names"), - (origin, "origins"), - (extent, "extents"), - ): - if len(var) != n_dims: - raise ValueError( - f"received {len(var)} {desc} for {n_dims} dimensions: {var}" - ) - - -def _is_float(dtype): - """Expected floating point type for Pace""" - return ( - dtype == Float - or dtype == float - or dtype == np.float32 - or dtype == np.float64 - or dtype == np.float16 - ) +from ndsl.quantity.bounds import BoundedArrayView +from ndsl.quantity.metadata import QuantityMetadata, QuantityHaloSpec class Quantity: @@ -302,7 +45,7 @@ def __init__( if ( not allow_mismatch_float_precision - and _is_float(data.dtype) + and is_float(data.dtype) and data.dtype != Float ): raise ValueError( @@ -362,8 +105,8 @@ def __init__( _validate_quantity_property_lengths(data.shape, dims, origin, extent) self._metadata = QuantityMetadata( - origin=ensure_int_tuple(origin, "origin"), - extent=ensure_int_tuple(extent, "extent"), + origin=_ensure_int_tuple(origin, "origin"), + extent=_ensure_int_tuple(extent, "extent"), dims=tuple(dims), units=units, data_type=type(self._data), @@ -584,10 +327,10 @@ def transpose( transpose_order = [self.dims.index(dim) for dim in target_dims] transposed = Quantity( self.np.transpose(self.data, transpose_order), # type: ignore[attr-defined] - dims=transpose_sequence(self.dims, transpose_order), + dims=_transpose_sequence(self.dims, transpose_order), units=self.units, - origin=transpose_sequence(self.origin, transpose_order), - extent=transpose_sequence(self.extent, transpose_order), + origin=_transpose_sequence(self.origin, transpose_order), + extent=_transpose_sequence(self.extent, transpose_order), gt4py_backend=self.gt4py_backend, allow_mismatch_float_precision=allow_mismatch_float_precision, ) @@ -611,7 +354,7 @@ def plot_k_level(self, k_index=0): plt.show() -def transpose_sequence(sequence, order): +def _transpose_sequence(sequence, order): return sequence.__class__(sequence[i] for i in order) @@ -641,21 +384,27 @@ def _collapse_dims(target_dims, dims): return return_list -def fill_index(index, length): - return tuple(index) + (slice(None, None, None),) * (length - len(index)) - - -def shift_slice(slice_in, shift, extent): - start = shift_index(slice_in.start, shift, extent) - stop = shift_index(slice_in.stop, shift, extent) - return slice(start, stop, slice_in.step) +def _validate_quantity_property_lengths(shape, dims, origin, extent): + n_dims = len(shape) + for var, desc in ( + (dims, "dimension names"), + (origin, "origins"), + (extent, "extents"), + ): + if len(var) != n_dims: + raise ValueError( + f"received {len(var)} {desc} for {n_dims} dimensions: {var}" + ) -def shift_index(current_value, shift, extent): - if current_value is None: - new_value = None - else: - new_value = current_value + shift - if new_value < 0: - new_value = extent + new_value - return new_value +def _ensure_int_tuple(arg, arg_name): + return_list = [] + for item in arg: + try: + return_list.append(int(item)) + except ValueError: + raise TypeError( + f"tuple arg {arg_name}={arg} contains item {item} of " + f"unexpected type {type(item)}" + ) + return tuple(return_list) From 937417beda47d9d941a7872792d302c33c125715 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 10 Dec 2024 09:38:40 -0500 Subject: [PATCH 3/8] Fix tests --- tests/quantity/test_quantity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index a6de628..61e9202 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -1,8 +1,8 @@ import numpy as np import pytest -import ndsl.quantity as qty from ndsl import Quantity +from ndsl.quantity.bounds import _shift_slice try: @@ -229,7 +229,7 @@ def test_compute_view_edit_all_domain(quantity, n_halo, n_dims, extent_1d): ], ) def test_shift_slice(slice_in, shift, extent, slice_out): - result = qty.shift_slice(slice_in, shift, extent) + result = _shift_slice(slice_in, shift, extent) assert result == slice_out From 45c318050ef652e3f80775c1b1376a424bc1294f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 10 Dec 2024 09:51:12 -0500 Subject: [PATCH 4/8] Lint --- ndsl/quantity/__init__.py | 2 +- ndsl/quantity/metadata.py | 1 + ndsl/quantity/quantity.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ndsl/quantity/__init__.py b/ndsl/quantity/__init__.py index c8f68ea..fce0cf6 100644 --- a/ndsl/quantity/__init__.py +++ b/ndsl/quantity/__init__.py @@ -1,5 +1,5 @@ +from ndsl.quantity.metadata import QuantityHaloSpec, QuantityMetadata from ndsl.quantity.quantity import Quantity -from ndsl.quantity.metadata import QuantityMetadata, QuantityHaloSpec __all__ = [ diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index 0051f6b..611cf90 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Tuple, Union import numpy as np + from ndsl.optional_imports import cupy from ndsl.types import NumpyModule diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 00e66f0..a7d89df 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -8,9 +8,9 @@ from ndsl.dsl.typing import Float, is_float from ndsl.optional_imports import cupy, dace, gt4py from ndsl.optional_imports import xarray as xr -from ndsl.types import NumpyModule from ndsl.quantity.bounds import BoundedArrayView -from ndsl.quantity.metadata import QuantityMetadata, QuantityHaloSpec +from ndsl.quantity.metadata import QuantityHaloSpec, QuantityMetadata +from ndsl.types import NumpyModule class Quantity: From 0330cdbea8c65760bcb4f3541556b82b30bc6623 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 10 Dec 2024 09:56:01 -0500 Subject: [PATCH 5/8] Remove `cp.ndarray` since cupy is optional --- ndsl/quantity/quantity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index a7d89df..35c232e 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -231,7 +231,7 @@ def view(self) -> BoundedArrayView: return self._compute_domain_view @property - def data(self) -> Union[np.ndarray, cupy.ndarray]: + def data(self) -> np.ndarray: """the underlying array of data""" return self._data From 18b2f3f8ed38b456e28e003b0a9735d498ee5599 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 10 Dec 2024 11:00:56 -0500 Subject: [PATCH 6/8] Restore workaround for optional cupy --- ndsl/quantity/quantity.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 35c232e..5e8b317 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -13,6 +13,10 @@ from ndsl.types import NumpyModule +if cupy is None: + import numpy as cupy + + class Quantity: """ Data container for physical quantities. @@ -231,7 +235,7 @@ def view(self) -> BoundedArrayView: return self._compute_domain_view @property - def data(self) -> np.ndarray: + def data(self) -> Union[np.ndarray, cupy.ndarray]: """the underlying array of data""" return self._data From a8a7c85accfac3ce62c992984008fe27b450756d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 10 Dec 2024 11:08:00 -0500 Subject: [PATCH 7/8] Cupy trick for metadata --- ndsl/quantity/metadata.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index 611cf90..e859139 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -7,6 +7,10 @@ from ndsl.types import NumpyModule +if cupy is None: + import numpy as cupy + + @dataclasses.dataclass class QuantityMetadata: origin: Tuple[int, ...] From fcfb0584540eeaee558ac26bd40be74ceb47ca4d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 7 Jan 2025 13:52:33 -0500 Subject: [PATCH 8/8] Fix merge --- ndsl/quantity/metadata.py | 9 +++++++++ ndsl/quantity/quantity.py | 9 --------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ndsl/quantity/metadata.py b/ndsl/quantity/metadata.py index e859139..d7ddba0 100644 --- a/ndsl/quantity/metadata.py +++ b/ndsl/quantity/metadata.py @@ -45,6 +45,15 @@ def np(self) -> NumpyModule: f"quantity underlying data is of unexpected type {self.data_type}" ) + def duplicate_metadata(self, metadata_copy): + metadata_copy.origin = self.origin + metadata_copy.extent = self.extent + metadata_copy.dims = self.dims + metadata_copy.units = self.units + metadata_copy.data_type = self.data_type + metadata_copy.dtype = self.dtype + metadata_copy.gt4py_backend = self.gt4py_backend + @dataclasses.dataclass class QuantityHaloSpec: diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 6ddc90e..c88ba14 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -262,15 +262,6 @@ def data_array(self) -> xr.DataArray: def np(self) -> NumpyModule: return self.metadata.np - def duplicate_metadata(self, metadata_copy): - metadata_copy.origin = self.origin - metadata_copy.extent = self.extent - metadata_copy.dims = self.dims - metadata_copy.units = self.units - metadata_copy.data_type = self.data_type - metadata_copy.dtype = self.dtype - metadata_copy.gt4py_backend = self.gt4py_backend - @property def __array_interface__(self): return self.data.__array_interface__