From cb298978ade3f1edb0ffd02ee079a69f08917a2a Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Fri, 28 Jun 2024 21:50:27 +0900 Subject: [PATCH] add gather_use_object arguments (#31514) * add gather_use_object arguments * fix name and pass the CI test for Seq2SeqTrainer * make style * make it to functools * fix typo * add accelerate version: * adding warning * Update src/transformers/trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * make style * Update src/transformers/training_args.py * check function move to initial part * add test for eval_use_gather_object --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/trainer.py | 5 +++++ src/transformers/training_args.py | 18 +++++++++++++++++- tests/trainer/test_trainer.py | 12 ++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 71c3ee43af2..affc7b725e8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4605,6 +4605,11 @@ class Trainer: # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 41a9607e312..5eff032774e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -773,8 +773,11 @@ class TrainingArguments: that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. - eval_on_start(`bool`, *optional*, defaults to `False`): + eval_on_start (`bool`, *optional*, defaults to `False`): Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. + + eval_use_gather_object (`bool`, *optional*, defaults to `False`): + Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. """ framework = "pt" @@ -1465,6 +1468,13 @@ class TrainingArguments: }, ) + eval_use_gather_object: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1992,6 +2002,12 @@ class TrainingArguments: FutureWarning, ) + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0." + "This is not supported and we recommend you to update your version." + ) + def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2c6793e39ca..26fa4624674 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -132,6 +132,7 @@ if is_torch_available(): # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") +require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") if is_accelerate_available(): from accelerate import Accelerator @@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertIn("torch_dtype", args_dict) self.assertEqual(args_dict["torch_dtype"], dtype) + @require_accelerate_version_min_0_30 + def test_eval_use_gather_object(self): + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + @require_torch @is_staging_test