diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 71c3ee43af2..affc7b725e8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4605,6 +4605,11 @@ class Trainer: # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 41a9607e312..5eff032774e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -773,8 +773,11 @@ class TrainingArguments: that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. - eval_on_start(`bool`, *optional*, defaults to `False`): + eval_on_start (`bool`, *optional*, defaults to `False`): Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. + + eval_use_gather_object (`bool`, *optional*, defaults to `False`): + Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. """ framework = "pt" @@ -1465,6 +1468,13 @@ class TrainingArguments: }, ) + eval_use_gather_object: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices." + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1992,6 +2002,12 @@ class TrainingArguments: FutureWarning, ) + if self.eval_use_gather_object and not is_accelerate_available("0.30.0"): + raise ValueError( + "--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0." + "This is not supported and we recommend you to update your version." + ) + def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2c6793e39ca..26fa4624674 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -132,6 +132,7 @@ if is_torch_available(): # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") +require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") if is_accelerate_available(): from accelerate import Accelerator @@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertIn("torch_dtype", args_dict) self.assertEqual(args_dict["torch_dtype"], dtype) + @require_accelerate_version_min_0_30 + def test_eval_use_gather_object(self): + train_dataset = RegressionDataset() + eval_dataset = RegressionDataset() + model = RegressionDictModel() + args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train() + _ = trainer.evaluate() + _ = trainer.predict(eval_dataset) + @require_torch @is_staging_test