diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 83a489bb13f..f9ab6fce6cf 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -857,7 +857,7 @@ class GenerationMixin: self, unconditional_ids=negative_prompt_ids, unconditional_attention_mask=negative_prompt_attention_mask, - use_cache=model_kwargs["use_cache"], + use_cache=generation_config.use_cache, ) ) if generation_config.sequence_bias is not None: @@ -2004,10 +2004,7 @@ class GenerationMixin: # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": - model_kwargs["use_cache"] = True generation_config.use_cache = True - else: - model_kwargs["use_cache"] = generation_config.use_cache if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( @@ -2116,6 +2113,9 @@ class GenerationMixin: generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) + # Set model_kwargs `use_cache` so we can use it later in forward runs + model_kwargs["use_cache"] = generation_config.use_cache + # 10. go into different generation modes if generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 81159ee1c0c..4dd5f36a93e 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1663,19 +1663,31 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): def prepare_inputs_for_generation( self, input_ids, - past_key_values=None, attention_mask=None, position_ids=None, + inputs_embeds=None, + past_key_values=None, + cache_position=None, pixel_values=None, image_hidden_states=None, image_attention_mask=None, use_cache=None, - cache_position=None, **kwargs, ): + model_inputs = {} + if image_hidden_states is not None: + if self.config.use_resampler: + model_inputs["perceiver_embeddings"] = image_hidden_states + else: + model_inputs["image_encoder_embeddings"] = image_hidden_states + else: + model_inputs["pixel_values"] = pixel_values + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens if past_key_values is not None: - if input_ids.shape[1] != cache_position.shape[0]: + if inputs_embeds is not None: + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: input_ids = input_ids[:, cache_position] if image_attention_mask is not None: image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :] @@ -1690,19 +1702,17 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) - model_inputs = {} - image_hidden_states = kwargs.pop("image_hidden_states", None) - if image_hidden_states is not None: - if self.config.use_resampler: - model_inputs["perceiver_embeddings"] = image_hidden_states - else: - model_inputs["image_encoder_embeddings"] = image_hidden_states + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None}) else: - model_inputs["pixel_values"] = pixel_values + # The clone here is for the same reason as for `position_ids`. + model_inputs.update( + {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + ) model_inputs.update( { - "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, "cache_position": cache_position, diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 62b6ca22293..250c47c3a7e 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -772,6 +772,12 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni def test_custom_4d_attention_mask(self): pass + @unittest.skip( + reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels" + ) + def test_generate_from_inputs_embeds_decoder_only(self): + pass + @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") def test_generate_compile_fullgraph(self): pass diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index f87e87607c2..4071fcbb232 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -539,6 +539,31 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) + def test_inputs_embeds_matches_input_ids_with_generate(self): + # overwrite because IDEFICS needs ids and embeds at the input to be not None + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + + wte = model.get_input_embeddings() + + input_ids = inputs["input_ids"] + # some models infer position ids/attn mask differently when input ids + # by check if pad_token let's make sure no padding is in input ids + not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1 + input_ids[input_ids == pad_token_id] = not_pad_token_id + del inputs["input_ids"] + inputs_embeds = wte(input_ids) + out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2) + out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2) + + self.assertTrue(torch.allclose(out_embeds, out_ids)) + @require_torch class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index 44e06b07c54..f0366e7b539 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -526,6 +526,31 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) + def test_inputs_embeds_matches_input_ids_with_generate(self): + # overwrite because IDEFICS needs ids and embeds at the input to be not None + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + + wte = model.get_input_embeddings() + + input_ids = inputs["input_ids"] + # some models infer position ids/attn mask differently when input ids + # by check if pad_token let's make sure no padding is in input ids + not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1 + input_ids[input_ids == pad_token_id] = not_pad_token_id + del inputs["input_ids"] + inputs_embeds = wte(input_ids) + out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2) + out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2) + + self.assertTrue(torch.allclose(out_embeds, out_ids)) + @require_torch class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 396a4179388..22cbffcfdb6 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -428,6 +428,12 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) + @unittest.skip( + "KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten" + ) + def test_inputs_embeds_matches_input_ids_with_generate(self): + pass + @slow def test_model_from_pretrained(self): model_name = "microsoft/kosmos-2-patch14-224" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 38c1f5ff177..da33bbb48c5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3000,8 +3000,11 @@ class ModelTesterMixin: def test_inputs_embeds_matches_input_ids_with_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_generative_model_classes: - if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES): + for model_class in self.all_model_classes: + if model_class.__name__ not in [ + *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES), + *get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES), + ]: continue model = model_class(config) model.to(torch_device) @@ -3018,6 +3021,13 @@ class ModelTesterMixin: inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + # VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged + # embeds already + if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES): + inputs.pop("pixel_values", None) + inputs.pop("pixel_values_videos", None) + inputs.pop("pixel_values_images", None) + wte = model.get_input_embeddings() if not self.is_encoder_decoder: input_ids = inputs["input_ids"]