Skip to content

Commit a2ecf1c

Browse files
committed
Defer import of dask.array
1 parent 3fb63af commit a2ecf1c

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

pint/compat.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,6 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False):
203203
# Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast
204204
# types using guarded imports
205205

206-
try:
207-
from dask import array as dask_array
208-
from dask.base import compute, persist, visualize
209-
except ImportError:
210-
compute, persist, visualize = None, None, None
211-
dask_array = None
212-
213-
214206
# TODO: merge with upcast_type_map
215207

216208
#: List upcast type names

pint/facets/dask/__init__.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Generic, Any
1515
import functools
1616

17-
from ...compat import compute, dask_array, persist, visualize, TypeAlias
17+
from ...compat import TypeAlias
1818
from ..plain import (
1919
GenericPlainRegistry,
2020
PlainQuantity,
@@ -25,14 +25,20 @@
2525
)
2626

2727

28+
def is_dask_array(obj):
29+
return type(obj).__name__ == "Array" and "dask" == type(obj).__module__[:4]
30+
31+
2832
def check_dask_array(f):
2933
@functools.wraps(f)
3034
def wrapper(self, *args, **kwargs):
31-
if isinstance(self._magnitude, dask_array.Array):
35+
if is_dask_array(self._magnitude):
3236
return f(self, *args, **kwargs)
3337
else:
34-
msg = "Method {} only implemented for objects of {}, not {}".format(
35-
f.__name__, dask_array.Array, self._magnitude.__class__
38+
msg = (
39+
"Method {} only implemented for objects of dask array, not {}.".format(
40+
f.__name__, self._magnitude.__class__.__name__
41+
)
3642
)
3743
raise AttributeError(msg)
3844

@@ -42,7 +48,9 @@ def wrapper(self, *args, **kwargs):
4248
class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
4349
# Dask.array.Array ducking
4450
def __dask_graph__(self):
45-
if isinstance(self._magnitude, dask_array.Array):
51+
import dask.array as da
52+
53+
if isinstance(self._magnitude, da.Array):
4654
return self._magnitude.__dask_graph__()
4755

4856
return None
@@ -57,11 +65,15 @@ def __dask_tokenize__(self):
5765

5866
@property
5967
def __dask_optimize__(self):
60-
return dask_array.Array.__dask_optimize__
68+
import dask.array as da
69+
70+
return da.Array.__dask_optimize__
6171

6272
@property
6373
def __dask_scheduler__(self):
64-
return dask_array.Array.__dask_scheduler__
74+
import dask.array as da
75+
76+
return da.Array.__dask_scheduler__
6577

6678
def __dask_postcompute__(self):
6779
func, args = self._magnitude.__dask_postcompute__()
@@ -89,6 +101,8 @@ def compute(self, **kwargs):
89101
pint.PlainQuantity
90102
A pint.PlainQuantity wrapped numpy array.
91103
"""
104+
from dask.base import compute
105+
92106
(result,) = compute(self, **kwargs)
93107
return result
94108

@@ -106,6 +120,8 @@ def persist(self, **kwargs):
106120
pint.PlainQuantity
107121
A pint.PlainQuantity wrapped Dask array.
108122
"""
123+
from dask.base import persist
124+
109125
(result,) = persist(self, **kwargs)
110126
return result
111127

@@ -124,6 +140,8 @@ def visualize(self, **kwargs):
124140
-------
125141
126142
"""
143+
from dask.base import visualize
144+
127145
visualize(self, **kwargs)
128146

129147

pint/testsuite/test_dask.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ def test_exception_method_not_implemented(local_registry, numpy_array, method):
162162
q = local_registry.Quantity(numpy_array, units_)
163163

164164
exctruth = (
165-
f"Method {method} only implemented for objects of"
166-
" <class 'dask.array.core.Array'>, not"
167-
" <class 'numpy.ndarray'>"
165+
f"Method {method} only implemented for objects of" " dask array, not ndarray."
168166
)
169167
with pytest.raises(AttributeError, match=exctruth):
170168
obj_method = getattr(q, method)

0 commit comments

Comments
 (0)