mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
VLM generate: tests can't generate image/video tokens (#33623)
This commit is contained in:
parent
653eb40425
commit
2fdb5e74cc
@ -132,7 +132,7 @@ class GenerationTesterMixin:
|
||||
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"repetition_penalty": 1.2,
|
||||
@ -146,6 +146,17 @@ class GenerationTesterMixin:
|
||||
"temperature": 0.7,
|
||||
}
|
||||
)
|
||||
# TODO (joao, raushan): see this comment for a long-term fix
|
||||
# https://github.com/huggingface/transformers/pull/33593#issuecomment-2361824264)
|
||||
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
|
||||
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
|
||||
if config is not None:
|
||||
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
|
||||
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
|
||||
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
|
||||
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
|
||||
if video_token_index is not None and video_token_index < config.get_text_config().vocab_size:
|
||||
logits_processor_kwargs["bad_words_ids"].append([video_token_index])
|
||||
|
||||
return logits_processor_kwargs
|
||||
|
||||
@ -211,7 +222,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -246,7 +257,7 @@ class GenerationTesterMixin:
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -281,7 +292,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -316,7 +327,7 @@ class GenerationTesterMixin:
|
||||
use_cache=True,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -350,7 +361,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -385,7 +396,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
use_cache=True,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -424,7 +435,7 @@ class GenerationTesterMixin:
|
||||
"top_k": 5,
|
||||
}
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -2052,6 +2063,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
|
@ -300,7 +300,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
@ -1485,7 +1485,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
return output_generate
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
|
@ -303,7 +303,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
@ -1469,7 +1469,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
|
||||
return output_generate
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
|
@ -411,9 +411,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
return False
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
|
||||
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
|
||||
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample, config=config)
|
||||
logits_processor_kwargs["temperature"] = 0.0
|
||||
return logits_processor_kwargs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user