mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-22 22:09:23 +06:00
to_pil - don't rescale if int and in range 0-255 (#22158)
* Don't rescale if in and in range 0-255 * Raise value error if int values too large * Update tests/test_image_transforms.py * Update tests/test_image_transforms.py
This commit is contained in:
parent
3b22bfbc6a
commit
c6318c3788
@ -156,12 +156,20 @@ def to_pil_image(
|
|||||||
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
||||||
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
||||||
|
|
||||||
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
|
# PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
|
||||||
if do_rescale is None:
|
if do_rescale is None:
|
||||||
if np.all(0 <= image) and np.all(image <= 1):
|
if image.dtype == np.uint8:
|
||||||
do_rescale = True
|
|
||||||
elif np.allclose(image, image.astype(int)):
|
|
||||||
do_rescale = False
|
do_rescale = False
|
||||||
|
elif np.allclose(image, image.astype(int)):
|
||||||
|
if np.all(0 <= image) and np.all(image <= 255):
|
||||||
|
do_rescale = False
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The image to be converted to a PIL image contains values outside the range [0, 255], "
|
||||||
|
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
|
||||||
|
)
|
||||||
|
elif np.all(0 <= image) and np.all(image <= 1):
|
||||||
|
do_rescale = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The image to be converted to a PIL image contains values outside the range [0, 1], "
|
"The image to be converted to a PIL image contains values outside the range [0, 1], "
|
||||||
|
@ -101,6 +101,27 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
to_pil_image(image)
|
to_pil_image(image)
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
def test_to_pil_image_from_mask(self):
|
||||||
|
# Make sure binary mask remains a binary mask
|
||||||
|
image = np.random.randint(0, 2, (3, 4, 5)).astype(np.uint8)
|
||||||
|
pil_image = to_pil_image(image)
|
||||||
|
self.assertIsInstance(pil_image, PIL.Image.Image)
|
||||||
|
self.assertEqual(pil_image.size, (5, 4))
|
||||||
|
|
||||||
|
np_img = np.asarray(pil_image)
|
||||||
|
self.assertTrue(np_img.min() == 0)
|
||||||
|
self.assertTrue(np_img.max() == 1)
|
||||||
|
|
||||||
|
image = np.random.randint(0, 2, (3, 4, 5)).astype(np.float32)
|
||||||
|
pil_image = to_pil_image(image)
|
||||||
|
self.assertIsInstance(pil_image, PIL.Image.Image)
|
||||||
|
self.assertEqual(pil_image.size, (5, 4))
|
||||||
|
|
||||||
|
np_img = np.asarray(pil_image)
|
||||||
|
self.assertTrue(np_img.min() == 0)
|
||||||
|
self.assertTrue(np_img.max() == 1)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_to_pil_image_from_tensorflow(self):
|
def test_to_pil_image_from_tensorflow(self):
|
||||||
# channels_first
|
# channels_first
|
||||||
@ -222,7 +243,7 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertIsInstance(resized_image, np.ndarray)
|
self.assertIsInstance(resized_image, np.ndarray)
|
||||||
self.assertEqual(resized_image.shape, (30, 40, 3))
|
self.assertEqual(resized_image.shape, (30, 40, 3))
|
||||||
|
|
||||||
# Check PIL.Image.Image is return if return_numpy=False
|
# Check PIL.Image.Image is returned if return_numpy=False
|
||||||
resized_image = resize(image, (30, 40), return_numpy=False)
|
resized_image = resize(image, (30, 40), return_numpy=False)
|
||||||
self.assertIsInstance(resized_image, PIL.Image.Image)
|
self.assertIsInstance(resized_image, PIL.Image.Image)
|
||||||
# PIL size is in (width, height) order
|
# PIL size is in (width, height) order
|
||||||
|
Loading…
Reference in New Issue
Block a user