Enable passing number of channels when inferring data format (#25412)

This commit is contained in:
amyeroberts 2023-08-09 17:41:21 +01:00 committed by GitHub
parent cb3c821cb7
commit 944ddce8bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 3 deletions

View File

@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray:
return to_numpy(img)
def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
def infer_channel_dimension_format(
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
) -> ChannelDimension:
"""
Infers the channel dimension format of `image`.
Args:
image (`np.ndarray`):
The image to infer the channel dimension of.
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
The number of channels of the image.
Returns:
The channel dimension of the image.
"""
num_channels = num_channels if num_channels is not None else (1, 3)
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
if image.ndim == 3:
first_dim, last_dim = 0, 2
elif image.ndim == 4:
@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
else:
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
if image.shape[first_dim] in (1, 3):
if image.shape[first_dim] in num_channels:
return ChannelDimension.FIRST
elif image.shape[last_dim] in (1, 3):
elif image.shape[last_dim] in num_channels:
return ChannelDimension.LAST
raise ValueError("Unable to infer channel dimension format")

View File

@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
with pytest.raises(ValueError):
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
# But if we explicitly set one of the number of channels to 50 it works
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
self.assertEqual(inferred_dim, ChannelDimension.LAST)
# Test we correctly identify the channel dimension
image = np.random.randint(0, 256, (3, 4, 5))
inferred_dim = infer_channel_dimension_format(image)