mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-14 10:08:29 +06:00
Add Idefics2/3 and SmolVLM Fast image processors + improvements for fast image processors (#38157)
* add working idefics2 fast and improvements for fast nested images processing * add fast image processors idefics 3 and smolvlm * cleanup tests * fic doc idefics2 * PR review and fix issues after merge * Force providing disable_grouping to group_images_by_shape * simplify group_images_by_shape * fix modular * Fix nits after review
This commit is contained in:
parent
1a96127e46
commit
d29482cc91
@ -162,7 +162,7 @@ To load and run a model using Flash Attention-2, simply change the code snippet
|
||||
```diff
|
||||
model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||
"HuggingFaceM4/idefics2-8b",
|
||||
+ torch_dtype=torch.float16,
|
||||
+ torch_dtype=torch.float16,
|
||||
+ attn_implementation="flash_attention_2",
|
||||
).to(device)
|
||||
```
|
||||
@ -184,7 +184,7 @@ Quantizing a model is as simple as passing a `quantization_config` to the model.
|
||||
+ )
|
||||
model = Idefics2ForConditionalGeneration.from_pretrained(
|
||||
"HuggingFaceM4/idefics2-8b",
|
||||
+ torch_dtype=torch.float16,
|
||||
+ torch_dtype=torch.float16,
|
||||
+ quantization_config=quantization_config,
|
||||
).to(device)
|
||||
```
|
||||
@ -218,7 +218,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
[[autodoc]] Idefics2ImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Idefics2ImageProcessorFast
|
||||
[[autodoc]] Idefics2ImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## Idefics2Processor
|
||||
[[autodoc]] Idefics2Processor
|
||||
- __call__
|
||||
- __call__
|
||||
|
@ -80,6 +80,9 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
|
||||
[[autodoc]] Idefics3ImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Idefics3ImageProcessorFast
|
||||
[[autodoc]] Idefics3ImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## Idefics3Processor
|
||||
[[autodoc]] Idefics3Processor
|
||||
|
@ -32,7 +32,7 @@ SmolVLM2 is an adaptation of the Idefics3 model with two main differences:
|
||||
|
||||
Input images are processed either by upsampling (if resizing is enabled) or at their original resolution. The resizing behavior depends on two parameters: do_resize and size.
|
||||
|
||||
Videos should not be upsampled.
|
||||
Videos should not be upsampled.
|
||||
|
||||
If `do_resize` is set to `True`, the model resizes images so that the longest edge is 4*512 pixels by default.
|
||||
The default resizing behavior can be customized by passing a dictionary to the `size` parameter. For example, `{"longest_edge": 4 * 512}` is the default, but you can change it to a different value if needed.
|
||||
@ -192,11 +192,14 @@ print(generated_texts[0])
|
||||
[[autodoc]] SmolVLMForConditionalGeneration
|
||||
- forward
|
||||
|
||||
|
||||
## SmolVLMImageProcessor
|
||||
[[autodoc]] SmolVLMImageProcessor
|
||||
- preprocess
|
||||
|
||||
## SmolVLMImageProcessorFast
|
||||
[[autodoc]] SmolVLMImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## SmolVLMVideoProcessor
|
||||
[[autodoc]] SmolVLMVideoProcessor
|
||||
- preprocess
|
||||
|
@ -396,7 +396,7 @@ def add_fast_image_processor_file(
|
||||
|
||||
content_header = get_fast_image_processing_content_header(content_base_file)
|
||||
content_base_file = (
|
||||
f"@auto_docstring(\n"
|
||||
f"@auto_docstring\n"
|
||||
f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
|
||||
" # This generated class can be used as a starting point for the fast image processor.\n"
|
||||
" # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n"
|
||||
|
@ -184,6 +184,7 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False):
|
||||
data_format: Optional[ChannelDimension]
|
||||
input_data_format: Optional[Union[str, ChannelDimension]]
|
||||
device: Optional["torch.device"]
|
||||
disable_grouping: Optional[bool]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -480,18 +481,35 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
) -> list["torch.Tensor"]:
|
||||
"""
|
||||
Prepare the input images for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
do_convert_rgb (`bool`, *optional*):
|
||||
Whether to convert the images to RGB.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The input data format of the images.
|
||||
device (`torch.device`, *optional*):
|
||||
The device to put the processed images on.
|
||||
|
||||
Returns:
|
||||
List[`torch.Tensor`]: The processed images.
|
||||
"""
|
||||
|
||||
# Get structured images (potentially nested)
|
||||
images = self._prepare_images_structure(images)
|
||||
process_image_fn = partial(
|
||||
self._process_image,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
|
||||
process_image_partial = partial(
|
||||
self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
||||
)
|
||||
# todo: yoni - check if we can parallelize this efficiently
|
||||
processed_images = []
|
||||
for image in images:
|
||||
processed_images.append(process_image_fn(image))
|
||||
|
||||
# Check if we have nested structure, assuming the nesting is consistent
|
||||
has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))
|
||||
|
||||
if has_nested_structure:
|
||||
processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
|
||||
else:
|
||||
processed_images = [process_image_partial(img) for img in images]
|
||||
|
||||
return processed_images
|
||||
|
||||
@ -621,11 +639,12 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -635,7 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
@ -656,47 +675,3 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
encoder_dict.pop("_valid_processor_keys", None)
|
||||
encoder_dict.pop("_valid_kwargs_names", None)
|
||||
return encoder_dict
|
||||
|
||||
|
||||
class SemanticSegmentationMixin:
|
||||
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
||||
"""
|
||||
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`MobileNetV2ForSemanticSegmentation`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
|
||||
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
|
||||
predictions will not be resized.
|
||||
|
||||
Returns:
|
||||
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
|
||||
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
|
||||
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
|
||||
"""
|
||||
logits = outputs.logits
|
||||
|
||||
# Resize logits and compute semantic segmentation maps
|
||||
if target_sizes is not None:
|
||||
if len(logits) != len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
# if is_torch_tensor(target_sizes):
|
||||
# target_sizes = target_sizes.numpy()
|
||||
|
||||
semantic_segmentation = []
|
||||
|
||||
for idx in range(len(logits)):
|
||||
resized_logits = torch.nn.functional.interpolate(
|
||||
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
|
||||
)
|
||||
semantic_map = resized_logits[0].argmax(dim=0)
|
||||
semantic_segmentation.append(semantic_map)
|
||||
else:
|
||||
semantic_segmentation = logits.argmax(dim=1)
|
||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||
|
||||
return semantic_segmentation
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Collection, Iterable
|
||||
from math import ceil
|
||||
from typing import Optional, Union
|
||||
@ -841,37 +842,128 @@ def _cast_tensor_to_float(x):
|
||||
return x.float()
|
||||
|
||||
|
||||
def _group_images_by_shape(nested_images, is_nested: bool = False):
|
||||
"""Helper function to flatten a single level of nested image structures and group by shape."""
|
||||
grouped_images = defaultdict(list)
|
||||
grouped_images_index = {}
|
||||
nested_images = [nested_images] if not is_nested else nested_images
|
||||
for i, sublist in enumerate(nested_images):
|
||||
for j, image in enumerate(sublist):
|
||||
key = (i, j) if is_nested else j
|
||||
shape = image.shape[1:]
|
||||
grouped_images[shape].append(image)
|
||||
grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1)
|
||||
|
||||
return grouped_images, grouped_images_index
|
||||
|
||||
|
||||
def _reconstruct_nested_structure(indices, processed_images):
|
||||
"""Helper function to reconstruct a single level nested structure."""
|
||||
# Find the maximum outer index
|
||||
max_outer_idx = max(idx[0] for idx in indices.keys())
|
||||
|
||||
# Create the outer list
|
||||
result = [None] * (max_outer_idx + 1)
|
||||
|
||||
# Group indices by outer index
|
||||
nested_indices = defaultdict(list)
|
||||
for i, j in indices.keys():
|
||||
nested_indices[i].append(j)
|
||||
|
||||
for i in range(max_outer_idx + 1):
|
||||
if i in nested_indices:
|
||||
inner_max_idx = max(nested_indices[i])
|
||||
inner_list = [None] * (inner_max_idx + 1)
|
||||
for j in range(inner_max_idx + 1):
|
||||
if (i, j) in indices:
|
||||
shape, idx = indices[(i, j)]
|
||||
inner_list[j] = processed_images[shape][idx]
|
||||
result[i] = inner_list
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def group_images_by_shape(
|
||||
images: list["torch.Tensor"],
|
||||
) -> tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[int, tuple[tuple[int, int], int]]]:
|
||||
images: Union[list["torch.Tensor"], "torch.Tensor"],
|
||||
disable_grouping: bool,
|
||||
is_nested: bool = False,
|
||||
) -> tuple[
|
||||
dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]
|
||||
]:
|
||||
"""
|
||||
Groups images by shape.
|
||||
Returns a dictionary with the shape as key and a list of images with that shape as value,
|
||||
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
|
||||
|
||||
The function supports both flat lists of tensors and nested structures.
|
||||
The input must be either all flat or all nested, not a mix of both.
|
||||
|
||||
Args:
|
||||
images (Union[list["torch.Tensor"], "torch.Tensor"]):
|
||||
A list of images or a single tensor
|
||||
disable_grouping (bool):
|
||||
Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
|
||||
This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
|
||||
is_nested (bool, *optional*, defaults to False):
|
||||
Whether the images are nested.
|
||||
|
||||
Returns:
|
||||
tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
|
||||
- A dictionary with shape as key and list of images with that shape as value
|
||||
- A dictionary mapping original indices to (shape, index) tuples
|
||||
"""
|
||||
grouped_images = {}
|
||||
grouped_images_index = {}
|
||||
for i, image in enumerate(images):
|
||||
shape = image.shape[1:]
|
||||
if shape not in grouped_images:
|
||||
grouped_images[shape] = []
|
||||
grouped_images[shape].append(image)
|
||||
grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)
|
||||
# stack images with the same shape
|
||||
grouped_images = {shape: torch.stack(images, dim=0) for shape, images in grouped_images.items()}
|
||||
# If disable grouping is not explicitely provided, we favor disabling it if the images are on CPU, and enabling it otherwise.
|
||||
if disable_grouping is None:
|
||||
device = images[0][0].device if is_nested else images[0].device
|
||||
disable_grouping = device == "cpu"
|
||||
|
||||
if disable_grouping:
|
||||
if is_nested:
|
||||
return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, {
|
||||
(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i]))
|
||||
}
|
||||
else:
|
||||
return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))}
|
||||
|
||||
# Handle single level nested structure
|
||||
grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested)
|
||||
|
||||
# Stack images with the same shape
|
||||
grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
|
||||
|
||||
return grouped_images, grouped_images_index
|
||||
|
||||
|
||||
def reorder_images(
|
||||
processed_images: dict[tuple[int, int], "torch.Tensor"], grouped_images_index: dict[int, tuple[int, int]]
|
||||
) -> list["torch.Tensor"]:
|
||||
processed_images: dict[tuple[int, int], "torch.Tensor"],
|
||||
grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]],
|
||||
is_nested: bool = False,
|
||||
) -> Union[list["torch.Tensor"], "torch.Tensor"]:
|
||||
"""
|
||||
Reconstructs a list of images in the original order.
|
||||
Reconstructs images in the original order, preserving the original structure (nested or not).
|
||||
The input structure is either all flat or all nested.
|
||||
|
||||
Args:
|
||||
processed_images (dict[tuple[int, int], "torch.Tensor"]):
|
||||
Dictionary mapping shapes to batched processed images.
|
||||
grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]):
|
||||
Dictionary mapping original indices to (shape, index) tuples.
|
||||
is_nested (bool, *optional*, defaults to False):
|
||||
Whether the images are nested. Cannot be infered from the input, as some processing functions outputs nested images.
|
||||
even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input.
|
||||
|
||||
|
||||
Returns:
|
||||
Union[list["torch.Tensor"], "torch.Tensor"]:
|
||||
Images in the original structure.
|
||||
"""
|
||||
return [
|
||||
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
|
||||
for i in range(len(grouped_images_index))
|
||||
]
|
||||
if not is_nested:
|
||||
return [
|
||||
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
|
||||
for i in range(len(grouped_images_index))
|
||||
]
|
||||
|
||||
return _reconstruct_nested_structure(grouped_images_index, processed_images)
|
||||
|
||||
|
||||
class NumpyToTensor:
|
||||
|
@ -95,8 +95,8 @@ else:
|
||||
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||
("idefics", ("IdeficsImageProcessor",)),
|
||||
("idefics2", ("Idefics2ImageProcessor",)),
|
||||
("idefics3", ("Idefics3ImageProcessor",)),
|
||||
("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
|
||||
("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
|
||||
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("imagegpt", ("ImageGPTImageProcessor",)),
|
||||
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
|
||||
@ -148,6 +148,7 @@ else:
|
||||
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
||||
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
|
||||
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
|
||||
("superglue", ("SuperGlueImageProcessor",)),
|
||||
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
|
@ -27,13 +27,7 @@ from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...image_processing_utils import ImageProcessingMixin
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
PROCESSOR_NAME,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
cached_file,
|
||||
logging,
|
||||
)
|
||||
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
@ -118,6 +112,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("shieldgemma2", "ShieldGemma2Processor"),
|
||||
("siglip", "SiglipProcessor"),
|
||||
("siglip2", "Siglip2Processor"),
|
||||
("smolvlm", "SmolVLMProcessor"),
|
||||
("speech_to_text", "Speech2TextProcessor"),
|
||||
("speech_to_text_2", "Speech2Text2Processor"),
|
||||
("speecht5", "SpeechT5Processor"),
|
||||
|
@ -105,6 +105,7 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
@ -112,7 +113,7 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
||||
images = self.reduce_label(images)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -122,7 +123,7 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -223,6 +223,7 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
|
||||
images: list["torch.Tensor"],
|
||||
constant_values: Union[float, Iterable[float]] = 0,
|
||||
return_pixel_mask: bool = True,
|
||||
disable_grouping: Optional[bool] = False,
|
||||
) -> tuple:
|
||||
"""
|
||||
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
|
||||
@ -235,6 +236,8 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
|
||||
The value to use for the padding if `mode` is `"constant"`.
|
||||
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a pixel mask.
|
||||
disable_grouping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to disable grouping of images by size.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
@ -245,7 +248,7 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
|
||||
"""
|
||||
pad_size = get_max_height_width(images)
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
processed_masks_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
@ -283,11 +286,12 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -299,7 +303,7 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
@ -314,7 +318,9 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
data = {}
|
||||
if do_pad:
|
||||
processed_images, processed_masks = self.pad(processed_images, return_pixel_mask=True)
|
||||
processed_images, processed_masks = self.pad(
|
||||
processed_images, return_pixel_mask=True, disable_grouping=disable_grouping
|
||||
)
|
||||
processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
|
||||
data["pixel_mask"] = processed_masks
|
||||
|
||||
|
@ -19,11 +19,7 @@ from typing import Optional, Union
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
get_resize_output_image_size,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_transforms import get_resize_output_image_size, resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
|
@ -153,10 +153,11 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -168,7 +169,7 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -17,17 +17,8 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from ...image_processing_base import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
|
||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling, SizeDict
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
@ -85,10 +76,11 @@ class DepthProImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched scaling
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
|
@ -16,19 +16,9 @@
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
)
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
|
||||
from ...image_transforms import group_images_by_shape, reorder_images
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
@ -230,11 +220,12 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_align_long_axis:
|
||||
@ -254,7 +245,7 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -176,18 +176,19 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
keep_aspect_ratio: bool,
|
||||
ensure_multiple_of: Optional[int],
|
||||
do_pad: bool,
|
||||
size_divisor: Optional[int],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if do_reduce_labels:
|
||||
images = self.reduce_label(images)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -203,7 +204,7 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -224,18 +224,19 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
keep_aspect_ratio: bool,
|
||||
ensure_multiple_of: Optional[int],
|
||||
do_pad: bool,
|
||||
size_divisor: Optional[int],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
if do_reduce_labels:
|
||||
images = self.reduce_label(images)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -251,7 +252,7 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -17,19 +17,9 @@
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
)
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
|
||||
from ...image_transforms import group_images_by_shape, reorder_images
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
@ -181,11 +171,12 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
|
||||
include_top: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -195,7 +186,7 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -367,10 +367,11 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
do_map_pixels: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> "torch.Tensor":
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -380,7 +381,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
@ -432,6 +433,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
codebook_do_normalize: Optional[bool],
|
||||
codebook_image_mean: Optional[Union[float, list[float]]],
|
||||
codebook_image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
@ -448,6 +450,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
do_map_pixels=False,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
disable_grouping=disable_grouping,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
data = {
|
||||
@ -468,6 +471,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
do_map_pixels=codebook_do_map_pixels,
|
||||
image_mean=codebook_image_mean,
|
||||
image_std=codebook_image_std,
|
||||
disable_grouping=disable_grouping,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
data["codebook_pixel_values"] = codebook_processed_images
|
||||
|
@ -25,12 +25,7 @@ from ...image_processing_utils_fast import (
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, SizeDict
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
@ -205,12 +200,13 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched processing
|
||||
processed_images_grouped = {}
|
||||
num_crops_grouped = {}
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
for shape_images, stacked_images in grouped_images.items():
|
||||
if do_pan_and_scan:
|
||||
pas_images, num_crops = self._process_images_for_pan_and_scan(
|
||||
@ -224,7 +220,9 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
stacked_images = [stacked_images] + pas_images
|
||||
# Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together)
|
||||
processed_image_patches_grouped = {}
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(stacked_images)
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
stacked_images, disable_grouping=disable_grouping
|
||||
)
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
@ -254,7 +252,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
|
@ -23,13 +23,7 @@ from ...image_processing_utils_fast import (
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput, PILImageResampling, SizeDict
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
@ -177,10 +171,11 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
if crop_to_patches:
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
num_patches = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
@ -200,7 +195,7 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
|
||||
num_patches = [1] * len(images)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -210,7 +205,7 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_idefics2 import *
|
||||
from .image_processing_idefics2 import *
|
||||
from .image_processing_idefics2_fast import *
|
||||
from .modeling_idefics2 import *
|
||||
from .processing_idefics2 import *
|
||||
else:
|
||||
|
@ -0,0 +1,318 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
SizeDict,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_nested_list_of_images,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
|
||||
from .image_processing_idefics2 import convert_to_rgb
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_resize_output_image_size(image: "torch.Tensor", size: SizeDict) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to resize.
|
||||
size (`SizeDict`):
|
||||
Size of the output image containing the keys "shortest_edge" and "longest_edge".
|
||||
|
||||
Returns:
|
||||
The output size of the image after resizing.
|
||||
"""
|
||||
height, width = image.size()[-2:]
|
||||
|
||||
min_len = size.shortest_edge
|
||||
max_len = size.longest_edge
|
||||
aspect_ratio = width / height
|
||||
|
||||
if width >= height and width > max_len:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
elif height > width and height > max_len:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
height = max(height, min_len)
|
||||
width = max(width, min_len)
|
||||
return height, width
|
||||
|
||||
|
||||
def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
image_sizes = []
|
||||
for images in images_list:
|
||||
for image in images:
|
||||
image_sizes.append(image.size()[-2:])
|
||||
|
||||
max_height = max(size[0] for size in image_sizes)
|
||||
max_width = max(size[1] for size in image_sizes)
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor":
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to make the pixel mask for.
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = image.size()[-2:]
|
||||
mask = torch.zeros(output_size, dtype=torch.int64, device=image.device)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class Idefics2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
do_image_splitting (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the image into a sequence 4 equal sub-images concatenated with the original image.
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to pad images to the largest height and width in the batch.
|
||||
"""
|
||||
|
||||
do_image_splitting: Optional[bool]
|
||||
do_pad: Optional[bool]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Idefics2ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BILINEAR
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_pad = True
|
||||
do_convert_rgb = True
|
||||
do_image_splitting = False
|
||||
size = {"shortest_edge": 378, "longest_edge": 980}
|
||||
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||
valid_kwargs = Idefics2FastImageProcessorKwargs
|
||||
|
||||
def convert_to_rgb(self, image: ImageInput) -> ImageInput:
|
||||
"""
|
||||
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
||||
as is.
|
||||
"""
|
||||
return convert_to_rgb(image)
|
||||
|
||||
def resize(
|
||||
self, image: torch.Tensor, size: SizeDict, interpolation: Optional["F.InterpolationMode"] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resize an image using torchvision's functional resize.
|
||||
"""
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||
|
||||
if size.shortest_edge and size.longest_edge:
|
||||
new_size = get_resize_output_image_size(image, size)
|
||||
elif size.height and size.width:
|
||||
new_size = (size.height, size.width)
|
||||
else:
|
||||
raise ValueError("Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys.")
|
||||
|
||||
image = F.resize(image, size=new_size, interpolation=interpolation, **kwargs)
|
||||
return image
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare a nested images structure for processing.
|
||||
"""
|
||||
return make_nested_list_of_images(images)
|
||||
|
||||
def split_images(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
) -> list["torch.Tensor"]:
|
||||
"""
|
||||
Split a batch of images into 4 equal sub-images, and concatenate that sequence with the original image.
|
||||
"""
|
||||
height, width = images.size()[-2:]
|
||||
|
||||
mid_width = width // 2
|
||||
mid_height = height // 2
|
||||
|
||||
batch_split_images = [
|
||||
images[..., :mid_height, :mid_width],
|
||||
images[..., :mid_height, mid_width:],
|
||||
images[..., mid_height:, :mid_width],
|
||||
images[..., mid_height:, mid_width:],
|
||||
images,
|
||||
]
|
||||
|
||||
# transpose the batch dimension to the first dimension
|
||||
batch_split_images = [[image[i] for image in batch_split_images] for i in range(len(batch_split_images[0]))]
|
||||
return batch_split_images
|
||||
|
||||
def pad(
|
||||
self, image: "torch.Tensor", padded_size: tuple[int, int], fill: int = 0
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Pad an image to the specified size and create the corresponding pixel mask.
|
||||
"""
|
||||
original_size = image.shape[-2:]
|
||||
padding_bottom = padded_size[0] - original_size[0]
|
||||
padding_right = padded_size[1] - original_size[1]
|
||||
|
||||
if padding_bottom < 0 or padding_right < 0:
|
||||
raise ValueError(
|
||||
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
||||
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
||||
)
|
||||
|
||||
# Only pad if necessary
|
||||
if original_size != padded_size:
|
||||
# torchvision's pad takes a 4-element tuple for 2D padding: (left, top, right, bottom)
|
||||
padding = (0, 0, padding_right, padding_bottom)
|
||||
# Use constant padding to match slow implementation
|
||||
image = F.pad(image, padding, fill=fill, padding_mode="constant")
|
||||
|
||||
# Create pixel mask to match the slow implementation
|
||||
pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
|
||||
pixel_mask[: original_size[0], : original_size[1]] = 1
|
||||
|
||||
return image, pixel_mask
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[Idefics2FastImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[list["torch.Tensor"]],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
do_pad: Optional[bool],
|
||||
do_image_splitting: Optional[bool],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Process a batch of images for the model.
|
||||
"""
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
split_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_image_splitting:
|
||||
stacked_images = self.split_images(stacked_images)
|
||||
split_images_grouped[shape] = stacked_images
|
||||
split_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
|
||||
if do_image_splitting:
|
||||
# flattenened the doubly nested list to a nested list
|
||||
for i, group_images in enumerate(split_images):
|
||||
split_images[i] = [image for sublist in group_images for image in sublist]
|
||||
|
||||
# Group images by size for further processing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
split_images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
resized_images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
|
||||
|
||||
if do_pad:
|
||||
# Get max images per batch
|
||||
max_num_images = max(len(images_) for images_ in processed_images)
|
||||
max_height, max_width = get_max_height_width(processed_images)
|
||||
|
||||
processed_images_padded = torch.zeros(
|
||||
len(processed_images),
|
||||
max_num_images,
|
||||
*(processed_images[0][0].shape[0], max_height, max_width),
|
||||
device=processed_images[0][0].device,
|
||||
)
|
||||
pixel_attention_masks = torch.zeros(
|
||||
len(processed_images),
|
||||
max_num_images,
|
||||
*(max_height, max_width),
|
||||
device=processed_images[0][0].device,
|
||||
)
|
||||
for i, images in enumerate(processed_images):
|
||||
for j, image in enumerate(images):
|
||||
processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
|
||||
image, (max_height, max_width)
|
||||
)
|
||||
processed_images = processed_images_padded
|
||||
if do_pad:
|
||||
data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
|
||||
elif return_tensors == "pt":
|
||||
data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
|
||||
else:
|
||||
data = {"pixel_values": processed_images}
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Idefics2ImageProcessorFast"]
|
@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_idefics3 import *
|
||||
from .image_processing_idefics3 import *
|
||||
from .image_processing_idefics3_fast import *
|
||||
from .modeling_idefics3 import *
|
||||
from .processing_idefics3 import *
|
||||
else:
|
||||
|
@ -770,6 +770,7 @@ class Idefics3ImageProcessor(BaseImageProcessor):
|
||||
split_image_array, rows, cols = self.split_image(
|
||||
image,
|
||||
max_image_size=max_image_size,
|
||||
resample=resample,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
split_image_arrays.extend(split_image_array)
|
||||
|
@ -0,0 +1,507 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
SizeDict,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_nested_list_of_images,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
|
||||
|
||||
|
||||
def _resize_output_size_rescale_to_max_len(
|
||||
height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
min_len (`int`, *optional*, defaults to 1):
|
||||
Minimum size of the output image.
|
||||
max_len (`int`, *optional*, defaults to the maximum size of the image):
|
||||
Maximum size of the output image.
|
||||
Returns:
|
||||
The output size of the image after resizing.
|
||||
"""
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
aspect_ratio = width / height
|
||||
|
||||
if width >= height:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
if height % 2 != 0:
|
||||
height += 1
|
||||
elif height > width:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
if width % 2 != 0:
|
||||
width += 1
|
||||
|
||||
# Avoid resizing to a size smaller than min_len
|
||||
height = max(height, min_len)
|
||||
width = max(width, min_len)
|
||||
return height, width
|
||||
|
||||
|
||||
def _resize_output_size_scale_below_upper_bound(
|
||||
height: int, width: int, max_len: Optional[dict[str, int]] = None
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
|
||||
Defines the maximum dimensions of the image.
|
||||
Returns:
|
||||
The output size of the image after resizing.
|
||||
"""
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
|
||||
aspect_ratio = width / height
|
||||
if width >= height and width > max_len:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
elif height > width and height > max_len:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
|
||||
# Avoid resizing to a size smaller than 1
|
||||
height = max(height, 1)
|
||||
width = max(width, 1)
|
||||
return height, width
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
image,
|
||||
resolution_max_side: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to resize.
|
||||
resolution_max_side (`int`):
|
||||
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
|
||||
input aspect ratio.
|
||||
Returns:
|
||||
The output size of the image after resizing.
|
||||
"""
|
||||
height, width = image.size()[-2:]
|
||||
|
||||
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
|
||||
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
|
||||
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
|
||||
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
|
||||
return height, width
|
||||
|
||||
|
||||
def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
image_sizes = []
|
||||
for images in images_list:
|
||||
for image in images:
|
||||
image_sizes.append(image.size()[-2:])
|
||||
|
||||
max_height = max(size[0] for size in image_sizes)
|
||||
max_width = max(size[1] for size in image_sizes)
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor":
|
||||
"""
|
||||
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to make the pixel mask for.
|
||||
output_size (`Tuple[int, int]`):
|
||||
Output size of the mask.
|
||||
"""
|
||||
input_height, input_width = image.size()[-2:]
|
||||
mask = torch.zeros(output_size, dtype=torch.int64, device=image.device)
|
||||
mask[:input_height, :input_width] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class Idefics3FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
do_image_splitting (`bool`, *optional*, defaults to `True`):
|
||||
Whether to split the image into sub-images concatenated with the original image. They are split into patches
|
||||
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
|
||||
max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
|
||||
Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
|
||||
return_row_col_info (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return the row and column information of the images.
|
||||
"""
|
||||
|
||||
do_pad: Optional[bool]
|
||||
do_image_splitting: Optional[bool]
|
||||
max_image_size: Optional[dict[str, int]]
|
||||
return_row_col_info: Optional[bool]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Idefics3ImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.LANCZOS
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"longest_edge": 4 * 364}
|
||||
max_image_size = {"longest_edge": 364}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_image_splitting = True
|
||||
do_pad = True
|
||||
return_row_col_info = False
|
||||
valid_kwargs = Idefics3FastImageProcessorKwargs
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare a nested images structure for processing.
|
||||
"""
|
||||
return make_nested_list_of_images(images)
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
size: SizeDict,
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
antialias: bool = True,
|
||||
**kwargs,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resize an image. The longest edge of the image is resized to size.longest_edge, with the shortest edge
|
||||
resized to keep the input aspect ratio. Can also be used with size.height and size.width.
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Size of the output image.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
antialias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use antialiasing when resizing the image.
|
||||
"""
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||
if interpolation == F.InterpolationMode.LANCZOS:
|
||||
logger.warning_once(
|
||||
"You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
|
||||
"BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
|
||||
"want full consistency with the original model."
|
||||
)
|
||||
interpolation = F.InterpolationMode.BICUBIC
|
||||
|
||||
if size.longest_edge:
|
||||
size = get_resize_output_image_size(image, resolution_max_side=size.longest_edge)
|
||||
elif size.height and size.width:
|
||||
size = (size.height, size.width)
|
||||
else:
|
||||
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
|
||||
|
||||
return F.resize(image, size, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
def split_images(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
max_image_size: dict[str, int],
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
):
|
||||
"""
|
||||
Split an image into squares of side max_image_size and the original image resized to max_image_size.
|
||||
That means that a single image becomes a sequence of images.
|
||||
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
|
||||
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
|
||||
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
|
||||
sub-images of the same size each (image_size, image_size). Typically, 364x364.
|
||||
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
Images to split.
|
||||
max_image_size (`Dict[str, int]`):
|
||||
Maximum size of the output image. If the image is larger than this size, it will be split into
|
||||
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
"""
|
||||
batch_size, num_channels, height, width = images.size()
|
||||
height_dim, width_dim = 2, 3
|
||||
|
||||
max_height = max_width = max_image_size["longest_edge"]
|
||||
|
||||
frames = []
|
||||
if height > max_height or width > max_width:
|
||||
# Calculate the number of splits
|
||||
num_splits_h = math.ceil(height / max_height)
|
||||
num_splits_w = math.ceil(width / max_width)
|
||||
|
||||
# Split the images by height, then by width
|
||||
frames = (
|
||||
images.unfold(height_dim, size=max_height, step=max_height)
|
||||
.unfold(width_dim, size=max_width, step=max_width)
|
||||
.contiguous()
|
||||
.view(batch_size, num_channels, -1, max_height, max_width)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
) # batch_size x n_frames x num_channels x height x width
|
||||
|
||||
# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
|
||||
global_image_height, global_image_width = max_height, max_width
|
||||
images = self.resize(
|
||||
images, SizeDict(height=global_image_height, width=global_image_width), interpolation=interpolation
|
||||
)
|
||||
|
||||
frames = torch.cat((frames, images.unsqueeze(1)), dim=1)
|
||||
else:
|
||||
num_splits_h, num_splits_w = 0, 0
|
||||
frames = images.unsqueeze(1)
|
||||
|
||||
num_splits_h = [num_splits_h] * batch_size
|
||||
num_splits_w = [num_splits_w] * batch_size
|
||||
|
||||
return frames, num_splits_h, num_splits_w
|
||||
|
||||
def resize_for_vision_encoder(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
vision_encoder_max_size: int,
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
):
|
||||
"""
|
||||
Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Images to resize.
|
||||
vision_encoder_max_size (`int`):
|
||||
Maximum size of the output image. If the image is larger than this size, it will be split into
|
||||
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
"""
|
||||
height, width = image.size()[-2:]
|
||||
|
||||
aspect_ratio = width / height
|
||||
if width >= height:
|
||||
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
||||
height = int(width / aspect_ratio)
|
||||
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
||||
elif height > width:
|
||||
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
||||
width = int(height * aspect_ratio)
|
||||
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
||||
new_size = SizeDict(height=height, width=width)
|
||||
return self.resize(image, size=new_size, interpolation=interpolation)
|
||||
|
||||
def pad(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
padded_size: tuple[int, int],
|
||||
fill: int = 0,
|
||||
return_pixel_mask: bool = True,
|
||||
):
|
||||
original_size = image.shape[-2:]
|
||||
padding_bottom = padded_size[0] - original_size[0]
|
||||
padding_right = padded_size[1] - original_size[1]
|
||||
|
||||
if padding_bottom < 0 or padding_right < 0:
|
||||
raise ValueError(
|
||||
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
||||
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
||||
)
|
||||
|
||||
# Only pad if necessary
|
||||
if original_size != padded_size:
|
||||
padding = (0, 0, padding_right, padding_bottom)
|
||||
image = F.pad(image, padding, fill=fill, padding_mode="constant")
|
||||
|
||||
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
pixel_mask = None
|
||||
if return_pixel_mask:
|
||||
pixel_mask = torch.zeros_like(image[..., 0, :, :], dtype=torch.int64)
|
||||
pixel_mask[: original_size[0], : original_size[1]] = 1
|
||||
|
||||
return image, pixel_mask
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[Idefics3FastImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[list["torch.Tensor"]],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
do_pad: Optional[bool],
|
||||
do_image_splitting: Optional[bool],
|
||||
max_image_size: Optional[dict[str, int]],
|
||||
return_row_col_info: Optional[bool],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Process a batch of images for the model.
|
||||
"""
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
resized_images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
split_images_grouped = {}
|
||||
if do_image_splitting:
|
||||
rows_grouped = {}
|
||||
cols_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
stacked_images = self.resize_for_vision_encoder(
|
||||
stacked_images, max_image_size["longest_edge"], interpolation=interpolation
|
||||
)
|
||||
stacked_images, rows, cols = self.split_images(
|
||||
stacked_images, max_image_size=max_image_size, interpolation=interpolation
|
||||
)
|
||||
split_images_grouped[shape] = stacked_images
|
||||
rows_grouped[shape] = rows
|
||||
cols_grouped[shape] = cols
|
||||
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
|
||||
rows = reorder_images(rows_grouped, grouped_images_index, is_nested=True)
|
||||
cols = reorder_images(cols_grouped, grouped_images_index, is_nested=True)
|
||||
# flattenened the doubly nested list to a nested list
|
||||
for i, group_images in enumerate(processed_images):
|
||||
processed_images[i] = [image for sublist in group_images for image in sublist]
|
||||
else:
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# We square the images to max_image_size
|
||||
stacked_images = self.resize(
|
||||
image=stacked_images,
|
||||
size=SizeDict(height=max_image_size["longest_edge"], width=max_image_size["longest_edge"]),
|
||||
interpolation=interpolation,
|
||||
)
|
||||
split_images_grouped[shape] = stacked_images
|
||||
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
|
||||
rows = [[0] * len(images) for images in processed_images]
|
||||
cols = [[0] * len(images) for images in processed_images]
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
processed_images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
|
||||
if do_pad:
|
||||
# Get max images per batch
|
||||
max_num_images = max(len(images_) for images_ in processed_images)
|
||||
max_height, max_width = get_max_height_width(processed_images)
|
||||
|
||||
processed_images_padded = torch.zeros(
|
||||
len(processed_images),
|
||||
max_num_images,
|
||||
*(processed_images[0][0].shape[0], max_height, max_width),
|
||||
device=processed_images[0][0].device,
|
||||
)
|
||||
pixel_attention_masks = torch.zeros(
|
||||
len(processed_images),
|
||||
max_num_images,
|
||||
*(max_height, max_width),
|
||||
device=processed_images[0][0].device,
|
||||
)
|
||||
for i, images in enumerate(processed_images):
|
||||
for j, image in enumerate(images):
|
||||
processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
|
||||
image, (max_height, max_width)
|
||||
)
|
||||
processed_images = processed_images_padded
|
||||
|
||||
if do_pad:
|
||||
data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
|
||||
elif return_tensors == "pt":
|
||||
data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
|
||||
else:
|
||||
data = {"pixel_values": processed_images}
|
||||
# This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
|
||||
encoding = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
if return_row_col_info:
|
||||
encoding["rows"] = rows
|
||||
encoding["cols"] = cols
|
||||
|
||||
return encoding
|
||||
|
||||
def to_dict(self):
|
||||
encoder_dict = super().to_dict()
|
||||
encoder_dict.pop("_valid_processor_keys", None)
|
||||
encoder_dict.pop("return_row_col_info", None)
|
||||
return encoder_dict
|
||||
|
||||
|
||||
__all__ = ["Idefics3ImageProcessorFast"]
|
@ -91,6 +91,7 @@ class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
|
||||
apply_ocr: bool,
|
||||
ocr_lang: Optional[str],
|
||||
tesseract_config: Optional[str],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
@ -111,7 +112,7 @@ class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
|
||||
boxes_batch.append(boxes)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -121,7 +122,7 @@ class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
||||
|
@ -16,11 +16,7 @@
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
)
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
|
||||
from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images
|
||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict
|
||||
from ...processing_utils import Unpack
|
||||
@ -106,6 +102,7 @@ class LayoutLMv3ImageProcessorFast(BaseImageProcessorFast):
|
||||
ocr_lang: Optional[str],
|
||||
tesseract_config: Optional[str],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
disable_grouping: Optional[bool],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Tesseract OCR to get words + normalized bounding boxes
|
||||
@ -125,7 +122,7 @@ class LayoutLMv3ImageProcessorFast(BaseImageProcessorFast):
|
||||
boxes_batch.append(boxes)
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -135,7 +132,7 @@ class LayoutLMv3ImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -393,13 +393,14 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions, device=images[0].device)
|
||||
# process images by batch, grouped by shape
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
grouped_processed_images = {}
|
||||
grouped_aspect_ratios = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
|
@ -16,9 +16,7 @@
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils import (
|
||||
BatchFeature,
|
||||
)
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
@ -147,10 +145,11 @@ class LlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_pad:
|
||||
@ -162,7 +161,7 @@ class LlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for batched resizing
|
||||
# Needed in case do_pad is False, or padding returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(padded_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(padded_images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -172,7 +171,7 @@ class LlavaImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -242,7 +242,9 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
do_pad: bool,
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
@ -271,7 +273,9 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches, disable_grouping=disable_grouping
|
||||
)
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
|
@ -248,7 +248,9 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
do_pad: bool,
|
||||
batch_num_images: list[int],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
@ -287,7 +289,9 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches, disable_grouping=disable_grouping
|
||||
)
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
|
@ -169,7 +169,9 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
do_pad: bool,
|
||||
batch_num_images: list[int],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
@ -208,7 +210,9 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches, disable_grouping=disable_grouping
|
||||
)
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
|
@ -96,11 +96,12 @@ class PerceiverImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
@ -112,7 +113,7 @@ class PerceiverImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
|
@ -23,11 +23,7 @@ from ...image_processing_utils_fast import (
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
)
|
||||
from ...image_utils import ImageInput, PILImageResampling, SizeDict
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
@ -38,9 +34,7 @@ from ...utils import (
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
from .image_processing_pixtral import (
|
||||
get_resize_output_image_size,
|
||||
)
|
||||
from .image_processing_pixtral import get_resize_output_image_size
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -164,12 +158,13 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
patch_size = get_size_dict(patch_size, default_to_square=True)
|
||||
patch_size = SizeDict(**patch_size)
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -181,7 +176,7 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
batch_image_sizes = [grouped_images_index[i][0] for i in range(len(grouped_images_index))]
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
|
@ -16,11 +16,7 @@
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
)
|
||||
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
|
||||
from ...image_transforms import (
|
||||
ChannelDimension,
|
||||
get_resize_output_image_size,
|
||||
@ -225,11 +221,12 @@ class PoolFormerImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -241,7 +238,7 @@ class PoolFormerImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_center_crop:
|
||||
|
@ -139,6 +139,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_convert_rgb: bool,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
disable_grouping: Optional[bool],
|
||||
):
|
||||
"""
|
||||
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
||||
@ -191,7 +192,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resized_height, resized_width = height, width
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -210,7 +211,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
@ -270,6 +271,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
disable_grouping: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@ -360,6 +362,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
@ -393,6 +396,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
pixel_values_videos.extend(patches)
|
||||
vision_grid_thws_videos.append(video_grid_thw)
|
||||
|
@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_smolvlm import *
|
||||
from .image_processing_smolvlm import *
|
||||
from .image_processing_smolvlm_fast import *
|
||||
from .modeling_smolvlm import *
|
||||
from .processing_smolvlm import *
|
||||
else:
|
||||
|
@ -767,6 +767,7 @@ class SmolVLMImageProcessor(BaseImageProcessor):
|
||||
split_image_array, rows, cols = self.split_image(
|
||||
image,
|
||||
max_image_size=max_image_size,
|
||||
resample=resample,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
split_image_arrays.extend(split_image_array)
|
||||
|
497
src/transformers/models/smolvlm/image_processing_smolvlm_fast.py
Normal file
497
src/transformers/models/smolvlm/image_processing_smolvlm_fast.py
Normal file
@ -0,0 +1,497 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_smolvlm.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
# Written by Orr Zohar
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
SizeDict,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
make_nested_list_of_images,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SmolVLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
do_image_splitting (`bool`, *optional*, defaults to `True`):
|
||||
Whether to split the image into sub-images concatenated with the original image. They are split into patches
|
||||
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
|
||||
max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
|
||||
Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
|
||||
return_row_col_info (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return the row and column information of the images.
|
||||
"""
|
||||
|
||||
do_pad: Optional[bool]
|
||||
do_image_splitting: Optional[bool]
|
||||
max_image_size: Optional[dict[str, int]]
|
||||
return_row_col_info: Optional[bool]
|
||||
|
||||
|
||||
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
|
||||
|
||||
|
||||
def _resize_output_size_rescale_to_max_len(
|
||||
height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
min_len (`int`, *optional*, defaults to 1):
|
||||
Minimum size of the output image.
|
||||
max_len (`int`, *optional*, defaults to the maximum size of the image):
|
||||
Maximum size of the output image.
|
||||
Returns:
|
||||
The output size of the image after resizing.
|
||||
"""
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
aspect_ratio = width / height
|
||||
|
||||
if width >= height:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
if height % 2 != 0:
|
||||
height += 1
|
||||
elif height > width:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
if width % 2 != 0:
|
||||
width += 1
|
||||
|
||||
# Avoid resizing to a size smaller than min_len
|
||||
height = max(height, min_len)
|
||||
width = max(width, min_len)
|
||||
return height, width
|
||||
|
||||
|
||||
def _resize_output_size_scale_below_upper_bound(
|
||||
height: int, width: int, max_len: Optional[dict[str, int]] = None
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
|
||||
Defines the maximum dimensions of the image.
|
||||
Returns:
|
||||
The output size of the image after resizing.
|
||||
"""
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
|
||||
aspect_ratio = width / height
|
||||
if width >= height and width > max_len:
|
||||
width = max_len
|
||||
height = int(width / aspect_ratio)
|
||||
elif height > width and height > max_len:
|
||||
height = max_len
|
||||
width = int(height * aspect_ratio)
|
||||
|
||||
# Avoid resizing to a size smaller than 1
|
||||
height = max(height, 1)
|
||||
width = max(width, 1)
|
||||
return height, width
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
image,
|
||||
resolution_max_side: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Image to resize.
|
||||
resolution_max_side (`int`):
|
||||
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
|
||||
input aspect ratio.
|
||||
Returns:
|
||||
The output size of the image after resizing.
|
||||
"""
|
||||
height, width = image.size()[-2:]
|
||||
|
||||
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
|
||||
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
|
||||
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
|
||||
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
|
||||
return height, width
|
||||
|
||||
|
||||
def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
|
||||
"""
|
||||
Get the maximum height and width across all images in a batch.
|
||||
"""
|
||||
image_sizes = []
|
||||
for images in images_list:
|
||||
for image in images:
|
||||
image_sizes.append(image.size()[-2:])
|
||||
|
||||
max_height = max(size[0] for size in image_sizes)
|
||||
max_width = max(size[1] for size in image_sizes)
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SmolVLMImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.LANCZOS
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"longest_edge": 4 * 364}
|
||||
max_image_size = {"longest_edge": 364}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_image_splitting = True
|
||||
do_pad = True
|
||||
return_row_col_info = False
|
||||
valid_kwargs = SmolVLMFastImageProcessorKwargs
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare a nested images structure for processing.
|
||||
"""
|
||||
return make_nested_list_of_images(images)
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: "torch.Tensor",
|
||||
size: SizeDict,
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
antialias: bool = True,
|
||||
**kwargs,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resize an image. The longest edge of the image is resized to size.longest_edge, with the shortest edge
|
||||
resized to keep the input aspect ratio. Can also be used with size.height and size.width.
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Size of the output image.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
antialias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use antialiasing when resizing the image.
|
||||
"""
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||
if interpolation == F.InterpolationMode.LANCZOS:
|
||||
logger.warning_once(
|
||||
"You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
|
||||
"BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
|
||||
"want full consistency with the original model."
|
||||
)
|
||||
interpolation = F.InterpolationMode.BICUBIC
|
||||
|
||||
if size.longest_edge:
|
||||
size = get_resize_output_image_size(image, resolution_max_side=size.longest_edge)
|
||||
elif size.height and size.width:
|
||||
size = (size.height, size.width)
|
||||
else:
|
||||
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
|
||||
|
||||
return F.resize(image, size, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
def split_images(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
max_image_size: dict[str, int],
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
):
|
||||
"""
|
||||
Split an image into squares of side max_image_size and the original image resized to max_image_size.
|
||||
That means that a single image becomes a sequence of images.
|
||||
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
|
||||
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
|
||||
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
|
||||
sub-images of the same size each (image_size, image_size). Typically, 364x364.
|
||||
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
Images to split.
|
||||
max_image_size (`Dict[str, int]`):
|
||||
Maximum size of the output image. If the image is larger than this size, it will be split into
|
||||
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
"""
|
||||
batch_size, num_channels, height, width = images.size()
|
||||
height_dim, width_dim = 2, 3
|
||||
|
||||
max_height = max_width = max_image_size["longest_edge"]
|
||||
|
||||
frames = []
|
||||
if height > max_height or width > max_width:
|
||||
# Calculate the number of splits
|
||||
num_splits_h = math.ceil(height / max_height)
|
||||
num_splits_w = math.ceil(width / max_width)
|
||||
|
||||
# Split the images by height, then by width
|
||||
frames = (
|
||||
images.unfold(height_dim, size=max_height, step=max_height)
|
||||
.unfold(width_dim, size=max_width, step=max_width)
|
||||
.contiguous()
|
||||
.view(batch_size, num_channels, -1, max_height, max_width)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
) # batch_size x n_frames x num_channels x height x width
|
||||
|
||||
# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
|
||||
global_image_height, global_image_width = max_height, max_width
|
||||
images = self.resize(
|
||||
images, SizeDict(height=global_image_height, width=global_image_width), interpolation=interpolation
|
||||
)
|
||||
|
||||
frames = torch.cat((frames, images.unsqueeze(1)), dim=1)
|
||||
else:
|
||||
num_splits_h, num_splits_w = 0, 0
|
||||
frames = images.unsqueeze(1)
|
||||
|
||||
num_splits_h = [num_splits_h] * batch_size
|
||||
num_splits_w = [num_splits_w] * batch_size
|
||||
|
||||
return frames, num_splits_h, num_splits_w
|
||||
|
||||
def resize_for_vision_encoder(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
vision_encoder_max_size: int,
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
):
|
||||
"""
|
||||
Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
|
||||
Args:
|
||||
image (`torch.Tensor`):
|
||||
Images to resize.
|
||||
vision_encoder_max_size (`int`):
|
||||
Maximum size of the output image. If the image is larger than this size, it will be split into
|
||||
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
"""
|
||||
height, width = image.size()[-2:]
|
||||
|
||||
aspect_ratio = width / height
|
||||
if width >= height:
|
||||
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
||||
height = int(width / aspect_ratio)
|
||||
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
||||
elif height > width:
|
||||
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
|
||||
width = int(height * aspect_ratio)
|
||||
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
|
||||
new_size = SizeDict(height=height, width=width)
|
||||
return self.resize(image, size=new_size, interpolation=interpolation)
|
||||
|
||||
def pad(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
padded_size: tuple[int, int],
|
||||
fill: int = 0,
|
||||
return_pixel_mask: bool = True,
|
||||
):
|
||||
original_size = image.shape[-2:]
|
||||
padding_bottom = padded_size[0] - original_size[0]
|
||||
padding_right = padded_size[1] - original_size[1]
|
||||
|
||||
if padding_bottom < 0 or padding_right < 0:
|
||||
raise ValueError(
|
||||
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
||||
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
||||
)
|
||||
|
||||
# Only pad if necessary
|
||||
if original_size != padded_size:
|
||||
padding = (0, 0, padding_right, padding_bottom)
|
||||
image = F.pad(image, padding, fill=fill, padding_mode="constant")
|
||||
|
||||
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
pixel_mask = None
|
||||
if return_pixel_mask:
|
||||
pixel_mask = torch.zeros_like(image[..., 0, :, :], dtype=torch.int64)
|
||||
pixel_mask[: original_size[0], : original_size[1]] = 1
|
||||
|
||||
return image, pixel_mask
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[SmolVLMFastImageProcessorKwargs]) -> BatchFeature:
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[list["torch.Tensor"]],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
do_pad: Optional[bool],
|
||||
do_image_splitting: Optional[bool],
|
||||
max_image_size: Optional[dict[str, int]],
|
||||
return_row_col_info: Optional[bool],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Process a batch of images for the model.
|
||||
"""
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
resized_images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
split_images_grouped = {}
|
||||
if do_image_splitting:
|
||||
rows_grouped = {}
|
||||
cols_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
stacked_images = self.resize_for_vision_encoder(
|
||||
stacked_images, max_image_size["longest_edge"], interpolation=interpolation
|
||||
)
|
||||
stacked_images, rows, cols = self.split_images(
|
||||
stacked_images, max_image_size=max_image_size, interpolation=interpolation
|
||||
)
|
||||
split_images_grouped[shape] = stacked_images
|
||||
rows_grouped[shape] = rows
|
||||
cols_grouped[shape] = cols
|
||||
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
|
||||
rows = reorder_images(rows_grouped, grouped_images_index, is_nested=True)
|
||||
cols = reorder_images(cols_grouped, grouped_images_index, is_nested=True)
|
||||
# flattenened the doubly nested list to a nested list
|
||||
for i, group_images in enumerate(processed_images):
|
||||
processed_images[i] = [image for sublist in group_images for image in sublist]
|
||||
else:
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# We square the images to max_image_size
|
||||
stacked_images = self.resize(
|
||||
image=stacked_images,
|
||||
size=SizeDict(height=max_image_size["longest_edge"], width=max_image_size["longest_edge"]),
|
||||
interpolation=interpolation,
|
||||
)
|
||||
split_images_grouped[shape] = stacked_images
|
||||
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
|
||||
rows = [[0] * len(images) for images in processed_images]
|
||||
cols = [[0] * len(images) for images in processed_images]
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(
|
||||
processed_images, is_nested=True, disable_grouping=disable_grouping
|
||||
)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
|
||||
if do_pad:
|
||||
# Get max images per batch
|
||||
max_num_images = max(len(images_) for images_ in processed_images)
|
||||
max_height, max_width = get_max_height_width(processed_images)
|
||||
|
||||
processed_images_padded = torch.zeros(
|
||||
len(processed_images),
|
||||
max_num_images,
|
||||
*(processed_images[0][0].shape[0], max_height, max_width),
|
||||
device=processed_images[0][0].device,
|
||||
)
|
||||
pixel_attention_masks = torch.zeros(
|
||||
len(processed_images),
|
||||
max_num_images,
|
||||
*(max_height, max_width),
|
||||
device=processed_images[0][0].device,
|
||||
)
|
||||
for i, images in enumerate(processed_images):
|
||||
for j, image in enumerate(images):
|
||||
processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
|
||||
image, (max_height, max_width)
|
||||
)
|
||||
processed_images = processed_images_padded
|
||||
|
||||
if do_pad:
|
||||
data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
|
||||
elif return_tensors == "pt":
|
||||
data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
|
||||
else:
|
||||
data = {"pixel_values": processed_images}
|
||||
# This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
|
||||
encoding = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
if return_row_col_info:
|
||||
encoding["rows"] = rows
|
||||
encoding["cols"] = cols
|
||||
|
||||
return encoding
|
||||
|
||||
def to_dict(self):
|
||||
encoder_dict = super().to_dict()
|
||||
encoder_dict.pop("_valid_processor_keys", None)
|
||||
encoder_dict.pop("return_row_col_info", None)
|
||||
return encoder_dict
|
||||
|
||||
|
||||
__all__ = ["SmolVLMImageProcessorFast"]
|
@ -34,7 +34,12 @@ from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel
|
||||
from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig
|
||||
|
||||
|
@ -25,6 +25,7 @@ from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
|
||||
from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
|
||||
from ..idefics3.image_processing_idefics3_fast import Idefics3ImageProcessorFast
|
||||
from ..idefics3.modeling_idefics3 import (
|
||||
Idefics3BaseModelOutputWithPast,
|
||||
Idefics3ForConditionalGeneration,
|
||||
@ -160,6 +161,10 @@ class SmolVLMImageProcessor(Idefics3ImageProcessor):
|
||||
pass
|
||||
|
||||
|
||||
class SmolVLMImageProcessorFast(Idefics3ImageProcessorFast):
|
||||
pass
|
||||
|
||||
|
||||
class SmolVLMBaseModelOutputWithPast(Idefics3BaseModelOutputWithPast):
|
||||
pass
|
||||
|
||||
@ -396,6 +401,7 @@ __all__ = [
|
||||
"SmolVLMVisionConfig",
|
||||
"SmolVLMConfig",
|
||||
"SmolVLMImageProcessor",
|
||||
"SmolVLMImageProcessorFast",
|
||||
"SmolVLMForConditionalGeneration",
|
||||
"SmolVLMPreTrainedModel",
|
||||
"SmolVLMModel",
|
||||
|
@ -100,11 +100,11 @@ class Swin2SRImageProcessorFast(BaseImageProcessorFast):
|
||||
rescale_factor: float,
|
||||
do_pad: bool,
|
||||
pad_size: int,
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_image_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_rescale:
|
||||
|
@ -93,6 +93,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
@ -102,7 +103,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
|
||||
This method overrides the base class method to include padding and pixel mask generation.
|
||||
"""
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
@ -112,7 +113,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
# Group images by size for further processing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
@ -127,7 +128,9 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
|
||||
# Handle padding if required
|
||||
data = {}
|
||||
if do_pad:
|
||||
pixel_values, pixel_mask = self._pad_batch(processed_images, return_tensors)
|
||||
pixel_values, pixel_mask = self._pad_batch(
|
||||
processed_images, return_tensors, disable_grouping=disable_grouping
|
||||
)
|
||||
data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
|
||||
else:
|
||||
# If no padding, just return the processed images
|
||||
@ -195,6 +198,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
disable_grouping: Optional[bool],
|
||||
) -> tuple:
|
||||
"""
|
||||
Pad a batch of images to the same size based on the maximum dimensions.
|
||||
@ -210,7 +214,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
|
||||
max_size = get_max_height_width(images)
|
||||
|
||||
# Group images by shape before padding
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
processed_images = {}
|
||||
processed_masks = {}
|
||||
|
||||
|
@ -202,10 +202,12 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast):
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
do_pad: Optional[bool] = None,
|
||||
size_divisibility: Optional[int] = None,
|
||||
disable_grouping: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape in grouped_images:
|
||||
stacked_images = grouped_images[shape]
|
||||
|
@ -188,11 +188,12 @@ class ZoeDepthImageProcessorFast(BaseImageProcessorFast):
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
disable_grouping: Optional[bool],
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_rescale:
|
||||
|
@ -204,6 +204,15 @@ class ImageProcessorArgs:
|
||||
"shape": None,
|
||||
}
|
||||
|
||||
disable_grouping = {
|
||||
"description": """
|
||||
Whether to disable grouping of images by size to process them individually and not in batches.
|
||||
If None, will be set to True if the images are on CPU, and False otherwise. This choice is based on
|
||||
empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
|
||||
""",
|
||||
"shape": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelArgs:
|
||||
labels = {
|
||||
|
@ -298,11 +298,10 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
|
||||
self._assert_slow_fast_tensors_equivalence(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
image_encoding_slow.labels.float(), image_encoding_fast.labels.float()
|
||||
)
|
||||
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
|
||||
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
@ -324,7 +323,5 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.labels.float(), encoding_fast.labels.float())
|
||||
|
@ -19,14 +19,11 @@ from typing import Optional, Union
|
||||
import requests
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@ -124,10 +121,6 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "size_divisor"))
|
||||
|
||||
def _assertEquivalence(self, a, b):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-1))
|
||||
self.assertLessEqual(torch.mean(torch.abs(a - b)).item(), 1e-3)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
@ -146,8 +139,8 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
self._assertEquivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assertEquivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@ -170,5 +163,5 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
self._assertEquivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assertEquivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())
|
||||
|
@ -418,15 +418,8 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_fast = image_processor_fast(
|
||||
dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True
|
||||
)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(encoding_slow.codebook_pixel_values, encoding_fast.codebook_pixel_values, atol=1e-1)
|
||||
)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.codebook_pixel_values - encoding_fast.codebook_pixel_values)).item(),
|
||||
1e-3,
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.codebook_pixel_values, encoding_fast.codebook_pixel_values
|
||||
)
|
||||
|
@ -286,7 +286,4 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
@ -125,10 +125,7 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
def test_slow_fast_equivalence_batched_crop_to_patches(self):
|
||||
# Prepare image inputs so that we have two groups of images with equal resolution with a group of images with
|
||||
@ -144,10 +141,7 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
def test_crop_to_patches(self):
|
||||
# test slow image processor
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,13 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin
|
||||
|
||||
@ -28,6 +28,8 @@ if is_vision_available():
|
||||
|
||||
from transformers import Idefics2ImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Idefics2ImageProcessorFast
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -88,10 +90,6 @@ class Idefics2ImageProcessingTester:
|
||||
}
|
||||
|
||||
def get_expected_values(self, image_inputs, batched=False):
|
||||
"""
|
||||
This function computes the expected height and width when providing images to BridgeTowerImageProcessor,
|
||||
assuming do_resize is set to True with a scalar size and size_divisor.
|
||||
"""
|
||||
if not batched:
|
||||
shortest_edge = self.size["shortest_edge"]
|
||||
longest_edge = self.size["longest_edge"]
|
||||
@ -142,11 +140,6 @@ class Idefics2ImageProcessingTester:
|
||||
numpify=False,
|
||||
torchify=False,
|
||||
):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
|
||||
One can specify whether the images are of the same resolution or not.
|
||||
"""
|
||||
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
|
||||
|
||||
batch_size = batch_size if batch_size is not None else self.batch_size
|
||||
@ -162,23 +155,19 @@ class Idefics2ImageProcessingTester:
|
||||
if equal_resolution:
|
||||
width = height = max_resolution
|
||||
else:
|
||||
# To avoid getting image width/height 0
|
||||
if size_divisor is not None:
|
||||
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
|
||||
min_resolution = max(size_divisor, min_resolution)
|
||||
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
|
||||
images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
|
||||
images_list.append(images)
|
||||
|
||||
if not numpify and not torchify:
|
||||
# PIL expects the channel dimension as last dimension
|
||||
images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list]
|
||||
|
||||
if torchify:
|
||||
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
|
||||
|
||||
if numpify:
|
||||
# Numpy images are typically in channels last format
|
||||
images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list]
|
||||
|
||||
return images_list
|
||||
@ -188,6 +177,7 @@ class Idefics2ImageProcessingTester:
|
||||
@require_vision
|
||||
class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = Idefics2ImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = Idefics2ImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -198,22 +188,23 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for sample_images in image_inputs:
|
||||
@ -238,7 +229,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processor_dict = self.image_processor_dict
|
||||
image_processor_dict["image_mean"] = [0.5, 0.5, 0.5, 0.5]
|
||||
image_processor_dict["image_std"] = [0.5, 0.5, 0.5, 0.5]
|
||||
image_processing = self.image_processing_class(**image_processor_dict)
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
@ -266,7 +257,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
for images in image_inputs:
|
||||
@ -288,7 +279,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
|
||||
@ -308,3 +299,104 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_images.shape),
|
||||
(self.image_processor_tester.batch_size, *expected_output_image_shape),
|
||||
)
|
||||
|
||||
def test_image_splitting(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor_dict = self.image_processor_dict.copy()
|
||||
image_processor_dict["do_image_splitting"] = True
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=True, torchify=True, num_images=1
|
||||
)
|
||||
|
||||
result = image_processing(image_inputs[0], return_tensors="pt")
|
||||
self.assertEqual(result.pixel_values.shape[1], 5)
|
||||
|
||||
image_processor_dict["do_image_splitting"] = False
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
|
||||
result = image_processing(image_inputs[0], return_tensors="pt")
|
||||
if len(result.pixel_values.shape) == 5:
|
||||
self.assertEqual(result.pixel_values.shape[1], 1)
|
||||
else:
|
||||
self.assertEqual(result.pixel_values.shape[1], self.image_processor_tester.num_channels)
|
||||
|
||||
def test_pixel_attention_mask(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor_dict = self.image_processor_dict.copy()
|
||||
image_processor_dict["do_pad"] = True
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
|
||||
result = image_processing(image_inputs, return_tensors="pt")
|
||||
self.assertIn("pixel_attention_mask", result)
|
||||
|
||||
self.assertEqual(result.pixel_attention_mask.shape[-2:], result.pixel_values.shape[-2:])
|
||||
|
||||
image_processor_dict["do_pad"] = False
|
||||
image_processor_dict["do_image_splitting"] = False
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
|
||||
equal_size_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
result = image_processing(equal_size_inputs, return_tensors="pt")
|
||||
self.assertNotIn("pixel_attention_mask", result)
|
||||
|
||||
def test_convert_rgb(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
rgba_image = Image.new("RGBA", (100, 100), (255, 0, 0, 128))
|
||||
|
||||
# Test with do_convert_rgb=True - this should work for all processors
|
||||
image_processor_dict = self.image_processor_dict.copy()
|
||||
image_processor_dict["do_convert_rgb"] = True
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
|
||||
result = image_processing([rgba_image], return_tensors="pt")
|
||||
self.assertIsNotNone(result.pixel_values)
|
||||
rgb_image = rgba_image.convert("RGB")
|
||||
|
||||
image_processor_dict["do_convert_rgb"] = False
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
|
||||
# Use the RGB image instead of RGBA when do_convert_rgb=False
|
||||
result = image_processing([rgb_image], return_tensors="pt")
|
||||
self.assertIsNotNone(result.pixel_values)
|
||||
|
||||
# Additional test: verifying proper handling of regular RGB images
|
||||
rgb_image = Image.new("RGB", (100, 100), (255, 0, 0))
|
||||
result = image_processing([rgb_image], return_tensors="pt")
|
||||
self.assertIsNotNone(result.pixel_values)
|
||||
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=False, num_images=5, torchify=True
|
||||
)
|
||||
# pop some images to have non homogenous batches:
|
||||
indices_to_pop = [i if np.random.random() < 0.5 else None for i in range(len(dummy_images))]
|
||||
for i in indices_to_pop:
|
||||
if i is not None:
|
||||
dummy_images[i].pop()
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -16,10 +17,11 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin
|
||||
|
||||
@ -29,6 +31,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import Idefics3ImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Idefics3ImageProcessorFast
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -164,6 +169,7 @@ class Idefics3ImageProcessingTester:
|
||||
@require_vision
|
||||
class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = Idefics3ImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = Idefics3ImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -174,25 +180,26 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
self.assertTrue(hasattr(image_processing, "max_image_size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
self.assertTrue(hasattr(image_processing, "max_image_size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for sample_images in image_inputs:
|
||||
@ -216,7 +223,7 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processor_dict = self.image_processor_dict
|
||||
image_processing = self.image_processing_class(**image_processor_dict)
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
|
||||
@ -239,7 +246,7 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
for images in image_inputs:
|
||||
@ -261,7 +268,7 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
|
||||
@ -281,3 +288,73 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_images.shape),
|
||||
(self.image_processor_tester.batch_size, *expected_output_image_shape),
|
||||
)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
dummy_image = dummy_image.resize((100, 150))
|
||||
image_processor_slow = self.image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
image_processor_fast = self.fast_image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt", return_row_col_info=True)
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt", return_row_col_info=True)
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
|
||||
)
|
||||
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
|
||||
self.assertEqual(encoding_slow.cols, encoding_fast.cols)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=False, num_images=5, torchify=True
|
||||
)
|
||||
# pop some images to have non homogenous batches:
|
||||
indices_to_pop = [i if np.random.random() < 0.5 else None for i in range(len(dummy_images))]
|
||||
for i in indices_to_pop:
|
||||
if i is not None:
|
||||
dummy_images[i].pop()
|
||||
|
||||
image_processor_slow = self.image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
image_processor_fast = self.fast_image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt", return_row_col_info=True)
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt", return_row_col_info=True)
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=3e-1)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
|
||||
)
|
||||
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
|
||||
self.assertEqual(encoding_slow.cols, encoding_fast.cols)
|
||||
|
@ -15,9 +15,21 @@
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from transformers.testing_utils import require_pytesseract, require_torch, require_vision
|
||||
from transformers.utils import is_pytesseract_available, is_torch_available, is_torchvision_available
|
||||
from transformers.testing_utils import (
|
||||
require_pytesseract,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_pytesseract_available,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
)
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@ -157,16 +169,8 @@ class LayoutLMv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255, atol=1e-1
|
||||
)
|
||||
)
|
||||
self.assertLessEqual(
|
||||
torch.mean(
|
||||
torch.abs(encoding_slow.pixel_values.float() - encoding_fast.pixel_values.float()) / 255
|
||||
).item(),
|
||||
1e-3,
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255
|
||||
)
|
||||
|
||||
@require_vision
|
||||
@ -190,14 +194,28 @@ class LayoutLMv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255, atol=1e-1
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255
|
||||
)
|
||||
self.assertLessEqual(
|
||||
torch.mean(
|
||||
torch.abs(encoding_slow.pixel_values.float() - encoding_fast.pixel_values.float()) / 255
|
||||
).item(),
|
||||
1e-3,
|
||||
|
||||
# Overriding as we can't use torch.testing.assert_close on int8 tensors
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@require_vision
|
||||
def test_can_compile_fast_image_processor(self):
|
||||
if self.fast_image_processing_class is None:
|
||||
self.skipTest("Skipping compilation test as fast image processor is not defined")
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
torch.compiler.reset()
|
||||
input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8)
|
||||
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
|
||||
image_processor = torch.compile(image_processor, mode="reduce-overhead")
|
||||
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
output_eager.pixel_values.float() / 255, output_compiled.pixel_values.float() / 255
|
||||
)
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -214,29 +213,6 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
|
||||
self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_fast_is_faster_than_slow(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping speed test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping speed test as one of the image processors is not defined")
|
||||
|
||||
def measure_time(image_processor, image):
|
||||
start = time.time()
|
||||
_ = image_processor(image, return_tensors="pt")
|
||||
return time.time() - start
|
||||
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
fast_time = measure_time(image_processor_fast, image_inputs_list)
|
||||
slow_time = measure_time(image_processor_slow, image_inputs_list)
|
||||
|
||||
self.assertLessEqual(fast_time, slow_time)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
@ -255,9 +231,7 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
torch.testing.assert_close(
|
||||
encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], rtol=100, atol=1e-1
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0])
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@ -282,14 +256,8 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
for i in range(len(encoding_slow.pixel_values)):
|
||||
self.assertTrue(
|
||||
torch.allclose(encoding_slow.pixel_values[i][0], encoding_fast.pixel_values[i][0], atol=1e-1)
|
||||
)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values[i][0] - encoding_fast.pixel_values[i][0])).item(), 1e-3
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], rtol=100, atol=1e-1
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_values[i][0], encoding_fast.pixel_values[i][0]
|
||||
)
|
||||
|
||||
@slow
|
||||
@ -309,8 +277,8 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processor = torch.compile(image_processor, mode="reduce-overhead")
|
||||
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_eager.pixel_values[0][0], output_compiled.pixel_values[0][0], rtol=1e-4, atol=1e-4
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
output_eager.pixel_values[0][0], output_compiled.pixel_values[0][0], atol=1e-4, rtol=1e-4, mean_atol=1e-5
|
||||
)
|
||||
|
||||
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
|
||||
|
@ -362,6 +362,4 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(
|
||||
encoding_slow.pixel_values, encoding_fast.pixel_values, rtol=100, atol=1e-2
|
||||
) # @yoni bit weird that we have such diffs
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
@ -18,14 +18,11 @@ import unittest
|
||||
import requests
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@ -150,7 +147,6 @@ class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
# increase mean tolerance to 1e-3 -> 2e-3
|
||||
# Ignore copy
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
@ -167,10 +163,7 @@ class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
# increase mean tolerance to 1e-3 -> 2e-3
|
||||
# Ignore copy
|
||||
@ -193,7 +186,4 @@ class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -16,10 +17,11 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin
|
||||
|
||||
@ -29,6 +31,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import SmolVLMImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import SmolVLMImageProcessorFast
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -164,6 +169,7 @@ class SmolVLMImageProcessingTester:
|
||||
@require_vision
|
||||
class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = SmolVLMImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = SmolVLMImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -174,25 +180,26 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
self.assertTrue(hasattr(image_processing, "max_image_size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
self.assertTrue(hasattr(image_processing, "max_image_size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
for sample_images in image_inputs:
|
||||
@ -216,7 +223,7 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processor_dict = self.image_processor_dict
|
||||
image_processing = self.image_processing_class(**image_processor_dict)
|
||||
image_processing = image_processing_class(**image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
|
||||
@ -239,7 +246,7 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
for images in image_inputs:
|
||||
@ -261,7 +268,7 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
|
||||
|
||||
@ -281,3 +288,73 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_images.shape),
|
||||
(self.image_processor_tester.batch_size, *expected_output_image_shape),
|
||||
)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
dummy_image = Image.open(
|
||||
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
|
||||
)
|
||||
dummy_image = dummy_image.resize((100, 150))
|
||||
image_processor_slow = self.image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
image_processor_fast = self.fast_image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt", return_row_col_info=True)
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt", return_row_col_info=True)
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
|
||||
)
|
||||
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
|
||||
self.assertEqual(encoding_slow.cols, encoding_fast.cols)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
if not self.test_slow_image_processor or not self.test_fast_image_processor:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test")
|
||||
|
||||
if self.image_processing_class is None or self.fast_image_processing_class is None:
|
||||
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
|
||||
|
||||
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
|
||||
self.skipTest(
|
||||
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
|
||||
)
|
||||
|
||||
dummy_images = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=False, num_images=5, torchify=True
|
||||
)
|
||||
# pop some images to have non homogenous batches:
|
||||
indices_to_pop = [i if np.random.random() < 0.5 else None for i in range(len(dummy_images))]
|
||||
for i in indices_to_pop:
|
||||
if i is not None:
|
||||
dummy_images[i].pop()
|
||||
|
||||
image_processor_slow = self.image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
image_processor_fast = self.fast_image_processing_class(
|
||||
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
|
||||
)
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt", return_row_col_info=True)
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt", return_row_col_info=True)
|
||||
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=3e-1)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
|
||||
)
|
||||
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
|
||||
self.assertEqual(encoding_slow.cols, encoding_fast.cols)
|
||||
|
@ -197,7 +197,7 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||
|
||||
encoded_slow = image_processor_slow(image_inputs, return_tensors="pt").pixel_values
|
||||
encoded_fast = image_processor_fast(image_inputs, return_tensors="pt").pixel_values
|
||||
encoded_slow = image_processor_slow(image_inputs, return_tensors="pt")
|
||||
encoded_fast = image_processor_fast(image_inputs, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoded_slow, encoded_fast, atol=1e-1))
|
||||
self._assert_slow_fast_tensors_equivalence(encoded_slow.pixel_values, encoded_fast.pixel_values)
|
||||
|
@ -312,10 +312,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, trimaps=dummy_trimap, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, trimaps=dummy_trimap, return_tensors="pt")
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
def test_slow_fast_equivalence_batched(self):
|
||||
# this only checks on equal resolution, since the slow processor doesn't work otherwise
|
||||
@ -338,10 +335,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
encoding_slow = image_processor_slow(dummy_images, trimaps=dummy_trimaps, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, trimaps=dummy_trimaps, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
@ -162,6 +162,10 @@ class ImageProcessingTestMixin:
|
||||
|
||||
self.image_processor_list = image_processor_list
|
||||
|
||||
def _assert_slow_fast_tensors_equivalence(self, slow_tensor, fast_tensor, atol=1e-1, rtol=1e-3, mean_atol=5e-3):
|
||||
torch.testing.assert_close(slow_tensor, fast_tensor, atol=atol, rtol=rtol)
|
||||
self.assertLessEqual(torch.mean(torch.abs(slow_tensor - fast_tensor)).item(), mean_atol)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_slow_fast_equivalence(self):
|
||||
@ -179,10 +183,7 @@ class ImageProcessingTestMixin:
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-3)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 5e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@ -205,10 +206,7 @@ class ImageProcessingTestMixin:
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-3)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 5e-3
|
||||
)
|
||||
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@ -577,8 +575,10 @@ class ImageProcessingTestMixin:
|
||||
|
||||
image_processor = torch.compile(image_processor, mode="reduce-overhead")
|
||||
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(output_eager.pixel_values, output_compiled.pixel_values, rtol=1e-4, atol=1e-4)
|
||||
print(output_eager.pixel_values.dtype, output_compiled.pixel_values.dtype)
|
||||
self._assert_slow_fast_tensors_equivalence(
|
||||
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
|
||||
)
|
||||
|
||||
|
||||
class AnnotationFormatTestMixin:
|
||||
|
Loading…
Reference in New Issue
Block a user