diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 9b63e42946a..23190ebe851 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -113,6 +113,10 @@ from unittest.mock import patch from transformers.utils import is_sklearn_available +# TODO: raushan remove this when VLMs start accepting input embeds +VLM_CLASS_NAMES = ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3", "gotocr2", "qwen2vl", "qwen2_5_vl"] + + class GenerationTesterMixin: input_name = "input_ids" model_tester = None @@ -1258,6 +1262,7 @@ class GenerationTesterMixin: "blip2", # overridden `generate()` "instructblip", "instructblipvideo", + *VLM_CLASS_NAMES, # shouldn't suggest image tokens ] ): self.skipTest(reason="May fix in the future: need model-specific fixes") @@ -1411,7 +1416,8 @@ class GenerationTesterMixin: "return_dict_in_generate": True, "use_cache": True, } - output_assisted = model.generate(**generation_kwargs, **inputs_dict) + logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) + output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) self._check_generate_outputs(output_assisted, config, use_cache=True) @@ -1690,8 +1696,7 @@ class GenerationTesterMixin: # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` pixel_values_is_mutually_exclusive = any( - model_name in model_class.__name__.lower() - for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3", "gotocr2"] + model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES ) if pixel_values_is_mutually_exclusive: inputs_dict.pop("pixel_values", None) @@ -1699,7 +1704,7 @@ class GenerationTesterMixin: inputs_dict.pop("pixel_values_images", None) # 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds` has_complex_embeds_computation = any( - model_name in model_class.__name__.lower() for model_name in ["moshi", "qwen2vl", "qwen2_5_vl"] + model_name in model_class.__name__.lower() for model_name in ["moshi"] ) # 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate, # we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input. @@ -1769,8 +1774,7 @@ class GenerationTesterMixin: # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` pixel_values_is_mutually_exclusive = any( - model_name in model_class.__name__.lower() - for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"] + model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES ) if pixel_values_is_mutually_exclusive: inputs_dict.pop("pixel_values", None) @@ -1929,8 +1933,7 @@ class GenerationTesterMixin: self.skipTest(reason="This model doesn't return `past_key_values`") pixel_values_is_mutually_exclusive = any( - model_name in model_class.__name__.lower() - for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"] + model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES ) if pixel_values_is_mutually_exclusive: inputs_dict.pop("pixel_values", None) @@ -2311,11 +2314,14 @@ class GenerationTesterMixin: "return_dict_in_generate": True, "output_scores": True, } + logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) # Setting logits_to_keep at 0 keeps all logits (old behavior) - with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0) + with_all_logits = model.generate( + **generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0 + ) # By default, logits_to_keep is automatically set to 1 if not provided (new behavior) - without_all_logits = model.generate(**inputs_dict, **generation_kwargs) + without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs) self._check_similar_generate_outputs(with_all_logits, without_all_logits)