Skip to content

Commit 568f00d

Browse files
committed
[no-thunks] Reduce stack frames in cond and friends
Also deprecate the very old form of `cond`: `cond(predicate, true_arg, true_fun, false_arg, false_fun)`.
1 parent 176e3cb commit 568f00d

File tree

9 files changed

+83
-163
lines changed

9 files changed

+83
-163
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2400,7 +2400,7 @@ def trace_to_jaxpr(
24002400
in_tree: PyTreeDef,
24012401
in_avals_flat: Sequence[AbstractValue | core.AvalQDD],
24022402
debug_info: core.DebugInfo
2403-
) -> tuple[Jaxpr, PyTreeDef, list[Any]]:
2403+
) -> tuple[ClosedJaxpr, PyTreeDef, list[Any]]:
24042404
config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat))
24052405
parent_trace = core.trace_ctx.trace
24062406
trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace)
@@ -2424,6 +2424,8 @@ def trace_to_jaxpr(
24242424
del trace, fun, in_tracers_flat, in_tracers, out_tracers, ans, ans_flat
24252425

24262426
config.enable_checks.value and core.check_jaxpr(jaxpr)
2427+
# TODO(dougalm): remove this once we merge Jaxpr and ClosedJaxpr
2428+
jaxpr = close_jaxpr(convert_constvars_jaxpr(jaxpr))
24272429
return jaxpr, out_tree, consts
24282430

24292431

jax/_src/lax/control_flow/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@
5050
# Private utilities used elsewhere in JAX
5151
# TODO(sharadmv): lift them into a more common place
5252
from jax._src.lax.control_flow.common import (
53-
_initial_style_open_jaxpr as _initial_style_open_jaxpr,
54-
_initial_style_jaxpr as _initial_style_jaxpr,
55-
_initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts,
5653
_check_tree_and_avals as _check_tree_and_avals,
5754

5855
)

jax/_src/lax/control_flow/common.py

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -43,78 +43,54 @@ def _typecheck_param(prim, param, name, msg_required, pred):
4343
msg = sep.join([msg, param_str])
4444
raise core.JaxprTypeError(msg)
4545

46-
# TODO(dougalm): this is a silly wrapper now. Delete it.
47-
@weakref_lru_cache
48-
def _initial_style_open_jaxpr(fun: Callable,
49-
in_tree: PyTreeDef,
50-
in_avals: Sequence[core.AbstractValue | core.AvalQDD],
51-
debug_info: core.DebugInfo):
52-
jaxpr, out_tree, consts = pe.trace_to_jaxpr(fun, in_tree, in_avals, debug_info)
53-
return jaxpr, consts, out_tree
54-
55-
# TODO(dougalm): Delete. Make `trace_to_jaxpr` do the jaxpr-closing thing instead.
56-
@weakref_lru_cache
57-
def _initial_style_jaxpr(fun: Callable,
58-
in_tree: PyTreeDef,
59-
in_avals: Sequence[core.AbstractValue],
60-
debug_info: core.DebugInfo) -> tuple[core.ClosedJaxpr, Sequence[Any], PyTreeDef]:
61-
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
62-
fun, in_tree, in_avals, debug_info)
63-
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
64-
return closed_jaxpr, consts, out_tree
65-
66-
def _initial_style_jaxprs_with_common_consts(
67-
funs: Sequence[Callable],
68-
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue | core.AvalQDD],
69-
debug_infos: Sequence[core.DebugInfo]):
70-
jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info)
71-
for fn, debug_info in zip(funs, debug_infos)]
72-
if not jaxpr_data: return [], [], []
73-
jaxprs, all_consts, all_out_trees = zip(*jaxpr_data)
74-
46+
# TODO(dougalm): this seems way too complicated. Why not allow different consts for each
47+
# branch of a switch?
48+
def _merge_common_consts(
49+
jaxprs: Sequence[core.Jaxpr],
50+
all_consts: Sequence[Sequence[Any]]
51+
) -> tuple[Sequence[core.ClosedJaxpr], Sequence[Any]]:
7552
# Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars.
7653
lens = map(len, all_consts)
7754
consts = [c for cs in all_consts for c in cs]
7855
avalqdds = tuple(map(core.cur_aval_qdd, consts))
79-
jaxprs = [_pad_constvars(jaxpr, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):])
80-
for i, jaxpr in enumerate(jaxprs)]
56+
num_constss = [len(cs) for cs in all_consts]
57+
jaxprs = [_pad_constvars(jaxpr, num_consts, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):])
58+
for i, (jaxpr, num_consts) in enumerate(zip(jaxprs, num_constss))]
8159
# De-duplicate shared constants.
8260
const_ids = tuple(id(c) for c in consts)
8361
seen = set()
84-
consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore
85-
jaxprs = [_dedup_consts(jaxpr, const_ids) for jaxpr in jaxprs]
86-
87-
closed_jaxprs = [pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
88-
for jaxpr in jaxprs]
89-
return closed_jaxprs, consts, all_out_trees
62+
dd_consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore
63+
jaxprs = [_dedup_consts(jaxpr, len(consts), const_ids) for jaxpr in jaxprs]
64+
return jaxprs, dd_consts
9065

