transformers/tests/test_image_transforms.py
amyeroberts 4c7e8d0900
Add object detection + segmentation transforms (#20003)
* Add transforms for object detection

* Update src/transformers/image_transforms.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Better var names & docstring

* Remove unused var desc in docstring

* Update src/transformers/image_transforms.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2022-11-15 12:50:03 +00:00

292 lines
12 KiB
Python

# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from parameterized import parameterized
from transformers.testing_utils import require_flax, require_tf, require_torch, require_vision
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax
if is_vision_available():
import PIL.Image
from transformers.image_transforms import (
center_crop,
center_to_corners_format,
corners_to_center_format,
get_resize_output_image_size,
id_to_rgb,
normalize,
resize,
rgb_to_id,
to_channel_dimension_format,
to_pil_image,
)
def get_random_image(height, width, num_channels=3, channels_first=True):
shape = (num_channels, height, width) if channels_first else (height, width, num_channels)
random_array = np.random.randint(0, 256, shape, dtype=np.uint8)
return random_array
@require_vision
class ImageTransformsTester(unittest.TestCase):
@parameterized.expand(
[
("numpy_float_channels_first", (3, 4, 5), np.float32),
("numpy_float_channels_last", (4, 5, 3), np.float32),
("numpy_int_channels_first", (3, 4, 5), np.int32),
("numpy_uint_channels_first", (3, 4, 5), np.uint8),
]
)
@require_vision
def test_to_pil_image(self, name, image_shape, dtype):
image = np.random.randint(0, 256, image_shape).astype(dtype)
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
@require_tf
def test_to_pil_image_from_tensorflow(self):
# channels_first
image = tf.random.uniform((3, 4, 5))
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
# channels_last
image = tf.random.uniform((4, 5, 3))
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
@require_torch
def test_to_pil_image_from_torch(self):
# channels first
image = torch.rand((3, 4, 5))
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
# channels last
image = torch.rand((4, 5, 3))
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
@require_flax
def test_to_pil_image_from_jax(self):
key = jax.random.PRNGKey(0)
# channel first
image = jax.random.uniform(key, (3, 4, 5))
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
# channel last
image = jax.random.uniform(key, (4, 5, 3))
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
def test_to_channel_dimension_format(self):
# Test that function doesn't reorder if channel dim matches the input.
image = np.random.rand(3, 4, 5)
image = to_channel_dimension_format(image, "channels_first")
self.assertEqual(image.shape, (3, 4, 5))
image = np.random.rand(4, 5, 3)
image = to_channel_dimension_format(image, "channels_last")
self.assertEqual(image.shape, (4, 5, 3))
# Test that function reorders if channel dim doesn't match the input.
image = np.random.rand(3, 4, 5)
image = to_channel_dimension_format(image, "channels_last")
self.assertEqual(image.shape, (4, 5, 3))
image = np.random.rand(4, 5, 3)
image = to_channel_dimension_format(image, "channels_first")
self.assertEqual(image.shape, (3, 4, 5))
def test_get_resize_output_image_size(self):
image = np.random.randint(0, 256, (3, 224, 224))
# Test the output size defaults to (x, x) if an int is given.
self.assertEqual(get_resize_output_image_size(image, 10), (10, 10))
self.assertEqual(get_resize_output_image_size(image, [10]), (10, 10))
self.assertEqual(get_resize_output_image_size(image, (10,)), (10, 10))
# Test the output size is the same as the input if a two element tuple/list is given.
self.assertEqual(get_resize_output_image_size(image, (10, 20)), (10, 20))
self.assertEqual(get_resize_output_image_size(image, [10, 20]), (10, 20))
self.assertEqual(get_resize_output_image_size(image, (10, 20), default_to_square=True), (10, 20))
# To match pytorch behaviour, max_size is only relevant if size is an int
self.assertEqual(get_resize_output_image_size(image, (10, 20), max_size=5), (10, 20))
# Test output size = (int(size * height / width), size) if size is an int and height > width
image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False), (25, 20))
# Test output size = (size, int(size * width / height)) if size is an int and width <= height
image = np.random.randint(0, 256, (3, 40, 50))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False), (20, 25))
# Test size is resized if longer size > max_size
image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
def test_resize(self):
image = np.random.randint(0, 256, (3, 224, 224))
# Check the channel order is the same by default
resized_image = resize(image, (30, 40))
self.assertIsInstance(resized_image, np.ndarray)
self.assertEqual(resized_image.shape, (3, 30, 40))
# Check channel order is changed if specified
resized_image = resize(image, (30, 40), data_format="channels_last")
self.assertIsInstance(resized_image, np.ndarray)
self.assertEqual(resized_image.shape, (30, 40, 3))
# Check PIL.Image.Image is return if return_numpy=False
resized_image = resize(image, (30, 40), return_numpy=False)
self.assertIsInstance(resized_image, PIL.Image.Image)
# PIL size is in (width, height) order
self.assertEqual(resized_image.size, (40, 30))
def test_normalize(self):
image = np.random.randint(0, 256, (224, 224, 3)) / 255
# Test that exception is raised if inputs are incorrect
# Not a numpy array image
with self.assertRaises(ValueError):
normalize(5, 5, 5)
# Number of mean values != number of channels
with self.assertRaises(ValueError):
normalize(image, mean=(0.5, 0.6), std=1)
# Number of std values != number of channels
with self.assertRaises(ValueError):
normalize(image, mean=1, std=(0.5, 0.6))
# Test result is correct - output data format is channels_first and normalization
# correctly computed
mean = (0.5, 0.6, 0.7)
std = (0.1, 0.2, 0.3)
expected_image = ((image - mean) / std).transpose((2, 0, 1))
normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first")
self.assertIsInstance(normalized_image, np.ndarray)
self.assertEqual(normalized_image.shape, (3, 224, 224))
self.assertTrue(np.allclose(normalized_image, expected_image))
def test_center_crop(self):
image = np.random.randint(0, 256, (3, 224, 224))
# Test that exception is raised if inputs are incorrect
with self.assertRaises(ValueError):
center_crop(image, 10)
# Test result is correct - output data format is channels_first and center crop
# correctly computed
expected_image = image[:, 52:172, 82:142].transpose(1, 2, 0)
cropped_image = center_crop(image, (120, 60), data_format="channels_last")
self.assertIsInstance(cropped_image, np.ndarray)
self.assertEqual(cropped_image.shape, (120, 60, 3))
self.assertTrue(np.allclose(cropped_image, expected_image))
# Test that image is padded with zeros if crop size is larger than image size
expected_image = np.zeros((300, 260, 3))
expected_image[38:262, 18:242, :] = image.transpose((1, 2, 0))
cropped_image = center_crop(image, (300, 260), data_format="channels_last")
self.assertIsInstance(cropped_image, np.ndarray)
self.assertEqual(cropped_image.shape, (300, 260, 3))
self.assertTrue(np.allclose(cropped_image, expected_image))
def test_center_to_corners_format(self):
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
self.assertTrue(np.allclose(center_to_corners_format(bbox_center), expected))
# Check that the function and inverse function are inverse of each other
self.assertTrue(np.allclose(corners_to_center_format(center_to_corners_format(bbox_center)), bbox_center))
def test_corners_to_center_format(self):
bbox_corners = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
expected = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
self.assertTrue(np.allclose(corners_to_center_format(bbox_corners), expected))
# Check that the function and inverse function are inverse of each other
self.assertTrue(np.allclose(center_to_corners_format(corners_to_center_format(bbox_corners)), bbox_corners))
def test_rgb_to_id(self):
# test list input
rgb = [125, 4, 255]
self.assertEqual(rgb_to_id(rgb), 16712829)
# test numpy array input
color = np.array(
[
[
[213, 54, 165],
[88, 207, 39],
[156, 108, 128],
],
[
[183, 194, 46],
[137, 58, 88],
[114, 131, 233],
],
]
)
expected = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
self.assertTrue(np.allclose(rgb_to_id(color), expected))
def test_id_to_rgb(self):
# test int input
self.assertEqual(id_to_rgb(16712829), [125, 4, 255])
# test array input
id_array = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
color = np.array(
[
[
[213, 54, 165],
[88, 207, 39],
[156, 108, 128],
],
[
[183, 194, 46],
[137, 58, 88],
[114, 131, 233],
],
]
)
self.assertTrue(np.allclose(id_to_rgb(id_array), color))