also add the changes needed to modeling utils

This commit is contained in:
Arthur 2025-06-30 11:39:29 +02:00
parent 37b4ef022e
commit 7f113b43cc

View File

@ -33,7 +33,7 @@ from contextlib import contextmanager
from enum import Enum
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union, Dict, Tuple
from zipfile import is_zipfile
import torch
@ -2006,6 +2006,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
_supports_attention_backend = False
_can_record_outputs = None
@property
def dummy_inputs(self) -> dict[str, torch.Tensor]:
@ -2056,6 +2057,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self._no_split_modules = self._no_split_modules or []
_param_to_record = {}
for module in self.modules():
if hasattr(module, "return_hooks"):
_param_to_record.update({module.return_hooks[0]: (module, module.return_hooks[1])})
self._can_record_outputs: Dict[str, Tuple[nn.Module, int]] = _param_to_record
def post_init(self):
"""