mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
nits and fixes
This commit is contained in:
parent
abf9d39d12
commit
eb6747bca9
@ -2057,11 +2057,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
_param_to_record = {}
|
||||
for module in self.modules():
|
||||
if hasattr(module, "return_hooks"):
|
||||
_param_to_record.update({module.return_hooks[0]: (module, module.return_hooks[1])})
|
||||
self._can_record_outputs: Dict[str, Tuple[nn.Module, int]] = _param_to_record
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
|
@ -20,7 +20,7 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import OrderedDict, UserDict
|
||||
from collections import OrderedDict, UserDict, defaultdict
|
||||
from collections.abc import Iterable, MutableMapping
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields, is_dataclass
|
||||
@ -930,6 +930,7 @@ def can_return_tuple(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return_dict = self.config.use_return_dict if hasattr(self, "config") else True
|
||||
if "return_dict" in kwargs:
|
||||
return_dict = kwargs.get("return_dict", self.config.use_return_dict)
|
||||
kwargs["return_dict"] = True
|
||||
@ -975,16 +976,18 @@ def check_model_inputs(func):
|
||||
# and all of the rest that is general transformers checking
|
||||
|
||||
hooks = []
|
||||
collected_outputs = {}
|
||||
collected_outputs = defaultdict(list)
|
||||
|
||||
def make_capture_fn(key, index):
|
||||
def capture_fn(module, input, output):
|
||||
collected_outputs[key].append(output[index])
|
||||
if output[index] is not None:
|
||||
collected_outputs[key].append(output[index])
|
||||
|
||||
return capture_fn
|
||||
|
||||
capture_flags = self._can_record_outputs.keys()
|
||||
recordable_keys = {f"output_{k}": kwargs.get(f"output_{k}", False) for k in capture_flags}
|
||||
capture_flags = self._can_record_outputs
|
||||
all_args.update(**all_args["kwargs"])
|
||||
recordable_keys = {f"output_{k}": all_args.get(f"output_{k}", False) for k in capture_flags}
|
||||
if any(recordable_keys.values()):
|
||||
for (
|
||||
_,
|
||||
@ -1002,7 +1005,7 @@ def check_model_inputs(func):
|
||||
|
||||
for key in collected_outputs:
|
||||
outputs[key] = collected_outputs[key]
|
||||
|
||||
print(collected_outputs)
|
||||
if return_dict is False:
|
||||
outputs = outputs.to_tuple()
|
||||
return outputs
|
||||
|
Loading…
Reference in New Issue
Block a user