-
Notifications
You must be signed in to change notification settings - Fork 117
WIP: Add JIT-ed UL2R functions #1158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@codex can you look too |
For now, I can only help with PRs you've created. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces JIT-compiled functions for UL2R (Unified Language Learner 2 Revised) preprocessing, implementing the R/X-denoising training methodology. The implementation provides JAX-native versions of functions from a previous NumPy implementation to enable efficient batched processing.
- Adds noise span generation, segmentation, and sentinel token replacement functions
- Implements UL2R-specific attention masks for PrefixLM-style bidirectional attention on inputs
- Extends the attention system to support prefix masks that can be combined with causal attention
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/levanter/data/ul2r.py |
Core UL2R functions including noise mask generation, token replacement, and attention mask creation |
src/levanter/layers/attention.py |
Adds prefix_mask support to AttentionMask class for UL2R-style attention patterns |
tests/test_ul2r.py |
Comprehensive test suite covering all UL2R functions with edge cases and determinism checks |
tests/test_attention.py |
Integration test for prefix mask functionality with AttentionMask |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
src/levanter/data/ul2r.py
Outdated
| f"Too many noise spans: {max_segments} > {len(sentinel_tokens)}", | ||
| ), | ||
| errors=checkify.index_checks, | ||
| )() |
Copilot
AI
Sep 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkify result is assigned but err is never used to actually check or handle the error. This means the check will be computed but any violations will be silently ignored.
| )() | |
| )() | |
| if err is not None: | |
| raise err |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah interesting, I guess I can ignore the return but then unless the caller wraps this function with checkify.checkify nothing happens? https://docs.jax.dev/en/latest/_autosummary/jax.experimental.checkify.check.html
import jax
import jax.numpy as jnp
from jax.experimental import checkify
def f(x):
checkify.check(x>0, "{x} needs to be positive!", x=x)
return 1/x
checked_f = checkify.checkify(f)
err, out = jax.jit(checked_f)(-3.)
err.throw()
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
I don't see it used anywhere else so maybe not worth it, but will just remove err value for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't use checkify so i dunno. i use eqx.error_if
src/levanter/data/ul2r.py
Outdated
| input_len = jnp.argmax(inputs == pad_token_id) | ||
| target_len = jnp.argmax(targets == pad_token_id) |
Copilot
AI
Sep 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using argmax to find padding will return 0 if no padding tokens exist (when the entire array is non-padding), which would incorrectly indicate zero length. Consider using jnp.where(inputs == pad_token_id, jnp.arange(len(inputs)), len(inputs)).min() or handle the case where argmax returns 0 but the first element isn't actually padding.
| input_len = jnp.argmax(inputs == pad_token_id) | |
| target_len = jnp.argmax(targets == pad_token_id) | |
| input_len = jnp.where(inputs == pad_token_id, indices, padded_length).min() | |
| target_len = jnp.where(targets == pad_token_id, indices, padded_length).min() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh, nice!
tests/test_ul2r.py
Outdated
| assert bool(has_sentinel_inputs), "Inputs should contain at least one sentinel token" | ||
| assert bool(has_sentinel_outputs), "Outputs should contain at least one sentinel token" |
Copilot
AI
Sep 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The bool() conversion is unnecessary since assert already evaluates the expression as a boolean. You can simplify to assert has_sentinel_inputs and assert has_sentinel_outputs.
| assert bool(has_sentinel_inputs), "Inputs should contain at least one sentinel token" | |
| assert bool(has_sentinel_outputs), "Outputs should contain at least one sentinel token" | |
| assert has_sentinel_inputs, "Inputs should contain at least one sentinel token" | |
| assert has_sentinel_outputs, "Outputs should contain at least one sentinel token" |
dlwh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good! We need to have a version both with and without prefix probably, and not allow the better attention implementations when it's enabled (for now)...
It's probably not too hard to add those though.
I tried to check against my memory of the algorithms and they look right but it's likely i missed something
src/levanter/data/ul2r.py
Outdated
|
|
||
|
|
||
| @functools.partial( | ||
| jax.jit, static_argnames=["mean_noise_span_length", "random_roll", "padded_length"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't think there's any reason for mean_noise_span_length to be static and if you're going to use cond anyway, no reason for random_roll to be static
src/levanter/data/ul2r.py
Outdated
|
|
||
|
|
||
| @functools.partial(jax.jit, static_argnames=["padded_length"]) | ||
| def random_segmentation_jax( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you link to references for these (t5x or whatever)
src/levanter/data/ul2r.py
Outdated
| return mask | ||
|
|
||
|
|
||
| @functools.partial(jax.jit, static_argnames=["pad_token_id"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably fine for pad token to not be static
| pad_token_id: int, | ||
| ) -> jnp.ndarray: | ||
| """ | ||
| Replace each run of consecutive noise tokens with a different sentinel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe document that length is the true length?
src/levanter/data/ul2r.py
Outdated
| "pad_token_id", | ||
| "mean_noise_span_length", | ||
| "random_roll", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't think these need to be static
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha thanks for pointing all these out, good to learn. I guess for arguments that will have the same shape in different invocations it's not generally not worth making it static unless that somehow lets the compiler optimize control flow? (edit: I think I'm wrong)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh hm yea I guess curious about the intuition around static! The docs say:
Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not array-like or containers thereof must be marked as static.
but I definitely don't have the full context
src/levanter/data/ul2r.py
Outdated
| input_masks: jnp.ndarray, segment_ids: jnp.ndarray | ||
| ) -> jnp.ndarray: | ||
| """ | ||
| - Input tokens can attend to all other input tokens bidirectionally within |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is probably convenient to have a version that is still causal b/c it will compose nicer with splash attention. (Or we'll need to write our own kernel...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah interesting, so this function should generate the full PrefixLM mask instead of just the block diagonal part for the prefixes? I was thinking we could take this mask and or it with a causal mask but don't know much about splash attention, will look up the implementation in the meantime!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well i guess we'd want to and it to get causal but yeah i guess that would work
src/levanter/layers/attention.py
Outdated
| segment_mask = _materialize_segment_mask(self.segment_ids, QPos, KPos, q_slice, k_slice) | ||
| mask = combine_masks_and(mask, segment_mask) | ||
|
|
||
| if self.prefix_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add code to raise if prefix_mask is encountered in splash_attention, flash_attention [which can probably handle this with a bit of work, but not a ton], or the nvidia path
|
btw the whole plan looks good to me! forgot to say that! |
41cebc9 to
c18f60d
Compare
6eb67e8 to
0f4dbfc
Compare
pr: feedback wip: start work on Dataset wip: a little work on Dataset wip: out offsets actually will be different from in offsets wip: llm-generated tests pass wip: create_ul2r_example test passes@ wip: Ul2rDataset tests pass! wip: didn't end up modifying that function wip: try make config run with python -m levanter.main.cache_dataset --config config/llama3_ul2r.yaml wip: make python -m levanter.main.cache_dataset --config config/llama3_ul2r.yaml run wip: try print ul2r-ed dataset (not working) wip: half-works! train_lm --data-only outputs data where some spans look right and other spans look like gibberish wip: fix some bugs, but there's still gibberish wip: fix last gibberish bug!! tokens look ok, now need to check loss/attn wip: autoformat wip: test that no new tokens were introduced (gibberish), take explicit QPos/KPos
539ffe1 to
61f7b30
Compare
…hreshold); length calculation
…s were causing nans
df54427 to
0f7c870
Compare
…es from instead of 0
S-denoising is its own kind.
We encode task parameters for access by JAX as a tensor
task_params[task_idx] = [task_kind, task_token_id, mask_prob, mean_span_length].to_ul2r_tokens: Dispatches toto_ul2r_rx_tokens(span corruption) orto_ul2r_s_tokens.ul2r_loss_mask: Creates loss mask (loss computed only on outputs, not inputs/padding)ul2r_block_diagonal_mask: Creates prefix attention mask (inputs attend bidirectionally within segment)Ul2rDataset: Packs / reserves space usingGreedyPrepackedDataset; denoises usingcreate_ul2r_examplePipeline
process_segmentrolls it to position 0, appliesto_ul2r_tokenscreating0...<task_token><denoised>...0....OR these together:
task_token1 denoised1 task_token2 denoised2 ....Generate
LmExamplewithul2r_loss_maskandul2r_block_diagonal_mask(ORed with causal mask).Data
doc1 doc2 ...from TreeCachedoc1 doc2 0...packed contiguously inGreedyPrepackedDatasetwith reserved space at endprocess_segment(ul2r.py:749-781):doc restto_ul2r_tokens:task_token denoised 0...out_start:0... task_token denoised 0...task_token1 denoised1 task_token2 denoised2 ...JAX
Had to modify implementations to work in JIT-ed functions.
For example, the original
np.pad(np_rng.permutation(...))pattern doesn't work whennum_itemsandnum_segmentsdepend on example length.random_segmentation_jaxandrandom_spans_noise_mask_jaxwon't produce identical results to their NumPy equivalents for the same keynoise_span_to_unique_sentinelshould behave identically (modulo padding)to_ul2r_rx_tokenscorresponds to this code.TODOs
cache_dataset(debugging config errors right now)KPosas well asPos(aka QPos?)text.pydoesn't use it, we would just use it for the segment IDs in AttentionMask?eos_idGreedyPrepackedholding O(n) tensors for entire dataset (iterate over chunks to trade space for time?)Note: Not attached to any implementation details, names, etc. Feedback from JAX experts very welcome!