mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
🚨🚨🚨 Fix rescale ViVit Efficientnet (#25174)
* Fix rescaling bug * Add tests * Update integration tests * Fix up * Update src/transformers/image_transforms.py * Update test - new possible order in list
This commit is contained in:
parent
03f98f9683
commit
05cda5df34
@ -110,10 +110,11 @@ def rescale(
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
|
||||
image = image.astype(dtype)
|
||||
|
||||
rescaled_image = image * scale
|
||||
if data_format is not None:
|
||||
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
|
||||
rescaled_image = rescaled_image.astype(dtype)
|
||||
return rescaled_image
|
||||
|
||||
|
||||
|
@ -153,7 +153,13 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Rescale an image by a scale factor. image = image * scale.
|
||||
Rescale an image by a scale factor.
|
||||
|
||||
If offset is True, the image is rescaled between [-1, 1].
|
||||
image = image * scale * 2 - 1
|
||||
|
||||
If offset is False, the image is rescaled between [0, 1].
|
||||
image = image * scale
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
@ -165,13 +171,12 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
scale = scale * 2 if offset else scale
|
||||
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
|
||||
if offset:
|
||||
rescaled_image = (image - 127.5) * scale
|
||||
if data_format is not None:
|
||||
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
|
||||
rescaled_image = rescaled_image.astype(np.float32)
|
||||
else:
|
||||
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
rescaled_image = rescaled_image - 1
|
||||
|
||||
return rescaled_image
|
||||
|
||||
def preprocess(
|
||||
|
@ -167,6 +167,7 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
|
||||
# Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale
|
||||
def rescale(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
@ -178,23 +179,29 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
Rescale an image by a scale factor.
|
||||
|
||||
If offset is `True`, image scaled between [-1, 1]: image = (image - 127.5) * scale. If offset is `False`, image
|
||||
scaled between [0, 1]: image = image * scale
|
||||
If offset is True, the image is rescaled between [-1, 1].
|
||||
image = image * scale * 2 - 1
|
||||
|
||||
If offset is False, the image is rescaled between [0, 1].
|
||||
image = image * scale
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to rescale.
|
||||
scale (`int` or `float`):
|
||||
Scale to apply to the image.
|
||||
offset (`bool`, *optional*):
|
||||
offset (`bool`, *optional*):
|
||||
Whether to scale the image in both negative and positive directions.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
"""
|
||||
image = image.astype(np.float32)
|
||||
scale = scale * 2 if offset else scale
|
||||
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
|
||||
if offset:
|
||||
image = image - (scale / 2)
|
||||
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
rescaled_image = rescaled_image - 1
|
||||
|
||||
return rescaled_image
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
|
@ -193,3 +193,17 @@ class EfficientNetImageProcessorTest(ImageProcessingSavingTestMixin, unittest.Te
|
||||
self.image_processor_tester.size["width"],
|
||||
),
|
||||
)
|
||||
|
||||
def test_rescale(self):
|
||||
# EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1
|
||||
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
|
||||
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255)
|
||||
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
|
||||
expected_image = image.astype(np.float32) / 255.0
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
@ -212,3 +212,17 @@ class VivitImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
|
||||
def test_rescale(self):
|
||||
# ViVit optionally rescales between -1 and 1 instead of the usual 0 and 1
|
||||
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
|
||||
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255)
|
||||
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
||||
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
|
||||
expected_image = image.astype(np.float32) / 255.0
|
||||
self.assertTrue(np.allclose(rescaled_image, expected_image))
|
||||
|
@ -345,6 +345,6 @@ class VivitModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
# taken from original model
|
||||
expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]).to(torch_device)
|
||||
expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))
|
||||
|
@ -85,6 +85,7 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
|
||||
[
|
||||
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
|
||||
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}],
|
||||
[{"score": 0.333, "label": "b"}, {"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}],
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user