Fix Mistral3 tests (#36797)

* fix processor tests

* fix modeling tests

* fix test processor chat template

* revert modeling test changes
This commit is contained in:
Yoni Gozlan 2025-03-18 13:08:12 -04:00 committed by GitHub
parent db1d4c5a0b
commit 30580f035b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 5 deletions

View File

@ -54,7 +54,7 @@ Here is how you can use the `image-text-to-text` pipeline to perform inference w
... },
... ]
>>> pipe = pipeline("image-text-to-text", model="../mistral3_weights", torch_dtype=torch.bfloat16)
>>> pipe = pipeline("image-text-to-text", model="mistralai/Mistral-Small-3.1-24B-Instruct-2503", torch_dtype=torch.bfloat16)
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
>>> outputs[0]["generated_text"]
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'

View File

@ -51,17 +51,75 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
processor = PixtralProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
processor = self.processor_class.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
processor.save_pretrained(self.tmpdirname)
def get_processor(self):
return self.processor_class.from_pretrained(self.tmpdirname)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
def test_chat_template_accepts_processing_kwargs(self):
# override to use slow image processor to return numpy arrays
processor = self.processor_class.from_pretrained(self.tmpdirname, use_fast=False)
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
truncation=True,
max_length=50,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
truncation=True,
max_length=5,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
do_rescale=True,
rescale_factor=-1,
return_tensors="np",
)
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname, use_fast=False)
expected_prompt = "<s>[SYSTEM_PROMPT][/SYSTEM_PROMPT][INST][IMG]What is shown in this image?[/INST]"
messages = [
{
"role": "system",
"content": "",
},
{
"role": "user",
"content": [
@ -81,6 +139,10 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_token_index = 10
messages = [
{
"role": "system",
"content": "",
},
{
"role": "user",
"content": [