@@ -357,6 +357,49 @@ def __init__(self, config):
357357 assert set (data_out ) == set (range (num_seqs ))
358358
359359
360+ def test_horovod_partition_combined_dataset_sampling ():
361+ num_seqs = 10
362+ sampling_size = 12
363+ dummy_data = [{"data" : numpy .array ([i ])} for i in range (num_seqs )]
364+ from returnn .datasets .meta import CombinedDataset
365+ dataset = MapDatasetWrapper (FromListDataset (data_list = dummy_data ))
366+ combined_dataset = CombinedDataset (
367+ datasets = {"dataset" : dataset }, data_map = {("dataset" , "data" ): "data" }, sampling_sizes = {"dataset" : sampling_size },
368+ data_dims = {"data" : (1 , 1 )}, seq_ordering = "random" )
369+ from returnn .config import get_global_config
370+ global_config = get_global_config (auto_create = True )
371+ global_config .set ("use_horovod" , True )
372+ global_config .set ("horovod_dataset_distribution" , "partition" )
373+ from returnn .tf import horovod
374+
375+ horovod_size = 3
376+ data_out = []
377+ for rank in range (horovod_size ):
378+ # Simulating a multi-gpu setup.
379+ def get_dummy_ctx (config = None ):
380+ class DummyHorovodContext (horovod .HorovodContext ):
381+ def __init__ (self , config ):
382+ self ._rank = rank
383+ self ._size = horovod_size
384+ self ._config = config
385+ return DummyHorovodContext (config or global_config )
386+ horovod .get_ctx = get_dummy_ctx
387+ combined_dataset .init_seq_order (epoch = None )
388+ seq_idx = 0
389+ while combined_dataset .is_less_than_num_seqs (seq_idx ):
390+ combined_dataset .load_seqs (seq_idx , seq_idx + 1 )
391+ data = combined_dataset .get_data (seq_idx , "data" )
392+ data_out .extend (data .tolist ())
393+ seq_idx += 1
394+ # We sample 12 values from range(10) "in order", so 0 and 1 should appear twice, all other values once. This e.g.
395+ # would not be the case if the sub-dataset is partitioned before sampling,
396+ # see Dataset.disable_horovod_partition.
397+ assert len (data_out ) == sampling_size
398+ assert set (data_out ) == set (range (num_seqs ))
399+ assert data_out .count (0 ) == 2
400+ assert data_out .count (1 ) == 2
401+
402+
360403if __name__ == "__main__" :
361404 better_exchook .install ()
362405 if len (sys .argv ) <= 1 :
0 commit comments