[bug] fix llava processor to calculate unpadding size correctly (#37988)

* fix llava processor to calculate unpad size correctly

* repo consistency

* Revert "repo consistency" & "setUp in llava family"

This reverts commit 26a50af8db.

* add edge case test for padding & unpadding

* compute unpadding size from original size

* make test config explicit

* Revert "compute unpadding size from original size"

This reverts commit 752cd27ad9.

* Revert "add edge case test for padding & unpadding"

This reverts commit ccbd094d69.

* revert unpad logic

* remove irrelevant tests

* model test

* remove processor from model test

---------

Co-authored-by: jaycha <jaycha@ncsoft.com>
This commit is contained in:
youngrok cha 2025-05-13 22:49:09 +09:00 committed by GitHub
parent 67b3d45eb6
commit a5cc7a67d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 158 additions and 64 deletions

View File

@ -20,7 +20,6 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
@ -133,14 +132,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor

View File

@ -304,14 +304,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor

View File

@ -18,7 +18,6 @@ from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.llava_next.modeling_llava_next import (

View File

@ -644,7 +644,7 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=size["height"],
patch_size=size_tuple[0],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,

View File

@ -284,14 +284,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor

View File

@ -50,7 +50,7 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
if is_vision_available():
@ -298,18 +298,27 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
def test_unpad_image(self):
original_size = (400, 400)
def test_odd_sized_image(self):
# prepare model configuration
config = self.model_tester.get_config()
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# prepare input
num_image_tokens = 24
pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size])
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
input_ids[:, :num_image_tokens] = config.image_token_index
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# forward with odd-sized image input
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model(**inputs_dict)
@parameterized.expand(
[

View File

@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
import torch
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor
from transformers import LlamaTokenizerFast, LlavaNextProcessor
from transformers.testing_utils import (
require_vision,
)
@ -52,6 +54,10 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def get_image_processor(self, **kwargs):
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
@staticmethod
def prepare_processor_dict():
return {
@ -73,13 +79,16 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def test_image_token_filling(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
processor = self.processor_class.from_pretrained(self.tmpdirname)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.image_processor.crop_size = {"height": 336, "width": 336}
processor.image_processor.size = {"shortest_edge": 336}
processor.image_processor.image_grid_pinpoints = [[672, 336]]
# Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526
image_token_index = 32000
image = torch.randint(0, 2, (3, 503, 316))
expected_image_tokens = 1525
image_token_index = processor.image_token_id
messages = [
{

View File

@ -49,8 +49,6 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_next_video.modeling_llava_next_video import unpad_image
if is_vision_available():
from PIL import Image
@ -314,18 +312,27 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
def test_unpad_image(self):
original_size = (400, 400)
def test_odd_sized_image(self):
# prepare model configuration
config = self.model_tester.get_config()
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# prepare input
num_image_tokens = 24
pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size])
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
input_ids[:, :num_image_tokens] = config.image_token_index
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# forward with odd-sized image input
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model(**inputs_dict)
@parameterized.expand(
[

View File

@ -17,6 +17,8 @@ import shutil
import tempfile
import unittest
import torch
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
from transformers.testing_utils import require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -63,6 +65,10 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def get_video_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
@classmethod
def prepare_processor_dict(cls):
return {
@ -84,6 +90,31 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.image_processor.crop_size = {"height": 336, "width": 336}
processor.image_processor.size = {"shortest_edge": 336}
processor.image_processor.image_grid_pinpoints = [[672, 336]]
# Important to check with non square image
image = torch.randint(0, 2, (3, 503, 316))
expected_image_tokens = 1525
image_token_index = processor.image_token_id
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)

View File

@ -49,8 +49,6 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_onevision.modeling_llava_onevision import unpad_image
if is_vision_available():
from PIL import Image
@ -268,18 +266,27 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_unpad_image(self):
original_size = (400, 400)
def test_odd_sized_image(self):
# prepare model configuration
config = self.model_tester.get_config()
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# prepare input
num_image_tokens = 10
pixel_values = floats_tensor([1, 2, 3, config.vision_config.image_size, config.vision_config.image_size])
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
input_ids[:, :num_image_tokens] = config.image_token_index
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# forward with odd-sized image input
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model(**inputs_dict)
@parameterized.expand(
[

View File

@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
import torch
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -90,3 +93,33 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.image_processor.crop_size = {"height": 336, "width": 336}
processor.image_processor.size = {"shortest_edge": 336}
processor.image_processor.image_grid_pinpoints = [[672, 336]]
processor.num_image_tokens = (processor.image_processor.size["shortest_edge"] // processor.patch_size) ** 2
# Important to check with non square image
image = torch.randint(0, 2, (3, 503, 316))
expected_image_tokens = 1525
image_token_index = processor.image_token_id
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)