Idefics: enable generation tests (#34062)

* add idefics

* conflicts after merging main

* enable tests but need to fix some

* fix tests

* no print

* fix/skip some slow tests

* continue not skip

* rebasing broken smth, this is the fix
This commit is contained in:
Raushan Turganbay 2024-10-15 11:17:14 +02:00 committed by GitHub
parent dd4216b766
commit 23874f5948
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 406 additions and 96 deletions

View File

@ -726,14 +726,23 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
# Handle cross attention models
if "cross_attention_mask" in model_kwargs:
# Mllama case is special and has another mask for cross attention model
# Mllama case
cross_mask = model_kwargs["cross_attention_mask"]
if mask_length_diff < 0:
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
elif mask_length_diff > 0:
new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
elif "image_attention_mask" in model_kwargs:
# IDEFICS case
cross_mask = model_kwargs["image_attention_mask"]
if mask_length_diff < 0:
model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
elif mask_length_diff > 0:
new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
return model_kwargs

View File

@ -2005,6 +2005,7 @@ class GenerationMixin:
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
generation_config.use_cache = True
else:
model_kwargs["use_cache"] = generation_config.use_cache
@ -4299,7 +4300,8 @@ class GenerationMixin:
newly_added_length,
is_decoder_attention=True,
)
else:
# some (V)LLMs have hard requirement on SDPA and thus never return attn
elif outputs.attentions[0] is not None:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,

View File

@ -28,12 +28,12 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...modeling_utils import PretrainedConfig, PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
@ -622,11 +622,9 @@ class IdeficsAttention(nn.Module):
query_states = self.q_layer_norm(query_states)
key_states = self.k_layer_norm(key_states)
causal_mask = attention_mask
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@ -638,13 +636,13 @@ class IdeficsAttention(nn.Module):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
@ -1490,7 +1488,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
return causal_mask
class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
@ -1670,6 +1668,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
position_ids=None,
pixel_values=None,
image_hidden_states=None,
image_attention_mask=None,
use_cache=None,
cache_position=None,
**kwargs,
@ -1678,6 +1677,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
@ -1696,7 +1697,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
model_inputs["perceiver_embeddings"] = image_hidden_states
else:
model_inputs["image_encoder_embeddings"] = image_hidden_states
pixel_values = None
else:
model_inputs["pixel_values"] = pixel_values
model_inputs.update(
{
@ -1706,21 +1708,13 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_attention_mask": kwargs.get("image_attention_mask", None),
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
}
)
return model_inputs
@staticmethod
def _expand_inputs_for_generation(
*args,
**model_kwargs,
):
return expand_inputs_for_generation(*args, **model_kwargs)
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
@ -1738,7 +1732,10 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask
if model_kwargs.get("use_cache", True):
model_kwargs["image_attention_mask"] = last_mask
else:
model_kwargs["image_attention_mask"] = torch.cat([image_attention_mask, last_mask], dim=1)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states

View File

