add edge case test for padding & unpadding

This commit is contained in:
cyr0930 2025-05-08 15:01:36 +00:00 committed by jaycha
parent d463b2ed93
commit ccbd094d69
3 changed files with 134 additions and 7 deletions

View File

@ -111,6 +111,31 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
return num_patches
def reshape_image(image_feature, num_patch_height, num_patch_width, height, width):
"""
Reshape the image feature tensor to recover height and width dimensions.
Args:
image_feature (`torch.Tensor`):
The image feature tensor to reshape.
num_patch_height (`int`):
The number of patches in the height dimension.
num_patch_width (`int`):
The number of patches in the width dimension.
height (`int`):
The height of each patch.
width (`int`):
The width of each patch.
Returns:
`torch.Tensor`: The reshaped image feature tensor.
"""
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
return image_feature
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
@ -436,9 +461,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = reshape_image(image_feature, num_patch_height, num_patch_width, height, width)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(

View File

@ -109,6 +109,32 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
return num_patches
# Copied from transformers.models.llava_next.modeling_llava_next.reshape_image
def reshape_image(image_feature, num_patch_height, num_patch_width, height, width):
"""
Reshape the image feature tensor to recover height and width dimensions.
Args:
image_feature (`torch.Tensor`):
The image feature tensor to reshape.
num_patch_height (`int`):
The number of patches in the height dimension.
num_patch_width (`int`):
The number of patches in the width dimension.
height (`int`):
The height of each patch.
width (`int`):
The width of each patch.
Returns:
`torch.Tensor`: The reshaped image feature tensor.
"""
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
return image_feature
# Copied from transformers.models.llava_next.modeling_llava_next.unpad_image
def unpad_image(tensor, original_size):
"""
@ -428,9 +454,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = reshape_image(image_feature, num_patch_height, num_patch_width, height, width)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
max_num_patches = int(vision_aspect_ratio.strip("anyres_max_"))
channels, curr_height, curr_width = image_feature.shape

View File

@ -26,6 +26,8 @@ from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.image_utils import ChannelDimension
from transformers.models.llava_next.image_processing_llava_next import select_best_resolution
from transformers.testing_utils import (
cleanup,
require_bitsandbytes,
@ -48,7 +50,12 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape,
image_size_to_num_patches,
reshape_image,
unpad_image,
)
if is_vision_available():
@ -301,6 +308,79 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
def test_compare_padding_unpadding(self):
# prepare inputs
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
image = torch.randint(0, 2, (3, 503, 316)).numpy() # edge case image size
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
inputs = processor(images=[image], text=[prompt], return_tensors="pt", padding=True).to(torch_device)
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
image_sizes = inputs["image_sizes"]
# compute padding size
patch_size = 14
image_processor = processor.image_processor
image_grid_pinpoints = image_processor.image_grid_pinpoints
input_data_format = ChannelDimension.FIRST
resample = image_processor.resample
best_resolution = select_best_resolution(image.shape[1:], image_grid_pinpoints)
resized_image = image_processor._resize_for_patching(
image, best_resolution, resample=resample, input_data_format=input_data_format
)
padded_image = image_processor._pad_for_patching(
resized_image, best_resolution, input_data_format=input_data_format
)
pad_h = padded_image.shape[1] - resized_image.shape[1]
pad_h, r_h = divmod(pad_h, 2)
pad_top, pad_bottom = pad_h, pad_h + r_h
num_patch_pad_top = pad_top // patch_size
num_patch_pad_bottom = pad_bottom // patch_size
# prepare model config
config = self.model_tester.get_config()
config.image_grid_pinpoints = image_grid_pinpoints
config.vision_config.image_size = image_processor.size["shortest_edge"]
config.vision_config.patch_size = patch_size
for model_class in self.all_model_classes:
# prepare model
model = model_class(config).to(torch_device)
# compute image features
image_features = model.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=config.vision_feature_layer,
vision_feature_select_strategy=config.vision_feature_select_strategy,
)
_, feature_lens = model.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=config.vision_feature_select_strategy,
image_newline=model.image_newline,
)
# verify number of image tokens and number of unpadded features are the same
n_image_tokens = (input_ids == processor.image_token_id).sum()
self.assertEqual(n_image_tokens, feature_lens)
# compute unpadding size
image_idx = 0
image_feature = image_features[image_idx][1:]
height = width = config.vision_config.image_size // config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
config.image_grid_pinpoints,
config.vision_config.image_size,
)
image_feature = reshape_image(image_feature, num_patch_height, num_patch_width, height, width)
current_height = image_feature.shape[1]
image_feature = unpad_image(image_feature, image_sizes[image_idx])
new_height = image_feature.shape[1]
# verify that the padding size and unpadding size are the same
self.assertEqual(num_patch_pad_top + num_patch_pad_bottom, current_height - new_height)
@parameterized.expand(
[
(-1,),