[llava] one pixel is missing from padding when length is odd (#37819)

* [fix] one pixel should be added when length is odd

* [fix] add vision_aspect_ratio args & typo

* [fix] style

* [fix] do not fix fast file directly

* [fix] convert using modular

* remove duplicate codes

* match unpad logic with pad logic

* test odd-sized images for llava & aria

* test unpad odd-sized padding for llava family

* fix style

* add kwarg to onvision modular

* move vision_aspect_ratio from image_processor to processor
(llava_onevision)
This commit is contained in:
youngrok cha 2025-05-06 20:11:26 +09:00 committed by GitHub
parent 9981214d32
commit acded47fe7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 234 additions and 133 deletions

View File

@ -466,7 +466,12 @@ setup(
package_data={"": ["**/*.cu", "**/*.cpp", "**/*.cuh", "**/*.h", "**/*.pyx", "py.typed"]},
zip_safe=False,
extras_require=extras,
entry_points={"console_scripts": ["transformers=transformers.commands.transformers_cli:main", "transformers-cli=transformers.commands.transformers_cli:main_cli"]},
entry_points={
"console_scripts": [
"transformers=transformers.commands.transformers_cli:main",
"transformers-cli=transformers.commands.transformers_cli:main_cli",
]
},
python_requires=">=3.9.0",
install_requires=list(install_requires),
classifiers=[

View File

@ -18,12 +18,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
@ -71,23 +70,6 @@ def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> Li
return patches
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class AriaImageProcessor(BaseImageProcessor):
"""
A vision processor for the Aria model that handles image preprocessing.
@ -375,7 +357,7 @@ class AriaImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -389,12 +371,12 @@ class AriaImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
@ -20,7 +19,7 @@ import numpy as np
from ...activations import ACT2FN
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
@ -461,23 +460,6 @@ class AriaProjector(nn.Module):
return out
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class AriaImageProcessor(BaseImageProcessor):
"""
A vision processor for the Aria model that handles image preprocessing.
@ -765,7 +747,7 @@ class AriaImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -779,12 +761,12 @@ class AriaImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -14,12 +14,17 @@
# limitations under the License.
"""Image processor class for LLaVa-NeXT."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
from ...image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_patch_output_size,
get_size_dict,
select_best_resolution,
)
from ...image_transforms import (
PaddingMode,
convert_to_rgb,
@ -99,23 +104,6 @@ def expand_to_square(image: np.array, background_color, input_data_format) -> np
return result
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class LlavaNextImageProcessor(BaseImageProcessor):
r"""
Constructs a LLaVa-NeXT image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
@ -429,7 +417,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -443,12 +431,12 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -164,10 +164,10 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
return padded_image

View File

@ -139,14 +139,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor

View File

@ -262,14 +262,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor

View File

@ -14,12 +14,17 @@
# limitations under the License.
"""Image processor class for LLaVa-Onevision."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
from ...image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_patch_output_size,
get_size_dict,
select_best_resolution,
)
from ...image_transforms import (
PaddingMode,
convert_to_rgb,
@ -99,24 +104,6 @@ def expand_to_square(image: np.array, background_color, input_data_format) -> np
return result
# Copied from transformers.models.llava_next.image_processing_llava_next._get_patch_output_size
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class LlavaOnevisionImageProcessor(BaseImageProcessor):
r"""
Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
@ -321,7 +308,7 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -336,12 +323,12 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -146,10 +146,10 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
return padded_image

View File

@ -140,14 +140,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor

View File

@ -70,6 +70,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
Special token used to denote image location.
video_token (`str`, *optional*, defaults to `"<video>"`):
Special token used to denote video location.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
"""
attributes = ["image_processor", "tokenizer", "video_processor"]
@ -79,6 +81,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
"vision_feature_select_strategy",
"image_token",
"video_token",
"vision_aspect_ratio",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
@ -94,6 +97,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
chat_template=None,
image_token="<image>",
video_token="<video>",
vision_aspect_ratio="anyres_max_9",
**kwargs,
):
self.num_image_tokens = num_image_tokens
@ -110,6 +114,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.vision_aspect_ratio = vision_aspect_ratio
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
def __call__(
@ -264,7 +269,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * patches_height**2))
max_num_patches = int(self.vision_aspect_ratio.strip("anyres_max_"))
ratio = math.sqrt(current_height * current_width / (max_num_patches * patches_height**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(current_width // ratio)
newline_features = int(current_height // ratio)

View File

@ -17,7 +17,7 @@ import unittest
import numpy as np
from transformers.image_utils import PILImageResampling
from transformers.image_utils import ChannelDimension, PILImageResampling
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
@ -264,3 +264,41 @@ class AriaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)
def test_pad_for_patching(self):
for image_processing_class in self.image_processor_list:
if image_processing_class == self.fast_image_processing_class:
numpify = False
torchify = True
input_data_format = image_processing_class.data_format
else:
numpify = True
torchify = False
input_data_format = ChannelDimension.LAST
image_processing = image_processing_class(**self.image_processor_dict)
# Create odd-sized images
image_input = self.image_processor_tester.prepare_image_inputs(
batch_size=1,
max_resolution=400,
num_images=1,
equal_resolution=True,
numpify=numpify,
torchify=torchify,
)[0][0]
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
# Test odd-width
image_shape = (400, 601)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)
# Test odd-height
image_shape = (503, 400)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)

View File

@ -16,7 +16,7 @@ import unittest
import numpy as np
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension
from transformers.models.llava_next.image_processing_llava_next import select_best_resolution
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -230,3 +230,38 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Image processor should return same pixel values, independently of ipnut format
self.assertTrue((encoded_images_nested == encoded_images).all())
def test_pad_for_patching(self):
for image_processing_class in self.image_processor_list:
if image_processing_class == self.fast_image_processing_class:
numpify = False
torchify = True
input_data_format = image_processing_class.data_format
else:
numpify = True
torchify = False
input_data_format = ChannelDimension.LAST
image_processing = image_processing_class(**self.image_processor_dict)
# Create odd-sized images
image_input = self.image_processor_tester.prepare_image_inputs(
equal_resolution=True,
numpify=numpify,
torchify=torchify,
)[0]
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
# Test odd-width
image_shape = (400, 601)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)
# Test odd-height
image_shape = (503, 400)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)

View File

@ -48,7 +48,7 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image
if is_vision_available():
@ -288,6 +288,19 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
def test_unpad_image(self):
original_size = (400, 400)
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
@parameterized.expand(
[
(-1,),

View File

@ -47,6 +47,8 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_next_video.modeling_llava_next_video import unpad_image
if is_vision_available():
from PIL import Image
@ -302,6 +304,19 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
def test_unpad_image(self):
original_size = (400, 400)
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
@parameterized.expand(
[
(-1,),

View File

@ -16,7 +16,7 @@ import unittest
import numpy as np
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -305,3 +305,38 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
) # FIXME yoni
def test_can_compile_fast_image_processor(self):
pass
def test_pad_for_patching(self):
for image_processing_class in self.image_processor_list:
if image_processing_class == self.fast_image_processing_class:
numpify = False
torchify = True
input_data_format = image_processing_class.data_format
else:
numpify = True
torchify = False
input_data_format = ChannelDimension.LAST
image_processing = image_processing_class(**self.image_processor_dict)
# Create odd-sized images
image_input = self.image_processor_tester.prepare_image_inputs(
equal_resolution=True,
numpify=numpify,
torchify=torchify,
)[0]
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
# Test odd-width
image_shape = (400, 601)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)
# Test odd-height
image_shape = (503, 400)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)

View File

@ -48,6 +48,8 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_onevision.modeling_llava_onevision import unpad_image
if is_vision_available():
from PIL import Image
@ -258,6 +260,19 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_unpad_image(self):
original_size = (400, 400)
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
@parameterized.expand(
[
(-1,),