mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Updated image src in aria
This commit is contained in:
parent
b7b9ae39e8
commit
1ccdce6bc1
@ -52,6 +52,9 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
# Used to be https://aria-vl.github.io/static/images/view.jpg but it was removed, llava-vl has the same image
|
||||
IMAGE_OF_VIEW_URL = "https://llava-vl.github.io/static/images/view.jpg"
|
||||
|
||||
|
||||
class AriaVisionText2TextModelTester:
|
||||
def __init__(
|
||||
@ -265,20 +268,19 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True)
|
||||
|
||||
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
|
||||
image_file = "https://aria-vl.github.io/static/images/view.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
raw_image = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
|
||||
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
|
||||
|
||||
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
|
||||
EXPECTED_INPUT_IDS = torch.tensor(
|
||||
[[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]],
|
||||
) # fmt: skip
|
||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
decoded_output = self.processor.decode(output[0], skip_special_tokens=True)
|
||||
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch_large_accelerator
|
||||
@ -291,7 +293,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
|
||||
image_file = "https://aria-vl.github.io/static/images/view.jpg"
|
||||
image_file = IMAGE_OF_VIEW_URL
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
|
||||
@ -317,7 +319,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:",
|
||||
"USER: <image>\nWhat is this? ASSISTANT:",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
|
||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
|
||||
@ -342,7 +344,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
|
||||
"USER: <image>\nWhat is this?\nASSISTANT:",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
|
||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
|
||||
@ -373,7 +375,7 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
|
||||
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
|
||||
]
|
||||
image1 = Image.open(requests.get("https://aria-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||
image1 = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
|
||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||
|
||||
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True)
|
||||
@ -382,10 +384,8 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
decoded_output = processor.batch_decode(output, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch_large_accelerator
|
||||
|
Loading…
Reference in New Issue
Block a user