diff --git a/CHANGES b/CHANGES index 6d34f7476..0eaa658ed 100644 --- a/CHANGES +++ b/CHANGES @@ -4,7 +4,7 @@ Pint Changelog 0.23 (unreleased) ----------------- -- Nothing changed yet. +- Defer import of `dask.array` to speed up import time of pint (PR #1795) 0.22 (2023-05-25) diff --git a/pint/compat.py b/pint/compat.py index 6be906f4d..b80b0f5c1 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -203,14 +203,6 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False): # Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast # types using guarded imports -try: - from dask import array as dask_array - from dask.base import compute, persist, visualize -except ImportError: - compute, persist, visualize = None, None, None - dask_array = None - - # TODO: merge with upcast_type_map #: List upcast type names diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py index 8d62f55d7..91d4086c9 100644 --- a/pint/facets/dask/__init__.py +++ b/pint/facets/dask/__init__.py @@ -14,7 +14,7 @@ from typing import Generic, Any import functools -from ...compat import compute, dask_array, persist, visualize, TypeAlias +from ...compat import TypeAlias from ..plain import ( GenericPlainRegistry, PlainQuantity, @@ -25,14 +25,20 @@ ) +def is_dask_array(obj): + return type(obj).__name__ == "Array" and "dask" == type(obj).__module__[:4] + + def check_dask_array(f): @functools.wraps(f) def wrapper(self, *args, **kwargs): - if isinstance(self._magnitude, dask_array.Array): + if is_dask_array(self._magnitude): return f(self, *args, **kwargs) else: - msg = "Method {} only implemented for objects of {}, not {}".format( - f.__name__, dask_array.Array, self._magnitude.__class__ + msg = ( + "Method {} only implemented for objects of dask array, not {}.".format( + f.__name__, self._magnitude.__class__.__name__ + ) ) raise AttributeError(msg) @@ -42,7 +48,9 @@ def wrapper(self, *args, **kwargs): class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): # Dask.array.Array ducking def __dask_graph__(self): - if isinstance(self._magnitude, dask_array.Array): + import dask.array as da + + if isinstance(self._magnitude, da.Array): return self._magnitude.__dask_graph__() return None @@ -57,11 +65,15 @@ def __dask_tokenize__(self): @property def __dask_optimize__(self): - return dask_array.Array.__dask_optimize__ + import dask.array as da + + return da.Array.__dask_optimize__ @property def __dask_scheduler__(self): - return dask_array.Array.__dask_scheduler__ + import dask.array as da + + return da.Array.__dask_scheduler__ def __dask_postcompute__(self): func, args = self._magnitude.__dask_postcompute__() @@ -89,6 +101,8 @@ def compute(self, **kwargs): pint.PlainQuantity A pint.PlainQuantity wrapped numpy array. """ + from dask.base import compute + (result,) = compute(self, **kwargs) return result @@ -106,6 +120,8 @@ def persist(self, **kwargs): pint.PlainQuantity A pint.PlainQuantity wrapped Dask array. """ + from dask.base import persist + (result,) = persist(self, **kwargs) return result @@ -124,6 +140,8 @@ def visualize(self, **kwargs): ------- """ + from dask.base import visualize + visualize(self, **kwargs) diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py index 0e6a1cfe7..2cfc3468c 100644 --- a/pint/testsuite/test_dask.py +++ b/pint/testsuite/test_dask.py @@ -162,9 +162,7 @@ def test_exception_method_not_implemented(local_registry, numpy_array, method): q = local_registry.Quantity(numpy_array, units_) exctruth = ( - f"Method {method} only implemented for objects of" - " , not" - " " + f"Method {method} only implemented for objects of" " dask array, not ndarray." ) with pytest.raises(AttributeError, match=exctruth): obj_method = getattr(q, method)