diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index a4c8f79ae92..1e4d7a47024 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -726,14 +726,23 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en elif mask_length_diff > 0: model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) + # Handle cross attention models if "cross_attention_mask" in model_kwargs: - # Mllama case is special and has another mask for cross attention model + # Mllama case cross_mask = model_kwargs["cross_attention_mask"] if mask_length_diff < 0: model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff] elif mask_length_diff > 0: new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1) model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1) + elif "image_attention_mask" in model_kwargs: + # IDEFICS case + cross_mask = model_kwargs["image_attention_mask"] + if mask_length_diff < 0: + model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff] + elif mask_length_diff > 0: + new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1) + model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1) return model_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 68b8b598ec0..09be2f6bc22 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2005,6 +2005,7 @@ class GenerationMixin: # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": model_kwargs["use_cache"] = True + generation_config.use_cache = True else: model_kwargs["use_cache"] = generation_config.use_cache @@ -4299,7 +4300,8 @@ class GenerationMixin: newly_added_length, is_decoder_attention=True, ) - else: + # some (V)LLMs have hard requirement on SDPA and thus never return attn + elif outputs.attentions[0] is not None: decoder_attentions = _split_model_outputs( decoder_attentions, outputs.attentions, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 02de8d61ae2..81159ee1c0c 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -28,12 +28,12 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ModelOutput -from ...modeling_utils import PretrainedConfig +from ...modeling_utils import PretrainedConfig, PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, @@ -622,11 +622,9 @@ class IdeficsAttention(nn.Module): query_states = self.q_layer_norm(query_states) key_states = self.k_layer_norm(key_states) + causal_mask = attention_mask if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -638,13 +636,13 @@ class IdeficsAttention(nn.Module): # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) @@ -1490,7 +1488,7 @@ class IdeficsModel(IdeficsPreTrainedModel): return causal_mask -class IdeficsForVisionText2Text(IdeficsPreTrainedModel): +class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] @@ -1670,6 +1668,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): position_ids=None, pixel_values=None, image_hidden_states=None, + image_attention_mask=None, use_cache=None, cache_position=None, **kwargs, @@ -1678,6 +1677,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): if past_key_values is not None: if input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] + if image_attention_mask is not None: + image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation @@ -1696,7 +1697,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): model_inputs["perceiver_embeddings"] = image_hidden_states else: model_inputs["image_encoder_embeddings"] = image_hidden_states - pixel_values = None + else: + model_inputs["pixel_values"] = pixel_values model_inputs.update( { @@ -1706,21 +1708,13 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): "cache_position": cache_position, "position_ids": position_ids, "attention_mask": attention_mask, - "pixel_values": pixel_values, - "image_attention_mask": kwargs.get("image_attention_mask", None), + "image_attention_mask": image_attention_mask, "interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False), } ) return model_inputs - @staticmethod - def _expand_inputs_for_generation( - *args, - **model_kwargs, - ): - return expand_inputs_for_generation(*args, **model_kwargs) - def _update_model_kwargs_for_generation( self, outputs: ModelOutput, @@ -1738,7 +1732,10 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): if "image_attention_mask" in model_kwargs: image_attention_mask = model_kwargs["image_attention_mask"] last_mask = image_attention_mask[:, -1, :].unsqueeze(1) - model_kwargs["image_attention_mask"] = last_mask + if model_kwargs.get("use_cache", True): + model_kwargs["image_attention_mask"] = last_mask + else: + model_kwargs["image_attention_mask"] = torch.cat([image_attention_mask, last_mask], dim=1) # Get the precomputed image_hidden_states model_kwargs["image_hidden_states"] = outputs.image_hidden_states diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index b53d0722587..d34e0acde4c 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1427,6 +1427,7 @@ class Idefics2Model(Idefics2PreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1657,35 +1658,19 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) past_key_values=None, attention_mask=None, inputs_embeds=None, + cache_position=None, + pixel_values=None, + pixel_attention_mask=None, + image_hidden_states=None, num_logits_to_keep=None, **kwargs, ): - past_length = 0 - # Omit tokens covered by past_key_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = past_key_values.get_seq_length() - max_cache_length = past_key_values.get_max_cache_shape() - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and past_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: + input_ids = input_ids[:, cache_position] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1696,21 +1681,22 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin) position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + # but IDEFICS requires noth ids and embeds to be present + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep - image_hidden_states = kwargs.get("image_hidden_states", None) if image_hidden_states is not None: pixel_values = None pixel_attention_mask = None else: - pixel_values = kwargs.get("pixel_values", None) - pixel_attention_mask = kwargs.get("pixel_attention_mask", None) + pixel_values = pixel_values + pixel_attention_mask = pixel_attention_mask model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 757391175ea..e653fd3d2a6 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -24,7 +24,8 @@ from torch.nn import CrossEntropyLoss from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, ModelOutput from ...utils import ( @@ -953,6 +954,8 @@ class Idefics3Model(Idefics3PreTrainedModel): past_seen_tokens = 0 if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() past_seen_tokens = past_key_values.get_seq_length() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: @@ -1019,6 +1022,7 @@ class Idefics3Model(Idefics3PreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1040,7 +1044,7 @@ class Idefics3Model(Idefics3PreTrainedModel): """The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, IDEFICS3_START_DOCSTRING, ) -class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel): +class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3 @@ -1245,35 +1249,19 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel): past_key_values=None, attention_mask=None, inputs_embeds=None, + cache_position=None, + pixel_values=None, + pixel_attention_mask=None, + image_hidden_states=None, num_logits_to_keep=None, **kwargs, ): - past_length = 0 - # Omit tokens covered by past_key_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens if past_key_values is not None: - # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = past_key_values.get_seq_length() - max_cache_length = past_key_values.get_max_cache_shape() - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and past_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: + input_ids = input_ids[:, cache_position] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1284,21 +1272,22 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel): position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_length == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + # but IDEFICS requires noth ids and embeds to be present + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep - image_hidden_states = kwargs.get("image_hidden_states", None) if image_hidden_states is not None: pixel_values = None pixel_attention_mask = None else: - pixel_values = kwargs.get("pixel_values", None) - pixel_attention_mask = kwargs.get("pixel_attention_mask", None) + pixel_values = pixel_values + pixel_attention_mask = pixel_attention_mask model_inputs.update( { "position_ids": position_ids, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5d92e8ce216..a1bc5265667 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -153,7 +153,11 @@ class GenerationTesterMixin: # This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens. if config is not None: - image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None + image_token_index = ( + config.image_token_index + if getattr(config, "image_token_index", None) is not None + else getattr(config, "image_token_id", None) + ) video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None if image_token_index is not None and image_token_index < config.get_text_config().vocab_size: logits_processor_kwargs["bad_words_ids"].append([image_token_index]) @@ -1496,13 +1500,14 @@ class GenerationTesterMixin: if "past_key_values" not in outputs: self.skipTest(reason="This model doesn't return `past_key_values`") + text_config = config.get_text_config() num_hidden_layers = ( - getattr(config, "decoder_layers", None) - or getattr(config, "num_decoder_layers", None) - or config.num_hidden_layers + getattr(text_config, "decoder_layers", None) + or getattr(text_config, "num_decoder_layers", None) + or text_config.num_hidden_layers ) - num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads) - embed_dim = getattr(config, "d_model", config.hidden_size) + num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads) + embed_dim = getattr(text_config, "d_model", text_config.hidden_size) per_head_embed_dim = embed_dim // num_attention_heads past_kv = outputs["past_key_values"] diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index a49bce8d878..62b6ca22293 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -14,8 +14,10 @@ # limitations under the License. """Testing suite for the PyTorch Idefics model.""" +import inspect import unittest +import pytest from parameterized import parameterized from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available @@ -31,6 +33,7 @@ from transformers.testing_utils import ( ) from transformers.utils import cached_property +from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_pipeline_mixin import PipelineTesterMixin @@ -318,6 +321,12 @@ class IdeficsModelTester: def test_eager_matches_sdpa_inference(self, torch_dtype: str): self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test") + @require_torch_sdpa + @slow + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_generate(self): + self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test") + @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch @@ -580,8 +589,9 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch -class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase): +class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase): all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else () + all_generative_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else () def setUp(self): self.model_tester = IdeficsModelTester( @@ -590,6 +600,182 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase): ) self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37) + @pytest.mark.generate + def test_left_padding_compatibility(self): + """Overwrite because IDEFICS needs image attention mask to be also padded""" + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + def _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature): + model_kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "image_attention_mask": image_attention_mask, + } + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + input_ids = inputs_dict.pop("input_ids") + attention_mask = inputs_dict.pop("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + image_attention_mask = inputs_dict.pop("image_attention_mask", None) + + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :] + + # With left-padding (length 32) + # can hardcode pad_token to be 0 as we'll do attn masking anyway + pad_token_id = ( + config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 + ) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + + pad_size_img = (input_ids.shape[0], 32, image_attention_mask.shape[-1]) + extra_img_mask = torch.zeros(pad_size_img, dtype=image_attention_mask.dtype, device=torch_device) + padded_image_attention_mask = torch.cat([extra_img_mask, image_attention_mask], dim=1) + model_kwargs = _prepare_model_kwargs( + padded_input_ids, padded_attention_mask, padded_image_attention_mask, signature + ) + next_logits_with_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :] + + # They should result in very similar logits + self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5)) + + @pytest.mark.generate + def test_generate_continue_from_past_key_values(self): + """Overwrite because IDEFICS needs image attention mask to be also processed""" + + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When + # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents + # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls + # with cache, what is considered a prompt is different in the two cases. + + model = model_class(config).to(torch_device) + model.eval() + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None + model.generation_config.encoder_no_repeat_ngram_size = 0 + model.generation_config.use_cache = True + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[-1] + inputs["input_ids"] = outputs_cached.sequences + if "attention_mask" in inputs: + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], + (0, new_attention_len - inputs["attention_mask"].shape[1]), + mode="constant", + value=1, + ) + if "image_attention_mask" in inputs: + inputs["image_attention_mask"] = inputs["image_attention_mask"][:, -1:, :] + + outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) + + # The two sets of generated text and past kv should be equal to each other + self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) + for layer_idx in range(len(outputs_cached.past_key_values)): + for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + outputs_cached.past_key_values[layer_idx][kv_idx], + ) + ) + + @pytest.mark.generate + def test_generate_without_input_ids(self): + """Overwrite because IDEFICS needs image attention mask to be also processed and requires image at input always.""" + + config, input_dict = self.prepare_config_and_inputs_for_generate() + pixel_values = input_dict["pixel_values"] + image_attention_mask = input_dict["image_attention_mask"][:, -1:, :] + + # hack in case they are equal, otherwise the attn mask will be [0] + if config.bos_token_id == config.pad_token_id: + config.pad_token_id = None + + for model_class in self.all_generative_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + output_ids_generate = model.generate( + pixel_values=pixel_values, + image_attention_mask=image_attention_mask, + do_sample=False, + max_new_tokens=self.max_new_tokens, + remove_invalid_values=True, + ) + self.assertIsNotNone(output_ids_generate) + + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + """ + Overwrite from generation tests because Idefics has only SDPA layers. + Do not skip because we still want generation tests to run. Rather we can remove checks for shape. + """ + pass + + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate(self): + pass + + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip(reason="We only test the model that takes in multiple images") + def test_custom_4d_attention_mask(self): + pass + + @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") + def test_generate_compile_fullgraph(self): + pass + @unittest.skip(reason="We only test the model that takes in multiple images") def test_model(self): pass diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index e02c5b4c9f0..f87e87607c2 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -19,6 +19,7 @@ import gc import unittest from io import BytesIO +import pytest import requests from transformers import ( @@ -96,7 +97,7 @@ class Idefics2VisionText2TextModelTester: "pad_token_id": 0, # None in the original configuration_mistral, we set it to the unk_token_id "bos_token_id": 1, "eos_token_id": 2, - "image_token_id": 32_001, + "image_token_id": 99, "tie_word_embeddings": False, "rope_theta": 10000.0, "sliding_window": 32, @@ -334,6 +335,7 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest """ all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -356,6 +358,72 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest def test_flash_attn_2_inference_padding_right(self): pass + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate(self): + pass + + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip( + reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" + ) + def test_prompt_lookup_decoding_matches_greedy_search(self): + pass + + @unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @pytest.mark.generate + def test_generate_from_inputs_embeds_decoder_only(self): + # overwrite because IDEFICS needs ids and embeds at the input to be not None + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + # Ignore: + # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids, + # which would cause a mismatch), + config.pad_token_id = config.eos_token_id = -1 + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + input_ids = inputs_dict.pop("input_ids") + + # Traditional way of generating text + outputs_from_ids = model.generate( + input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + ) + self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + + # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs_from_embeds = model.generate( + input_ids, + inputs_embeds=inputs_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) + + # But if we pass different inputs_embeds, we should get different outputs (the output text may be the + # same, but the logits will almost surely be different) + random_embeds = torch.rand_like(inputs_embeds) + outputs_from_rand_embeds = model.generate( + input_ids, + inputs_embeds=random_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + for i in range(len(outputs_from_rand_embeds.scores)): + self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) + # We need to override as we need to prepare such that the image token is the last token def test_resize_tokens_embeddings(self): (original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index 550bb2785e0..44e06b07c54 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -19,6 +19,7 @@ import gc import unittest from io import BytesIO +import pytest import requests from transformers import ( @@ -321,6 +322,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest """ all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True @@ -343,6 +345,72 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest def test_flash_attn_2_inference_padding_right(self): pass + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate(self): + pass + + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip( + reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates" + ) + def test_prompt_lookup_decoding_matches_greedy_search(self): + pass + + @unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @pytest.mark.generate + def test_generate_from_inputs_embeds_decoder_only(self): + # overwrite because IDEFICS needs ids and embeds at the input to be not None + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + + # Ignore: + # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids, + # which would cause a mismatch), + config.pad_token_id = config.eos_token_id = -1 + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + input_ids = inputs_dict.pop("input_ids") + + # Traditional way of generating text + outputs_from_ids = model.generate( + input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + ) + self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + + # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs_from_embeds = model.generate( + input_ids, + inputs_embeds=inputs_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) + + # But if we pass different inputs_embeds, we should get different outputs (the output text may be the + # same, but the logits will almost surely be different) + random_embeds = torch.rand_like(inputs_embeds) + outputs_from_rand_embeds = model.generate( + input_ids, + inputs_embeds=random_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + for i in range(len(outputs_from_rand_embeds.scores)): + self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) + # We need to override as we need to prepare such that the image token is the last token def test_resize_tokens_embeddings(self): (original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fa4a35391ba..38c1f5ff177 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4768,7 +4768,7 @@ class ModelTesterMixin: config, _ = self.model_tester.prepare_config_and_inputs_for_common() # TODO: to change it in the future with other relevant auto classes - fa2_model = AutoModelForCausalLM.from_config( + fa2_model = model_class._from_config( config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 ).to(torch_device) @@ -4789,7 +4789,7 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: fa2_model.save_pretrained(tmpdirname) - model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname) + model_from_pretrained = model_class.from_pretrained(tmpdirname) self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")