-
Notifications
You must be signed in to change notification settings - Fork 134
Description
Originally posted in #1257, on the HuggingFaceDataset, but actually, the discussion applies just the same for most of our datasets, e.g. HDFDataset, OggZipDataset, etc
[We] do random access here (
self.dataset[int(corpus_seq_idx)]. Is this efficient?On the speed: In the HF datasets process doc on shuffling (
shuffle):Shuffling takes the list of indices [0:len(my_dataset)] and shuffles it to create an indices mapping. However as soon as your Dataset has an indices mapping, the speed can become 10x slower. This is because there is an extra step to get the row index to read using the indices mapping, and most importantly, you aren’t reading contiguous chunks of data anymore.
shufflesounds very much like what we do ininit_seq_order, i.e. building a list of indices (indices mapping). Then (if it is not the default order), we aren’t reading contiguous chunks of data anymore.So, is this an issue? This sounds like an issue? Is it still an issue when we have the data on local disk? (Does the format matter? What about Parquet? Arrow?
save_to_diskuses Arrow?)The doc also mentions "fast approximate shuffling IterableDataset.shuffle()". As I understand the implementation, it prefetches some amount of entries into a buffer, but this is done from the beginning of the dataset. Then it randomly samples an item from the buffer. I see that this is fast now, because we again can read contiguous chunks of data. But at the same time, this is not really so random at all? This sounds quite bad? However, it seems this would also at least shuffle over shards, if there are shards. That makes it better.
We could maybe also adapt our shuffling to operate more in chunks. I.e. first segment the data into chunks of some fixed chunk size, then shuffle the chunks, then shuffle within each chunk. That way, most of the random access would be more close together. But I don't know if the underlying Arrow dataset would already benefit from that, or still needs extra logic (like prefetching the whole chunk). Also, not sure about a good chunk size.
I read a bit more about that. E.g. here on HF datasets arrow:
Arrow’s standard format allows zero-copy reads which removes virtually all serialization overhead.
You can obtain the best performance by accessing slices of data (or “batches”), in order to reduce the amount of lookups on disk.They have this example code:
wiki = load_dataset("wikipedia", "20220301.en", split="train") batch_size = 1000 for i in range(0, len(wiki), batch_size): batch = wiki[i:i + batch_size]So, it seems that the underlying Arrow would not internally do this, and I need to take care of such logic (?).
But that wouldn't really be such a big issue.
Edit I just checked
datasets.arrow_dataset.Dataset.__iter__(that is what a common dataset__iter__will likely be):def __iter__(self): if self._indices is None: # Fast iteration # Benchmark: https://gist.github.com/mariosasko/0248288a2e3a7556873969717c1fe52b (fast_iter_batch) ... batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER for pa_subtable in table_iter(self.data, batch_size=batch_size): for i in range(pa_subtable.num_rows): ... yield formatted_output else: for i in range(self.num_rows): yield self._getitem(i)So you see why it is slower with
shuffle(whereself._indiceswould be set): It doesn't use the batched reading.However, if we made sure that even with shuffling, most of the seq indices are still closeby, I think we still can use batched reading, and speeding it up again quite a bit.
I think we can improve our current logic. By more batched read access to the underlying data (e.g. like ARROW_READER_BATCH_SIZE_IN_DATASET_ITER), and then by more structured sorting/shuffling which makes sure that such batched reads are actually useful.
But, before we invest time in this, we should really see that this is actually needed. Accessing the data from local disk makes it already much faster (see DistributeFilesDataset with file caching, or HuggingFaceDataset use_file_cache option, etc). Then you can also use MultiProcDataset. At least for my current training runs, I can always make the dataset loading fast enough that this is never a bottleneck.
(cc @NeoLegends @dthulke)