[mllama] Allow pixel_values with inputs_embeds (#38334)

* Allow pixel_values and inputs_embeds at the same time

* remove unnecessary overwritten tests
This commit is contained in:
Cory Cornelius 2025-05-27 09:33:56 -07:00 committed by GitHub
parent 0f5a8243c4
commit 9c50576860
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 0 additions and 48 deletions

View File

@ -1699,11 +1699,6 @@ class MllamaModel(MllamaPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")

View File

@ -285,49 +285,6 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_config(self):
self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_resize_embeddings_results_in_successful_loss(self):
# resizing embeddings should result in successful loss computation
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()