From 1f03a07c4d377ac2172279708d8a3b445943fbdf Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 11 Jan 2024 10:46:43 -0800 Subject: [PATCH] basic exec protocol --- dask_expr/_collection.py | 13 +++++++++++++ dask_expr/_core.py | 5 +++++ dask_expr/_expr.py | 4 ++++ dask_expr/_groupby.py | 9 +++++++++ dask_expr/_merge.py | 5 +++++ dask_expr/_shuffle.py | 12 ++++++++++++ dask_expr/io/io.py | 6 ++++++ dask_expr/tests/test_collection.py | 8 ++++++++ 8 files changed, 62 insertions(+) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 71dc25d18..294dd9b2e 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -292,6 +292,19 @@ def compute(self, fuse=True, **kwargs): def dask(self): return self.__dask_graph__() + def exec(self, simplify=True): + """Directly execute with the backend DataFrame library + + WARNING: This is an experimental feature. Use at your own risk. + + This function will NOT convert the expression to a task graph + and execute with dask. Instead, the backend library will be + used to execute the logic defined by the ``Expr.__exec__`` + protocols directly. + """ + out = self.simplify() if simplify else self + return out.expr.__exec__() + def __dask_graph__(self): out = self.expr out = out.lower_completely() diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 203a9f30f..696d28976 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -46,6 +46,11 @@ def __init__(self, *args, **kwargs): # avoid infinite recursion raise ValueError(f"{dep} has no attribute {self._required_attribute}") + def __exec__(self): + raise NotImplementedError( + f"Backend exec is not yet supported for {type(self)}." + ) + @property def _required_attribute(self) -> str: # Specify if the first `dependency` must support diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 88415845a..31bece13e 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -452,6 +452,10 @@ def _meta(self): args = [op._meta if isinstance(op, Expr) else op for op in self._args] return self.operation(*args, **self._kwargs) + def __exec__(self): + args = [op.__exec__() if isinstance(op, Expr) else op for op in self._args] + return self.operation(*args, **self._kwargs) + @functools.cached_property def _kwargs(self) -> dict: if self._keyword_only: diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index d7f460148..9f329f67e 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -286,6 +286,15 @@ class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase): } chunk = staticmethod(_groupby_apply_funcs) + def __exec__(self): + frame = self.frame.__exec__() + kwargs = { + "sort": self.sort, + **_as_dict("observed", self.observed), + **_as_dict("dropna", self.dropna), + } + return frame.groupby(self.by, **kwargs).aggregate(self.arg) + @functools.cached_property def spec(self): # Converts the `arg` operand into specific diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 5ade7353c..04799a991 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -93,6 +93,11 @@ def _meta(self): right = meta_nonempty(self.right._meta) return make_meta(left.merge(right, **self.kwargs)) + def __exec__(self): + left = self.left.__exec__() + right = self.right.__exec__() + return left.merge(right, **self.kwargs) + @functools.cached_property def _npartitions(self): if self.operand("_npartitions") is not None: diff --git a/dask_expr/_shuffle.py b/dask_expr/_shuffle.py index 9a8a31519..d362e4d1e 100644 --- a/dask_expr/_shuffle.py +++ b/dask_expr/_shuffle.py @@ -783,6 +783,14 @@ def _meta(self): other = self._other return self.frame._meta.set_index(other, drop=self.drop) + def __exec__(self): + frame = self.frame.__exec__() + if isinstance(self._other, Expr): + other = self._other.__exec__() + else: + other = self._other + return frame.set_index(other, drop=self.drop) + @property def _divisions_column(self): return self.other @@ -951,6 +959,10 @@ def sort_function_kwargs(self): def _meta(self): return self.frame._meta + def __exec__(self): + frame = self.frame.__exec__() + return self.sort_function(frame, **self.sort_function_kwargs) + @functools.cached_property def _meta_by_dtype(self): dtype = self._meta.dtypes[self.by] diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 081473ad9..950bb8798 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -337,6 +337,12 @@ def _meta(self): return meta[self.columns[0]] if self._series else meta[self.columns] return meta + def __exec__(self): + pdf = self.operand("frame")._data + if self.columns: + return pdf[self.columns[0]] if self._series else pdf[self.columns] + return pdf + @functools.cached_property def columns(self): columns_operand = self.operand("columns") diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 14fa4c767..c7df7c25a 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -2138,3 +2138,11 @@ def test_axes(df, pdf): [assert_eq(d, p) for d, p in zip(df.axes, pdf.axes)] assert len(df.x.axes) == len(pdf.x.axes) assert_eq(df.x.axes[0], pdf.x.axes[0]) + + +def test_exec(): + pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 6], "b": 1, "c": 2}) + df = from_pandas(pdf.copy()) + result = (df + 1).sort_values("a")["a"] + result_pd = (pdf + 1).sort_values("a")["a"] + assert_eq(result.exec(), result_pd)