diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8d6d4bdb480..9c4ddd268a0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -417,6 +417,7 @@ class Trainer: self.args = args # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.hp_name = None self.deepspeed = None self.is_in_train = False @@ -4864,6 +4865,9 @@ class Trainer: even_batches=accelerator_config.pop("even_batches"), use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), ) + if is_accelerate_available("1.1.0"): + dataloader_config.data_seed = self.args.data_seed + non_blocking = accelerator_config.pop("non_blocking") if not is_accelerate_available("0.30.0"): if non_blocking: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 26da84dfe23..a2d83b2915e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2078,6 +2078,13 @@ class TrainingArguments: "This is not supported and we recommend you to update your version." ) + if self.data_seed is not None: + if not is_accelerate_available("1.1.0"): + raise NotImplementedError( + "data_seed requires Accelerate version `accelerate` >= 1.1.0. " + "This is not supported and we recommend you to update your version." + ) + if self.include_inputs_for_metrics: logger.warning( "Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."