mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: return past_key_values
(#25086)
This commit is contained in:
parent
441c3e0dd2
commit
a6c82d4567
@ -104,12 +104,20 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -140,6 +148,13 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -149,6 +164,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -169,15 +185,23 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is
|
||||
passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples
|
||||
(one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length,
|
||||
hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -211,6 +235,13 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -220,6 +251,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -243,12 +275,20 @@ class SampleDecoderOnlyOutput(ModelOutput):
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -283,6 +323,13 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -292,6 +339,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -319,6 +367,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -327,6 +382,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
||||
beam_indices: Optional[torch.LongTensor] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -366,6 +422,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -377,6 +440,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -404,6 +468,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -412,6 +483,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
||||
beam_indices: Optional[torch.LongTensor] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -450,6 +522,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor = None
|
||||
@ -461,6 +540,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||
|
||||
|
||||
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
|
||||
@ -2148,8 +2228,8 @@ class GenerationMixin:
|
||||
items.append(item.repeat_interleave(1, dim=0))
|
||||
else:
|
||||
items.append(item.repeat_interleave(top_k, dim=0))
|
||||
new_key_values.append(items)
|
||||
model_kwargs["past_key_values"] = new_key_values
|
||||
new_key_values.append(tuple(items))
|
||||
model_kwargs["past_key_values"] = tuple(new_key_values)
|
||||
|
||||
if sequential:
|
||||
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
|
||||
@ -2330,6 +2410,17 @@ class GenerationMixin:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
||||
# `past_key_values` to be consistent with the other decoding methods
|
||||
if model_kwargs.get("past_key_values") is not None:
|
||||
past_key_values = []
|
||||
for layer in model_kwargs["past_key_values"]:
|
||||
layer_past_key_values = []
|
||||
for item in layer:
|
||||
layer_past_key_values.append(item[..., :-1, :])
|
||||
past_key_values.append(tuple(layer_past_key_values))
|
||||
model_kwargs["past_key_values"] = tuple(past_key_values)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return ContrastiveSearchEncoderDecoderOutput(
|
||||
sequences=input_ids,
|
||||
@ -2339,6 +2430,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return ContrastiveSearchDecoderOnlyOutput(
|
||||
@ -2346,6 +2438,7 @@ class GenerationMixin:
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
@ -2598,6 +2691,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return GreedySearchDecoderOnlyOutput(
|
||||
@ -2605,6 +2699,7 @@ class GenerationMixin:
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
@ -2880,6 +2975,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return SampleDecoderOnlyOutput(
|
||||
@ -2887,6 +2983,7 @@ class GenerationMixin:
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
@ -3201,6 +3298,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return BeamSearchDecoderOnlyOutput(
|
||||
@ -3210,6 +3308,7 @@ class GenerationMixin:
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
@ -3530,6 +3629,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return BeamSampleDecoderOnlyOutput(
|
||||
@ -3539,6 +3639,7 @@ class GenerationMixin:
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
@ -3909,6 +4010,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return BeamSearchDecoderOnlyOutput(
|
||||
@ -3918,6 +4020,7 @@ class GenerationMixin:
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
@ -4244,6 +4347,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return BeamSearchDecoderOnlyOutput(
|
||||
@ -4253,6 +4357,7 @@ class GenerationMixin:
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
@ -4672,6 +4777,7 @@ class GenerationMixin:
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return GreedySearchDecoderOnlyOutput(
|
||||
@ -4679,6 +4785,7 @@ class GenerationMixin:
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
|
@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
|
||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||
)
|
||||
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: old models with unique inputs/caches/others
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||
return
|
||||
# may fix in the future: needs modeling or test input preparation fixes for compatibility
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
return
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
|
||||
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
|
||||
# continuation would force it to generate beyond an EOS token)
|
||||
# 3. ignore `token_type_ids` for simplicity
|
||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||
# active by default on some models
|
||||
config.use_cache = True
|
||||
if "token_type_ids" in inputs:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs)
|
||||
if "past_key_values" not in outputs:
|
||||
return
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
|
||||
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
|
||||
|
||||
# Continue from the tokens generated above, preparing the inputs accordingly
|
||||
inputs["past_key_values"] = outputs_cached.past_key_values
|
||||
new_attention_len = outputs_cached.sequences.shape[-1]
|
||||
if config.is_encoder_decoder:
|
||||
inputs["decoder_input_ids"] = outputs_cached.sequences
|
||||
if "decoder_attention_mask" in inputs:
|
||||
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["decoder_attention_mask"],
|
||||
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
else:
|
||||
inputs["input_ids"] = outputs_cached.sequences
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["attention_mask"],
|
||||
(0, new_attention_len - inputs["attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
|
||||
|
||||
# The two sets of generated text and past kv should be equal to each other
|
||||
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
|
||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
outputs_cached.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Past Key Value States -- two notes here:
|
||||
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
if use_cache and has_standard_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
num_sequences_in_output,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
)
|
||||
|
||||
def _check_scores(self, batch_size, scores, length, config):
|
||||
expected_shape = (batch_size, config.vocab_size)
|
||||
self.assertIsInstance(scores, tuple)
|
||||
@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||
self.assertIsInstance(past_key_values, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
||||
[True] * len(past_key_values),
|
||||
)
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||
seq_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||
# set to same device. we don't care what device.
|
||||
|
Loading…
Reference in New Issue
Block a user