compute unpadding size from original size

This commit is contained in:
cyr0930 2025-05-09 06:05:32 +00:00 committed by jaycha
parent ccbd094d69
commit 752cd27ad9
3 changed files with 34 additions and 32 deletions

View File

@ -136,7 +136,7 @@ def reshape_image(image_feature, num_patch_height, num_patch_width, height, widt
return image_feature
def unpad_image(tensor, original_size):
def unpad_image(tensor, original_size, patch_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
@ -145,6 +145,8 @@ def unpad_image(tensor, original_size):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
patch_size (`int`):
The size of each patch.
Returns:
`torch.Tensor`: The unpadded image tensor.
@ -156,7 +158,8 @@ def unpad_image(tensor, original_size):
)
original_size = original_size.tolist()
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
num_patches_height, num_patches_width = tensor.shape[1:]
current_height, current_width = num_patches_height * patch_size, num_patches_width * patch_size
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
@ -165,12 +168,14 @@ def unpad_image(tensor, original_size):
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
num_padding_top, num_padding_bottom = padding // patch_size, (padding + r) // patch_size
unpadded_tensor = tensor[:, num_padding_top : num_patches_height - num_padding_bottom, :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
num_padding_left, num_padding_right = padding // patch_size, (padding + r) // patch_size
unpadded_tensor = tensor[:, :, num_padding_left : num_patches_width - num_padding_right]
return unpadded_tensor
@ -462,7 +467,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
)
image_feature = reshape_image(image_feature, num_patch_height, num_patch_width, height, width)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = unpad_image(
image_feature, image_sizes[image_idx], self.config.vision_config.patch_size
)
if image_newline is not None:
image_feature = torch.cat(
(

View File

@ -179,46 +179,43 @@ class LlavaNextProcessor(ProcessorMixin):
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
height_best_resolution, width_best_resolution = select_best_resolution(
[orig_height, orig_width], image_grid_pinpoints
)
scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
patches_height = height // self.patch_size
patches_width = width // self.patch_size
unpadded_features, newline_features = self._get_unpadded_features(
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
)
unpadded_features, newline_features = self._get_unpadded_features(orig_height, orig_width)
# The base patch covers the entire image (+1 for the CLS)
base_features = patches_height * patches_width + self.num_additional_image_tokens
num_image_tokens = unpadded_features + newline_features + base_features
return num_image_tokens
def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
def _get_unpadded_features(self, height, width):
"""
Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
patches an image is divided into and get the number of features from that.
"""
current_height = patches_height * scale_height
current_width = patches_width * scale_width
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
current_height, current_width = select_best_resolution([height, width], image_grid_pinpoints)
num_patches_height = current_height // self.patch_size
num_patches_width = current_width // self.patch_size
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / width
new_height = min(math.ceil(height * scale_factor), current_height)
current_height = new_height
padding, r = divmod(current_height - new_height, 2)
num_padding_patches = padding // self.patch_size + (padding + r) // self.patch_size
num_patches_height -= num_padding_patches
else:
scale_factor = current_height / height
new_width = min(math.ceil(width * scale_factor), current_width)
current_width = new_width
padding, r = divmod(current_width - new_width, 2)
num_padding_patches = padding // self.patch_size + (padding + r) // self.patch_size
num_patches_width -= num_padding_patches
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
unpadded_features = num_patches_height * num_patches_width
newline_features = num_patches_height
return unpadded_features, newline_features
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):

View File

@ -300,12 +300,12 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
unpadded_tensor = unpad_image(pixel_values, original_size, 1)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
unpadded_tensor = unpad_image(pixel_values, original_size, 1)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
def test_compare_padding_unpadding(self):
@ -331,11 +331,9 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
padded_image = image_processor._pad_for_patching(
resized_image, best_resolution, input_data_format=input_data_format
)
pad_h = padded_image.shape[1] - resized_image.shape[1]
pad_h, r_h = divmod(pad_h, 2)
pad_top, pad_bottom = pad_h, pad_h + r_h
num_patch_pad_top = pad_top // patch_size
num_patch_pad_bottom = pad_bottom // patch_size
padding_h, r_h = divmod(padded_image.shape[1] - resized_image.shape[1], 2)
num_padding_top = padding_h // patch_size
num_padding_bottom = (padding_h + r_h) // patch_size
# prepare model config
config = self.model_tester.get_config()
@ -375,11 +373,11 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
)
image_feature = reshape_image(image_feature, num_patch_height, num_patch_width, height, width)
current_height = image_feature.shape[1]
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = unpad_image(image_feature, image_sizes[image_idx], config.vision_config.patch_size)
new_height = image_feature.shape[1]
# verify that the padding size and unpadding size are the same
self.assertEqual(num_patch_pad_top + num_patch_pad_bottom, current_height - new_height)
self.assertEqual(num_padding_top + num_padding_bottom, current_height - new_height)
@parameterized.expand(
[