Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 75 additions & 36 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,11 @@ def _monitor_metrics(self):
class CacheLedger:
# NB: unlike the old cache, the mere existence of a ledger doesn't mean the cache is finished
total_num_rows: int
shard_rows: Dict[str, int]
"""Number of outputted rows in the cache"""
shard_rows_in: Dict[str, int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, can we rename this one back to just shard_rows (leave comment) so that we don't invalidate all other caches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha yeah... Do I need to do some other logic to make shard_rows_out an optional key during deserialization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm pretty sure if you give it a default value it will be fine. I'll check it against a cache before merging

"""Numbers of rows read in from each shard"""
shard_rows_out: Dict[str, int]
"""Number of rows written out of each shard"""
is_finished: bool = False
finished_shards: List[str] = dataclasses.field(default_factory=list)
field_counts: Dict[str, int] = dataclasses.field(default_factory=dict)
Expand All @@ -487,7 +491,8 @@ def load_or_initialize(cache_dir: str, source: ShardedDataSource, processor: Bat
except FileNotFoundError:
return CacheLedger(
total_num_rows=0,
shard_rows={shard: 0 for shard in source.shard_names},
shard_rows_in={shard: 0 for shard in source.shard_names},
shard_rows_out={shard: 0 for shard in source.shard_names},
is_finished=False,
metadata=metadata,
)
Expand Down Expand Up @@ -569,7 +574,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
ledger = CacheLedger(
total_num_rows=len(self._tree_store),
is_finished=True,
shard_rows={"": len(self._tree_store)},
shard_rows_in={"": len(self._tree_store)}, # This is ok since we directly write batches
shard_rows_out={"": len(self._tree_store)},
finished_shards=[""],
field_counts={},
metadata=self.metadata or CacheMetadata.empty(),
Expand Down Expand Up @@ -740,9 +746,17 @@ def _notify_updated_ledger(self, ledger: CacheLedger):
if ledger.total_num_rows < self._ledger.total_num_rows:
raise RuntimeError(f"Ledger went backwards: {ledger.total_num_rows} < {self._ledger.total_num_rows}")

for shard, rows in ledger.shard_rows.items():
if rows < self._ledger.shard_rows.get(shard, 0):
raise RuntimeError(f"Shard {shard} went backwards: {rows} < {self._ledger.shard_rows.get(shard, 0)}")
for shard in ledger.shard_rows_in.keys():
rows_in = ledger.shard_rows_in[shard]
rows_out = ledger.shard_rows_out[shard]
if rows_in < self._ledger.shard_rows_in.get(shard, 0):
raise RuntimeError(
f"Shard {shard} went backwards (IN): {rows_in} < {self._ledger.shard_rows_in.get(shard, 0)}"
)
if rows_out < self._ledger.shard_rows_out.get(shard, 0):
raise RuntimeError(
f"Shard {shard} went backwards (OUT): {rows_in} < {self._ledger.shard_rows_out.get(shard, 0)}"
)

if was_finished:
raise RuntimeError("Ledger was already finished")
Expand Down Expand Up @@ -1007,10 +1021,14 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None)
def _merge_ledgers(dest: CacheLedger, source: CacheLedger):
assert not dest.is_finished
dest.total_num_rows += source.total_num_rows
for shard, rows in source.shard_rows.items():
current_value = dest.shard_rows.get(shard, 0)
assert current_value == 0, f"Shard {shard} already has {current_value} rows"
dest.shard_rows[shard] = rows
for shard in source.shard_rows_in.keys():
current_rows_in = dest.shard_rows_in.get(shard, 0)
assert current_rows_in == 0, f"Shard {shard} already has {current_rows_in} IN rows"
dest.shard_rows_in[shard] = source.shard_rows_in[shard]

current_rows_out = dest.shard_rows_out.get(shard, 0)
assert current_rows_out == 0, f"Shard {shard} already has {current_rows_out} OUT rows"
dest.shard_rows_out[shard] = source.shard_rows_out[shard]

dest.finished_shards.extend(source.finished_shards)
for field, count in source.field_counts.items():
Expand Down Expand Up @@ -1347,37 +1365,44 @@ def _tokenize_one_shard_group(

for shard_name in shards:
if shard_name in ledger.finished_shards:
report_fn(_ProgressReport(new_rows=ledger.shard_rows[shard_name], new_shards=1), ledger)
report_fn(_ProgressReport(new_rows=ledger.shard_rows_out[shard_name], new_shards=1), ledger)
logger.info(f"Shard {shard_name} already processed.")
continue

logger.debug(f"Processing {shard_name}.")

rows_this_shard = ledger.shard_rows.get(shard_name, 0)
rows_in_this_shard = ledger.shard_rows_in.get(shard_name, 0)
rows_out_this_shard = ledger.shard_rows_out.get(shard_name, 0)

if found_shard_with_rows and rows_this_shard != 0:
if found_shard_with_rows and rows_in_this_shard != 0:
raise ValueError(
"Found more than one partially processed shard in this group. This indicates that the"
"number of groups has changed, which is not supported."
)

if rows_this_shard != 0:
report_fn(_ProgressReport(new_rows=rows_this_shard), ledger)
if rows_in_this_shard != 0:
report_fn(_ProgressReport(new_rows=rows_out_this_shard), ledger)
found_shard_with_rows = True

shard_iterator = source.open_shard_at_row(shard_name, rows_this_shard)
# We open at rows_in_this_shard since that's how many elements come out of the input
shard_iterator = source.open_shard_at_row(shard_name, rows_in_this_shard)

prepared_batch: PyTree[PreparedBatch] | None = None
this_batch_size = 0
this_batch_size_in = 0
this_batch_size_out = 0

for batch in batched(shard_iterator, options.batch_size):
tokenized = processor(batch)
tokenized = _canonicalize_batch(tokenized) # type: ignore
this_prepared = writer._tree_store.batch_preparer(tokenized)

this_batch_size += len(batch)
rows_this_shard += len(batch)
total_rows += len(batch)
rows_in_this_shard += len(batch)
this_batch_size_in += len(batch)

rows_out_this_shard += len(tokenized)
this_batch_size_out += len(tokenized)

total_rows += len(tokenized)

if prepared_batch is None:
prepared_batch = this_prepared
Expand All @@ -1389,31 +1414,34 @@ def _tokenize_one_shard_group(
batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch))

if batch_byte_size > options.target_bytes_per_flush:
writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch)
report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger)
writer.write_prepared_batch(shard_name, this_batch_size_in, this_batch_size_out, prepared_batch)
report_fn(_ProgressReport(new_rows=this_batch_size_out, new_bytes=batch_byte_size), writer.ledger)

