@@ -43,78 +43,54 @@ def _typecheck_param(prim, param, name, msg_required, pred):
4343 msg = sep .join ([msg , param_str ])
4444 raise core .JaxprTypeError (msg )
4545
46- # TODO(dougalm): this is a silly wrapper now. Delete it.
47- @weakref_lru_cache
48- def _initial_style_open_jaxpr (fun : Callable ,
49- in_tree : PyTreeDef ,
50- in_avals : Sequence [core .AbstractValue | core .AvalQDD ],
51- debug_info : core .DebugInfo ):
52- jaxpr , out_tree , consts = pe .trace_to_jaxpr (fun , in_tree , in_avals , debug_info )
53- return jaxpr , consts , out_tree
54-
55- # TODO(dougalm): Delete. Make `trace_to_jaxpr` do the jaxpr-closing thing instead.
56- @weakref_lru_cache
57- def _initial_style_jaxpr (fun : Callable ,
58- in_tree : PyTreeDef ,
59- in_avals : Sequence [core .AbstractValue ],
60- debug_info : core .DebugInfo ) -> tuple [core .ClosedJaxpr , Sequence [Any ], PyTreeDef ]:
61- jaxpr , consts , out_tree = _initial_style_open_jaxpr (
62- fun , in_tree , in_avals , debug_info )
63- closed_jaxpr = pe .close_jaxpr (pe .convert_constvars_jaxpr (jaxpr ))
64- return closed_jaxpr , consts , out_tree
65-
66- def _initial_style_jaxprs_with_common_consts (
67- funs : Sequence [Callable ],
68- in_tree : PyTreeDef , in_avals : Sequence [core .AbstractValue | core .AvalQDD ],
69- debug_infos : Sequence [core .DebugInfo ]):
70- jaxpr_data = [_initial_style_open_jaxpr (fn , in_tree , in_avals , debug_info )
71- for fn , debug_info in zip (funs , debug_infos )]
72- if not jaxpr_data : return [], [], []
73- jaxprs , all_consts , all_out_trees = zip (* jaxpr_data )
74-
46+ # TODO(dougalm): this seems way too complicated. Why not allow different consts for each
47+ # branch of a switch?
48+ def _merge_common_consts (
49+ jaxprs : Sequence [core .Jaxpr ],
50+ all_consts : Sequence [Sequence [Any ]]
51+ ) -> tuple [Sequence [core .ClosedJaxpr ], Sequence [Any ]]:
7552 # Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars.
7653 lens = map (len , all_consts )
7754 consts = [c for cs in all_consts for c in cs ]
7855 avalqdds = tuple (map (core .cur_aval_qdd , consts ))
79- jaxprs = [_pad_constvars (jaxpr , avalqdds [:sum (lens [:i ])], avalqdds [sum (lens [:i + 1 ]):])
80- for i , jaxpr in enumerate (jaxprs )]
56+ num_constss = [len (cs ) for cs in all_consts ]
57+ jaxprs = [_pad_constvars (jaxpr , num_consts , avalqdds [:sum (lens [:i ])], avalqdds [sum (lens [:i + 1 ]):])
58+ for i , (jaxpr , num_consts ) in enumerate (zip (jaxprs , num_constss ))]
8159 # De-duplicate shared constants.
8260 const_ids = tuple (id (c ) for c in consts )
8361 seen = set ()
84- consts = [c for c in consts if id (c ) not in seen and not seen .add (id (c ))] # type: ignore
85- jaxprs = [_dedup_consts (jaxpr , const_ids ) for jaxpr in jaxprs ]
86-
87- closed_jaxprs = [pe .close_jaxpr (pe .convert_constvars_jaxpr (jaxpr ))
88- for jaxpr in jaxprs ]
89- return closed_jaxprs , consts , all_out_trees
62+ dd_consts = [c for c in consts if id (c ) not in seen and not seen .add (id (c ))] # type: ignore
63+ jaxprs = [_dedup_consts (jaxpr , len (consts ), const_ids ) for jaxpr in jaxprs ]
64+ return jaxprs , dd_consts
9065
9166@weakref_lru_cache
92- def _pad_constvars (jaxpr : core .Jaxpr , left : tuple [core .AvalQDD , ...],
93- right : tuple [core .AbstractValue , ...]) -> core .Jaxpr :
67+ def _pad_constvars (jaxpr : core .ClosedJaxpr , num_consts : int ,
68+ left : tuple [core .AvalQDD , ...],
69+ right : tuple [core .AbstractValue , ...]) -> core .ClosedJaxpr :
9470 def make_var (aq ):
9571 return core .Var (aq .aval , initial_qdd = aq .qdd , final_qdd = aq .qdd )
96- constvars = [* map (make_var , left ), * jaxpr .constvars , * map ( make_var , right )]
97- effs = pe . _renumber_effects ([ * constvars , * jaxpr .invars ],
98- [ * jaxpr . constvars , * jaxpr .invars ] , jaxpr .effects )
99- jaxpr = jaxpr .replace (constvars = constvars , effects = effs )
72+ invars = [* map (make_var , left ), * jaxpr .invars [: num_consts ],
73+ * map ( make_var , right ), * jaxpr .invars [ num_consts :]]
74+ effs = pe . _renumber_effects ( invars , jaxpr .invars , jaxpr .effects )
75+ jaxpr = jaxpr .replace (jaxpr = jaxpr . jaxpr . replace ( invars = invars , effects = effs ) )
10076 config .enable_checks .value and core .check_jaxpr (jaxpr )
10177 return jaxpr
10278
10379@weakref_lru_cache
104- def _dedup_consts (jaxpr , const_ids ):
80+ def _dedup_consts (jaxpr , num_consts , const_ids ):
10581 newvars = {}
10682 canonicalize = {v : newvars .setdefault (constid , v )
107- for constid , v in zip (const_ids , jaxpr .constvars )}
83+ for constid , v in zip (const_ids , jaxpr .invars [: num_consts ] )}
10884 eqns = [e .replace (invars = [canonicalize .get (x , x ) if isinstance (x , core .Var )
10985 else x for x in e .invars ]) for e in jaxpr .eqns ]
11086 outvars = [canonicalize .get (x , x ) if isinstance (x , core .Var ) else x
11187 for x in jaxpr .outvars ]
112- constvars = list (newvars .values ())
113- effs = pe ._renumber_effects (
114- [* constvars , * jaxpr .invars ],
115- [ * map ( canonicalize . get , jaxpr . constvars ), * jaxpr . invars ], jaxpr .effects )
116- jaxpr = jaxpr .replace (constvars = constvars , eqns = eqns , outvars = outvars ,
117- effects = effs )
88+ invars = [ * list (newvars .values ()), * jaxpr . invars [ num_consts :]]
89+ effs = pe ._renumber_effects (invars ,
90+ [* map ( canonicalize . get , jaxpr . invars [: num_consts ]), * jaxpr .invars [ num_consts :] ],
91+ jaxpr .effects )
92+ jaxpr = jaxpr .replace (jaxpr = jaxpr . jaxpr . replace ( invars = invars , eqns = eqns , outvars = outvars ,
93+ effects = effs ))
11894 config .enable_checks .value and core .check_jaxpr (jaxpr )
11995 return jaxpr
12096
0 commit comments