9166
@weakref_lru_cache
92-
def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...],
93-
right: tuple[core.AbstractValue, ...]) -> core.Jaxpr:
67+
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
68+
left: tuple[core.AvalQDD, ...],
69+
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
9470
def make_var(aq):
9571
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
96-
constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
97-
effs = pe._renumber_effects([*constvars, *jaxpr.invars],
98-
[*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
99-
jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
72+
invars = [*map(make_var, left), *jaxpr.invars[:num_consts],
73+
*map(make_var, right), *jaxpr.invars[num_consts:]]
74+
effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects)
75+
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
10076
config.enable_checks.value and core.check_jaxpr(jaxpr)
10177
return jaxpr
10278

10379
@weakref_lru_cache
104-
def _dedup_consts(jaxpr, const_ids):
80+
def _dedup_consts(jaxpr, num_consts, const_ids):
10581
newvars = {}
10682
canonicalize = {v: newvars.setdefault(constid, v)
107-
for constid, v in zip(const_ids, jaxpr.constvars)}
83+
for constid, v in zip(const_ids, jaxpr.invars[:num_consts])}
10884
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
10985
else x for x in e.invars]) for e in jaxpr.eqns]
11086
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
11187
for x in jaxpr.outvars]
112-
constvars = list(newvars.values())
113-
effs = pe._renumber_effects(
114-
[*constvars, *jaxpr.invars],
115-
[*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
116-
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
117-
effects=effs)
88+
invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]]
89+
effs = pe._renumber_effects(invars,
90+
[*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]],
91+
jaxpr.effects)
92+
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
93+
effects=effs))
11894
config.enable_checks.value and core.check_jaxpr(jaxpr)
11995
return jaxpr
12096

jax/_src/lax/control_flow/conditionals.py

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
import numpy as np
5454

5555
from jax._src.lax.control_flow.common import (
56-
_avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts,
56+
_avals_short, _typecheck_param, _merge_common_consts,
5757
_make_closed_jaxpr, _prune_zeros)
5858

