mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Move convert_to_rgb to image_transforms module (#20784)
* Move convert_to_rgb to image_transforms module * Fix tests
This commit is contained in:
parent
4bc723f87d
commit
491e951875
@ -20,6 +20,7 @@ import numpy as np
|
||||
|
||||
from transformers.image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
get_channel_dimension_axis,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
@ -687,3 +688,22 @@ def pad(
|
||||
|
||||
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
|
||||
return image
|
||||
|
||||
|
||||
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
|
||||
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
||||
"""
|
||||
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
||||
as is.
|
||||
|
||||
Args:
|
||||
image (Image):
|
||||
The image to convert.
|
||||
"""
|
||||
requires_backends(convert_to_rgb, ["vision"])
|
||||
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for BiT."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
center_crop,
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
normalize,
|
||||
rescale,
|
||||
@ -41,20 +42,6 @@ if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
||||
"""
|
||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to convert.
|
||||
"""
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
class BitImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a BiT image processor.
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for Chinese-CLIP."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
center_crop,
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
normalize,
|
||||
rescale,
|
||||
@ -41,20 +42,6 @@ if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
||||
"""
|
||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to convert.
|
||||
"""
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
class ChineseCLIPImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a Chinese-CLIP image processor.
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for CLIP."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
center_crop,
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
normalize,
|
||||
rescale,
|
||||
@ -41,20 +42,6 @@ if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
||||
"""
|
||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to convert.
|
||||
"""
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
class CLIPImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a CLIP image processor.
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Image processor class for ViT hybrid."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
center_crop,
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
normalize,
|
||||
rescale,
|
||||
@ -41,21 +42,6 @@ if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
# Copied from transformers.models.bit.image_processing_bit.convert_to_rgb
|
||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
||||
"""
|
||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to convert.
|
||||
"""
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
class ViTHybridImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a ViT Hybrid image processor.
|
||||
|
@ -37,6 +37,7 @@ if is_vision_available():
|
||||
from transformers.image_transforms import (
|
||||
center_crop,
|
||||
center_to_corners_format,
|
||||
convert_to_rgb,
|
||||
corners_to_center_format,
|
||||
get_resize_output_image_size,
|
||||
id_to_rgb,
|
||||
@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
|
||||
)
|
||||
|
||||
@require_vision
|
||||
def test_convert_to_rgb(self):
|
||||
# Test that an RGBA image is converted to RGB
|
||||
image = np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.uint8)
|
||||
pil_image = PIL.Image.fromarray(image)
|
||||
self.assertEqual(pil_image.mode, "RGBA")
|
||||
self.assertEqual(pil_image.size, (2, 1))
|
||||
|
||||
# For the moment, numpy images are returned as is
|
||||
rgb_image = convert_to_rgb(image)
|
||||
self.assertEqual(rgb_image.shape, (1, 2, 4))
|
||||
self.assertTrue(np.allclose(rgb_image, image))
|
||||
|
||||
# And PIL images are converted
|
||||
rgb_image = convert_to_rgb(pil_image)
|
||||
self.assertEqual(rgb_image.mode, "RGB")
|
||||
self.assertEqual(rgb_image.size, (2, 1))
|
||||
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[1, 2, 3], [5, 6, 7]]], dtype=np.uint8)))
|
||||
|
||||
# Test that a grayscale image is converted to RGB
|
||||
image = np.array([[0, 255]], dtype=np.uint8)
|
||||
pil_image = PIL.Image.fromarray(image)
|
||||
self.assertEqual(pil_image.mode, "L")
|
||||
self.assertEqual(pil_image.size, (2, 1))
|
||||
rgb_image = convert_to_rgb(pil_image)
|
||||
self.assertEqual(rgb_image.mode, "RGB")
|
||||
self.assertEqual(rgb_image.size, (2, 1))
|
||||
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))
|
||||
|
Loading…
Reference in New Issue
Block a user