Skip to content

Commit 539ffe1

Browse files
committed
wip: test that no new tokens were introduced (gibberish), take explicit QPos/KPos
1 parent 9e43e60 commit 539ffe1

File tree

2 files changed

+68
-38
lines changed

2 files changed

+68
-38
lines changed

src/levanter/data/ul2r.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -637,17 +637,20 @@ def s_length() -> jnp.ndarray:
637637
return jax.lax.cond(task_kind == RX_TASK_KIND, rx_length, s_length)
638638

639639

640-
@functools.partial(jax.jit, static_argnames=("max_segments_per_example", "Pos"))
640+
@functools.partial(jax.jit, static_argnames=("max_segments_per_example", "QPos", "KPos"))
641641
def create_ul2r_example(
642642
key: PRNGKeyArray,
643643
task_params: jnp.ndarray,
644644
task_indices: jnp.ndarray,
645645
max_segments_per_example: int,
646-
Pos: Axis,
646+
QPos: Axis,
647+
KPos: Axis,
647648
pad_token_id: int,
648649
tokens: hax.NamedArray,
649650
segment_ids: hax.NamedArray,
650651
) -> LmExample:
652+
jax.debug.print("create_ul2r_example start")
653+
651654
# TODO Use NamedArrays more idiomatically
652655
# `unique_seg_ids = [3, 4, ..., -1, ...]`
653656
# Valid segment IDs come first, padded with -1.
@@ -710,7 +713,7 @@ def process_segment(key: PRNGKeyArray, id: int) -> tuple[jnp.ndarray, jnp.ndarra
710713
out_start = jnp.squeeze(out_starts[idx])
711714

712715
segment = jnp.roll(tokens.array, -in_start)
713-
inputs_len, denoising_tokens = to_ul2r_tokens(key, task_params[task_idx], segment, in_length, Pos.size)
716+
inputs_len, denoising_tokens = to_ul2r_tokens(key, task_params[task_idx], segment, in_length, QPos.size)
714717

715718
n_tokens = tokens.array.shape[0]
716719
input_mask = jnp.arange(n_tokens) < inputs_len
@@ -750,18 +753,17 @@ def loop(
750753
# TODO GreedyPrepackedDataset pads w/ zeros so can we end up with two
751754
# padding token IDs?
752755
loss_mask = ul2r_loss_mask(input_mask, out_seg_ids, denoising_tokens, pad_token_id)
753-
loss_mask = hax.named(loss_mask, Pos)
756+
loss_mask = hax.named(loss_mask, QPos)
754757

755-
KPos = Pos.alias("KPos")
756758
attn_mask = AttentionMask(
757759
is_causal=True,
758760
is_prefix=True,
759-
input_mask=hax.named(input_mask, [Pos]),
760-
segment_ids=(hax.named(out_seg_ids, [Pos]), hax.named(out_seg_ids, [KPos])),
761+
input_mask=hax.named(input_mask, [QPos]),
762+
segment_ids=(hax.named(out_seg_ids, [QPos]), hax.named(out_seg_ids, [KPos])),
761763
)
762764

763-
denoising_tokens = hax.named(denoising_tokens, Pos)
764-
out_seg_ids = hax.named(out_seg_ids, Pos)
765+
denoising_tokens = hax.named(denoising_tokens, QPos)
766+
out_seg_ids = hax.named(out_seg_ids, QPos)
765767
return LmExample(tokens=denoising_tokens, loss_mask=loss_mask, attn_mask=attn_mask)
766768

767769

@@ -773,7 +775,8 @@ class Ul2rDataset(MappedAsyncDataset[tuple[TokenizedDict, TokenizedDict], LmExam
773775
def __init__(
774776
self,
775777
cache: TreeCache[TokenizedDict],
776-
Pos: Axis,
778+
QPos: Axis,
779+
KPos: Axis,
777780
task_configs: typing.Dict[str, DenoisingConfig],
778781
task_probs: Dict[str, float],
779782
key: PRNGKeyArray,
@@ -826,14 +829,15 @@ def _compute_length(task_idx: jnp.ndarray, length: jnp.ndarray) -> int:
826829
# packed leaves and the second has the segment ids
827830
self.packed: GreedyPrepackedDataset[TokenizedDict] = GreedyPrepackedDataset(
828831
cache.store.tree,
829-
Pos.size,
832+
QPos.size,
830833
max_segments_per_example=max_segments_per_example,
831834
slice_strategy=slice_strategy,
832835
packing_lengths=out_lengths,
833836
# Reserve space for UL2R; denoising examples increase in length.
834837
pad_with_zeros=True,
835838
)
836-
self.Pos = Pos
839+
self.QPos = QPos
840+
self.KPos = KPos
837841
self.pad_token_id = pad_token_id
838842

839843
sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])
@@ -842,14 +846,15 @@ def _compute_length(task_idx: jnp.ndarray, length: jnp.ndarray) -> int:
842846
@functools.partial(eqx.filter_jit, out_shardings=sharding)
843847
def _create_lm_example(e: tuple[TokenizedDict, TokenizedDict]) -> LmExample:
844848
example, seg_ids = e
845-
tokens = hax.named(example["input_ids"], self.Pos)
846-
segment_ids = hax.named(seg_ids["input_ids"], self.Pos)
849+
tokens = hax.named(example["input_ids"], self.QPos)
850+
segment_ids = hax.named(seg_ids["input_ids"], self.QPos)
847851
return create_ul2r_example(
848852
key,
849853
task_params,
850854
task_indices,
851855
max_segments_per_example,
852-
self.Pos,
856+
self.QPos,
857+
self.KPos,
853858
self.pad_token_id,
854859
tokens,
855860
segment_ids,

tests/test_ul2r.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
import tempfile
44
import typing
5-
from haliax import Axis
5+
from haliax import Axis, NamedArray
66
from haliax.nn import hax
77
import jax
88
import jax.numpy as jnp
@@ -29,6 +29,9 @@
2929
XDenoisingConfig,
3030
SDenoisingConfig,
3131
Ul2rDataset,
32+
R_TASK_TOKEN_ID,
33+
X_TASK_TOKEN_ID,
34+
S_TASK_TOKEN_ID,
3235
)
3336
from levanter.layers.attention import AttentionMask
3437
from levanter.models.lm_model import LmExample
@@ -398,7 +401,8 @@ def test_to_ul2r_s_tokens():
398401

399402

400403
def test_create_ul2r_example():
401-
Pos = Axis("position", 128)
404+
QPos = Axis("QPos", 128)
405+
KPos = Axis("KPos", 128)
402406
pad_token_id = 0
403407
max_segments_per_example = 8
404408

@@ -411,7 +415,7 @@ def test_create_ul2r_example():
411415
jnp.zeros(105, dtype=jnp.int32), # padding
412416
]
413417
)
414-
tokens = hax.named(tokens, Pos)
418+
tokens = hax.named(tokens, QPos)
415419

416420
segment_ids = jnp.concatenate(
417421
[
@@ -421,7 +425,7 @@ def test_create_ul2r_example():
421425
jnp.full(105, -1),
422426
]
423427
)
424-
segment_ids = hax.named(segment_ids, Pos)
428+
segment_ids = hax.named(segment_ids, QPos)
425429

426430
task_configs = [
427431
RDenoisingConfig(mask_prob=0.15, mean_span_length=3.0),
@@ -438,15 +442,16 @@ def test_create_ul2r_example():
438442
task_params,
439443
task_indices,
440444
max_segments_per_example,
441-
Pos,
445+
QPos,
446+
KPos,
442447
pad_token_id,
443448
tokens,
444449
segment_ids,
445450
)
446451

447452
# Basic smoke checks
448-
assert example.tokens.array.shape == (Pos.size,)
449-
assert example.loss_mask.array.shape == (Pos.size,)
453+
assert example.tokens.array.shape == (QPos.size,)
454+
assert example.loss_mask.array.shape == (QPos.size,)
450455
assert example.attn_mask.is_causal
451456

452457
# Should contain sentinel tokens after denoising
@@ -517,7 +522,8 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
517522
cache = typing.cast(TreeCache[TokenizedDict], cache)
518523

519524
# Test Ul2rDataset
520-
Pos = hax.Axis("position", 128)
525+
QPos = hax.Axis("QPos", 128)
526+
KPos = hax.Axis("KPos", 128)
521527
task_configs = {
522528
"r": RDenoisingConfig(mask_prob=0.15, mean_span_length=3.0),
523529
"x": XDenoisingConfig(mask_prob=0.5, mean_span_length=3.0),
@@ -526,7 +532,8 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
526532

527533
dataset = Ul2rDataset(
528534
cache=cache,
529-
Pos=Pos,
535+
QPos=QPos,
536+
KPos=KPos,
530537
task_configs=task_configs,
531538
task_probs={"r": 0.33, "x": 0.33, "s": 0.34},
532539
key=jax.random.PRNGKey(123),
@@ -541,8 +548,8 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
541548

542549
# Structure checks
543550
assert isinstance(ex, LmExample)
544-
assert ex.tokens.axes == (Pos,)
545-
assert ex.loss_mask.axes == (Pos,)
551+
assert ex.tokens.axes == (QPos,)
552+
assert ex.loss_mask.axes == (QPos,)
546553
assert isinstance(ex.attn_mask, AttentionMask)
547554
assert ex.attn_mask.is_causal
548555

@@ -557,22 +564,40 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
557564
assert not jnp.any(ex.loss_mask.array & ~non_padding) # No loss on padding
558565
assert jnp.any(jnp.isin(ex.tokens.array, SENTINEL_TOKEN_IDS)) # Has sentinels from denoising
559566

567+
# Collect all original input tokens from the cache
568+
original_tokens = set()
569+
for item in cache_sync:
570+
original_tokens.update(int(t) for t in item["input_ids"] if t != pad_id)
571+
572+
# Check that all output tokens (except pad, sentinels, and task tokens) were present in the input
573+
# This helps verify we're not creating gibberish by overlapping spans
574+
ul2r_special_tokens = set(SENTINEL_TOKEN_IDS.tolist()) | {
575+
R_TASK_TOKEN_ID,
576+
X_TASK_TOKEN_ID,
577+
S_TASK_TOKEN_ID,
578+
pad_id,
579+
}
580+
allowed_tokens = original_tokens | ul2r_special_tokens
581+
output_tokens = set(int(t) for t in ex.tokens.array)
582+
unexpected_tokens = output_tokens - allowed_tokens
583+
assert len(unexpected_tokens) == 0, f"Found unexpected tokens not in input: {unexpected_tokens}"
584+
560585
# Attention mask checks
561-
if ex.attn_mask.prefix_mask is not None:
562-
assert ex.attn_mask.prefix_mask.array.shape == (Pos.size, Pos.size)
563-
# Materialize full attention mask (causal + prefix)
564-
materialized = ex.attn_mask.materialize(Pos, Pos)
565-
assert materialized is not None
566-
# Diagonal should be True for all non-padding (tokens attend to themselves)
567-
diag = jnp.diag(materialized.array)
568-
assert jnp.all(diag[non_padding])
569-
# Some off-diagonal should be True (bidirectional attention on input positions)
570-
off_diag_sum = jnp.sum(materialized.array) - jnp.sum(diag)
571-
assert off_diag_sum > 0, "Expected some bidirectional attention for input positions"
586+
input_mask = typing.cast(NamedArray, ex.attn_mask.input_mask)
587+
assert input_mask.array.shape == (QPos.size,)
588+
# Materialize full attention mask (causal + prefix)
589+
materialized = ex.attn_mask.materialize(QPos, KPos)
590+
assert materialized is not None
591+
# Diagonal should be True for all non-padding (tokens attend to themselves)
592+
diag = jnp.diag(materialized.array)
593+
assert jnp.all(diag[non_padding])
594+
# Some off-diagonal should be True (bidirectional attention on input positions)
595+
off_diag_sum = jnp.sum(materialized.array) - jnp.sum(diag)
596+
assert off_diag_sum > 0, "Expected some bidirectional attention for input positions"
572597

573598
# Check consistency across multiple examples
574599
for ex_i in [dataset_sync[i] for i in range(min(3, len(dataset_sync)))]:
575-
assert ex_i.tokens.axes == (Pos,) and ex_i.loss_mask.axes == (Pos,)
600+
assert ex_i.tokens.axes == (QPos,) and ex_i.loss_mask.axes == (QPos,)
576601
non_pad_i = jnp.sum(ex_i.tokens.array != pad_id)
577602
loss_i = jnp.sum(ex_i.loss_mask.array)
578603
assert 0 < loss_i < non_pad_i

0 commit comments

Comments
 (0)