mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Add vision requirement to image transforms (#20712)
* Add require_vision decorator * Fixup * Use requires_backends * Add requires_backend to utils functions
This commit is contained in:
parent
fd2bed7f9f
commit
b58beebe72
@ -18,25 +18,27 @@ from typing import Iterable, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.utils import ExplicitEnum, TensorType
|
from transformers.image_utils import (
|
||||||
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
ChannelDimension,
|
||||||
|
get_channel_dimension_axis,
|
||||||
|
get_image_size,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
to_numpy_array,
|
||||||
|
)
|
||||||
|
from transformers.utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
|
||||||
|
from transformers.utils.import_utils import (
|
||||||
|
is_flax_available,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
from .image_utils import (
|
from .image_utils import PILImageResampling
|
||||||
ChannelDimension,
|
|
||||||
PILImageResampling,
|
|
||||||
get_channel_dimension_axis,
|
|
||||||
get_image_size,
|
|
||||||
infer_channel_dimension_format,
|
|
||||||
is_jax_tensor,
|
|
||||||
is_tf_tensor,
|
|
||||||
is_torch_tensor,
|
|
||||||
to_numpy_array,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -116,9 +118,9 @@ def rescale(
|
|||||||
|
|
||||||
|
|
||||||
def to_pil_image(
|
def to_pil_image(
|
||||||
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
||||||
do_rescale: Optional[bool] = None,
|
do_rescale: Optional[bool] = None,
|
||||||
) -> PIL.Image.Image:
|
) -> "PIL.Image.Image":
|
||||||
"""
|
"""
|
||||||
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
||||||
needed.
|
needed.
|
||||||
@ -133,6 +135,8 @@ def to_pil_image(
|
|||||||
Returns:
|
Returns:
|
||||||
`PIL.Image.Image`: The converted image.
|
`PIL.Image.Image`: The converted image.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(to_pil_image, ["vision"])
|
||||||
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@ -226,7 +230,7 @@ def get_resize_output_image_size(
|
|||||||
def resize(
|
def resize(
|
||||||
image,
|
image,
|
||||||
size: Tuple[int, int],
|
size: Tuple[int, int],
|
||||||
resample=PILImageResampling.BILINEAR,
|
resample: "PILImageResampling" = None,
|
||||||
reducing_gap: Optional[int] = None,
|
reducing_gap: Optional[int] = None,
|
||||||
data_format: Optional[ChannelDimension] = None,
|
data_format: Optional[ChannelDimension] = None,
|
||||||
return_numpy: bool = True,
|
return_numpy: bool = True,
|
||||||
@ -253,6 +257,10 @@ def resize(
|
|||||||
Returns:
|
Returns:
|
||||||
`np.ndarray`: The resized image.
|
`np.ndarray`: The resized image.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(resize, ["vision"])
|
||||||
|
|
||||||
|
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
||||||
|
|
||||||
if not len(size) == 2:
|
if not len(size) == 2:
|
||||||
raise ValueError("size must have 2 elements")
|
raise ValueError("size must have 2 elements")
|
||||||
|
|
||||||
@ -303,6 +311,8 @@ def normalize(
|
|||||||
data_format (`ChannelDimension`, *optional*):
|
data_format (`ChannelDimension`, *optional*):
|
||||||
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(normalize, ["vision"])
|
||||||
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
|
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
|
||||||
@ -372,6 +382,8 @@ def center_crop(
|
|||||||
Returns:
|
Returns:
|
||||||
`np.ndarray`: The cropped image.
|
`np.ndarray`: The cropped image.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(center_crop, ["vision"])
|
||||||
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
|
"PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.",
|
||||||
|
@ -28,6 +28,7 @@ from .utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
|
requires_backends,
|
||||||
to_numpy,
|
to_numpy,
|
||||||
)
|
)
|
||||||
from .utils.constants import ( # noqa: F401
|
from .utils.constants import ( # noqa: F401
|
||||||
@ -64,7 +65,8 @@ class ChannelDimension(ExplicitEnum):
|
|||||||
|
|
||||||
def is_valid_image(img):
|
def is_valid_image(img):
|
||||||
return (
|
return (
|
||||||
isinstance(img, (PIL.Image.Image, np.ndarray))
|
(is_vision_available() and isinstance(img, PIL.Image.Image))
|
||||||
|
or isinstance(img, np.ndarray)
|
||||||
or is_torch_tensor(img)
|
or is_torch_tensor(img)
|
||||||
or is_tf_tensor(img)
|
or is_tf_tensor(img)
|
||||||
or is_jax_tensor(img)
|
or is_jax_tensor(img)
|
||||||
@ -90,7 +92,10 @@ def is_batched(img):
|
|||||||
|
|
||||||
|
|
||||||
def to_numpy_array(img) -> np.ndarray:
|
def to_numpy_array(img) -> np.ndarray:
|
||||||
if isinstance(img, PIL.Image.Image):
|
if not is_valid_image(img):
|
||||||
|
raise ValueError(f"Invalid image type: {type(img)}")
|
||||||
|
|
||||||
|
if is_vision_available() and isinstance(img, PIL.Image.Image):
|
||||||
return np.array(img)
|
return np.array(img)
|
||||||
return to_numpy(img)
|
return to_numpy(img)
|
||||||
|
|
||||||
@ -215,6 +220,7 @@ def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
|
|||||||
Returns:
|
Returns:
|
||||||
`PIL.Image.Image`: A PIL Image.
|
`PIL.Image.Image`: A PIL Image.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(load_image, ["vision"])
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
if image.startswith("http://") or image.startswith("https://"):
|
if image.startswith("http://") or image.startswith("https://"):
|
||||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||||
|
Loading…
Reference in New Issue
Block a user