1414from typing import Generic , Any
1515import functools
1616
17- from ...compat import compute , dask_array , persist , visualize , TypeAlias
17+ from ...compat import TypeAlias
1818from ..plain import (
1919 GenericPlainRegistry ,
2020 PlainQuantity ,
2525)
2626
2727
28+ def is_dask_array (obj ):
29+ return type (obj ).__name__ == "Array" and "dask" == type (obj ).__module__ [:4 ]
30+
31+
2832def check_dask_array (f ):
2933 @functools .wraps (f )
3034 def wrapper (self , * args , ** kwargs ):
31- if isinstance (self ._magnitude , dask_array . Array ):
35+ if is_dask_array (self ._magnitude ):
3236 return f (self , * args , ** kwargs )
3337 else :
34- msg = "Method {} only implemented for objects of {}, not {}" .format (
35- f .__name__ , dask_array .Array , self ._magnitude .__class__
38+ msg = (
39+ "Method {} only implemented for objects of dask array, not {}." .format (
40+ f .__name__ , self ._magnitude .__class__ .__name__
41+ )
3642 )
3743 raise AttributeError (msg )
3844
@@ -42,7 +48,9 @@ def wrapper(self, *args, **kwargs):
4248class DaskQuantity (Generic [MagnitudeT ], PlainQuantity [MagnitudeT ]):
4349 # Dask.array.Array ducking
4450 def __dask_graph__ (self ):
45- if isinstance (self ._magnitude , dask_array .Array ):
51+ import dask .array as da
52+
53+ if isinstance (self ._magnitude , da .Array ):
4654 return self ._magnitude .__dask_graph__ ()
4755
4856 return None
@@ -57,11 +65,15 @@ def __dask_tokenize__(self):
5765
5866 @property
5967 def __dask_optimize__ (self ):
60- return dask_array .Array .__dask_optimize__
68+ import dask .array as da
69+
70+ return da .Array .__dask_optimize__
6171
6272 @property
6373 def __dask_scheduler__ (self ):
64- return dask_array .Array .__dask_scheduler__
74+ import dask .array as da
75+
76+ return da .Array .__dask_scheduler__
6577
6678 def __dask_postcompute__ (self ):
6779 func , args = self ._magnitude .__dask_postcompute__ ()
@@ -89,6 +101,8 @@ def compute(self, **kwargs):
89101 pint.PlainQuantity
90102 A pint.PlainQuantity wrapped numpy array.
91103 """
104+ from dask .base import compute
105+
92106 (result ,) = compute (self , ** kwargs )
93107 return result
94108
@@ -106,6 +120,8 @@ def persist(self, **kwargs):
106120 pint.PlainQuantity
107121 A pint.PlainQuantity wrapped Dask array.
108122 """
123+ from dask .base import persist
124+
109125 (result ,) = persist (self , ** kwargs )
110126 return result
111127
@@ -124,6 +140,8 @@ def visualize(self, **kwargs):
124140 -------
125141
126142 """
143+ from dask .base import visualize
144+
127145 visualize (self , ** kwargs )
128146
129147
0 commit comments