77
88This achieves about a 90% "real token" rate, compared to like 10% without packing.
99"""
10+
1011import asyncio
1112from dataclasses import dataclass
1213from typing import Iterable , Iterator , Literal , Optional , Sequence , TypeVar
@@ -62,9 +63,17 @@ def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int):
6263 assert pad_token is not None , "pad_token must be set"
6364
6465 def can_pack (self , ids : list [int ]) -> bool :
65- return len (ids ) + len (self ._ids ) <= self .Pos .size and self .num_segments < self .max_pack_size
66+ return (
67+ len (ids ) + len (self ._ids ) <= self .Pos .size
68+ and self .num_segments < self .max_pack_size
69+ )
6670
67- def add_example (self , ids : list [int ], loss_mask : list [int ] | np .ndarray , segment_id : int | None = None ):
71+ def add_example (
72+ self ,
73+ ids : list [int ],
74+ loss_mask : list [int ] | np .ndarray ,
75+ segment_id : int | None = None ,
76+ ):
6877 if len (ids ) != len (loss_mask ):
6978 raise ValueError ("ids and loss_mask must have the same length" )
7079
@@ -90,7 +99,9 @@ def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment
9099 def pack (self ) -> LmExample :
91100 ids = self ._ids + [self .pad_token ] * (self .Pos .size - len (self ._ids ))
92101
93- segment_ids = self ._segment_ids + [- 1 ] * (self .Pos .size - len (self ._segment_ids ))
102+ segment_ids = self ._segment_ids + [- 1 ] * (
103+ self .Pos .size - len (self ._segment_ids )
104+ )
94105
95106 loss_mask = self ._loss_mask + [0 ] * (self .Pos .size - len (self ._loss_mask ))
96107
@@ -173,7 +184,9 @@ def per_segment_loss(
173184 This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
174185 """
175186
176- assert packed_example .attn_mask .segment_ids is not None , "segment_ids must be set in the AttentionMask"
187+ assert (
188+ packed_example .attn_mask .segment_ids is not None
189+ ), "segment_ids must be set in the AttentionMask"
177190
178191 segment_ids = packed_example .attn_mask .segment_ids
179192 assert (
@@ -200,7 +213,9 @@ def per_segment_loss(
200213def _unique_segment_ids (max_Segments , segment_ids ):
201214 # Extract unique segment IDs with padding
202215 # TODO: add unique to haliax
203- unique_segment_ids = jnp .unique (segment_ids .array , size = max_Segments .size , fill_value = - 1 )
216+ unique_segment_ids = jnp .unique (
217+ segment_ids .array , size = max_Segments .size , fill_value = - 1
218+ )
204219 unique_segment_ids = hax .named (unique_segment_ids , max_Segments )
205220 return unique_segment_ids
206221
@@ -219,7 +234,9 @@ def per_segment_correct(
219234 correct is a boolean array of the same shape as the losses array indicating whether the token was correct
220235 """
221236
222- assert packed_example .attn_mask .segment_ids is not None , "segment_ids must be set in the AttentionMask"
237+ assert (
238+ packed_example .attn_mask .segment_ids is not None
239+ ), "segment_ids must be set in the AttentionMask"
223240
224241 segment_ids = packed_example .attn_mask .segment_ids
225242 assert (
@@ -250,6 +267,8 @@ def greedy_pack_prompt_completions(
250267 sequences : Iterable [PromptCompletion ],
251268 pad_token : int ,
252269 max_segments_per_example : int = 64 ,
270+ pad_start : int = 0 ,
271+ lengths : np .ndarray | None = None ,
253272) -> list [LmExample ]:
254273 """
255274 Greedy packing of prompt completions into LmExamples using [pack_documents][]
@@ -265,8 +284,12 @@ def make_loss_mask(id, prompt_length):
265284 ids = [sequence .ids for sequence in sequences ]
266285
267286 # Pack documents based on their lengths
287+ pack_lengths = (
288+ np .array ([len (token_ids ) for token_ids in ids ]) if lengths is None else lengths
289+ )
290+ pack_lengths = pack_lengths + pad_start
268291 packs = pack_documents (
269- lengths = np . array ([ len ( token_ids ) for token_ids in ids ]) ,
292+ lengths = pack_lengths ,
270293 max_length = Pos .size ,
271294 max_segments_per_example = max_segments_per_example ,
272295 slice_too_long_examples = True ,
@@ -285,10 +308,21 @@ def make_loss_mask(id, prompt_length):
285308 concat_loss_mask = []
286309 segment_ids = []
287310
288- for doc_id , seq , prompt_len in zip (docs_in_pack , pack_sequences , pack_prompt_lengths ):
311+ for doc_id , seq , prompt_len in zip (
312+ docs_in_pack , pack_sequences , pack_prompt_lengths
313+ ):
314+ doc_length = len (seq .ids )
315+ pad_end = pack_lengths [doc_id ] - pad_start - doc_length
316+
289317 concat_ids .extend (seq .ids )
318+
319+ concat_loss_mask .extend ([0 ] * pad_start )
290320 concat_loss_mask .extend (make_loss_mask (seq .ids , prompt_len ))
291- segment_ids .extend ([doc_id ] * len (seq .ids ))
321+ concat_loss_mask .extend ([0 ] * pad_end )
322+
323+ segment_ids .extend ([- 1 ] * pad_start )
324+ segment_ids .extend ([doc_id ] * pack_lengths [doc_id ])
325+ segment_ids .extend ([- 1 ] * pad_end )
292326
293327 # Pad to max length
294328 pad_length = Pos .size - len (concat_ids )
@@ -300,7 +334,9 @@ def make_loss_mask(id, prompt_length):
300334 elif pad_length < 0 :
301335 # too long, this should only happen if there's 1 document in the pack
302336 if len (pack_sequences ) != 1 :
303- raise ValueError ("Too many tokens in a pack with more than one document" )
337+ raise ValueError (
338+ "Too many tokens in a pack with more than one document"
339+ )
304340 concat_ids = concat_ids [- Pos .size :]
305341 concat_loss_mask = concat_loss_mask [- Pos .size :]
306342 segment_ids = segment_ids [- Pos .size :]
@@ -326,6 +362,7 @@ def _segment_ids_from_lengths(doc_ids: list[int], lengths: list[int]) -> list[in
326362def pack_documents (
327363 lengths : PyTree [np .ndarray ],
328364 max_length : PyTree [int ],
365+ pad_start : PyTree [int ],
329366 max_segments_per_example : int | None = None ,
330367 slice_too_long_examples : bool = False ,
331368) -> list [range ]:
@@ -347,17 +384,24 @@ def pack_documents(
347384 if max_segments_per_example is not None and (
348385 not isinstance (max_segments_per_example , int ) or max_segments_per_example <= 0
349386 ):
350- raise ValueError (f"max_segments_per_example must be a positive integer, got { max_segments_per_example } " )
387+ raise ValueError (
388+ f"max_segments_per_example must be a positive integer, got { max_segments_per_example } "
389+ )
390+
391+ lengths_leaves = jax .tree .leaves (lengths )
392+ leaf_names = jax .tree .leaves (leaf_key_paths (lengths ))
351393
352394 # Broadcast max_length to match the structure of lengths
353395 max_length_tree = tree_broadcast_to (max_length , lengths )
354-
355- lengths_leaves = jax .tree .leaves (lengths )
356396 max_length_leaves = jax .tree .leaves (max_length_tree )
357- leaf_names = jax .tree .leaves (leaf_key_paths (lengths ))
397+
398+ pad_start_tree = tree_broadcast_to (pad_start , lengths )
399+ pad_start_leaves = jax .tree .leaves (pad_start_tree )
358400
359401 if len (lengths_leaves ) != len (max_length_leaves ):
360- raise ValueError ("Lengths and max_length PyTrees must have the same number of leaves." )
402+ raise ValueError (
403+ "Lengths and max_length PyTrees must have the same number of leaves."
404+ )
361405
362406 # Check that all leaves have the same number of documents.
363407 n_docs = None
@@ -370,12 +414,14 @@ def pack_documents(
370414 if n_docs is None :
371415 raise ValueError ("Could not determine the number of documents from lengths." )
372416
373- # Validate document lengths
417+ # Validate document lengths (including pad_start)
374418 for lens , allowed , leaf_name in zip (lengths_leaves , max_length_leaves , leaf_names ):
375419 for i in range (n_docs ):
376- if lens [i ] > allowed and not slice_too_long_examples :
420+ effective_length = lens [i ] + pad_start
421+ if effective_length > allowed and not slice_too_long_examples :
377422 raise ValueError (
378- f"Document { i } in leaf '{ leaf_name } ' has length { lens [i ]} which exceeds "
423+ f"Document { i } in leaf '{ leaf_name } ' has effective length { effective_length } "
424+ f"(document length { lens [i ]} + pad_start { pad_start } ) which exceeds "
379425 f"maximum allowed length { allowed } . Consider setting slice_too_long_examples=True "
380426 "or increasing max_length."
381427 )
@@ -388,19 +434,27 @@ def pack_documents(
388434 # Accumulate documents while for each leaf the token span remains within the allowed max.
389435 while i < n_docs :
390436 # Check optional segment constraint: if adding one more document would exceed max_segments_per_example.
391- if max_segments_per_example is not None and (total_segments + 1 ) > max_segments_per_example :
437+ if (
438+ max_segments_per_example is not None
439+ and (total_segments + 1 ) > max_segments_per_example
440+ ):
392441 break
393442 # For each leaf, check if adding document i would keep the token count within allowed capacity.
394443 valid = True
395- for lens , allowed , leaf_name in zip (lengths_leaves , max_length_leaves , leaf_names , strict = True ):
396- # Compute token count from document start to document i+1.
397- token_sum = sum (lens [start : i + 1 ])
444+ for lens , allowed , leaf_name in zip (
445+ lengths_leaves , max_length_leaves , leaf_names , strict = True
446+ ):
447+ # Compute token count from document start to document i+1, including pad_start for each doc.
448+ num_docs_in_pack = i - start + 1
449+ token_sum = sum (lens [start : i + 1 ]) + num_docs_in_pack * pad_start
398450 if token_sum > allowed :
399451 valid = False
400452 if not slice_too_long_examples and i == start :
401453 # If this is the first document in a new pack and it's too long, raise an error
454+ effective_length = lens [i ] + pad_start
402455 raise ValueError (
403- f"Document { i } in leaf '{ leaf_name } ' has length { lens [i ]} which exceeds "
456+ f"Document { i } in leaf '{ leaf_name } ' has effective length { effective_length } "
457+ f"(document length { lens [i ]} + pad_start { pad_start } ) which exceeds "
404458 f"maximum allowed length { allowed } . Consider setting slice_too_long_examples=True "
405459 "or increasing max_length."
406460 )
@@ -445,6 +499,8 @@ def __init__(
445499 max_segments_per_example : int | None = None ,
446500 pad_with_zeros : bool = True ,
447501 slice_strategy : Literal ["left" , "right" , "raise" ] = "raise" ,
502+ lengths : np .ndarray | None = None ,
503+ prefixes : PyTree [np .ndarray ] | None = None ,
448504 ):
449505 """
450506 Args:
@@ -457,29 +513,42 @@ def __init__(
457513 super ().__init__ ()
458514
459515 if slice_strategy not in ["left" , "right" , "raise" ]:
460- raise ValueError (f"slice_strategy must be one of 'left', 'right', or 'raise', got { slice_strategy } " )
516+ raise ValueError (
517+ f"slice_strategy must be one of 'left', 'right', or 'raise', got { slice_strategy } "
518+ )
461519
462520 self .dataset = dataset
463521 self .max_length = max_length
464522 self .max_segments_per_example = max_segments_per_example
465523 self .pad_with_zeros = pad_with_zeros
466524 self .slice_strategy = slice_strategy
467525
468- _offsets = jax .tree .map (lambda store : store .offsets [0 : store .num_rows + 1 ].read (), self .dataset )
526+ _offsets = jax .tree .map (
527+ lambda store : store .offsets [0 : store .num_rows + 1 ].read (), self .dataset
528+ )
469529 self ._offsets = jax .tree .map (lambda fut : fut .result (), _offsets )
470530
471- def diff_offsets (offsets : np .ndarray ):
472- # fine to mutate since we have a copy
473- # the array store has the number of rows in the 0th offset
474- offsets [0 ] = 0
475- return offsets [1 :] - offsets [:- 1 ]
476531
477- # Convert offsets to lengths
478- self ._lengths = jax .tree .map (diff_offsets , self ._offsets )
532+ if lengths is not None :
533+ self ._lengths = lengths
534+ else :
535+ def diff_offsets (offsets : np .ndarray ):
536+ # fine to mutate since we have a copy
537+ # the array store has the number of rows in the 0th offset
538+ offsets [0 ] = 0
539+ return offsets [1 :] - offsets [:- 1 ]
540+
541+ # Convert offsets to lengths
542+ self ._lengths = jax .tree .map (diff_offsets , self ._offsets )
543+
479544
480545 # Build pack indices
481546 self ._pack_indices : list [range ] = pack_documents (
482- self ._lengths , max_length , max_segments_per_example , slice_strategy != "raise"
547+ self ._lengths ,
548+ max_length ,
549+ pad_start ,
550+ max_segments_per_example ,
551+ slice_strategy != "raise" ,
483552 )
484553
485554 def is_finite (self ) -> bool :
@@ -494,7 +563,9 @@ async def final_length_is_known(self) -> bool:
494563 async def current_len (self ) -> Optional [int ]:
495564 return len (self ._pack_indices )
496565
497- async def get_batch (self , indices : Sequence [int ]) -> Sequence [tuple [PyTree [np .ndarray ], PyTree [np .ndarray ]]]:
566+ async def get_batch (
567+ self , indices : Sequence [int ]
568+ ) -> Sequence [tuple [PyTree [np .ndarray ], PyTree [np .ndarray ]]]:
498569 """
499570 For each requested packed example (by index into self._pack_indices), reconstruct the
500571 token data on the fly from the underlying dataset. In our packing scheme the pack holds, for each leaf,
@@ -508,7 +579,9 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[tuple[PyTree[np.nd
508579
509580 pack_doc_ranges = [self ._pack_indices [i ] for i in indices ]
510581
511- async def get_data_for_leaf (store , offsets , allowed : int ) -> tuple [list [np .ndarray ], list [np .ndarray ]]:
582+ async def get_data_for_leaf (
583+ store , offsets , allowed : int
584+ ) -> tuple [list [np .ndarray ], list [np .ndarray ]]:
512585 out_data = []
513586 out_segment_ids = []
514587 # Using ts.Batch to group reads.
@@ -520,7 +593,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
520593 token_count = token_end - token_start
521594 if token_count > allowed :
522595 if self .slice_strategy != "raise" :
523- assert len (dr ) == 1 , "We shouldn't have packed two examples together if one is too long."
596+ assert (
597+ len (dr ) == 1
598+ ), "We shouldn't have packed two examples together if one is too long."
524599 if self .slice_strategy == "right" :
525600 # slice from the right
526601 token_start = token_end - allowed
@@ -533,12 +608,20 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
533608 f"{ list (dr )} . Consider using a different slice_strategy or increasing max_length."
534609 )
535610 # Read the slice from the underlying data.
611+ # TODO need to pad at start with token_idx and and at end up to length here
612+ # the size of each example will differ, but the that way the output size of denoi
613+ # or... maybe we identify the starts of segments using seg_ids (and pad this tensor?)
614+ # then we roll & call on each of the rolled parts
615+ # so this doesn't need to start with token_idx
616+ # and we don't need pad_start
536617 out_data .append (store .data [token_start :token_end ].read ())
537618
538619 # Create segment IDs for this pack
539620 segment_ids = []
540621 for doc_idx in range (len (dr )):
541- doc_start = offsets [dr .start + doc_idx ] if dr .start + doc_idx > 0 else 0
622+ doc_start = (
623+ offsets [dr .start + doc_idx ] if dr .start + doc_idx > 0 else 0
624+ )
542625 doc_end = offsets [dr .start + doc_idx + 1 ]
543626 doc_length = doc_end - doc_start
544627 # Use the global document index as the segment ID
@@ -555,7 +638,10 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
555638
556639 if self .pad_with_zeros :
557640 out_data = [np .pad (x , (0 , allowed - x .shape [0 ])) for x in out_data ]
558- out_segment_ids = [np .pad (x , (0 , allowed - x .shape [0 ]), constant_values = - 1 ) for x in out_segment_ids ]
641+ out_segment_ids = [
642+ np .pad (x , (0 , allowed - x .shape [0 ]), constant_values = - 1 )
643+ for x in out_segment_ids
644+ ]
559645
560646 return out_data , out_segment_ids
561647
@@ -568,7 +654,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
568654 # Use tree.map to combine the leaves from: dataset, max_length and, for each pack, its doc_range.
569655 # Note: jax.tree.map will map over each pack in parallel across the leaves.
570656 max_length_tree = tree_broadcast_to (self .max_length , self ._offsets )
571- leaf_batch_futures = jax .tree .map (get_data_for_leaf , self .dataset , self ._offsets , max_length_tree )
657+ leaf_batch_futures = jax .tree .map (
658+ get_data_for_leaf , self .dataset , self ._offsets , max_length_tree
659+ )
572660
573661 # Flatten the resulting PyTree: each leaf is now an Awaitable returning a tuple of lists of np.ndarray—one per requested pack.
574662 leaves , treedef = jax .tree .flatten (leaf_batch_futures )
@@ -582,7 +670,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
582670 results = []
583671 for i in range (len (indices )):
584672 data = jax .tree .unflatten (treedef , [leaf [0 ][i ] for leaf in resolved_leaves ])
585- segment_ids = jax .tree .unflatten (treedef , [leaf [1 ][i ] for leaf in resolved_leaves ])
673+ segment_ids = jax .tree .unflatten (
674+ treedef , [leaf [1 ][i ] for leaf in resolved_leaves ]
675+ )
586676 results .append ((data , segment_ids ))
587677 return results
588678
@@ -598,7 +688,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
598688 store = JaggedArrayStore .open (path , mode = "r" , dtype = np .uint32 , cache_metadata = True )
599689
600690 time_in = time .time ()
601- packed = GreedyPrepackedDataset (store , max_length = 4096 , pad_with_zeros = True , slice_strategy = "right" )
691+ packed = GreedyPrepackedDataset (
692+ store , max_length = 4096 , pad_with_zeros = True , slice_strategy = "right"
693+ )
602694 time_out = time .time ()
603695 print (f"Took { time_out - time_in :.2f} s to build pack" )
604696
0 commit comments