nice_bytes = humanfriendly.format_size(batch_byte_size)
logger.debug(
f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})"
f"Processed {rows_in_this_shard} rows. Wrote {this_batch_size_out} rows to {shard_name}."
f" ({nice_bytes})"
)
# print(f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})", flush=True)
this_batch_size = 0
this_batch_size_out = 0
prepared_batch = None

if prepared_batch is not None:
batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch))
nice_bytes = humanfriendly.format_size(batch_byte_size)

report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger)
report_fn(_ProgressReport(new_rows=this_batch_size_out, new_bytes=batch_byte_size), writer.ledger)

writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch)
writer.write_prepared_batch(shard_name, this_batch_size_in, this_batch_size_out, prepared_batch)
logger.debug(
f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})"
f"Processed {rows_in_this_shard} rows. Wrote {this_batch_size_out} rows to {shard_name}."
f" ({nice_bytes})"
)
this_batch_size = 0
this_batch_size_in = 0
this_batch_size_out = 0
prepared_batch = None

writer.finish_shard(shard_name, rows_this_shard)
writer.finish_shard(shard_name, rows_in_this_shard, rows_out_this_shard)

report_fn(_ProgressReport(new_shards=1), writer.ledger)

Expand Down Expand Up @@ -1451,26 +1479,37 @@ def get_ledger(self):
def is_finished(self):
return self._ledger.is_finished

def finish_shard(self, shard_name: str, num_rows: int):
def finish_shard(self, shard_name: str, num_rows_in: int, num_rows_out: int):
if shard_name not in self.shards:
raise ValueError(f"Shard {shard_name} not in tracked shards")

current_rows = self._ledger.shard_rows.get(shard_name, 0)
if current_rows != num_rows:
raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}")
current_rows_in = self._ledger.shard_rows_in.get(shard_name, 0)
if current_rows_in != num_rows_in:
raise ValueError(
f"Expected {num_rows_in} rows IN to finished shard {shard_name}, but found {current_rows_in}"
)

current_rows_out = self._ledger.shard_rows_out.get(shard_name, 0)
if current_rows_out != num_rows_out:
raise ValueError(
f"Expected {num_rows_out} rows OUT to finished shard {shard_name}, but found {current_rows_out}"
)

self._ledger.finished_shards.append(shard_name)
self._ledger._serialize_and_commit(self.cache_dir)

def write_prepared_batch(self, shard_name: str, row_count: int, batch: PyTree[PreparedBatch]):
def write_prepared_batch(
self, shard_name: str, row_count_in: int, row_count_out: int, batch: PyTree[PreparedBatch]
):
if self.is_finished:
raise RuntimeError("Cannot write to a finished cache")
self._tree_store.extend_with_batch(batch)

if shard_name not in self.shards:
raise ValueError(f"Shard {shard_name} not in tracked shards")
self._ledger.shard_rows[shard_name] += row_count
self._ledger.total_num_rows += row_count
self._ledger.shard_rows_in[shard_name] += row_count_in
self._ledger.shard_rows_out[shard_name] += row_count_out
self._ledger.total_num_rows += row_count_out

self._ledger._serialize_and_commit(self.cache_dir)

Expand Down
Loading