Move convert_to_rgb to image_transforms module (#20784)

* Move convert_to_rgb to image_transforms module

* Fix tests
This commit is contained in:
amyeroberts 2022-12-15 18:47:04 +00:00 committed by GitHub
parent 4bc723f87d
commit 491e951875
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 58 additions and 61 deletions

View File

@ -20,6 +20,7 @@ import numpy as np
from transformers.image_utils import (
ChannelDimension,
ImageInput,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
@ -687,3 +688,22 @@ def pad(
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
return image
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
def convert_to_rgb(image: ImageInput) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
Args:
image (Image):
The image to convert.
"""
requires_backends(convert_to_rgb, ["vision"])
if not isinstance(image, PIL.Image.Image):
return image
image = image.convert("RGB")
return image

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""Image processor class for BiT."""
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
convert_to_rgb,
get_resize_output_image_size,
normalize,
rescale,
@ -41,20 +42,6 @@ if is_vision_available():
import PIL
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class BitImageProcessor(BaseImageProcessor):
r"""
Constructs a BiT image processor.

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""Image processor class for Chinese-CLIP."""
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
convert_to_rgb,
get_resize_output_image_size,
normalize,
rescale,
@ -41,20 +42,6 @@ if is_vision_available():
import PIL
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class ChineseCLIPImageProcessor(BaseImageProcessor):
r"""
Constructs a Chinese-CLIP image processor.

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""Image processor class for CLIP."""
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
convert_to_rgb,
get_resize_output_image_size,
normalize,
rescale,
@ -41,20 +42,6 @@ if is_vision_available():
import PIL
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class CLIPImageProcessor(BaseImageProcessor):
r"""
Constructs a CLIP image processor.

View File

@ -14,7 +14,7 @@
# limitations under the License.
"""Image processor class for ViT hybrid."""
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
convert_to_rgb,
get_resize_output_image_size,
normalize,
rescale,
@ -41,21 +42,6 @@ if is_vision_available():
import PIL
# Copied from transformers.models.bit.image_processing_bit.convert_to_rgb
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class ViTHybridImageProcessor(BaseImageProcessor):
r"""
Constructs a ViT Hybrid image processor.

View File

@ -37,6 +37,7 @@ if is_vision_available():
from transformers.image_transforms import (
center_crop,
center_to_corners_format,
convert_to_rgb,
corners_to_center_format,
get_resize_output_image_size,
id_to_rgb,
@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue(
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
)
@require_vision
def test_convert_to_rgb(self):
# Test that an RGBA image is converted to RGB
image = np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.uint8)
pil_image = PIL.Image.fromarray(image)
self.assertEqual(pil_image.mode, "RGBA")
self.assertEqual(pil_image.size, (2, 1))
# For the moment, numpy images are returned as is
rgb_image = convert_to_rgb(image)
self.assertEqual(rgb_image.shape, (1, 2, 4))
self.assertTrue(np.allclose(rgb_image, image))
# And PIL images are converted
rgb_image = convert_to_rgb(pil_image)
self.assertEqual(rgb_image.mode, "RGB")
self.assertEqual(rgb_image.size, (2, 1))
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[1, 2, 3], [5, 6, 7]]], dtype=np.uint8)))
# Test that a grayscale image is converted to RGB
image = np.array([[0, 255]], dtype=np.uint8)
pil_image = PIL.Image.fromarray(image)
self.assertEqual(pil_image.mode, "L")
self.assertEqual(pil_image.size, (2, 1))
rgb_image = convert_to_rgb(pil_image)
self.assertEqual(rgb_image.mode, "RGB")
self.assertEqual(rgb_image.size, (2, 1))
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))