# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Testing suite for the PyTorch Llava-NeXT model.""" import copy import unittest import requests from huggingface_hub import hf_hub_download from parameterized import parameterized from transformers import ( AutoProcessor, LlavaNextConfig, LlavaNextForConditionalGeneration, LlavaNextModel, is_torch_available, is_vision_available, ) from transformers.testing_utils import ( cleanup, require_bitsandbytes, require_torch, slow, torch_device, ) from transformers.utils import check_torch_load_is_safe from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, ) if is_torch_available(): import torch from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches if is_vision_available(): from PIL import Image class LlavaNextVisionText2TextModelTester: def __init__( self, parent, ignore_index=-100, image_token_index=0, projector_hidden_act="gelu", seq_length=7, vision_feature_select_strategy="default", vision_feature_layer=-1, text_config={ "model_type": "llama", "seq_length": 7, "is_training": True, "use_input_mask": True, "use_token_type_ids": False, "use_labels": True, "vocab_size": 99, "hidden_size": 32, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 37, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 580, "type_vocab_size": 16, "type_sequence_label_size": 2, "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, "pad_token_id": 1, }, is_training=True, vision_config={ "image_size": 8, "patch_size": 4, "num_channels": 3, "is_training": True, "hidden_size": 32, "projection_dim": 32, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 37, "dropout": 0.1, "attention_dropout": 0.1, "initializer_range": 0.02, }, ): self.parent = parent self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer self.text_config = text_config self.vision_config = vision_config self.pad_token_id = text_config["pad_token_id"] self.num_hidden_layers = text_config["num_hidden_layers"] self.vocab_size = text_config["vocab_size"] self.hidden_size = text_config["hidden_size"] self.num_attention_heads = text_config["num_attention_heads"] self.is_training = is_training self.batch_size = 3 self.num_channels = 3 self.image_size = 30 self.image_grid_pinpoints = [[16, 16]] self.num_image_tokens = 24 self.seq_length = seq_length + self.num_image_tokens self.encoder_seq_length = self.seq_length def get_config(self): return LlavaNextConfig( text_config=self.text_config, vision_config=self.vision_config, ignore_index=self.ignore_index, image_token_index=self.image_token_index, projector_hidden_act=self.projector_hidden_act, vision_feature_select_strategy=self.vision_feature_select_strategy, vision_feature_layer=self.vision_feature_layer, image_grid_pinpoints=self.image_grid_pinpoints, image_seq_length=self.num_image_tokens, ) def prepare_config_and_inputs(self): pixel_values = floats_tensor( [ self.batch_size, 5, self.vision_config["num_channels"], self.vision_config["image_size"], self.vision_config["image_size"], ] ) config = self.get_config() return config, pixel_values def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2 attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) input_ids[input_ids == config.image_token_index] = self.pad_token_id input_ids[:, : self.num_image_tokens] = config.image_token_index inputs_dict = { "pixel_values": pixel_values, "image_sizes": torch.tensor( [[self.vision_config["image_size"], self.vision_config["image_size"]]] * self.batch_size ), "input_ids": input_ids, "attention_mask": attention_mask, } return config, inputs_dict @require_torch class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): """ Model tester for `LlavaNextForConditionalGeneration`. """ all_model_classes = ( ( LlavaNextModel, LlavaNextForConditionalGeneration, ) if is_torch_available() else () ) pipeline_model_mapping = {"image-text-to-text": LlavaNextForConditionalGeneration} if is_torch_available() else {} test_pruning = False test_head_masking = False _is_composite = True def setUp(self): self.model_tester = LlavaNextVisionText2TextModelTester(self) common_properties = ["image_token_index", "vision_feature_layer", "image_seq_length"] self.config_tester = ConfigTester( self, config_class=LlavaNextConfig, has_text_modality=False, common_properties=common_properties ) def test_config(self): self.config_tester.run_common_tests() def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() configs_no_init = _config_zero_init(config) for model_class in self.all_model_classes: model = model_class(config=configs_no_init) for name, param in model.named_parameters(): if "image_newline" in name: continue elif param.requires_grad: self.assertIn( ((param.data.mean() * 1e9).round() / 1e9).item(), [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) model.to(torch_device) model.eval() inputs = self._prepare_for_class(inputs_dict, model_class) input_ids = inputs["input_ids"] del inputs["input_ids"] del inputs["pixel_values"] wte = model.get_input_embeddings() inputs["inputs_embeds"] = wte(input_ids) with torch.no_grad(): model(**inputs) # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs # while some other models require pixel_values to be present def test_inputs_embeds_matches_input_ids(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) model.to(torch_device) model.eval() inputs = self._prepare_for_class(inputs_dict, model_class) input_ids = inputs["input_ids"] del inputs["input_ids"] del inputs["pixel_values"] inputs_embeds = model.get_input_embeddings()(input_ids) with torch.no_grad(): out_ids = model(input_ids=input_ids, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] torch.testing.assert_close(out_embeds, out_ids) def test_mismatching_num_image_tokens(self): """ Tests that VLMs through an error with explicit message saying what is wrong when number of images don't match number of image tokens in the text. Also we need to test multi-image cases when one prompr has multiple image tokens. """ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device) curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further _ = model(**curr_input_dict) # successful forward with no modifications # remove one image but leave the image token in text curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] curr_input_dict["image_sizes"] = curr_input_dict["image_sizes"][-1:, ...] with self.assertRaises(ValueError): _ = model(**curr_input_dict) # simulate multi-image case by concatenating inputs where each has exactly one image/image-token input_ids = curr_input_dict["input_ids"][:1] pixel_values = curr_input_dict["pixel_values"][:1] image_sizes = curr_input_dict["image_sizes"][:1] input_ids = torch.cat([input_ids, input_ids], dim=0) # one image and two image tokens raise an error with self.assertRaises(ValueError): _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes) # two images and two image tokens don't raise an error pixel_values = torch.cat([pixel_values, pixel_values], dim=0) image_sizes = torch.cat([image_sizes, image_sizes], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes) def test_odd_sized_image(self): # prepare model configuration config = self.model_tester.get_config() # prepare input num_image_tokens = 24 pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size]) input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2 input_ids[:, :num_image_tokens] = config.image_token_index attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) inputs_dict = { "pixel_values": pixel_values, "image_sizes": torch.tensor([[13, 16]]), # odd-sized image "input_ids": input_ids, "attention_mask": attention_mask, } # forward with odd-sized image input for model_class in self.all_model_classes: model = model_class(config).to(torch_device) model(**inputs_dict) @parameterized.expand( [ (-1,), ([-1],), ([-1, -2],), ], ) def test_vision_feature_layers(self, vision_feature_layer): """ Test that we can use either one vision feature layer, or a list of vision feature layers. """ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.vision_feature_layer = vision_feature_layer num_feature_layers = 1 if isinstance(vision_feature_layer, int) else len(vision_feature_layer) hidden_size = config.vision_config.hidden_size expected_features = hidden_size * num_feature_layers for model_class in self.all_model_classes: model = model_class(config).to(torch_device) # We should have the right number of input features, # and should be able to run a forward pass without exploding base_model = getattr(model, "model", model) assert base_model.multi_modal_projector.linear_1.in_features == expected_features model(**input_dict) @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing(self): pass @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing_use_reentrant(self): pass @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing_use_reentrant_false(self): pass @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass @require_torch class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): def setUp(self): self.processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" self.image = Image.open(requests.get(url, stream=True).raw) self.prompt = "[INST] \nWhat is shown in this image? [/INST]" def tearDown(self): cleanup(torch_device, gc_collect=True) @slow @require_bitsandbytes def test_small_model_integration_test(self): model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", load_in_4bit=True, ) inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt").to(torch_device) # verify inputs against original implementation filepath = hf_hub_download( repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset", ) check_torch_load_is_safe() original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True) # replace -200 by image_token_index (since we use token ID = 32000 for the image token) # remove image token indices because HF impl expands image tokens `image_seq_length` times original_input_ids = original_input_ids[original_input_ids != -200] observed_input_ids = inputs.input_ids[inputs.input_ids != model.config.image_token_index] assert original_input_ids[0].tolist() == observed_input_ids[0].tolist() filepath = hf_hub_download( repo_id="nielsr/test-image", filename="llava_1_6_pixel_values.pt", repo_type="dataset", ) check_torch_load_is_safe() original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True) assert torch.allclose( original_pixel_values, inputs.pixel_values.to(device="cpu", dtype=original_pixel_values.dtype) ) # verify generation output = model.generate(**inputs, max_new_tokens=100) EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L' self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) @slow @require_bitsandbytes def test_small_model_integration_test_batch(self): model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", load_in_4bit=True ) url = "http://images.cocodataset.org/val2017/000000039769.jpg" cats_image = Image.open(requests.get(url, stream=True).raw) inputs = self.processor( images=[self.image, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True, ).to(torch_device) # it should not matter whether two images are the same size or not output = model.generate(**inputs, max_new_tokens=20) EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) @slow @require_bitsandbytes def test_small_model_integration_test_unk_token(self): # related to (#29835) model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", load_in_4bit=True, ) prompt_with_unk = "[INST] \nWhat is shown in this image? [/INST]" inputs = self.processor(images=self.image, text=prompt_with_unk, return_tensors="pt") # verify single forward pass inputs = inputs.to(torch_device) with torch.no_grad(): output = model(**inputs) # verify generation output = model.generate(**inputs, max_new_tokens=40) EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart' # fmt: skip self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) @slow @require_bitsandbytes def test_small_model_integration_test_batch_different_resolutions(self): model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", load_in_4bit=True, ) url = "http://images.cocodataset.org/val2017/000000039769.jpg" lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e" cats_image = Image.open(requests.get(url, stream=True).raw) lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) inputs = self.processor( images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True ).to(torch_device) pixel_values = inputs["pixel_values"] # verify pixel values are padded correctly with 0 when one image has more num_patches than the other image_num_patches = [ image_size_to_num_patches( image_size=imsize, grid_pinpoints=model.config.image_grid_pinpoints, patch_size=model.config.vision_config.image_size, ) for imsize in inputs["image_sizes"] ] for pix_val, num_patch in zip(pixel_values, image_num_patches): self.assertTrue(torch.all(pix_val[num_patch:] == 0)) # pad on the right for i in range(num_patch): self.assertFalse(torch.all(pix_val[i : i + 1] == 0)) # no padding expected in any of patches # verify generation output = model.generate(**inputs, max_new_tokens=50) EXPECTED_DECODED_TEXT = "[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given" self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) @slow @require_bitsandbytes def test_small_model_integration_test_batch_matches_single(self): model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", load_in_4bit=True, ) url = "http://images.cocodataset.org/val2017/000000039769.jpg" lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e" cats_image = Image.open(requests.get(url, stream=True).raw) lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) inputs_batched = self.processor( images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True ).to(torch_device) inputs_single = self.processor(images=lowres_img, text=self.prompt, return_tensors="pt", padding=True).to( torch_device ) # verify generation output_batched = model.generate(**inputs_batched, max_new_tokens=50) output_single = model.generate(**inputs_single, max_new_tokens=50) self.assertEqual( self.processor.decode(output_batched[0], skip_special_tokens=True), self.processor.decode(output_single[0], skip_special_tokens=True), ) @slow @require_bitsandbytes def test_small_model_integration_test_full_vision_state_selection(self): model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", load_in_4bit=True, ) # test that changing `strategy` won't error out model.vision_feature_select_strategy = "full" inputs = self.processor(self.prompt, self.image, return_tensors="pt").to(model.device) # verify generation output = model.generate(**inputs, max_new_tokens=30) EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes' # fmt: skip self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) @slow def test_granite_vision(self): """ Check the expected output of a granite vision model, which leverages multiple vision feature layers and a visual encoder with no CLS (siglip). """ granite_model_path = "ibm-granite/granite-vision-3.1-2b-preview" model = LlavaNextForConditionalGeneration.from_pretrained(granite_model_path) self.processor = AutoProcessor.from_pretrained(granite_model_path) prompt = "<|user|>\n\nWhat is shown in this image?\n<|assistant|>\n" inputs = self.processor(prompt, self.image, return_tensors="pt").to(model.device) # verify generation output = model.generate(**inputs, max_new_tokens=30) EXPECTED_DECODED_TEXT = "<|user|>\n\nWhat is shown in this image?\n<|assistant|>\nThe image displays a radar chart comparing the performance of various machine learning models." # fmt: skip self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, )