22from pathlib import Path
33import tempfile
44import typing
5- from haliax import Axis
5+ from haliax import Axis , NamedArray
66from haliax .nn import hax
77import jax
88import jax .numpy as jnp
2929 XDenoisingConfig ,
3030 SDenoisingConfig ,
3131 Ul2rDataset ,
32+ R_TASK_TOKEN_ID ,
33+ X_TASK_TOKEN_ID ,
34+ S_TASK_TOKEN_ID ,
3235)
3336from levanter .layers .attention import AttentionMask
3437from levanter .models .lm_model import LmExample
@@ -398,7 +401,8 @@ def test_to_ul2r_s_tokens():
398401
399402
400403def test_create_ul2r_example ():
401- Pos = Axis ("position" , 128 )
404+ QPos = Axis ("QPos" , 128 )
405+ KPos = Axis ("KPos" , 128 )
402406 pad_token_id = 0
403407 max_segments_per_example = 8
404408
@@ -411,7 +415,7 @@ def test_create_ul2r_example():
411415 jnp .zeros (105 , dtype = jnp .int32 ), # padding
412416 ]
413417 )
414- tokens = hax .named (tokens , Pos )
418+ tokens = hax .named (tokens , QPos )
415419
416420 segment_ids = jnp .concatenate (
417421 [
@@ -421,7 +425,7 @@ def test_create_ul2r_example():
421425 jnp .full (105 , - 1 ),
422426 ]
423427 )
424- segment_ids = hax .named (segment_ids , Pos )
428+ segment_ids = hax .named (segment_ids , QPos )
425429
426430 task_configs = [
427431 RDenoisingConfig (mask_prob = 0.15 , mean_span_length = 3.0 ),
@@ -438,15 +442,16 @@ def test_create_ul2r_example():
438442 task_params ,
439443 task_indices ,
440444 max_segments_per_example ,
441- Pos ,
445+ QPos ,
446+ KPos ,
442447 pad_token_id ,
443448 tokens ,
444449 segment_ids ,
445450 )
446451
447452 # Basic smoke checks
448- assert example .tokens .array .shape == (Pos .size ,)
449- assert example .loss_mask .array .shape == (Pos .size ,)
453+ assert example .tokens .array .shape == (QPos .size ,)
454+ assert example .loss_mask .array .shape == (QPos .size ,)
450455 assert example .attn_mask .is_causal
451456
452457 # Should contain sentinel tokens after denoising
@@ -517,7 +522,8 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
517522 cache = typing .cast (TreeCache [TokenizedDict ], cache )
518523
519524 # Test Ul2rDataset
520- Pos = hax .Axis ("position" , 128 )
525+ QPos = hax .Axis ("QPos" , 128 )
526+ KPos = hax .Axis ("KPos" , 128 )
521527 task_configs = {
522528 "r" : RDenoisingConfig (mask_prob = 0.15 , mean_span_length = 3.0 ),
523529 "x" : XDenoisingConfig (mask_prob = 0.5 , mean_span_length = 3.0 ),
@@ -526,7 +532,8 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
526532
527533 dataset = Ul2rDataset (
528534 cache = cache ,
529- Pos = Pos ,
535+ QPos = QPos ,
536+ KPos = KPos ,
530537 task_configs = task_configs ,
531538 task_probs = {"r" : 0.33 , "x" : 0.33 , "s" : 0.34 },
532539 key = jax .random .PRNGKey (123 ),
@@ -541,8 +548,8 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
541548
542549 # Structure checks
543550 assert isinstance (ex , LmExample )
544- assert ex .tokens .axes == (Pos ,)
545- assert ex .loss_mask .axes == (Pos ,)
551+ assert ex .tokens .axes == (QPos ,)
552+ assert ex .loss_mask .axes == (QPos ,)
546553 assert isinstance (ex .attn_mask , AttentionMask )
547554 assert ex .attn_mask .is_causal
548555
@@ -557,22 +564,40 @@ def test_ul2r_dataset_build(dummy_text_data, hf_tokenizer):
557564 assert not jnp .any (ex .loss_mask .array & ~ non_padding ) # No loss on padding
558565 assert jnp .any (jnp .isin (ex .tokens .array , SENTINEL_TOKEN_IDS )) # Has sentinels from denoising
559566
567+ # Collect all original input tokens from the cache
568+ original_tokens = set ()
569+ for item in cache_sync :
570+ original_tokens .update (int (t ) for t in item ["input_ids" ] if t != pad_id )
571+
572+ # Check that all output tokens (except pad, sentinels, and task tokens) were present in the input
573+ # This helps verify we're not creating gibberish by overlapping spans
574+ ul2r_special_tokens = set (SENTINEL_TOKEN_IDS .tolist ()) | {
575+ R_TASK_TOKEN_ID ,
576+ X_TASK_TOKEN_ID ,
577+ S_TASK_TOKEN_ID ,
578+ pad_id ,
579+ }
580+ allowed_tokens = original_tokens | ul2r_special_tokens
581+ output_tokens = set (int (t ) for t in ex .tokens .array )
582+ unexpected_tokens = output_tokens - allowed_tokens
583+ assert len (unexpected_tokens ) == 0 , f"Found unexpected tokens not in input: { unexpected_tokens } "
584+
560585 # Attention mask checks
561- if ex .attn_mask .prefix_mask is not None :
562- assert ex . attn_mask . prefix_mask . array .shape == (Pos .size , Pos . size )
563- # Materialize full attention mask (causal + prefix)
564- materialized = ex .attn_mask .materialize (Pos , Pos )
565- assert materialized is not None
566- # Diagonal should be True for all non-padding (tokens attend to themselves)
567- diag = jnp .diag (materialized .array )
568- assert jnp .all (diag [non_padding ])
569- # Some off-diagonal should be True (bidirectional attention on input positions)
570- off_diag_sum = jnp .sum (materialized .array ) - jnp .sum (diag )
571- assert off_diag_sum > 0 , "Expected some bidirectional attention for input positions"
586+ input_mask = typing . cast ( NamedArray , ex .attn_mask .input_mask )
587+ assert input_mask . array .shape == (QPos .size ,)
588+ # Materialize full attention mask (causal + prefix)
589+ materialized = ex .attn_mask .materialize (QPos , KPos )
590+ assert materialized is not None
591+ # Diagonal should be True for all non-padding (tokens attend to themselves)
592+ diag = jnp .diag (materialized .array )
593+ assert jnp .all (diag [non_padding ])
594+ # Some off-diagonal should be True (bidirectional attention on input positions)
595+ off_diag_sum = jnp .sum (materialized .array ) - jnp .sum (diag )
596+ assert off_diag_sum > 0 , "Expected some bidirectional attention for input positions"
572597
573598 # Check consistency across multiple examples
574599 for ex_i in [dataset_sync [i ] for i in range (min (3 , len (dataset_sync )))]:
575- assert ex_i .tokens .axes == (Pos ,) and ex_i .loss_mask .axes == (Pos ,)
600+ assert ex_i .tokens .axes == (QPos ,) and ex_i .loss_mask .axes == (QPos ,)
576601 non_pad_i = jnp .sum (ex_i .tokens .array != pad_id )
577602 loss_i = jnp .sum (ex_i .loss_mask .array )
578603 assert 0 < loss_i < non_pad_i
0 commit comments