# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Testing suite for the PyTorch Llama4 model.""" import unittest from transformers import is_torch_available from transformers.testing_utils import ( require_read_token, require_torch_large_accelerator, slow, torch_device, ) if is_torch_available(): import torch from transformers import ( Llama4ForConditionalGeneration, Llama4Processor, ) @slow @require_torch_large_accelerator @require_read_token class Llama4IntegrationTest(unittest.TestCase): model_id = "meta-llama/Llama-4-Scout-17B-16E" @classmethod def setUpClass(cls): cls.model = Llama4ForConditionalGeneration.from_pretrained( "meta-llama/Llama-4-Scout-17B-16E", device_map="auto", torch_dtype=torch.float32, attn_implementation="eager", ) def setUp(self): self.processor = Llama4Processor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E", padding_side="left") url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" self.messages_1 = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, { "role": "user", "content": [ {"type": "image", "url": url}, {"type": "text", "text": "What is shown in this image?"}, ], }, ] self.messages_2 = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, { "role": "user", "content": [ { "type": "image", "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", }, {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, {"type": "text", "text": "Are these images identical?"}, ], }, ] def test_model_17b_16e_fp16(self): EXPECTED_TEXT = [ 'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white' ] # fmt: skip inputs = self.processor.apply_chat_template( self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True ).to(device=torch_device, dtype=self.model.dtype) output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False) output_text = self.processor.batch_decode(output, skip_special_tokens=True) print(output_text) self.assertEqual(output_text, EXPECTED_TEXT) def test_model_17b_16e_batch(self): inputs = self.processor.apply_chat_template( [self.messages_1, self.messages_2], tokenize=True, return_dict=True, return_tensors="pt", padding=True, add_generation_prompt=True, ).to(device=torch_device, dtype=torch.float32) output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False) output_text = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_TEXTS = [ 'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white', 'system\n\nYou are a helpful assistant.user\n\nAre these images identical?assistant\n\nNo, these images are not identical. The first image shows a cow standing on a beach with a blue sky and a white cloud in the background.' ] # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS)