mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Generate: Deprecate returning legacy cache by default; Handle use_cache=False
(#32863)
This commit is contained in:
parent
09e6579d2d
commit
a26de15139
@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin):
|
||||
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
penalty_alpha (`float`, *optional*):
|
||||
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
|
||||
dola_layers (`str` or `List[int]`, *optional*):
|
||||
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
|
||||
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
|
||||
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
|
||||
layers up to the last 20 layers.
|
||||
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
|
||||
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
|
||||
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
|
||||
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
|
||||
|
||||
> Parameters that control the cache
|
||||
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
cache_implementation (`str`, *optional*, default to `None`):
|
||||
Cache class that should be used when generating.
|
||||
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
||||
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
||||
it will be converted to its repsective `CacheConfig` internally.
|
||||
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
|
||||
return_legacy_cache (`bool`, *optional*, default to `True`):
|
||||
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
|
||||
|
||||
> Parameters for manipulation of the model output logits
|
||||
|
||||
@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
max_matching_ngram_size (`int`, *optional*, default to `None`):
|
||||
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
|
||||
|
||||
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
|
||||
|
||||
dola_layers (`str` or `List[int]`, *optional*):
|
||||
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
|
||||
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
|
||||
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
|
||||
layers up to the last 20 layers.
|
||||
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
|
||||
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
|
||||
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
|
||||
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
|
||||
|
||||
> Parameters specific to the caching mechanism:
|
||||
|
||||
cache_implementation (`str`, *optional*, default to `None`):
|
||||
Cache class that should be used when generating.
|
||||
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
||||
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
||||
it will be converted to its repsective `CacheConfig` internally.
|
||||
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
|
||||
return_legacy_cache (`bool`, *optional*, default to `True`):
|
||||
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
|
||||
|
||||
> Wild card
|
||||
|
||||
generation_kwargs:
|
||||
@ -352,7 +349,19 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.num_beams = kwargs.pop("num_beams", 1)
|
||||
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
||||
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
|
||||
self.dola_layers = kwargs.pop("dola_layers", None)
|
||||
|
||||
# Parameters that control the cache
|
||||
self.use_cache = kwargs.pop("use_cache", True)
|
||||
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
||||
self.cache_config = kwargs.pop("cache_config", None)
|
||||
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
|
||||
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
|
||||
if self.cache_config is None:
|
||||
self.cache_config = cache_config_class()
|
||||
elif isinstance(self.cache_config, dict):
|
||||
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
||||
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
||||
|
||||
# Parameters for manipulation of the model output logits
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
@ -411,20 +420,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
||||
|
||||
# DoLa generation
|
||||
self.dola_layers = kwargs.pop("dola_layers", None)
|
||||
|
||||
# Cache implementation
|
||||
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
||||
self.cache_config = kwargs.pop("cache_config", None)
|
||||
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
|
||||
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
|
||||
if self.cache_config is None:
|
||||
self.cache_config = cache_config_class()
|
||||
elif isinstance(self.cache_config, dict):
|
||||
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
||||
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)
|
||||
|
||||
# Prompt lookup decoding
|
||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
||||
@ -544,8 +539,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
||||
if self.pad_token_id is not None and self.pad_token_id < 0:
|
||||
warnings.warn(
|
||||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. "
|
||||
"Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values."
|
||||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
|
||||
"generating, if there is padding. Please set `pad_token_id` explicitly as "
|
||||
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
|
||||
)
|
||||
|
||||
# Validation of attribute relations:
|
||||
@ -675,6 +671,14 @@ class GenerationConfig(PushToHubMixin):
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
# DoLa generation
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
warnings.warn(
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
@ -690,7 +694,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 5. check `cache_config`
|
||||
# 5. check cache-related arguments
|
||||
if self.cache_config is not None:
|
||||
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
|
||||
if cache_class is None:
|
||||
@ -702,6 +706,20 @@ class GenerationConfig(PushToHubMixin):
|
||||
if not isinstance(self.cache_config, cache_class):
|
||||
self.cache_config = cache_class.from_dict(self.cache_config)
|
||||
self.cache_config.validate()
|
||||
if self.use_cache is False:
|
||||
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
||||
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
||||
# (otherwise a user might need to overwrite several parameters).
|
||||
no_cache_warning = (
|
||||
"You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
|
||||
"have no effect."
|
||||
)
|
||||
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
|
||||
if getattr(self, arg_name) is not None:
|
||||
logger.warning_once(
|
||||
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 6. check watermarking arguments
|
||||
if self.watermarking_config is not None:
|
||||
@ -727,17 +745,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
"`generate()` (or a pipeline) directly."
|
||||
)
|
||||
|
||||
# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
dola_decoding_wrong_parameter_msg = (
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
|
||||
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
|
||||
)
|
||||
warnings.warn(
|
||||
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput):
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
||||
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -176,36 +172,32 @@ class GenerateEncoderDecoderOutput(ModelOutput):
|
||||
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, sequence_length, hidden_size)`.
|
||||
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -228,33 +220,29 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput):
|
||||
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
|
||||
Final beam scores of the generated `sequences`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
||||
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||
`(batch_size*num_return_sequences, sequence_length)`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
||||
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -276,43 +264,39 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
|
||||
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
|
||||
Final beam scores of the generated `sequences`.
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
||||
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||
`(batch_size*num_return_sequences, sequence_length)`.
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
|
||||
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
|
||||
sequence_length)`.
|
||||
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
||||
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
# TODO (joao): remove the equivalent classes and typing shortcuts below in v5
|
||||
# Equivalent classes (kept for retrocompatibility purposes)
|
||||
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
||||
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
||||
@ -1501,6 +1486,121 @@ class GenerationMixin:
|
||||
"""
|
||||
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()
|
||||
|
||||
def _prepare_cache_for_generation(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
model_kwargs: Dict,
|
||||
assistant_model: "PreTrainedModel",
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> bool:
|
||||
"""
|
||||
Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is
|
||||
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
||||
"""
|
||||
|
||||
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
|
||||
# Quick escape route 1: if the user specifies a cache, we only need to:
|
||||
# a) check for conflicting `generate` arguments
|
||||
# b) convert to the new cache format (if the user passes a legacy cache and model supports it)
|
||||
user_defined_cache = model_kwargs.get(cache_name)
|
||||
if user_defined_cache is not None:
|
||||
if generation_config.cache_implementation is not None:
|
||||
raise ValueError(
|
||||
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
|
||||
"Cache object) is unsupported. Please use only one of the two."
|
||||
)
|
||||
if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
|
||||
model_kwargs[cache_name] = (
|
||||
DynamicCache.from_legacy_cache(user_defined_cache)
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
|
||||
)
|
||||
return
|
||||
|
||||
# Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
|
||||
# `generation_config.validate()`)
|
||||
if generation_config.use_cache is False:
|
||||
return
|
||||
|
||||
# Quick escape route 3: model that only supports legacy caches = nothing to prepare
|
||||
if not self._supports_default_dynamic_cache():
|
||||
if generation_config.cache_implementation is not None:
|
||||
warnings.warn(
|
||||
"This model does not support `Cache` instances, it only supports the legacy cache format (tuple "
|
||||
f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be "
|
||||
"ignored.",
|
||||
UserWarning,
|
||||
)
|
||||
return
|
||||
|
||||
# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
|
||||
|
||||
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
||||
# which is only supported in dynamic caches atm
|
||||
if assistant_model is not None and generation_config.cache_implementation is not None:
|
||||
logger.warning_once(
|
||||
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
||||
f"'{generation_config.cache_implementation}'."
|
||||
)
|
||||
generation_config.cache_implementation = None
|
||||
|
||||
if generation_config.cache_implementation is not None:
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
||||
raise ValueError(
|
||||
"This model does not support `cache_implementation='static'`. Please check the following "
|
||||
"issue: https://github.com/huggingface/transformers/issues/28981"
|
||||
)
|
||||
model_kwargs[cache_name] = self._get_cache(
|
||||
cache_implementation=generation_config.cache_implementation,
|
||||
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
|
||||
max_cache_len=generation_config.max_length,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif generation_config.cache_implementation == "quantized":
|
||||
if not self._supports_quantized_cache:
|
||||
raise ValueError(
|
||||
"This model does not support the quantized cache. If you want your model to support quantized "
|
||||
"cache, please open an issue and tag @zucchini-nlp."
|
||||
)
|
||||
|
||||
cache_config = (
|
||||
generation_config.cache_config
|
||||
if generation_config.cache_config is not None
|
||||
else QuantizedCacheConfig()
|
||||
)
|
||||
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
|
||||
|
||||
if cache_config.backend == "quanto" and not is_quanto_available():
|
||||
raise ImportError(
|
||||
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
|
||||
"Please install it via with `pip install quanto`"
|
||||
)
|
||||
elif cache_config.backend == "HQQ" and not is_hqq_available():
|
||||
raise ImportError(
|
||||
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
||||
"Please install it via with `pip install hqq`"
|
||||
)
|
||||
|
||||
model_kwargs[cache_name] = cache_class(cache_config)
|
||||
elif generation_config.cache_implementation == "offloaded":
|
||||
model_kwargs[cache_name] = OffloadedCache()
|
||||
|
||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||
# keeps copying the cache thus using much more memory
|
||||
else:
|
||||
model_kwargs[cache_name] = (
|
||||
DynamicCache()
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
)
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
@ -1776,104 +1876,18 @@ class GenerationMixin:
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
use_dynamic_cache_by_default = False
|
||||
if "mamba" in self.__class__.__name__.lower():
|
||||
cache_name = "cache_params"
|
||||
else:
|
||||
cache_name = "past_key_values"
|
||||
|
||||
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
||||
# which is only supported in dynamic caches atm
|
||||
if (
|
||||
assistant_model is not None
|
||||
and generation_config.cache_implementation is not None
|
||||
and self._supports_default_dynamic_cache()
|
||||
):
|
||||
logger.warning_once(
|
||||
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
||||
f"'{generation_config.cache_implementation}'."
|
||||
)
|
||||
generation_config.cache_implementation = None
|
||||
|
||||
if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
|
||||
raise ValueError(
|
||||
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
|
||||
"may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` "
|
||||
"input argument."
|
||||
)
|
||||
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
|
||||
raise ValueError(
|
||||
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
|
||||
"Cache object) is unsupported. Please use only one of the two."
|
||||
)
|
||||
elif generation_config.cache_implementation is not None:
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
||||
raise ValueError(
|
||||
"This model does not support `cache_implementation='static'`. Please check the following "
|
||||
"issue: https://github.com/huggingface/transformers/issues/28981"
|
||||
)
|
||||
model_kwargs[cache_name] = self._get_cache(
|
||||
cache_implementation=generation_config.cache_implementation,
|
||||
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
|
||||
max_cache_len=generation_config.max_length,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif generation_config.cache_implementation == "quantized":
|
||||
if not self._supports_quantized_cache:
|
||||
raise ValueError(
|
||||
"This model does not support the quantized cache. If you want your model to support quantized "
|
||||
"cache, please open an issue."
|
||||
)
|
||||
|
||||
cache_config = (
|
||||
generation_config.cache_config
|
||||
if generation_config.cache_config is not None
|
||||
else QuantizedCacheConfig()
|
||||
)
|
||||
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
|
||||
|
||||
if cache_config.backend == "quanto" and not is_quanto_available():
|
||||
raise ImportError(
|
||||
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
|
||||
"Please install it via with `pip install quanto`"
|
||||
)
|
||||
elif cache_config.backend == "HQQ" and not is_hqq_available():
|
||||
raise ImportError(
|
||||
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
||||
"Please install it via with `pip install hqq`"
|
||||
)
|
||||
|
||||
model_kwargs[cache_name] = cache_class(cache_config)
|
||||
elif generation_config.cache_implementation == "offloaded":
|
||||
model_kwargs[cache_name] = OffloadedCache()
|
||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||
# keeps copying the cache thus using much more memory
|
||||
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
||||
past = model_kwargs.get(cache_name, None)
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
if past is None:
|
||||
model_kwargs[cache_name] = (
|
||||
DynamicCache()
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
)
|
||||
use_dynamic_cache_by_default = True
|
||||
elif isinstance(past, tuple):
|
||||
model_kwargs[cache_name] = (
|
||||
DynamicCache.from_legacy_cache(past)
|
||||
if not requires_cross_attention_cache
|
||||
else EncoderDecoderCache.from_legacy_cache(past)
|
||||
)
|
||||
use_dynamic_cache_by_default = True
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 7. determine generation mode
|
||||
# 7. Prepare the cache.
|
||||
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
|
||||
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||
user_defined_cache = model_kwargs.get(cache_name)
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device)
|
||||
|
||||
# 8. determine generation mode
|
||||
generation_mode = generation_config.get_generation_mode(assistant_model)
|
||||
|
||||
if streamer is not None and (generation_config.num_beams > 1):
|
||||
@ -1892,7 +1906,7 @@ class GenerationMixin:
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 8. prepare distribution pre_processing samplers
|
||||
# 9. prepare logits processors and stopping criteria
|
||||
prepared_logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
@ -1904,8 +1918,6 @@ class GenerationMixin:
|
||||
negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 9. prepare stopping criteria
|
||||
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
||||
)
|
||||
@ -2138,11 +2150,34 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Convert to legacy cache if needed
|
||||
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
|
||||
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
|
||||
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
|
||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||
# Convert to legacy cache format if requested
|
||||
if (
|
||||
generation_config.return_legacy_cache is not False # Should check for `True` after v4.47
|
||||
and not is_torchdynamo_compiling()
|
||||
and hasattr(result, "past_key_values")
|
||||
and hasattr(result.past_key_values, "to_legacy_cache")
|
||||
and result.past_key_values.to_legacy_cache is not None
|
||||
):
|
||||
# handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type)
|
||||
should_convert_cache = generation_config.return_legacy_cache
|
||||
is_user_defined_cache = user_defined_cache is not None
|
||||
is_default_cache_type = (
|
||||
type(result.past_key_values) == DynamicCache # noqa E721
|
||||
or (
|
||||
isinstance(result.past_key_values, EncoderDecoderCache)
|
||||
and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721
|
||||
and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721
|
||||
)
|
||||
)
|
||||
if not is_user_defined_cache and is_default_cache_type:
|
||||
logger.warning_once(
|
||||
"From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` "
|
||||
"instance instead by default (as opposed to the legacy tuple of tuples format). If you want to "
|
||||
"keep returning the legacy format, please set `return_legacy_cache=True`."
|
||||
)
|
||||
should_convert_cache = True
|
||||
if should_convert_cache:
|
||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||
return result
|
||||
|
||||
def _has_unfinished_sequences(
|
||||
|
@ -768,7 +768,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs):
|
||||
# only last token for inputs_ids if the state is passed along.
|
||||
if state is not None:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
@ -780,6 +780,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs["state"] = state
|
||||
model_inputs["use_cache"] = use_cache
|
||||
return model_inputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
|
||||
|
@ -194,6 +194,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@ -207,6 +208,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
@ -224,6 +226,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
@ -239,6 +242,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
@ -256,6 +260,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@ -268,6 +273,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@ -286,6 +292,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
@ -299,6 +306,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@ -317,6 +325,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@ -329,6 +338,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@ -348,6 +358,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
@ -361,6 +372,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
constraints=constraints,
|
||||
use_cache=use_cache,
|
||||
**beam_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
@ -378,6 +390,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
contrastive_search_kwargs = {
|
||||
"penalty_alpha": 0.6,
|
||||
@ -396,6 +409,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
use_cache=use_cache,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
@ -419,7 +433,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
@ -430,6 +443,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -454,7 +468,6 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
@ -466,6 +479,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -495,7 +509,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
@ -507,6 +520,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -545,9 +559,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
output_generate = self._beam_search_generate(
|
||||
@ -560,6 +571,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
@ -589,7 +601,6 @@ class GenerationTesterMixin:
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._beam_search_generate(
|
||||
@ -602,6 +613,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -676,9 +688,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
@ -692,6 +701,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -764,7 +774,6 @@ class GenerationTesterMixin:
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
@ -778,6 +787,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
@ -857,9 +867,6 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# Sample constraints
|
||||
@ -882,6 +889,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -913,13 +921,12 @@ class GenerationTesterMixin:
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test old generation output for backwards compatibility
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._contrastive_generate(
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask
|
||||
model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@ -940,7 +947,6 @@ class GenerationTesterMixin:
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@ -953,6 +959,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -978,7 +985,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test output equality of low versus high memory
|
||||
@ -991,6 +997,7 @@ class GenerationTesterMixin:
|
||||
low_memory=True,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
@ -1000,6 +1007,7 @@ class GenerationTesterMixin:
|
||||
low_memory=False,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@ -1031,10 +1039,17 @@ class GenerationTesterMixin:
|
||||
# test output equality of low versus high memory
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
|
||||
low_output = model.generate(
|
||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True
|
||||
)
|
||||
|
||||
high_output = model.generate(
|
||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
|
||||
input_ids,
|
||||
max_new_tokens=8,
|
||||
num_beams=5,
|
||||
early_stopping=True,
|
||||
low_memory=False,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@ -1079,7 +1094,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# Sets assisted generation arguments such that:
|
||||
@ -1098,6 +1112,7 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
@ -1150,7 +1165,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# Sets assisted generation arguments such that:
|
||||
@ -1169,6 +1183,7 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@ -1196,12 +1211,6 @@ class GenerationTesterMixin:
|
||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# Some models don't support the cache and returning past_key_values
|
||||
if not hasattr(config, "use_cache"):
|
||||
config.use_cache = False
|
||||
else:
|
||||
config.use_cache = True
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||
@ -1224,11 +1233,12 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": hasattr(config, "use_cache"), # Some models don't support the cache
|
||||
}
|
||||
generation_kwargs.update({"dola_layers": "low"})
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
|
||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
|
||||
|
||||
def test_assisted_decoding_sample(self):
|
||||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||
@ -1261,7 +1271,6 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# Sets assisted generation arguments such that:
|
||||
@ -1284,6 +1293,7 @@ class GenerationTesterMixin:
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
@ -1566,7 +1576,6 @@ class GenerationTesterMixin:
|
||||
# 3. ignore `token_type_ids` for simplicity
|
||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||
# active by default on some models
|
||||
config.use_cache = True
|
||||
if "token_type_ids" in inputs:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
@ -1574,6 +1583,7 @@ class GenerationTesterMixin:
|
||||
model.eval()
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
model.generation_config.use_cache = True
|
||||
|
||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs)
|
||||
@ -1631,7 +1641,6 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
@ -1640,6 +1649,7 @@ class GenerationTesterMixin:
|
||||
"num_beams": num_beams,
|
||||
"num_return_sequences": num_beams,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
# Sets seed before calling `generate` for the case with do_sample=True
|
||||
@ -1701,7 +1711,6 @@ class GenerationTesterMixin:
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
batch_size, seq_length = input_ids.shape
|
||||
max_new_tokens = 20
|
||||
@ -1712,6 +1721,7 @@ class GenerationTesterMixin:
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
@ -1740,7 +1750,6 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="This model does not support the quantized cache format")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
@ -1750,6 +1759,7 @@ class GenerationTesterMixin:
|
||||
# careful with group size, should be divisor of model's hidden size
|
||||
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@ -1890,22 +1900,24 @@ class GenerationTesterMixin:
|
||||
|
||||
# Past Key Value States -- a few notes here:
|
||||
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
|
||||
# complete
|
||||
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
|
||||
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
|
||||
# standard cache format (e.g.gptbigcode )
|
||||
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
if use_cache and has_standard_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
num_sequences_in_output,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
)
|
||||
if has_standard_cache:
|
||||
if use_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
num_sequences_in_output,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
)
|
||||
elif use_cache is False:
|
||||
self.assertTrue(output.past_key_values is None)
|
||||
|
||||
def _check_scores(self, batch_size, scores, length, config):
|
||||
expected_shape = (batch_size, config.vocab_size)
|
||||
|
Loading…
Reference in New Issue
Block a user