mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Generate: move prepare_inputs_for_generation
in encoder-decoder llms (#34048)
This commit is contained in:
parent
fd70464fa7
commit
37ac078535
@ -387,13 +387,14 @@ class GenerationMixin:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# 3. Prepare base model inputs
|
||||
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs["input_ids"] = None
|
||||
if inputs_embeds is not None and not self.config.is_encoder_decoder and cache_position[0] == 0:
|
||||
model_inputs[input_ids_key] = None
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
else:
|
||||
# `clone` calls in this function ensure a consistent stride. See #32227
|
||||
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
|
||||
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
|
||||
model_inputs["inputs_embeds"] = None
|
||||
|
||||
# 4. Create missing `position_ids` on the fly
|
||||
@ -421,8 +422,8 @@ class GenerationMixin:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
batch_size, sequence_length = model_inputs[input_ids_key].shape
|
||||
device = model_inputs[input_ids_key].device
|
||||
|
||||
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
|
||||
# the 4D causal mask exists, it should be present in the base model (XXXModel class).
|
||||
@ -455,6 +456,8 @@ class GenerationMixin:
|
||||
if key not in model_inputs:
|
||||
model_inputs[key] = value
|
||||
|
||||
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
||||
model_inputs.pop("labels", None)
|
||||
return model_inputs
|
||||
|
||||
def _prepare_model_inputs(
|
||||
|
@ -1682,45 +1682,6 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
|
@ -2561,45 +2561,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
|
@ -1333,43 +1333,6 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMi
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
|
@ -1285,43 +1285,6 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, Ge
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
|
@ -915,6 +915,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
# Overwrite -- hardcoded key return (`is_decoder=True`)
|
||||
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
|
@ -1191,7 +1191,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin):
|
||||
@add_end_docstrings(FSMT_GENERATION_EXAMPLE)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
@ -1263,30 +1263,6 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
|
@ -2437,36 +2437,6 @@ class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin):
|
||||
encoder_global_attentions=outputs.encoder_global_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"global_attention_mask": global_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
|
@ -2085,42 +2085,6 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
||||
|
@ -1621,43 +1621,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1438,43 +1438,6 @@ class MarianMTModel(MarianPreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
|
@ -1647,43 +1647,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
|
@ -1820,45 +1820,6 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
@ -1475,43 +1475,6 @@ class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
|
@ -1762,44 +1762,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin):
|
||||
total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
|
||||
return total_router_logits, total_expert_indexes
|
||||
|
||||
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
|
@ -1390,43 +1390,6 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
|
@ -1588,37 +1588,6 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin)
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -1372,43 +1372,6 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||
**kwargs, # TODO: Check if this is needed. It is unused?
|
||||
) -> Dict[str, Any]:
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if decoder_input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id)
|
||||
|
||||
|
@ -2018,35 +2018,6 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMi
|
||||
|
||||
return loss
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
|
||||
|
||||
if past_key_values:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
||||
|
@ -1172,6 +1172,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
n_docs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- `do_marginalize` is explicitly set in the output
|
||||
|
||||
if past_key_values is not None:
|
||||
# if past is defined use only last decoder_input_ids
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
@ -1702,45 +1702,6 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
||||
total_expert_indexes.append(expert_indexes)
|
||||
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
output_router_logits = kwargs.get("output_router_logits", True)
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
"output_router_logits": output_router_logits,
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
||||
|
@ -1791,44 +1791,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
||||
|
@ -1302,45 +1302,6 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# cut decoder_input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
@ -1519,6 +1519,8 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
||||
|
||||
empty_past_kv = past_key_values is None
|
||||
|
||||
# Omit tokens covered by past_key_values
|
||||
|
@ -3841,6 +3841,38 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(model_inputs["input_ids"] is not None)
|
||||
self.assertTrue(model_inputs["inputs_embeds"] is None)
|
||||
|
||||
def test_prepare_inputs_for_generation_encoder_decoder_llm(self):
|
||||
"""
|
||||
Same as `test_prepare_inputs_for_generation_decoder_llm` but for encoder-decoder models. Main difference: we
|
||||
should look for `decoder_input_ids`, instead of `input_ids`.
|
||||
"""
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = model.to(torch_device)
|
||||
|
||||
# 1. Sanity check: the model's `prepare_inputs_for_generation` comes from `GenerationMixin`
|
||||
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))
|
||||
|
||||
# 2. If we pass input ids by themselves, we should get back the same input ids -- with the encoder-decoder key
|
||||
decoder_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
|
||||
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids)
|
||||
self.assertTrue(torch.all(model_inputs["decoder_input_ids"] == decoder_input_ids))
|
||||
|
||||
# 3. If we pass the attention mask too, we will get back the attention mask. Encoder-decoder models usually
|
||||
# don't use `position_ids`
|
||||
decoder_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
decoder_input_ids, decoder_attention_mask=decoder_attention_mask
|
||||
)
|
||||
self.assertTrue(torch.all(model_inputs["decoder_attention_mask"] == decoder_attention_mask))
|
||||
self.assertTrue("position_ids" not in model_inputs)
|
||||
|
||||
# 4. `use_cache` (and other kwargs, like the encoder outputs) are forwarded
|
||||
self.assertFalse("use_cache" in model_inputs) # From the previous input, there is no `use_cache`
|
||||
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids, use_cache=True, encoder_outputs="foo")
|
||||
self.assertTrue(model_inputs["use_cache"] is True)
|
||||
self.assertTrue(model_inputs["encoder_outputs"] == "foo")
|
||||
# See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here.
|
||||
|
||||
def test_generate_compile_fullgraph_tiny(self):
|
||||
"""
|
||||
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
|
||||
|
Loading…
Reference in New Issue
Block a user