Update Granite Vision Model Path / Tests (#35998)

* Update granite vision model path

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>

* Enable granite vision test

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>

---------

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
This commit is contained in:
Alex Brooks 2025-02-03 12:06:03 -07:00 committed by GitHub
parent 9d2056f12b
commit e284c7e954
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 10 deletions

View File

@ -31,13 +31,8 @@ Tips:
Sample inference:
```python
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import requests
# Note: These docs were written prior to the public model release,
# and this path is subject to change.
# Please see https://huggingface.co/ibm-granite for the current model list.
model_path = "ibm-granite/granite-3.1-2b-instruct-vision"
model_path = "ibm-granite/granite-vision-3.1-2b-preview"
processor = LlavaNextProcessor.from_pretrained(model_path)
model = LlavaNextForConditionalGeneration.from_pretrained(model_path).to("cuda")

View File

@ -586,15 +586,13 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT,
)
@unittest.skip(reason="Granite multimodal [vision] models are not yet released")
@slow
def test_granite_vision(self):
"""
Check the expected output of a granite vision model, which leverages
multiple vision feature layers and a visual encoder with no CLS (siglip).
"""
# TODO @alex-jw-brooks - update the path and enable this test once the 2b model is released
granite_model_path = "llava-granite-2b"
granite_model_path = "ibm-granite/granite-vision-3.1-2b-preview"
model = LlavaNextForConditionalGeneration.from_pretrained(granite_model_path)
self.processor = AutoProcessor.from_pretrained(granite_model_path)
prompt = "<|user|>\n<image>\nWhat is shown in this image?\n<|assistant|>\n"
@ -602,7 +600,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
# verify generation
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = "<|user|>\n\nWhat is shown in this image?\n<|assistant|>\nThe image depicts a diagram."
EXPECTED_DECODED_TEXT = "<|user|>\n\nWhat is shown in this image?\n<|assistant|>\nThe image displays a radar chart comparing the performance of various machine learning models." # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,