transformers.image_transforms.normalize wrong types (#35773)

transformers.image_transforms.normalize documents and checks for the wrong type for std and mean arguments

Co-authored-by: Louis Groux <louis.cal.groux@gmail.com>
This commit is contained in:
CalOmnie 2025-01-20 16:00:46 +01:00 committed by GitHub
parent 3998fa8aab
commit a142f16131
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@
import warnings
from math import ceil
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
@ -357,8 +357,8 @@ def resize(
def normalize(
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
mean: Union[float, Sequence[float]],
std: Union[float, Sequence[float]],
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
@ -370,9 +370,9 @@ def normalize(
Args:
image (`np.ndarray`):
The image to normalize.
mean (`float` or `Iterable[float]`):
mean (`float` or `Sequence[float]`):
The mean to use for normalization.
std (`float` or `Iterable[float]`):
std (`float` or `Sequence[float]`):
The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input.
@ -393,14 +393,14 @@ def normalize(
if not np.issubdtype(image.dtype, np.floating):
image = image.astype(np.float32)
if isinstance(mean, Iterable):
if isinstance(mean, Sequence):
if len(mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
else:
mean = [mean] * num_channels
mean = np.array(mean, dtype=image.dtype)
if isinstance(std, Iterable):
if isinstance(std, Sequence):
if len(std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
else: