dose this fix it?

This commit is contained in:
Arthur 2025-07-01 17:28:48 +02:00
parent 0c9f6de0fd
commit 501aead20b

View File

@ -954,8 +954,7 @@ def can_return_tuple(func):
return_dict = self.config.use_return_dict if hasattr(self, "config") else True
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
output = func(self, *args, **kwargs)
if "return_dict" in kwargs and return_dict is False:
if return_dict is False:
output = output.to_tuple()
return output
@ -981,11 +980,7 @@ def check_model_inputs(func):
def wrapper(self, *args, **kwargs):
use_cache = kwargs.get("use_cache", self.config.use_cache)
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
kwargs.setdefault("use_cache", use_cache)
kwargs["return_dict"] = kwargs.pop("return_dict", return_dict)
# Use inspect to bind args/kwargs to parameter names
sig = inspect.signature(func)
bound = sig.bind_partial(self, *args, **kwargs)
bound.apply_defaults()
@ -996,10 +991,7 @@ def check_model_inputs(func):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
# TODO (arthur): should we init the cache here if not provided?
# if not isinstance(past_key_values, (type(None), Cache)):
# raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
#
hooks = []
collected_outputs = defaultdict(tuple)