mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
also add the changes needed to modeling utils
This commit is contained in:
parent
37b4ef022e
commit
7f113b43cc
@ -33,7 +33,7 @@ from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union, Dict, Tuple
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import torch
|
||||
@ -2006,6 +2006,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# In practice, it means that they support attention interface functions, fully pass the kwargs
|
||||
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
||||
_supports_attention_backend = False
|
||||
_can_record_outputs = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
@ -2056,6 +2057,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
_param_to_record = {}
|
||||
for module in self.modules():
|
||||
if hasattr(module, "return_hooks"):
|
||||
_param_to_record.update({module.return_hooks[0]: (module, module.return_hooks[1])})
|
||||
self._can_record_outputs: Dict[str, Tuple[nn.Module, int]] = _param_to_record
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user