Skip to content

Conversation

@jyc
Copy link
Contributor

@jyc jyc commented Sep 12, 2025

  • Task types: R, X, S (the task mix is given in terms of task types) TODO maybe clean this up & make type either rx/s, then just call these task config names?
  • Task kinds: R- and X- denoising have the same implementation, just different parameters, so they have the same kind.
    S-denoising is its own kind.
    We encode task parameters for access by JAX as a tensor task_params[task_idx] = [task_kind, task_token_id, mask_prob, mean_span_length].
  • to_ul2r_tokens: Dispatches to to_ul2r_rx_tokens (span corruption) or to_ul2r_s_tokens.
  • ul2r_loss_mask: Creates loss mask (loss computed only on outputs, not inputs/padding)
  • ul2r_block_diagonal_mask: Creates prefix attention mask (inputs attend bidirectionally within segment)
  • Ul2rDataset: Packs / reserves space using GreedyPrepackedDataset; denoises using create_ul2r_example

Pipeline

  1. Compute denoising lengths
  2. Reserve space in GreedyPrepackedDataset: Documents remain contiguous; padding reserved at end only.
  3. Create denoising examples in loop: For each segment, process_segment rolls it to position 0, applies to_ul2r_tokens creating 0...<task_token><denoised>...0....
    OR these together: task_token1 denoised1 task_token2 denoised2 ....
    Generate LmExample with ul2r_loss_mask and ul2r_block_diagonal_mask (ORed with causal mask).

Data

  1. (NumPy) doc1 doc2 ... from TreeCache
  2. (NumPy) doc1 doc2 0... packed contiguously in GreedyPrepackedDataset with reserved space at end
  3. (JAX) For each segment i, process_segment (ul2r.py:749-781):
    • Rolled input: doc rest
    • After to_ul2r_tokens: task_token denoised 0...
    • Rolled to out_start: 0... task_token denoised 0...
  4. (JAX) OR all segments together: task_token1 denoised1 task_token2 denoised2 ...

JAX

Had to modify implementations to work in JIT-ed functions.
For example, the original np.pad(np_rng.permutation(...)) pattern doesn't work when num_items and num_segments depend on example length.

  • random_segmentation_jax and random_spans_noise_mask_jax won't produce identical results to their NumPy equivalents for the same key
  • noise_span_to_unique_sentinel should behave identically (modulo padding)

to_ul2r_rx_tokens corresponds to this code.

TODOs

  • Try generate some data using cache_dataset (debugging config errors right now)
  • Do we need to pass in KPos as well as Pos (aka QPos?) text.py doesn't use it, we would just use it for the segment IDs in AttentionMask?
  • Do we need eos_id
  • Do we need to make sure we don't mask out bos/eos_id or have the prefix/continuation consist entirely of them?

  • Make prefix attention work with splash attention
  • Implement versions both with and without PrefixLM; disable better attention implementations when prefix is enabled (from @dlwh)
  • Later: avoid GreedyPrepacked holding O(n) tensors for entire dataset (iterate over chunks to trade space for time?)

Note: Not attached to any implementation details, names, etc. Feedback from JAX experts very welcome!

@dlwh dlwh requested a review from Copilot September 15, 2025 04:24
@dlwh
Copy link
Member

dlwh commented Sep 15, 2025

@codex can you look too

@chatgpt-codex-connector
Copy link

@codex can you look too

For now, I can only help with PRs you've created.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces JIT-compiled functions for UL2R (Unified Language Learner 2 Revised) preprocessing, implementing the R/X-denoising training methodology. The implementation provides JAX-native versions of functions from a previous NumPy implementation to enable efficient batched processing.

  • Adds noise span generation, segmentation, and sentinel token replacement functions
  • Implements UL2R-specific attention masks for PrefixLM-style bidirectional attention on inputs
  • Extends the attention system to support prefix masks that can be combined with causal attention

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
src/levanter/data/ul2r.py Core UL2R functions including noise mask generation, token replacement, and attention mask creation
src/levanter/layers/attention.py Adds prefix_mask support to AttentionMask class for UL2R-style attention patterns
tests/test_ul2r.py Comprehensive test suite covering all UL2R functions with edge cases and determinism checks
tests/test_attention.py Integration test for prefix mask functionality with AttentionMask

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

f"Too many noise spans: {max_segments} > {len(sentinel_tokens)}",
),
errors=checkify.index_checks,
)()
Copy link

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkify result is assigned but err is never used to actually check or handle the error. This means the check will be computed but any violations will be silently ignored.

Suggested change
)()
)()
if err is not None:
raise err

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting, I guess I can ignore the return but then unless the caller wraps this function with checkify.checkify nothing happens? https://docs.jax.dev/en/latest/_autosummary/jax.experimental.checkify.check.html

import jax
import jax.numpy as jnp
from jax.experimental import checkify
def f(x):
  checkify.check(x>0, "{x} needs to be positive!", x=x)
  return 1/x
checked_f = checkify.checkify(f)
err, out = jax.jit(checked_f)(-3.)
err.throw()  
Traceback (most recent call last):
  ...
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!

I don't see it used anywhere else so maybe not worth it, but will just remove err value for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't use checkify so i dunno. i use eqx.error_if

