Skip to content

Commit 41cebc9

Browse files
committed
WIP: does not compile
1 parent e7c6b10 commit 41cebc9

File tree

4 files changed

+645
-94
lines changed

4 files changed

+645
-94
lines changed

src/levanter/data/packing.py

Lines changed: 133 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
This achieves about a 90% "real token" rate, compared to like 10% without packing.
99
"""
10+
1011
import asyncio
1112
from dataclasses import dataclass
1213
from typing import Iterable, Iterator, Literal, Optional, Sequence, TypeVar
@@ -62,9 +63,17 @@ def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int):
6263
assert pad_token is not None, "pad_token must be set"
6364

6465
def can_pack(self, ids: list[int]) -> bool:
65-
return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size
66+
return (
67+
len(ids) + len(self._ids) <= self.Pos.size
68+
and self.num_segments < self.max_pack_size
69+
)
6670

67-
def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment_id: int | None = None):
71+
def add_example(
72+
self,
73+
ids: list[int],
74+
loss_mask: list[int] | np.ndarray,
75+
segment_id: int | None = None,
76+
):
6877
if len(ids) != len(loss_mask):
6978
raise ValueError("ids and loss_mask must have the same length")
7079

@@ -90,7 +99,9 @@ def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment
9099
def pack(self) -> LmExample:
91100
ids = self._ids + [self.pad_token] * (self.Pos.size - len(self._ids))
92101

93-
segment_ids = self._segment_ids + [-1] * (self.Pos.size - len(self._segment_ids))
102+
segment_ids = self._segment_ids + [-1] * (
103+
self.Pos.size - len(self._segment_ids)
104+
)
94105

95106
loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask))
96107

@@ -173,7 +184,9 @@ def per_segment_loss(
173184
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
174185
"""
175186

176-
assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"
187+
assert (
188+
packed_example.attn_mask.segment_ids is not None
189+
), "segment_ids must be set in the AttentionMask"
177190

