mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Cache: don't show warning in forward passes when past_key_values
is None (#33541)
This commit is contained in:
parent
f3b3810fe6
commit
80b774eb29
@ -120,7 +120,7 @@ To enable quantization of the key-value cache, one needs to indicate `cache_impl
|
|||||||
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`~QuantizedCacheConfig`] class.
|
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`~QuantizedCacheConfig`] class.
|
||||||
One has to indicate which quantization backend to use in the [`~QuantizedCacheConfig`], the default is `quanto`.
|
One has to indicate which quantization backend to use in the [`~QuantizedCacheConfig`], the default is `quanto`.
|
||||||
|
|
||||||
It is recommended to set `axis-key/axis-value` parameters in the cache config to `0` if you're using the `quanto` backend and to `1` if you're using the `HQQ` backend. For other config values, please use the defaults unless you're running out of memory. In that case, you may consider decreasing the residual length.
|
It is recommended to set `axis-key/axis-value` parameters in the cache config to `0` if you're using the `quanto` backend and to `1` if you're using the `HQQ` backend. For other config values, please use the defaults unless you're running out of memory. In that case, you may consider decreasing the residual length.
|
||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac
|
|||||||
|
|
||||||
### Encoder-Decoder Cache
|
### Encoder-Decoder Cache
|
||||||
|
|
||||||
The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper) models but we will be adding more models soon.
|
The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper) models but we will be adding more models soon.
|
||||||
|
|
||||||
In terms of usage, there is nothing special to be done and calling `generate()` or `forward()` will handle everything for you.
|
In terms of usage, there is nothing special to be done and calling `generate()` or `forward()` will handle everything for you.
|
||||||
|
|
||||||
@ -379,7 +379,7 @@ Sometimes you would want to first fill-in cache object with key/values for certa
|
|||||||
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
|
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
|
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
>>> # Init StaticCache with big enough max-length (1024 tokens for the below example)
|
>>> # Init StaticCache with big enough max-length (1024 tokens for the below example)
|
||||||
>>> # You can also init a DynamicCache, if that suits you better
|
>>> # You can also init a DynamicCache, if that suits you better
|
||||||
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
|
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
@ -394,10 +394,35 @@ Sometimes you would want to first fill-in cache object with key/values for certa
|
|||||||
>>> for prompt in prompts:
|
>>> for prompt in prompts:
|
||||||
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
|
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
|
||||||
... past_key_values = copy.deepcopy(prompt_cache)
|
... past_key_values = copy.deepcopy(prompt_cache)
|
||||||
... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
|
... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
|
||||||
... response = tokenizer.batch_decode(outputs)[0]
|
... response = tokenizer.batch_decode(outputs)[0]
|
||||||
... responses.append(response)
|
... responses.append(response)
|
||||||
|
|
||||||
>>> print(responses)
|
>>> print(responses)
|
||||||
['<s> You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTitle: The Ultimate Guide to Travelling: Tips, Tricks, and', '<s> You are a helpful assistant. What is the capital of France?\n\nYes, the capital of France is Paris.</s>']
|
['<s> You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTitle: The Ultimate Guide to Travelling: Tips, Tricks, and', '<s> You are a helpful assistant. What is the capital of France?\n\nYes, the capital of France is Paris.</s>']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Legacy cache format
|
||||||
|
|
||||||
|
Prior to the introduction of the `Cache` object, the cache of LLMs used to be a tuple of tuples of tensors. The legacy
|
||||||
|
format has a dynamic size, growing as we generate text -- very similar to `DynamicCache`. If your project depend on
|
||||||
|
this legacy format, you can seamlessly convert it to a `DynamicCache` and back.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
|
||||||
|
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
>>> # `return_dict_in_generate=True` is required to return the cache. `return_legacy_cache` forces the returned cache
|
||||||
|
>>> # to be of the legacy type
|
||||||
|
>>> generation_outputs = model.generate(**inputs, return_dict_in_generate=True, return_legacy_cache=True, max_new_tokens=5)
|
||||||
|
|
||||||
|
>>> # We can convert a legacy cache to a DynamicCache -- and the other way around. This is helpful if you have custom
|
||||||
|
>>> # logic to manipulate a cache in a specific format.
|
||||||
|
>>> cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
|
||||||
|
>>> legacy_format_cache = cache.to_legacy_cache()
|
||||||
|
```
|
||||||
|
@ -687,14 +687,18 @@ class BloomModel(BloomPreTrainedModel):
|
|||||||
inputs_embeds = self.word_embeddings(input_ids)
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
use_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
logger.warning_once(
|
past_key_values = DynamicCache()
|
||||||
"Using `past_key_values` as a tuple is deprecated and will be removed in v4.45. "
|
else:
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
@ -765,9 +769,9 @@ class BloomModel(BloomPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -526,14 +526,18 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.wte(input_ids)
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
if not self.training:
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = inputs_embeds.shape[1]
|
seq_length = inputs_embeds.shape[1]
|
||||||
@ -608,9 +612,9 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -910,16 +910,19 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if (
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
logger.warning_once(
|
past_key_values = DynamicCache()
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
else:
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -1059,16 +1059,19 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
|
|
||||||
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
|
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if (
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
logger.warning_once(
|
past_key_values = DynamicCache()
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
else:
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -1031,17 +1031,21 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
use_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
if not self.training:
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||||
alibi = None
|
alibi = None
|
||||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
@ -1126,9 +1130,9 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -476,12 +476,19 @@ class GemmaModel(LlamaModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False # noqa: F841
|
return_legacy_cache = False # noqa: F841
|
||||||
if (
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = True # noqa: F841
|
return_legacy_cache = True # noqa: F841
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -828,12 +828,19 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
return_legacy_cache = False # noqa: F841
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if (
|
return_legacy_cache = False
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
return_legacy_cache = True
|
||||||
return_legacy_cache = True # noqa: F841
|
if past_key_values is None:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
@ -856,15 +863,6 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
# See https://github.com/huggingface/transformers/pull/29402
|
# See https://github.com/huggingface/transformers/pull/29402
|
||||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||||
hidden_states = hidden_states * normalizer
|
hidden_states = hidden_states * normalizer
|
||||||
if (
|
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = True
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
logger.warning_once(
|
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
@ -417,14 +417,19 @@ class GitEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
@ -463,9 +468,9 @@ class GitEncoder(nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -741,14 +741,18 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.wte(input_ids)
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
if not self.training:
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = inputs_embeds.shape[1]
|
seq_length = inputs_embeds.shape[1]
|
||||||
@ -822,9 +826,9 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -943,14 +943,18 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_in(input_ids)
|
inputs_embeds = self.embed_in(input_ids)
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
if not self.training:
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = inputs_embeds.shape[1]
|
seq_length = inputs_embeds.shape[1]
|
||||||
@ -1021,9 +1025,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
||||||
|
@ -663,14 +663,18 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_in(input_ids)
|
inputs_embeds = self.embed_in(input_ids)
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
if not self.training:
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = inputs_embeds.shape[1]
|
seq_length = inputs_embeds.shape[1]
|
||||||
@ -725,9 +729,9 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
||||||
|
@ -813,14 +813,18 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.wte(input_ids)
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
if not self.training:
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = inputs_embeds.shape[1]
|
seq_length = inputs_embeds.shape[1]
|
||||||
@ -917,9 +921,9 @@ class GPTJModel(GPTJPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -834,14 +834,19 @@ class GraniteModel(GranitePreTrainedModel):
|
|||||||
|
|
||||||
inputs_embeds = inputs_embeds * self.embedding_multiplier
|
inputs_embeds = inputs_embeds * self.embedding_multiplier
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
logger.warning_once(
|
past_key_values = DynamicCache()
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
else:
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -1239,15 +1239,19 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache):
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
if not self.training:
|
|
||||||
logger.warning_once(
|
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
|
||||||
)
|
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -1345,11 +1345,19 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
|||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
past_seen_tokens = 0
|
past_seen_tokens = 0
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if use_cache:
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
return_legacy_cache = True
|
||||||
return_legacy_cache = True
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
past_seen_tokens = past_key_values.get_seq_length()
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
|
||||||
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||||
|
@ -1033,12 +1033,19 @@ class JetMoeModel(JetMoePreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if (
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -944,16 +944,19 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if (
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
logger.warning_once(
|
past_key_values = DynamicCache()
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
else:
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -762,14 +762,19 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -1018,14 +1018,19 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -1095,9 +1100,9 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -866,16 +866,19 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if (
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
use_cache and not isinstance(past_key_values, Cache) and not self.training
|
|
||||||
): # kept for BC (non `Cache` `past_key_values` inputs)
|
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
logger.warning_once(
|
past_key_values = DynamicCache()
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
else:
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -1007,14 +1007,19 @@ class OlmoeModel(OlmoePreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
return_legacy_cache = False
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
return_legacy_cache = True
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
if past_key_values is None:
|
||||||
logger.warning_once(
|
past_key_values = DynamicCache()
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
else:
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
)
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
@ -685,14 +685,19 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -761,9 +766,9 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
@ -976,14 +976,19 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -1053,9 +1058,10 @@ class PhiModel(PhiPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
|
@ -1003,14 +1003,19 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -1074,9 +1079,10 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
|
@ -915,14 +915,19 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -991,9 +996,9 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
@ -1079,14 +1079,19 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -1161,9 +1166,9 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -960,14 +960,19 @@ class StableLmModel(StableLmPreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -1036,9 +1041,9 @@ class StableLmModel(StableLmPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
@ -889,14 +889,19 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
use_legacy_cache = False
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
return_legacy_cache = False
|
||||||
use_legacy_cache = True
|
if use_cache and not isinstance(past_key_values, Cache):
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
return_legacy_cache = True
|
||||||
logger.warning_once(
|
if past_key_values is None:
|
||||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
|
past_key_values = DynamicCache()
|
||||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
else:
|
||||||
)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
logger.warning_once(
|
||||||
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||||
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||||
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||||
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
@ -966,9 +971,9 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
next_cache = next_cache.to_legacy_cache()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
Loading…
Reference in New Issue
Block a user