Skip to content

Commit 0b73a90

Browse files
Fix horovod dataset partion for CombinedDataset with sampling_sizes
1 parent e49a2e4 commit 0b73a90

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

returnn/datasets/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(self, name=None,
121121
self.random_seed_offset = random_seed_offset
122122
self.partition_epoch = partition_epoch or 1
123123
self.repeat_epoch = repeat_epoch or 1
124+
self.disable_horovod_partition = False # can be set by meta-dataset to handle multi-gpu partitioning on meta-level
124125
self.seq_tags_filter = set(self._load_seq_list_file(seq_list_filter_file)) if seq_list_filter_file else None
125126
self.unique_seq_tags = unique_seq_tags
126127
self._seq_order_seq_lens_file = seq_order_seq_lens_file
@@ -483,7 +484,8 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
483484
seq_index = self._apply_partition_epoch(seq_index, partition_epoch, epoch)
484485
if repeat_epoch > 1:
485486
seq_index = seq_index * repeat_epoch
486-
seq_index = self._apply_multi_gpu_partition(seq_index)
487+
if not self.disable_horovod_partition:
488+
seq_index = self._apply_multi_gpu_partition(seq_index)
487489
if self.seq_tags_filter is not None:
488490
# Note: This is as generic as possible, but requires that get_all_tags is implemented.
489491
assert seq_index

returnn/datasets/meta.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,9 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
913913
# partition epoch of the individual sub-datasets is still supported. Later we will call init_seq_order again with a
914914
# sequence list to e.g. apply joint sorting or partition epoch of all sequences.
915915
for dataset in self.datasets.values():
916+
if self.sampling_sizes:
917+
# Partitioning does not make sense if we sample a fixed number of sequences anyway.
918+
dataset.disable_horovod_partition = True
916919
dataset.init_seq_order(epoch=epoch)
917920

918921
# noinspection PyBroadException
@@ -1076,6 +1079,8 @@ def _get_sampling_seq_order(self):
10761079
# We want to additionally sort the sequences in the current sample. For this, create a sequence order on a
10771080
# range of length of the number of sequences in the sample. Note that we have to map the indices to make use
10781081
# of self._get_seq_length here.
1082+
# This get_seq_order_for_epoch call now also handles horovod_dataset_distribution = 'partition', which we
1083+
# disabled on sub-dataset level via 'disable_horovod_partition' above.
10791084
seq_order_remapping = self.get_seq_order_for_epoch(
10801085
epoch=epoch, num_seqs=len(seq_order), get_seq_len=lambda i: self._get_seq_length(seq_order[i]))
10811086

0 commit comments

Comments
 (0)