Bug fix - flip_channel_order for channels first images (#23701)

Bug fix - flip_channel_order for channels_first
This commit is contained in:
amyeroberts 2023-05-31 17:12:27 +01:00 committed by GitHub
parent 0b3d092f63
commit c608b8fc93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 62 deletions

View File

@ -742,3 +742,32 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
image = image.convert("RGB")
return image
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
"""
Flips the channel order of the image.
If the image is in RGB format, it will be converted to BGR and vice versa.
Args:
image (`np.ndarray`):
The image to flip.
data_format (`ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
"""
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[::-1, ...]
else:
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image

View File

@ -19,12 +19,11 @@ from typing import Dict, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format, to_pil_image
from ...image_transforms import flip_channel_order, resize, to_channel_dimension_format, to_pil_image
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
@ -85,20 +84,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op
return words, normalized_boxes
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[:, ::-1, ...]
else:
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
class LayoutLMv2ImageProcessor(BaseImageProcessor):
r"""
Constructs a LayoutLMv2 image processor.

View File

@ -26,7 +26,6 @@ from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
@ -86,20 +85,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op
return words, normalized_boxes
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[:, ::-1, ...]
else:
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
class LayoutLMv3ImageProcessor(BaseImageProcessor):
r"""
Constructs a LayoutLMv3 image processor.
@ -356,7 +341,6 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
if do_normalize:
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
# flip color channels from RGB to BGR (as Detectron2 requires this)
images = [to_channel_dimension_format(image, data_format) for image in images]
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)

View File

@ -19,12 +19,18 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, get_resize_output_image_size, rescale, resize, to_channel_dimension_format
from ...image_transforms import (
center_crop,
flip_channel_order,
get_resize_output_image_size,
rescale,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
valid_images,
@ -42,34 +48,6 @@ if is_torch_available():
logger = logging.get_logger(__name__)
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension]) -> np.ndarray:
"""
Flip the color channels from RGB to BGR or vice versa.
Args:
image (`np.ndarray`):
The image, represented as a numpy array.
data_format (`ChannelDimension`, *`optional`*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
Returns:
`np.ndarray`: The image with the flipped color channels.
"""
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
image = image[:, ::-1, ...]
else:
raise ValueError(f"Invalid input channel dimension format: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
class MobileViTImageProcessor(BaseImageProcessor):
r"""
Constructs a MobileViT image processor.

View File

@ -39,6 +39,7 @@ if is_vision_available():
center_to_corners_format,
convert_to_rgb,
corners_to_center_format,
flip_channel_order,
get_resize_output_image_size,
id_to_rgb,
normalize,
@ -520,3 +521,41 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(rgb_image.mode, "RGB")
self.assertEqual(rgb_image.size, (2, 1))
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))
def test_flip_channel_order(self):
# fmt: off
img_channels_first = np.array([
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]],
])
# fmt: on
img_channels_last = np.moveaxis(img_channels_first, 0, -1)
# fmt: off
flipped_img_channels_first = np.array([
[[16, 17, 18, 19],
[20, 21, 22, 23]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
])
# fmt: on
flipped_img_channels_last = np.moveaxis(flipped_img_channels_first, 0, -1)
self.assertTrue(np.allclose(flip_channel_order(img_channels_first), flipped_img_channels_first))
self.assertTrue(
np.allclose(flip_channel_order(img_channels_first, "channels_last"), flipped_img_channels_last)
)
self.assertTrue(np.allclose(flip_channel_order(img_channels_last), flipped_img_channels_last))
self.assertTrue(
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
)