mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix bugs in mllama image processing (#36156)
* fix: handle input_channel_dim == channels_last Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> * fix: default PIL images to channels_last Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * fixup from review batch Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> * test: add 1x1 PIL image to ambiguous channel test Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> * fix(mllama): avoid 0 dimension for image with impractical aspect ratio Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> --------- Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
1c4b62b219
commit
d8663cb8c5
@ -93,7 +93,7 @@ def get_image_size_fit_to_canvas(
|
||||
canvas_height and canvas_width, while ensuring that the image dimensions are not smaller than
|
||||
tile_size. If the image is larger than the canvas, the returned size will fit within the canvas.
|
||||
If the image already fits within the canvas, the size remains unchanged.
|
||||
The aspect ratio of the original image is preserved.
|
||||
The aspect ratio of the original image is preserved as much as possible.
|
||||
|
||||
Args:
|
||||
image_height (`int`):
|
||||
@ -120,10 +120,12 @@ def get_image_size_fit_to_canvas(
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(image_height * scale_w), target_height)
|
||||
# minimum height is 1 to avoid invalid height of 0
|
||||
new_height = min(math.floor(image_height * scale_w) or 1, target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(image_width * scale_h), target_width)
|
||||
# minimum width is 1 to avoid invalid width of 0
|
||||
new_width = min(math.floor(image_width * scale_h) or 1, target_width)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
@ -695,8 +697,6 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
if self.do_convert_rgb:
|
||||
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
|
||||
|
||||
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
|
||||
|
||||
batch_images = []
|
||||
batch_aspect_ratios = []
|
||||
|
||||
@ -707,6 +707,13 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
|
||||
# iterate over images in a batch sample
|
||||
for image in images:
|
||||
# default PIL images to channels_last
|
||||
if input_data_format is None and isinstance(image, PIL.Image.Image):
|
||||
input_data_format = ChannelDimension.LAST
|
||||
|
||||
# convert to numpy array for processing
|
||||
image = to_numpy_array(image)
|
||||
|
||||
# convert images to channels first format for faster processing
|
||||
# LAST is slower for `pad` and not supported by `split_to_tiles`
|
||||
data_format = ChannelDimension.FIRST
|
||||
@ -735,7 +742,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
image = self.rescale(
|
||||
image=image,
|
||||
scale=rescale_factor,
|
||||
input_data_format=input_data_format,
|
||||
input_data_format=data_format,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
@ -744,7 +751,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
image=image,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
input_data_format=data_format,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
|
@ -224,6 +224,36 @@ class MllamaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_channels_last(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# a white 1x1 pixel RGB image
|
||||
image_inputs = [[np.full(shape=(1, 1, 3), fill_value=1.0, dtype=float)]]
|
||||
encoded_images = image_processing(
|
||||
image_inputs, return_tensors="pt", input_data_format="channels_last"
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
def test_ambiguous_channel_pil_image(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
image_inputs = [[Image.new("RGB", (1, 1))], [Image.new("RGB", (100, 1))]]
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape))
|
||||
|
||||
def test_resize_impractical_aspect_ratio(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# Ensure that no error is raised even if the aspect ratio is impractical
|
||||
image_inputs = [[Image.new("RGB", (9999999, 1))]]
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user