mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add gather_use_object arguments (#31514)
* add gather_use_object arguments * fix name and pass the CI test for Seq2SeqTrainer * make style * make it to functools * fix typo * add accelerate version: * adding warning * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * make style * Update src/transformers/training_args.py * check function move to initial part * add test for eval_use_gather_object --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
82a1fc7256
commit
cb298978ad
@ -4605,6 +4605,11 @@ class Trainer:
|
||||
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
|
||||
if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
|
||||
self.gather_function = functools.partial(
|
||||
self.gather_function, use_gather_object=self.args.eval_use_gather_object
|
||||
)
|
||||
|
||||
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
|
@ -773,8 +773,11 @@ class TrainingArguments:
|
||||
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
|
||||
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
|
||||
|
||||
eval_on_start(`bool`, *optional*, defaults to `False`):
|
||||
eval_on_start (`bool`, *optional*, defaults to `False`):
|
||||
Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
|
||||
|
||||
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
|
||||
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices.
|
||||
"""
|
||||
|
||||
framework = "pt"
|
||||
@ -1465,6 +1468,13 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
eval_use_gather_object: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||
for field in _VALID_DICT_FIELDS:
|
||||
@ -1992,6 +2002,12 @@ class TrainingArguments:
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if self.eval_use_gather_object and not is_accelerate_available("0.30.0"):
|
||||
raise ValueError(
|
||||
"--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0."
|
||||
"This is not supported and we recommend you to update your version."
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
self_as_dict = asdict(self)
|
||||
|
||||
|
@ -132,6 +132,7 @@ if is_torch_available():
|
||||
|
||||
# for version specific tests in TrainerIntegrationTest
|
||||
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
|
||||
require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30")
|
||||
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
|
||||
if is_accelerate_available():
|
||||
from accelerate import Accelerator
|
||||
@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertIn("torch_dtype", args_dict)
|
||||
self.assertEqual(args_dict["torch_dtype"], dtype)
|
||||
|
||||
@require_accelerate_version_min_0_30
|
||||
def test_eval_use_gather_object(self):
|
||||
train_dataset = RegressionDataset()
|
||||
eval_dataset = RegressionDataset()
|
||||
model = RegressionDictModel()
|
||||
args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
_ = trainer.evaluate()
|
||||
_ = trainer.predict(eval_dataset)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
Loading…
Reference in New Issue
Block a user