fix: center_crop occasionally outputs off-by-one dimension matrix (#30934)

If required padding for a crop larger than input image is odd-numbered,
the padding would be rounded down instead of rounded up, causing the
output dimension to be one smaller than it should be.
This commit is contained in:
Matthew Beckers 2024-05-21 13:56:52 +01:00 committed by GitHub
parent daf281f44f
commit 3b09d3f05f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 2 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from math import ceil
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -483,9 +484,9 @@ def center_crop(
new_image = np.zeros_like(image, shape=new_shape) new_image = np.zeros_like(image, shape=new_shape)
# If the image is too small, pad it with zeros # If the image is too small, pad it with zeros
top_pad = (new_height - orig_height) // 2 top_pad = ceil((new_height - orig_height) / 2)
bottom_pad = top_pad + orig_height bottom_pad = top_pad + orig_height
left_pad = (new_width - orig_width) // 2 left_pad = ceil((new_width - orig_width) / 2)
right_pad = left_pad + orig_width right_pad = left_pad + orig_width
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image

View File

@ -369,6 +369,10 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(cropped_image.shape, (300, 260, 3)) self.assertEqual(cropped_image.shape, (300, 260, 3))
self.assertTrue(np.allclose(cropped_image, expected_image)) self.assertTrue(np.allclose(cropped_image, expected_image))
# Test that odd numbered padding requirement still leads to correct output dimensions
cropped_image = center_crop(image, (300, 259), data_format="channels_last")
self.assertEqual(cropped_image.shape, (300, 259, 3))
# Test image with 4 channels is cropped correctly # Test image with 4 channels is cropped correctly
image = np.random.randint(0, 256, (224, 224, 4)) image = np.random.randint(0, 256, (224, 224, 4))
expected_image = image[52:172, 82:142, :] expected_image = image[52:172, 82:142, :]