Add vision requirement to image transforms (#20712)

* Add require_vision decorator

* Fixup

* Use requires_backends

* Add requires_backend to utils functions
This commit is contained in:
amyeroberts 2022-12-12 17:43:45 +00:00 committed by GitHub
parent fd2bed7f9f
commit b58beebe72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 19 deletions

View File

@ -18,25 +18,27 @@ from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import ExplicitEnum, TensorType
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
from transformers.image_utils import (
ChannelDimension,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
to_numpy_array,
)
from transformers.utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
from transformers.utils.import_utils import (
is_flax_available,
is_tf_available,
is_torch_available,
is_vision_available,
requires_backends,
)
if is_vision_available():
import PIL
from .image_utils import (
ChannelDimension,
PILImageResampling,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
is_jax_tensor,
is_tf_tensor,
is_torch_tensor,
to_numpy_array,
)
from .image_utils import PILImageResampling
if is_torch_available():
import torch
@ -116,9 +118,9 @@ def rescale(
def to_pil_image(
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None,
) -> PIL.Image.Image:
) -> "PIL.Image.Image":
"""
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
needed.
@ -133,6 +135,8 @@ def to_pil_image(
Returns:
`PIL.Image.Image`: The converted image.
"""
requires_backends(to_pil_image, ["vision"])
if isinstance(image, PIL.Image.Image):
return image
@ -226,7 +230,7 @@ def get_resize_output_image_size(
def resize(
image,
size: Tuple[int, int],
resample=PILImageResampling.BILINEAR,
resample: "PILImageResampling" = None,
reducing_gap: Optional[int] = None,
data_format: Optional[ChannelDimension] = None,
return_numpy: bool = True,
@ -253,6 +257,10 @@ def resize(
Returns:
`np.ndarray`: The resized image.
"""
requires_backends(resize, ["vision"])
resample = resample if resample is not None else PILImageResampling.BILINEAR
if not len(size) == 2:
raise ValueError("size must have 2 elements")
@ -303,6 +311,8 @@ def normalize(
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input.
"""
requires_backends(normalize, ["vision"])
if isinstance(image, PIL.Image.Image):
warnings.warn(
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
@ -372,6 +382,8 @@ def center_crop(
Returns:
`np.ndarray`: The cropped image.
"""
requires_backends(center_crop, ["vision"])
if isinstance(image, PIL.Image.Image):
warnings.warn(
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",

View File

@ -28,6 +28,7 @@ from .utils import (
is_torch_available,
is_torch_tensor,
is_vision_available,
requires_backends,
to_numpy,
)
from .utils.constants import ( # noqa: F401
@ -64,7 +65,8 @@ class ChannelDimension(ExplicitEnum):
def is_valid_image(img):
return (
isinstance(img, (PIL.Image.Image, np.ndarray))
(is_vision_available() and isinstance(img, PIL.Image.Image))
or isinstance(img, np.ndarray)
or is_torch_tensor(img)
or is_tf_tensor(img)
or is_jax_tensor(img)
@ -90,7 +92,10 @@ def is_batched(img):
def to_numpy_array(img) -> np.ndarray:
if isinstance(img, PIL.Image.Image):
if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")
if is_vision_available() and isinstance(img, PIL.Image.Image):
return np.array(img)
return to_numpy(img)
@ -215,6 +220,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
Returns:
`PIL.Image.Image`: A PIL Image.
"""
requires_backends(load_image, ["vision"])
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file