Generate: Deprecate returning legacy cache by default; Handle use_cache=False (#32863)

This commit is contained in:
Joao Gante 2024-08-22 20:01:52 +01:00 committed by GitHub
parent 09e6579d2d
commit a26de15139
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 311 additions and 256 deletions

View File

@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
> Parameters that control the cache
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
> Parameters for manipulation of the model output logits
@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin):
max_matching_ngram_size (`int`, *optional*, default to `None`):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
> Parameters specific to the caching mechanism:
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
> Wild card
generation_kwargs:
@ -352,7 +349,19 @@ class GenerationConfig(PushToHubMixin):
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.dola_layers = kwargs.pop("dola_layers", None)
# Parameters that control the cache
self.use_cache = kwargs.pop("use_cache", True)
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
@ -411,20 +420,6 @@ class GenerationConfig(PushToHubMixin):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
# DoLa generation
self.dola_layers = kwargs.pop("dola_layers", None)
# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)
# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
@ -544,8 +539,9 @@ class GenerationConfig(PushToHubMixin):
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
if self.pad_token_id is not None and self.pad_token_id < 0:
warnings.warn(
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. "
"Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values."
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
"generating, if there is padding. Please set `pad_token_id` explicitly as "
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
)
# Validation of attribute relations:
@ -675,6 +671,14 @@ class GenerationConfig(PushToHubMixin):
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# DoLa generation
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
warnings.warn(
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
UserWarning,
)
# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
@ -690,7 +694,7 @@ class GenerationConfig(PushToHubMixin):
f"({self.num_beams})."
)
# 5. check `cache_config`
# 5. check cache-related arguments
if self.cache_config is not None:
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
if cache_class is None:
@ -702,6 +706,20 @@ class GenerationConfig(PushToHubMixin):
if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate()
if self.use_cache is False:
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
# (otherwise a user might need to overwrite several parameters).
no_cache_warning = (
"You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
"have no effect."
)
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
if getattr(self, arg_name) is not None:
logger.warning_once(
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
UserWarning,
)
# 6. check watermarking arguments
if self.watermarking_config is not None:
@ -727,17 +745,6 @@ class GenerationConfig(PushToHubMixin):
"`generate()` (or a pipeline) directly."
)
# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
dola_decoding_wrong_parameter_msg = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
)
warnings.warn(
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
UserWarning,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput):
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor = None
@ -176,36 +172,32 @@ class GenerateEncoderDecoderOutput(ModelOutput):
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor = None
@ -228,33 +220,29 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput):
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor = None
@ -276,43 +264,39 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
sequence_length)`.
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
"""
sequences: torch.LongTensor = None
@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
# TODO (joao): remove the equivalent classes and typing shortcuts below in v5
# Equivalent classes (kept for retrocompatibility purposes)
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
@ -1501,6 +1486,121 @@ class GenerationMixin:
"""
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()
def _prepare_cache_for_generation(
self,
generation_config: GenerationConfig,
model_kwargs: Dict,
assistant_model: "PreTrainedModel",
batch_size: int,
device: torch.device,
) -> bool:
"""
Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is
instantiated, writes it to `model_kwargs`, under the name expected by the model.
"""
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
# Quick escape route 1: if the user specifies a cache, we only need to:
# a) check for conflicting `generate` arguments
# b) convert to the new cache format (if the user passes a legacy cache and model supports it)
user_defined_cache = model_kwargs.get(cache_name)
if user_defined_cache is not None:
if generation_config.cache_implementation is not None:
raise ValueError(
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
"Cache object) is unsupported. Please use only one of the two."
)
if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
model_kwargs[cache_name] = (
DynamicCache.from_legacy_cache(user_defined_cache)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
)
return
# Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
# `generation_config.validate()`)
if generation_config.use_cache is False:
return
# Quick escape route 3: model that only supports legacy caches = nothing to prepare
if not self._supports_default_dynamic_cache():
if generation_config.cache_implementation is not None:
warnings.warn(
"This model does not support `Cache` instances, it only supports the legacy cache format (tuple "
f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be "
"ignored.",
UserWarning,
)
return
# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if assistant_model is not None and generation_config.cache_implementation is not None:
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue and tag @zucchini-nlp."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_quanto_available():
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs[cache_name] = cache_class(cache_config)
elif generation_config.cache_implementation == "offloaded":
model_kwargs[cache_name] = OffloadedCache()
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
else:
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
def _prepare_special_tokens(
self,
generation_config: GenerationConfig,
@ -1776,104 +1876,18 @@ class GenerationMixin:
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
use_dynamic_cache_by_default = False
if "mamba" in self.__class__.__name__.lower():
cache_name = "cache_params"
else:
cache_name = "past_key_values"
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if (
assistant_model is not None
and generation_config.cache_implementation is not None
and self._supports_default_dynamic_cache()
):
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None
if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
raise ValueError(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
"may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` "
"input argument."
)
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
raise ValueError(
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_quanto_available():
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs[cache_name] = cache_class(cache_config)
elif generation_config.cache_implementation == "offloaded":
model_kwargs[cache_name] = OffloadedCache()
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get(cache_name, None)
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if past is None:
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs[cache_name] = (
DynamicCache.from_legacy_cache(past)
if not requires_cross_attention_cache
else EncoderDecoderCache.from_legacy_cache(past)
)
use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode
# 7. Prepare the cache.
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
# - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
user_defined_cache = model_kwargs.get(cache_name)
self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device)
# 8. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)
if streamer is not None and (generation_config.num_beams > 1):
@ -1892,7 +1906,7 @@ class GenerationMixin:
UserWarning,
)
# 8. prepare distribution pre_processing samplers
# 9. prepare logits processors and stopping criteria
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
@ -1904,8 +1918,6 @@ class GenerationMixin:
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# 9. prepare stopping criteria
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
@ -2138,11 +2150,34 @@ class GenerationMixin:
**model_kwargs,
)
# Convert to legacy cache if needed
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
result.past_key_values = result.past_key_values.to_legacy_cache()
# Convert to legacy cache format if requested
if (
generation_config.return_legacy_cache is not False # Should check for `True` after v4.47
and not is_torchdynamo_compiling()
and hasattr(result, "past_key_values")
and hasattr(result.past_key_values, "to_legacy_cache")
and result.past_key_values.to_legacy_cache is not None
):
# handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type)
should_convert_cache = generation_config.return_legacy_cache
is_user_defined_cache = user_defined_cache is not None
is_default_cache_type = (
type(result.past_key_values) == DynamicCache # noqa E721
or (
isinstance(result.past_key_values, EncoderDecoderCache)
and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721
and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721
)
)
if not is_user_defined_cache and is_default_cache_type:
logger.warning_once(
"From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` "
"instance instead by default (as opposed to the legacy tuple of tuples format). If you want to "
"keep returning the legacy format, please set `return_legacy_cache=True`."
)
should_convert_cache = True
if should_convert_cache:
result.past_key_values = result.past_key_values.to_legacy_cache()
return result
def _has_unfinished_sequences(

View File

@ -768,7 +768,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs):
# only last token for inputs_ids if the state is passed along.
if state is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
@ -780,6 +780,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
model_inputs = {"input_ids": input_ids}
model_inputs["state"] = state
model_inputs["use_cache"] = use_cache
return model_inputs
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)

