Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defer import of dask.array to speed up import #1795

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Pint Changelog
0.23 (unreleased)
-----------------

- Nothing changed yet.
- Defer import of `dask.array` to speed up import time of pint (PR #1795)


0.22 (2023-05-25)
Expand Down
8 changes: 0 additions & 8 deletions pint/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,6 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
# Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast
# types using guarded imports

try:
from dask import array as dask_array
from dask.base import compute, persist, visualize
except ImportError:
compute, persist, visualize = None, None, None
dask_array = None


# TODO: merge with upcast_type_map

#: List upcast type names
Expand Down
32 changes: 25 additions & 7 deletions pint/facets/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Generic, Any
import functools

from ...compat import compute, dask_array, persist, visualize, TypeAlias
from ...compat import TypeAlias
from ..plain import (
GenericPlainRegistry,
PlainQuantity,
Expand All @@ -25,14 +25,20 @@
)


def is_dask_array(obj):
return type(obj).__name__ == "Array" and "dask" == type(obj).__module__[:4]


def check_dask_array(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if isinstance(self._magnitude, dask_array.Array):
if is_dask_array(self._magnitude):
return f(self, *args, **kwargs)
else:
msg = "Method {} only implemented for objects of {}, not {}".format(
f.__name__, dask_array.Array, self._magnitude.__class__
msg = (
"Method {} only implemented for objects of dask array, not {}.".format(
f.__name__, self._magnitude.__class__.__name__
)
)
raise AttributeError(msg)

Expand All @@ -42,7 +48,9 @@ def wrapper(self, *args, **kwargs):
class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Dask.array.Array ducking
def __dask_graph__(self):
if isinstance(self._magnitude, dask_array.Array):
import dask.array as da

if isinstance(self._magnitude, da.Array):
return self._magnitude.__dask_graph__()

return None
Expand All @@ -57,11 +65,15 @@ def __dask_tokenize__(self):

@property
def __dask_optimize__(self):
return dask_array.Array.__dask_optimize__
import dask.array as da

return da.Array.__dask_optimize__

@property
def __dask_scheduler__(self):
return dask_array.Array.__dask_scheduler__
import dask.array as da

return da.Array.__dask_scheduler__

def __dask_postcompute__(self):
func, args = self._magnitude.__dask_postcompute__()
Expand Down Expand Up @@ -89,6 +101,8 @@ def compute(self, **kwargs):
pint.PlainQuantity
A pint.PlainQuantity wrapped numpy array.
"""
from dask.base import compute

(result,) = compute(self, **kwargs)
return result

Expand All @@ -106,6 +120,8 @@ def persist(self, **kwargs):
pint.PlainQuantity
A pint.PlainQuantity wrapped Dask array.
"""
from dask.base import persist

(result,) = persist(self, **kwargs)
return result

Expand All @@ -124,6 +140,8 @@ def visualize(self, **kwargs):
-------

"""
from dask.base import visualize

visualize(self, **kwargs)


Expand Down
4 changes: 1 addition & 3 deletions pint/testsuite/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def test_exception_method_not_implemented(local_registry, numpy_array, method):
q = local_registry.Quantity(numpy_array, units_)

exctruth = (
f"Method {method} only implemented for objects of"
" <class 'dask.array.core.Array'>, not"
" <class 'numpy.ndarray'>"
f"Method {method} only implemented for objects of" " dask array, not ndarray."
)
with pytest.raises(AttributeError, match=exctruth):
obj_method = getattr(q, method)
Expand Down