Skip to content

Commit d7e81a3

Browse files
Merge pull request #340 from egraphs-good/tree
Add `egraph.function_values(fn)` to export all function values like `print-function`
2 parents b083415 + c3b7cb4 commit d7e81a3

File tree

4 files changed

+52
-10
lines changed

4 files changed

+52
-10
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Add `egraph.function_values(fn)` to export all function values like `print-function` [#340](https://github.com/egraphs-good/egglog-python/pull/340)
78
- Add `egraph.stats()` method to print overall stats [#339](https://github.com/egraphs-good/egglog-python/pull/339)
89
- Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338)
910
- Fix execution of docs [#337](https://github.com/egraphs-good/egglog-python/pull/337)

docs/reference/egglog-translation.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,15 @@ The `(print-stats)` command is translated into `egraph.stats()` to get overall s
531531
egraph.stats()
532532
```
533533

534+
## Function Values
535+
536+
The `print-function` command is translated into `egraph.function_values(fn, [length]?)` to get the values of a specific function. Note that the function provided must either return a primitive or be created with a merge function.
537+
538+
```{code-cell} python
539+
# (print-function fib 3)
540+
egraph.function_values(fib, length=3)
541+
```
542+
534543
## Include
535544

536545
The `(include <path>)` command is used to add modularity, by allowing you to pull in the source from another egglog file into the current file.

python/egglog/egraph.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -852,12 +852,12 @@ def input(self, fn: Callable[..., String], path: str) -> None:
852852
"""
853853
Loads a CSV file and sets it as *input, output of the function.
854854
"""
855-
self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn), path))
855+
self._egraph.run_program(bindings.Input(span(1), self._callable_to_egg(fn)[1], path))
856856

857-
def _callable_to_egg(self, fn: ExprCallable) -> str:
857+
def _callable_to_egg(self, fn: ExprCallable) -> tuple[CallableRef, str]:
858858
ref, decls = resolve_callable(fn)
859859
self._add_decls(decls)
860-
return self._state.callable_ref_to_egg(ref)[0]
860+
return ref, self._state.callable_ref_to_egg(ref)[0]
861861

862862
# TODO: Change let to be action...
863863
def let(self, name: str, expr: BASE_EXPR) -> BASE_EXPR:
@@ -961,15 +961,15 @@ def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tu
961961
runtime_expr = to_runtime_expr(expr)
962962
extract_report = self._run_extract(runtime_expr, 0)
963963
assert isinstance(extract_report, bindings.ExtractBest)
964-
(new_typed_expr,) = self._state.exprs_from_egg(
965-
extract_report.termdag, [extract_report.term], runtime_expr.__egg_typed_expr__.tp
966-
)
967-
968-
res = cast("BASE_EXPR", RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr))
964+
res = self._from_termdag(extract_report.termdag, extract_report.term, runtime_expr.__egg_typed_expr__.tp)
969965
if include_cost:
970966
return res, extract_report.cost
971967
return res
972968

969+
def _from_termdag(self, termdag: bindings.TermDag, term: bindings._Term, tp: JustTypeRef) -> Any:
970+
(new_typed_expr,) = self._state.exprs_from_egg(termdag, [term], tp)
971+
return RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)
972+
973973
def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
974974
"""
975975
Extract multiple expressions from the egraph.
@@ -1040,7 +1040,7 @@ def _serialize(
10401040
msg = ", ".join(set(self._state.possible_egglog_functions(serialized.truncated_functions)))
10411041
warn(f"Truncated: {msg}", stacklevel=3)
10421042
if split_primitive_outputs or split_functions:
1043-
additional_ops = set(map(self._callable_to_egg, split_functions))
1043+
additional_ops = {self._callable_to_egg(f)[1] for f in split_functions}
10441044
serialized.split_classes(self._egraph, additional_ops)
10451045
serialized.map_ops(self._state.op_mapping())
10461046

@@ -1191,7 +1191,7 @@ def function_size(self, fn: ExprCallable) -> int:
11911191
"""
11921192
Returns the number of rows in a certain function
11931193
"""
1194-
egg_name = self._callable_to_egg(fn)
1194+
egg_name = self._callable_to_egg(fn)[1]
11951195
(output,) = self._egraph.run_program(bindings.PrintSize(span(1), egg_name))
11961196
assert isinstance(output, bindings.PrintFunctionSize)
11971197
return output.size
@@ -1214,6 +1214,27 @@ def all_function_sizes(self) -> list[tuple[ExprCallable, int]]:
12141214
if (refs := self._state.egg_fn_to_callable_refs[name])
12151215
]
12161216

1217+
def function_values(
1218+
self, fn: Callable[..., BASE_EXPR] | BASE_EXPR, length: int | None = None
1219+
) -> dict[BASE_EXPR, BASE_EXPR]:
1220+
"""
1221+
Given a callable that is a "function", meaning it returns a primitive or has a merge set,
1222+
returns a mapping of the function applied with its arguments to its values
1223+
1224+
If length is specified, only the first `length` values will be returned.
1225+
"""
1226+
ref, egg_name = self._callable_to_egg(fn)
1227+
cmd = bindings.PrintFunction(span(1), egg_name, length, None, bindings.DefaultPrintFunctionMode())
1228+
(output,) = self._egraph.run_program(cmd)
1229+
assert isinstance(output, bindings.PrintFunctionOutput)
1230+
signature = self.__egg_decls__.get_callable_decl(ref).signature
1231+
assert isinstance(signature, FunctionSignature)
1232+
tp = signature.semantic_return_type.to_just()
1233+
return {
1234+
self._from_termdag(output.termdag, call, tp): self._from_termdag(output.termdag, res, tp)
1235+
for (call, res) in output.terms
1236+
}
1237+
12171238

12181239
# Either a constant or a function.
12191240
ExprCallable: TypeAlias = Callable[..., BaseExpr] | BaseExpr

python/tests/test_high_level.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,3 +1052,14 @@ def test_all_function_size():
10521052

10531053
def test_overall_run_report():
10541054
assert EGraph().stats()
1055+
1056+
1057+
def test_function_values():
1058+
egraph = EGraph()
1059+
1060+
@function
1061+
def f(x: i64Like) -> i64: ...
1062+
1063+
egraph.register(set_(f(i64(1))).to(i64(2)))
1064+
values = egraph.function_values(f)
1065+
assert values == {f(i64(1)): i64(2)}

0 commit comments

Comments
 (0)