Generate: return past_key_values (#25086)

This commit is contained in:
Joao Gante 2023-11-02 15:39:21 +00:00 committed by GitHub
parent 441c3e0dd2
commit a6c82d4567
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 233 additions and 5 deletions

View File

@ -104,12 +104,20 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -140,6 +148,13 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
@ -149,6 +164,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -169,15 +185,23 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is
passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples
(one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length,
hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -211,6 +235,13 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
@ -220,6 +251,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -243,12 +275,20 @@ class SampleDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -283,6 +323,13 @@ class SampleEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
@ -292,6 +339,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -319,6 +367,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
@ -327,6 +382,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -366,6 +422,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
@ -377,6 +440,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -404,6 +468,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
@ -412,6 +483,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
@dataclass
@ -450,6 +522,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
@ -461,6 +540,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
@ -2148,8 +2228,8 @@ class GenerationMixin:
items.append(item.repeat_interleave(1, dim=0))
else:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items)
model_kwargs["past_key_values"] = new_key_values
new_key_values.append(tuple(items))
model_kwargs["past_key_values"] = tuple(new_key_values)
if sequential:
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
@ -2330,6 +2410,17 @@ class GenerationMixin:
streamer.end()
if return_dict_in_generate:
# Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods
if model_kwargs.get("past_key_values") is not None:
past_key_values = []
for layer in model_kwargs["past_key_values"]:
layer_past_key_values = []
for item in layer:
layer_past_key_values.append(item[..., :-1, :])
past_key_values.append(tuple(layer_past_key_values))
model_kwargs["past_key_values"] = tuple(past_key_values)
if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput(
sequences=input_ids,
@ -2339,6 +2430,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return ContrastiveSearchDecoderOnlyOutput(
@ -2346,6 +2438,7 @@ class GenerationMixin:
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
@ -2598,6 +2691,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GreedySearchDecoderOnlyOutput(
@ -2605,6 +2699,7 @@ class GenerationMixin:
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
@ -2880,6 +2975,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return SampleDecoderOnlyOutput(
@ -2887,6 +2983,7 @@ class GenerationMixin:
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
@ -3201,6 +3298,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return BeamSearchDecoderOnlyOutput(
@ -3210,6 +3308,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequence_outputs["sequences"]
@ -3530,6 +3629,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return BeamSampleDecoderOnlyOutput(
@ -3539,6 +3639,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequence_outputs["sequences"]
@ -3909,6 +4010,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return BeamSearchDecoderOnlyOutput(
@ -3918,6 +4020,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequence_outputs["sequences"]
@ -4244,6 +4347,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return BeamSearchDecoderOnlyOutput(
@ -4253,6 +4357,7 @@ class GenerationMixin:
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return sequence_outputs["sequences"]
@ -4672,6 +4777,7 @@ class GenerationMixin:
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GreedySearchDecoderOnlyOutput(
@ -4679,6 +4785,7 @@ class GenerationMixin:
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids

View File

@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
outputs_from_embeds_wo_ids[:, 1:].tolist(),
)
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
# won't fix: old models with unique inputs/caches/others
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
return
# may fix in the future: needs modeling or test input preparation fixes for compatibility
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
return
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
return
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
config.use_cache = True
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
model = model_class(config).to(torch_device)
model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
return
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[-1]
if config.is_encoder_decoder:
inputs["decoder_input_ids"] = outputs_cached.sequences
if "decoder_attention_mask" in inputs:
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
inputs["decoder_attention_mask"],
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
mode="constant",
value=1,
)
else:
inputs["input_ids"] = outputs_cached.sequences
if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
# The two sets of generated text and past kv should be equal to each other
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences
@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
use_cache=use_cache,
)
# Past Key Value States -- two notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)
if use_cache and has_standard_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size)
self.assertIsInstance(scores, tuple)
@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
[encoder_expected_shape] * len(hidden_states),
)
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, tuple)
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
[True] * len(past_key_values),
)
# (batch, head, seq_length, head_features)
expected_shape = (
batch_size * num_beam_groups,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
seq_length,
config.hidden_size // config.num_attention_heads,
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
# set to same device. we don't care what device.