transformers.image_transforms.normalize wrong types (#35773)

transformers.image_transforms.normalize documents and checks for the wrong type for std and mean arguments

Co-authored-by: Louis Groux <louis.cal.groux@gmail.com>
This commit is contained in:
CalOmnie 2025-01-20 16:00:46 +01:00 committed by GitHub
parent 3998fa8aab
commit a142f16131
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@
import warnings import warnings
from math import ceil from math import ceil
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
@ -357,8 +357,8 @@ def resize(
def normalize( def normalize(
image: np.ndarray, image: np.ndarray,
mean: Union[float, Iterable[float]], mean: Union[float, Sequence[float]],
std: Union[float, Iterable[float]], std: Union[float, Sequence[float]],
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
@ -370,9 +370,9 @@ def normalize(
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
The image to normalize. The image to normalize.
mean (`float` or `Iterable[float]`): mean (`float` or `Sequence[float]`):
The mean to use for normalization. The mean to use for normalization.
std (`float` or `Iterable[float]`): std (`float` or `Sequence[float]`):
The standard deviation to use for normalization. The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*): data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input. The channel dimension format of the output image. If unset, will use the inferred format from the input.
@ -393,14 +393,14 @@ def normalize(
if not np.issubdtype(image.dtype, np.floating): if not np.issubdtype(image.dtype, np.floating):
image = image.astype(np.float32) image = image.astype(np.float32)
if isinstance(mean, Iterable): if isinstance(mean, Sequence):
if len(mean) != num_channels: if len(mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
else: else:
mean = [mean] * num_channels mean = [mean] * num_channels
mean = np.array(mean, dtype=image.dtype) mean = np.array(mean, dtype=image.dtype)
if isinstance(std, Iterable): if isinstance(std, Sequence):
if len(std) != num_channels: if len(std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
else: else: