mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
dose this fix it?
This commit is contained in:
parent
0c9f6de0fd
commit
501aead20b
@ -954,8 +954,7 @@ def can_return_tuple(func):
|
||||
return_dict = self.config.use_return_dict if hasattr(self, "config") else True
|
||||
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
|
||||
output = func(self, *args, **kwargs)
|
||||
|
||||
if "return_dict" in kwargs and return_dict is False:
|
||||
if return_dict is False:
|
||||
output = output.to_tuple()
|
||||
return output
|
||||
|
||||
@ -981,11 +980,7 @@ def check_model_inputs(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
use_cache = kwargs.get("use_cache", self.config.use_cache)
|
||||
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
|
||||
|
||||
kwargs.setdefault("use_cache", use_cache)
|
||||
kwargs["return_dict"] = kwargs.pop("return_dict", return_dict)
|
||||
|
||||
# Use inspect to bind args/kwargs to parameter names
|
||||
sig = inspect.signature(func)
|
||||
bound = sig.bind_partial(self, *args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
@ -996,10 +991,7 @@ def check_model_inputs(func):
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
# TODO (arthur): should we init the cache here if not provided?
|
||||
# if not isinstance(past_key_values, (type(None), Cache)):
|
||||
# raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
#
|
||||
|
||||
hooks = []
|
||||
collected_outputs = defaultdict(tuple)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user