# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Testing suite for the PyTorch VideoLlava model.""" import gc import unittest import numpy as np import requests from huggingface_hub import hf_hub_download from transformers import ( VideoLlavaConfig, VideoLlavaForConditionalGeneration, VideoLlavaProcessor, is_torch_available, is_vision_available, ) from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor if is_torch_available(): import torch if is_vision_available(): from PIL import Image class VideoLlavaVisionText2TextModelTester: def __init__( self, parent, ignore_index=-100, image_token_index=0, video_token_index=1, projector_hidden_act="gelu", seq_length=13, num_frames=8, vision_feature_select_strategy="default", vision_feature_layer=-1, text_config={ "model_type": "llama", "seq_length": 13, "is_training": True, "use_input_mask": True, "use_token_type_ids": False, "use_labels": True, "vocab_size": 99, "hidden_size": 32, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 37, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, "max_position_embeddings": 2048, # we need it high because videos are 8 frames "type_vocab_size": 16, "type_sequence_label_size": 2, "initializer_range": 0.02, "num_labels": 3, "num_choices": 4, "pad_token_id": 0, }, is_training=True, vision_config={ "model_type": "clip_vision_model", "batch_size": 12, "image_size": 30, "patch_size": 2, "num_channels": 3, "is_training": True, "hidden_size": 32, "projection_dim": 32, "num_hidden_layers": 2, "num_attention_heads": 4, "intermediate_size": 37, "dropout": 0.1, "attention_dropout": 0.1, "initializer_range": 0.02, }, ): self.parent = parent self.ignore_index = ignore_index self.image_token_index = image_token_index self.video_token_index = video_token_index self.projector_hidden_act = projector_hidden_act self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer self.text_config = text_config self.vision_config = vision_config self.seq_length = seq_length self.num_frames = num_frames self.num_hidden_layers = text_config["num_hidden_layers"] self.vocab_size = text_config["vocab_size"] self.hidden_size = text_config["hidden_size"] self.num_attention_heads = text_config["num_attention_heads"] self.is_training = is_training self.batch_size = 5 self.num_channels = 3 self.image_size = 224 self.encoder_seq_length = 2044 def get_config(self): return VideoLlavaConfig( text_config=self.text_config, vision_config=self.vision_config, ignore_index=self.ignore_index, image_token_index=self.image_token_index, video_token_index=self.video_token_index, projector_hidden_act=self.projector_hidden_act, vision_feature_select_strategy=self.vision_feature_select_strategy, vision_feature_layer=self.vision_feature_layer, ) def prepare_config_and_inputs(self): pixel_values_videos = floats_tensor( [ self.batch_size, self.num_frames, self.vision_config["num_channels"], self.vision_config["image_size"], self.vision_config["image_size"], ] ) pixel_values_images = floats_tensor( [ self.batch_size, self.vision_config["num_channels"], self.vision_config["image_size"], self.vision_config["image_size"], ] ) config = self.get_config() return config, pixel_values_images, pixel_values_videos def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values_images, pixel_values_videos = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 attention_mask = input_ids.ne(1).to(torch_device) # we are giving 3 videos and 3 images. Need to pass in image and video tokens, both # also need to make sure no other special tokens are set input_ids[(input_ids == 0) | (input_ids == 1)] = 3 input_ids[:, 0] = config.video_token_index input_ids[:, 1:2] = config.image_token_index inputs_dict = { "pixel_values_videos": pixel_values_videos, "pixel_values_images": pixel_values_images, "input_ids": input_ids, "attention_mask": attention_mask, } return config, inputs_dict def prepare_config_and_inputs_for_batched_test(self): config_and_inputs = self.prepare_config_and_inputs() config, _, pixel_values_videos = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 attention_mask = input_ids.ne(1).to(torch_device) # make sure no other special tokens are set input_ids[(input_ids == 0) | (input_ids == 1)] = 3 input_ids[:, 0] = config.video_token_index inputs_dict = { "pixel_values_videos": pixel_values_videos, "input_ids": input_ids, "attention_mask": attention_mask, } return config, inputs_dict @require_torch class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): """ Model tester for `VideoLlavaForConditionalGeneration`. """ all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_pruning = False test_resize_embeddings = True test_head_masking = False def setUp(self): self.model_tester = VideoLlavaVisionText2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=VideoLlavaConfig, has_text_modality=False) @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing(self): pass @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing_use_reentrant(self): pass @unittest.skip( reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" ) def test_training_gradient_checkpointing_use_reentrant_false(self): pass @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`") def test_sdpa_can_compile_dynamic(self): pass @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`") def test_sdpa_can_dispatch_on_flash(self): pass def test_mixed_input(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device).eval() # test that the forward does not fail with torch.no_grad(): _ = model(**inputs) # if we remove some images from inputs leaving only one # image number mismatch error should raise inputs["pixel_values_images"] = inputs["pixel_values_images"][:1] with self.assertRaises(ValueError): _ = model(**inputs) def test_video_only_input(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device).eval() # replace video_token with dummy id which is not video token id # error that video-tokens and num-of-video-inputs mismatch will be raised inputs["input_ids"][:, 1:2] = 2 with self.assertRaises(ValueError): _ = model(**inputs) inputs["pixel_values_images"] = None _ = model(**inputs) def test_image_only_input(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config).to(torch_device).eval() # set dummy id, which is not image token id, same as above inputs["input_ids"][:, :1] = 2 with self.assertRaises(ValueError): _ = model(**inputs) inputs["pixel_values_videos"] = None _ = model(**inputs) def test_batching_equivalence(self): def recursive_check(batched_object, single_row_object, model_name, key): if isinstance(batched_object, (list, tuple)): for batched_object_value, single_row_object_value in zip(batched_object, single_row_object): recursive_check(batched_object_value, single_row_object_value, model_name, key) # do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects elif batched_object is None or not isinstance(batched_object, torch.Tensor): return elif batched_object.dim() == 0: return else: batched_row = batched_object[:1] self.assertFalse( torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}" ) self.assertFalse( torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}" ) self.assertFalse( torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}" ) self.assertFalse( torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" ) self.assertTrue( (torch.max(torch.abs(batched_row - single_row_object))) <= 1e-03, msg=( f"Batched and Single row outputs are not equal in {model_name} for key={key}. " f"Difference={torch.max(torch.abs(batched_row - single_row_object))}." ), ) config, batched_input = self.model_tester.prepare_config_and_inputs_for_batched_test() for model_class in self.all_model_classes: config.output_hidden_states = True model_name = model_class.__name__ batched_input_prepared = self._prepare_for_class(batched_input, model_class) model = model_class(config).to(torch_device).eval() single_row_input = {} for key, value in batched_input_prepared.items(): single_row_input[key] = value[:1] with torch.no_grad(): model_batched_output = model(**batched_input_prepared) model_row_output = model(**single_row_input) for key in model_batched_output: recursive_check(model_batched_output[key], model_row_output[key], model_name, key) # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) model.to(torch_device) model.eval() inputs = self._prepare_for_class(inputs_dict, model_class) input_ids = inputs["input_ids"] del inputs["input_ids"] del inputs["pixel_values_images"] del inputs["pixel_values_videos"] wte = model.get_input_embeddings() inputs["inputs_embeds"] = wte(input_ids) with torch.no_grad(): model(**inputs) # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs # while some other models require pixel_values to be present def test_inputs_embeds_matches_input_ids(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) model.to(torch_device) model.eval() inputs = self._prepare_for_class(inputs_dict, model_class) input_ids = inputs["input_ids"] del inputs["input_ids"] del inputs["pixel_values_images"] del inputs["pixel_values_videos"] inputs_embeds = model.get_input_embeddings()(input_ids) with torch.no_grad(): out_ids = model(input_ids=input_ids, **inputs)[0] out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] self.assertTrue(torch.allclose(out_embeds, out_ids)) @require_torch class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): def setUp(self): self.processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf") def tearDown(self): gc.collect() torch.cuda.empty_cache() @slow @require_bitsandbytes def test_small_model_integration_test(self): # Let' s make sure we test the preprocessing to replace what is used model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True) prompt = "USER: