mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Don't accidentally mutate the base_model_tp_plan (#36677)
* Don't accidentally mutate the base_model_tp_plan * Co-authored by: Joao Gante <joaofranciscocardosogante@gmail.com> * Trigger tests * Marking grad accum test as slow * Add a flaky decorator * Add a flaky decorator * Use cyril's codeblock * Don't copy() when it's None * Use cyril's new codeblock * make fixup
This commit is contained in:
parent
071a161d3e
commit
c7eb95581a
@ -1895,9 +1895,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._pp_plan = self.config.base_model_pp_plan
|
||||
|
||||
self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {}
|
||||
self._pp_plan = (
|
||||
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
)
|
||||
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
|
||||
else:
|
||||
self._tp_plan = self._tp_plan or {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
|
||||
|
@ -2305,6 +2305,7 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
@is_flaky
|
||||
def test_assisted_decoding_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
|
@ -803,6 +803,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
|
||||
set_seed(42)
|
||||
import datasets
|
||||
|
Loading…
Reference in New Issue
Block a user