VLM generate: tests can't generate image/video tokens (#33623)

This commit is contained in:
Joao Gante 2024-09-20 15:43:27 +01:00 committed by GitHub
parent 653eb40425
commit 2fdb5e74cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 14 deletions

View File

@ -132,7 +132,7 @@ class GenerationTesterMixin:
return config, input_ids, attention_mask, inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
"repetition_penalty": 1.2,
@ -146,6 +146,17 @@ class GenerationTesterMixin:
"temperature": 0.7,
}
)
# TODO (joao, raushan): see this comment for a long-term fix
# https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264)
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
if config is not None:
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
if video_token_index is not None and video_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([video_token_index])
return logits_processor_kwargs
@ -211,7 +222,7 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -246,7 +257,7 @@ class GenerationTesterMixin:
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -281,7 +292,7 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -316,7 +327,7 @@ class GenerationTesterMixin:
use_cache=True,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -350,7 +361,7 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -385,7 +396,7 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
use_cache=True,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -424,7 +435,7 @@ class GenerationTesterMixin:
"top_k": 5,
}
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -2052,6 +2063,7 @@ class GenerationTesterMixin:
)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):

View File

@ -300,7 +300,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
@ -1485,7 +1485,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return output_generate
def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

View File

@ -303,7 +303,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
@ -1469,7 +1469,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
return output_generate
def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs

View File

@ -411,9 +411,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return False
def _get_logits_processor_kwargs(self, do_sample=False):
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample, config=config)
logits_processor_kwargs["temperature"] = 0.0
return logits_processor_kwargs