-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[no-thunks] Reduce stack frames in cond and friends
#33677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @dougalm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the internal implementation of JAX's control flow primitives, specifically Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors control flow primitives to reduce stack frames by removing wrapper functions and calling trace_to_jaxpr directly. The changes are mostly consistent and achieve the goal of simplification. However, I've identified a critical issue in _switch_internal that includes an assert False and incorrect variable usage, which will cause a crash. Additionally, there are several incorrect or missing type hints in jax/_src/lax/control_flow/common.py that should be addressed to improve code correctness and maintainability.
| assert False | ||
|
|
||
| jaxprs_, all_consts, out_trees = zip(*[pe.trace_to_jaxpr( | ||
| f, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbg)]) | ||
| jaxprs, consts = _merge_common_consts(jaxprs_, all_consts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code has several issues that will cause it to fail:
assert Falseon line 151 will always be triggered, causing a crash.- In the list comprehension on lines 153-154,
dbgis used inzip(branches, dbg), but the list of debug info is nameddbgs. This should bezip(branches, dbgs). - In the same list comprehension,
fis passed tope.trace_to_jaxpr, but it's not defined. The function for the current branch isbranch.
jaxprs_, all_consts, out_trees = zip(*[pe.trace_to_jaxpr(
branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)])
jaxprs, consts = _merge_common_consts(jaxprs_, all_consts)| # TODO(dougalm): this seems way too complicated. Why not allow different consts for each | ||
| # branch of a switch? | ||
| def _merge_common_consts( | ||
| jaxprs: Sequence[core.Jaxpr], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for the jaxprs parameter is Sequence[core.Jaxpr], but this function is called with a sequence of core.ClosedJaxpr objects. The implementation also relies on this, as it passes these objects to _pad_constvars which expects core.ClosedJaxpr. Please update the type hint to Sequence[core.ClosedJaxpr] for correctness.
| jaxprs: Sequence[core.Jaxpr], | |
| jaxprs: Sequence[core.ClosedJaxpr], |
jax/_src/lax/control_flow/common.py
Outdated
| def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...], | ||
| def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int, | ||
| left: tuple[core.AvalQDD, ...], | ||
| right: tuple[core.AbstractValue, ...]) -> core.Jaxpr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function _pad_constvars is declared to return a core.Jaxpr, but it actually returns a core.ClosedJaxpr since it calls jaxpr.replace(...) on a ClosedJaxpr instance. Please update the return type hint to core.ClosedJaxpr to match the implementation.
| right: tuple[core.AbstractValue, ...]) -> core.Jaxpr: | |
| right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr: |
|
|
||
| @weakref_lru_cache | ||
| def _dedup_consts(jaxpr, const_ids): | ||
| def _dedup_consts(jaxpr, num_consts, const_ids): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function _dedup_consts lacks type hints, which makes it harder to understand and maintain. Based on its usage, it takes and returns a core.ClosedJaxpr. Please consider adding type hints to the function signature.
| def _dedup_consts(jaxpr, num_consts, const_ids): | |
| def _dedup_consts(jaxpr: core.ClosedJaxpr, num_consts: int, const_ids: tuple[int, ...]) -> core.ClosedJaxpr: |
744d1cf to
5d4f621
Compare
568f00d to
8e772e7
Compare
Also deprecate the very old form of `cond`: `cond(predicate, true_arg, true_fun, false_arg, false_fun)`.
8e772e7 to
bc41be8
Compare
Imported from GitHub PR jax-ml#33677 Copybara import of the project: -- 8e772e7 by Dougal <[email protected]>: [no-thunks] Reduce stack frames in `cond` and friends Also deprecate the very old form of `cond`: `cond(predicate, true_arg, true_fun, false_arg, false_fun)`. Merging this change closes jax-ml#33677 COPYBARA_INTEGRATE_REVIEW=jax-ml#33677 from jax-ml:fewer-cond-stack-frames 8e772e7 PiperOrigin-RevId: 839851241
No description provided.