Fix bugs in mllama image processing (#36156)

* fix: handle input_channel_dim == channels_last

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* fix: default PIL images to channels_last

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* Apply suggestions from code review

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* fixup from review batch

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* test: add 1x1 PIL image to ambiguous channel test

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* fix(mllama): avoid 0 dimension for image with impractical aspect ratio

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

---------

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Travis Johnson 2025-03-11 03:22:48 -06:00 committed by GitHub
parent 1c4b62b219
commit d8663cb8c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 7 deletions

View File

@ -93,7 +93,7 @@ def get_image_size_fit_to_canvas(
canvas_height and canvas_width, while ensuring that the image dimensions are not smaller than
tile_size. If the image is larger than the canvas, the returned size will fit within the canvas.
If the image already fits within the canvas, the size remains unchanged.
The aspect ratio of the original image is preserved.
The aspect ratio of the original image is preserved as much as possible.
Args:
image_height (`int`):
@ -120,10 +120,12 @@ def get_image_size_fit_to_canvas(
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(image_height * scale_w), target_height)
# minimum height is 1 to avoid invalid height of 0
new_height = min(math.floor(image_height * scale_w) or 1, target_height)
else:
new_height = target_height
new_width = min(math.floor(image_width * scale_h), target_width)
# minimum width is 1 to avoid invalid width of 0
new_width = min(math.floor(image_width * scale_h) or 1, target_width)
return new_height, new_width
@ -695,8 +697,6 @@ class MllamaImageProcessor(BaseImageProcessor):
if self.do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
batch_images = []
batch_aspect_ratios = []
@ -707,6 +707,13 @@ class MllamaImageProcessor(BaseImageProcessor):
# iterate over images in a batch sample
for image in images:
# default PIL images to channels_last
if input_data_format is None and isinstance(image, PIL.Image.Image):
input_data_format = ChannelDimension.LAST
# convert to numpy array for processing
image = to_numpy_array(image)
# convert images to channels first format for faster processing
# LAST is slower for `pad` and not supported by `split_to_tiles`
data_format = ChannelDimension.FIRST
@ -735,7 +742,7 @@ class MllamaImageProcessor(BaseImageProcessor):
image = self.rescale(
image=image,
scale=rescale_factor,
input_data_format=input_data_format,
input_data_format=data_format,
data_format=data_format,
)
@ -744,7 +751,7 @@ class MllamaImageProcessor(BaseImageProcessor):
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
input_data_format=data_format,
data_format=data_format,
)

View File

@ -224,6 +224,36 @@ class MllamaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_channels_last(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# a white 1x1 pixel RGB image
image_inputs = [[np.full(shape=(1, 1, 3), fill_value=1.0, dtype=float)]]
encoded_images = image_processing(
image_inputs, return_tensors="pt", input_data_format="channels_last"
).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
def test_ambiguous_channel_pil_image(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = [[Image.new("RGB", (1, 1))], [Image.new("RGB", (100, 1))]]
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
def test_resize_impractical_aspect_ratio(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# Ensure that no error is raised even if the aspect ratio is impractical
image_inputs = [[Image.new("RGB", (9999999, 1))]]
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)