mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add vision requirement to image transforms (#20712)
* Add require_vision decorator * Fixup * Use requires_backends * Add requires_backend to utils functions
This commit is contained in:
parent
fd2bed7f9f
commit
b58beebe72
@ -18,25 +18,27 @@ from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.utils import ExplicitEnum, TensorType
|
||||
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
||||
from transformers.image_utils import (
|
||||
ChannelDimension,
|
||||
get_channel_dimension_axis,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
to_numpy_array,
|
||||
)
|
||||
from transformers.utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
|
||||
from transformers.utils.import_utils import (
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
from .image_utils import (
|
||||
ChannelDimension,
|
||||
PILImageResampling,
|
||||
get_channel_dimension_axis,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_jax_tensor,
|
||||
is_tf_tensor,
|
||||
is_torch_tensor,
|
||||
to_numpy_array,
|
||||
)
|
||||
|
||||
from .image_utils import PILImageResampling
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -116,9 +118,9 @@ def rescale(
|
||||
|
||||
|
||||
def to_pil_image(
|
||||
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
||||
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
||||
do_rescale: Optional[bool] = None,
|
||||
) -> PIL.Image.Image:
|
||||
) -> "PIL.Image.Image":
|
||||
"""
|
||||
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
||||
needed.
|
||||
@ -133,6 +135,8 @@ def to_pil_image(
|
||||
Returns:
|
||||
`PIL.Image.Image`: The converted image.
|
||||
"""
|
||||
requires_backends(to_pil_image, ["vision"])
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
|
||||
@ -226,7 +230,7 @@ def get_resize_output_image_size(
|
||||
def resize(
|
||||
image,
|
||||
size: Tuple[int, int],
|
||||
resample=PILImageResampling.BILINEAR,
|
||||
resample: "PILImageResampling" = None,
|
||||
reducing_gap: Optional[int] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
return_numpy: bool = True,
|
||||
@ -253,6 +257,10 @@ def resize(
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
requires_backends(resize, ["vision"])
|
||||
|
||||
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
||||
|
||||
if not len(size) == 2:
|
||||
raise ValueError("size must have 2 elements")
|
||||
|
||||
@ -303,6 +311,8 @@ def normalize(
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
||||
"""
|
||||
requires_backends(normalize, ["vision"])
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
warnings.warn(
|
||||
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
|
||||
@ -372,6 +382,8 @@ def center_crop(
|
||||
Returns:
|
||||
`np.ndarray`: The cropped image.
|
||||
"""
|
||||
requires_backends(center_crop, ["vision"])
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
warnings.warn(
|
||||
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
|
||||
|
@ -28,6 +28,7 @@ from .utils import (
|
||||
is_torch_available,
|
||||
is_torch_tensor,
|
||||
is_vision_available,
|
||||
requires_backends,
|
||||
to_numpy,
|
||||
)
|
||||
from .utils.constants import ( # noqa: F401
|
||||
@ -64,7 +65,8 @@ class ChannelDimension(ExplicitEnum):
|
||||
|
||||
def is_valid_image(img):
|
||||
return (
|
||||
isinstance(img, (PIL.Image.Image, np.ndarray))
|
||||
(is_vision_available() and isinstance(img, PIL.Image.Image))
|
||||
or isinstance(img, np.ndarray)
|
||||
or is_torch_tensor(img)
|
||||
or is_tf_tensor(img)
|
||||
or is_jax_tensor(img)
|
||||
@ -90,7 +92,10 @@ def is_batched(img):
|
||||
|
||||
|
||||
def to_numpy_array(img) -> np.ndarray:
|
||||
if isinstance(img, PIL.Image.Image):
|
||||
if not is_valid_image(img):
|
||||
raise ValueError(f"Invalid image type: {type(img)}")
|
||||
|
||||
if is_vision_available() and isinstance(img, PIL.Image.Image):
|
||||
return np.array(img)
|
||||
return to_numpy(img)
|
||||
|
||||
@ -215,6 +220,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
|
||||
Returns:
|
||||
`PIL.Image.Image`: A PIL Image.
|
||||
"""
|
||||
requires_backends(load_image, ["vision"])
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
|
Loading…
Reference in New Issue
Block a user