Add vision requirement to image transforms (#20712)

* Add require_vision decorator

* Fixup

* Use requires_backends

* Add requires_backend to utils functions
This commit is contained in:
amyeroberts 2022-12-12 17:43:45 +00:00 committed by GitHub
parent fd2bed7f9f
commit b58beebe72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 19 deletions

View File

@ -18,25 +18,27 @@ from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from transformers.utils import ExplicitEnum, TensorType from transformers.image_utils import (
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available ChannelDimension,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
to_numpy_array,
)
from transformers.utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
from transformers.utils.import_utils import (
is_flax_available,
is_tf_available,
is_torch_available,
is_vision_available,
requires_backends,
)
if is_vision_available(): if is_vision_available():
import PIL import PIL
from .image_utils import ( from .image_utils import PILImageResampling
ChannelDimension,
PILImageResampling,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
is_jax_tensor,
is_tf_tensor,
is_torch_tensor,
to_numpy_array,
)
if is_torch_available(): if is_torch_available():
import torch import torch
@ -116,9 +118,9 @@ def rescale(
def to_pil_image( def to_pil_image(
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"], image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None, do_rescale: Optional[bool] = None,
) -> PIL.Image.Image: ) -> "PIL.Image.Image":
""" """
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
needed. needed.
@ -133,6 +135,8 @@ def to_pil_image(
Returns: Returns:
`PIL.Image.Image`: The converted image. `PIL.Image.Image`: The converted image.
""" """
requires_backends(to_pil_image, ["vision"])
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
return image return image
@ -226,7 +230,7 @@ def get_resize_output_image_size(
def resize( def resize(
image, image,
size: Tuple[int, int], size: Tuple[int, int],
resample=PILImageResampling.BILINEAR, resample: "PILImageResampling" = None,
reducing_gap: Optional[int] = None, reducing_gap: Optional[int] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
return_numpy: bool = True, return_numpy: bool = True,
@ -253,6 +257,10 @@ def resize(
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
""" """
requires_backends(resize, ["vision"])
resample = resample if resample is not None else PILImageResampling.BILINEAR
if not len(size) == 2: if not len(size) == 2:
raise ValueError("size must have 2 elements") raise ValueError("size must have 2 elements")
@ -303,6 +311,8 @@ def normalize(
data_format (`ChannelDimension`, *optional*): data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input. The channel dimension format of the output image. If unset, will use the inferred format from the input.
""" """
requires_backends(normalize, ["vision"])
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
warnings.warn( warnings.warn(
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.", "PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
@ -372,6 +382,8 @@ def center_crop(
Returns: Returns:
`np.ndarray`: The cropped image. `np.ndarray`: The cropped image.
""" """
requires_backends(center_crop, ["vision"])
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
warnings.warn( warnings.warn(
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.", "PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",

View File

@ -28,6 +28,7 @@ from .utils import (
is_torch_available, is_torch_available,
is_torch_tensor, is_torch_tensor,
is_vision_available, is_vision_available,
requires_backends,
to_numpy, to_numpy,
) )
from .utils.constants import ( # noqa: F401 from .utils.constants import ( # noqa: F401
@ -64,7 +65,8 @@ class ChannelDimension(ExplicitEnum):
def is_valid_image(img): def is_valid_image(img):
return ( return (
isinstance(img, (PIL.Image.Image, np.ndarray)) (is_vision_available() and isinstance(img, PIL.Image.Image))
or isinstance(img, np.ndarray)
or is_torch_tensor(img) or is_torch_tensor(img)
or is_tf_tensor(img) or is_tf_tensor(img)
or is_jax_tensor(img) or is_jax_tensor(img)
@ -90,7 +92,10 @@ def is_batched(img):
def to_numpy_array(img) -> np.ndarray: def to_numpy_array(img) -> np.ndarray:
if isinstance(img, PIL.Image.Image): if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")
if is_vision_available() and isinstance(img, PIL.Image.Image):
return np.array(img) return np.array(img)
return to_numpy(img) return to_numpy(img)
@ -215,6 +220,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
Returns: Returns:
`PIL.Image.Image`: A PIL Image. `PIL.Image.Image`: A PIL Image.
""" """
requires_backends(load_image, ["vision"])
if isinstance(image, str): if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"): if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file # We need to actually check for a real protocol, otherwise it's impossible to use a local file