mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Fix data_seed unused (#33731)
* fixing data_seed unused * fix accelerate version needed * fix style * update the fix following accelerate fix
This commit is contained in:
parent
b2f09fb90f
commit
a37a06a20b
@ -417,6 +417,7 @@ class Trainer:
|
|||||||
self.args = args
|
self.args = args
|
||||||
# Seed must be set before instantiating the model when using model
|
# Seed must be set before instantiating the model when using model
|
||||||
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||||
|
|
||||||
self.hp_name = None
|
self.hp_name = None
|
||||||
self.deepspeed = None
|
self.deepspeed = None
|
||||||
self.is_in_train = False
|
self.is_in_train = False
|
||||||
@ -4864,6 +4865,9 @@ class Trainer:
|
|||||||
even_batches=accelerator_config.pop("even_batches"),
|
even_batches=accelerator_config.pop("even_batches"),
|
||||||
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
|
||||||
)
|
)
|
||||||
|
if is_accelerate_available("1.1.0"):
|
||||||
|
dataloader_config.data_seed = self.args.data_seed
|
||||||
|
|
||||||
non_blocking = accelerator_config.pop("non_blocking")
|
non_blocking = accelerator_config.pop("non_blocking")
|
||||||
if not is_accelerate_available("0.30.0"):
|
if not is_accelerate_available("0.30.0"):
|
||||||
if non_blocking:
|
if non_blocking:
|
||||||
|
@ -2078,6 +2078,13 @@ class TrainingArguments:
|
|||||||
"This is not supported and we recommend you to update your version."
|
"This is not supported and we recommend you to update your version."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.data_seed is not None:
|
||||||
|
if not is_accelerate_available("1.1.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"data_seed requires Accelerate version `accelerate` >= 1.1.0. "
|
||||||
|
"This is not supported and we recommend you to update your version."
|
||||||
|
)
|
||||||
|
|
||||||
if self.include_inputs_for_metrics:
|
if self.include_inputs_for_metrics:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."
|
"Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."
|
||||||
|
Loading…
Reference in New Issue
Block a user