Add ViTImageProcessorFast to tests (#31424)

* Add ViTImageProcessor to tests

* Correct data format

* Review comments
This commit is contained in:
amyeroberts 2024-06-25 13:36:58 +01:00 committed by GitHub
parent aab0829790
commit 0f67ba1d74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 230 additions and 87 deletions

View File

@ -151,6 +151,11 @@ class BaseImageProcessor(ImageProcessingMixin):
**kwargs, **kwargs,
) )
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
return encoder_dict
VALID_SIZE_DICT_KEYS = ( VALID_SIZE_DICT_KEYS = (
{"height", "width"}, {"height", "width"},

View File

@ -61,3 +61,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
def get_transforms(self, **kwargs) -> "Compose": def get_transforms(self, **kwargs) -> "Compose":
self._validate_params(**kwargs) self._validate_params(**kwargs)
return self._build_transforms(**kwargs) return self._build_transforms(**kwargs)
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_transform_params", None)
return encoder_dict

View File

@ -399,7 +399,7 @@ class AutoImageProcessor:
kwargs["token"] = use_auth_token kwargs["token"] = use_auth_token
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
use_fast = kwargs.pop("use_fast", False) use_fast = kwargs.pop("use_fast", None)
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
@ -430,10 +430,11 @@ class AutoImageProcessor:
if image_processor_class is not None: if image_processor_class is not None:
# Update class name to reflect the use_fast option. If class is not found, None is returned. # Update class name to reflect the use_fast option. If class is not found, None is returned.
if use_fast and not image_processor_class.endswith("Fast"): if use_fast is not None:
image_processor_class += "Fast" if use_fast and not image_processor_class.endswith("Fast"):
elif not use_fast and image_processor_class.endswith("Fast"): image_processor_class += "Fast"
image_processor_class = image_processor_class[:-4] elif not use_fast and image_processor_class.endswith("Fast"):
image_processor_class = image_processor_class[:-4]
image_processor_class = image_processor_class_from_name(image_processor_class) image_processor_class = image_processor_class_from_name(image_processor_class)
has_remote_code = image_processor_auto_map is not None has_remote_code = image_processor_auto_map is not None

View File

@ -772,7 +772,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
ignore_index, ignore_index,
do_reduce_labels, do_reduce_labels,
return_tensors, return_tensors,
input_data_format=input_data_format, input_data_format=data_format,
) )
return encoded_inputs return encoded_inputs

View File

@ -772,7 +772,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
ignore_index, ignore_index,
do_reduce_labels, do_reduce_labels,
return_tensors, return_tensors,
input_data_format=input_data_format, input_data_format=data_format,
) )
return encoded_inputs return encoded_inputs

View File

@ -772,7 +772,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
ignore_index, ignore_index,
do_reduce_labels, do_reduce_labels,
return_tensors, return_tensors,
input_data_format=input_data_format, input_data_format=data_format,
) )
return encoded_inputs return encoded_inputs

View File

@ -114,7 +114,6 @@ class ViTImageProcessorFast(BaseImageProcessorFast):
self.rescale_factor = rescale_factor self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self._transform_settings = {}
def _build_transforms( def _build_transforms(
self, self,
@ -285,5 +284,5 @@ class ViTImageProcessorFast(BaseImageProcessorFast):
) )
transformed_images = [transforms(image) for image in images] transformed_images = [transforms(image) for image in images]
data = {"pixel_values": torch.vstack(transformed_images)} data = {"pixel_values": torch.stack(transformed_images, dim=0)}
return BatchFeature(data, tensor_type=return_tensors) return BatchFeature(data, tensor_type=return_tensors)

View File

@ -17,6 +17,8 @@
import unittest import unittest
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
@ -84,6 +86,8 @@ class BridgeTowerImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
scale = size / min(w, h) scale = size / min(w, h)

View File

