diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 2113224a9d0..718c1e96c98 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -20,7 +20,6 @@ from typing import List, Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -133,14 +132,14 @@ def unpad_image(tensor, original_size): if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width - new_height = min(math.ceil(original_height * scale_factor), current_height) - padding, r = divmod(current_height - new_height, 2) - unpadded_tensor = tensor[:, padding : current_height - (padding + r), :] + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height - new_width = min(math.ceil(original_width * scale_factor), current_width) - padding, r = divmod(current_width - new_width, 2) - unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)] + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 484030b3e38..b956b6f1bd0 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -304,14 +304,14 @@ def unpad_image(tensor, original_size): if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width - new_height = min(math.ceil(original_height * scale_factor), current_height) - padding, r = divmod(current_height - new_height, 2) - unpadded_tensor = tensor[:, padding : current_height - (padding + r), :] + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height - new_width = min(math.ceil(original_width * scale_factor), current_width) - padding, r = divmod(current_width - new_width, 2) - unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)] + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index b00fc91813c..88ef859ad29 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from transformers.models.llava_next.modeling_llava_next import ( diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py index 502e9ac74c4..a664cfa7b64 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision.py @@ -644,7 +644,7 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor): image, image_grid_pinpoints, size=size_tuple, - patch_size=size["height"], + patch_size=size_tuple[0], resample=resample, data_format=input_data_format, input_data_format=input_data_format, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index a4271297d03..18475d381a0 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -284,14 +284,14 @@ def unpad_image(tensor, original_size): if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width - new_height = min(math.ceil(original_height * scale_factor), current_height) - padding, r = divmod(current_height - new_height, 2) - unpadded_tensor = tensor[:, padding : current_height - (padding + r), :] + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height - new_width = min(math.ceil(original_width * scale_factor), current_width) - padding, r = divmod(current_width - new_width, 2) - unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)] + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index c8789f0ba38..93142f1da68 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -50,7 +50,7 @@ from ...test_modeling_common import ( if is_torch_available(): import torch - from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image + from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches if is_vision_available(): @@ -298,18 +298,27 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes image_sizes = torch.cat([image_sizes, image_sizes], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes) - def test_unpad_image(self): - original_size = (400, 400) + def test_odd_sized_image(self): + # prepare model configuration + config = self.model_tester.get_config() - # Test case width is padded - pixel_values = floats_tensor([3, 400, 601]) - unpadded_tensor = unpad_image(pixel_values, original_size) - self.assertEqual(unpadded_tensor.shape[1:], original_size) + # prepare input + num_image_tokens = 24 + pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size]) + input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2 + input_ids[:, :num_image_tokens] = config.image_token_index + attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) + inputs_dict = { + "pixel_values": pixel_values, + "image_sizes": torch.tensor([[13, 16]]), # odd-sized image + "input_ids": input_ids, + "attention_mask": attention_mask, + } - # Test case height is padded - pixel_values = floats_tensor([3, 503, 400]) - unpadded_tensor = unpad_image(pixel_values, original_size) - self.assertEqual(unpadded_tensor.shape[1:], original_size) + # forward with odd-sized image input + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model(**inputs_dict) @parameterized.expand( [ diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index a565212b49e..2adf527d782 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json +import shutil import tempfile import unittest import torch -from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor +from transformers import LlamaTokenizerFast, LlavaNextProcessor from transformers.testing_utils import ( require_vision, ) @@ -52,6 +54,10 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase): def get_image_processor(self, **kwargs): return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + @staticmethod def prepare_processor_dict(): return { @@ -73,13 +79,16 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase): self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) def test_image_token_filling(self): - processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") + processor = self.processor_class.from_pretrained(self.tmpdirname) processor.patch_size = 14 processor.vision_feature_select_strategy = "default" + processor.image_processor.crop_size = {"height": 336, "width": 336} + processor.image_processor.size = {"shortest_edge": 336} + processor.image_processor.image_grid_pinpoints = [[672, 336]] # Important to check with non square image - image = torch.randint(0, 2, (3, 500, 316)) - expected_image_tokens = 1526 - image_token_index = 32000 + image = torch.randint(0, 2, (3, 503, 316)) + expected_image_tokens = 1525 + image_token_index = processor.image_token_id messages = [ { diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index e68a1e4362e..577dd669860 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -49,8 +49,6 @@ from ...test_modeling_common import ( if is_torch_available(): import torch - from transformers.models.llava_next_video.modeling_llava_next_video import unpad_image - if is_vision_available(): from PIL import Image @@ -314,18 +312,27 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati image_sizes = torch.cat([image_sizes, image_sizes], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes) - def test_unpad_image(self): - original_size = (400, 400) + def test_odd_sized_image(self): + # prepare model configuration + config = self.model_tester.get_config() - # Test case width is padded - pixel_values = floats_tensor([3, 400, 601]) - unpadded_tensor = unpad_image(pixel_values, original_size) - self.assertEqual(unpadded_tensor.shape[1:], original_size) + # prepare input + num_image_tokens = 24 + pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size]) + input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2 + input_ids[:, :num_image_tokens] = config.image_token_index + attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) + inputs_dict = { + "pixel_values": pixel_values, + "image_sizes": torch.tensor([[13, 16]]), # odd-sized image + "input_ids": input_ids, + "attention_mask": attention_mask, + } - # Test case height is padded - pixel_values = floats_tensor([3, 503, 400]) - unpadded_tensor = unpad_image(pixel_values, original_size) - self.assertEqual(unpadded_tensor.shape[1:], original_size) + # forward with odd-sized image input + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model(**inputs_dict) @parameterized.expand( [ diff --git a/tests/models/llava_next_video/test_processor_llava_next_video.py b/tests/models/llava_next_video/test_processor_llava_next_video.py index f35cacb5fd2..49fa33ffc14 100644 --- a/tests/models/llava_next_video/test_processor_llava_next_video.py +++ b/tests/models/llava_next_video/test_processor_llava_next_video.py @@ -17,6 +17,8 @@ import shutil import tempfile import unittest +import torch + from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor from transformers.testing_utils import require_vision from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available @@ -63,6 +65,10 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase): def get_video_processor(self, **kwargs): return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + @classmethod def prepare_processor_dict(cls): return { @@ -84,6 +90,31 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_dict = self.prepare_processor_dict() self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) - @classmethod - def tearDownClass(cls): - shutil.rmtree(cls.tmpdirname, ignore_errors=True) + def test_image_token_filling(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + processor.patch_size = 14 + processor.vision_feature_select_strategy = "default" + processor.image_processor.crop_size = {"height": 336, "width": 336} + processor.image_processor.size = {"shortest_edge": 336} + processor.image_processor.image_grid_pinpoints = [[672, 336]] + # Important to check with non square image + image = torch.randint(0, 2, (3, 503, 316)) + expected_image_tokens = 1525 + image_token_index = processor.image_token_id + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens) diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index ba95c330dbd..fba739b9956 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -49,8 +49,6 @@ from ...test_modeling_common import ( if is_torch_available(): import torch - from transformers.models.llava_onevision.modeling_llava_onevision import unpad_image - if is_vision_available(): from PIL import Image @@ -268,18 +266,27 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] torch.testing.assert_close(out_embeds, out_ids) - def test_unpad_image(self): - original_size = (400, 400) + def test_odd_sized_image(self): + # prepare model configuration + config = self.model_tester.get_config() - # Test case width is padded - pixel_values = floats_tensor([3, 400, 601]) - unpadded_tensor = unpad_image(pixel_values, original_size) - self.assertEqual(unpadded_tensor.shape[1:], original_size) + # prepare input + num_image_tokens = 10 + pixel_values = floats_tensor([1, 2, 3, config.vision_config.image_size, config.vision_config.image_size]) + input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2 + input_ids[:, :num_image_tokens] = config.image_token_index + attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device) + inputs_dict = { + "pixel_values": pixel_values, + "image_sizes": torch.tensor([[13, 16]]), # odd-sized image + "input_ids": input_ids, + "attention_mask": attention_mask, + } - # Test case height is padded - pixel_values = floats_tensor([3, 503, 400]) - unpadded_tensor = unpad_image(pixel_values, original_size) - self.assertEqual(unpadded_tensor.shape[1:], original_size) + # forward with odd-sized image input + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model(**inputs_dict) @parameterized.expand( [ diff --git a/tests/models/llava_onevision/test_processor_llava_onevision.py b/tests/models/llava_onevision/test_processor_llava_onevision.py index 419c6d0acfd..d4bd5f00025 100644 --- a/tests/models/llava_onevision/test_processor_llava_onevision.py +++ b/tests/models/llava_onevision/test_processor_llava_onevision.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json import shutil import tempfile import unittest +import torch + from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available @@ -90,3 +93,33 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase): # so we check if the same template is loaded processor_dict = self.prepare_processor_dict() self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) + + def test_image_token_filling(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) + processor.patch_size = 14 + processor.vision_feature_select_strategy = "default" + processor.image_processor.crop_size = {"height": 336, "width": 336} + processor.image_processor.size = {"shortest_edge": 336} + processor.image_processor.image_grid_pinpoints = [[672, 336]] + processor.num_image_tokens = (processor.image_processor.size["shortest_edge"] // processor.patch_size) ** 2 + # Important to check with non square image + image = torch.randint(0, 2, (3, 503, 316)) + expected_image_tokens = 1525 + image_token_index = processor.image_token_id + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + inputs = processor( + text=[processor.apply_chat_template(messages)], + images=[image], + return_tensors="pt", + ) + image_tokens = (inputs["input_ids"] == image_token_index).sum().item() + self.assertEqual(expected_image_tokens, image_tokens)