From a6c82d45670918acf1a9331abf89be210c72ca4f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 2 Nov 2023 15:39:21 +0000 Subject: [PATCH] Generate: return `past_key_values` (#25086) --- src/transformers/generation/utils.py | 117 ++++++++++++++++++++++++-- tests/generation/test_utils.py | 121 +++++++++++++++++++++++++++ 2 files changed, 233 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index df4239d0542..69cbc373e5f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -104,12 +104,20 @@ class GreedySearchDecoderOnlyOutput(ModelOutput): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -140,6 +148,13 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None @@ -149,6 +164,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -169,15 +185,23 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is - passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples + (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, + hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -211,6 +235,13 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None @@ -220,6 +251,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -243,12 +275,20 @@ class SampleDecoderOnlyOutput(ModelOutput): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -283,6 +323,13 @@ class SampleEncoderDecoderOutput(ModelOutput): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None @@ -292,6 +339,7 @@ class SampleEncoderDecoderOutput(ModelOutput): decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -319,6 +367,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None @@ -327,6 +382,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): beam_indices: Optional[torch.LongTensor] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -366,6 +422,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None @@ -377,6 +440,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -404,6 +468,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None @@ -412,6 +483,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): beam_indices: Optional[torch.LongTensor] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @dataclass @@ -450,6 +522,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None @@ -461,6 +540,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] @@ -2148,8 +2228,8 @@ class GenerationMixin: items.append(item.repeat_interleave(1, dim=0)) else: items.append(item.repeat_interleave(top_k, dim=0)) - new_key_values.append(items) - model_kwargs["past_key_values"] = new_key_values + new_key_values.append(tuple(items)) + model_kwargs["past_key_values"] = tuple(new_key_values) if sequential: all_outputs = {key: [] for key in outputs} # defined in first loop iteration @@ -2330,6 +2410,17 @@ class GenerationMixin: streamer.end() if return_dict_in_generate: + # Contrastive search works by forward looking at the next token, so we need to exclude it from + # `past_key_values` to be consistent with the other decoding methods + if model_kwargs.get("past_key_values") is not None: + past_key_values = [] + for layer in model_kwargs["past_key_values"]: + layer_past_key_values = [] + for item in layer: + layer_past_key_values.append(item[..., :-1, :]) + past_key_values.append(tuple(layer_past_key_values)) + model_kwargs["past_key_values"] = tuple(past_key_values) + if self.config.is_encoder_decoder: return ContrastiveSearchEncoderDecoderOutput( sequences=input_ids, @@ -2339,6 +2430,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return ContrastiveSearchDecoderOnlyOutput( @@ -2346,6 +2438,7 @@ class GenerationMixin: scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids @@ -2598,6 +2691,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return GreedySearchDecoderOnlyOutput( @@ -2605,6 +2699,7 @@ class GenerationMixin: scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids @@ -2880,6 +2975,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return SampleDecoderOnlyOutput( @@ -2887,6 +2983,7 @@ class GenerationMixin: scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids @@ -3201,6 +3298,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return BeamSearchDecoderOnlyOutput( @@ -3210,6 +3308,7 @@ class GenerationMixin: beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return sequence_outputs["sequences"] @@ -3530,6 +3629,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return BeamSampleDecoderOnlyOutput( @@ -3539,6 +3639,7 @@ class GenerationMixin: beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return sequence_outputs["sequences"] @@ -3909,6 +4010,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return BeamSearchDecoderOnlyOutput( @@ -3918,6 +4020,7 @@ class GenerationMixin: beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return sequence_outputs["sequences"] @@ -4244,6 +4347,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return BeamSearchDecoderOnlyOutput( @@ -4253,6 +4357,7 @@ class GenerationMixin: beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return sequence_outputs["sequences"] @@ -4672,6 +4777,7 @@ class GenerationMixin: decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return GreedySearchDecoderOnlyOutput( @@ -4679,6 +4785,7 @@ class GenerationMixin: scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6468973d675..7e2f242c6fd 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1829,6 +1829,85 @@ class GenerationTesterMixin: outputs_from_embeds_wo_ids[:, 1:].tolist(), ) + def test_generate_continue_from_past_key_values(self): + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + # won't fix: old models with unique inputs/caches/others + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): + return + # may fix in the future: needs modeling or test input preparation fixes for compatibility + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + return + + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + # If it doesn't support cache, pass the test + if not hasattr(config, "use_cache"): + return + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + config.use_cache = True + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + + model = model_class(config).to(torch_device) + model.eval() + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None + + # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs) + if "past_key_values" not in outputs: + return + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[-1] + if config.is_encoder_decoder: + inputs["decoder_input_ids"] = outputs_cached.sequences + if "decoder_attention_mask" in inputs: + inputs["decoder_attention_mask"] = torch.nn.functional.pad( + inputs["decoder_attention_mask"], + (0, new_attention_len - inputs["decoder_attention_mask"].shape[1]), + mode="constant", + value=1, + ) + else: + inputs["input_ids"] = outputs_cached.sequences + if "attention_mask" in inputs: + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], + (0, new_attention_len - inputs["attention_mask"].shape[1]), + mode="constant", + value=1, + ) + outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) + + # The two sets of generated text and past kv should be equal to each other + self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) + for layer_idx in range(len(outputs_cached.past_key_values)): + for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + outputs_cached.past_key_values[layer_idx][kv_idx], + ) + ) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences @@ -1894,6 +1973,24 @@ class GenerationTesterMixin: use_cache=use_cache, ) + # Past Key Value States -- two notes here: + # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" + # 2. Some old models still return `output.past_key_values` even without `use_cache=True` + # 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete + models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer") + has_standard_cache = not any( + model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache + ) + if use_cache and has_standard_cache: + past_key_values = output.past_key_values + past_sequence_length = output.sequences.shape[-1] - 1 + self._check_past_key_values_for_generate( + num_sequences_in_output, + past_key_values, + seq_length=past_sequence_length, + config=config, + ) + def _check_scores(self, batch_size, scores, length, config): expected_shape = (batch_size, config.vocab_size) self.assertIsInstance(scores, tuple) @@ -1959,6 +2056,30 @@ class GenerationTesterMixin: [encoder_expected_shape] * len(hidden_states), ) + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): + self.assertIsInstance(past_key_values, tuple) + self.assertListEqual( + [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values], + [True] * len(past_key_values), + ) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size * num_beam_groups, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + seq_length, + config.hidden_size // config.num_attention_heads, + ) + # check shape key, value + self.assertListEqual( + [layer_past_key_values[0].shape for layer_past_key_values in past_key_values], + [expected_shape] * len(past_key_values), + ) + self.assertListEqual( + [layer_past_key_values[1].shape for layer_past_key_values in past_key_values], + [expected_shape] * len(past_key_values), + ) + def _check_sequence_inside_sequence(self, tensor_1, tensor_2): # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. # set to same device. we don't care what device.