@@ -89,7 +89,9 @@ def __init__(self, *args, **kwargs):
8989
9090 def __iter__ (self ):
9191 if self .generator is None :
92- self .generator = torch .Generator (device = torch .get_default_device ())
92+ self .generator = torch .Generator (
93+ device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu"
94+ )
9395 self .generator .manual_seed (self .initial_seed )
9496
9597 # Allow `self.epoch` to modify the seed of the generator
@@ -1156,13 +1158,19 @@ def prepare_data_loader(
11561158 data_source = sampler .data_source ,
11571159 replacement = sampler .replacement ,
11581160 num_samples = sampler ._num_samples ,
1159- generator = getattr (sampler , "generator" , torch .Generator (device = torch .get_default_device ())),
1161+ generator = getattr (
1162+ sampler ,
1163+ "generator" ,
1164+ torch .Generator (device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu" ),
1165+ ),
11601166 data_seed = data_seed ,
11611167 )
11621168
11631169 if isinstance (dataloader .sampler , RandomSampler ) and state .distributed_type == DistributedType .XLA :
11641170 # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
1165- generator = torch .Generator (device = torch .get_default_device ()).manual_seed (42 )
1171+ generator = torch .Generator (
1172+ device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu"
1173+ ).manual_seed (42 )
11661174 dataloader .generator = generator
11671175 dataloader .sampler .generator = generator
11681176 # No change if no multiprocess
@@ -1181,7 +1189,9 @@ def prepare_data_loader(
11811189 else :
11821190 if not use_seedable_sampler and hasattr (sampler , "generator" ):
11831191 if sampler .generator is None :
1184- sampler .generator = torch .Generator (device = torch .get_default_device ())
1192+ sampler .generator = torch .Generator (
1193+ device = torch .get_default_device () if hasattr (torch , "get_default_device" ) else "cpu"
1194+ )
11851195 synchronized_generator = sampler .generator
11861196 batch_sampler = dataloader .sampler if sampler_is_batch_sampler else dataloader .batch_sampler
11871197 new_batch_sampler = BatchSamplerShard (
0 commit comments