178191
segment_ids = packed_example.attn_mask.segment_ids
179192
assert (
@@ -200,7 +213,9 @@ def per_segment_loss(
200213
def _unique_segment_ids(max_Segments, segment_ids):
201214
# Extract unique segment IDs with padding
202215
# TODO: add unique to haliax
203-
unique_segment_ids = jnp.unique(segment_ids.array, size=max_Segments.size, fill_value=-1)
216+
unique_segment_ids = jnp.unique(
217+
segment_ids.array, size=max_Segments.size, fill_value=-1
218+
)
204219
unique_segment_ids = hax.named(unique_segment_ids, max_Segments)
205220
return unique_segment_ids
206221

@@ -219,7 +234,9 @@ def per_segment_correct(
219234
correct is a boolean array of the same shape as the losses array indicating whether the token was correct
220235
"""
221236

222-
assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"
237+
assert (
238+
packed_example.attn_mask.segment_ids is not None
239+
), "segment_ids must be set in the AttentionMask"
223240

224241
segment_ids = packed_example.attn_mask.segment_ids
225242
assert (
@@ -250,6 +267,8 @@ def greedy_pack_prompt_completions(
250267
sequences: Iterable[PromptCompletion],
251268
pad_token: int,
252269
max_segments_per_example: int = 64,
270+
pad_start: int = 0,
271+
lengths: np.ndarray | None = None,
253272
) -> list[LmExample]:
254273
"""
255274
Greedy packing of prompt completions into LmExamples using [pack_documents][]
@@ -265,8 +284,12 @@ def make_loss_mask(id, prompt_length):
265284
ids = [sequence.ids for sequence in sequences]
266285

267286
# Pack documents based on their lengths
287+
pack_lengths = (
288+
np.array([len(token_ids) for token_ids in ids]) if lengths is None else lengths
289+
)
290+
pack_lengths = pack_lengths + pad_start
268291
packs = pack_documents(
269-
lengths=np.array([len(token_ids) for token_ids in ids]),
292+
lengths=pack_lengths,
270293
max_length=Pos.size,
271294
max_segments_per_example=max_segments_per_example,
272295
slice_too_long_examples=True,
@@ -285,10 +308,21 @@ def make_loss_mask(id, prompt_length):
285308
concat_loss_mask = []
286309
segment_ids = []
287310

288-
for doc_id, seq, prompt_len in zip(docs_in_pack, pack_sequences, pack_prompt_lengths):
311+
for doc_id, seq, prompt_len in zip(
312+
docs_in_pack, pack_sequences, pack_prompt_lengths
313+
):
314+
doc_length = len(seq.ids)
315+
pad_end = pack_lengths[doc_id] - pad_start - doc_length
316+
289317
concat_ids.extend(seq.ids)
318+
319+
concat_loss_mask.extend([0] * pad_start)
290320
concat_loss_mask.extend(make_loss_mask(seq.ids, prompt_len))
291-
segment_ids.extend([doc_id] * len(seq.ids))
321+
concat_loss_mask.extend([0] * pad_end)
322+
323+
segment_ids.extend([-1] * pad_start)
324+
segment_ids.extend([doc_id] * pack_lengths[doc_id])
325+
segment_ids.extend([-1] * pad_end)
292326

293327
# Pad to max length
294328
pad_length = Pos.size - len(concat_ids)
@@ -300,7 +334,9 @@ def make_loss_mask(id, prompt_length):
300334
elif pad_length < 0:
301335
# too long, this should only happen if there's 1 document in the pack
302336
if len(pack_sequences) != 1:
303-
raise ValueError("Too many tokens in a pack with more than one document")
337+
raise ValueError(
338+
"Too many tokens in a pack with more than one document"
339+
)
304340
concat_ids = concat_ids[-Pos.size :]
305341
concat_loss_mask = concat_loss_mask[-Pos.size :]
306342
segment_ids = segment_ids[-Pos.size :]
@@ -326,6 +362,7 @@ def _segment_ids_from_lengths(doc_ids: list[int], lengths: list[int]) -> list[in
326362
def pack_documents(
327363
lengths: PyTree[np.ndarray],
328364
max_length: PyTree[int],
365+
pad_start: PyTree[int],
329366
max_segments_per_example: int | None = None,
330367
slice_too_long_examples: bool = False,
331368
) -> list[range]:
@@ -347,17 +384,24 @@ def pack_documents(
347384
if max_segments_per_example is not None and (
348385
not isinstance(max_segments_per_example, int) or max_segments_per_example <= 0
349386
):
350-
raise ValueError(f"max_segments_per_example must be a positive integer, got {max_segments_per_example}")
387+
raise ValueError(
388+
f"max_segments_per_example must be a positive integer, got {max_segments_per_example}"
389+
)
390+
391+
lengths_leaves = jax.tree.leaves(lengths)
392+
leaf_names = jax.tree.leaves(leaf_key_paths(lengths))
351393

352394
# Broadcast max_length to match the structure of lengths
353395
max_length_tree = tree_broadcast_to(max_length, lengths)
354-
355-
lengths_leaves = jax.tree.leaves(lengths)
356396
max_length_leaves = jax.tree.leaves(max_length_tree)
357-
leaf_names = jax.tree.leaves(leaf_key_paths(lengths))
397+
398+
pad_start_tree = tree_broadcast_to(pad_start, lengths)
399+
pad_start_leaves = jax.tree.leaves(pad_start_tree)
358400

359401
if len(lengths_leaves) != len(max_length_leaves):
360-
raise ValueError("Lengths and max_length PyTrees must have the same number of leaves.")
402+
raise ValueError(
403+
"Lengths and max_length PyTrees must have the same number of leaves."
404+
)
361405

362406
# Check that all leaves have the same number of documents.
363407
n_docs = None
@@ -370,12 +414,14 @@ def pack_documents(
370414
if n_docs is None:
371415
raise ValueError("Could not determine the number of documents from lengths.")
372416

373-
# Validate document lengths
417+
# Validate document lengths (including pad_start)
374418
for lens, allowed, leaf_name in zip(lengths_leaves, max_length_leaves, leaf_names):
375419
for i in range(n_docs):
376-
if lens[i] > allowed and not slice_too_long_examples:
420+
effective_length = lens[i] + pad_start
421+
if effective_length > allowed and not slice_too_long_examples:
377422
raise ValueError(
378-
f"Document {i} in leaf '{leaf_name}' has length {lens[i]} which exceeds "
423+
f"Document {i} in leaf '{leaf_name}' has effective length {effective_length} "
424+
f"(document length {lens[i]} + pad_start {pad_start}) which exceeds "
379425
f"maximum allowed length {allowed}. Consider setting slice_too_long_examples=True "
380426
"or increasing max_length."
381427
)
@@ -388,19 +434,27 @@ def pack_documents(
388434
# Accumulate documents while for each leaf the token span remains within the allowed max.
389435
while i < n_docs:
390436
# Check optional segment constraint: if adding one more document would exceed max_segments_per_example.
391-
if max_segments_per_example is not None and (total_segments + 1) > max_segments_per_example:
437+
if (
438+
max_segments_per_example is not None
439+
and (total_segments + 1) > max_segments_per_example
440+
):
392441
break
393442
# For each leaf, check if adding document i would keep the token count within allowed capacity.
394443
valid = True
395-
for lens, allowed, leaf_name in zip(lengths_leaves, max_length_leaves, leaf_names, strict=True):
396-
# Compute token count from document start to document i+1.
397-
token_sum = sum(lens[start : i + 1])
444+
for lens, allowed, leaf_name in zip(
445+
lengths_leaves, max_length_leaves, leaf_names, strict=True
446+
):
447+
# Compute token count from document start to document i+1, including pad_start for each doc.
448+
num_docs_in_pack = i - start + 1
449+
token_sum = sum(lens[start : i + 1]) + num_docs_in_pack * pad_start
398450
if token_sum > allowed:
399451
valid = False
400452
if not slice_too_long_examples and i == start:
401453
# If this is the first document in a new pack and it's too long, raise an error
454+
effective_length = lens[i] + pad_start
402455
raise ValueError(
403-
f"Document {i} in leaf '{leaf_name}' has length {lens[i]} which exceeds "
456+
f"Document {i} in leaf '{leaf_name}' has effective length {effective_length} "
457+
f"(document length {lens[i]} + pad_start {pad_start}) which exceeds "
404458
f"maximum allowed length {allowed}. Consider setting slice_too_long_examples=True "
405459
"or increasing max_length."
406460
)
@@ -445,6 +499,8 @@ def __init__(
445499
max_segments_per_example: int | None = None,
446500
pad_with_zeros: bool = True,
447501
slice_strategy: Literal["left", "right", "raise"] = "raise",
502+
lengths: np.ndarray | None = None,
503+
prefixes: PyTree[np.ndarray] | None = None,
448504
):
449505
"""
450506
Args:
@@ -457,29 +513,42 @@ def __init__(
457513
super().__init__()
458514

459515
if slice_strategy not in ["left", "right", "raise"]:
460-
raise ValueError(f"slice_strategy must be one of 'left', 'right', or 'raise', got {slice_strategy}")
516+
raise ValueError(
517+
f"slice_strategy must be one of 'left', 'right', or 'raise', got {slice_strategy}"
518+
)
461519

462520
self.dataset = dataset
463521
self.max_length = max_length
464522
self.max_segments_per_example = max_segments_per_example
465523
self.pad_with_zeros = pad_with_zeros
466524
self.slice_strategy = slice_strategy
467525

468-
_offsets = jax.tree.map(lambda store: store.offsets[0 : store.num_rows + 1].read(), self.dataset)
526+
_offsets = jax.tree.map(
527+
lambda store: store.offsets[0 : store.num_rows + 1].read(), self.dataset
528+
)
469529
self._offsets = jax.tree.map(lambda fut: fut.result(), _offsets)
470530

471-
def diff_offsets(offsets: np.ndarray):
472-
# fine to mutate since we have a copy
473-
# the array store has the number of rows in the 0th offset
474-
offsets[0] = 0
475-
return offsets[1:] - offsets[:-1]
476531

477-
# Convert offsets to lengths
478-
self._lengths = jax.tree.map(diff_offsets, self._offsets)
532+
if lengths is not None:
533+
self._lengths = lengths
534+
else:
535+
def diff_offsets(offsets: np.ndarray):
536+
# fine to mutate since we have a copy
537+
# the array store has the number of rows in the 0th offset
538+
offsets[0] = 0
539+
return offsets[1:] - offsets[:-1]
540+
541+
# Convert offsets to lengths
542+
self._lengths = jax.tree.map(diff_offsets, self._offsets)
543+
479544

480545
# Build pack indices
481546
self._pack_indices: list[range] = pack_documents(
482-
self._lengths, max_length, max_segments_per_example, slice_strategy != "raise"
547+
self._lengths,
548+
max_length,
549+
pad_start,
550+
max_segments_per_example,
551+
slice_strategy != "raise",
483552
)
484553

485554
def is_finite(self) -> bool:
@@ -494,7 +563,9 @@ async def final_length_is_known(self) -> bool:
494563
async def current_len(self) -> Optional[int]:
495564
return len(self._pack_indices)
496565

497-
async def get_batch(self, indices: Sequence[int]) -> Sequence[tuple[PyTree[np.ndarray], PyTree[np.ndarray]]]:
566+
async def get_batch(
567+
self, indices: Sequence[int]
568+
) -> Sequence[tuple[PyTree[np.ndarray], PyTree[np.ndarray]]]:
498569
"""
499570
For each requested packed example (by index into self._pack_indices), reconstruct the
500571
token data on the fly from the underlying dataset. In our packing scheme the pack holds, for each leaf,
@@ -508,7 +579,9 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[tuple[PyTree[np.nd
508579

509580
pack_doc_ranges = [self._pack_indices[i] for i in indices]
510581

511-
async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarray], list[np.ndarray]]:
582+
async def get_data_for_leaf(
583+
store, offsets, allowed: int
584+
) -> tuple[list[np.ndarray], list[np.ndarray]]:
512585
out_data = []
513586
out_segment_ids = []
514587
# Using ts.Batch to group reads.
@@ -520,7 +593,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
520593
token_count = token_end - token_start
521594
if token_count > allowed:
522595
if self.slice_strategy != "raise":
523-
assert len(dr) == 1, "We shouldn't have packed two examples together if one is too long."
596+
assert (
597+
len(dr) == 1
598+
), "We shouldn't have packed two examples together if one is too long."
524599
if self.slice_strategy == "right":
525600
# slice from the right
526601
token_start = token_end - allowed
@@ -533,12 +608,20 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
533608
f"{list(dr)}. Consider using a different slice_strategy or increasing max_length."
534609
)
535610
# Read the slice from the underlying data.
611+
# TODO need to pad at start with token_idx and and at end up to length here
612+
# the size of each example will differ, but the that way the output size of denoi
613+
# or... maybe we identify the starts of segments using seg_ids (and pad this tensor?)
614+
# then we roll & call on each of the rolled parts
615+
# so this doesn't need to start with token_idx
616+
# and we don't need pad_start
536617
out_data.append(store.data[token_start:token_end].read())
537618

538619
# Create segment IDs for this pack
539620
segment_ids = []
540621
for doc_idx in range(len(dr)):
541-
doc_start = offsets[dr.start + doc_idx] if dr.start + doc_idx > 0 else 0
622+
doc_start = (
623+
offsets[dr.start + doc_idx] if dr.start + doc_idx > 0 else 0
624+
)
542625
doc_end = offsets[dr.start + doc_idx + 1]
543626
doc_length = doc_end - doc_start
544627
# Use the global document index as the segment ID
@@ -555,7 +638,10 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
555638

556639
if self.pad_with_zeros:
557640
out_data = [np.pad(x, (0, allowed - x.shape[0])) for x in out_data]
558-
out_segment_ids = [np.pad(x, (0, allowed - x.shape[0]), constant_values=-1) for x in out_segment_ids]
641+
out_segment_ids = [
642+
np.pad(x, (0, allowed - x.shape[0]), constant_values=-1)
643+
for x in out_segment_ids
644+
]
559645

560646
return out_data, out_segment_ids
561647

@@ -568,7 +654,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
568654
# Use tree.map to combine the leaves from: dataset, max_length and, for each pack, its doc_range.
569655
# Note: jax.tree.map will map over each pack in parallel across the leaves.
570656
max_length_tree = tree_broadcast_to(self.max_length, self._offsets)
571-
leaf_batch_futures = jax.tree.map(get_data_for_leaf, self.dataset, self._offsets, max_length_tree)
657+
leaf_batch_futures = jax.tree.map(
658+
get_data_for_leaf, self.dataset, self._offsets, max_length_tree
659+
)
572660

573661
# Flatten the resulting PyTree: each leaf is now an Awaitable returning a tuple of lists of np.ndarray—one per requested pack.
574662
leaves, treedef = jax.tree.flatten(leaf_batch_futures)
@@ -582,7 +670,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
582670
results = []
583671
for i in range(len(indices)):
584672
data = jax.tree.unflatten(treedef, [leaf[0][i] for leaf in resolved_leaves])
585-
segment_ids = jax.tree.unflatten(treedef, [leaf[1][i] for leaf in resolved_leaves])
673+
segment_ids = jax.tree.unflatten(
674+
treedef, [leaf[1][i] for leaf in resolved_leaves]
675+
)
586676
results.append((data, segment_ids))
587677
return results
588678

@@ -598,7 +688,9 @@ async def get_data_for_leaf(store, offsets, allowed: int) -> tuple[list[np.ndarr
598688
store = JaggedArrayStore.open(path, mode="r", dtype=np.uint32, cache_metadata=True)
599689

600690
time_in = time.time()
601-
packed = GreedyPrepackedDataset(store, max_length=4096, pad_with_zeros=True, slice_strategy="right")
691+
packed = GreedyPrepackedDataset(
692+
store, max_length=4096, pad_with_zeros=True, slice_strategy="right"
693+
)
602694
time_out = time.time()
603695
print(f"Took {time_out - time_in:.2f}s to build pack")
604696

0 commit comments

Comments
 (0)