View File

@ -194,6 +194,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@ -207,6 +208,7 @@ class GenerationTesterMixin:
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**logits_processor_kwargs,
**model_kwargs,
)
@ -224,6 +226,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
@ -239,6 +242,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**logits_processor_kwargs,
**model_kwargs,
)
@ -256,6 +260,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@ -268,6 +273,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@ -286,6 +292,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
@ -299,6 +306,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@ -317,6 +325,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@ -329,6 +338,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@ -348,6 +358,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
@ -361,6 +372,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
constraints=constraints,
use_cache=use_cache,
**beam_kwargs,
**logits_processor_kwargs,
**model_kwargs,
@ -378,6 +390,7 @@ class GenerationTesterMixin:
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
):
contrastive_search_kwargs = {
"penalty_alpha": 0.6,
@ -396,6 +409,7 @@ class GenerationTesterMixin:
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
use_cache=use_cache,
**logits_processor_kwargs,
**model_kwargs,
**contrastive_search_kwargs,
@ -419,7 +433,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
@ -430,6 +443,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@ -454,7 +468,6 @@ class GenerationTesterMixin:
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
@ -466,6 +479,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=True,
)
if model.config.is_encoder_decoder:
@ -495,7 +509,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
model=model,
@ -507,6 +520,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@ -545,9 +559,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_search_generate(
@ -560,6 +571,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
@ -589,7 +601,6 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
@ -602,6 +613,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=True,
)
if model.config.is_encoder_decoder:
@ -676,9 +688,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@ -692,6 +701,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@ -764,7 +774,6 @@ class GenerationTesterMixin:
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs()
@ -778,6 +787,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
@ -857,9 +867,6 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
# Sample constraints
@ -882,6 +889,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=False,
)
if model.config.is_encoder_decoder:
@ -913,13 +921,12 @@ class GenerationTesterMixin:
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
# test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval()
output_generate = self._contrastive_generate(
model=model, input_ids=input_ids, attention_mask=attention_mask
model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
@ -940,7 +947,6 @@ class GenerationTesterMixin:
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
@ -953,6 +959,7 @@ class GenerationTesterMixin:
output_hidden_states=True,
output_attentions=self.has_attentions,
return_dict_in_generate=True,
use_cache=True,
)
if model.config.is_encoder_decoder:
@ -978,7 +985,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
@ -991,6 +997,7 @@ class GenerationTesterMixin:
low_memory=True,
max_new_tokens=self.max_new_tokens,
attention_mask=attention_mask,
use_cache=True,
)
high_output = model.generate(
@ -1000,6 +1007,7 @@ class GenerationTesterMixin:
low_memory=False,
max_new_tokens=self.max_new_tokens,
attention_mask=attention_mask,
use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@ -1031,10 +1039,17 @@ class GenerationTesterMixin:
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
low_output = model.generate(
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True
)
high_output = model.generate(
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
input_ids,
max_new_tokens=8,
num_beams=5,
early_stopping=True,
low_memory=False,
use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@ -1079,7 +1094,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@ -1098,6 +1112,7 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": True,
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@ -1150,7 +1165,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@ -1169,6 +1183,7 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": True,
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@ -1196,12 +1211,6 @@ class GenerationTesterMixin:
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, input_ids, attention_mask = self._get_input_ids_and_config()
# Some models don't support the cache and returning past_key_values
if not hasattr(config, "use_cache"):
config.use_cache = False
else:
config.use_cache = True
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
self.skipTest("DoLa is not supported for encoder-decoder models")
@ -1224,11 +1233,12 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": hasattr(config, "use_cache"), # Some models don't support the cache
}
generation_kwargs.update({"dola_layers": "low"})
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
def test_assisted_decoding_sample(self):
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
@ -1261,7 +1271,6 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"):
self.skipTest(reason="This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# Sets assisted generation arguments such that:
@ -1284,6 +1293,7 @@ class GenerationTesterMixin:
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
"use_cache": True,
}
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@ -1566,7 +1576,6 @@ class GenerationTesterMixin:
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
config.use_cache = True
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
@ -1574,6 +1583,7 @@ class GenerationTesterMixin:
model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.generation_config.use_cache = True
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
@ -1631,7 +1641,6 @@ class GenerationTesterMixin:
self.skipTest(reason="This model does not support the new cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
@ -1640,6 +1649,7 @@ class GenerationTesterMixin:
"num_beams": num_beams,
"num_return_sequences": num_beams,
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
# Sets seed before calling `generate` for the case with do_sample=True
@ -1701,7 +1711,6 @@ class GenerationTesterMixin:
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
config.use_cache = True
config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_new_tokens = 20
@ -1712,6 +1721,7 @@ class GenerationTesterMixin:
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
max_cache_len = seq_length + max_new_tokens
@ -1740,7 +1750,6 @@ class GenerationTesterMixin:
self.skipTest(reason="This model does not support the quantized cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
@ -1750,6 +1759,7 @@ class GenerationTesterMixin:
# careful with group size, should be divisor of model's hidden size
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
@ -1890,22 +1900,24 @@ class GenerationTesterMixin:
# Past Key Value States -- a few notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
# standard cache format (e.g.gptbigcode )
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)
if use_cache and has_standard_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
if has_standard_cache:
if use_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
elif use_cache is False:
self.assertTrue(output.past_key_values is None)
def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size)