mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
dose this fix it?
This commit is contained in:
parent
0c9f6de0fd
commit
501aead20b
@ -954,8 +954,7 @@ def can_return_tuple(func):
|
|||||||
return_dict = self.config.use_return_dict if hasattr(self, "config") else True
|
return_dict = self.config.use_return_dict if hasattr(self, "config") else True
|
||||||
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
|
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
|
||||||
output = func(self, *args, **kwargs)
|
output = func(self, *args, **kwargs)
|
||||||
|
if return_dict is False:
|
||||||
if "return_dict" in kwargs and return_dict is False:
|
|
||||||
output = output.to_tuple()
|
output = output.to_tuple()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -981,11 +980,7 @@ def check_model_inputs(func):
|
|||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
use_cache = kwargs.get("use_cache", self.config.use_cache)
|
use_cache = kwargs.get("use_cache", self.config.use_cache)
|
||||||
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
|
return_dict = kwargs.pop("return_dict", self.config.use_return_dict)
|
||||||
|
|
||||||
kwargs.setdefault("use_cache", use_cache)
|
kwargs.setdefault("use_cache", use_cache)
|
||||||
kwargs["return_dict"] = kwargs.pop("return_dict", return_dict)
|
|
||||||
|
|
||||||
# Use inspect to bind args/kwargs to parameter names
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
bound = sig.bind_partial(self, *args, **kwargs)
|
bound = sig.bind_partial(self, *args, **kwargs)
|
||||||
bound.apply_defaults()
|
bound.apply_defaults()
|
||||||
@ -996,10 +991,7 @@ def check_model_inputs(func):
|
|||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
# TODO (arthur): should we init the cache here if not provided?
|
|
||||||
# if not isinstance(past_key_values, (type(None), Cache)):
|
|
||||||
# raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
|
||||||
#
|
|
||||||
hooks = []
|
hooks = []
|
||||||
collected_outputs = defaultdict(tuple)
|
collected_outputs = defaultdict(tuple)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user