mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Input data format (#25464)
* Add copied from statements for image processors * Move out rescale and normalize to base image processor * Remove rescale and normalize from vit (post rebase) * Update docstrings and tidy up * PR comments * Add input_data_format as preprocess argument * Resolve tests and tidy up * Remove num_channels argument * Update doc strings -> default ints not in code formatting
This commit is contained in:
parent
a6609caf4e
commit
6bca43bb90
@ -521,7 +521,12 @@ class BaseImageProcessor(ImageProcessingMixin):
|
||||
raise NotImplementedError("Each image processor must implement its own preprocess method")
|
||||
|
||||
def rescale(
|
||||
self, image: np.ndarray, scale: float, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
|
||||
self,
|
||||
image: np.ndarray,
|
||||
scale: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale an image by a scale factor. image = image * scale.
|
||||
@ -536,11 +541,16 @@ class BaseImageProcessor(ImageProcessingMixin):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The rescaled image.
|
||||
"""
|
||||
return rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
|
||||
|
||||
def normalize(
|
||||
self,
|
||||
@ -548,6 +558,7 @@ class BaseImageProcessor(ImageProcessingMixin):
|
||||
mean: Union[float, Iterable[float]],
|
||||
std: Union[float, Iterable[float]],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -565,17 +576,25 @@ class BaseImageProcessor(ImageProcessingMixin):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The normalized image.
|
||||
"""
|
||||
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||
return normalize(
|
||||
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
|
||||
def center_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -588,12 +607,26 @@ class BaseImageProcessor(ImageProcessingMixin):
|
||||
size (`Dict[str, int]`):
|
||||
Size of the output image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
||||
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
|
||||
return center_crop(
|
||||
image,
|
||||
size=(size["height"], size["width"]),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"})
|
||||
|
@ -27,6 +27,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -145,6 +146,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -159,12 +161,19 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=True, param_name="size")
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
|
||||
return resize(
|
||||
image, size=(size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs
|
||||
image,
|
||||
size=(size["height"], size["width"]),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
||||
@ -189,21 +198,22 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_reduce_labels:
|
||||
image = self.reduce_label(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample)
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image=image, size=crop_size)
|
||||
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor)
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
|
||||
return image
|
||||
|
||||
@ -221,10 +231,13 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
image = self._preprocess(
|
||||
image,
|
||||
do_reduce_labels=False,
|
||||
@ -238,9 +251,10 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_segmentation_map(
|
||||
@ -252,6 +266,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
do_center_crop: bool = None,
|
||||
crop_size: Dict[str, int] = None,
|
||||
do_reduce_labels: bool = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""Preprocesses a single segmentation map."""
|
||||
# All transformations expect numpy arrays.
|
||||
@ -260,8 +275,11 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
if segmentation_map.ndim == 2:
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
added_dimension = True
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_dimension = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
||||
segmentation_map = self._preprocess(
|
||||
image=segmentation_map,
|
||||
do_reduce_labels=do_reduce_labels,
|
||||
@ -272,6 +290,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
crop_size=crop_size,
|
||||
do_normalize=False,
|
||||
do_rescale=False,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
# Remove extra axis if added
|
||||
if added_dimension:
|
||||
@ -301,6 +320,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -344,8 +364,15 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -403,6 +430,7 @@ class BeitImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
|
@ -31,6 +31,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -125,6 +126,7 @@ class BitImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -140,12 +142,23 @@ class BitImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -163,6 +176,7 @@ class BitImageProcessor(BaseImageProcessor):
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -205,9 +219,15 @@ class BitImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: defaults to the channel dimension format of the input image.
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -250,19 +270,36 @@ class BitImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -26,6 +26,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -111,6 +112,7 @@ class BlipImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -128,6 +130,13 @@ class BlipImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -136,7 +145,14 @@ class BlipImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -152,6 +168,7 @@ class BlipImageProcessor(BaseImageProcessor):
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -190,8 +207,15 @@ class BlipImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -229,16 +253,31 @@ class BlipImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
|
||||
|
@ -50,7 +50,9 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -60,33 +62,40 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
|
||||
|
||||
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.vilt.image_processing_vilt.get_resize_output_image_size
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32
|
||||
input_image: np.ndarray,
|
||||
shorter: int = 800,
|
||||
longer: int = 1333,
|
||||
size_divisor: int = 32,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
input_height, input_width = get_image_size(input_image)
|
||||
input_height, input_width = get_image_size(input_image, input_data_format)
|
||||
min_size, max_size = shorter, longer
|
||||
|
||||
scale = min_size / min(input_height, input_width)
|
||||
@ -122,7 +131,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
Resize the shorter side of the input to `size["shortest_edge"]`. The longer side will be limited to under
|
||||
`int((1333 / 800) * size["shortest_edge"])` while preserving the aspect ratio. Only has an effect if
|
||||
`do_resize` is set to `True`. Can be overridden by the `size` parameter in the `preprocess` method.
|
||||
size_divisor (`int`, *optional*, defaults to `32`):
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
|
||||
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
@ -197,6 +206,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
size_divisor: int = 32,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -217,20 +227,32 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
shorter = size["shortest_edge"]
|
||||
longer = int(1333 / 800 * shorter)
|
||||
output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def center_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -244,9 +266,18 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
Size of the output image in the form `{"height": h, "width": w}`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input
|
||||
image.
|
||||
"""
|
||||
output_size = size["shortest_edge"]
|
||||
return center_crop(image, size=(output_size, output_size), data_format=data_format, **kwargs)
|
||||
return center_crop(
|
||||
image,
|
||||
size=(output_size, output_size),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
|
||||
def _pad_image(
|
||||
@ -255,18 +286,24 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -278,6 +315,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -299,17 +337,28 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -330,6 +379,7 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
do_center_crop: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -374,8 +424,15 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
|
||||
@ -414,22 +471,41 @@ class BridgeTowerImageProcessor(BaseImageProcessor):
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images
|
||||
self.resize(
|
||||
image=image,
|
||||
size=size,
|
||||
size_divisor=size_divisor,
|
||||
resample=resample,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors)
|
||||
encoded_outputs = self.pad(
|
||||
images, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=data_format
|
||||
)
|
||||
else:
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
|
||||
|
@ -31,6 +31,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -124,6 +125,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -139,12 +141,22 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input
|
||||
image.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=(size["height"], size["width"]), default_to_square=False
|
||||
image, size=(size["height"], size["width"]), default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -162,6 +174,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -204,9 +217,15 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: defaults to the channel dimension format of the input image.
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -249,19 +268,36 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -31,6 +31,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -124,6 +125,7 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -139,12 +141,23 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -162,6 +175,7 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -204,9 +218,15 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: defaults to the channel dimension format of the input image.
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -249,19 +269,36 @@ class CLIPImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -124,7 +124,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int]], max_size: Optional[int] = None
|
||||
input_image: np.ndarray,
|
||||
size: Union[int, Tuple[int, int], List[int]],
|
||||
max_size: Optional[int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Computes the output image size given the input image size and the desired output size. If the desired output size
|
||||
@ -138,8 +141,10 @@ def get_resize_output_image_size(
|
||||
The desired output size.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum allowed output size.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
"""
|
||||
image_size = get_image_size(input_image)
|
||||
image_size = get_image_size(input_image, input_data_format)
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size
|
||||
|
||||
@ -209,23 +214,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -235,7 +245,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -277,11 +287,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->ConditionalDetr
|
||||
def prepare_coco_detection_annotation(image, target, return_segmentation_masks: bool = False):
|
||||
def prepare_coco_detection_annotation(
|
||||
image,
|
||||
target,
|
||||
return_segmentation_masks: bool = False,
|
||||
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
||||
):
|
||||
"""
|
||||
Convert the target in COCO format into the format expected by ConditionalDetr.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
image_id = target["image_id"]
|
||||
image_id = np.asarray([image_id], dtype=np.int64)
|
||||
@ -366,12 +381,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->ConditionalDetr
|
||||
def prepare_coco_panoptic_annotation(
|
||||
image: np.ndarray, target: Dict, masks_path: Union[str, pathlib.Path], return_masks: bool = True
|
||||
image: np.ndarray,
|
||||
target: Dict,
|
||||
masks_path: Union[str, pathlib.Path],
|
||||
return_masks: bool = True,
|
||||
input_data_format: Union[ChannelDimension, str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare a coco panoptic annotation for ConditionalDetr.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
||||
|
||||
new_target = {}
|
||||
@ -842,6 +861,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
format: Optional[AnnotionFormat] = None,
|
||||
return_segmentation_masks: bool = None,
|
||||
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare an annotation for feeding into ConditionalDetr model.
|
||||
@ -850,11 +870,17 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
if format == AnnotionFormat.COCO_DETECTION:
|
||||
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_detection_annotation(image, target, return_segmentation_masks)
|
||||
target = prepare_coco_detection_annotation(
|
||||
image, target, return_segmentation_masks, input_data_format=input_data_format
|
||||
)
|
||||
elif format == AnnotionFormat.COCO_PANOPTIC:
|
||||
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_panoptic_annotation(
|
||||
image, target, masks_path=masks_path, return_masks=return_segmentation_masks
|
||||
image,
|
||||
target,
|
||||
masks_path=masks_path,
|
||||
return_masks=return_segmentation_masks,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Format {format} is not supported.")
|
||||
@ -892,11 +918,26 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
||||
int, smaller edge of the image will be matched to this number.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary containing the size to resize to. Can contain the keys `shortest_edge` and `longest_edge` or
|
||||
`height` and `width`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if "max_size" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -908,7 +949,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
max_size = None
|
||||
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
||||
if "shortest_edge" in size and "longest_edge" in size:
|
||||
size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
|
||||
size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
@ -916,7 +959,9 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
||||
f" {size.keys()}."
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
|
||||
@ -935,7 +980,11 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -950,8 +999,13 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
|
||||
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
||||
@ -968,18 +1022,24 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -991,6 +1051,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -1012,17 +1073,28 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -1046,6 +1118,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
format: Optional[Union[str, AnnotionFormat]] = None,
|
||||
return_tensors: Optional[Union[TensorType, str]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -1091,8 +1164,17 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
Format of the annotations.
|
||||
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
|
||||
Type of tensors to return. If `None`, will return the list of images.
|
||||
data_format (`str` or `ChannelDimension`, *optional*, defaults to self.data_format):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -1177,13 +1259,22 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
||||
if annotations is not None:
|
||||
prepared_images = []
|
||||
prepared_annotations = []
|
||||
for image, target in zip(images, annotations):
|
||||
target = self.prepare_annotation(
|
||||
image, target, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path
|
||||
image,
|
||||
target,
|
||||
format,
|
||||
return_segmentation_masks=return_segmentation_masks,
|
||||
masks_path=masks_path,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
prepared_images.append(image)
|
||||
prepared_annotations.append(target)
|
||||
@ -1196,33 +1287,49 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
if annotations is not None:
|
||||
resized_images, resized_annotations = [], []
|
||||
for image, target in zip(images, annotations):
|
||||
orig_size = get_image_size(image)
|
||||
resized_image = self.resize(image, size=size, max_size=max_size, resample=resample)
|
||||
resized_annotation = self.resize_annotation(target, orig_size, get_image_size(resized_image))
|
||||
orig_size = get_image_size(image, input_data_format)
|
||||
resized_image = self.resize(
|
||||
image, size=size, max_size=max_size, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
resized_annotation = self.resize_annotation(
|
||||
target, orig_size, get_image_size(resized_image, input_data_format)
|
||||
)
|
||||
resized_images.append(resized_image)
|
||||
resized_annotations.append(resized_annotation)
|
||||
images = resized_images
|
||||
annotations = resized_annotations
|
||||
del resized_images, resized_annotations
|
||||
else:
|
||||
images = [self.resize(image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, rescale_factor) for image in images]
|
||||
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image, image_mean, image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(annotation, get_image_size(image))
|
||||
self.normalize_annotation(
|
||||
annotation, get_image_size(image, input_data_format), input_data_format=input_data_format
|
||||
)
|
||||
for annotation, image in zip(annotations, images)
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
||||
data = self.pad(images, return_pixel_mask=True, data_format=data_format)
|
||||
data = self.pad(
|
||||
images, return_pixel_mask=True, data_format=data_format, input_data_format=input_data_format
|
||||
)
|
||||
else:
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": images}
|
||||
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -31,6 +31,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -118,6 +119,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
|
||||
crop_pct: float,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -137,6 +139,9 @@ class ConvNextImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input
|
||||
image.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
@ -146,14 +151,34 @@ class ConvNextImageProcessor(BaseImageProcessor):
|
||||
if shortest_edge < 384:
|
||||
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
|
||||
resize_shortest_edge = int(shortest_edge / crop_pct)
|
||||
resize_size = get_resize_output_image_size(image, size=resize_shortest_edge, default_to_square=False)
|
||||
image = resize(image=image, size=resize_size, resample=resample, data_format=data_format, **kwargs)
|
||||
resize_size = get_resize_output_image_size(
|
||||
image, size=resize_shortest_edge, default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
image = resize(
|
||||
image=image,
|
||||
size=resize_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
# then crop to (shortest_edge, shortest_edge)
|
||||
return center_crop(image=image, size=(shortest_edge, shortest_edge), data_format=data_format, **kwargs)
|
||||
return center_crop(
|
||||
image=image,
|
||||
size=(shortest_edge, shortest_edge),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# warping (no cropping) when evaluated at 384 or larger
|
||||
return resize(
|
||||
image, size=(shortest_edge, shortest_edge), resample=resample, data_format=data_format, **kwargs
|
||||
image,
|
||||
size=(shortest_edge, shortest_edge),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
@ -170,6 +195,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -209,8 +235,15 @@ class ConvNextImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
crop_pct = crop_pct if crop_pct is not None else self.crop_pct
|
||||
@ -247,16 +280,33 @@ class ConvNextImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(
|
||||
image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -123,7 +123,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int]], max_size: Optional[int] = None
|
||||
input_image: np.ndarray,
|
||||
size: Union[int, Tuple[int, int], List[int]],
|
||||
max_size: Optional[int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Computes the output image size given the input image size and the desired output size. If the desired output size
|
||||
@ -137,8 +140,10 @@ def get_resize_output_image_size(
|
||||
The desired output size.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum allowed output size.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
"""
|
||||
image_size = get_image_size(input_image)
|
||||
image_size = get_image_size(input_image, input_data_format)
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size
|
||||
|
||||
@ -208,23 +213,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -234,7 +244,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -276,11 +286,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->DeformableDetr
|
||||
def prepare_coco_detection_annotation(image, target, return_segmentation_masks: bool = False):
|
||||
def prepare_coco_detection_annotation(
|
||||
image,
|
||||
target,
|
||||
return_segmentation_masks: bool = False,
|
||||
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
||||
):
|
||||
"""
|
||||
Convert the target in COCO format into the format expected by DeformableDetr.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
image_id = target["image_id"]
|
||||
image_id = np.asarray([image_id], dtype=np.int64)
|
||||
@ -365,12 +380,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->DeformableDetr
|
||||
def prepare_coco_panoptic_annotation(
|
||||
image: np.ndarray, target: Dict, masks_path: Union[str, pathlib.Path], return_masks: bool = True
|
||||
image: np.ndarray,
|
||||
target: Dict,
|
||||
masks_path: Union[str, pathlib.Path],
|
||||
return_masks: bool = True,
|
||||
input_data_format: Union[ChannelDimension, str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare a coco panoptic annotation for DeformableDetr.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
||||
|
||||
new_target = {}
|
||||
@ -840,6 +859,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
format: Optional[AnnotionFormat] = None,
|
||||
return_segmentation_masks: bool = None,
|
||||
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare an annotation for feeding into DeformableDetr model.
|
||||
@ -848,11 +868,17 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
if format == AnnotionFormat.COCO_DETECTION:
|
||||
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_detection_annotation(image, target, return_segmentation_masks)
|
||||
target = prepare_coco_detection_annotation(
|
||||
image, target, return_segmentation_masks, input_data_format=input_data_format
|
||||
)
|
||||
elif format == AnnotionFormat.COCO_PANOPTIC:
|
||||
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_panoptic_annotation(
|
||||
image, target, masks_path=masks_path, return_masks=return_segmentation_masks
|
||||
image,
|
||||
target,
|
||||
masks_path=masks_path,
|
||||
return_masks=return_segmentation_masks,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Format {format} is not supported.")
|
||||
@ -890,11 +916,26 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
||||
int, smaller edge of the image will be matched to this number.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary containing the size to resize to. Can contain the keys `shortest_edge` and `longest_edge` or
|
||||
`height` and `width`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if "max_size" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -906,7 +947,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
max_size = None
|
||||
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
||||
if "shortest_edge" in size and "longest_edge" in size:
|
||||
size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
|
||||
size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
@ -914,7 +957,9 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
||||
f" {size.keys()}."
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
|
||||
@ -933,7 +978,11 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -948,8 +997,13 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
|
||||
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
||||
@ -966,18 +1020,24 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -989,6 +1049,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -1010,17 +1071,28 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -1044,6 +1116,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
format: Optional[Union[str, AnnotionFormat]] = None,
|
||||
return_tensors: Optional[Union[TensorType, str]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -1089,8 +1162,17 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
Format of the annotations.
|
||||
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
|
||||
Type of tensors to return. If `None`, will return the list of images.
|
||||
data_format (`str` or `ChannelDimension`, *optional*, defaults to self.data_format):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -1175,13 +1257,22 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
||||
if annotations is not None:
|
||||
prepared_images = []
|
||||
prepared_annotations = []
|
||||
for image, target in zip(images, annotations):
|
||||
target = self.prepare_annotation(
|
||||
image, target, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path
|
||||
image,
|
||||
target,
|
||||
format,
|
||||
return_segmentation_masks=return_segmentation_masks,
|
||||
masks_path=masks_path,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
prepared_images.append(image)
|
||||
prepared_annotations.append(target)
|
||||
@ -1194,33 +1285,49 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
if annotations is not None:
|
||||
resized_images, resized_annotations = [], []
|
||||
for image, target in zip(images, annotations):
|
||||
orig_size = get_image_size(image)
|
||||
resized_image = self.resize(image, size=size, max_size=max_size, resample=resample)
|
||||
resized_annotation = self.resize_annotation(target, orig_size, get_image_size(resized_image))
|
||||
orig_size = get_image_size(image, input_data_format)
|
||||
resized_image = self.resize(
|
||||
image, size=size, max_size=max_size, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
resized_annotation = self.resize_annotation(
|
||||
target, orig_size, get_image_size(resized_image, input_data_format)
|
||||
)
|
||||
resized_images.append(resized_image)
|
||||
resized_annotations.append(resized_annotation)
|
||||
images = resized_images
|
||||
annotations = resized_annotations
|
||||
del resized_images, resized_annotations
|
||||
else:
|
||||
images = [self.resize(image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, rescale_factor) for image in images]
|
||||
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image, image_mean, image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(annotation, get_image_size(image))
|
||||
self.normalize_annotation(
|
||||
annotation, get_image_size(image, input_data_format), input_data_format=input_data_format
|
||||
)
|
||||
for annotation, image in zip(annotations, images)
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
||||
data = self.pad(images, return_pixel_mask=True, data_format=data_format)
|
||||
data = self.pad(
|
||||
images, return_pixel_mask=True, data_format=data_format, input_data_format=input_data_format
|
||||
)
|
||||
else:
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": images}
|
||||
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -26,6 +26,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -114,6 +115,7 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -131,6 +133,13 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -139,7 +148,14 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -156,6 +172,7 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -197,6 +214,12 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -235,19 +258,36 @@ class DeiTImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -115,7 +115,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int]], max_size: Optional[int] = None
|
||||
input_image: np.ndarray,
|
||||
size: Union[int, Tuple[int, int], List[int]],
|
||||
max_size: Optional[int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Computes the output image size given the input image size and the desired output size. If the desired output size
|
||||
@ -129,8 +132,10 @@ def get_resize_output_image_size(
|
||||
The desired output size.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum allowed output size.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
"""
|
||||
image_size = get_image_size(input_image)
|
||||
image_size = get_image_size(input_image, input_data_format)
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size
|
||||
|
||||
@ -200,23 +205,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -226,7 +236,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -268,11 +278,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation with DETR->DETA
|
||||
def prepare_coco_detection_annotation(image, target, return_segmentation_masks: bool = False):
|
||||
def prepare_coco_detection_annotation(
|
||||
image,
|
||||
target,
|
||||
return_segmentation_masks: bool = False,
|
||||
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
||||
):
|
||||
"""
|
||||
Convert the target in COCO format into the format expected by DETA.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
image_id = target["image_id"]
|
||||
image_id = np.asarray([image_id], dtype=np.int64)
|
||||
@ -357,12 +372,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->DETA
|
||||
def prepare_coco_panoptic_annotation(
|
||||
image: np.ndarray, target: Dict, masks_path: Union[str, pathlib.Path], return_masks: bool = True
|
||||
image: np.ndarray,
|
||||
target: Dict,
|
||||
masks_path: Union[str, pathlib.Path],
|
||||
return_masks: bool = True,
|
||||
input_data_format: Union[ChannelDimension, str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare a coco panoptic annotation for DETA.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
||||
|
||||
new_target = {}
|
||||
@ -522,6 +541,7 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
format: Optional[AnnotionFormat] = None,
|
||||
return_segmentation_masks: bool = None,
|
||||
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare an annotation for feeding into DETA model.
|
||||
@ -530,11 +550,17 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
|
||||
if format == AnnotionFormat.COCO_DETECTION:
|
||||
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_detection_annotation(image, target, return_segmentation_masks)
|
||||
target = prepare_coco_detection_annotation(
|
||||
image, target, return_segmentation_masks, input_data_format=input_data_format
|
||||
)
|
||||
elif format == AnnotionFormat.COCO_PANOPTIC:
|
||||
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_panoptic_annotation(
|
||||
image, target, masks_path=masks_path, return_masks=return_segmentation_masks
|
||||
image,
|
||||
target,
|
||||
masks_path=masks_path,
|
||||
return_masks=return_segmentation_masks,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Format {format} is not supported.")
|
||||
@ -571,15 +597,32 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
||||
int, smaller edge of the image will be matched to this number.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
The desired output size. Can contain keys `shortest_edge` and `longest_edge` or `height` and `width`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image.
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input
|
||||
image.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" in size and "longest_edge" in size:
|
||||
size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
|
||||
size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
@ -587,7 +630,9 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
||||
f" {size.keys()}."
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format
|
||||
)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
|
||||
@ -606,7 +651,11 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -621,8 +670,13 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
|
||||
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
||||
@ -639,18 +693,24 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -662,6 +722,7 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -683,17 +744,28 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -716,6 +788,7 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
format: Optional[Union[str, AnnotionFormat]] = None,
|
||||
return_tensors: Optional[Union[TensorType, str]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -761,8 +834,17 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
Format of the annotations.
|
||||
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
|
||||
Type of tensors to return. If `None`, will return the list of images.
|
||||
data_format (`str` or `ChannelDimension`, *optional*, defaults to self.data_format):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -839,13 +921,22 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
||||
if annotations is not None:
|
||||
prepared_images = []
|
||||
prepared_annotations = []
|
||||
for image, target in zip(images, annotations):
|
||||
target = self.prepare_annotation(
|
||||
image, target, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path
|
||||
image,
|
||||
target,
|
||||
format,
|
||||
return_segmentation_masks=return_segmentation_masks,
|
||||
masks_path=masks_path,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
prepared_images.append(image)
|
||||
prepared_annotations.append(target)
|
||||
@ -858,33 +949,47 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
if annotations is not None:
|
||||
resized_images, resized_annotations = [], []
|
||||
for image, target in zip(images, annotations):
|
||||
orig_size = get_image_size(image)
|
||||
resized_image = self.resize(image, size=size, resample=resample)
|
||||
resized_annotation = self.resize_annotation(target, orig_size, get_image_size(resized_image))
|
||||
orig_size = get_image_size(image, input_data_format)
|
||||
resized_image = self.resize(
|
||||
image, size=size, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
resized_annotation = self.resize_annotation(
|
||||
target, orig_size, get_image_size(resized_image, input_data_format)
|
||||
)
|
||||
resized_images.append(resized_image)
|
||||
resized_annotations.append(resized_annotation)
|
||||
images = resized_images
|
||||
annotations = resized_annotations
|
||||
del resized_images, resized_annotations
|
||||
else:
|
||||
images = [self.resize(image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, rescale_factor) for image in images]
|
||||
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image, image_mean, image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(annotation, get_image_size(image))
|
||||
self.normalize_annotation(annotation, get_image_size(image, input_data_format))
|
||||
for annotation, image in zip(annotations, images)
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
||||
data = self.pad(images, return_pixel_mask=True, data_format=data_format)
|
||||
data = self.pad(
|
||||
images, return_pixel_mask=True, data_format=data_format, input_data_format=input_data_format
|
||||
)
|
||||
else:
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": images}
|
||||
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -121,7 +121,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int]], max_size: Optional[int] = None
|
||||
input_image: np.ndarray,
|
||||
size: Union[int, Tuple[int, int], List[int]],
|
||||
max_size: Optional[int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Computes the output image size given the input image size and the desired output size. If the desired output size
|
||||
@ -135,8 +138,10 @@ def get_resize_output_image_size(
|
||||
The desired output size.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum allowed output size.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
"""
|
||||
image_size = get_image_size(input_image)
|
||||
image_size = get_image_size(input_image, input_data_format)
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size
|
||||
|
||||
@ -203,23 +208,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -229,7 +239,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -271,11 +281,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar
|
||||
|
||||
|
||||
# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
|
||||
def prepare_coco_detection_annotation(image, target, return_segmentation_masks: bool = False):
|
||||
def prepare_coco_detection_annotation(
|
||||
image,
|
||||
target,
|
||||
return_segmentation_masks: bool = False,
|
||||
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
||||
):
|
||||
"""
|
||||
Convert the target in COCO format into the format expected by DETR.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
image_id = target["image_id"]
|
||||
image_id = np.asarray([image_id], dtype=np.int64)
|
||||
@ -358,12 +373,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def prepare_coco_panoptic_annotation(
|
||||
image: np.ndarray, target: Dict, masks_path: Union[str, pathlib.Path], return_masks: bool = True
|
||||
image: np.ndarray,
|
||||
target: Dict,
|
||||
masks_path: Union[str, pathlib.Path],
|
||||
return_masks: bool = True,
|
||||
input_data_format: Union[ChannelDimension, str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare a coco panoptic annotation for DETR.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
||||
|
||||
new_target = {}
|
||||
@ -822,6 +841,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
format: Optional[AnnotionFormat] = None,
|
||||
return_segmentation_masks: bool = None,
|
||||
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare an annotation for feeding into DETR model.
|
||||
@ -830,11 +850,17 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
if format == AnnotionFormat.COCO_DETECTION:
|
||||
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_detection_annotation(image, target, return_segmentation_masks)
|
||||
target = prepare_coco_detection_annotation(
|
||||
image, target, return_segmentation_masks, input_data_format=input_data_format
|
||||
)
|
||||
elif format == AnnotionFormat.COCO_PANOPTIC:
|
||||
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_panoptic_annotation(
|
||||
image, target, masks_path=masks_path, return_masks=return_segmentation_masks
|
||||
image,
|
||||
target,
|
||||
masks_path=masks_path,
|
||||
return_masks=return_segmentation_masks,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Format {format} is not supported.")
|
||||
@ -867,11 +893,26 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
||||
int, smaller edge of the image will be matched to this number.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary containing the size to resize to. Can contain the keys `shortest_edge` and `longest_edge` or
|
||||
`height` and `width`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if "max_size" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -883,7 +924,9 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
max_size = None
|
||||
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
||||
if "shortest_edge" in size and "longest_edge" in size:
|
||||
size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
|
||||
size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
@ -891,7 +934,9 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
||||
f" {size.keys()}."
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
return image
|
||||
|
||||
def resize_annotation(
|
||||
@ -909,7 +954,11 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
|
||||
# TODO (Amy) - update to use `rescale_factor` instead of `scale`
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -924,8 +973,13 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
||||
"""
|
||||
@ -940,18 +994,24 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -962,6 +1022,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -983,17 +1044,28 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -1016,6 +1088,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
format: Optional[Union[str, AnnotionFormat]] = None,
|
||||
return_tensors: Optional[Union[TensorType, str]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -1061,8 +1134,17 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
Format of the annotations.
|
||||
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
|
||||
Type of tensors to return. If `None`, will return the list of images.
|
||||
data_format (`str` or `ChannelDimension`, *optional*, defaults to self.data_format):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -1147,13 +1229,22 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
||||
if annotations is not None:
|
||||
prepared_images = []
|
||||
prepared_annotations = []
|
||||
for image, target in zip(images, annotations):
|
||||
target = self.prepare_annotation(
|
||||
image, target, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path
|
||||
image,
|
||||
target,
|
||||
format,
|
||||
return_segmentation_masks=return_segmentation_masks,
|
||||
masks_path=masks_path,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
prepared_images.append(image)
|
||||
prepared_annotations.append(target)
|
||||
@ -1166,33 +1257,49 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
if annotations is not None:
|
||||
resized_images, resized_annotations = [], []
|
||||
for image, target in zip(images, annotations):
|
||||
orig_size = get_image_size(image)
|
||||
resized_image = self.resize(image, size=size, max_size=max_size, resample=resample)
|
||||
resized_annotation = self.resize_annotation(target, orig_size, get_image_size(resized_image))
|
||||
orig_size = get_image_size(image, input_data_format)
|
||||
resized_image = self.resize(
|
||||
image, size=size, max_size=max_size, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
resized_annotation = self.resize_annotation(
|
||||
target, orig_size, get_image_size(resized_image, input_data_format)
|
||||
)
|
||||
resized_images.append(resized_image)
|
||||
resized_annotations.append(resized_annotation)
|
||||
images = resized_images
|
||||
annotations = resized_annotations
|
||||
del resized_images, resized_annotations
|
||||
else:
|
||||
images = [self.resize(image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, rescale_factor) for image in images]
|
||||
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image, image_mean, image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(annotation, get_image_size(image))
|
||||
self.normalize_annotation(
|
||||
annotation, get_image_size(image, input_data_format), input_data_format=input_data_format
|
||||
)
|
||||
for annotation, image in zip(annotations, images)
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
# Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
|
||||
data = self.pad(images, return_pixel_mask=True, data_format=data_format)
|
||||
data = self.pad(
|
||||
images, return_pixel_mask=True, data_format=data_format, input_data_format=input_data_format
|
||||
)
|
||||
else:
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": images}
|
||||
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -32,6 +32,7 @@ from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -122,7 +123,11 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
|
||||
def align_long_axis(
|
||||
self, image: np.ndarray, size: Dict[str, int], data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Align the long axis of the image to the longest axis of the specified size.
|
||||
@ -132,11 +137,15 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
The image to be aligned.
|
||||
size (`Dict[str, int]`):
|
||||
The size `{"height": h, "width": w}` to align the long axis to.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The data format of the output image. If unset, the same format as the input image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The aligned image.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = size["height"], size["width"]
|
||||
|
||||
if (output_width < output_height and input_width > input_height) or (
|
||||
@ -145,7 +154,7 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
image = np.rot90(image, 3)
|
||||
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
|
||||
return image
|
||||
|
||||
@ -155,6 +164,7 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
random_padding: bool = False,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad the image to the specified size.
|
||||
@ -168,9 +178,11 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
Whether to use random padding or not.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The data format of the output image. If unset, the same format as the input image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
output_height, output_width = size["height"], size["width"]
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
delta_width = output_width - input_width
|
||||
delta_height = output_height - input_height
|
||||
@ -186,7 +198,7 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
pad_right = delta_width - pad_left
|
||||
|
||||
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
|
||||
return pad(image, padding, data_format=data_format)
|
||||
return pad(image, padding, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
def pad(self, *args, **kwargs):
|
||||
logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
|
||||
@ -198,6 +210,7 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -213,8 +226,10 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
The resampling filter to use.
|
||||
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
|
||||
The data format of the output image. If unset, the same format as the input image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = size["height"], size["width"]
|
||||
|
||||
# We always resize to the smallest of either the input or output size.
|
||||
@ -230,7 +245,13 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
height = int(input_height * width / input_width)
|
||||
|
||||
return resize(
|
||||
image, size=(height, width), resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs
|
||||
image,
|
||||
size=(height, width),
|
||||
resample=resample,
|
||||
reducing_gap=2.0,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def resize(
|
||||
@ -239,6 +260,7 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -254,11 +276,22 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
shortest_edge = min(size["height"], size["width"])
|
||||
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
|
||||
resized_image = resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
resized_image = resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
return resized_image
|
||||
|
||||
def preprocess(
|
||||
@ -278,6 +311,7 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -327,6 +361,12 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: defaults to the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -367,25 +407,45 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_align_long_axis:
|
||||
images = [self.align_long_axis(image, size=size) for image in images]
|
||||
images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_thumbnail:
|
||||
images = [self.thumbnail(image=image, size=size) for image in images]
|
||||
images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_pad:
|
||||
images = [self.pad_image(image=image, size=size, random_padding=random_padding) for image in images]
|
||||
images = [
|
||||
self.pad_image(
|
||||
image=image, size=size, random_padding=random_padding, input_data_format=input_data_format
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -28,6 +28,7 @@ from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_torch_available,
|
||||
is_torch_tensor,
|
||||
make_list_of_images,
|
||||
@ -48,7 +49,11 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, output_size: Union[int, Iterable[int]], keep_aspect_ratio: bool, multiple: int
|
||||
input_image: np.ndarray,
|
||||
output_size: Union[int, Iterable[int]],
|
||||
keep_aspect_ratio: bool,
|
||||
multiple: int,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None):
|
||||
x = round(val / multiple) * multiple
|
||||
@ -63,7 +68,7 @@ def get_resize_output_image_size(
|
||||
|
||||
output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
|
||||
|
||||
input_height, input_width = get_image_size(input_image)
|
||||
input_height, input_width = get_image_size(input_image, input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
# determine new height and width
|
||||
@ -97,7 +102,7 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
|
||||
be overidden by `keep_aspect_ratio` in `preprocess`.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to `1`):
|
||||
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
||||
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
|
||||
by `ensure_multiple_of` in `preprocess`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
@ -156,6 +161,7 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
ensure_multiple_of: int = 1,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -170,7 +176,7 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
Target size of the output image.
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to `1`):
|
||||
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
||||
The image is resized to a size that is a multiple of this value.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
|
||||
@ -179,6 +185,8 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
@ -188,8 +196,16 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
output_size=(size["height"], size["width"]),
|
||||
keep_aspect_ratio=keep_aspect_ratio,
|
||||
multiple=ensure_multiple_of,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -206,6 +222,7 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -249,6 +266,12 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -282,16 +305,31 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -30,6 +30,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_batched,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -116,6 +117,7 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -133,6 +135,8 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -140,13 +144,17 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
|
||||
size = get_size_dict(size)
|
||||
|
||||
if "shortest_edge" in size:
|
||||
size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
# size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError(f"Size must contain 'height' and 'width' keys or 'shortest_edge' key. Got {size.keys()}")
|
||||
return resize(image, size=size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -163,6 +171,7 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -205,6 +214,12 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
@ -241,19 +256,36 @@ class EfficientFormerImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -26,6 +26,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -123,6 +124,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.NEAREST,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -140,6 +142,13 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -148,7 +157,14 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def rescale(
|
||||
self,
|
||||
@ -156,6 +172,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
scale: Union[int, float],
|
||||
offset: bool = True,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -177,8 +194,12 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
Whether to scale the image in both negative and positive directions.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
rescaled_image = rescale(
|
||||
image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
|
||||
if offset:
|
||||
rescaled_image = rescaled_image - 1
|
||||
@ -202,6 +223,7 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
include_top: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -247,6 +269,12 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -287,22 +315,44 @@ class EfficientNetImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor, offset=rescale_offset) for image in images]
|
||||
images = [
|
||||
self.rescale(
|
||||
image=image, scale=rescale_factor, offset=rescale_offset, input_data_format=input_data_format
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if include_top:
|
||||
images = [self.normalize(image=image, mean=[0, 0, 0], std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=0, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -29,6 +29,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -338,6 +339,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -355,6 +357,13 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -363,7 +372,14 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def map_pixels(self, image: np.ndarray) -> np.ndarray:
|
||||
return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
|
||||
@ -383,6 +399,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_map_pixels: bool = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[ChannelDimension] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
if do_resize and size is None or resample is None:
|
||||
@ -397,23 +414,27 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample)
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image=image, size=crop_size)
|
||||
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor)
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
|
||||
if do_map_pixels:
|
||||
image = self.map_pixels(image)
|
||||
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
@ -452,6 +473,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
codebook_image_std: Optional[Iterable[float]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -533,6 +555,12 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -615,6 +643,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
image_std=image_std,
|
||||
do_map_pixels=False,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
@ -636,6 +665,7 @@ class FlavaImageProcessor(BaseImageProcessor):
|
||||
image_std=codebook_image_std,
|
||||
do_map_pixels=codebook_do_map_pixels,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
|
@ -25,6 +25,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -75,6 +76,7 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||
size_divisor: int,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -95,15 +97,27 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not set, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
height, width = get_image_size(image)
|
||||
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||
# Rounds the height and width down to the closest multiple of size_divisor
|
||||
new_h = height // size_divisor * size_divisor
|
||||
new_w = width // size_divisor * size_divisor
|
||||
image = resize(image, (new_h, new_w), resample=resample, data_format=data_format, **kwargs)
|
||||
image = resize(
|
||||
image,
|
||||
(new_h, new_w),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
@ -115,6 +129,7 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||
do_rescale: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[TensorType, str]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -144,6 +159,12 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
@ -161,13 +182,22 @@ class GLPNImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(img) for img in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image, size_divisor=size_divisor, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, scale=1 / 255) for image in images]
|
||||
images = [self.rescale(image, scale=1 / 255, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -24,6 +24,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -107,6 +108,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -124,6 +126,13 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -132,12 +141,20 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def normalize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Normalizes an images' pixel values to between [-1, 1].
|
||||
@ -147,8 +164,10 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
Image to normalize.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
image = rescale(image=image, scale=1 / 127.5, data_format=data_format)
|
||||
image = rescale(image=image, scale=1 / 127.5, data_format=data_format, input_data_format=input_data_format)
|
||||
image = image - 1
|
||||
return image
|
||||
|
||||
@ -163,6 +182,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -197,6 +217,12 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
Only has an effect if `do_color_quantize` is set to `False`.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -224,14 +250,21 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image) for image in images]
|
||||
images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_color_quantize:
|
||||
images = [to_channel_dimension_format(image, ChannelDimension.LAST) for image in images]
|
||||
images = [to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) for image in images]
|
||||
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
|
||||
images = np.array(images)
|
||||
images = color_quantize(images, clusters).reshape(images.shape[:-1])
|
||||
@ -243,7 +276,10 @@ class ImageGPTImageProcessor(BaseImageProcessor):
|
||||
# We need to convert back to a list of images to keep consistent behaviour across processors.
|
||||
images = list(images)
|
||||
else:
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
data = {"input_ids": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -24,6 +24,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -50,12 +51,17 @@ def normalize_box(box, width, height):
|
||||
]
|
||||
|
||||
|
||||
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str] = None):
|
||||
def apply_tesseract(
|
||||
image: np.ndarray,
|
||||
lang: Optional[str],
|
||||
tesseract_config: Optional[str] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
||||
tesseract_config = tesseract_config if tesseract_config is not None else ""
|
||||
|
||||
# apply OCR
|
||||
pil_image = to_pil_image(image)
|
||||
pil_image = to_pil_image(image, input_data_format=input_data_format)
|
||||
image_width, image_height = pil_image.size
|
||||
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
|
||||
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
||||
@ -138,6 +144,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -155,6 +162,13 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -163,7 +177,14 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -176,6 +197,7 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
tesseract_config: Optional[str] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -233,21 +255,30 @@ class LayoutLMv2ImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if apply_ocr:
|
||||
requires_backends(self, "pytesseract")
|
||||
words_batch = []
|
||||
boxes_batch = []
|
||||
for image in images:
|
||||
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
|
||||
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config, input_data_format=input_data_format)
|
||||
words_batch.append(words)
|
||||
boxes_batch.append(boxes)
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
||||
images = [flip_channel_order(image) for image in images]
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [flip_channel_order(image, input_data_format=input_data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
|
||||
|
@ -26,6 +26,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -52,11 +53,16 @@ def normalize_box(box, width, height):
|
||||
]
|
||||
|
||||
|
||||
def apply_tesseract(image: np.ndarray, lang: Optional[str], tesseract_config: Optional[str]):
|
||||
def apply_tesseract(
|
||||
image: np.ndarray,
|
||||
lang: Optional[str],
|
||||
tesseract_config: Optional[str],
|
||||
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
||||
):
|
||||
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
||||
|
||||
# apply OCR
|
||||
pil_image = to_pil_image(image)
|
||||
pil_image = to_pil_image(image, input_data_format=input_data_format)
|
||||
image_width, image_height = pil_image.size
|
||||
data = pytesseract.image_to_data(pil_image, lang=lang, output_type="dict", config=tesseract_config)
|
||||
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
||||
@ -164,6 +170,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -181,6 +188,13 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -189,7 +203,14 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -207,6 +228,7 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
tesseract_config: Optional[str] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -252,6 +274,12 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -286,26 +314,41 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# Tesseract OCR to get words + normalized bounding boxes
|
||||
if apply_ocr:
|
||||
requires_backends(self, "pytesseract")
|
||||
words_batch = []
|
||||
boxes_batch = []
|
||||
for image in images:
|
||||
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config)
|
||||
words, boxes = apply_tesseract(image, ocr_lang, tesseract_config, input_data_format=input_data_format)
|
||||
words_batch.append(words)
|
||||
boxes_batch.append(boxes)
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -119,6 +120,7 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -143,19 +145,28 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size_dict = get_size_dict(size, default_to_square=False)
|
||||
# size_dict is a dict with either keys "height" and "width" or "shortest_edge"
|
||||
if "shortest_edge" in size:
|
||||
shortest_edge = int((256 / 224) * size["shortest_edge"])
|
||||
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
size_dict = {"height": output_size[0], "width": output_size[1]}
|
||||
if "height" not in size_dict or "width" not in size_dict:
|
||||
raise ValueError(
|
||||
f"Size dict must have keys 'height' and 'width' or 'shortest_edge'. Got {size_dict.keys()}"
|
||||
)
|
||||
return resize(
|
||||
image, size=(size_dict["height"], size_dict["width"]), resample=resample, data_format=data_format, **kwargs
|
||||
image,
|
||||
size=(size_dict["height"], size_dict["width"]),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
@ -173,6 +184,7 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, Iterable[float]]] = None,
|
||||
return_tensors: Optional[TensorType] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -217,6 +229,12 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -255,19 +273,27 @@ class LevitImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image, size, resample) for image in images]
|
||||
images = [self.resize(image, size, resample, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image, crop_size) for image in images]
|
||||
images = [self.center_crop(image, crop_size, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, rescale_factor) for image in images]
|
||||
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image, image_mean, image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -66,23 +66,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -92,7 +97,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -297,6 +302,7 @@ def get_mask2former_resize_output_image_size(
|
||||
max_size: Optional[int] = None,
|
||||
size_divisor: int = 0,
|
||||
default_to_square: bool = True,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Computes the output size given the desired size.
|
||||
@ -310,14 +316,18 @@ def get_mask2former_resize_output_image_size(
|
||||
Whether to default to square if no size is provided.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum size of the output image.
|
||||
size_divisible (`int`, *optional*, defaults to `0`):
|
||||
size_divisible (`int`, *optional*, defaults to 0):
|
||||
If size_divisible is given, the output image size will be divisible by the number.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The output size.
|
||||
"""
|
||||
output_size = get_resize_output_image_size(
|
||||
input_image=image, size=size, default_to_square=default_to_square, max_size=max_size
|
||||
input_image=image,
|
||||
size=size,
|
||||
default_to_square=default_to_square,
|
||||
max_size=max_size,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
if size_divisor > 0:
|
||||
@ -450,11 +460,27 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
size_divisor: int = 0,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format=None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
|
||||
int, smaller edge of the image will be matched to this number.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
The size of the output image.
|
||||
size_divisor (`int`, *optional*, defaults to 0):
|
||||
If size_divisor is given, the output image size will be divisible by the number.
|
||||
resample (`PILImageResampling` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use when resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if "max_size" in kwargs:
|
||||
warnings.warn(
|
||||
@ -482,13 +508,20 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
max_size=max_size,
|
||||
size_divisor=size_divisor,
|
||||
default_to_square=False,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -503,8 +536,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
|
||||
def convert_segmentation_map_to_binary_masks(
|
||||
@ -538,13 +576,16 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_resize:
|
||||
image = self.resize(image, size=size, size_divisor=size_divisor, resample=resample)
|
||||
image = self.resize(
|
||||
image, size=size, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
if do_rescale:
|
||||
image = self.rescale(image, rescale_factor=rescale_factor)
|
||||
image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
|
||||
if do_normalize:
|
||||
image = self.normalize(image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_image(
|
||||
@ -560,10 +601,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
image = self._preprocess(
|
||||
image=image,
|
||||
do_resize=do_resize,
|
||||
@ -575,9 +619,10 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_mask(
|
||||
@ -586,14 +631,19 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
size_divisor: int = 0,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single mask."""
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
# Add channel dimension if missing - needed for certain transformations
|
||||
added_channel_dim = False
|
||||
if segmentation_map.ndim == 2:
|
||||
added_channel_dim = True
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_channel_dim = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map)
|
||||
# TODO: (Amy)
|
||||
# Remork segmentation map processing to include reducing labels and resizing which doesn't
|
||||
# drop segment IDs > 255.
|
||||
@ -605,6 +655,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
size_divisor=size_divisor,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
# Remove extra channel dimension if added for processing
|
||||
if added_channel_dim:
|
||||
@ -629,6 +680,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
reduce_labels: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
@ -691,17 +743,26 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = [
|
||||
self._preprocess_mask(segmentation_map, do_resize, size, size_divisor)
|
||||
self._preprocess_mask(
|
||||
segmentation_map, do_resize, size, size_divisor, input_data_format=input_data_format
|
||||
)
|
||||
for segmentation_map in segmentation_maps
|
||||
]
|
||||
encoded_inputs = self.encode_inputs(
|
||||
images, segmentation_maps, instance_id_to_semantic_id, ignore_index, reduce_labels, return_tensors
|
||||
images,
|
||||
segmentation_maps,
|
||||
instance_id_to_semantic_id,
|
||||
ignore_index,
|
||||
reduce_labels,
|
||||
return_tensors,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return encoded_inputs
|
||||
|
||||
@ -712,18 +773,24 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -735,6 +802,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -756,17 +824,28 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -779,6 +858,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index: Optional[int] = None,
|
||||
reduce_labels: bool = False,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
|
||||
@ -815,6 +895,9 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
||||
objects.
|
||||
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
@ -831,7 +914,13 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels
|
||||
|
||||
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
||||
encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
|
||||
|
||||
encoded_inputs = self.pad(
|
||||
pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
if segmentation_maps is not None:
|
||||
mask_labels = []
|
||||
|
@ -70,23 +70,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -96,7 +101,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -299,6 +304,7 @@ def get_maskformer_resize_output_image_size(
|
||||
max_size: Optional[int] = None,
|
||||
size_divisor: int = 0,
|
||||
default_to_square: bool = True,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Computes the output size given the desired size.
|
||||
@ -312,14 +318,18 @@ def get_maskformer_resize_output_image_size(
|
||||
Whether to default to square if no size is provided.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum size of the output image.
|
||||
size_divisible (`int`, *optional*, defaults to `0`):
|
||||
size_divisible (`int`, *optional*, defaults to 0):
|
||||
If size_divisible is given, the output image size will be divisible by the number.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The output size.
|
||||
"""
|
||||
output_size = get_resize_output_image_size(
|
||||
input_image=image, size=size, default_to_square=default_to_square, max_size=max_size
|
||||
input_image=image,
|
||||
size=size,
|
||||
default_to_square=default_to_square,
|
||||
max_size=max_size,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
if size_divisor > 0:
|
||||
@ -458,11 +468,27 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
size_divisor: int = 0,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format=None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
|
||||
int, smaller edge of the image will be matched to this number.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
The size of the output image.
|
||||
size_divisor (`int`, *optional*, defaults to 0):
|
||||
If size_divisor is given, the output image size will be divisible by the number.
|
||||
resample (`PILImageResampling` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use when resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if "max_size" in kwargs:
|
||||
warnings.warn(
|
||||
@ -490,13 +516,20 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
max_size=max_size,
|
||||
size_divisor=size_divisor,
|
||||
default_to_square=False,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -511,8 +544,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
def convert_segmentation_map_to_binary_masks(
|
||||
self,
|
||||
@ -545,13 +583,16 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_resize:
|
||||
image = self.resize(image, size=size, size_divisor=size_divisor, resample=resample)
|
||||
image = self.resize(
|
||||
image, size=size, size_divisor=size_divisor, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
if do_rescale:
|
||||
image = self.rescale(image, rescale_factor=rescale_factor)
|
||||
image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
|
||||
if do_normalize:
|
||||
image = self.normalize(image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_image(
|
||||
@ -567,10 +608,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
image = self._preprocess(
|
||||
image=image,
|
||||
do_resize=do_resize,
|
||||
@ -582,9 +626,10 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_mask(
|
||||
@ -593,14 +638,19 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
size_divisor: int = 0,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single mask."""
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
# Add channel dimension if missing - needed for certain transformations
|
||||
added_channel_dim = False
|
||||
if segmentation_map.ndim == 2:
|
||||
added_channel_dim = True
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_channel_dim = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
||||
# TODO: (Amy)
|
||||
# Remork segmentation map processing to include reducing labels and resizing which doesn't
|
||||
# drop segment IDs > 255.
|
||||
@ -612,6 +662,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
size_divisor=size_divisor,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
# Remove extra channel dimension if added for processing
|
||||
if added_channel_dim:
|
||||
@ -636,6 +687,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
@ -708,17 +760,26 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = [
|
||||
self._preprocess_mask(segmentation_map, do_resize, size, size_divisor)
|
||||
self._preprocess_mask(
|
||||
segmentation_map, do_resize, size, size_divisor, input_data_format=input_data_format
|
||||
)
|
||||
for segmentation_map in segmentation_maps
|
||||
]
|
||||
encoded_inputs = self.encode_inputs(
|
||||
images, segmentation_maps, instance_id_to_semantic_id, ignore_index, do_reduce_labels, return_tensors
|
||||
images,
|
||||
segmentation_maps,
|
||||
instance_id_to_semantic_id,
|
||||
ignore_index,
|
||||
do_reduce_labels,
|
||||
return_tensors,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return encoded_inputs
|
||||
|
||||
@ -729,18 +790,24 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -752,6 +819,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -773,17 +841,28 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -796,6 +875,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index: Optional[int] = None,
|
||||
reduce_labels: bool = False,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
|
||||
@ -848,12 +928,18 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels
|
||||
|
||||
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
||||
encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
|
||||
|
||||
encoded_inputs = self.pad(
|
||||
pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
if segmentation_maps is not None:
|
||||
mask_labels = []
|
||||
class_labels = []
|
||||
pad_size = get_max_height_width(pixel_values_list)
|
||||
pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
|
||||
# Convert to list of binary masks and labels
|
||||
for idx, segmentation_map in enumerate(segmentation_maps):
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
@ -869,7 +955,13 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
# this will be removed in the future
|
||||
masks = [mask[None, ...] for mask in masks]
|
||||
masks = [
|
||||
self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks
|
||||
self._pad_image(
|
||||
image=mask,
|
||||
output_size=pad_size,
|
||||
constant_values=ignore_index,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
for mask in masks
|
||||
]
|
||||
masks = np.concatenate(masks, axis=0)
|
||||
mask_labels.append(torch.from_numpy(masks))
|
||||
|
@ -30,6 +30,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -118,6 +119,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -133,12 +135,23 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -155,6 +168,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -197,6 +211,12 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -234,19 +254,36 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -30,6 +30,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -122,6 +123,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -137,12 +139,23 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -159,6 +172,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -201,6 +215,12 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -238,19 +258,36 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -29,6 +29,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -114,6 +115,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -129,15 +131,29 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def flip_channel_order(
|
||||
self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Flip the color channels from RGB to BGR or vice versa.
|
||||
@ -147,8 +163,10 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
The image, represented as a numpy array.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
return flip_channel_order(image, data_format=data_format)
|
||||
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -163,6 +181,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
do_flip_channel_order: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -199,6 +218,12 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -234,20 +259,34 @@ class MobileViTImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
# the pretrained checkpoints assume images are BGR, not RGB
|
||||
if do_flip_channel_order:
|
||||
images = [self.flip_channel_order(image=image) for image in images]
|
||||
images = [self.flip_channel_order(image=image, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -67,23 +67,28 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -93,7 +98,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -295,6 +300,7 @@ def get_oneformer_resize_output_image_size(
|
||||
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||
max_size: Optional[int] = None,
|
||||
default_to_square: bool = True,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Computes the output size given the desired size.
|
||||
@ -304,16 +310,20 @@ def get_oneformer_resize_output_image_size(
|
||||
The input image.
|
||||
size (`int`, `Tuple[int, int]`, `List[int]`, `Tuple[int]`):
|
||||
The size of the output image.
|
||||
default_to_square (`bool`, *optional*, defaults to `True`):
|
||||
Whether to default to square if no size is provided.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum size of the output image.
|
||||
default_to_square (`bool`, *optional*, defaults to `True`):
|
||||
Whether to default to square if no size is provided.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The output size.
|
||||
"""
|
||||
output_size = get_resize_output_image_size(
|
||||
input_image=image, size=size, default_to_square=default_to_square, max_size=max_size
|
||||
input_image=image,
|
||||
size=size,
|
||||
default_to_square=default_to_square,
|
||||
max_size=max_size,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return output_size
|
||||
|
||||
@ -442,6 +452,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format=None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -469,17 +480,20 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
f" {size.keys()}."
|
||||
)
|
||||
size = get_oneformer_resize_output_image_size(
|
||||
image=image,
|
||||
size=size,
|
||||
max_size=max_size,
|
||||
default_to_square=False,
|
||||
image=image, size=size, max_size=max_size, default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -494,8 +508,13 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
|
||||
def convert_segmentation_map_to_binary_masks(
|
||||
@ -528,13 +547,14 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_resize:
|
||||
image = self.resize(image, size=size, resample=resample)
|
||||
image = self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
if do_rescale:
|
||||
image = self.rescale(image, rescale_factor=rescale_factor)
|
||||
image = self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
|
||||
if do_normalize:
|
||||
image = self.normalize(image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_image(
|
||||
@ -549,10 +569,13 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
image = self._preprocess(
|
||||
image=image,
|
||||
do_resize=do_resize,
|
||||
@ -563,9 +586,10 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_mask(
|
||||
@ -573,14 +597,19 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
segmentation_map: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single mask."""
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
# Add channel dimension if missing - needed for certain transformations
|
||||
added_channel_dim = False
|
||||
if segmentation_map.ndim == 2:
|
||||
added_channel_dim = True
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_channel_dim = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
||||
# TODO: (Amy)
|
||||
# Remork segmentation map processing to include reducing labels and resizing which doesn't
|
||||
# drop segment IDs > 255.
|
||||
@ -591,6 +620,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
size=size,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
# Remove extra channel dimension if added for processing
|
||||
if added_channel_dim:
|
||||
@ -615,6 +645,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
@ -691,13 +722,15 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if segmentation_maps is not None:
|
||||
segmentation_maps = [
|
||||
self._preprocess_mask(segmentation_map, do_resize, size) for segmentation_map in segmentation_maps
|
||||
self._preprocess_mask(segmentation_map, do_resize, size, input_data_format=input_data_format)
|
||||
for segmentation_map in segmentation_maps
|
||||
]
|
||||
encoded_inputs = self.encode_inputs(
|
||||
images,
|
||||
@ -707,6 +740,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index,
|
||||
do_reduce_labels,
|
||||
return_tensors,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return encoded_inputs
|
||||
|
||||
@ -717,18 +751,24 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -740,6 +780,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -761,17 +802,28 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -882,6 +934,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index: Optional[int] = None,
|
||||
reduce_labels: bool = False,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
|
||||
@ -921,6 +974,10 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
||||
objects.
|
||||
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input
|
||||
image.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
@ -938,8 +995,14 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
ignore_index = self.ignore_index if ignore_index is None else ignore_index
|
||||
reduce_labels = self.do_reduce_labels if reduce_labels is None else reduce_labels
|
||||
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
|
||||
pad_size = get_max_height_width(pixel_values_list)
|
||||
encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(pixel_values_list[0])
|
||||
|
||||
pad_size = get_max_height_width(pixel_values_list, input_data_format=input_data_format)
|
||||
encoded_inputs = self.pad(
|
||||
pixel_values_list, return_tensors=return_tensors, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
annotations = None
|
||||
if segmentation_maps is not None:
|
||||
|
@ -33,6 +33,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -169,36 +170,79 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to a certain size.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
The size to resize the image to. Must contain height and width keys.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
The resampling filter to use when resizing the input.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError("size dictionary must contain height and width keys")
|
||||
|
||||
return resize(image, (size["height"], size["width"]), resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
(size["height"], size["width"]),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def center_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
crop_size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Center crop an image to a certain size.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to center crop.
|
||||
crop_size (`Dict[str, int]`):
|
||||
The size to center crop the image to. Must contain height and width keys.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
crop_size = get_size_dict(crop_size, default_to_square=True)
|
||||
if "height" not in crop_size or "width" not in crop_size:
|
||||
raise ValueError("crop_size dictionary must contain height and width keys")
|
||||
|
||||
return center_crop(image, (crop_size["height"], crop_size["width"]), data_format=data_format, **kwargs)
|
||||
return center_crop(
|
||||
image,
|
||||
(crop_size["height"], crop_size["width"]),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -213,8 +257,13 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -231,6 +280,7 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[TensorType, str]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -277,6 +327,12 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: defaults to the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -312,19 +368,36 @@ class OwlViTImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image, crop_size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image, crop_size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, rescale_factor=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
return encoded_inputs
|
||||
|
||||
|
@ -27,6 +27,7 @@ from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -117,6 +118,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
crop_size: Dict[str, int],
|
||||
size: Optional[int] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -135,16 +137,24 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
Size of the image after resizing. If not provided, the self.size attribute will be used.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = self.size if size is None else size
|
||||
size = get_size_dict(size)
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
height, width = get_image_size(image)
|
||||
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||
min_dim = min(height, width)
|
||||
cropped_height = (size["height"] / crop_size["height"]) * min_dim
|
||||
cropped_width = (size["width"] / crop_size["width"]) * min_dim
|
||||
return center_crop(image, size=(cropped_height, cropped_width), data_format=data_format, **kwargs)
|
||||
return center_crop(
|
||||
image,
|
||||
size=(cropped_height, cropped_width),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
||||
def resize(
|
||||
@ -153,6 +163,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -170,6 +181,13 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -178,7 +196,14 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -195,6 +220,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -235,6 +261,12 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
@ -272,19 +304,36 @@ class PerceiverImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image, crop_size, size=size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image, crop_size, size=size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -157,7 +157,9 @@ def render_text(
|
||||
|
||||
|
||||
# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
|
||||
def render_header(image: np.ndarray, header: str, **kwargs):
|
||||
def render_header(
|
||||
image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Renders the input text as a header on the input image.
|
||||
|
||||
@ -176,7 +178,7 @@ def render_header(image: np.ndarray, header: str, **kwargs):
|
||||
requires_backends(render_header, "vision")
|
||||
|
||||
# Convert to PIL image if necessary
|
||||
image = to_pil_image(image)
|
||||
image = to_pil_image(image, input_data_format=input_data_format)
|
||||
|
||||
header_image = render_text(header, **kwargs)
|
||||
new_width = max(header_image.width, image.width)
|
||||
@ -236,7 +238,14 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
self.max_patches = max_patches
|
||||
self.is_vqa = is_vqa
|
||||
|
||||
def extract_flattened_patches(self, image: np.ndarray, max_patches: int, patch_size: dict, **kwargs) -> np.ndarray:
|
||||
def extract_flattened_patches(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
max_patches: int,
|
||||
patch_size: dict,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Extract flattened patches from an image.
|
||||
|
||||
@ -256,11 +265,11 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
_check_torch_version()
|
||||
|
||||
# convert to torch
|
||||
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
|
||||
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
patch_height, patch_width = patch_size["height"], patch_size["width"]
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, ChannelDimension.FIRST)
|
||||
|
||||
# maximize scale s.t.
|
||||
scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
|
||||
@ -312,7 +321,11 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
return result
|
||||
|
||||
def normalize(
|
||||
self, image: np.ndarray, data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs
|
||||
self,
|
||||
image: np.ndarray,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Normalize an image. image = (image - image_mean) / image_std.
|
||||
@ -323,6 +336,11 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to normalize.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if image.dtype == np.uint8:
|
||||
image = image.astype(np.float32)
|
||||
@ -332,7 +350,14 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
std = np.std(image)
|
||||
adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape)))
|
||||
|
||||
return normalize(image, mean=mean, std=adjusted_stddev, **kwargs)
|
||||
return normalize(
|
||||
image,
|
||||
mean=mean,
|
||||
std=adjusted_stddev,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -344,6 +369,7 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
patch_size: Optional[Dict[str, int]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
@ -374,6 +400,17 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
@ -399,6 +436,10 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if is_vqa:
|
||||
if header_text is None:
|
||||
raise ValueError("A header text must be provided for VQA models.")
|
||||
@ -414,11 +455,13 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image) for image in images]
|
||||
images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
|
||||
|
||||
# convert to torch tensor and permute
|
||||
images = [
|
||||
self.extract_flattened_patches(image=image, max_patches=max_patches, patch_size=patch_size)
|
||||
self.extract_flattened_patches(
|
||||
image=image, max_patches=max_patches, patch_size=patch_size, input_data_format=input_data_format
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
|
@ -30,6 +30,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -137,6 +138,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
crop_pct: Optional[float] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -166,6 +168,8 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size and ("height" not in size or "width" not in size):
|
||||
@ -181,16 +185,27 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
else:
|
||||
raise ValueError("Invalid size for resize: {}".format(size))
|
||||
|
||||
output_size = get_resize_output_image_size(image, size=scale_size, default_to_square=False)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=scale_size, default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
else:
|
||||
if "shortest_edge" in size:
|
||||
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
output_size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError("Invalid size for resize: {}".format(size))
|
||||
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -208,6 +223,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -250,6 +266,12 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
crop_pct = crop_pct if crop_pct is not None else self.crop_pct
|
||||
@ -289,19 +311,38 @@ class PoolFormerImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, crop_pct=crop_pct, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(
|
||||
image=image, size=size, crop_pct=crop_pct, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -26,6 +26,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -100,6 +101,7 @@ class PvtImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -117,6 +119,13 @@ class PvtImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -125,7 +134,14 @@ class PvtImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -140,6 +156,7 @@ class PvtImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -178,6 +195,12 @@ class PvtImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
@ -207,16 +230,31 @@ class PvtImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -143,6 +143,7 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
image: np.ndarray,
|
||||
pad_size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -156,14 +157,22 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
|
||||
`data_format` of the `image` will be used.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
output_height, output_width = pad_size["height"], pad_size["width"]
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
pad_width = output_width - input_width
|
||||
pad_height = output_height - input_height
|
||||
|
||||
padded_image = pad(image, ((0, pad_height), (0, pad_width)), data_format=data_format, **kwargs)
|
||||
padded_image = pad(
|
||||
image,
|
||||
((0, pad_height), (0, pad_width)),
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
|
||||
@ -183,6 +192,7 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -202,15 +212,28 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "longest_edge" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
|
||||
input_size = get_image_size(image)
|
||||
input_size = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])
|
||||
return resize(image, size=(output_height, output_width), resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=(output_height, output_width),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -228,6 +251,7 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -272,6 +296,12 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -314,23 +344,40 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
original_sizes = [get_image_size(image) for image in images]
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
original_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
reshaped_input_sizes = [get_image_size(image) for image in images]
|
||||
reshaped_input_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
images = [self.pad_image(image=image, pad_size=pad_size) for image in images]
|
||||
images = [
|
||||
self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
encoded_outputs = BatchFeature(
|
||||
data={
|
||||
"pixel_values": images,
|
||||
@ -517,6 +564,7 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||
device: Optional["torch.device"] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
return_tensors: str = "pt",
|
||||
):
|
||||
"""
|
||||
@ -539,6 +587,8 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
device (`torch.device`, *optional*, defaults to None):
|
||||
Device to use for the computation. If None, cpu will be used.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
return_tensors (`str`, *optional*, defaults to `pt`):
|
||||
If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
|
||||
"""
|
||||
@ -549,6 +599,7 @@ class SamImageProcessor(BaseImageProcessor):
|
||||
overlap_ratio,
|
||||
points_per_crop,
|
||||
crop_n_points_downscale_factor,
|
||||
input_data_format,
|
||||
)
|
||||
if return_tensors == "pt":
|
||||
if device is None:
|
||||
@ -855,6 +906,7 @@ def _generate_crop_boxes(
|
||||
overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[List[int]] = 1,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
||||
@ -874,12 +926,14 @@ def _generate_crop_boxes(
|
||||
Number of points to sample per crop.
|
||||
crop_n_points_downscale_factor (`int`, *optional*):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
|
||||
if isinstance(image, list):
|
||||
raise ValueError("Only one image is allowed for crop generation.")
|
||||
image = to_numpy_array(image)
|
||||
original_size = get_image_size(image)
|
||||
original_size = get_image_size(image, input_data_format)
|
||||
|
||||
points_grid = []
|
||||
for i in range(crop_n_layers + 1):
|
||||
@ -889,7 +943,7 @@ def _generate_crop_boxes(
|
||||
crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
|
||||
|
||||
cropped_images, point_grid_per_crop = _generate_crop_images(
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
|
||||
)
|
||||
crop_boxes = np.array(crop_boxes)
|
||||
crop_boxes = crop_boxes.astype(np.float32)
|
||||
@ -935,7 +989,9 @@ def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_size, original_size):
|
||||
def _generate_crop_images(
|
||||
crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
|
||||
):
|
||||
"""
|
||||
Takes as an input bounding boxes that are used to crop the image. Based in the crops, the corresponding points are
|
||||
also passed.
|
||||
@ -945,7 +1001,7 @@ def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_siz
|
||||
for i, crop_box in enumerate(crop_boxes):
|
||||
left, top, right, bottom = crop_box
|
||||
|
||||
channel_dim = infer_channel_dimension_format(image)
|
||||
channel_dim = infer_channel_dimension_format(image, input_data_format)
|
||||
if channel_dim == ChannelDimension.LAST:
|
||||
cropped_im = image[top:bottom, left:right, :]
|
||||
else:
|
||||
@ -953,7 +1009,7 @@ def _generate_crop_images(crop_boxes, image, points_grid, layer_idxs, target_siz
|
||||
|
||||
cropped_images.append(cropped_im)
|
||||
|
||||
cropped_im_size = get_image_size(cropped_im)
|
||||
cropped_im_size = get_image_size(cropped_im, channel_dim)
|
||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||
|
||||
points = points_grid[layer_idxs[i]] * points_scale
|
||||
|
@ -27,6 +27,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -135,6 +136,7 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -152,6 +154,13 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -160,7 +169,14 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
|
||||
def reduce_label(self, label: ImageInput) -> np.ndarray:
|
||||
@ -183,18 +199,19 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
rescale_factor: Optional[float] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
if do_reduce_labels:
|
||||
image = self.reduce_label(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample)
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor)
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
|
||||
return image
|
||||
|
||||
@ -210,10 +227,13 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
image = self._preprocess(
|
||||
image=image,
|
||||
do_reduce_labels=False,
|
||||
@ -225,9 +245,10 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def _preprocess_mask(
|
||||
@ -236,14 +257,19 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
do_reduce_labels: bool = None,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single mask."""
|
||||
segmentation_map = to_numpy_array(segmentation_map)
|
||||
# Add channel dimension if missing - needed for certain transformations
|
||||
added_channel_dim = False
|
||||
if segmentation_map.ndim == 2:
|
||||
added_channel_dim = True
|
||||
segmentation_map = segmentation_map[None, ...]
|
||||
input_data_format = ChannelDimension.FIRST
|
||||
else:
|
||||
added_channel_dim = False
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
|
||||
# reduce zero label if needed
|
||||
segmentation_map = self._preprocess(
|
||||
image=segmentation_map,
|
||||
@ -253,6 +279,7 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
size=size,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
# Remove extra channel dimension if added for processing
|
||||
if added_channel_dim:
|
||||
@ -284,6 +311,7 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
do_reduce_labels: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -326,6 +354,12 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
@ -374,6 +408,7 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
@ -387,6 +422,7 @@ class SegformerImageProcessor(BaseImageProcessor):
|
||||
do_reduce_labels=do_reduce_labels,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for segmentation_map in segmentation_maps
|
||||
]
|
||||
|
@ -20,7 +20,14 @@ import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_transforms import get_image_size, pad, to_channel_dimension_format
|
||||
from ...image_utils import ChannelDimension, ImageInput, make_list_of_images, to_numpy_array, valid_images
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
@ -57,7 +64,13 @@ class Swin2SRImageProcessor(BaseImageProcessor):
|
||||
self.do_pad = do_pad
|
||||
self.pad_size = pad_size
|
||||
|
||||
def pad(self, image: np.ndarray, size: int, data_format: Optional[Union[str, ChannelDimension]] = None):
|
||||
def pad(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: int,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pad an image to make the height and width divisible by `size`.
|
||||
|
||||
@ -71,15 +84,26 @@ class Swin2SRImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The padded image.
|
||||
"""
|
||||
old_height, old_width = get_image_size(image)
|
||||
old_height, old_width = get_image_size(image, input_data_format)
|
||||
pad_height = (old_height // size + 1) * size - old_height
|
||||
pad_width = (old_width // size + 1) * size - old_width
|
||||
|
||||
return pad(image, ((0, pad_height), (0, pad_width)), mode="symmetric", data_format=data_format)
|
||||
return pad(
|
||||
image,
|
||||
((0, pad_height), (0, pad_width)),
|
||||
mode="symmetric",
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -90,6 +114,7 @@ class Swin2SRImageProcessor(BaseImageProcessor):
|
||||
pad_size: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -104,12 +129,13 @@ class Swin2SRImageProcessor(BaseImageProcessor):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad the image to make the height and width divisible by `window_size`.
|
||||
pad_size (`int`, *optional*, defaults to `32`):
|
||||
pad_size (`int`, *optional*, defaults to 32):
|
||||
The size of the sliding window for the local attention.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of typ, input_data_format=input_data_formate
|
||||
`tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
@ -118,6 +144,12 @@ class Swin2SRImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
@ -138,13 +170,22 @@ class Swin2SRImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
images = [self.pad(image, size=pad_size) for image in images]
|
||||
images = [self.pad(image, size=pad_size, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -29,6 +29,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -155,6 +156,7 @@ class TvltImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -171,15 +173,26 @@ class TvltImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" in size:
|
||||
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
output_size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
@ -195,6 +208,7 @@ class TvltImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
if do_resize and size is None or resample is None:
|
||||
@ -212,18 +226,21 @@ class TvltImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample)
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image, size=crop_size)
|
||||
image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor)
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
@ -244,6 +261,7 @@ class TvltImageProcessor(BaseImageProcessor):
|
||||
is_mixed: bool = False,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -291,6 +309,12 @@ class TvltImageProcessor(BaseImageProcessor):
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the inferred channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
@ -361,6 +385,7 @@ class TvltImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in video
|
||||
]
|
||||
|
@ -30,6 +30,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -134,6 +135,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -150,15 +152,26 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" in size:
|
||||
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
output_size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
@ -174,6 +187,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
if do_resize and size is None or resample is None:
|
||||
@ -191,19 +205,22 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample)
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image, size=crop_size)
|
||||
image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor)
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
@ -221,6 +238,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -262,6 +280,12 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the inferred channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -300,6 +324,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in video
|
||||
]
|
||||
|
@ -49,7 +49,9 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
return [max(values_i) for values_i in zip(*values)]
|
||||
|
||||
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -59,31 +61,38 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
|
||||
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, shorter: int = 800, longer: int = 1333, size_divisor: int = 32
|
||||
input_image: np.ndarray,
|
||||
shorter: int = 800,
|
||||
longer: int = 1333,
|
||||
size_divisor: int = 32,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
input_height, input_width = get_image_size(input_image)
|
||||
input_height, input_width = get_image_size(input_image, input_data_format)
|
||||
min_size, max_size = shorter, longer
|
||||
|
||||
scale = min_size / min(input_height, input_width)
|
||||
@ -200,6 +209,7 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
size_divisor: int = 32,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -220,14 +230,25 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
shorter = size["shortest_edge"]
|
||||
longer = int(1333 / 800 * shorter)
|
||||
output_size = get_resize_output_image_size(image, shorter=shorter, longer=longer, size_divisor=size_divisor)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
|
||||
def _pad_image(
|
||||
@ -236,18 +257,24 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
@ -259,6 +286,7 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -280,17 +308,28 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -310,6 +349,7 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
do_pad: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -353,6 +393,12 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
|
||||
@ -387,21 +433,42 @@ class ViltImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(image=image, size=size, size_divisor=size_divisor, resample=resample) for image in images
|
||||
self.resize(
|
||||
image=image,
|
||||
size=size,
|
||||
size_divisor=size_divisor,
|
||||
resample=resample,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
encoded_outputs = self.pad(images, return_pixel_mask=True, return_tensors=return_tensors)
|
||||
encoded_outputs = self.pad(
|
||||
images, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=data_format
|
||||
)
|
||||
else:
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
|
||||
|
@ -26,6 +26,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -99,6 +100,7 @@ class ViTImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -116,6 +118,13 @@ class ViTImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@ -124,7 +133,14 @@ class ViTImageProcessor(BaseImageProcessor):
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -139,6 +155,7 @@ class ViTImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -177,6 +194,12 @@ class ViTImageProcessor(BaseImageProcessor):
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
@ -206,16 +229,31 @@ class ViTImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size_dict, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -31,6 +31,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -125,6 +126,7 @@ class ViTHybridImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -140,12 +142,23 @@ class ViTHybridImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" not in size:
|
||||
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
|
||||
output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=False)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
@ -163,6 +176,7 @@ class ViTHybridImageProcessor(BaseImageProcessor):
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -208,6 +222,12 @@ class ViTHybridImageProcessor(BaseImageProcessor):
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: defaults to the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
@ -250,19 +270,36 @@ class ViTHybridImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image=image, size=crop_size) for image in images]
|
||||
images = [
|
||||
self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -33,6 +33,7 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
@ -141,6 +142,7 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@ -157,15 +159,26 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
if "shortest_edge" in size:
|
||||
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
|
||||
output_size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
output_size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale
|
||||
def rescale(
|
||||
@ -174,6 +187,7 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
scale: Union[int, float],
|
||||
offset: bool = True,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -195,8 +209,12 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
Whether to scale the image in both negative and positive directions.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
|
||||
rescaled_image = rescale(
|
||||
image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
|
||||
if offset:
|
||||
rescaled_image = rescaled_image - 1
|
||||
@ -218,6 +236,7 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Preprocesses a single image."""
|
||||
if do_resize and size is None or resample is None:
|
||||
@ -238,19 +257,22 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample)
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image, size=crop_size)
|
||||
image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, offset=offset)
|
||||
image = self.rescale(image=image, scale=rescale_factor, offset=offset, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std)
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
@ -269,6 +291,7 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@ -312,6 +335,12 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the inferred channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
@ -352,6 +381,7 @@ class VivitImageProcessor(BaseImageProcessor):
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for img in video
|
||||
]
|
||||
|
@ -88,18 +88,21 @@ SUPPORTED_ANNOTATION_FORMATS = (AnnotionFormat.COCO_DETECTION, AnnotionFormat.CO
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
|
||||
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
|
||||
def get_max_height_width(
|
||||
images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
input_channel_dimension = infer_channel_dimension_format(images[0])
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if input_channel_dimension == ChannelDimension.FIRST:
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
||||
elif input_channel_dimension == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
||||
else:
|
||||
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
|
||||
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
@ -137,7 +140,10 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, in
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int]], max_size: Optional[int] = None
|
||||
input_image: np.ndarray,
|
||||
size: Union[int, Tuple[int, int], List[int]],
|
||||
max_size: Optional[int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Computes the output image size given the input image size and the desired output size. If the desired output size
|
||||
@ -151,8 +157,10 @@ def get_resize_output_image_size(
|
||||
The desired output size.
|
||||
max_size (`int`, *optional*):
|
||||
The maximum allowed output size.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
"""
|
||||
image_size = get_image_size(input_image)
|
||||
image_size = get_image_size(input_image, input_data_format)
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size
|
||||
|
||||
@ -222,7 +230,9 @@ def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
|
||||
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
def make_pixel_mask(
|
||||
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
@ -232,7 +242,7 @@ def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarr
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
mask = np.zeros(output_size, dtype=np.int64)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
@ -274,11 +284,16 @@ def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndar
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_detection_annotation
|
||||
def prepare_coco_detection_annotation(image, target, return_segmentation_masks: bool = False):
|
||||
def prepare_coco_detection_annotation(
|
||||
image,
|
||||
target,
|
||||
return_segmentation_masks: bool = False,
|
||||
input_data_format: Optional[Union[ChannelDimension, str]] = None,
|
||||
):
|
||||
"""
|
||||
Convert the target in COCO format into the format expected by DETR.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
image_id = target["image_id"]
|
||||
image_id = np.asarray([image_id], dtype=np.int64)
|
||||
@ -363,12 +378,16 @@ def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.prepare_coco_panoptic_annotation with DETR->YOLOS
|
||||
def prepare_coco_panoptic_annotation(
|
||||
image: np.ndarray, target: Dict, masks_path: Union[str, pathlib.Path], return_masks: bool = True
|
||||
image: np.ndarray,
|
||||
target: Dict,
|
||||
masks_path: Union[str, pathlib.Path],
|
||||
return_masks: bool = True,
|
||||
input_data_format: Union[ChannelDimension, str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare a coco panoptic annotation for YOLOS.
|
||||
"""
|
||||
image_height, image_width = get_image_size(image)
|
||||
image_height, image_width = get_image_size(image, channel_dim=input_data_format)
|
||||
annotation_path = pathlib.Path(masks_path) / target["file_name"]
|
||||
|
||||
new_target = {}
|
||||
@ -751,6 +770,7 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
format: Optional[AnnotionFormat] = None,
|
||||
return_segmentation_masks: bool = None,
|
||||
masks_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Prepare an annotation for feeding into DETR model.
|
||||
@ -759,11 +779,17 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
|
||||
if format == AnnotionFormat.COCO_DETECTION:
|
||||
return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_detection_annotation(image, target, return_segmentation_masks)
|
||||
target = prepare_coco_detection_annotation(
|
||||
image, target, return_segmentation_masks, input_data_format=input_data_format
|
||||
)
|
||||
elif format == AnnotionFormat.COCO_PANOPTIC:
|
||||
return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
|
||||
target = prepare_coco_panoptic_annotation(
|
||||
image, target, masks_path=masks_path, return_masks=return_segmentation_masks
|
||||
image,
|
||||
target,
|
||||
masks_path=masks_path,
|
||||
return_masks=return_segmentation_masks,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Format {format} is not supported.")
|
||||
@ -801,11 +827,26 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
|
||||
int, smaller edge of the image will be matched to this number.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary containing the size to resize to. Can contain the keys `shortest_edge` and `longest_edge` or
|
||||
`height` and `width`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if "max_size" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -817,7 +858,9 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
max_size = None
|
||||
size = get_size_dict(size, max_size=max_size, default_to_square=False)
|
||||
if "shortest_edge" in size and "longest_edge" in size:
|
||||
size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
|
||||
size = get_resize_output_image_size(
|
||||
image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
|
||||
)
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
@ -825,7 +868,9 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
|
||||
f" {size.keys()}."
|
||||
)
|
||||
image = resize(image, size=size, resample=resample, data_format=data_format)
|
||||
image = resize(
|
||||
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
return image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
|
||||
@ -844,7 +889,11 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
|
||||
def rescale(
|
||||
self, image: np.ndarray, rescale_factor: float, data_format: Optional[Union[str, ChannelDimension]] = None
|
||||
self,
|
||||
image: np.ndarray,
|
||||
rescale_factor: float,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescale the image by the given factor. image = image * rescale_factor.
|
||||
@ -859,8 +908,13 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. If unset, is inferred from the input image. Can be
|
||||
one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
"""
|
||||
return rescale(image, rescale_factor, data_format=data_format)
|
||||
return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
|
||||
def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
|
||||
@ -877,28 +931,36 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
output_size: Tuple[int, int],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pad an image with zeros to the given size.
|
||||
"""
|
||||
input_height, input_width = get_image_size(image)
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
pad_bottom = output_height - input_height
|
||||
pad_right = output_width - input_width
|
||||
padding = ((0, pad_bottom), (0, pad_right))
|
||||
padded_image = pad(
|
||||
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
|
||||
image,
|
||||
padding,
|
||||
mode=PaddingMode.CONSTANT,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
return padded_image
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
|
||||
def pad(
|
||||
self,
|
||||
images: List[np.ndarray],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
return_pixel_mask: bool = False,
|
||||
return_pixel_mask: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -920,17 +982,28 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
pad_size = get_max_height_width(images, input_data_format=input_data_format)
|
||||
|
||||
padded_images = [
|
||||
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
|
||||
self._pad_image(
|
||||
image,
|
||||
pad_size,
|
||||
constant_values=constant_values,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": padded_images}
|
||||
|
||||
if return_pixel_mask:
|
||||
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
|
||||
masks = [
|
||||
make_pixel_mask(image=image, output_size=pad_size, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data["pixel_mask"] = masks
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
@ -953,6 +1026,7 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
format: Optional[Union[str, AnnotionFormat]] = None,
|
||||
return_tensors: Optional[Union[TensorType, str]] = None,
|
||||
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -1000,6 +1074,12 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
Type of tensors to return. If `None`, will return the list of images.
|
||||
data_format (`str` or `ChannelDimension`, *optional*, defaults to self.data_format):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if "pad_and_return_pixel_mask" in kwargs:
|
||||
logger.warning_once(
|
||||
@ -1084,13 +1164,22 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
# All transformations expect numpy arrays
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
||||
if annotations is not None:
|
||||
prepared_images = []
|
||||
prepared_annotations = []
|
||||
for image, target in zip(images, annotations):
|
||||
target = self.prepare_annotation(
|
||||
image, target, format, return_segmentation_masks=return_segmentation_masks, masks_path=masks_path
|
||||
image,
|
||||
target,
|
||||
format,
|
||||
return_segmentation_masks=return_segmentation_masks,
|
||||
masks_path=masks_path,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
prepared_images.append(image)
|
||||
prepared_annotations.append(target)
|
||||
@ -1103,22 +1192,31 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
if annotations is not None:
|
||||
resized_images, resized_annotations = [], []
|
||||
for image, target in zip(images, annotations):
|
||||
orig_size = get_image_size(image)
|
||||
resized_image = self.resize(image, size=size, max_size=max_size, resample=resample)
|
||||
resized_annotation = self.resize_annotation(target, orig_size, get_image_size(resized_image))
|
||||
orig_size = get_image_size(image, input_data_format)
|
||||
resized_image = self.resize(
|
||||
image, size=size, max_size=max_size, resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
resized_annotation = self.resize_annotation(
|
||||
target, orig_size, get_image_size(resized_image, input_data_format)
|
||||
)
|
||||
resized_images.append(resized_image)
|
||||
resized_annotations.append(resized_annotation)
|
||||
images = resized_images
|
||||
annotations = resized_annotations
|
||||
del resized_images, resized_annotations
|
||||
else:
|
||||
images = [self.resize(image, size=size, resample=resample) for image in images]
|
||||
images = [
|
||||
self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [self.rescale(image, rescale_factor) for image in images]
|
||||
images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_normalize:
|
||||
images = [self.normalize(image, image_mean, image_std) for image in images]
|
||||
images = [
|
||||
self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(annotation, get_image_size(image))
|
||||
@ -1126,9 +1224,12 @@ class YolosImageProcessor(BaseImageProcessor):
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
data = self.pad(images, data_format=data_format)
|
||||
data = self.pad(images, data_format=data_format, input_data_format=input_data_format)
|
||||
else:
|
||||
images = [to_channel_dimension_format(image, data_format) for image in images]
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
data = {"pixel_values": images}
|
||||
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
@ -70,7 +70,7 @@ class BlipImageProcessingTester(unittest.TestCase):
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return 3, self.size["height"], self.size["width"]
|
||||
return self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
@ -135,3 +135,11 @@ class BlipImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.Tes
|
||||
@unittest.skip("BlipImageProcessor does not support 4 channels yet") # FIXME Amy
|
||||
def test_call_pytorch(self):
|
||||
return super().test_call_torch()
|
||||
|
||||
@unittest.skip("BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
def test_call_pil(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@ -26,6 +26,10 @@ if is_vision_available():
|
||||
from transformers import ChineseCLIPImageProcessor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
|
||||
class ChineseCLIPImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
@ -120,6 +124,10 @@ class ChineseCLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
|
||||
@unittest.skip("ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@ -152,3 +160,7 @@ class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unitt
|
||||
@unittest.skip("ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy
|
||||
def test_call_pytorch(self):
|
||||
return super().test_call_torch()
|
||||
|
||||
@unittest.skip("ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
@ -337,6 +337,11 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_numpy(self):
|
||||
self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
self.image_processing_class.num_channels = 4
|
||||
self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
|
||||
self.image_processing_class.num_channels = 3
|
||||
|
||||
def test_call_pytorch(self):
|
||||
self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True})
|
||||
|
||||
|
@ -144,3 +144,18 @@ class GLPNImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertTrue(tuple(encoded_images.shape) == (1, *expected_output_image_shape))
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processing_class.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input (GLPNImageProcessor doesn't support batching)
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertTrue(tuple(encoded_images.shape) == (1, *expected_output_image_shape))
|
||||
self.image_processing_class.num_channels = 3
|
||||
|
@ -198,6 +198,10 @@ class ImageGPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
||||
@unittest.skip("ImageGPT assumes clusters for 3 channels")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
# Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
|
@ -222,6 +222,40 @@ class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
(self.image_processor_tester.batch_size, max_patch, expected_hidden_dim),
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Initialize image_processor
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
expected_hidden_dim = (
|
||||
(self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"])
|
||||
* self.image_processor_tester.num_channels
|
||||
) + 2
|
||||
|
||||
for max_patch in self.image_processor_tester.max_patches:
|
||||
# Test not batched input
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_first"
|
||||
).flattened_patches
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, max_patch, expected_hidden_dim),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processor(
|
||||
image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_first"
|
||||
).flattened_patches
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(self.image_processor_tester.batch_size, max_patch, expected_hidden_dim),
|
||||
)
|
||||
self.image_processor_tester.num_channels = 3
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processor
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
@ -318,3 +352,7 @@ class Pix2StructImageProcessingTestFourChannels(ImageProcessingTestMixin, unitte
|
||||
@unittest.skip("Pix2StructImageProcessor does not support 4 channels yet") # FIXME Amy
|
||||
def test_call_pytorch(self):
|
||||
return super().test_call_torch()
|
||||
|
||||
@unittest.skip("Pix2StructImageProcessor does treat numpy and PIL 4 channel images consistently") # FIXME Amy
|
||||
def test_call_numpy_4_channels(self):
|
||||
return super().test_call_torch()
|
||||
|
@ -147,6 +147,24 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Swin2SRImageProcessor does not support batched input
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(
|
||||
image_inputs[0], return_tensors="pt", input_data_format="channels_first"
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
self.image_processor_tester.num_channels = 3
|
||||
|
||||
# Swin2SRImageProcessor does not support batched input
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
|
@ -217,6 +217,47 @@ class TvltImageProcessorTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Initialize image_processor
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
video_inputs = prepare_video_inputs(self.image_processor_tester, equal_resolution=False, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, list)
|
||||
self.assertIsInstance(video[0], np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processor(
|
||||
video_inputs[0], return_tensors="pt", input_data_format="channels_first", image_mean=0, image_std=1
|
||||
).pixel_values
|
||||
self.assertEqual(
|
||||
encoded_videos.shape,
|
||||
(
|
||||
1,
|
||||
self.image_processor_tester.num_frames,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processor(
|
||||
video_inputs, return_tensors="pt", input_data_format="channels_first", image_mean=0, image_std=1
|
||||
).pixel_values
|
||||
self.assertEqual(
|
||||
encoded_videos.shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_frames,
|
||||
self.image_processor_tester.num_channels,
|
||||
self.image_processor_tester.crop_size["height"],
|
||||
self.image_processor_tester.crop_size["width"],
|
||||
),
|
||||
)
|
||||
self.image_processor_tester.num_channels = 3
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processor
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
@ -165,6 +165,33 @@ class VideoMAEImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_videos.shape), (self.image_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=False, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, list)
|
||||
self.assertIsInstance(video[0], np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(
|
||||
video_inputs[0], return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first"
|
||||
).pixel_values
|
||||
expected_output_video_shape = self.image_processor_tester.expected_output_image_shape([encoded_videos[0]])
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(
|
||||
video_inputs, return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first"
|
||||
).pixel_values
|
||||
expected_output_video_shape = self.image_processor_tester.expected_output_image_shape(encoded_videos)
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape), (self.image_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
self.image_processor_tester.num_channels = 3
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
@ -179,6 +179,33 @@ class VivitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_videos.shape), (self.image_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=False, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, list)
|
||||
self.assertIsInstance(video[0], np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(
|
||||
video_inputs[0], return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first"
|
||||
).pixel_values
|
||||
expected_output_video_shape = self.image_processor_tester.expected_output_image_shape([encoded_videos[0]])
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(
|
||||
video_inputs, return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first"
|
||||
).pixel_values
|
||||
expected_output_video_shape = self.image_processor_tester.expected_output_image_shape(encoded_videos)
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape), (self.image_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
self.image_processor_tester.num_channels = 3
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
@ -252,3 +252,36 @@ class ImageProcessingTestMixin:
|
||||
tuple(encoded_images.shape),
|
||||
(self.image_processor_tester.batch_size, *expected_output_image_shape),
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Test that can process images which have an arbitrary number of channels
|
||||
# Initialize image_processing
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0],
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_first",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processor(
|
||||
image_inputs,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_first",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user