Comment on lines 231 to 232
input_len = jnp.argmax(inputs == pad_token_id)
target_len = jnp.argmax(targets == pad_token_id)
Copy link

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using argmax to find padding will return 0 if no padding tokens exist (when the entire array is non-padding), which would incorrectly indicate zero length. Consider using jnp.where(inputs == pad_token_id, jnp.arange(len(inputs)), len(inputs)).min() or handle the case where argmax returns 0 but the first element isn't actually padding.

Suggested change
input_len = jnp.argmax(inputs == pad_token_id)
target_len = jnp.argmax(targets == pad_token_id)
input_len = jnp.where(inputs == pad_token_id, indices, padded_length).min()
target_len = jnp.where(targets == pad_token_id, indices, padded_length).min()

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, nice!

Comment on lines 412 to 413
assert bool(has_sentinel_inputs), "Inputs should contain at least one sentinel token"
assert bool(has_sentinel_outputs), "Outputs should contain at least one sentinel token"
Copy link

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The bool() conversion is unnecessary since assert already evaluates the expression as a boolean. You can simplify to assert has_sentinel_inputs and assert has_sentinel_outputs.

Suggested change
assert bool(has_sentinel_inputs), "Inputs should contain at least one sentinel token"
assert bool(has_sentinel_outputs), "Outputs should contain at least one sentinel token"
assert has_sentinel_inputs, "Inputs should contain at least one sentinel token"
assert has_sentinel_outputs, "Outputs should contain at least one sentinel token"

Copilot uses AI. Check for mistakes.
Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good! We need to have a version both with and without prefix probably, and not allow the better attention implementations when it's enabled (for now)...

It's probably not too hard to add those though.

I tried to check against my memory of the algorithms and they look right but it's likely i missed something



@functools.partial(
jax.jit, static_argnames=["mean_noise_span_length", "random_roll", "padded_length"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think there's any reason for mean_noise_span_length to be static and if you're going to use cond anyway, no reason for random_roll to be static



@functools.partial(jax.jit, static_argnames=["padded_length"])
def random_segmentation_jax(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you link to references for these (t5x or whatever)

return mask


@functools.partial(jax.jit, static_argnames=["pad_token_id"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably fine for pad token to not be static

pad_token_id: int,
) -> jnp.ndarray:
"""
Replace each run of consecutive noise tokens with a different sentinel.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe document that length is the true length?

Comment on lines 186 to 188
"pad_token_id",
"mean_noise_span_length",
"random_roll",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think these need to be static

Copy link
Contributor Author

@jyc jyc Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha thanks for pointing all these out, good to learn. I guess for arguments that will have the same shape in different invocations it's not generally not worth making it static unless that somehow lets the compiler optimize control flow? (edit: I think I'm wrong)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh hm yea I guess curious about the intuition around static! The docs say:

Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not array-like or containers thereof must be marked as static.

but I definitely don't have the full context

input_masks: jnp.ndarray, segment_ids: jnp.ndarray
) -> jnp.ndarray:
"""
- Input tokens can attend to all other input tokens bidirectionally within
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is probably convenient to have a version that is still causal b/c it will compose nicer with splash attention. (Or we'll need to write our own kernel...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah interesting, so this function should generate the full PrefixLM mask instead of just the block diagonal part for the prefixes? I was thinking we could take this mask and or it with a causal mask but don't know much about splash attention, will look up the implementation in the meantime!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well i guess we'd want to and it to get causal but yeah i guess that would work

segment_mask = _materialize_segment_mask(self.segment_ids, QPos, KPos, q_slice, k_slice)
mask = combine_masks_and(mask, segment_mask)

if self.prefix_mask is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add code to raise if prefix_mask is encountered in splash_attention, flash_attention [which can probably handle this with a bit of work, but not a ton], or the nvidia path

@dlwh
Copy link
Member

dlwh commented Sep 16, 2025

btw the whole plan looks good to me! forgot to say that!

@jyc jyc force-pushed the jyc-ul2r-denoising-jax branch from 41cebc9 to c18f60d Compare September 18, 2025 03:42
@jyc jyc force-pushed the jyc-ul2r-denoising-jax branch from 6eb67e8 to 0f4dbfc Compare October 2, 2025 05:34
pr: feedback

wip: start work on Dataset

wip: a little work on Dataset

wip: out offsets actually will be different from in offsets

wip: llm-generated tests pass

wip: create_ul2r_example test passes@

wip: Ul2rDataset tests pass!

wip: didn't end up modifying that function

wip: try make config run with python -m levanter.main.cache_dataset --config config/llama3_ul2r.yaml

wip: make python -m levanter.main.cache_dataset --config config/llama3_ul2r.yaml run

wip: try print ul2r-ed dataset (not working)

wip: half-works! train_lm --data-only outputs data where some spans look right and other spans look like gibberish

wip: fix some bugs, but there's still gibberish

wip: fix last gibberish bug!! tokens look ok, now need to check loss/attn

wip: autoformat

wip: test that no new tokens were introduced (gibberish), take explicit QPos/KPos
@jyc jyc force-pushed the jyc-ul2r-denoising-jax branch from 539ffe1 to 61f7b30 Compare October 20, 2025 20:58
@jyc jyc force-pushed the jyc-ul2r-denoising-jax branch from df54427 to 0f7c870 Compare November 5, 2025 01:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet