mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-08 07:10:06 +06:00
Bug fix - flip_channel_order for channels first images (#23701)
Bug fix - flip_channel_order for channels_first
This commit is contained in:
parent
0b3d092f63
commit
c608b8fc93
@ -742,3 +742,32 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
|
|||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Flips the channel order of the image.
|
||||||
|
|
||||||
|
If the image is in RGB format, it will be converted to BGR and vice versa.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
The image to flip.
|
||||||
|
data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
If unset, will use same as the input image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
if input_data_format == ChannelDimension.LAST:
|
||||||
|
image = image[..., ::-1]
|
||||||
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
|
image = image[::-1, ...]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||||
|
|
||||||
|
if data_format is not None:
|
||||||
|
image = to_channel_dimension_format(image, data_format)
|
||||||
|
return image
|
||||||
|
@ -19,12 +19,11 @@ from typing import Dict, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import resize, to_channel_dimension_format, to_pil_image
|
from ...image_transforms import flip_channel_order, resize, to_channel_dimension_format, to_pil_image
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
ImageInput,
|
ImageInput,
|
||||||
PILImageResampling,
|
PILImageResampling,
|
||||||
infer_channel_dimension_format,
|
|
||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
@ -85,20 +84,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op
|
|||||||
return words, normalized_boxes
|
return words, normalized_boxes
|
||||||
|
|
||||||
|
|
||||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
|
||||||
input_data_format = infer_channel_dimension_format(image)
|
|
||||||
if input_data_format == ChannelDimension.LAST:
|
|
||||||
image = image[..., ::-1]
|
|
||||||
elif input_data_format == ChannelDimension.FIRST:
|
|
||||||
image = image[:, ::-1, ...]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
|
||||||
|
|
||||||
if data_format is not None:
|
|
||||||
image = to_channel_dimension_format(image, data_format)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||||
r"""
|
r"""
|
||||||
Constructs a LayoutLMv2 image processor.
|
Constructs a LayoutLMv2 image processor.
|
||||||
|
@ -26,7 +26,6 @@ from ...image_utils import (
|
|||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
ImageInput,
|
ImageInput,
|
||||||
PILImageResampling,
|
PILImageResampling,
|
||||||
infer_channel_dimension_format,
|
|
||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
@ -86,20 +85,6 @@ def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Op
|
|||||||
return words, normalized_boxes
|
return words, normalized_boxes
|
||||||
|
|
||||||
|
|
||||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
|
||||||
input_data_format = infer_channel_dimension_format(image)
|
|
||||||
if input_data_format == ChannelDimension.LAST:
|
|
||||||
image = image[..., ::-1]
|
|
||||||
elif input_data_format == ChannelDimension.FIRST:
|
|
||||||
image = image[:, ::-1, ...]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
|
||||||
|
|
||||||
if data_format is not None:
|
|
||||||
image = to_channel_dimension_format(image, data_format)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||||
r"""
|
r"""
|
||||||
Constructs a LayoutLMv3 image processor.
|
Constructs a LayoutLMv3 image processor.
|
||||||
@ -356,7 +341,6 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
|||||||
if do_normalize:
|
if do_normalize:
|
||||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||||
|
|
||||||
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
|
||||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||||
|
|
||||||
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||||
|
@ -19,12 +19,18 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import center_crop, get_resize_output_image_size, rescale, resize, to_channel_dimension_format
|
from ...image_transforms import (
|
||||||
|
center_crop,
|
||||||
|
flip_channel_order,
|
||||||
|
get_resize_output_image_size,
|
||||||
|
rescale,
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
)
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
ImageInput,
|
ImageInput,
|
||||||
PILImageResampling,
|
PILImageResampling,
|
||||||
infer_channel_dimension_format,
|
|
||||||
make_list_of_images,
|
make_list_of_images,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
@ -42,34 +48,6 @@ if is_torch_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension]) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Flip the color channels from RGB to BGR or vice versa.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`np.ndarray`):
|
|
||||||
The image, represented as a numpy array.
|
|
||||||
data_format (`ChannelDimension`, *`optional`*):
|
|
||||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`np.ndarray`: The image with the flipped color channels.
|
|
||||||
"""
|
|
||||||
input_data_format = infer_channel_dimension_format(image)
|
|
||||||
|
|
||||||
if input_data_format == ChannelDimension.LAST:
|
|
||||||
image = image[..., ::-1]
|
|
||||||
elif input_data_format == ChannelDimension.FIRST:
|
|
||||||
image = image[:, ::-1, ...]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid input channel dimension format: {input_data_format}")
|
|
||||||
|
|
||||||
if data_format is not None:
|
|
||||||
image = to_channel_dimension_format(image, data_format)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
class MobileViTImageProcessor(BaseImageProcessor):
|
class MobileViTImageProcessor(BaseImageProcessor):
|
||||||
r"""
|
r"""
|
||||||
Constructs a MobileViT image processor.
|
Constructs a MobileViT image processor.
|
||||||
|
@ -39,6 +39,7 @@ if is_vision_available():
|
|||||||
center_to_corners_format,
|
center_to_corners_format,
|
||||||
convert_to_rgb,
|
convert_to_rgb,
|
||||||
corners_to_center_format,
|
corners_to_center_format,
|
||||||
|
flip_channel_order,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
id_to_rgb,
|
id_to_rgb,
|
||||||
normalize,
|
normalize,
|
||||||
@ -520,3 +521,41 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertEqual(rgb_image.mode, "RGB")
|
self.assertEqual(rgb_image.mode, "RGB")
|
||||||
self.assertEqual(rgb_image.size, (2, 1))
|
self.assertEqual(rgb_image.size, (2, 1))
|
||||||
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))
|
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))
|
||||||
|
|
||||||
|
def test_flip_channel_order(self):
|
||||||
|
# fmt: off
|
||||||
|
img_channels_first = np.array([
|
||||||
|
[[ 0, 1, 2, 3],
|
||||||
|
[ 4, 5, 6, 7]],
|
||||||
|
|
||||||
|
[[ 8, 9, 10, 11],
|
||||||
|
[12, 13, 14, 15]],
|
||||||
|
|
||||||
|
[[16, 17, 18, 19],
|
||||||
|
[20, 21, 22, 23]],
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
img_channels_last = np.moveaxis(img_channels_first, 0, -1)
|
||||||
|
# fmt: off
|
||||||
|
flipped_img_channels_first = np.array([
|
||||||
|
[[16, 17, 18, 19],
|
||||||
|
[20, 21, 22, 23]],
|
||||||
|
|
||||||
|
[[ 8, 9, 10, 11],
|
||||||
|
[12, 13, 14, 15]],
|
||||||
|
|
||||||
|
[[ 0, 1, 2, 3],
|
||||||
|
[ 4, 5, 6, 7]],
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
flipped_img_channels_last = np.moveaxis(flipped_img_channels_first, 0, -1)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(flip_channel_order(img_channels_first), flipped_img_channels_first))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(flip_channel_order(img_channels_first, "channels_last"), flipped_img_channels_last)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(flip_channel_order(img_channels_last), flipped_img_channels_last))
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user