Pass along seed to DistributedSampler (#11406)

* Pass along seed to DistributedSampler

* Add seed to DistributedLengthGroupedSampler
This commit is contained in:
Sylvain Gugger 2021-04-26 10:26:52 -04:00 committed by GitHub
parent b24ead87e1
commit ab2cabb964
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -547,6 +547,7 @@ class Trainer:
rank=self.args.process_index,
lengths=lengths,
model_input_name=model_input_name,
seed=self.args.seed,
)
else:
@ -562,10 +563,14 @@ class Trainer:
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
else:
return DistributedSampler(
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
def get_train_dataloader(self) -> DataLoader: