Skip to content

Commit e1cf1cd

Browse files
committed
add top-k utility
1 parent ed72668 commit e1cf1cd

File tree

2 files changed

+191
-2
lines changed

2 files changed

+191
-2
lines changed

ibis/backends/materialize/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import ibis
1111
import ibis.expr.operations as ops
1212
import ibis.expr.types as ir
13-
from ibis.backends.materialize.api import mz_now
13+
from ibis.backends.materialize.api import mz_now, mz_top_k
1414
from ibis.backends.postgres import Backend as PostgresBackend
1515
from ibis.backends.sql.compilers.materialize import MaterializeCompiler
1616

17-
__all__ = ("Backend", "mz_now")
17+
__all__ = ("Backend", "mz_now", "mz_top_k")
1818

1919

2020
class Backend(PostgresBackend):

ibis/backends/materialize/api.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import ibis
56
import ibis.expr.types as ir
67
from ibis.backends.materialize import operations as mz_ops
78

@@ -67,3 +68,191 @@ def mz_now() -> ir.TimestampScalar:
6768
- Idiomatic patterns: https://materialize.com/docs/transform-data/idiomatic-materialize-sql/#temporal-filters
6869
"""
6970
return mz_ops.MzNow().to_expr()
71+
72+
73+
def mz_top_k(
74+
table: ir.Table,
75+
k: int,
76+
by: list[str] | str,
77+
order_by: list[str] | str | list[tuple[str, bool]],
78+
desc: bool = True,
79+
group_size: int | None = None,
80+
) -> ir.Table:
81+
"""Get top-k rows per group using idiomatic Materialize SQL.
82+
83+
Parameters
84+
----------
85+
table : Table
86+
The input table
87+
k : int
88+
Number of rows to keep per group
89+
by : str or list of str
90+
Column(s) to group by (partition keys)
91+
order_by : str or list of str or list of (str, bool)
92+
Column(s) to order by within each group.
93+
If tuple, second element is True for DESC, False for ASC.
94+
desc : bool, default True
95+
Default sort direction when order_by is just column names
96+
group_size : int, optional
97+
Materialize-specific query hint to control memory usage.
98+
For k=1: Sets DISTINCT ON INPUT GROUP SIZE
99+
For k>1: Sets LIMIT INPUT GROUP SIZE
100+
Ignored for non-Materialize backends.
101+
102+
Returns
103+
-------
104+
Table
105+
Top k rows per group
106+
107+
Examples
108+
--------
109+
>>> import ibis
110+
>>> from ibis.backends.materialize.api import mz_top_k
111+
>>> con = ibis.materialize.connect(...)
112+
>>> orders = con.table("orders")
113+
>>>
114+
>>> # Top 3 items per order by subtotal
115+
>>> mz_top_k(orders, k=3, by="order_id", order_by="subtotal", desc=True)
116+
>>>
117+
>>> # Top seller per region (k=1 uses DISTINCT ON)
118+
>>> sales = con.table("sales")
119+
>>> mz_top_k(sales, k=1, by="region", order_by="total_sales")
120+
>>>
121+
>>> # Multiple order-by columns with explicit direction
122+
>>> events = con.table("events")
123+
>>> mz_top_k(
124+
... events,
125+
... k=10,
126+
... by="user_id",
127+
... order_by=[
128+
... ("priority", True), # DESC (high priority first)
129+
... ("timestamp", False) # ASC (oldest first)
130+
... ]
131+
... )
132+
>>>
133+
>>> # Use group_size hint to optimize memory usage
134+
>>> mz_top_k(
135+
... orders,
136+
... k=5,
137+
... by="customer_id",
138+
... order_by="order_date",
139+
... group_size=1000 # Hint: expect ~1000 orders per customer
140+
... )
141+
142+
Notes
143+
-----
144+
The `group_size` parameter helps Materialize optimize memory usage by
145+
providing an estimate of the expected number of rows per group. This is
146+
particularly useful for large datasets.
147+
148+
References
149+
----------
150+
https://materialize.com/docs/transform-data/idiomatic-materialize-sql/top-k/
151+
https://materialize.com/docs/transform-data/optimization/#query-hints
152+
"""
153+
from ibis.backends.materialize import Backend as MaterializeBackend
154+
155+
# Normalize inputs
156+
if isinstance(by, str):
157+
by = [by]
158+
159+
# Normalize order_by to list of (column, desc) tuples
160+
if isinstance(order_by, str):
161+
order_by = [(order_by, desc)]
162+
elif isinstance(order_by, list):
163+
if order_by and not isinstance(order_by[0], tuple):
164+
order_by = [(col, desc) for col in order_by]
165+
166+
backend = table._find_backend()
167+
168+
if isinstance(backend, MaterializeBackend):
169+
if k == 1:
170+
return _top_k_distinct_on(table, by, order_by, group_size)
171+
else:
172+
return _top_k_lateral(table, k, by, order_by, group_size)
173+
else:
174+
return _top_k_generic(table, k, by, order_by)
175+
176+
177+
def _top_k_distinct_on(table, by, order_by, group_size):
178+
"""Use DISTINCT ON for k=1 in Materialize."""
179+
backend = table._find_backend()
180+
table_name = table.get_name()
181+
182+
# Build column lists
183+
by_cols = ", ".join(by)
184+
order_exprs = ", ".join(
185+
[f"{col} {'DESC' if desc else 'ASC'}" for col, desc in order_by]
186+
)
187+
188+
# Add group size hint if provided
189+
options_clause = ""
190+
if group_size is not None:
191+
options_clause = f"\n OPTIONS (DISTINCT ON INPUT GROUP SIZE = {group_size})"
192+
193+
sql = f"""
194+
SELECT DISTINCT ON({by_cols}) *
195+
FROM {table_name}{options_clause}
196+
ORDER BY {by_cols}, {order_exprs}
197+
"""
198+
199+
return backend.sql(sql)
200+
201+
202+
def _top_k_lateral(table, k, by, order_by, group_size):
203+
"""Use LATERAL join pattern for k>1 in Materialize."""
204+
backend = table._find_backend()
205+
table_name = table.get_name()
206+
207+
# Build column lists
208+
by_cols = ", ".join(by)
209+
210+
# Get all columns except group by columns for the lateral select
211+
all_cols = list(table.columns)
212+
lateral_cols = [col for col in all_cols if col not in by]
213+
lateral_select = ", ".join(lateral_cols)
214+
215+
# Build WHERE clause for lateral join
216+
where_clause = " AND ".join([f"{col} = grp.{col}" for col in by])
217+
218+
# Build ORDER BY for lateral subquery
219+
lateral_order = ", ".join(
220+
[f"{col} {'DESC' if desc else 'ASC'}" for col, desc in order_by]
221+
)
222+
223+
# Build final ORDER BY (group keys + order keys)
224+
final_order_cols = ", ".join(
225+
[f"{col} {'DESC' if desc else 'ASC'}" for col, desc in order_by]
226+
)
227+
228+
# Add group size hint if provided
229+
options_clause = ""
230+
if group_size is not None:
231+
options_clause = f"\n OPTIONS (LIMIT INPUT GROUP SIZE = {group_size})"
232+
233+
sql = f"""
234+
SELECT grp.{by_cols}, lateral_data.*
235+
FROM (SELECT DISTINCT {by_cols} FROM {table_name}) grp,
236+
LATERAL (
237+
SELECT {lateral_select}
238+
FROM {table_name}
239+
WHERE {where_clause}{options_clause}
240+
ORDER BY {lateral_order}
241+
LIMIT {k}
242+
) lateral_data
243+
ORDER BY {by_cols}, {final_order_cols}
244+
"""
245+
246+
return backend.sql(sql)
247+
248+
249+
def _top_k_generic(table, k, by, order_by):
250+
"""Generic ROW_NUMBER() implementation for non-Materialize backends."""
251+
# Build window function
252+
order_keys = [ibis.desc(col) if desc else ibis.asc(col) for col, desc in order_by]
253+
254+
return (
255+
table.mutate(_rn=ibis.row_number().over(group_by=by, order_by=order_keys))
256+
.filter(ibis._["_rn"] <= k)
257+
.drop("_rn")
258+
)

0 commit comments

Comments
 (0)