mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Fix data_seed unused (#33731)
* fixing data_seed unused * fix accelerate version needed * fix style * update the fix following accelerate fix
This commit is contained in:
parent
b2f09fb90f
commit
a37a06a20b
@ -417,6 +417,7 @@ class Trainer:
|
||||
self.args = args
|
||||
# Seed must be set before instantiating the model when using model
|
||||
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||
|
||||
self.hp_name = None
|
||||
self.deepspeed = None
|
||||
self.is_in_train = False
|
||||
@ -4864,6 +4865,9 @@ class Trainer:
|
||||
even_batches=accelerator_config.pop("even_batches"),
|
||||
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
||||
)
|
||||
if is_accelerate_available("1.1.0"):
|
||||
dataloader_config.data_seed = self.args.data_seed
|
||||
|
||||
non_blocking = accelerator_config.pop("non_blocking")
|
||||
if not is_accelerate_available("0.30.0"):
|
||||
if non_blocking:
|
||||
|
@ -2078,6 +2078,13 @@ class TrainingArguments:
|
||||
"This is not supported and we recommend you to update your version."
|
||||
)
|
||||
|
||||
if self.data_seed is not None:
|
||||
if not is_accelerate_available("1.1.0"):
|
||||
raise NotImplementedError(
|
||||
"data_seed requires Accelerate version `accelerate` >= 1.1.0. "
|
||||
"This is not supported and we recommend you to update your version."
|
||||
)
|
||||
|
||||
if self.include_inputs_for_metrics:
|
||||
logger.warning(
|
||||
"Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."
|
||||
|
Loading…
Reference in New Issue
Block a user