@ -1427,6 +1427,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@ -1657,35 +1658,19 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
pixel_values=None,
pixel_attention_mask=None,
image_hidden_states=None,
num_logits_to_keep=None,
**kwargs,
):
past_length = 0
# Omit tokens covered by past_key_values
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_cache_shape()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and past_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@ -1696,21 +1681,22 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
# but IDEFICS requires noth ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
image_hidden_states = kwargs.get("image_hidden_states", None)
if image_hidden_states is not None:
pixel_values = None
pixel_attention_mask = None
else:
pixel_values = kwargs.get("pixel_values", None)
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
pixel_values = pixel_values
pixel_attention_mask = pixel_attention_mask
model_inputs.update(
{
"position_ids": position_ids,

View File

@ -24,7 +24,8 @@ from torch.nn import CrossEntropyLoss
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...utils import (
@ -953,6 +954,8 @@ class Idefics3Model(Idefics3PreTrainedModel):
past_seen_tokens = 0
if use_cache:
if past_key_values is None:
past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
@ -1019,6 +1022,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@ -1040,7 +1044,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
"""The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
IDEFICS3_START_DOCSTRING,
)
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
@ -1245,35 +1249,19 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
pixel_values=None,
pixel_attention_mask=None,
image_hidden_states=None,
num_logits_to_keep=None,
**kwargs,
):
past_length = 0
# Omit tokens covered by past_key_values
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = past_key_values.get_seq_length()
max_cache_length = past_key_values.get_max_cache_shape()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and past_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@ -1284,21 +1272,22 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
# but IDEFICS requires noth ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else:
model_inputs = {"input_ids": input_ids}
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
image_hidden_states = kwargs.get("image_hidden_states", None)
if image_hidden_states is not None:
pixel_values = None
pixel_attention_mask = None
else:
pixel_values = kwargs.get("pixel_values", None)
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
pixel_values = pixel_values
pixel_attention_mask = pixel_attention_mask
model_inputs.update(
{
"position_ids": position_ids,

View File

@ -153,7 +153,11 @@ class GenerationTesterMixin:
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
if config is not None:
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
image_token_index = (
config.image_token_index
if getattr(config, "image_token_index", None) is not None
else getattr(config, "image_token_id", None)
)
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
@ -1496,13 +1500,14 @@ class GenerationTesterMixin:
if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")
text_config = config.get_text_config()
num_hidden_layers = (
getattr(config, "decoder_layers", None)
or getattr(config, "num_decoder_layers", None)
or config.num_hidden_layers
getattr(text_config, "decoder_layers", None)
or getattr(text_config, "num_decoder_layers", None)
or text_config.num_hidden_layers
)
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads)
embed_dim = getattr(config, "d_model", config.hidden_size)
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
past_kv = outputs["past_key_values"]

View File

@ -14,8 +14,10 @@
# limitations under the License.
"""Testing suite for the PyTorch Idefics model."""
import inspect
import unittest
import pytest
from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
@ -31,6 +33,7 @@ from transformers.testing_utils import (
)
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin
@ -318,6 +321,12 @@ class IdeficsModelTester:
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_generate(self):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
@ -580,8 +589,9 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch
class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
all_generative_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
def setUp(self):
self.model_tester = IdeficsModelTester(
@ -590,6 +600,182 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
)
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
@pytest.mark.generate
def test_left_padding_compatibility(self):
"""Overwrite because IDEFICS needs image attention mask to be also padded"""
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
def _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature):
model_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"image_attention_mask": image_attention_mask,
}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
input_ids = inputs_dict.pop("input_ids")
attention_mask = inputs_dict.pop("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
image_attention_mask = inputs_dict.pop("image_attention_mask", None)
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :]
# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
pad_size_img = (input_ids.shape[0], 32, image_attention_mask.shape[-1])
extra_img_mask = torch.zeros(pad_size_img, dtype=image_attention_mask.dtype, device=torch_device)
padded_image_attention_mask = torch.cat([extra_img_mask, image_attention_mask], dim=1)
model_kwargs = _prepare_model_kwargs(
padded_input_ids, padded_attention_mask, padded_image_attention_mask, signature
)
next_logits_with_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
"""Overwrite because IDEFICS needs image attention mask to be also processed"""
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
# with cache, what is considered a prompt is different in the two cases.
model = model_class(config).to(torch_device)
model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.generation_config.encoder_no_repeat_ngram_size = 0
model.generation_config.use_cache = True
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[-1]
inputs["input_ids"] = outputs_cached.sequences
if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
if "image_attention_mask" in inputs:
inputs["image_attention_mask"] = inputs["image_attention_mask"][:, -1:, :]
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
# The two sets of generated text and past kv should be equal to each other
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
@pytest.mark.generate
def test_generate_without_input_ids(self):
"""Overwrite because IDEFICS needs image attention mask to be also processed and requires image at input always."""
config, input_dict = self.prepare_config_and_inputs_for_generate()
pixel_values = input_dict["pixel_values"]
image_attention_mask = input_dict["image_attention_mask"][:, -1:, :]
# hack in case they are equal, otherwise the attn mask will be [0]
if config.bos_token_id == config.pad_token_id:
config.pad_token_id = None
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
pixel_values=pixel_values,
image_attention_mask=image_attention_mask,
do_sample=False,
max_new_tokens=self.max_new_tokens,
remove_invalid_values=True,
)
self.assertIsNotNone(output_ids_generate)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
"""
Overwrite from generation tests because Idefics has only SDPA layers.
Do not skip because we still want generation tests to run. Rather we can remove checks for shape.
"""
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(reason="We only test the model that takes in multiple images")
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_fullgraph(self):
pass
@unittest.skip(reason="We only test the model that takes in multiple images")
def test_model(self):
pass

View File

@ -19,6 +19,7 @@ import gc
import unittest
from io import BytesIO
import pytest
import requests
from transformers import (
@ -96,7 +97,7 @@ class Idefics2VisionText2TextModelTester:
"pad_token_id": 0, # None in the original configuration_mistral, we set it to the unk_token_id
"bos_token_id": 1,
"eos_token_id": 2,
"image_token_id": 32_001,
"image_token_id": 99,
"tie_word_embeddings": False,
"rope_theta": 10000.0,
"sliding_window": 32,
@ -334,6 +335,7 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
"""
all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
@ -356,6 +358,72 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_flash_attn_2_inference_padding_right(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
# Traditional way of generating text
outputs_from_ids = model.generate(
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -19,6 +19,7 @@ import gc
import unittest
from io import BytesIO
import pytest
import requests
from transformers import (
@ -321,6 +322,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
"""
all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
@ -343,6 +345,72 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_flash_attn_2_inference_padding_right(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
# Traditional way of generating text
outputs_from_ids = model.generate(
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -4768,7 +4768,7 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes
fa2_model = AutoModelForCausalLM.from_config(
fa2_model = model_class._from_config(
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
).to(torch_device)
@ -4789,7 +4789,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname)
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
model_from_pretrained = model_class.from_pretrained(tmpdirname)
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")