nits and fixes

This commit is contained in:
Arthur 2025-06-30 12:03:41 +02:00
parent abf9d39d12
commit eb6747bca9
2 changed files with 9 additions and 11 deletions

View File

@ -2057,11 +2057,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self._no_split_modules = self._no_split_modules or []
_param_to_record = {}
for module in self.modules():
if hasattr(module, "return_hooks"):
_param_to_record.update({module.return_hooks[0]: (module, module.return_hooks[1])})
self._can_record_outputs: Dict[str, Tuple[nn.Module, int]] = _param_to_record
def post_init(self):
"""

View File

@ -20,7 +20,7 @@ import json
import os
import tempfile
import warnings
from collections import OrderedDict, UserDict
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, MutableMapping
from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
@ -930,6 +930,7 @@ def can_return_tuple(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
return_dict = self.config.use_return_dict if hasattr(self, "config") else True
if "return_dict" in kwargs:
return_dict = kwargs.get("return_dict", self.config.use_return_dict)
kwargs["return_dict"] = True
@ -975,16 +976,18 @@ def check_model_inputs(func):
# and all of the rest that is general transformers checking
hooks = []
collected_outputs = {}
collected_outputs = defaultdict(list)
def make_capture_fn(key, index):
def capture_fn(module, input, output):
collected_outputs[key].append(output[index])
if output[index] is not None:
collected_outputs[key].append(output[index])
return capture_fn
capture_flags = self._can_record_outputs.keys()
recordable_keys = {f"output_{k}": kwargs.get(f"output_{k}", False) for k in capture_flags}
capture_flags = self._can_record_outputs
all_args.update(**all_args["kwargs"])
recordable_keys = {f"output_{k}": all_args.get(f"output_{k}", False) for k in capture_flags}
if any(recordable_keys.values()):
for (
_,
@ -1002,7 +1005,7 @@ def check_model_inputs(func):
for key in collected_outputs:
outputs[key] = collected_outputs[key]
print(collected_outputs)
if return_dict is False:
outputs = outputs.to_tuple()
return outputs