Idefics: fix position ids (#33907)

* fix position ids

* fix labels also

* fix copies

* oops, not that one

* dont deprecate
This commit is contained in:
Raushan Turganbay 2024-10-11 10:28:34 +02:00 committed by GitHub
parent 7d97cca8dd
commit be9aeba581
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 79 additions and 69 deletions

View File

@ -183,51 +183,6 @@ def expand_inputs_for_generation(
return input_ids, model_kwargs
def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
cache_position = kwargs.get("cache_position", None)
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
pixel_values = kwargs.get("pixel_values", None)
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
image_attention_mask = kwargs.get("image_attention_mask", None)
interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"pixel_values": pixel_values,
"image_encoder_embeddings": image_encoder_embeddings,
"perceiver_embeddings": perceiver_embeddings,
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": interpolate_pos_encoding,
}
def freeze_model(model, module_exceptions=[]):
mapping = {
"LayerNorm": nn.LayerNorm,
@ -1210,11 +1165,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids[:, -seq_length:]
elif position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
position_ids = cache_position.unsqueeze(0)
if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:
raise ValueError(
@ -1684,7 +1637,9 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:].to(logits.device)
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
@ -1707,19 +1662,57 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
position_ids=None,
pixel_values=None,
image_hidden_states=None,
use_cache=None,
cache_position=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
model_inputs = {}
image_hidden_states = kwargs.pop("image_hidden_states", None)
if image_hidden_states is not None:
if self.config.use_resampler:
kwargs["perceiver_embeddings"] = image_hidden_states
model_inputs["perceiver_embeddings"] = image_hidden_states
else:
kwargs["image_encoder_embeddings"] = image_hidden_states
kwargs["pixel_values"] = None
inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
unwanted_kwargs = ["token_type_ids"]
for kwarg in unwanted_kwargs:
inputs.pop(kwarg, None)
return inputs
model_inputs["image_encoder_embeddings"] = image_hidden_states
pixel_values = None
model_inputs.update(
{
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_attention_mask": kwargs.get("image_attention_mask", None),
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
}
)
return model_inputs
@staticmethod
def _expand_inputs_for_generation(

View File

@ -1626,7 +1626,9 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:].to(logits.device)
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:

View File

@ -1213,7 +1213,9 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:].to(logits.device)
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:

View File

@ -546,7 +546,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:

View File

@ -923,7 +923,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:

View File

@ -1004,7 +1004,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:

View File

@ -519,7 +519,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:

View File

@ -676,7 +676,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:

View File

@ -532,7 +532,8 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
shift_attention_mask = attention_mask[..., 1:]
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:

View File

@ -656,7 +656,9 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:

View File

@ -539,7 +539,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else: