[llava] one pixel is missing from padding when length is odd (#37819)

* [fix] one pixel should be added when length is odd

* [fix] add vision_aspect_ratio args & typo

* [fix] style

* [fix] do not fix fast file directly

* [fix] convert using modular

* remove duplicate codes

* match unpad logic with pad logic

* test odd-sized images for llava & aria

* test unpad odd-sized padding for llava family

* fix style

* add kwarg to onvision modular

* move vision_aspect_ratio from image_processor to processor
(llava_onevision)
This commit is contained in:
youngrok cha 2025-05-06 20:11:26 +09:00 committed by GitHub
parent 9981214d32
commit acded47fe7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 234 additions and 133 deletions

View File

@ -466,7 +466,12 @@ setup(
package_data={"": ["**/*.cu", "**/*.cpp", "**/*.cuh", "**/*.h", "**/*.pyx", "py.typed"]},
zip_safe=False,
extras_require=extras,
entry_points={"console_scripts": ["transformers=transformers.commands.transformers_cli:main", "transformers-cli=transformers.commands.transformers_cli:main_cli"]},
entry_points={
"console_scripts": [
"transformers=transformers.commands.transformers_cli:main",
"transformers-cli=transformers.commands.transformers_cli:main_cli",
]
},
python_requires=">=3.9.0",
install_requires=list(install_requires),
classifiers=[

View File

@ -18,12 +18,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
@ -71,23 +70,6 @@ def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> Li
return patches
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class AriaImageProcessor(BaseImageProcessor):
"""
A vision processor for the Aria model that handles image preprocessing.
@ -375,7 +357,7 @@ class AriaImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -389,12 +371,12 @@ class AriaImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
@ -20,7 +19,7 @@ import numpy as np
from ...activations import ACT2FN
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
@ -461,23 +460,6 @@ class AriaProjector(nn.Module):
return out
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class AriaImageProcessor(BaseImageProcessor):
"""
A vision processor for the Aria model that handles image preprocessing.
@ -765,7 +747,7 @@ class AriaImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -779,12 +761,12 @@ class AriaImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -14,12 +14,17 @@
# limitations under the License.
"""Image processor class for LLaVa-NeXT."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
from ...image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_patch_output_size,
get_size_dict,
select_best_resolution,
)
from ...image_transforms import (
PaddingMode,
convert_to_rgb,
@ -99,23 +104,6 @@ def expand_to_square(image: np.array, background_color, input_data_format) -> np
return result
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class LlavaNextImageProcessor(BaseImageProcessor):
r"""
Constructs a LLaVa-NeXT image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
@ -429,7 +417,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -443,12 +431,12 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -102,8 +102,8 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaNextFastImageProcessorKwargs]) -> BatchFeature:
@ -164,10 +164,10 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
return padded_image

View File

@ -139,14 +139,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor

View File

@ -262,14 +262,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor

View File

@ -14,12 +14,17 @@
# limitations under the License.
"""Image processor class for LLaVa-Onevision."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
from ...image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_patch_output_size,
get_size_dict,
select_best_resolution,
)
from ...image_transforms import (
PaddingMode,
convert_to_rgb,
@ -99,24 +104,6 @@ def expand_to_square(image: np.array, background_color, input_data_format) -> np
return result
# Copied from transformers.models.llava_next.image_processing_llava_next._get_patch_output_size
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class LlavaOnevisionImageProcessor(BaseImageProcessor):
r"""
Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
@ -151,8 +138,8 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
@ -321,7 +308,7 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -336,12 +323,12 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -84,8 +84,8 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
@ -146,10 +146,10 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
return padded_image

View File

@ -140,14 +140,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor

View File

@ -70,6 +70,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
Special token used to denote image location.
video_token (`str`, *optional*, defaults to `"<video>"`):
Special token used to denote video location.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
"""
attributes = ["image_processor", "tokenizer", "video_processor"]
@ -79,6 +81,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
"vision_feature_select_strategy",
"image_token",
"video_token",
"vision_aspect_ratio",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
@ -94,6 +97,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
chat_template=None,
image_token="<image>",
video_token="<video>",
vision_aspect_ratio="anyres_max_9",
**kwargs,
):
self.num_image_tokens = num_image_tokens
@ -110,6 +114,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.vision_aspect_ratio = vision_aspect_ratio
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
def __call__(
@ -264,7 +269,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * patches_height**2))
max_num_patches = int(self.vision_aspect_ratio.strip("anyres_max_"))
ratio = math.sqrt(current_height * current_width / (max_num_patches * patches_height**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(current_width // ratio)
newline_features = int(current_height // ratio)

View File

@ -17,7 +17,7 @@ import unittest
import numpy as np
from transformers.image_utils import PILImageResampling
from transformers.image_utils import ChannelDimension, PILImageResampling
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
@ -264,3 +264,41 @@ class AriaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)
def test_pad_for_patching(self):
for image_processing_class in self.image_processor_list:
if image_processing_class == self.fast_image_processing_class:
numpify = False
torchify = True
input_data_format = image_processing_class.data_format
else:
numpify = True
torchify = False
input_data_format = ChannelDimension.LAST
image_processing = image_processing_class(**self.image_processor_dict)
# Create odd-sized images
image_input = self.image_processor_tester.prepare_image_inputs(
batch_size=1,
max_resolution=400,
num_images=1,
equal_resolution=True,
numpify=numpify,
torchify=torchify,
)[0][0]
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
# Test odd-width
image_shape = (400, 601)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)
# Test odd-height
image_shape = (503, 400)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)

View File

@ -16,7 +16,7 @@ import unittest
import numpy as np
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension
from transformers.models.llava_next.image_processing_llava_next import select_best_resolution
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -230,3 +230,38 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Image processor should return same pixel values, independently of ipnut format
self.assertTrue((encoded_images_nested == encoded_images).all())
def test_pad_for_patching(self):
for image_processing_class in self.image_processor_list:
if image_processing_class == self.fast_image_processing_class:
numpify = False
torchify = True
input_data_format = image_processing_class.data_format
else:
numpify = True
torchify = False
input_data_format = ChannelDimension.LAST
image_processing = image_processing_class(**self.image_processor_dict)
# Create odd-sized images
image_input = self.image_processor_tester.prepare_image_inputs(
equal_resolution=True,
numpify=numpify,
torchify=torchify,
)[0]
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
# Test odd-width
image_shape = (400, 601)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)
# Test odd-height
image_shape = (503, 400)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)

View File

@ -48,7 +48,7 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image
if is_vision_available():
@ -288,6 +288,19 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
def test_unpad_image(self):
original_size = (400, 400)
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
@parameterized.expand(
[
(-1,),

View File

@ -47,6 +47,8 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_next_video.modeling_llava_next_video import unpad_image
if is_vision_available():
from PIL import Image
@ -302,6 +304,19 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
def test_unpad_image(self):
original_size = (400, 400)
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
@parameterized.expand(
[
(-1,),

View File

@ -16,7 +16,7 @@ import unittest
import numpy as np
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -305,3 +305,38 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
) # FIXME yoni
def test_can_compile_fast_image_processor(self):
pass
def test_pad_for_patching(self):
for image_processing_class in self.image_processor_list:
if image_processing_class == self.fast_image_processing_class:
numpify = False
torchify = True
input_data_format = image_processing_class.data_format
else:
numpify = True
torchify = False
input_data_format = ChannelDimension.LAST
image_processing = image_processing_class(**self.image_processor_dict)
# Create odd-sized images
image_input = self.image_processor_tester.prepare_image_inputs(
equal_resolution=True,
numpify=numpify,
torchify=torchify,
)[0]
self.assertIn(image_input.shape, [(3, 400, 400), (400, 400, 3)])
# Test odd-width
image_shape = (400, 601)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)
# Test odd-height
image_shape = (503, 400)
encoded_images = image_processing._pad_for_patching(image_input, image_shape, input_data_format)
encoded_image_shape = (
encoded_images.shape[:-1] if input_data_format == ChannelDimension.LAST else encoded_images.shape[1:]
)
self.assertEqual(encoded_image_shape, image_shape)

View File

@ -48,6 +48,8 @@ from ...test_modeling_common import (
if is_torch_available():
import torch
from transformers.models.llava_onevision.modeling_llava_onevision import unpad_image
if is_vision_available():
from PIL import Image
@ -258,6 +260,19 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_unpad_image(self):
original_size = (400, 400)
# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
@parameterized.expand(
[
(-1,),