From 2fdb5e74cce41aef7d168df1dc2cc9fec348a127 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 20 Sep 2024 15:43:27 +0100 Subject: [PATCH] VLM generate: tests can't generate image/video tokens (#33623) --- tests/generation/test_utils.py | 28 +++++++++++++------ .../models/musicgen/test_modeling_musicgen.py | 4 +-- .../test_modeling_musicgen_melody.py | 4 +-- tests/models/whisper/test_modeling_whisper.py | 4 +-- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 26ece9c25d0..08b40e71cf1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -132,7 +132,7 @@ class GenerationTesterMixin: return config, input_ids, attention_mask, inputs_dict - def _get_logits_processor_kwargs(self, do_sample=False): + def _get_logits_processor_kwargs(self, do_sample=False, config=None): logits_processor_kwargs = { "bad_words_ids": [[1, 0]], "repetition_penalty": 1.2, @@ -146,6 +146,17 @@ class GenerationTesterMixin: "temperature": 0.7, } ) + # TODO (joao, raushan): see this comment for a long-term fix + # https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264) + # This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them + # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens. + if config is not None: + image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None + video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None + if image_token_index is not None and image_token_index < config.get_text_config().vocab_size: + logits_processor_kwargs["bad_words_ids"].append([image_token_index]) + if video_token_index is not None and video_token_index < config.get_text_config().vocab_size: + logits_processor_kwargs["bad_words_ids"].append([video_token_index]) return logits_processor_kwargs @@ -211,7 +222,7 @@ class GenerationTesterMixin: return_dict_in_generate=False, use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -246,7 +257,7 @@ class GenerationTesterMixin: use_cache=True, ): torch.manual_seed(0) - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -281,7 +292,7 @@ class GenerationTesterMixin: return_dict_in_generate=False, use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -316,7 +327,7 @@ class GenerationTesterMixin: use_cache=True, ): torch.manual_seed(0) - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -350,7 +361,7 @@ class GenerationTesterMixin: return_dict_in_generate=False, use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -385,7 +396,7 @@ class GenerationTesterMixin: return_dict_in_generate=False, use_cache=True, ): - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -424,7 +435,7 @@ class GenerationTesterMixin: "top_k": 5, } - logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -2052,6 +2063,7 @@ class GenerationTesterMixin: ) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + @is_flaky() # assisted generation tests are flaky (minor fp ops differences) def test_assisted_decoding_with_num_logits_to_keep(self): for model_class in self.all_generative_model_classes: if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index e143b8ac3c8..a385a18b91c 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -300,7 +300,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) return config, input_ids, attention_mask, inputs_dict - def _get_logits_processor_kwargs(self, do_sample=False): + def _get_logits_processor_kwargs(self, do_sample=False, config=None): logits_processor_kwargs = {} return logits_processor_kwargs @@ -1485,7 +1485,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, return output_generate - def _get_logits_processor_kwargs(self, do_sample=False): + def _get_logits_processor_kwargs(self, do_sample=False, config=None): logits_processor_kwargs = {} return logits_processor_kwargs diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 28c2bf2f168..e8584e238d3 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -303,7 +303,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) return config, input_ids, attention_mask, inputs_dict - def _get_logits_processor_kwargs(self, do_sample=False): + def _get_logits_processor_kwargs(self, do_sample=False, config=None): logits_processor_kwargs = {} return logits_processor_kwargs @@ -1469,7 +1469,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester return output_generate - def _get_logits_processor_kwargs(self, do_sample=False): + def _get_logits_processor_kwargs(self, do_sample=False, config=None): logits_processor_kwargs = {} return logits_processor_kwargs diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 4bcf4252a60..bf0746a2927 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -411,9 +411,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi return False - def _get_logits_processor_kwargs(self, do_sample=False): + def _get_logits_processor_kwargs(self, do_sample=False, config=None): # Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search - logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample) + logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample, config=config) logits_processor_kwargs["temperature"] = 0.0 return logits_processor_kwargs