mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Bug fix - flip_channel_order for channels first images (#23701)
Bug fix - flip_channel_order for channels_first
This commit is contained in:
parent
0b3d092f63
commit
c608b8fc93
@ -742,3 +742,32 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
|
||||
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
||||
"""
|
||||
Flips the channel order of the image.
|
||||
|
||||
If the image is in RGB format, it will be converted to BGR and vice versa.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image to flip.
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
"""
|
||||
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
image = image[..., ::-1]
|
||||
elif input_data_format == ChannelDimension.FIRST:
|
||||
image = image[::-1, ...]
|
||||
else:
|
||||
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
return image
|
||||
|
@ -19,12 +19,11 @@ from typing import Dict, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import resize, to_channel_dimension_format, to_pil_image
|
||||
from ...image_transforms import flip_channel_order, resize, to_channel_dimension_format, to_pil_image
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -85,20 +84,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op
|
||||
return words, normalized_boxes
|
||||
|
||||
|
||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
image = image[..., ::-1]
|
||||
elif input_data_format == ChannelDimension.FIRST:
|
||||
image = image[:, ::-1, ...]
|
||||
else:
|
||||
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
return image
|
||||
|
||||
|
||||
class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LayoutLMv2 image processor.
|
||||
|
@ -26,7 +26,6 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -86,20 +85,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op
|
||||
return words, normalized_boxes
|
||||
|
||||
|
||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
image = image[..., ::-1]
|
||||
elif input_data_format == ChannelDimension.FIRST:
|
||||
image = image[:, ::-1, ...]
|
||||
else:
|
||||
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
return image
|
||||
|
||||
|
||||
class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LayoutLMv3 image processor.
|
||||
@ -356,7 +341,6 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
|
||||
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
|
||||
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
|
@ -19,12 +19,18 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import center_crop, get_resize_output_image_size, rescale, resize, to_channel_dimension_format
|
||||
from ...image_transforms import (
|
||||
center_crop,
|
||||
flip_channel_order,
|
||||
get_resize_output_image_size,
|
||||
rescale,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -42,34 +48,6 @@ if is_torch_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension]) -> np.ndarray:
|
||||
"""
|
||||
Flip the color channels from RGB to BGR or vice versa.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image, represented as a numpy array.
|
||||
data_format (`ChannelDimension`, *`optional`*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The image with the flipped color channels.
|
||||
"""
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
image = image[..., ::-1]
|
||||
elif input_data_format == ChannelDimension.FIRST:
|
||||
image = image[:, ::-1, ...]
|
||||
else:
|
||||
raise ValueError(f"Invalid input channel dimension format: {input_data_format}")
|
||||
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class MobileViTImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a MobileViT image processor.
|
||||
|
@ -39,6 +39,7 @@ if is_vision_available():
|
||||
center_to_corners_format,
|
||||
convert_to_rgb,
|
||||
corners_to_center_format,
|
||||
flip_channel_order,
|
||||
get_resize_output_image_size,
|
||||
id_to_rgb,
|
||||
normalize,
|
||||
@ -520,3 +521,41 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertEqual(rgb_image.mode, "RGB")
|
||||
self.assertEqual(rgb_image.size, (2, 1))
|
||||
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))
|
||||
|
||||
def test_flip_channel_order(self):
|
||||
# fmt: off
|
||||
img_channels_first = np.array([
|
||||
[[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7]],
|
||||
|
||||
[[ 8, 9, 10, 11],
|
||||
[12, 13, 14, 15]],
|
||||
|
||||
[[16, 17, 18, 19],
|
||||
[20, 21, 22, 23]],
|
||||
])
|
||||
# fmt: on
|
||||
img_channels_last = np.moveaxis(img_channels_first, 0, -1)
|
||||
# fmt: off
|
||||
flipped_img_channels_first = np.array([
|
||||
[[16, 17, 18, 19],
|
||||
[20, 21, 22, 23]],
|
||||
|
||||
[[ 8, 9, 10, 11],
|
||||
[12, 13, 14, 15]],
|
||||
|
||||
[[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7]],
|
||||
])
|
||||
# fmt: on
|
||||
flipped_img_channels_last = np.moveaxis(flipped_img_channels_first, 0, -1)
|
||||
|
||||
self.assertTrue(np.allclose(flip_channel_order(img_channels_first), flipped_img_channels_first))
|
||||
self.assertTrue(
|
||||
np.allclose(flip_channel_order(img_channels_first, "channels_last"), flipped_img_channels_last)
|
||||
)
|
||||
|
||||
self.assertTrue(np.allclose(flip_channel_order(img_channels_last), flipped_img_channels_last))
|
||||
self.assertTrue(
|
||||
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user