mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Pass along seed to DistributedSampler (#11406)
* Pass along seed to DistributedSampler * Add seed to DistributedLengthGroupedSampler
This commit is contained in:
parent
b24ead87e1
commit
ab2cabb964
@ -547,6 +547,7 @@ class Trainer:
|
|||||||
rank=self.args.process_index,
|
rank=self.args.process_index,
|
||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
model_input_name=model_input_name,
|
model_input_name=model_input_name,
|
||||||
|
seed=self.args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -562,10 +563,14 @@ class Trainer:
|
|||||||
batch_size=self.args.per_device_train_batch_size,
|
batch_size=self.args.per_device_train_batch_size,
|
||||||
num_replicas=self.args.world_size,
|
num_replicas=self.args.world_size,
|
||||||
rank=self.args.process_index,
|
rank=self.args.process_index,
|
||||||
|
seed=self.args.seed,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return DistributedSampler(
|
return DistributedSampler(
|
||||||
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
|
self.train_dataset,
|
||||||
|
num_replicas=self.args.world_size,
|
||||||
|
rank=self.args.process_index,
|
||||||
|
seed=self.args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
Loading…
Reference in New Issue
Block a user