Skip to content

Commit 5dc516e

Browse files
Merge pull request #343 from egraphs-good/set-cost
Add support for `set_cost` action to have row level costs for extraction
2 parents d3761d6 + 2dd4fe1 commit 5dc516e

File tree

9 files changed

+199
-8
lines changed

9 files changed

+199
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,4 @@ Source.*
8484
inlined
8585
visualizer.tgz
8686
package
87+
.mypy_cache/

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Add support for `set_cost` action to have row level costs for extraction [#343](https://github.com/egraphs-good/egglog-python/pull/343)
78
- Add `egraph.function_values(fn)` to export all function values like `print-function` [#340](https://github.com/egraphs-good/egglog-python/pull/340)
89
- Add `egraph.stats()` method to print overall stats [#339](https://github.com/egraphs-good/egglog-python/pull/339)
910
- Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338)

docs/reference/egglog-translation.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,23 @@ except BaseException as e:
268268
print(e)
269269
```
270270

271+
### Set Cost
272+
273+
You can also set the cost of individual values, like the egglog experimental feature, to override the default cost from constructing a function:
274+
275+
```{code-cell} python
276+
# egg: (set-cost (fib 0) 1)
277+
egraph.register(set_cost(fib(0), 1))
278+
```
279+
280+
This will be taken into account when extracting. Any value that can be converted to an `i64` is supported as a cost,
281+
so dynamic costs can be created in rules.
282+
283+
It does this by creating a new table for each function you set the cost for that maps the arguments to an i64.
284+
285+
_Note: Unlike in egglog, where you have to declare which functions support custom costs, in Python all functions
286+
are automatically registered to create a custom cost table when they are constructed_
287+
271288
## Defining Rules
272289

273290
To define rules in Python, we create a rule with the `rule(*facts).then(*actions) (rule ...)` command in egglog.

python/egglog/declarations.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"SaturateDecl",
7070
"ScheduleDecl",
7171
"SequenceDecl",
72+
"SetCostDecl",
7273
"SetDecl",
7374
"SpecialFunctions",
7475
"TypeOrVarRef",
@@ -854,7 +855,14 @@ class PanicDecl:
854855
msg: str
855856

856857

857-
ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
858+
@dataclass(frozen=True)
859+
class SetCostDecl:
860+
tp: JustTypeRef
861+
expr: CallDecl
862+
cost: ExprDecl
863+
864+
865+
ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl | SetCostDecl
858866

859867

860868
##

python/egglog/egraph.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .version_compat import *
4141

4242
if TYPE_CHECKING:
43-
from .builtins import String, Unit
43+
from .builtins import String, Unit, i64Like
4444

4545

4646
__all__ = [
@@ -84,6 +84,7 @@
8484
"run",
8585
"seq",
8686
"set_",
87+
"set_cost",
8788
"subsume",
8889
"union",
8990
"unstable_combine_rulesets",
@@ -985,8 +986,14 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
985986
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
986987
self._add_decls(expr)
987988
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
989+
# If we have defined any cost tables use the custom extraction
990+
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
991+
if self._state.cost_callables:
992+
cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args))
993+
else:
994+
cmd = bindings.Extract(span(2), *args)
988995
try:
989-
return self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))[0]
996+
return self._egraph.run_program(cmd)[0]
990997
except BaseException as e:
991998
raise add_note("Extracting: " + str(expr), e) # noqa: B904
992999

@@ -1460,10 +1467,13 @@ def __bool__(self) -> bool:
14601467
"""
14611468
Returns True if the two sides of an equality are structurally equal.
14621469
"""
1463-
if not isinstance(self.fact, EqDecl):
1464-
msg = "Can only check equality facts"
1465-
raise TypeError(msg)
1466-
return self.fact.left == self.fact.right
1470+
match self.fact:
1471+
case EqDecl(_, left, right):
1472+
return left == right
1473+
case ExprFactDecl(TypedExprDecl(_, CallDecl(FunctionRef("!="), (left_tp, right_tp)))):
1474+
return left_tp != right_tp
1475+
msg = f"Can only check equality for == or != not {self}"
1476+
raise ValueError(msg)
14671477

14681478

14691479
@dataclass
@@ -1511,6 +1521,18 @@ def panic(message: str) -> Action:
15111521
return Action(Declarations(), PanicDecl(message))
15121522

15131523

1524+
def set_cost(expr: BaseExpr, cost: i64Like) -> Action:
1525+
"""Set the cost of the given expression."""
1526+
from .builtins import i64 # noqa: PLC0415
1527+
1528+
expr_runtime = to_runtime_expr(expr)
1529+
typed_expr_decl = expr_runtime.__egg_typed_expr__
1530+
expr_decl = typed_expr_decl.expr
1531+
assert isinstance(expr_decl, CallDecl), "Can only set cost of calls, not literals or vars"
1532+
cost_decl = to_runtime_expr(convert(cost, i64)).__egg_typed_expr__.expr
1533+
return Action(expr_runtime.__egg_decls__, SetCostDecl(typed_expr_decl.tp, expr_decl, cost_decl))
1534+
1535+
15141536
def let(name: str, expr: BaseExpr) -> Action:
15151537
"""Create a let binding."""
15161538
runtime_expr = to_runtime_expr(expr)

python/egglog/egraph_state.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import re
88
from collections import defaultdict
9-
from dataclasses import dataclass, field
9+
from dataclasses import dataclass, field, replace
1010
from typing import TYPE_CHECKING, Literal, overload
1111

1212
from typing_extensions import assert_never
@@ -71,6 +71,9 @@ class EGraphState:
7171
# Cache of egg expressions for converting to egg
7272
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
7373

74+
# Callables which have cost tables associated with them
75+
cost_callables: set[CallableRef] = field(default_factory=set)
76+
7477
def copy(self) -> EGraphState:
7578
"""
7679
Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
@@ -83,6 +86,7 @@ def copy(self) -> EGraphState:
8386
callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
8487
type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
8588
expr_to_egg_cache=self.expr_to_egg_cache.copy(),
89+
cost_callables=self.cost_callables.copy(),
8690
)
8791

8892
def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
@@ -212,9 +216,32 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin
212216
return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs))
213217
case PanicDecl(name):
214218
return bindings.Panic(span(), name)
219+
case SetCostDecl(tp, expr, cost):
220+
self.type_ref_to_egg(tp)
221+
cost_table = self.create_cost_table(expr.callable)
222+
args_egg = [self.typed_expr_to_egg(x, False) for x in expr.args]
223+
return bindings.Set(span(), cost_table, args_egg, self._expr_to_egg(cost))
215224
case _:
216225
assert_never(action)
217226

227+
def create_cost_table(self, ref: CallableRef) -> str:
228+
"""
229+
Creates the egg cost table if needed and gets the name of the table.
230+
"""
231+
name = self.cost_table_name(ref)
232+
if ref not in self.cost_callables:
233+
self.cost_callables.add(ref)
234+
signature = self.__egg_decls__.get_callable_decl(ref).signature
235+
assert isinstance(signature, FunctionSignature), "Can only add cost tables for functions"
236+
signature = replace(signature, return_type=TypeRefWithVars("i64"))
237+
self.egraph.run_program(
238+
bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None)
239+
)
240+
return name
241+
242+
def cost_table_name(self, ref: CallableRef) -> str:
243+
return f"cost_table_{self.callable_ref_to_egg(ref)[0]}"
244+
218245
def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
219246
match fact:
220247
case EqDecl(tp, left, right):
@@ -350,11 +377,16 @@ def op_mapping(self) -> dict[str, str]:
350377
"""
351378
Create a mapping of egglog function name to Python function name, for use in the serialized format
352379
for better visualization.
380+
381+
Includes cost tables
353382
"""
354383
return {
355384
k: pretty_callable_ref(self.__egg_decls__, next(iter(v)))
356385
for k, v in self.egg_fn_to_callable_refs.items()
357386
if len(v) == 1
387+
} | {
388+
self.cost_table_name(ref): f"cost({pretty_callable_ref(self.__egg_decls__, ref, include_all_args=True)})"
389+
for ref in self.cost_callables
358390
}
359391

360392
def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:

python/egglog/examples/jointree.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# mypy: disable-error-code="empty-body"
2+
3+
"""
4+
Join Tree (custom costs)
5+
========================
6+
7+
Example of using custom cost functions for jointree.
8+
9+
From https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from egglog import *
15+
16+
17+
class JoinTree(Expr):
18+
def __init__(self, name: StringLike) -> None: ...
19+
20+
def join(self, other: JoinTree) -> JoinTree: ...
21+
22+
@method(merge=lambda old, new: old.min(new)) # type:ignore[prop-decorator]
23+
@property
24+
def size(self) -> i64: ...
25+
26+
27+
ra = JoinTree("a")
28+
rb = JoinTree("b")
29+
rc = JoinTree("c")
30+
rd = JoinTree("d")
31+
re = JoinTree("e")
32+
rf = JoinTree("f")
33+
34+
query = ra.join(rb).join(rc).join(rd).join(re).join(rf)
35+
36+
egraph = EGraph()
37+
egraph.register(
38+
set_(ra.size).to(50),
39+
set_(rb.size).to(200),
40+
set_(rc.size).to(10),
41+
set_(rd.size).to(123),
42+
set_(re.size).to(10000),
43+
set_(rf.size).to(1),
44+
)
45+
46+
47+
@egraph.register
48+
def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: i64):
49+
# cost of relation is its size minus 1, since the string arg will have a cost of 1 as well
50+
yield rule(JoinTree(s).size == asize).then(set_cost(JoinTree(s), asize - 1))
51+
# cost/size of join is product of sizes
52+
yield rule(a.join(b), a.size == asize, b.size == bsize).then(
53+
set_(a.join(b).size).to(asize * bsize), set_cost(a.join(b), asize * bsize)
54+
)
55+
# associativity
56+
yield rewrite(a.join(b)).to(b.join(a))
57+
# commutativity
58+
yield rewrite(a.join(b).join(c)).to(a.join(b.join(c)))
59+
60+
61+
egraph.register(query)
62+
egraph.run(1000)
63+
print(egraph.extract(query))
64+
print(egraph.extract(query.size))
65+
66+
67+
egraph

python/egglog/pretty.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def pretty_callable_ref(
9898
ref: CallableRef,
9999
first_arg: ExprDecl | None = None,
100100
bound_tp_params: tuple[JustTypeRef, ...] | None = None,
101+
include_all_args: bool = False,
101102
) -> str:
102103
"""
103104
Pretty print a callable reference, using a dummy value for
@@ -115,6 +116,13 @@ def pretty_callable_ref(
115116
# Either returns a function or a function with args. If args are provided, they would just be called,
116117
# on the function, so return them, because they are dummies
117118
if isinstance(res, tuple):
119+
# If we want to include all args as ARG_STR, then we need to figure out how many to use
120+
# used for set_cost so that `cost(E(...))` will show up as a call
121+
if include_all_args:
122+
signature = decls.get_callable_decl(ref).signature
123+
assert isinstance(signature, FunctionSignature)
124+
correct_args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * len(signature.arg_types)
125+
return f"{res[0]}({', '.join(context(a, parens=False, unwrap_lit=True) for a in correct_args)})"
118126
return res[0]
119127
return res
120128

@@ -190,6 +198,9 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90
190198
pass
191199
case DefaultRewriteDecl():
192200
pass
201+
case SetCostDecl(_, e, c):
202+
self(e)
203+
self(c)
193204
case _:
194205
assert_never(decl)
195206

@@ -285,6 +296,8 @@ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_na
285296
return f"{change}({self(expr)})", "action"
286297
case PanicDecl(s):
287298
return f"panic({s!r})", "action"
299+
case SetCostDecl(_, expr, cost):
300+
return f"set_cost({self(expr)}, {self(cost, unwrap_lit=True)})", "action"
288301
case EqDecl(_, left, right):
289302
return f"eq({self(left)}).to({self(right)})", "fact"
290303
case RulesetDecl(rules):

python/tests/test_high_level.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,3 +1063,33 @@ def f(x: i64Like) -> i64: ...
10631063
egraph.register(set_(f(i64(1))).to(i64(2)))
10641064
values = egraph.function_values(f)
10651065
assert values == {f(i64(1)): i64(2)}
1066+
1067+
1068+
def test_dynamic_cost():
1069+
"""
1070+
https://github.com/egraphs-good/egglog-experimental/blob/6d07a34ac76deec751f86f70d9b9358cd3e236ca/tests/integration_test.rs#L5-L35
1071+
"""
1072+
1073+
class E(Expr):
1074+
def __init__(self, x: i64Like) -> None: ...
1075+
def __add__(self, other: E) -> E: ...
1076+
@method(cost=200)
1077+
def __sub__(self, other: E) -> E: ...
1078+
1079+
egraph = EGraph()
1080+
egraph.register(
1081+
union(E(2)).with_(E(1) + E(1)),
1082+
set_cost(E(2), 1000),
1083+
set_cost(E(1), 100),
1084+
)
1085+
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203)
1086+
with egraph:
1087+
egraph.register(set_cost(E(1) + E(1), 800))
1088+
assert egraph.extract(E(2), include_cost=True) == (E(2), 1001)
1089+
with egraph:
1090+
egraph.register(set_cost(E(1) + E(1), 798))
1091+
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 1000)
1092+
egraph.register(union(E(2)).with_(E(5) - E(3)))
1093+
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203)
1094+
egraph.register(set_cost(E(5) - E(3), 198))
1095+
assert egraph.extract(E(2), include_cost=True) == (E(5) - E(3), 202)

0 commit comments

Comments
 (0)