66
77import re
88from collections import defaultdict
9- from dataclasses import dataclass , field
9+ from dataclasses import dataclass , field , replace
1010from typing import TYPE_CHECKING , Literal , overload
1111
1212from typing_extensions import assert_never
@@ -71,6 +71,9 @@ class EGraphState:
7171 # Cache of egg expressions for converting to egg
7272 expr_to_egg_cache : dict [ExprDecl , bindings ._Expr ] = field (default_factory = dict )
7373
74+ # Callables which have cost tables associated with them
75+ cost_callables : set [CallableRef ] = field (default_factory = set )
76+
7477 def copy (self ) -> EGraphState :
7578 """
7679 Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
@@ -83,6 +86,7 @@ def copy(self) -> EGraphState:
8386 callable_ref_to_egg_fn = self .callable_ref_to_egg_fn .copy (),
8487 type_ref_to_egg_sort = self .type_ref_to_egg_sort .copy (),
8588 expr_to_egg_cache = self .expr_to_egg_cache .copy (),
89+ cost_callables = self .cost_callables .copy (),
8690 )
8791
8892 def schedule_to_egg (self , schedule : ScheduleDecl ) -> bindings ._Schedule :
@@ -212,9 +216,32 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin
212216 return bindings .Union (span (), self ._expr_to_egg (lhs ), self ._expr_to_egg (rhs ))
213217 case PanicDecl (name ):
214218 return bindings .Panic (span (), name )
219+ case SetCostDecl (tp , expr , cost ):
220+ self .type_ref_to_egg (tp )
221+ cost_table = self .create_cost_table (expr .callable )
222+ args_egg = [self .typed_expr_to_egg (x , False ) for x in expr .args ]
223+ return bindings .Set (span (), cost_table , args_egg , self ._expr_to_egg (cost ))
215224 case _:
216225 assert_never (action )
217226
227+ def create_cost_table (self , ref : CallableRef ) -> str :
228+ """
229+ Creates the egg cost table if needed and gets the name of the table.
230+ """
231+ name = self .cost_table_name (ref )
232+ if ref not in self .cost_callables :
233+ self .cost_callables .add (ref )
234+ signature = self .__egg_decls__ .get_callable_decl (ref ).signature
235+ assert isinstance (signature , FunctionSignature ), "Can only add cost tables for functions"
236+ signature = replace (signature , return_type = TypeRefWithVars ("i64" ))
237+ self .egraph .run_program (
238+ bindings .FunctionCommand (span (), name , self ._signature_to_egg_schema (signature ), None )
239+ )
240+ return name
241+
242+ def cost_table_name (self , ref : CallableRef ) -> str :
243+ return f"cost_table_{ self .callable_ref_to_egg (ref )[0 ]} "
244+
218245 def fact_to_egg (self , fact : FactDecl ) -> bindings ._Fact :
219246 match fact :
220247 case EqDecl (tp , left , right ):
@@ -350,11 +377,16 @@ def op_mapping(self) -> dict[str, str]:
350377 """
351378 Create a mapping of egglog function name to Python function name, for use in the serialized format
352379 for better visualization.
380+
381+ Includes cost tables
353382 """
354383 return {
355384 k : pretty_callable_ref (self .__egg_decls__ , next (iter (v )))
356385 for k , v in self .egg_fn_to_callable_refs .items ()
357386 if len (v ) == 1
387+ } | {
388+ self .cost_table_name (ref ): f"cost({ pretty_callable_ref (self .__egg_decls__ , ref , include_all_args = True )} )"
389+ for ref in self .cost_callables
358390 }
359391
360392 def possible_egglog_functions (self , names : list [str ]) -> Iterable [str ]:
0 commit comments