@ -18,6 +18,8 @@ import json
import pathlib import pathlib
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision, slow from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
@ -87,6 +89,8 @@ class ConditionalDetrImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -18,6 +18,8 @@ import json
import pathlib import pathlib
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision, slow from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
@ -87,6 +89,8 @@ class DeformableDetrImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -17,6 +17,8 @@ import json
import pathlib import pathlib
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision, slow from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
@ -86,6 +88,8 @@ class DetrImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -66,6 +66,8 @@ class GLPNImageProcessingTester(unittest.TestCase):
def expected_output_image_shape(self, images): def expected_output_image_shape(self, images):
if isinstance(images[0], Image.Image): if isinstance(images[0], Image.Image):
width, height = images[0].size width, height = images[0].size
elif isinstance(images[0], np.ndarray):
height, width = images[0].shape[0], images[0].shape[1]
else: else:
height, width = images[0].shape[1], images[0].shape[2] height, width = images[0].shape[1], images[0].shape[2]

View File

@ -18,6 +18,8 @@ import json
import pathlib import pathlib
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision, slow from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
@ -93,6 +95,8 @@ class GroundingDinoImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -16,6 +16,8 @@
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_torchvision, require_vision from transformers.testing_utils import require_torch, require_torchvision, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
@ -75,6 +77,8 @@ class IdeficsImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
scale = size / min(w, h) scale = size / min(w, h)

View File

@ -99,6 +99,8 @@ class Idefics2ImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
@ -176,6 +178,10 @@ class Idefics2ImageProcessingTester(unittest.TestCase):
if torchify: if torchify:
images_list = [[torch.from_numpy(image) for image in images] for images in images_list] images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
if numpify:
# Numpy images are typically in channels last format
images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list]
return images_list return images_list
@ -206,66 +212,100 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(hasattr(image_processing, "do_image_splitting")) self.assertTrue(hasattr(image_processing, "do_image_splitting"))
def test_call_numpy(self): def test_call_numpy(self):
# Initialize image_processing for image_processing_class in self.image_processor_list:
image_processing = self.image_processing_class(**self.image_processor_dict) # Initialize image_processing
# create random numpy tensors image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) # create random numpy tensors
for sample_images in image_inputs: image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for image in sample_images: for sample_images in image_inputs:
self.assertIsInstance(image, np.ndarray) for image in sample_images:
self.assertIsInstance(image, np.ndarray)
# Test not batched input # Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched # Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual( self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
) )
def test_call_numpy_4_channels(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processor_dict = self.image_processor_dict
image_processor_dict["image_mean"] = [0.5, 0.5, 0.5, 0.5]
image_processor_dict["image_std"] = [0.5, 0.5, 0.5, 0.5]
image_processing = self.image_processing_class(**image_processor_dict)
# create random numpy tensors
self.image_processor_tester.num_channels = 4
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
for image in sample_images:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(
image_inputs[0], input_data_format="channels_last", return_tensors="pt"
).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(
image_inputs, input_data_format="channels_last", return_tensors="pt"
).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pil(self): def test_call_pil(self):
# Initialize image_processing for image_processing_class in self.image_processor_list:
image_processing = self.image_processing_class(**self.image_processor_dict) # Initialize image_processing
# create random PIL images image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) # create random PIL images
for images in image_inputs: image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in images: for images in image_inputs:
self.assertIsInstance(image, Image.Image) for image in images:
self.assertIsInstance(image, Image.Image)
# Test not batched input # Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched # Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual( self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
) )
def test_call_pytorch(self): def test_call_pytorch(self):
# Initialize image_processing for image_processing_class in self.image_processor_list:
image_processing = self.image_processing_class(**self.image_processor_dict) # Initialize image_processing
# create random PyTorch tensors image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) # create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for images in image_inputs: for images in image_inputs:
for image in images: for image in images:
self.assertIsInstance(image, torch.Tensor) self.assertIsInstance(image, torch.Tensor)
# Test not batched input # Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched # Test batched
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual( self.assertEqual(
tuple(encoded_images.shape), tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape), (self.image_processor_tester.batch_size, *expected_output_image_shape),
) )

View File

