Updated image src in aria

This commit is contained in:
remi-or 2025-06-27 10:20:11 -05:00
parent b7b9ae39e8
commit 1ccdce6bc1

View File

@ -52,6 +52,9 @@ if is_torch_available():
if is_vision_available():
from PIL import Image
# Used to be https://aria-vl.github.io/static/images/view.jpg but it was removed, llava-vl has the same image
IMAGE_OF_VIEW_URL = "https://llava-vl.github.io/static/images/view.jpg"
class AriaVisionText2TextModelTester:
def __init__(
@ -265,20 +268,19 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True)
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
image_file = "https://aria-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
raw_image = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
EXPECTED_INPUT_IDS = torch.tensor(
[[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]],
) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
decoded_output = self.processor.decode(output[0], skip_special_tokens=True)
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
@slow
@require_torch_large_accelerator
@ -291,7 +293,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
image_file = "https://aria-vl.github.io/static/images/view.jpg"
image_file = IMAGE_OF_VIEW_URL
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
@ -317,7 +319,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:",
"USER: <image>\nWhat is this? ASSISTANT:",
]
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
@ -342,7 +344,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
@ -373,7 +375,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True)
@ -382,10 +384,8 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
self.assertEqual(
processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
decoded_output = processor.batch_decode(output, skip_special_tokens=True)
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
@slow
@require_torch_large_accelerator