mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable passing number of channels when inferring data format (#25412)
This commit is contained in:
parent
cb3c821cb7
commit
944ddce8bf
@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray:
|
||||
return to_numpy(img)
|
||||
|
||||
|
||||
def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
|
||||
def infer_channel_dimension_format(
|
||||
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
|
||||
) -> ChannelDimension:
|
||||
"""
|
||||
Infers the channel dimension format of `image`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image to infer the channel dimension of.
|
||||
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
|
||||
The number of channels of the image.
|
||||
|
||||
Returns:
|
||||
The channel dimension of the image.
|
||||
"""
|
||||
num_channels = num_channels if num_channels is not None else (1, 3)
|
||||
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
|
||||
|
||||
if image.ndim == 3:
|
||||
first_dim, last_dim = 0, 2
|
||||
elif image.ndim == 4:
|
||||
@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
||||
|
||||
if image.shape[first_dim] in (1, 3):
|
||||
if image.shape[first_dim] in num_channels:
|
||||
return ChannelDimension.FIRST
|
||||
elif image.shape[last_dim] in (1, 3):
|
||||
elif image.shape[last_dim] in num_channels:
|
||||
return ChannelDimension.LAST
|
||||
raise ValueError("Unable to infer channel dimension format")
|
||||
|
||||
|
@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
|
||||
|
||||
# But if we explicitly set one of the number of channels to 50 it works
|
||||
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.LAST)
|
||||
|
||||
# Test we correctly identify the channel dimension
|
||||
image = np.random.randint(0, 256, (3, 4, 5))
|
||||
inferred_dim = infer_channel_dimension_format(image)
|
||||
|
Loading…
Reference in New Issue
Block a user