diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index a7c88baef..348e78812 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -473,7 +473,11 @@ def _monitor_metrics(self): class CacheLedger: # NB: unlike the old cache, the mere existence of a ledger doesn't mean the cache is finished total_num_rows: int - shard_rows: Dict[str, int] + """Number of outputted rows in the cache""" + shard_rows_in: Dict[str, int] + """Numbers of rows read in from each shard""" + shard_rows_out: Dict[str, int] + """Number of rows written out of each shard""" is_finished: bool = False finished_shards: List[str] = dataclasses.field(default_factory=list) field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) @@ -487,7 +491,8 @@ def load_or_initialize(cache_dir: str, source: ShardedDataSource, processor: Bat except FileNotFoundError: return CacheLedger( total_num_rows=0, - shard_rows={shard: 0 for shard in source.shard_names}, + shard_rows_in={shard: 0 for shard in source.shard_names}, + shard_rows_out={shard: 0 for shard in source.shard_names}, is_finished=False, metadata=metadata, ) @@ -569,7 +574,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): ledger = CacheLedger( total_num_rows=len(self._tree_store), is_finished=True, - shard_rows={"": len(self._tree_store)}, + shard_rows_in={"": len(self._tree_store)}, # This is ok since we directly write batches + shard_rows_out={"": len(self._tree_store)}, finished_shards=[""], field_counts={}, metadata=self.metadata or CacheMetadata.empty(), @@ -740,9 +746,17 @@ def _notify_updated_ledger(self, ledger: CacheLedger): if ledger.total_num_rows < self._ledger.total_num_rows: raise RuntimeError(f"Ledger went backwards: {ledger.total_num_rows} < {self._ledger.total_num_rows}") - for shard, rows in ledger.shard_rows.items(): - if rows < self._ledger.shard_rows.get(shard, 0): - raise RuntimeError(f"Shard {shard} went backwards: {rows} < {self._ledger.shard_rows.get(shard, 0)}") + for shard in ledger.shard_rows_in.keys(): + rows_in = ledger.shard_rows_in[shard] + rows_out = ledger.shard_rows_out[shard] + if rows_in < self._ledger.shard_rows_in.get(shard, 0): + raise RuntimeError( + f"Shard {shard} went backwards (IN): {rows_in} < {self._ledger.shard_rows_in.get(shard, 0)}" + ) + if rows_out < self._ledger.shard_rows_out.get(shard, 0): + raise RuntimeError( + f"Shard {shard} went backwards (OUT): {rows_in} < {self._ledger.shard_rows_out.get(shard, 0)}" + ) if was_finished: raise RuntimeError("Ledger was already finished") @@ -1007,10 +1021,14 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) def _merge_ledgers(dest: CacheLedger, source: CacheLedger): assert not dest.is_finished dest.total_num_rows += source.total_num_rows - for shard, rows in source.shard_rows.items(): - current_value = dest.shard_rows.get(shard, 0) - assert current_value == 0, f"Shard {shard} already has {current_value} rows" - dest.shard_rows[shard] = rows + for shard in source.shard_rows_in.keys(): + current_rows_in = dest.shard_rows_in.get(shard, 0) + assert current_rows_in == 0, f"Shard {shard} already has {current_rows_in} IN rows" + dest.shard_rows_in[shard] = source.shard_rows_in[shard] + + current_rows_out = dest.shard_rows_out.get(shard, 0) + assert current_rows_out == 0, f"Shard {shard} already has {current_rows_out} OUT rows" + dest.shard_rows_out[shard] = source.shard_rows_out[shard] dest.finished_shards.extend(source.finished_shards) for field, count in source.field_counts.items(): @@ -1347,37 +1365,44 @@ def _tokenize_one_shard_group( for shard_name in shards: if shard_name in ledger.finished_shards: - report_fn(_ProgressReport(new_rows=ledger.shard_rows[shard_name], new_shards=1), ledger) + report_fn(_ProgressReport(new_rows=ledger.shard_rows_out[shard_name], new_shards=1), ledger) logger.info(f"Shard {shard_name} already processed.") continue logger.debug(f"Processing {shard_name}.") - rows_this_shard = ledger.shard_rows.get(shard_name, 0) + rows_in_this_shard = ledger.shard_rows_in.get(shard_name, 0) + rows_out_this_shard = ledger.shard_rows_out.get(shard_name, 0) - if found_shard_with_rows and rows_this_shard != 0: + if found_shard_with_rows and rows_in_this_shard != 0: raise ValueError( "Found more than one partially processed shard in this group. This indicates that the" "number of groups has changed, which is not supported." ) - if rows_this_shard != 0: - report_fn(_ProgressReport(new_rows=rows_this_shard), ledger) + if rows_in_this_shard != 0: + report_fn(_ProgressReport(new_rows=rows_out_this_shard), ledger) found_shard_with_rows = True - shard_iterator = source.open_shard_at_row(shard_name, rows_this_shard) + # We open at rows_in_this_shard since that's how many elements come out of the input + shard_iterator = source.open_shard_at_row(shard_name, rows_in_this_shard) prepared_batch: PyTree[PreparedBatch] | None = None - this_batch_size = 0 + this_batch_size_in = 0 + this_batch_size_out = 0 for batch in batched(shard_iterator, options.batch_size): tokenized = processor(batch) tokenized = _canonicalize_batch(tokenized) # type: ignore this_prepared = writer._tree_store.batch_preparer(tokenized) - this_batch_size += len(batch) - rows_this_shard += len(batch) - total_rows += len(batch) + rows_in_this_shard += len(batch) + this_batch_size_in += len(batch) + + rows_out_this_shard += len(tokenized) + this_batch_size_out += len(tokenized) + + total_rows += len(tokenized) if prepared_batch is None: prepared_batch = this_prepared @@ -1389,31 +1414,34 @@ def _tokenize_one_shard_group( batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) if batch_byte_size > options.target_bytes_per_flush: - writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) - report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) + writer.write_prepared_batch(shard_name, this_batch_size_in, this_batch_size_out, prepared_batch) + report_fn(_ProgressReport(new_rows=this_batch_size_out, new_bytes=batch_byte_size), writer.ledger) nice_bytes = humanfriendly.format_size(batch_byte_size) logger.debug( - f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + f"Processed {rows_in_this_shard} rows. Wrote {this_batch_size_out} rows to {shard_name}." + f" ({nice_bytes})" ) # print(f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})", flush=True) - this_batch_size = 0 + this_batch_size_out = 0 prepared_batch = None if prepared_batch is not None: batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) nice_bytes = humanfriendly.format_size(batch_byte_size) - report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) + report_fn(_ProgressReport(new_rows=this_batch_size_out, new_bytes=batch_byte_size), writer.ledger) - writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + writer.write_prepared_batch(shard_name, this_batch_size_in, this_batch_size_out, prepared_batch) logger.debug( - f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + f"Processed {rows_in_this_shard} rows. Wrote {this_batch_size_out} rows to {shard_name}." + f" ({nice_bytes})" ) - this_batch_size = 0 + this_batch_size_in = 0 + this_batch_size_out = 0 prepared_batch = None - writer.finish_shard(shard_name, rows_this_shard) + writer.finish_shard(shard_name, rows_in_this_shard, rows_out_this_shard) report_fn(_ProgressReport(new_shards=1), writer.ledger) @@ -1451,26 +1479,37 @@ def get_ledger(self): def is_finished(self): return self._ledger.is_finished - def finish_shard(self, shard_name: str, num_rows: int): + def finish_shard(self, shard_name: str, num_rows_in: int, num_rows_out: int): if shard_name not in self.shards: raise ValueError(f"Shard {shard_name} not in tracked shards") - current_rows = self._ledger.shard_rows.get(shard_name, 0) - if current_rows != num_rows: - raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") + current_rows_in = self._ledger.shard_rows_in.get(shard_name, 0) + if current_rows_in != num_rows_in: + raise ValueError( + f"Expected {num_rows_in} rows IN to finished shard {shard_name}, but found {current_rows_in}" + ) + + current_rows_out = self._ledger.shard_rows_out.get(shard_name, 0) + if current_rows_out != num_rows_out: + raise ValueError( + f"Expected {num_rows_out} rows OUT to finished shard {shard_name}, but found {current_rows_out}" + ) self._ledger.finished_shards.append(shard_name) self._ledger._serialize_and_commit(self.cache_dir) - def write_prepared_batch(self, shard_name: str, row_count: int, batch: PyTree[PreparedBatch]): + def write_prepared_batch( + self, shard_name: str, row_count_in: int, row_count_out: int, batch: PyTree[PreparedBatch] + ): if self.is_finished: raise RuntimeError("Cannot write to a finished cache") self._tree_store.extend_with_batch(batch) if shard_name not in self.shards: raise ValueError(f"Shard {shard_name} not in tracked shards") - self._ledger.shard_rows[shard_name] += row_count - self._ledger.total_num_rows += row_count + self._ledger.shard_rows_in[shard_name] += row_count_in + self._ledger.shard_rows_out[shard_name] += row_count_out + self._ledger.total_num_rows += row_count_out self._ledger._serialize_and_commit(self.cache_dir)