mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Idefics: enable generation tests (#34062)
* add idefics * conflicts after merging main * enable tests but need to fix some * fix tests * no print * fix/skip some slow tests * continue not skip * rebasing broken smth, this is the fix
This commit is contained in:
parent
dd4216b766
commit
23874f5948
@ -726,14 +726,23 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en
|
||||
elif mask_length_diff > 0:
|
||||
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
|
||||
|
||||
# Handle cross attention models
|
||||
if "cross_attention_mask" in model_kwargs:
|
||||
# Mllama case is special and has another mask for cross attention model
|
||||
# Mllama case
|
||||
cross_mask = model_kwargs["cross_attention_mask"]
|
||||
if mask_length_diff < 0:
|
||||
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
|
||||
elif mask_length_diff > 0:
|
||||
new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
|
||||
model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
|
||||
elif "image_attention_mask" in model_kwargs:
|
||||
# IDEFICS case
|
||||
cross_mask = model_kwargs["image_attention_mask"]
|
||||
if mask_length_diff < 0:
|
||||
model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
|
||||
elif mask_length_diff > 0:
|
||||
new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
|
||||
model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
@ -2005,6 +2005,7 @@ class GenerationMixin:
|
||||
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
||||
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
||||
model_kwargs["use_cache"] = True
|
||||
generation_config.use_cache = True
|
||||
else:
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
@ -4299,7 +4300,8 @@ class GenerationMixin:
|
||||
newly_added_length,
|
||||
is_decoder_attention=True,
|
||||
)
|
||||
else:
|
||||
# some (V)LLMs have hard requirement on SDPA and thus never return attn
|
||||
elif outputs.attentions[0] is not None:
|
||||
decoder_attentions = _split_model_outputs(
|
||||
decoder_attentions,
|
||||
outputs.attentions,
|
||||
|
@ -28,12 +28,12 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PretrainedConfig
|
||||
from ...modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -622,11 +622,9 @@ class IdeficsAttention(nn.Module):
|
||||
query_states = self.q_layer_norm(query_states)
|
||||
key_states = self.k_layer_norm(key_states)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
@ -638,13 +636,13 @@ class IdeficsAttention(nn.Module):
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False
|
||||
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
@ -1490,7 +1488,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
||||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
||||
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
|
||||
|
||||
@ -1670,6 +1668,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
image_hidden_states=None,
|
||||
image_attention_mask=None,
|
||||
use_cache=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
@ -1678,6 +1677,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
if past_key_values is not None:
|
||||
if input_ids.shape[1] != cache_position.shape[0]:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
if image_attention_mask is not None:
|
||||
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
@ -1696,7 +1697,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
model_inputs["perceiver_embeddings"] = image_hidden_states
|
||||
else:
|
||||
model_inputs["image_encoder_embeddings"] = image_hidden_states
|
||||
pixel_values = None
|
||||
else:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
@ -1706,21 +1708,13 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
"cache_position": cache_position,
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"image_attention_mask": kwargs.get("image_attention_mask", None),
|
||||
"image_attention_mask": image_attention_mask,
|
||||
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
|
||||
}
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _expand_inputs_for_generation(
|
||||
*args,
|
||||
**model_kwargs,
|
||||
):
|
||||
return expand_inputs_for_generation(*args, **model_kwargs)
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
@ -1738,7 +1732,10 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
if "image_attention_mask" in model_kwargs:
|
||||
image_attention_mask = model_kwargs["image_attention_mask"]
|
||||
last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
|
||||
model_kwargs["image_attention_mask"] = last_mask
|
||||
if model_kwargs.get("use_cache", True):
|
||||
model_kwargs["image_attention_mask"] = last_mask
|
||||
else:
|
||||
model_kwargs["image_attention_mask"] = torch.cat([image_attention_mask, last_mask], dim=1)
|
||||
|
||||
# Get the precomputed image_hidden_states
|
||||
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
||||
|
@ -1427,6 +1427,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@ -1657,35 +1658,19 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
pixel_values=None,
|
||||
pixel_attention_mask=None,
|
||||
image_hidden_states=None,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = past_key_values.get_seq_length()
|
||||
max_cache_length = past_key_values.get_max_cache_shape()
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and past_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
@ -1696,21 +1681,22 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
# but IDEFICS requires noth ids and embeds to be present
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if num_logits_to_keep is not None:
|
||||
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||
|
||||
image_hidden_states = kwargs.get("image_hidden_states", None)
|
||||
if image_hidden_states is not None:
|
||||
pixel_values = None
|
||||
pixel_attention_mask = None
|
||||
else:
|
||||
pixel_values = kwargs.get("pixel_values", None)
|
||||
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
|
||||
pixel_values = pixel_values
|
||||
pixel_attention_mask = pixel_attention_mask
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
|
@ -24,7 +24,8 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...utils import (
|
||||
@ -953,6 +954,8 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache:
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||
@ -1019,6 +1022,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@ -1040,7 +1044,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
"""The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
|
||||
IDEFICS3_START_DOCSTRING,
|
||||
)
|
||||
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
|
||||
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
|
||||
@ -1245,35 +1249,19 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
pixel_values=None,
|
||||
pixel_attention_mask=None,
|
||||
image_hidden_states=None,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
# Omit tokens covered by past_key_values
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
if past_key_values is not None:
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = past_key_values.get_seq_length()
|
||||
max_cache_length = past_key_values.get_max_cache_shape()
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
|
||||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
||||
if (
|
||||
max_cache_length is not None
|
||||
and attention_mask is not None
|
||||
and past_length + input_ids.shape[1] > max_cache_length
|
||||
):
|
||||
attention_mask = attention_mask[:, -max_cache_length:]
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
@ -1284,21 +1272,22 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
# but IDEFICS requires noth ids and embeds to be present
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if num_logits_to_keep is not None:
|
||||
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||
|
||||
image_hidden_states = kwargs.get("image_hidden_states", None)
|
||||
if image_hidden_states is not None:
|
||||
pixel_values = None
|
||||
pixel_attention_mask = None
|
||||
else:
|
||||
pixel_values = kwargs.get("pixel_values", None)
|
||||
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
|
||||
pixel_values = pixel_values
|
||||
pixel_attention_mask = pixel_attention_mask
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
|
@ -153,7 +153,11 @@ class GenerationTesterMixin:
|
||||
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
|
||||
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
|
||||
if config is not None:
|
||||
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
|
||||
image_token_index = (
|
||||
config.image_token_index
|
||||
if getattr(config, "image_token_index", None) is not None
|
||||
else getattr(config, "image_token_id", None)
|
||||
)
|
||||
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
|
||||
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
|
||||
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
|
||||
@ -1496,13 +1500,14 @@ class GenerationTesterMixin:
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
text_config = config.get_text_config()
|
||||
num_hidden_layers = (
|
||||
getattr(config, "decoder_layers", None)
|
||||
or getattr(config, "num_decoder_layers", None)
|
||||
or config.num_hidden_layers
|
||||
getattr(text_config, "decoder_layers", None)
|
||||
or getattr(text_config, "num_decoder_layers", None)
|
||||
or text_config.num_hidden_layers
|
||||
)
|
||||
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads)
|
||||
embed_dim = getattr(config, "d_model", config.hidden_size)
|
||||
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
|
||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
|
@ -14,8 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Idefics model."""
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
|
||||
@ -31,6 +33,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@ -318,6 +321,12 @@ class IdeficsModelTester:
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
||||
@require_torch
|
||||
@ -580,8 +589,9 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
|
||||
@require_torch
|
||||
class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
|
||||
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = IdeficsModelTester(
|
||||
@ -590,6 +600,182 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
|
||||
)
|
||||
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
"""Overwrite because IDEFICS needs image attention mask to be also padded"""
|
||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
||||
|
||||
def _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature):
|
||||
model_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"image_attention_mask": image_attention_mask,
|
||||
}
|
||||
if "position_ids" in signature:
|
||||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
model_kwargs["position_ids"] = position_ids
|
||||
if "cache_position" in signature:
|
||||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
attention_mask = inputs_dict.pop("attention_mask")
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
image_attention_mask = inputs_dict.pop("image_attention_mask", None)
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
# no cache as some models require special cache classes to be init outside forward
|
||||
model.generation_config.use_cache = False
|
||||
|
||||
# Without padding
|
||||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature)
|
||||
next_logits_wo_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :]
|
||||
|
||||
# With left-padding (length 32)
|
||||
# can hardcode pad_token to be 0 as we'll do attn masking anyway
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
|
||||
)
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||
|
||||
pad_size_img = (input_ids.shape[0], 32, image_attention_mask.shape[-1])
|
||||
extra_img_mask = torch.zeros(pad_size_img, dtype=image_attention_mask.dtype, device=torch_device)
|
||||
padded_image_attention_mask = torch.cat([extra_img_mask, image_attention_mask], dim=1)
|
||||
model_kwargs = _prepare_model_kwargs(
|
||||
padded_input_ids, padded_attention_mask, padded_image_attention_mask, signature
|
||||
)
|
||||
next_logits_with_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :]
|
||||
|
||||
# They should result in very similar logits
|
||||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
"""Overwrite because IDEFICS needs image attention mask to be also processed"""
|
||||
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
|
||||
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
|
||||
# continuation would force it to generate beyond an EOS token)
|
||||
# 3. ignore `token_type_ids` for simplicity
|
||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||
# active by default on some models
|
||||
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
|
||||
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
|
||||
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
|
||||
# with cache, what is considered a prompt is different in the two cases.
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
model.generation_config.encoder_no_repeat_ngram_size = 0
|
||||
model.generation_config.use_cache = True
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
|
||||
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
|
||||
|
||||
# Continue from the tokens generated above, preparing the inputs accordingly
|
||||
inputs["past_key_values"] = outputs_cached.past_key_values
|
||||
new_attention_len = outputs_cached.sequences.shape[-1]
|
||||
inputs["input_ids"] = outputs_cached.sequences
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["attention_mask"],
|
||||
(0, new_attention_len - inputs["attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
if "image_attention_mask" in inputs:
|
||||
inputs["image_attention_mask"] = inputs["image_attention_mask"][:, -1:, :]
|
||||
|
||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
|
||||
|
||||
# The two sets of generated text and past kv should be equal to each other
|
||||
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
|
||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
outputs_cached.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_without_input_ids(self):
|
||||
"""Overwrite because IDEFICS needs image attention mask to be also processed and requires image at input always."""
|
||||
|
||||
config, input_dict = self.prepare_config_and_inputs_for_generate()
|
||||
pixel_values = input_dict["pixel_values"]
|
||||
image_attention_mask = input_dict["image_attention_mask"][:, -1:, :]
|
||||
|
||||
# hack in case they are equal, otherwise the attn mask will be [0]
|
||||
if config.bos_token_id == config.pad_token_id:
|
||||
config.pad_token_id = None
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
pixel_values=pixel_values,
|
||||
image_attention_mask=image_attention_mask,
|
||||
do_sample=False,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
"""
|
||||
Overwrite from generation tests because Idefics has only SDPA layers.
|
||||
Do not skip because we still want generation tests to run. Rather we can remove checks for shape.
|
||||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||
def test_model(self):
|
||||
pass
|
||||
|
@ -19,6 +19,7 @@ import gc
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from transformers import (
|
||||
@ -96,7 +97,7 @@ class Idefics2VisionText2TextModelTester:
|
||||
"pad_token_id": 0, # None in the original configuration_mistral, we set it to the unk_token_id
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"image_token_id": 32_001,
|
||||
"image_token_id": 99,
|
||||
"tie_word_embeddings": False,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 32,
|
||||
@ -334,6 +335,7 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
"""
|
||||
|
||||
all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
@ -356,6 +358,72 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
|
||||
)
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
# Ignore:
|
||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||
# which would cause a mismatch),
|
||||
config.pad_token_id = config.eos_token_id = -1
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(
|
||||
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(
|
||||
input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=5,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# same, but the logits will almost surely be different)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(
|
||||
input_ids,
|
||||
inputs_embeds=random_embeds,
|
||||
max_new_tokens=5,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||
|
||||
# We need to override as we need to prepare such that the image token is the last token
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -19,6 +19,7 @@ import gc
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from transformers import (
|
||||
@ -321,6 +322,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
"""
|
||||
|
||||
all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
@ -343,6 +345,72 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
|
||||
)
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
# Ignore:
|
||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||
# which would cause a mismatch),
|
||||
config.pad_token_id = config.eos_token_id = -1
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(
|
||||
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(
|
||||
input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=5,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# same, but the logits will almost surely be different)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(
|
||||
input_ids,
|
||||
inputs_embeds=random_embeds,
|
||||
max_new_tokens=5,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||
|
||||
# We need to override as we need to prepare such that the image token is the last token
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -4768,7 +4768,7 @@ class ModelTesterMixin:
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO: to change it in the future with other relevant auto classes
|
||||
fa2_model = AutoModelForCausalLM.from_config(
|
||||
fa2_model = model_class._from_config(
|
||||
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
@ -4789,7 +4789,7 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fa2_model.save_pretrained(tmpdirname)
|
||||
|
||||
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
||||
model_from_pretrained = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user