dose this fix it?

This commit is contained in:
Arthur 2025-07-01 17:28:48 +02:00
parent 0c9f6de0fd
commit 501aead20b

View File

@ -954,8 +954,7 @@ def can_return_tuple(func):
return_dict = self.config.use_return_dict if hasattr(self, "config") else True return_dict = self.config.use_return_dict if hasattr(self, "config") else True
return_dict = kwargs.pop("return_dict", self.config.use_return_dict) return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
output = func(self, *args, **kwargs) output = func(self, *args, **kwargs)
if return_dict is False:
if "return_dict" in kwargs and return_dict is False:
output = output.to_tuple() output = output.to_tuple()
return output return output
@ -981,11 +980,7 @@ def check_model_inputs(func):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
use_cache = kwargs.get("use_cache", self.config.use_cache) use_cache = kwargs.get("use_cache", self.config.use_cache)
return_dict = kwargs.pop("return_dict", self.config.use_return_dict) return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
kwargs.setdefault("use_cache", use_cache) kwargs.setdefault("use_cache", use_cache)
kwargs["return_dict"] = kwargs.pop("return_dict", return_dict)
# Use inspect to bind args/kwargs to parameter names
sig = inspect.signature(func) sig = inspect.signature(func)
bound = sig.bind_partial(self, *args, **kwargs) bound = sig.bind_partial(self, *args, **kwargs)
bound.apply_defaults() bound.apply_defaults()
@ -996,10 +991,7 @@ def check_model_inputs(func):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
) )
use_cache = False use_cache = False
# TODO (arthur): should we init the cache here if not provided?
# if not isinstance(past_key_values, (type(None), Cache)):
# raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
#
hooks = [] hooks = []
collected_outputs = defaultdict(tuple) collected_outputs = defaultdict(tuple)