@ -98,6 +98,8 @@ class Mask2FormerImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -98,6 +98,8 @@ class MaskFormerImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -106,6 +106,8 @@ class OneFormerImageProcessorTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -143,6 +143,8 @@ class OneFormerProcessorTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
if w < h: if w < h:

View File

@ -232,7 +232,7 @@ class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
for max_patch in self.image_processor_tester.max_patches: for max_patch in self.image_processor_tester.max_patches:
# Test not batched input # Test not batched input
encoded_images = image_processor( encoded_images = image_processor(
image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_first" image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_last"
).flattened_patches ).flattened_patches
self.assertEqual( self.assertEqual(
encoded_images.shape, encoded_images.shape,
@ -241,7 +241,7 @@ class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
# Test batched # Test batched
encoded_images = image_processor( encoded_images = image_processor(
image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_first" image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_last"
).flattened_patches ).flattened_patches
self.assertEqual( self.assertEqual(
encoded_images.shape, encoded_images.shape,

View File

@ -72,6 +72,8 @@ class Swin2SRImageProcessingTester(unittest.TestCase):
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
input_width, input_height = img.size input_width, input_height = img.size
elif isinstance(img, np.ndarray):
input_height, input_width = img.shape[-3:-1]
else: else:
input_height, input_width = img.shape[-2:] input_height, input_width = img.shape[-2:]
@ -160,7 +162,7 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Test not batched input # Test not batched input
encoded_images = image_processing( encoded_images = image_processing(
image_inputs[0], return_tensors="pt", input_data_format="channels_first" image_inputs[0], return_tensors="pt", input_data_format="channels_last"
).pixel_values ).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))

View File

@ -285,7 +285,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
encoded_images = image_processor( encoded_images = image_processor(
image_inputs[0], image_inputs[0],
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values_images ).pixel_values_images
@ -296,7 +296,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
encoded_images = image_processor( encoded_images = image_processor(
image_inputs, image_inputs,
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values_images ).pixel_values_images

View File

@ -16,6 +16,8 @@
import unittest import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
@ -78,6 +80,8 @@ class ViltImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
w, h = image.size w, h = image.size
elif isinstance(image, np.ndarray):
h, w = image.shape[0], image.shape[1]
else: else:
h, w = image.shape[1], image.shape[2] h, w = image.shape[1], image.shape[2]
scale = size / min(w, h) scale = size / min(w, h)

View File

@ -17,7 +17,7 @@
import unittest import unittest
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available(): if is_vision_available():
from transformers import ViTImageProcessor from transformers import ViTImageProcessor
if is_torchvision_available():
from transformers import ViTImageProcessorFast
class ViTImageProcessingTester(unittest.TestCase): class ViTImageProcessingTester(unittest.TestCase):
def __init__( def __init__(
@ -82,6 +85,7 @@ class ViTImageProcessingTester(unittest.TestCase):
@require_vision @require_vision
class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ViTImageProcessor if is_vision_available() else None image_processing_class = ViTImageProcessor if is_vision_available() else None
fast_image_processing_class = ViTImageProcessorFast if is_torchvision_available() else None
def setUp(self): def setUp(self):
super().setUp() super().setUp()

View File

@ -18,6 +18,7 @@ import json
import pathlib import pathlib
import unittest import unittest
import numpy as np
from parameterized import parameterized from parameterized import parameterized
from transformers.testing_utils import require_torch, require_vision, slow from transformers.testing_utils import require_torch, require_vision, slow
@ -89,6 +90,8 @@ class YolosImageProcessingTester(unittest.TestCase):
image = image_inputs[0] image = image_inputs[0]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
width, height = image.size width, height = image.size
elif isinstance(image, np.ndarray):
height, width = image.shape[0], image.shape[1]
else: else:
height, width = image.shape[1], image.shape[2] height, width = image.shape[1], image.shape[2]

View File

@ -18,7 +18,9 @@ import json
import os import os
import pathlib import pathlib
import tempfile import tempfile
import time
import numpy as np
import requests import requests
from transformers import AutoImageProcessor, BatchFeature from transformers import AutoImageProcessor, BatchFeature
@ -28,7 +30,6 @@ from transformers.utils import is_torch_available, is_vision_available
if is_torch_available(): if is_torch_available():
import numpy as np
import torch import torch
if is_vision_available(): if is_vision_available():
@ -72,6 +73,10 @@ def prepare_image_inputs(
if torchify: if torchify:
image_inputs = [torch.from_numpy(image) for image in image_inputs] image_inputs = [torch.from_numpy(image) for image in image_inputs]
if numpify:
# Numpy images are typically in channels last format
image_inputs = [image.transpose(1, 2, 0) for image in image_inputs]
return image_inputs return image_inputs
@ -167,33 +172,28 @@ class ImageProcessingTestMixin:
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-3)) self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-2))
@require_vision @require_vision
@require_torch @require_torch
def test_fast_is_faster_than_slow(self): def test_fast_is_faster_than_slow(self):
import time
def measure_time(self, image_processor, dummy_image):
start = time.time()
_ = image_processor(dummy_image, return_tensors="pt")
return time.time() - start
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
if not self.test_slow_image_processor or not self.test_fast_image_processor: if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest("Skipping speed test") self.skipTest("Skipping speed test")
if self.image_processing_class is None or self.fast_image_processing_class is None: if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping speed test as one of the image processors is not defined") self.skipTest("Skipping speed test as one of the image processors is not defined")
image_processor_slow = self.image_processing_class(**self.image_processor_dict) def measure_time(image_processor, image):
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) start = time.time()
_ = image_processor(image, return_tensors="pt")
return time.time() - start
slow_time = self.measure_time(image_processor_slow, dummy_image) dummy_images = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8)
fast_time = self.measure_time(image_processor_fast, dummy_image) image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class()
fast_time = measure_time(image_processor_fast, dummy_images)
slow_time = measure_time(image_processor_slow, dummy_images)
self.assertLessEqual(fast_time, slow_time) self.assertLessEqual(fast_time, slow_time)
@ -238,6 +238,52 @@ class ImageProcessingTestMixin:
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict()) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
def test_save_load_fast_slow(self):
"Test that we can load a fast image processor from a slow one and vice-versa."
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined")
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
image_processor_slow_0 = self.image_processing_class(**image_processor_dict)
# Load fast image processor from slow one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_slow_0.save_pretrained(tmpdirname)
image_processor_fast_0 = self.fast_image_processing_class.from_pretrained(tmpdirname)
image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict)
# Load slow image processor from fast one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_fast_1.save_pretrained(tmpdirname)
image_processor_slow_1 = self.image_processing_class.from_pretrained(tmpdirname)
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
def test_save_load_fast_slow_auto(self):
"Test that we can load a fast image processor from a slow one and vice-versa using AutoImageProcessor."
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined")
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
image_processor_slow_0 = self.image_processing_class(**image_processor_dict)
# Load fast image processor from slow one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_slow_0.save_pretrained(tmpdirname)
image_processor_fast_0 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=True)
image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict)
# Load slow image processor from fast one
with tempfile.TemporaryDirectory() as tmpdirname:
image_processor_fast_1.save_pretrained(tmpdirname)
image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False)
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
def test_init_without_params(self): def test_init_without_params(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list:
image_processor = image_processing_class() image_processor = image_processing_class()
@ -358,7 +404,7 @@ class ImageProcessingTestMixin:
encoded_images = image_processor( encoded_images = image_processor(
image_inputs[0], image_inputs[0],
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values ).pixel_values
@ -369,7 +415,7 @@ class ImageProcessingTestMixin:
encoded_images = image_processor( encoded_images = image_processor(
image_inputs, image_inputs,
return_tensors="pt", return_tensors="pt",
input_data_format="channels_first", input_data_format="channels_last",
image_mean=0, image_mean=0,
image_std=1, image_std=1,
).pixel_values ).pixel_values