Remove flakiness in VLMs (#36242)

* fix

* nit

* no logits processor needed

* two more tests on assisted decoding
This commit is contained in:
Raushan Turganbay 2025-02-18 11:41:07 +01:00 committed by GitHub
parent fdcfdbfd22
commit e6cc410d5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -113,6 +113,10 @@ from unittest.mock import patch
from transformers.utils import is_sklearn_available
# TODO: raushan remove this when VLMs start accepting input embeds
VLM_CLASS_NAMES = ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3", "gotocr2", "qwen2vl", "qwen2_5_vl"]
class GenerationTesterMixin:
input_name = "input_ids"
model_tester = None
@ -1258,6 +1262,7 @@ class GenerationTesterMixin:
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
*VLM_CLASS_NAMES, # shouldn't suggest image tokens
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@ -1411,7 +1416,8 @@ class GenerationTesterMixin:
"return_dict_in_generate": True,
"use_cache": True,
}
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
self._check_generate_outputs(output_assisted, config, use_cache=True)
@ -1690,8 +1696,7 @@ class GenerationTesterMixin:
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower()
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3", "gotocr2"]
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
@ -1699,7 +1704,7 @@ class GenerationTesterMixin:
inputs_dict.pop("pixel_values_images", None)
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
has_complex_embeds_computation = any(
model_name in model_class.__name__.lower() for model_name in ["moshi", "qwen2vl", "qwen2_5_vl"]
model_name in model_class.__name__.lower() for model_name in ["moshi"]
)
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
@ -1769,8 +1774,7 @@ class GenerationTesterMixin:
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower()
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"]
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
@ -1929,8 +1933,7 @@ class GenerationTesterMixin:
self.skipTest(reason="This model doesn't return `past_key_values`")
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower()
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"]
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
@ -2311,11 +2314,14 @@ class GenerationTesterMixin:
"return_dict_in_generate": True,
"output_scores": True,
}
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
# Setting logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0)
with_all_logits = model.generate(
**generation_kwargs, **inputs_dict, **logits_processor_kwargs, logits_to_keep=0
)
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
without_all_logits = model.generate(**inputs_dict, **generation_kwargs, **logits_processor_kwargs)
self._check_similar_generate_outputs(with_all_logits, without_all_logits)