5959
map, unsafe_map = safe_map, map
@@ -149,8 +149,10 @@ def _switch_internal(
149149
if config.mutable_array_checks.value:
150150
api_util.check_no_aliased_ref_args(lambda: dbgs[0], ops_avals, ops)
151151

152-
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
153-
branches, ops_tree, ops_avals, dbgs)
152+
jaxprs_, out_trees, all_consts = zip(*[pe.trace_to_jaxpr(
153+
branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)])
154+
jaxprs, consts = _merge_common_consts(jaxprs_, all_consts)
155+
154156
if config.mutable_array_checks.value:
155157
api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops)
156158
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
@@ -184,7 +186,7 @@ def _switch_internal(
184186
return tree_unflatten(out_trees[0], out)
185187

186188
@partial(api_boundary, repro_api_name="jax_cond")
187-
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
189+
def cond(pred, true_fun: Callable, false_fun: Callable, *operands,
188190
operand=_no_operand_sentinel):
189191
"""Conditionally apply ``true_fun`` or ``false_fun``.
190192
@@ -270,14 +272,16 @@ def cond(pred, true_fun, false_fun, *operands):
270272
if config.mutable_array_checks.value:
271273
api_util.check_no_aliased_ref_args(lambda: dbg_true_fun, ops_avals, ops)
272274
dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
273-
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
274-
(true_fun, false_fun), ops_tree, ops_avals,
275-
[dbg_true_fun, dbg_false_fun])
276-
true_jaxpr, false_jaxpr = jaxprs
275+
276+
true_jaxpr_, out_tree, true_consts = pe.trace_to_jaxpr(
277+
true_fun, ops_tree, ops_avals, dbg_true_fun)
278+
false_jaxpr_, false_out_tree, false_consts = pe.trace_to_jaxpr(
279+
false_fun, ops_tree, ops_avals, dbg_false_fun)
280+
(true_jaxpr, false_jaxpr), consts = _merge_common_consts(
281+
(true_jaxpr_, false_jaxpr_), (true_consts, false_consts))
277282
if config.mutable_array_checks.value:
278283
api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops)
279284

280-
out_tree, false_out_tree = out_trees
281285
if any(isinstance(out_aval, AbstractRef) for out_aval in
282286
true_jaxpr.out_avals + false_jaxpr.out_avals):
283287
raise ValueError("Cannot return `Ref`s from `cond`.")
@@ -399,48 +403,6 @@ def _capitalize(s):
399403
# s.capitalize() converts s[1:] to lowercase which we don't want.
400404
return s[0].capitalize() + s[1:]
401405

402-
@api_boundary
403-
@functools.wraps(_cond)
404-
def cond(*args, **kwargs):
405-
# detect an attempt to call the former, deprecated cond
406-
try:
407-
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
408-
except TypeError:
409-
pass
410-
else:
411-
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
412-
_, true_operand, true_fun, false_operand, false_fun = ba.args
413-
if callable(true_operand) and callable(true_fun):
414-
# treat this as modern cond (with two operands)
415-
return _cond(*args, **kwargs)
416-
if callable(true_fun) and callable(false_fun):
417-
return _cond_with_per_branch_args(*ba.args)
418-
419-
return _cond(*args, **kwargs)
420-
421-
@partial(api_boundary, repro_api_name="jax_cond_with_per_branch_args")
422-
def _cond_with_per_branch_args(pred,
423-
true_operand, true_fun: Callable,
424-
false_operand, false_fun: Callable):
425-
"""Conditionally apply ``true_fun`` or ``false_fun``.
426-
427-
Has equivalent semantics to this Python implementation::
428-
429-
def cond(pred, true_operand, true_fun, false_operand, false_fun):
430-
if pred:
431-
return true_fun(true_operand)
432-
else:
433-
return false_fun(false_operand)
434-
435-
Pred has to be a scalar type, collection types (list, tuple) are not supported
436-
"""
437-
if not (callable(true_fun) and callable(false_fun)):
438-
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
439-
return _cond(pred,
440-
lambda op: true_fun(op[0]),
441-
lambda op: false_fun(op[1]),
442-
(true_operand, false_operand))
443-
444406
def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects:
445407
joined_effects = set()
446408
for b in branches:

jax/_src/lax/control_flow/loops.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from jax._src.lax import slicing
5151
from jax._src.lax import windowed_reductions
5252
from jax._src.lax.control_flow.common import (
53-
_avals_short, _initial_style_jaxpr, _prune_zeros, _typecheck_param,
53+
_avals_short, _prune_zeros, _typecheck_param,
5454
_make_closed_jaxpr)
5555
from jax._src.lax.other import logaddexp
5656
from jax._src.pjit import auto_axes, PartitionSpec as P, reshard
@@ -281,9 +281,8 @@ def _create_jaxpr(init):
281281
init_flat, init_tree = tree_flatten(init)
282282
in_flat, in_tree = tree_flatten((init, xs))
283283
carry_avals = tuple(_map(core.get_aval, init_flat))
284-
open_jaxpr, out_tree, consts = pe.trace_to_jaxpr(
284+
jaxpr, out_tree, consts = pe.trace_to_jaxpr(
285285
f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body)
286-
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(open_jaxpr))
287286
if config.mutable_array_checks.value:
288287
_check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat)
289288
out_tree_children = out_tree.children()
@@ -1712,10 +1711,10 @@ def _create_jaxpr(init_val):
17121711
init_vals, in_tree = tree_flatten((init_val,))
17131712
init_avals = tuple(_map(core.get_aval, init_vals))
17141713
cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
1715-
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
1714+
cond_jaxpr, cond_tree, cond_consts = pe.trace_to_jaxpr(
17161715
cond_fun, in_tree, init_avals, cond_dbg)
17171716
body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
1718-
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
1717+
body_jaxpr, body_tree, body_consts = pe.trace_to_jaxpr(
17191718
body_fun, in_tree, init_avals, body_dbg)
17201719
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
17211720
msg = "cond_fun must return a boolean scalar, but got pytree {}."

jax/_src/lax/control_flow/solves.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from jax._src.interpreters import ad
2828
from jax._src.interpreters import batching
2929
from jax._src.interpreters import mlir
30+
from jax._src.interpreters import partial_eval as pe
3031
from jax._src.interpreters import pxla
3132
from jax._src.traceback_util import api_boundary
3233
from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves,
@@ -36,7 +37,6 @@
3637

3738
from jax._src.lax.control_flow.common import (
3839
_check_tree,
39-
_initial_style_jaxpr,
4040
)
4141

4242
_map = safe_map
@@ -95,7 +95,7 @@ def custom_root(f: Callable,
9595
guess_flat, in_args_tree = tree_flatten((initial_guess,))
9696
guess_avals = tuple(_map(core.get_aval, guess_flat))
9797
f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {})
98-
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
98+
f_jaxpr, out_tree, f_consts = pe.trace_to_jaxpr(
9999
f, in_args_tree, guess_avals, f_debug)
100100

101101
in_tree, = treedef_children(in_args_tree)
@@ -104,7 +104,7 @@ def custom_root(f: Callable,
104104
solve_debug = api_util.debug_info("custom_root solve", solve,
105105
(f, initial_guess), {},
106106
static_argnums=(0,))
107-
solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
107+
solve_jaxpr, solution_tree, solve_consts = pe.trace_to_jaxpr(
108108
partial(solve, f), in_args_tree, guess_avals, solve_debug)
109109
_check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
110110

@@ -114,7 +114,7 @@ def linearize_and_solve(x, b):
114114

115115
linearize_and_solve_dbg = api_util.debug_info("custom_root tangent_solve",
116116
tangent_solve, (initial_guess, initial_guess), {})
117-
l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
117+
l_and_s_jaxpr, out_tree, l_and_s_consts = pe.trace_to_jaxpr(
118118
linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2,
119119
linearize_and_solve_dbg)
120120
_check_tree("tangent_solve", "x", out_tree, in_tree, False)
@@ -268,15 +268,15 @@ def f_aux(x):
268268
matvec_debug = api_util.debug_info("custom_linear_solve",
269269
matvec, (b,), {})
270270
# no auxiliary data assumed for matvec
271-
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
271+
matvec_jaxpr, out_tree, matvec_consts = pe.trace_to_jaxpr(
272272
_shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
273273
matvec_debug)
274274
_check_tree("matvec", "b", out_tree, tree, False)
275275

276276
solve_debug = api_util.debug_info("custom_linear_solve solve",
277277
solve, (matvec, b), {},
278278
static_argnums=(0,))
279-
solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
279+
solve_jaxpr, out_tree, solve_consts = pe.trace_to_jaxpr(
280280
_shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
281281
solve_debug)
282282
_check_tree("solve", "b", out_tree, tree, has_aux)
@@ -294,11 +294,11 @@ def f_aux(x):
294294
vecmat_consts = matvec_consts
295295
else:
296296
vecmat = _transpose_one_output(matvec, b)
297-
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
297+
vecmat_jaxpr, out_tree, vecmat_consts = pe.trace_to_jaxpr(
298298
vecmat, in_args_tree, b_avals, transpose_solve_debug)
299299
assert out_tree == tree
300300

301-
tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
301+
tr_solve_jaxpr, out_tree, tr_solve_consts = pe.trace_to_jaxpr(
302302
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
303303
in_args_tree, b_avals, transpose_solve_debug)
304304
_check_tree("transpose_solve", "b", out_tree, tree, has_aux)

tests/api_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7207,10 +7207,10 @@ def fun(x):
72077207
def test_cond(self):
72087208
def f(x):
72097209
return lax.cond(x >= 0.,
7210+
lambda xt, _: xt + x,
7211+
lambda _, xf: xf - x,
72107212
x + 1.,
7211-
lambda xt: xt + x,
7212-
x + 2.,
7213-
lambda xf: xf - x)
7213+
x + 2.)
72147214
expected = """{ lambda ; a:f32[]. let
72157215
b:bool[] = ge a 0.0:f32[]
72167216
c:f32[] = add a 1.0:f32[]
@@ -7941,10 +7941,10 @@ def f(c, x):
79417941
jax.lax.scan(f, 0, jnp.arange(4))
79427942

79437943
def test_cond_traceback(self):
7944-
if sys.version_info < (3, 14):
7944+
if sys.version_info < (3, 13):
79457945
# Fails because 3.11 adds an extra stack frame due to a list comprehension
79467946
self.skipTest("Expected failure.")
7947-
expected_depth = 8
7947+
expected_depth = 4
79487948
init_depth = self.cur_depth()
79497949

79507950
def f():

tests/core_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,15 +448,11 @@ class JaxprTypeChecks(jtu.JaxTestCase):
448448

449449
def setUp(self):
450450
super().setUp()
451-
lax_control_flow._initial_style_open_jaxpr.cache_clear()
452-
lax_control_flow._initial_style_jaxpr.cache_clear()
453451
lax_control_flow.common._dedup_consts.cache_clear()
454452
lax_control_flow.common._pad_constvars.cache_clear()
455453

456454
def tearDown(self):
457455
super().tearDown()
458-
lax_control_flow._initial_style_open_jaxpr.cache_clear()
459-
lax_control_flow._initial_style_jaxpr.cache_clear()
460456
lax_control_flow.common._dedup_consts.cache_clear()
461457
lax_control_flow.common._pad_constvars.cache_clear()
462458

0 commit comments

Comments
 (0)