Skip to content

Conversation

@dougalm
Copy link
Collaborator

@dougalm dougalm commented Dec 2, 2025

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @dougalm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the internal implementation of JAX's control flow primitives, specifically cond and while_loop, to reduce the number of stack frames involved in their execution. By having trace_to_jaxpr directly return a ClosedJaxpr and streamlining the handling of constants and Jaxpr generation, the changes aim to improve the efficiency and debuggability of these core operations.

Highlights

  • Direct ClosedJaxpr Generation: The trace_to_jaxpr function now directly produces a ClosedJaxpr, removing an intermediate step and simplifying the Jaxpr generation process for control flow primitives.
  • Streamlined Control Flow Internals: Several internal helper functions (_initial_style_open_jaxpr, _initial_style_jaxpr, _initial_style_jaxprs_with_common_consts) have been removed or refactored, leading to a more direct and efficient implementation of control flow primitives like cond and while_loop.
  • Reduced Stack Frames: The changes result in a reduction of stack frames for cond and related operations, as confirmed by updated test expectations, which improves debugging and potentially performance.
  • Public cond API: The internal _cond function has been promoted to the public cond API, simplifying its usage and removing a previous wrapper.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors control flow primitives to reduce stack frames by removing wrapper functions and calling trace_to_jaxpr directly. The changes are mostly consistent and achieve the goal of simplification. However, I've identified a critical issue in _switch_internal that includes an assert False and incorrect variable usage, which will cause a crash. Additionally, there are several incorrect or missing type hints in jax/_src/lax/control_flow/common.py that should be addressed to improve code correctness and maintainability.

Comment on lines 151 to 155
assert False

jaxprs_, all_consts, out_trees = zip(*[pe.trace_to_jaxpr(
f, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbg)])
jaxprs, consts = _merge_common_consts(jaxprs_, all_consts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block of code has several issues that will cause it to fail:

  1. assert False on line 151 will always be triggered, causing a crash.
  2. In the list comprehension on lines 153-154, dbg is used in zip(branches, dbg), but the list of debug info is named dbgs. This should be zip(branches, dbgs).
  3. In the same list comprehension, f is passed to pe.trace_to_jaxpr, but it's not defined. The function for the current branch is branch.
  jaxprs_, all_consts, out_trees = zip(*[pe.trace_to_jaxpr(
      branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)])
  jaxprs, consts = _merge_common_consts(jaxprs_, all_consts)

# TODO(dougalm): this seems way too complicated. Why not allow different consts for each
# branch of a switch?
def _merge_common_consts(
jaxprs: Sequence[core.Jaxpr],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for the jaxprs parameter is Sequence[core.Jaxpr], but this function is called with a sequence of core.ClosedJaxpr objects. The implementation also relies on this, as it passes these objects to _pad_constvars which expects core.ClosedJaxpr. Please update the type hint to Sequence[core.ClosedJaxpr] for correctness.

Suggested change
jaxprs: Sequence[core.Jaxpr],
jaxprs: Sequence[core.ClosedJaxpr],

def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...],
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.Jaxpr:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _pad_constvars is declared to return a core.Jaxpr, but it actually returns a core.ClosedJaxpr since it calls jaxpr.replace(...) on a ClosedJaxpr instance. Please update the return type hint to core.ClosedJaxpr to match the implementation.

Suggested change
right: tuple[core.AbstractValue, ...]) -> core.Jaxpr:
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:


@weakref_lru_cache
def _dedup_consts(jaxpr, const_ids):
def _dedup_consts(jaxpr, num_consts, const_ids):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function _dedup_consts lacks type hints, which makes it harder to understand and maintain. Based on its usage, it takes and returns a core.ClosedJaxpr. Please consider adding type hints to the function signature.

Suggested change
def _dedup_consts(jaxpr, num_consts, const_ids):
def _dedup_consts(jaxpr: core.ClosedJaxpr, num_consts: int, const_ids: tuple[int, ...]) -> core.ClosedJaxpr:

@dougalm dougalm force-pushed the fewer-cond-stack-frames branch 3 times, most recently from 744d1cf to 5d4f621 Compare December 3, 2025 03:43
@dougalm dougalm added the pull ready Ready for copybara import and testing label Dec 3, 2025
@dougalm dougalm force-pushed the fewer-cond-stack-frames branch 5 times, most recently from 568f00d to 8e772e7 Compare December 3, 2025 19:35
Also deprecate the very old form of `cond`:
`cond(predicate, true_arg, true_fun, false_arg, false_fun)`.
@dougalm dougalm force-pushed the fewer-cond-stack-frames branch from 8e772e7 to bc41be8 Compare December 3, 2025 19:51
@copybara-service copybara-service bot closed this in 23cd412 Dec 3, 2025
partev pushed a commit to partev/jax that referenced this pull request Dec 4, 2025
Imported from GitHub PR jax-ml#33677

Copybara import of the project:

--
8e772e7 by Dougal <[email protected]>:

[no-thunks] Reduce stack frames in `cond` and friends

Also deprecate the very old form of `cond`:
`cond(predicate, true_arg, true_fun, false_arg, false_fun)`.

Merging this change closes jax-ml#33677

COPYBARA_INTEGRATE_REVIEW=jax-ml#33677 from jax-ml:fewer-cond-stack-frames 8e772e7
PiperOrigin-RevId: 839851241
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant