From c7eb95581a8da3e41ae98decfeb98f7be4d10d82 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 12 Mar 2025 18:59:13 +0000 Subject: [PATCH] Don't accidentally mutate the base_model_tp_plan (#36677) * Don't accidentally mutate the base_model_tp_plan * Co-authored by: Joao Gante * Trigger tests * Marking grad accum test as slow * Add a flaky decorator * Add a flaky decorator * Use cyril's codeblock * Don't copy() when it's None * Use cyril's new codeblock * make fixup --- src/transformers/modeling_utils.py | 12 +++++++++--- tests/generation/test_utils.py | 1 + tests/trainer/test_trainer.py | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 61c86cffd04..a33f556f843 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1895,9 +1895,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config if self.base_model is self: - self._pp_plan = self.config.base_model_pp_plan - - self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {} + self._pp_plan = ( + self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None + ) + self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {} + else: + self._tp_plan = self._tp_plan or {} + for name, module in self.named_children(): + if plan := getattr(module, "_tp_plan", None): + self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()}) for name, module in self.named_children(): if plan := getattr(module, "_tp_plan", None): self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()}) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 00ac78a94d3..6d93c77d865 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2305,6 +2305,7 @@ class GenerationTesterMixin: self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) @pytest.mark.generate + @is_flaky def test_assisted_decoding_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index beee7fcb48a..1f6441758c9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -803,6 +803,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.check_trained_model(trainer.model, alternate_seed=True) + @slow def test_gradient_accumulation_loss_alignment_with_model_loss(self): set_seed(42) import datasets