transformers/tests/models/llama4/test_modeling_llama4.py
Yao Matrix 89542fb81c
enable more test cases on xpu (#38572)
* enable glm4 integration cases on XPU, set xpu expectation for blip2

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* more

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* refine wording

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* refine test case names

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* run

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* add gemma2 and chameleon

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix review comments

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Matrix YAO <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-06-06 09:29:51 +02:00

113 lines
4.5 KiB
Python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Llama4 model."""
import unittest
from transformers import is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch_large_accelerator,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import (
Llama4ForConditionalGeneration,
Llama4Processor,
)
@slow
@require_torch_large_accelerator
@require_read_token
class Llama4IntegrationTest(unittest.TestCase):
model_id = "meta-llama/Llama-4-Scout-17B-16E"
@classmethod
def setUpClass(cls):
cls.model = Llama4ForConditionalGeneration.from_pretrained(
"meta-llama/Llama-4-Scout-17B-16E",
device_map="auto",
torch_dtype=torch.float32,
attn_implementation="eager",
)
def setUp(self):
self.processor = Llama4Processor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages_1 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
self.messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Are these images identical?"},
],
},
]
def test_model_17b_16e_fp16(self):
EXPECTED_TEXT = [
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'
] # fmt: skip
inputs = self.processor.apply_chat_template(
self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True
).to(device=torch_device, dtype=self.model.dtype)
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
print(output_text)
self.assertEqual(output_text, EXPECTED_TEXT)
def test_model_17b_16e_batch(self):
inputs = self.processor.apply_chat_template(
[self.messages_1, self.messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(device=torch_device, dtype=torch.float32)
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = [
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white',
'system\n\nYou are a helpful assistant.user\n\nAre these images identical?assistant\n\nNo, these images are not identical. The first image shows a cow standing on a beach with a blue sky and a white cloud in the background.'
] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)