mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add ViTImageProcessorFast to tests (#31424)
* Add ViTImageProcessor to tests * Correct data format * Review comments
This commit is contained in:
parent
aab0829790
commit
0f67ba1d74
@ -151,6 +151,11 @@ class BaseImageProcessor(ImageProcessingMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
encoder_dict = super().to_dict()
|
||||
encoder_dict.pop("_valid_processor_keys", None)
|
||||
return encoder_dict
|
||||
|
||||
|
||||
VALID_SIZE_DICT_KEYS = (
|
||||
{"height", "width"},
|
||||
|
@ -61,3 +61,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
def get_transforms(self, **kwargs) -> "Compose":
|
||||
self._validate_params(**kwargs)
|
||||
return self._build_transforms(**kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
encoder_dict = super().to_dict()
|
||||
encoder_dict.pop("_transform_params", None)
|
||||
return encoder_dict
|
||||
|
@ -399,7 +399,7 @@ class AutoImageProcessor:
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
use_fast = kwargs.pop("use_fast", False)
|
||||
use_fast = kwargs.pop("use_fast", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
@ -430,6 +430,7 @@ class AutoImageProcessor:
|
||||
|
||||
if image_processor_class is not None:
|
||||
# Update class name to reflect the use_fast option. If class is not found, None is returned.
|
||||
if use_fast is not None:
|
||||
if use_fast and not image_processor_class.endswith("Fast"):
|
||||
image_processor_class += "Fast"
|
||||
elif not use_fast and image_processor_class.endswith("Fast"):
|
||||
|
@ -772,7 +772,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index,
|
||||
do_reduce_labels,
|
||||
return_tensors,
|
||||
input_data_format=input_data_format,
|
||||
input_data_format=data_format,
|
||||
)
|
||||
return encoded_inputs
|
||||
|
||||
|
@ -772,7 +772,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index,
|
||||
do_reduce_labels,
|
||||
return_tensors,
|
||||
input_data_format=input_data_format,
|
||||
input_data_format=data_format,
|
||||
)
|
||||
return encoded_inputs
|
||||
|
||||
|
@ -772,7 +772,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index,
|
||||
do_reduce_labels,
|
||||
return_tensors,
|
||||
input_data_format=input_data_format,
|
||||
input_data_format=data_format,
|
||||
)
|
||||
return encoded_inputs
|
||||
|
||||
|
@ -114,7 +114,6 @@ class ViTImageProcessorFast(BaseImageProcessorFast):
|
||||
self.rescale_factor = rescale_factor
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
self._transform_settings = {}
|
||||
|
||||
def _build_transforms(
|
||||
self,
|
||||
@ -285,5 +284,5 @@ class ViTImageProcessorFast(BaseImageProcessorFast):
|
||||
)
|
||||
transformed_images = [transforms(image) for image in images]
|
||||
|
||||
data = {"pixel_values": torch.vstack(transformed_images)}
|
||||
data = {"pixel_values": torch.stack(transformed_images, dim=0)}
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
@ -17,6 +17,8 @@
|
||||
import unittest
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
@ -84,6 +86,8 @@ class BridgeTowerImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = size / min(w, h)
|
||||
|
@ -18,6 +18,8 @@ import json
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
@ -87,6 +89,8 @@ class ConditionalDetrImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -18,6 +18,8 @@ import json
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
@ -87,6 +89,8 @@ class DeformableDetrImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -17,6 +17,8 @@ import json
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
@ -86,6 +88,8 @@ class DetrImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -66,6 +66,8 @@ class GLPNImageProcessingTester(unittest.TestCase):
|
||||
def expected_output_image_shape(self, images):
|
||||
if isinstance(images[0], Image.Image):
|
||||
width, height = images[0].size
|
||||
elif isinstance(images[0], np.ndarray):
|
||||
height, width = images[0].shape[0], images[0].shape[1]
|
||||
else:
|
||||
height, width = images[0].shape[1], images[0].shape[2]
|
||||
|
||||
|
@ -18,6 +18,8 @@ import json
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
@ -93,6 +95,8 @@ class GroundingDinoImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_torchvision, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
@ -75,6 +77,8 @@ class IdeficsImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = size / min(w, h)
|
||||
|
@ -99,6 +99,8 @@ class Idefics2ImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
|
||||
@ -176,6 +178,10 @@ class Idefics2ImageProcessingTester(unittest.TestCase):
|
||||
if torchify:
|
||||
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
|
||||
|
||||
if numpify:
|
||||
# Numpy images are typically in channels last format
|
||||
images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list]
|
||||
|
||||
return images_list
|
||||
|
||||
|
||||
@ -206,6 +212,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
@ -226,7 +233,39 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processor_dict = self.image_processor_dict
|
||||
image_processor_dict["image_mean"] = [0.5, 0.5, 0.5, 0.5]
|
||||
image_processor_dict["image_std"] = [0.5, 0.5, 0.5, 0.5]
|
||||
image_processing = self.image_processing_class(**image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
|
||||
for sample_images in image_inputs:
|
||||
for image in sample_images:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(
|
||||
image_inputs[0], input_data_format="channels_last", return_tensors="pt"
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(
|
||||
image_inputs, input_data_format="channels_last", return_tensors="pt"
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
@ -248,6 +287,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
|
@ -98,6 +98,8 @@ class Mask2FormerImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -98,6 +98,8 @@ class MaskFormerImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -106,6 +106,8 @@ class OneFormerImageProcessorTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -143,6 +143,8 @@ class OneFormerProcessorTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
|
@ -232,7 +232,7 @@ class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
for max_patch in self.image_processor_tester.max_patches:
|
||||
# Test not batched input
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_first"
|
||||
image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_last"
|
||||
).flattened_patches
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
@ -241,7 +241,7 @@ class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processor(
|
||||
image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_first"
|
||||
image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_last"
|
||||
).flattened_patches
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
|
@ -72,6 +72,8 @@ class Swin2SRImageProcessingTester(unittest.TestCase):
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
input_width, input_height = img.size
|
||||
elif isinstance(img, np.ndarray):
|
||||
input_height, input_width = img.shape[-3:-1]
|
||||
else:
|
||||
input_height, input_width = img.shape[-2:]
|
||||
|
||||
@ -160,7 +162,7 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(
|
||||
image_inputs[0], return_tensors="pt", input_data_format="channels_first"
|
||||
image_inputs[0], return_tensors="pt", input_data_format="channels_last"
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
@ -285,7 +285,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0],
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_first",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values_images
|
||||
@ -296,7 +296,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
encoded_images = image_processor(
|
||||
image_inputs,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_first",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values_images
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
@ -78,6 +80,8 @@ class ViltImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = size / min(w, h)
|
||||
|
@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
|
||||
if is_vision_available():
|
||||
from transformers import ViTImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import ViTImageProcessorFast
|
||||
|
||||
|
||||
class ViTImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
@ -82,6 +85,7 @@ class ViTImageProcessingTester(unittest.TestCase):
|
||||
@require_vision
|
||||
class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = ViTImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = ViTImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -18,6 +18,7 @@ import json
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
@ -89,6 +90,8 @@ class YolosImageProcessingTester(unittest.TestCase):
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
width, height = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
height, width = image.shape[0], image.shape[1]
|
||||
else:
|
||||
height, width = image.shape[1], image.shape[2]
|
||||
|
||||
|
@ -18,7 +18,9 @@ import json
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers import AutoImageProcessor, BatchFeature
|
||||
@ -28,7 +30,6 @@ from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
@ -72,6 +73,10 @@ def prepare_image_inputs(
|
||||
if torchify:
|
||||
image_inputs = [torch.from_numpy(image) for image in image_inputs]
|
||||
|
||||
if numpify:
|
||||
# Numpy images are typically in channels last format
|
||||
image_inputs = [image.transpose(1, 2, 0) for image in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
@ -167,33 +172,28 @@ class ImageProcessingTestMixin:
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-2))
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_fast_is_faster_than_slow(self):
|
||||
import time
|
||||
|
||||
def measure_time(self, image_processor, dummy_image):
|
||||
start = time.time()
|
||||
_ = image_processor(dummy_image, return_tensors="pt")
|
||||
return time.time() - start
|
||||
|
||||
dummy_image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest("Skipping speed test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest("Skipping speed test as one of the image processors is not defined")
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
def measure_time(image_processor, image):
|
||||
start = time.time()
|
||||
_ = image_processor(image, return_tensors="pt")
|
||||
return time.time() - start
|
||||
|
||||
slow_time = self.measure_time(image_processor_slow, dummy_image)
|
||||
fast_time = self.measure_time(image_processor_fast, dummy_image)
|
||||
dummy_images = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class()
|
||||
|
||||
fast_time = measure_time(image_processor_fast, dummy_images)
|
||||
slow_time = measure_time(image_processor_slow, dummy_images)
|
||||
|
||||
self.assertLessEqual(fast_time, slow_time)
|
||||
|
||||
@ -238,6 +238,52 @@ class ImageProcessingTestMixin:
|
||||
|
||||
self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())
|
||||
|
||||
def test_save_load_fast_slow(self):
|
||||
"Test that we can load a fast image processor from a slow one and vice-versa."
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined")
|
||||
|
||||
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
|
||||
image_processor_slow_0 = self.image_processing_class(**image_processor_dict)
|
||||
|
||||
# Load fast image processor from slow one
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
image_processor_slow_0.save_pretrained(tmpdirname)
|
||||
image_processor_fast_0 = self.fast_image_processing_class.from_pretrained(tmpdirname)
|
||||
|
||||
image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict)
|
||||
|
||||
# Load slow image processor from fast one
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
image_processor_fast_1.save_pretrained(tmpdirname)
|
||||
image_processor_slow_1 = self.image_processing_class.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
|
||||
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
|
||||
|
||||
def test_save_load_fast_slow_auto(self):
|
||||
"Test that we can load a fast image processor from a slow one and vice-versa using AutoImageProcessor."
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest("Skipping slow/fast save/load test as one of the image processors is not defined")
|
||||
|
||||
image_processor_dict = self.image_processor_tester.prepare_image_processor_dict()
|
||||
image_processor_slow_0 = self.image_processing_class(**image_processor_dict)
|
||||
|
||||
# Load fast image processor from slow one
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
image_processor_slow_0.save_pretrained(tmpdirname)
|
||||
image_processor_fast_0 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=True)
|
||||
|
||||
image_processor_fast_1 = self.fast_image_processing_class(**image_processor_dict)
|
||||
|
||||
# Load slow image processor from fast one
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
image_processor_fast_1.save_pretrained(tmpdirname)
|
||||
image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False)
|
||||
|
||||
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
|
||||
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
|
||||
|
||||
def test_init_without_params(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class()
|
||||
@ -358,7 +404,7 @@ class ImageProcessingTestMixin:
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0],
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_first",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
@ -369,7 +415,7 @@ class ImageProcessingTestMixin:
|
||||
encoded_images = image_processor(
|
||||
image_inputs,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_first",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
|
Loading…
Reference in New Issue
Block a user