add gather_use_object arguments (#31514)

* add gather_use_object arguments

* fix name and pass the CI test for Seq2SeqTrainer

* make style

* make it to functools

* fix typo

* add accelerate version:

* adding warning

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* make style

* Update src/transformers/training_args.py

* check function move to initial part

* add test for eval_use_gather_object

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Sangbum Daniel Choi 2024-06-28 21:50:27 +09:00 committed by GitHub
parent 82a1fc7256
commit cb298978ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 1 deletions

View File

@ -4605,6 +4605,11 @@ class Trainer:
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
self.gather_function = functools.partial(
self.gather_function, use_gather_object=self.args.eval_use_gather_object
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

View File

@ -773,8 +773,11 @@ class TrainingArguments:
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
eval_on_start(`bool`, *optional*, defaults to `False`):
eval_on_start (`bool`, *optional*, defaults to `False`):
Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices.
"""
framework = "pt"
@ -1465,6 +1468,13 @@ class TrainingArguments:
},
)
eval_use_gather_object: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
},
)
def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
@ -1992,6 +2002,12 @@ class TrainingArguments:
FutureWarning,
)
if self.eval_use_gather_object and not is_accelerate_available("0.30.0"):
raise ValueError(
"--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0."
"This is not supported and we recommend you to update your version."
)
def __str__(self):
self_as_dict = asdict(self)

View File

@ -132,6 +132,7 @@ if is_torch_available():
# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
if is_accelerate_available():
from accelerate import Accelerator
@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIn("torch_dtype", args_dict)
self.assertEqual(args_dict["torch_dtype"], dtype)
@require_accelerate_version_min_0_30
def test_eval_use_gather_object(self):
train_dataset = RegressionDataset()
eval_dataset = RegressionDataset()
model = RegressionDictModel()
args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
_ = trainer.evaluate()
_ = trainer.predict(eval_dataset)
@require_torch
@is_staging_test