mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
compute unpadding size from original size
This commit is contained in:
parent
ccbd094d69
commit
752cd27ad9
@ -136,7 +136,7 @@ def reshape_image(image_feature, num_patch_height, num_patch_width, height, widt
|
||||
return image_feature
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
def unpad_image(tensor, original_size, patch_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
@ -145,6 +145,8 @@ def unpad_image(tensor, original_size):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
patch_size (`int`):
|
||||
The size of each patch.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
@ -156,7 +158,8 @@ def unpad_image(tensor, original_size):
|
||||
)
|
||||
original_size = original_size.tolist()
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
num_patches_height, num_patches_width = tensor.shape[1:]
|
||||
current_height, current_width = num_patches_height * patch_size, num_patches_width * patch_size
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
@ -165,12 +168,14 @@ def unpad_image(tensor, original_size):
|
||||
scale_factor = current_width / original_width
|
||||
new_height = min(math.ceil(original_height * scale_factor), current_height)
|
||||
padding, r = divmod(current_height - new_height, 2)
|
||||
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
|
||||
num_padding_top, num_padding_bottom = padding // patch_size, (padding + r) // patch_size
|
||||
unpadded_tensor = tensor[:, num_padding_top : num_patches_height - num_padding_bottom, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = min(math.ceil(original_width * scale_factor), current_width)
|
||||
padding, r = divmod(current_width - new_width, 2)
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
|
||||
num_padding_left, num_padding_right = padding // patch_size, (padding + r) // patch_size
|
||||
unpadded_tensor = tensor[:, :, num_padding_left : num_patches_width - num_padding_right]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
@ -462,7 +467,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
)
|
||||
|
||||
image_feature = reshape_image(image_feature, num_patch_height, num_patch_width, height, width)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx], self.config.vision_config.patch_size
|
||||
)
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
|
@ -179,46 +179,43 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||
|
||||
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
||||
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
|
||||
|
||||
height_best_resolution, width_best_resolution = select_best_resolution(
|
||||
[orig_height, orig_width], image_grid_pinpoints
|
||||
)
|
||||
scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
|
||||
|
||||
patches_height = height // self.patch_size
|
||||
patches_width = width // self.patch_size
|
||||
unpadded_features, newline_features = self._get_unpadded_features(
|
||||
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
|
||||
)
|
||||
unpadded_features, newline_features = self._get_unpadded_features(orig_height, orig_width)
|
||||
# The base patch covers the entire image (+1 for the CLS)
|
||||
base_features = patches_height * patches_width + self.num_additional_image_tokens
|
||||
num_image_tokens = unpadded_features + newline_features + base_features
|
||||
return num_image_tokens
|
||||
|
||||
def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
|
||||
def _get_unpadded_features(self, height, width):
|
||||
"""
|
||||
Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
|
||||
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
|
||||
patches an image is divided into and get the number of features from that.
|
||||
"""
|
||||
current_height = patches_height * scale_height
|
||||
current_width = patches_width * scale_width
|
||||
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
|
||||
current_height, current_width = select_best_resolution([height, width], image_grid_pinpoints)
|
||||
num_patches_height = current_height // self.patch_size
|
||||
num_patches_width = current_width // self.patch_size
|
||||
|
||||
original_aspect_ratio = width / height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / width
|
||||
new_height = min(math.ceil(height * scale_factor), current_height)
|
||||
current_height = new_height
|
||||
padding, r = divmod(current_height - new_height, 2)
|
||||
num_padding_patches = padding // self.patch_size + (padding + r) // self.patch_size
|
||||
num_patches_height -= num_padding_patches
|
||||
else:
|
||||
scale_factor = current_height / height
|
||||
new_width = min(math.ceil(width * scale_factor), current_width)
|
||||
current_width = new_width
|
||||
padding, r = divmod(current_width - new_width, 2)
|
||||
num_padding_patches = padding // self.patch_size + (padding + r) // self.patch_size
|
||||
num_patches_width -= num_padding_patches
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
unpadded_features = num_patches_height * num_patches_width
|
||||
newline_features = num_patches_height
|
||||
return unpadded_features, newline_features
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
|
@ -300,12 +300,12 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
|
||||
# Test case width is padded
|
||||
pixel_values = floats_tensor([3, 400, 601])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size, 1)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
# Test case height is padded
|
||||
pixel_values = floats_tensor([3, 503, 400])
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size)
|
||||
unpadded_tensor = unpad_image(pixel_values, original_size, 1)
|
||||
self.assertEqual(unpadded_tensor.shape[1:], original_size)
|
||||
|
||||
def test_compare_padding_unpadding(self):
|
||||
@ -331,11 +331,9 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
padded_image = image_processor._pad_for_patching(
|
||||
resized_image, best_resolution, input_data_format=input_data_format
|
||||
)
|
||||
pad_h = padded_image.shape[1] - resized_image.shape[1]
|
||||
pad_h, r_h = divmod(pad_h, 2)
|
||||
pad_top, pad_bottom = pad_h, pad_h + r_h
|
||||
num_patch_pad_top = pad_top // patch_size
|
||||
num_patch_pad_bottom = pad_bottom // patch_size
|
||||
padding_h, r_h = divmod(padded_image.shape[1] - resized_image.shape[1], 2)
|
||||
num_padding_top = padding_h // patch_size
|
||||
num_padding_bottom = (padding_h + r_h) // patch_size
|
||||
|
||||
# prepare model config
|
||||
config = self.model_tester.get_config()
|
||||
@ -375,11 +373,11 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
)
|
||||
image_feature = reshape_image(image_feature, num_patch_height, num_patch_width, height, width)
|
||||
current_height = image_feature.shape[1]
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx], config.vision_config.patch_size)
|
||||
new_height = image_feature.shape[1]
|
||||
|
||||
# verify that the padding size and unpadding size are the same
|
||||
self.assertEqual(num_patch_pad_top + num_patch_pad_bottom, current_height - new_height)
|
||||
self.assertEqual(num_padding_top + num_padding_bottom, current_height - new_height)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user