diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 369ddc8d4c0..0893792e2da 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -742,3 +742,32 @@ def convert_to_rgb(image: ImageInput) -> ImageInput: image = image.convert("RGB") return image + + +def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray: + """ + Flips the channel order of the image. + + If the image is in RGB format, it will be converted to BGR and vice versa. + + Args: + image (`np.ndarray`): + The image to flip. + data_format (`ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + """ + + input_data_format = infer_channel_dimension_format(image) + if input_data_format == ChannelDimension.LAST: + image = image[..., ::-1] + elif input_data_format == ChannelDimension.FIRST: + image = image[::-1, ...] + else: + raise ValueError(f"Unsupported channel dimension: {input_data_format}") + + if data_format is not None: + image = to_channel_dimension_format(image, data_format) + return image diff --git a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py index 6ad0968b612..c396b7e4032 100644 --- a/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/image_processing_layoutlmv2.py @@ -19,12 +19,11 @@ from typing import Dict, Optional, Union import numpy as np from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict -from ...image_transforms import resize, to_channel_dimension_format, to_pil_image +from ...image_transforms import flip_channel_order, resize, to_channel_dimension_format, to_pil_image from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - infer_channel_dimension_format, make_list_of_images, to_numpy_array, valid_images, @@ -85,20 +84,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op return words, normalized_boxes -def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray: - input_data_format = infer_channel_dimension_format(image) - if input_data_format == ChannelDimension.LAST: - image = image[..., ::-1] - elif input_data_format == ChannelDimension.FIRST: - image = image[:, ::-1, ...] - else: - raise ValueError(f"Unsupported channel dimension: {input_data_format}") - - if data_format is not None: - image = to_channel_dimension_format(image, data_format) - return image - - class LayoutLMv2ImageProcessor(BaseImageProcessor): r""" Constructs a LayoutLMv2 image processor. diff --git a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py index 7152abf06c4..5b67bd94dba 100644 --- a/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/image_processing_layoutlmv3.py @@ -26,7 +26,6 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - infer_channel_dimension_format, make_list_of_images, to_numpy_array, valid_images, @@ -86,20 +85,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op return words, normalized_boxes -def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray: - input_data_format = infer_channel_dimension_format(image) - if input_data_format == ChannelDimension.LAST: - image = image[..., ::-1] - elif input_data_format == ChannelDimension.FIRST: - image = image[:, ::-1, ...] - else: - raise ValueError(f"Unsupported channel dimension: {input_data_format}") - - if data_format is not None: - image = to_channel_dimension_format(image, data_format) - return image - - class LayoutLMv3ImageProcessor(BaseImageProcessor): r""" Constructs a LayoutLMv3 image processor. @@ -356,7 +341,6 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): if do_normalize: images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] - # flip color channels from RGB to BGR (as Detectron2 requires this) images = [to_channel_dimension_format(image, data_format) for image in images] data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index e121c2ae7ba..fe0cc6ab22e 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -19,12 +19,18 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict -from ...image_transforms import center_crop, get_resize_output_image_size, rescale, resize, to_channel_dimension_format +from ...image_transforms import ( + center_crop, + flip_channel_order, + get_resize_output_image_size, + rescale, + resize, + to_channel_dimension_format, +) from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - infer_channel_dimension_format, make_list_of_images, to_numpy_array, valid_images, @@ -42,34 +48,6 @@ if is_torch_available(): logger = logging.get_logger(__name__) -def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension]) -> np.ndarray: - """ - Flip the color channels from RGB to BGR or vice versa. - - Args: - image (`np.ndarray`): - The image, represented as a numpy array. - data_format (`ChannelDimension`, *`optional`*): - The channel dimension format of the image. If not provided, it will be the same as the input image. - - Returns: - `np.ndarray`: The image with the flipped color channels. - """ - input_data_format = infer_channel_dimension_format(image) - - if input_data_format == ChannelDimension.LAST: - image = image[..., ::-1] - elif input_data_format == ChannelDimension.FIRST: - image = image[:, ::-1, ...] - else: - raise ValueError(f"Invalid input channel dimension format: {input_data_format}") - - if data_format is not None: - image = to_channel_dimension_format(image, data_format) - - return image - - class MobileViTImageProcessor(BaseImageProcessor): r""" Constructs a MobileViT image processor. diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index cb1524ac12c..70db390394c 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -39,6 +39,7 @@ if is_vision_available(): center_to_corners_format, convert_to_rgb, corners_to_center_format, + flip_channel_order, get_resize_output_image_size, id_to_rgb, normalize, @@ -520,3 +521,41 @@ class ImageTransformsTester(unittest.TestCase): self.assertEqual(rgb_image.mode, "RGB") self.assertEqual(rgb_image.size, (2, 1)) self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8))) + + def test_flip_channel_order(self): + # fmt: off + img_channels_first = np.array([ + [[ 0, 1, 2, 3], + [ 4, 5, 6, 7]], + + [[ 8, 9, 10, 11], + [12, 13, 14, 15]], + + [[16, 17, 18, 19], + [20, 21, 22, 23]], + ]) + # fmt: on + img_channels_last = np.moveaxis(img_channels_first, 0, -1) + # fmt: off + flipped_img_channels_first = np.array([ + [[16, 17, 18, 19], + [20, 21, 22, 23]], + + [[ 8, 9, 10, 11], + [12, 13, 14, 15]], + + [[ 0, 1, 2, 3], + [ 4, 5, 6, 7]], + ]) + # fmt: on + flipped_img_channels_last = np.moveaxis(flipped_img_channels_first, 0, -1) + + self.assertTrue(np.allclose(flip_channel_order(img_channels_first), flipped_img_channels_first)) + self.assertTrue( + np.allclose(flip_channel_order(img_channels_first, "channels_last"), flipped_img_channels_last) + ) + + self.assertTrue(np.allclose(flip_channel_order(img_channels_last), flipped_img_channels_last)) + self.assertTrue( + np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first) + )