Refactoring of ImageProcessorFast (#35069)

* add init and base image processing functions

* add add_fast_image_processor to transformers-cli

* add working fast image processor clip

* add fast image processor to doc, working tests

* remove "to be implemented" SigLip

* fix unprotected import

* fix unprotected vision import

* update ViTImageProcessorFast

* increase threshold slow fast ewuivalence

* add fast img blip

* add fast class in tests with cli

* improve cli

* add fast image processor convnext

* add LlavaPatchingMixin and fast image processor for llava_next and llava_onevision

* add device kwarg to ImagesKwargs for fast processing on cuda

* cleanup

* fix unprotected import

* group images by sizes and add batch processing

* Add batch equivalence tests, skip when center_crop is used

* cleanup

* update init and cli

* fix-copies

* refactor convnext, cleanup base

* fix

* remove patching mixins, add piped torchvision transforms for ViT

* fix unbatched processing

* fix f strings

* protect imports

* change llava onevision to class transforms (test)

* fix convnext

* improve formatting (following Pavel review)

* fix handling device arg

* improve cli

* fix

* fix inits

* Add distinction between preprocess and _preprocess, and support for arbitrary kwargs through valid_extra_kwargs

* uniformize qwen2_vl fast

* fix docstrings

* add add fast image processor llava

* remove min_pixels max_pixels from accepted size

* nit

* nit

* refactor fast image processors docstrings

* cleanup and remove fast class transforms

* update add fast image processor transformers cli

* cleanup docstring

* uniformize pixtral fast and  make _process_image explicit

* fix prepare image structure llava next/onevision

* Use typed kwargs instead of explicit args

* nit fix import Unpack

* clearly separate pops and gets in base preprocess. Use explicit typed kwargs

* make qwen2_vl preprocess arguments hashable
This commit is contained in:
Yoni Gozlan 2025-02-04 17:52:31 -05:00 committed by GitHub
parent 8d73a38606
commit fa56dcc2ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 4047 additions and 2244 deletions

View File

@ -61,6 +61,11 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
[[autodoc]] BlipImageProcessor
- preprocess
## BlipImageProcessorFast
[[autodoc]] BlipImageProcessorFast
- preprocess
<frameworkcontent>
<pt>

View File

@ -251,6 +251,11 @@ The resource should ideally demonstrate something new instead of duplicating an
[[autodoc]] CLIPImageProcessor
- preprocess
## CLIPImageProcessorFast
[[autodoc]] CLIPImageProcessorFast
- preprocess
## CLIPFeatureExtractor
[[autodoc]] CLIPFeatureExtractor

View File

@ -64,6 +64,11 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] ConvNextImageProcessor
- preprocess
## ConvNextImageProcessorFast
[[autodoc]] ConvNextImageProcessorFast
- preprocess
<frameworkcontent>
<pt>

View File

@ -125,6 +125,11 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] DeiTImageProcessor
- preprocess
## DeiTImageProcessorFast
[[autodoc]] DeiTImageProcessorFast
- preprocess
<frameworkcontent>
<pt>

View File

@ -195,6 +195,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlavaImageProcessor
- preprocess
## LlavaImageProcessorFast
[[autodoc]] LlavaImageProcessorFast
- preprocess
## LlavaProcessor
[[autodoc]] LlavaProcessor

View File

@ -288,6 +288,11 @@ model = AutoModelForImageTextToText.from_pretrained(
[[autodoc]] LlavaNextImageProcessor
- preprocess
## LlavaNextImageProcessorFast
[[autodoc]] LlavaNextImageProcessorFast
- preprocess
## LlavaNextProcessor
[[autodoc]] LlavaNextProcessor

View File

@ -100,8 +100,8 @@ import torch
from PIL import Image
import requests
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda:0")
# prepare image and text prompt, using the appropriate prompt template
@ -298,8 +298,8 @@ First make sure to install flash-attn. Refer to the [original repository of Flas
from transformers import LlavaOnevisionForConditionalGeneration
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
use_flash_attention_2=True
).to(0)
@ -318,6 +318,11 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
[[autodoc]] LlavaOnevisionImageProcessor
## LlavaOnevisionImageProcessorFast
[[autodoc]] LlavaOnevisionImageProcessorFast
- preprocess
## LlavaOnevisionVideoProcessor
[[autodoc]] LlavaOnevisionVideoProcessor

View File

@ -214,6 +214,11 @@ Below is an expected speedup diagram that compares inference time between the na
[[autodoc]] SiglipImageProcessor
- preprocess
## SiglipImageProcessorFast
[[autodoc]] SiglipImageProcessorFast
- preprocess
## SiglipProcessor
[[autodoc]] SiglipProcessor

View File

@ -61,6 +61,11 @@ BLIP は、次のようなさまざまなマルチモーダル タスクを実
[[autodoc]] BlipImageProcessor
- preprocess
## BlipImageProcessorFast
[[autodoc]] BlipImageProcessorFast
- preprocess
<frameworkcontent>
<pt>

View File

@ -133,6 +133,11 @@ CLIP を使い始めるのに役立つ公式 Hugging Face およびコミュニ
[[autodoc]] CLIPImageProcessor
- preprocess
## CLIPImageProcessorFast
[[autodoc]] CLIPImageProcessorFast
- preprocess
## CLIPFeatureExtractor
[[autodoc]] CLIPFeatureExtractor

View File

@ -64,6 +64,11 @@ ConvNeXT の使用を開始するのに役立つ公式 Hugging Face およびコ
[[autodoc]] ConvNextImageProcessor
- preprocess
## ConvNextImageProcessorFast
[[autodoc]] ConvNextImageProcessorFast
- preprocess
<frameworkcontent>
<pt>

View File

@ -98,6 +98,11 @@ DeiT を始めるのに役立つ公式 Hugging Face およびコミュニティ
[[autodoc]] DeiTImageProcessor
- preprocess
## DeiTImageProcessorFast
[[autodoc]] DeiTImageProcessorFast
- preprocess
<frameworkcontent>
<pt>

View File

@ -452,10 +452,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
return model_inputs
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)

View File

@ -70,10 +70,7 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
return (embeddings,) + vlm_outputs
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)

View File

@ -1308,11 +1308,19 @@ except OptionalDependencyNotAvailable:
]
else:
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["models.blip"].append("BlipImageProcessorFast")
_import_structure["models.clip"].append("CLIPImageProcessorFast")
_import_structure["models.convnext"].append("ConvNextImageProcessorFast")
_import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
_import_structure["models.deit"].append("DeiTImageProcessorFast")
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
_import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast")
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
_import_structure["models.siglip"].append("SiglipImageProcessorFast")
_import_structure["models.vit"].append("ViTImageProcessorFast")
try:
@ -6442,11 +6450,19 @@ if TYPE_CHECKING:
from .utils.dummy_torchvision_objects import *
else:
from .image_processing_utils_fast import BaseImageProcessorFast
from .models.blip import BlipImageProcessorFast
from .models.clip import CLIPImageProcessorFast
from .models.convnext import ConvNextImageProcessorFast
from .models.deformable_detr import DeformableDetrImageProcessorFast
from .models.deit import DeiTImageProcessorFast
from .models.detr import DetrImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
from .models.pixtral import PixtralImageProcessorFast
from .models.qwen2_vl import Qwen2VLImageProcessorFast
from .models.rt_detr import RTDetrImageProcessorFast
from .models.siglip import SiglipImageProcessorFast
from .models.vit import ViTImageProcessorFast
try:

View File

@ -0,0 +1,655 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from argparse import ArgumentParser, Namespace
from datetime import date
from pathlib import Path
from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CURRENT_YEAR = date.today().year
TRANSFORMERS_PATH = Path(__file__).parent.parent
REPO_PATH = TRANSFORMERS_PATH.parent.parent
def add_import_structure_entry_init(content: str, fast_image_processor_name: str, model_name: str):
"""
Add an entry to the `_import_structure` dictionary in the `__init__.py` file of the transformers package.
"""
# Step 1: Find the block
block_regex = re.compile(
r"if not is_torchvision_available\(\):.*?else:\s*(\n(?P<indent>\s+)_import_structure\[.*?\].*?\n(?:\s*(?P=indent)_import_structure\[.*?\].*?\n)*)",
re.DOTALL,
)
match = block_regex.search(content)
if not match:
raise ValueError("Couldn't find the '_import_structure' block.")
# Capture the block content and indentation
block_content = match.group(1)
indent = match.group("indent")
# Step 2: Parse existing entries
lines = block_content.strip().split("\n")
entries = []
import_structure_header = indent + lines[0]
entries = lines[1:]
# Add the new entry, maintaining alphabetical order
new_entry = f'{indent}_import_structure["models.{model_name}"].append("{fast_image_processor_name}")'
if new_entry not in entries:
entries.append(new_entry)
entries.sort()
entries = [import_structure_header] + entries
# Step 3: Reconstruct the block
updated_block = "\n".join(entry for entry in entries)
# Replace the original block in the content
updated_content = content[: match.start(1)] + "\n" + updated_block + "\n" + content[match.end(1) :]
return updated_content
def add_import_statement_init(content: str, fast_image_processor_name: str, model_name: str):
"""
Add an import statement to the `__init__.py` file of the transformers package.
"""
# Step 1: Find the block
block_regex = re.compile(
r"if not is_torchvision_available\(\):\s+raise OptionalDependencyNotAvailable\(\)\s+except OptionalDependencyNotAvailable:\s+from \.utils\.dummy_torchvision_objects import \*\s+else:(?P<else_block>\s*(\n\s*from .+ import .*\n)+)(?=\s*try:\s+if not \(is_torchvision_available\(\) and is_timm_available\(\)\):)",
re.DOTALL,
)
match = block_regex.search(content)
if match:
block_content = match.group("else_block") # The captured import block
else:
print("Couldn't find the import statement block.")
# Step 2: Parse existing entries
lines = block_content.strip().split("\n")
entries = []
indent = " " * (len(lines[1]) - len(lines[1].lstrip()))
import_structure_header = indent + lines[0]
entries = lines[1:]
# Add the new entry, maintaining alphabetical order
new_entry = f"{indent}from .models.{model_name} import {fast_image_processor_name}"
if new_entry not in entries:
entries.append(new_entry)
entries.sort()
entries = [import_structure_header] + entries
# Step 3: Reconstruct the block
updated_block = "\n".join(entry for entry in entries)
# Replace the original block in the content
updated_content = (
content[: match.start("else_block")] + "\n" + updated_block + "\n\n" + content[match.end("else_block") :]
)
return updated_content
def add_fast_image_processor_to_main_init(fast_image_processor_name: str, model_name: str):
"""
Add the fast image processor to the main __init__.py file of the transformers package.
"""
with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f:
content = f.read()
# add _import_structure entry
content = add_import_structure_entry_init(content, fast_image_processor_name, model_name)
# add import statement
content = add_import_statement_init(content, fast_image_processor_name, model_name)
# write the updated content
with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f:
f.write(content)
def add_fast_image_processor_to_model_init(
fast_image_processing_module_file: str, fast_image_processor_name, model_name: str
):
"""
Add the fast image processor to the __init__.py file of the model.
"""
with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "r", encoding="utf-8") as f:
content = f.read()
fast_image_processing_module_file = fast_image_processing_module_file.split(os.sep)[-1].replace(".py", "")
if "import *" in content:
# we have an init file in the updated format
# get the indented block after if TYPE_CHECKING: and before else:, append the new import, sort the imports and write the updated content
# Step 1: Find the block
block_regex = re.compile(
r"if TYPE_CHECKING:\n(?P<if_block>.*?)(?=\s*else:)",
re.DOTALL,
)
match = block_regex.search(content)
if not match:
raise ValueError("Couldn't find the 'if TYPE_CHECKING' block.")
block_content = match.group("if_block") # The captured import block
# Step 2: Parse existing entries
entries = block_content.split("\n")
indent = " " * (len(entries[0]) - len(entries[0].lstrip()))
new_entry = f"{indent}from .{fast_image_processing_module_file} import *"
if new_entry not in entries:
entries.append(new_entry)
entries.sort()
updated_block = "\n".join(entry for entry in entries)
# Replace the original block in the content
updated_content = content[: match.start("if_block")] + updated_block + content[match.end("if_block") :]
else:
# we have an init file in the old format
# add "is_torchvision_available" import to from ...utils import (
# Regex to match import statements from transformers.utils
pattern = r"""
from\s+\.\.\.utils\s+import\s+
(?: # Non-capturing group for either:
([\w, ]+) # 1. Single-line imports (e.g., 'a, b')
| # OR
\((.*?)\) # 2. Multi-line imports (e.g., '(a, ... b)')
)
"""
regex = re.compile(pattern, re.VERBOSE | re.DOTALL)
def replacement_function(match):
# Extract existing imports
imports = (match.group(1) or match.group(2)).split(",")
imports = imports[:-1] if imports[-1] == "\n" else imports
imports = [imp.strip() for imp in imports]
# Add the new import if not already present
if "is_torchvision_available" not in imports:
imports.append("is_torchvision_available")
imports.sort()
# Convert to multi-line import in all cases
updated_imports = "(\n " + ",\n ".join(imports) + ",\n)"
return f"from ...utils import {updated_imports}"
# Replace all matches in the file content
updated_content = regex.sub(replacement_function, content)
vision_import_structure_block = f' _import_structure["{fast_image_processing_module_file[:-5]}"] = ["{fast_image_processor_name[:-4]}"]\n'
added_import_structure_block = (
"try:\n if not is_torchvision_available():\n"
" raise OptionalDependencyNotAvailable()\n"
"except OptionalDependencyNotAvailable:\n"
" pass\n"
"else:\n"
f' _import_structure["{fast_image_processing_module_file}"] = ["{fast_image_processor_name}"]\n'
)
if vision_import_structure_block not in updated_content:
raise ValueError("Couldn't find the 'vision _import_structure block' block.")
if added_import_structure_block not in updated_content:
updated_content = updated_content.replace(
vision_import_structure_block, vision_import_structure_block + "\n" + added_import_structure_block
)
vision_import_statement_block = (
f" from .{fast_image_processing_module_file[:-5]} import {fast_image_processor_name[:-4]}\n"
)
added_import_statement_block = (
" try:\n if not is_torchvision_available():\n"
" raise OptionalDependencyNotAvailable()\n"
" except OptionalDependencyNotAvailable:\n"
" pass\n"
" else:\n"
f" from .{fast_image_processing_module_file} import {fast_image_processor_name}\n"
)
if vision_import_statement_block not in updated_content:
raise ValueError("Couldn't find the 'vision _import_structure block' block.")
if added_import_statement_block not in updated_content:
updated_content = updated_content.replace(
vision_import_statement_block, vision_import_statement_block + "\n" + added_import_statement_block
)
# write the updated content
with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "w", encoding="utf-8") as f:
f.write(updated_content)
def add_fast_image_processor_to_auto(image_processor_name: str, fast_image_processor_name: str):
"""
Add the fast image processor to the auto module.
"""
with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "r", encoding="utf-8") as f:
content = f.read()
# get all lines containing the image processor name
updated_content = content.replace(
f'("{image_processor_name}",)', f'("{image_processor_name}", "{fast_image_processor_name}")'
)
# write the updated content
with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "w", encoding="utf-8") as f:
f.write(updated_content)
def add_fast_image_processor_to_dummy(fast_image_processor_name: str):
"""
Add the fast image processor to the dummy torchvision objects file.
"""
dummy_torchvision_objects_file = TRANSFORMERS_PATH / "utils" / "dummy_torchvision_objects.py"
with open(dummy_torchvision_objects_file, "r", encoding="utf-8") as f:
content = f.read()
# regex to find objects starting with "class " and ending with "ImageProcessorFast", including "ImageProcessorFast" in the match
image_processor_names = re.findall(r"class (\w*ImageProcessorFast)", content)
image_processor_names.append(fast_image_processor_name)
image_processor_names.sort()
index_new = image_processor_names.index(fast_image_processor_name)
new_dummy_object = (
f"class {fast_image_processor_name}(metaclass=DummyObject):\n"
' _backends = ["torchvision"]\n\n'
" def __init__(self, *args, **kwargs):\n"
' requires_backends(self, ["torchvision"])\n'
)
if new_dummy_object not in content:
if index_new != len(image_processor_names) - 1:
# add the dummy object just before the next ImageProcessorFast
first_line = f"class {image_processor_names[index_new+1]}(metaclass=DummyObject):"
updated_content = content.replace(first_line, new_dummy_object + "\n\n" + first_line)
else:
# add the dummy object at the very end
updated_content = content + "\n\n" + new_dummy_object
# write the updated content
with open(dummy_torchvision_objects_file, "w", encoding="utf-8") as f:
f.write(updated_content)
def add_fast_image_processor_to_doc(fast_image_processor_name: str, model_name: str):
"""
Add the fast image processor to the model's doc file.
"""
doc_source = REPO_PATH / "docs" / "source"
# find the doc files
doc_files = list(doc_source.glob(f"*/model_doc/{model_name}.md"))
if not doc_files:
# try again with "-"
doc_files = list(doc_source.glob(f"*/model_doc/{model_name.replace('_', '-')}.md"))
if not doc_files:
raise ValueError(f"No doc files found for {model_name}")
base_doc_string = (
f"## {fast_image_processor_name[:-4]}\n\n" f"[[autodoc]] {fast_image_processor_name[:-4]}\n" " - preprocess"
)
fast_doc_string = (
f"## {fast_image_processor_name}\n\n" f"[[autodoc]] {fast_image_processor_name}\n" " - preprocess"
)
for doc_file in doc_files:
with open(doc_file, "r", encoding="utf-8") as f:
content = f.read()
if fast_doc_string not in content:
# add the fast image processor to the doc
updated_content = content.replace(
base_doc_string,
base_doc_string + "\n\n" + fast_doc_string,
)
# write the updated content
with open(doc_file, "w", encoding="utf-8") as f:
f.write(updated_content)
def add_fast_image_processor_to_tests(fast_image_processor_name: str, model_name: str):
"""
Add the fast image processor to the image processing tests.
"""
tests_path = REPO_PATH / "tests" / "models" / model_name
test_file = tests_path / f"test_image_processing_{model_name}.py"
if not os.path.exists(test_file):
logger.warning(f"No test file found for {model_name}. Skipping.")
return
with open(test_file, "r", encoding="utf-8") as f:
content = f.read()
# add is_torchvision_available import to the imports
# Regex to match import statements from transformers.utils
pattern = r"""
from\s+transformers\.utils\s+import\s+
(?: # Non-capturing group for either:
([\w, ]+) # 1. Single-line imports (e.g., 'a, b')
| # OR
\((.*?)\) # 2. Multi-line imports (e.g., '(a, ... b)')
)
"""
regex = re.compile(pattern, re.VERBOSE | re.DOTALL)
def replacement_function(match):
# Extract existing imports
existing_imports = (match.group(1) or match.group(2)).split(",")
existing_imports = existing_imports[:-1] if existing_imports[-1] == "\n" else existing_imports
existing_imports = [imp.strip() for imp in existing_imports]
# Add the new import if not already present
if "is_torchvision_available" not in existing_imports:
existing_imports.append("is_torchvision_available")
existing_imports.sort()
# Rebuild the import statement
if match.group(1): # Single-line import
updated_imports = ", ".join(existing_imports)
else: # Multi-line import
updated_imports = "(\n " + ",\n ".join(existing_imports) + ",\n)"
return f"from transformers.utils import {updated_imports}"
# Replace all matches in the file content
updated_content = regex.sub(replacement_function, content)
# add the fast image processor to the imports
base_import_string = f" from transformers import {fast_image_processor_name[:-4]}"
fast_import_string = (
" if is_torchvision_available():\n" f" from transformers import {fast_image_processor_name}"
)
if fast_import_string not in updated_content:
updated_content = updated_content.replace(base_import_string, base_import_string + "\n\n" + fast_import_string)
# get line starting with " image_processing_class = " and add a line after it starting with " fast_image_processing_class = "
image_processing_class_line = re.search(r" image_processing_class = .*", updated_content)
if not image_processing_class_line:
logger.warning(f"Couldn't find the 'image_processing_class' line in {test_file}. Skipping.")
return
fast_image_processing_class_line = (
f" fast_image_processing_class = {fast_image_processor_name} if is_torchvision_available() else None"
)
if " fast_image_processing_class = " not in updated_content:
updated_content = updated_content.replace(
image_processing_class_line.group(0),
image_processing_class_line.group(0) + "\n" + fast_image_processing_class_line,
)
# write the updated content
with open(test_file, "w", encoding="utf-8") as f:
f.write(updated_content)
def get_fast_image_processing_content_header(content: str) -> str:
"""
Get the header of the slow image processor file.
"""
# get all lines before and including the line containing """Image processor
content_header = re.search(r"^(.*?\n)*?\"\"\"Image processor.*", content)
content_header = content_header.group(0)
content_header = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content_header)
content_header = content_header.replace("Image processor", "Fast Image processor")
return content_header
def write_default_fast_image_processor_file(
fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str
):
"""
Write a default fast image processor file. Used when encountering a problem while parsing the slow image processor file.
"""
imports = "\n\nfrom ...image_processing_utils_fast import BaseImageProcessorFast\n\n\n"
content_header = get_fast_image_processing_content_header(content_base_file)
content_base_file = (
f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
" # To be implemented\n"
" resample = None\n"
" image_mean = None\n"
" image_std = None\n"
" size = None\n"
" default_to_square = None\n"
" crop_size = None\n"
" do_resize = None\n"
" do_center_crop = None\n"
" do_rescale = None\n"
" do_normalize = None\n"
" do_convert_rgb = None\n\n\n"
f'__all__ = ["{fast_image_processor_name}"]\n'
)
content = content_header + imports + content_base_file
with open(fast_image_processing_module_file, "w", encoding="utf-8") as f:
f.write(content)
def add_fast_image_processor_file(
fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str
):
"""
Add the fast image processor file to the model's folder.
"""
# if the file already exists, do nothing
if os.path.exists(fast_image_processing_module_file):
print(f"{fast_image_processing_module_file} already exists. Skipping.")
return
regex = rf"class {fast_image_processor_name[:-4]}.*?(\n\S|$)"
match = re.search(regex, content_base_file, re.DOTALL)
if not match:
print(f"Couldn't find the {fast_image_processor_name[:-4]} class in {fast_image_processing_module_file}")
print("Creating a new file with the default content.")
return write_default_fast_image_processor_file(
fast_image_processing_module_file, fast_image_processor_name, content_base_file
)
# Exclude the last unindented line
slow_class_content = match.group(0).rstrip()
# get default args:
# find the __init__ block which start with def __init__ and ends with def
match = re.search(r"def __init__.*?def ", slow_class_content, re.DOTALL)
if not match:
print(
f"Couldn't find the __init__ block for {fast_image_processor_name[:-4]} in {fast_image_processing_module_file}"
)
print("Creating a new file with the default content.")
return write_default_fast_image_processor_file(
fast_image_processing_module_file, fast_image_processor_name, content_base_file
)
init = match.group(0)
init_signature_block = init.split(")")[0]
arg_names = init_signature_block.split(":")
arg_names = [arg_name.split("\n")[-1].strip() for arg_name in arg_names]
# get the default values
default_args = re.findall(r"= (.*?)(?:,|\))", init_signature_block)
# build default args dict
default_args_dict = dict(zip(arg_names, default_args))
pattern_default_size = r"size = size if size is not None else\s+(.*)"
match_default_size = re.findall(pattern_default_size, init)
default_args_dict["size"] = match_default_size[0] if match_default_size else None
pattern_default_crop_size = r"crop_size = crop_size if crop_size is not None else\s+(.*)"
match_default_crop_size = re.findall(pattern_default_crop_size, init)
default_args_dict["crop_size"] = match_default_crop_size[0] if match_default_crop_size else None
pattern_default_image_mean = r"self.image_mean = image_mean if image_mean is not None else\s+(.*)"
match_default_image_mean = re.findall(pattern_default_image_mean, init)
default_args_dict["image_mean"] = match_default_image_mean[0] if match_default_image_mean else None
pattern_default_image_std = r"self.image_std = image_std if image_std is not None else\s+(.*)"
match_default_image_std = re.findall(pattern_default_image_std, init)
default_args_dict["image_std"] = match_default_image_std[0] if match_default_image_std else None
default_args_dict["default_to_square"] = False if "(size, default_to_square=False" in init else None
content_header = get_fast_image_processing_content_header(content_base_file)
content_base_file = (
f"@add_start_docstrings(\n"
f' "Constructs a fast {fast_image_processor_name.replace("ImageProcessorFast", "")} image processor.",\n'
f" BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,\n)\n"
f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
" # This generated class can be used as a starting point for the fast image processor.\n"
" # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n"
" # only the default values should be set in the class.\n"
" # If the image processor requires more complex augmentations, methods from BaseImageProcessorFast can be overridden.\n"
" # In most cases, only the `_preprocess` method should be overridden.\n\n"
" # For an example of a fast image processor requiring more complex augmentations, see `LlavaNextImageProcessorFast`.\n\n"
" # Default values should be checked against the slow image processor\n"
" # None values left after checking can be removed\n"
f' resample = {default_args_dict.get("resample")}\n'
f' image_mean = {default_args_dict.get("image_mean")}\n'
f' image_std = {default_args_dict.get("image_std")}\n'
f' size = {default_args_dict.get("size")}\n'
f' default_to_square = {default_args_dict.get("default_to_square")}\n'
f' crop_size = {default_args_dict.get("crop_size")}\n'
f' do_resize = {default_args_dict.get("do_resize")}\n'
f' do_center_crop = {default_args_dict.get("do_center_crop")}\n'
f' do_rescale = {default_args_dict.get("do_rescale")}\n'
f' do_normalize = {default_args_dict.get("do_normalize")}\n'
f' do_convert_rgb = {default_args_dict.get("do_convert_rgb")}\n\n\n'
f'__all__ = ["{fast_image_processor_name}"]\n'
)
imports = (
"\n\nfrom ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast\n"
)
image_utils_imports = []
if default_args_dict.get("resample") is not None and "PILImageResampling" in default_args_dict.get("resample"):
image_utils_imports.append("PILImageResampling")
if default_args_dict.get("image_mean") is not None and not any(
char.isdigit() for char in default_args_dict.get("image_mean")
):
image_utils_imports.append(default_args_dict.get("image_mean"))
if default_args_dict.get("image_std") is not None and not any(
char.isdigit() for char in default_args_dict.get("image_std")
):
image_utils_imports.append(default_args_dict.get("image_std"))
if image_utils_imports:
# sort imports
image_utils_imports.sort()
imports += f"from ...image_utils import {', '.join(image_utils_imports)}\n"
imports += "from ...utils import add_start_docstrings\n"
content = content_header + imports + "\n\n" + content_base_file
with open(fast_image_processing_module_file, "w", encoding="utf-8") as f:
f.write(content)
def add_fast_image_processor(model_name: str):
"""
Add the necessary references to the fast image processor in the transformers package,
and create the fast image processor file in the model's folder.
"""
model_module = TRANSFORMERS_PATH / "models" / model_name
image_processing_module_file = list(model_module.glob("image_processing*.py"))
if not image_processing_module_file:
raise ValueError(f"No image processing module found in {model_module}")
elif len(image_processing_module_file) > 1:
for file_name in image_processing_module_file:
if not str(file_name).endswith("_fast.py"):
image_processing_module_file = str(file_name)
break
else:
image_processing_module_file = str(image_processing_module_file[0])
with open(image_processing_module_file, "r", encoding="utf-8") as f:
content_base_file = f.read()
# regex to find object starting with "class " and ending with "ImageProcessor", including "ImageProcessor" in the match
image_processor_name = re.findall(r"class (\w*ImageProcessor)", content_base_file)
if not image_processor_name:
raise ValueError(f"No ImageProcessor class found in {image_processing_module_file}")
elif len(image_processor_name) > 1:
raise ValueError(f"Multiple ImageProcessor classes found in {image_processing_module_file}")
image_processor_name = image_processor_name[0]
fast_image_processor_name = image_processor_name + "Fast"
fast_image_processing_module_file = image_processing_module_file.replace(".py", "_fast.py")
print(f"Adding {fast_image_processor_name} to {fast_image_processing_module_file}")
add_fast_image_processor_to_main_init(
fast_image_processor_name=fast_image_processor_name,
model_name=model_name,
)
add_fast_image_processor_to_model_init(
fast_image_processing_module_file=fast_image_processing_module_file,
fast_image_processor_name=fast_image_processor_name,
model_name=model_name,
)
add_fast_image_processor_to_auto(
image_processor_name=image_processor_name,
fast_image_processor_name=fast_image_processor_name,
)
add_fast_image_processor_to_dummy(fast_image_processor_name=fast_image_processor_name)
add_fast_image_processor_to_doc(
fast_image_processor_name=fast_image_processor_name,
model_name=model_name,
)
add_fast_image_processor_to_tests(
fast_image_processor_name=fast_image_processor_name,
model_name=model_name,
)
add_fast_image_processor_file(
fast_image_processing_module_file=fast_image_processing_module_file,
fast_image_processor_name=fast_image_processor_name,
content_base_file=content_base_file,
)
def add_new_model_like_command_factory(args: Namespace):
return AddFastImageProcessorCommand(model_name=args.model_name)
class AddFastImageProcessorCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
add_fast_image_processor_parser = parser.add_parser("add-fast-image-processor")
add_fast_image_processor_parser.add_argument(
"--model-name",
type=str,
required=True,
help="The name of the folder containing the model's implementation.",
)
add_fast_image_processor_parser.set_defaults(func=add_new_model_like_command_factory)
def __init__(self, model_name: str, *args):
self.model_name = model_name
def run(self):
add_fast_image_processor(model_name=self.model_name)

View File

@ -15,6 +15,7 @@
from transformers import HfArgumentParser
from .add_fast_image_processor import AddFastImageProcessorCommand
from .add_new_model_like import AddNewModelLikeCommand
from .chat import ChatCommand
from .convert import ConvertCommand
@ -40,6 +41,7 @@ def main():
UserCommands.register_subcommand(commands_parser)
AddNewModelLikeCommand.register_subcommand(commands_parser)
LfsCommands.register_subcommand(commands_parser)
AddFastImageProcessorCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@ -13,13 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Iterable, Optional, Union
import numpy as np
from .image_processing_base import BatchFeature, ImageProcessingMixin
from .image_transforms import center_crop, normalize, rescale
from .image_utils import ChannelDimension
from .image_utils import ChannelDimension, get_image_size
from .utils import logging
@ -285,3 +286,23 @@ def select_best_resolution(original_size: tuple, possible_resolutions: list) ->
best_fit = (height, width)
return best_fit
def get_patch_output_size(image, target_resolution, input_data_format):
"""
Given an image and a target resolution, calculate the output size of the image after cropping to the target
"""
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width

View File

@ -13,94 +13,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache, partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypedDict, Union
from .image_processing_utils import BaseImageProcessor
from .utils.import_utils import is_torch_available, is_torchvision_available
import numpy as np
from .image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from .image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
get_size_with_aspect_ratio,
group_images_by_shape,
reorder_images,
)
from .image_utils import (
ChannelDimension,
ImageInput,
ImageType,
SizeDict,
get_image_size,
get_image_size_for_max_height_width,
get_image_type,
infer_channel_dimension_format,
make_flat_list_of_images,
validate_fast_preprocess_arguments,
validate_kwargs,
)
from .processing_utils import Unpack
from .utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
logging,
)
if is_torchvision_available():
from torchvision.transforms import Compose
if is_vision_available():
from .image_utils import PILImageResampling
if is_torch_available():
import torch
if is_torchvision_available():
from .image_utils import pil_torch_interpolation_mapping
@dataclass(frozen=True)
class SizeDict:
"""
Hashable dictionary to store image size information.
"""
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
height: int = None
width: int = None
longest_edge: int = None
shortest_edge: int = None
max_height: int = None
max_width: int = None
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"Key {key} not found in SizeDict.")
class BaseImageProcessorFast(BaseImageProcessor):
_transform_params = None
def _build_transforms(self, **kwargs) -> "Compose":
"""
Given the input settings e.g. do_resize, build the image transforms.
"""
raise NotImplementedError
def _validate_params(self, **kwargs) -> None:
for k, v in kwargs.items():
if k not in self._transform_params:
raise ValueError(f"Invalid transform parameter {k}={v}.")
@functools.lru_cache(maxsize=1)
def get_transforms(self, **kwargs) -> "Compose":
self._validate_params(**kwargs)
return self._build_transforms(**kwargs)
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_transform_params", None)
return encoder_dict
def get_image_size_for_max_height_width(
image_size: Tuple[int, int],
max_height: int,
max_width: int,
) -> Tuple[int, int]:
"""
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
Important, even if image_height < max_height and image_width < max_width, the image will be resized
to at least one of the edges be equal to max_height or max_width.
For example:
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
Args:
image_size (`Tuple[int, int]`):
The image to resize.
max_height (`int`):
The maximum allowed height.
max_width (`int`):
The maximum allowed width.
"""
height, width = image_size
height_scale = max_height / height
width_scale = max_width / width
min_scale = min(height_scale, width_scale)
new_height = int(height * min_scale)
new_width = int(width * min_scale)
return new_height, new_width
logger = logging.get_logger(__name__)
def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
@ -131,3 +101,603 @@ def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]:
_, max_height, max_width = max_across_indices([img.shape for img in images])
return (max_height, max_width)
def divide_to_patches(
image: Union[np.array, "torch.Tensor"], patch_size: int
) -> List[Union[np.array, "torch.Tensor"]]:
"""
Divides an image into patches of a specified size.
Args:
image (`Union[np.array, "torch.Tensor"]`):
The input image.
patch_size (`int`):
The size of each patch.
Returns:
list: A list of Union[np.array, "torch.Tensor"] representing the patches.
"""
patches = []
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
patch = image[:, i : i + patch_size, j : j + patch_size]
patches.append(patch)
return patches
class DefaultFastImageProcessorInitKwargs(TypedDict, total=False):
do_resize: Optional[bool]
size: Optional[Dict[str, int]]
default_to_square: Optional[bool]
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]]
do_center_crop: Optional[bool]
crop_size: Optional[Dict[str, int]]
do_rescale: Optional[bool]
rescale_factor: Optional[Union[int, float]]
do_normalize: Optional[bool]
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
do_convert_rgb: Optional[bool]
class DefaultFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorInitKwargs):
return_tensors: Optional[Union[str, TensorType]]
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
device: Optional["torch.device"]
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING = r"""
Args:
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `self.size`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
Whether to default to a square image when resizing, if size is an int.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
`preprocess` method.
crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`):
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`):
Whether to convert the image to RGB."""
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS = r"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Describes the maximum input dimensions to the model.
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the output image after applying `center_crop`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
device (`torch.device`, *optional*):
The device to process the images on. If unset, the device is inferred from the input images."""
@add_start_docstrings(
"Constructs a fast base image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class BaseImageProcessorFast(BaseImageProcessor):
resample = None
image_mean = None
image_std = None
size = None
default_to_square = True
crop_size = None
do_resize = None
do_center_crop = None
do_rescale = None
rescale_factor = 1 / 255
do_normalize = None
do_convert_rgb = None
model_input_names = ["pixel_values"]
valid_init_kwargs = DefaultFastImageProcessorInitKwargs
valid_preprocess_kwargs = DefaultFastImageProcessorPreprocessKwargs
def __init__(
self,
**kwargs: Unpack[DefaultFastImageProcessorInitKwargs],
) -> None:
super().__init__(**kwargs)
size = kwargs.pop("size", self.size)
self.size = (
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
if size is not None
else None
)
crop_size = kwargs.pop("crop_size", self.crop_size)
self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
for key in self.valid_init_kwargs.__annotations__.keys():
kwarg = kwargs.pop(key, None)
if kwarg is not None:
setattr(self, key, kwarg)
else:
setattr(self, key, getattr(self, key, None))
def resize(
self,
image: "torch.Tensor",
size: SizeDict,
interpolation: "F.InterpolationMode" = None,
**kwargs,
) -> "torch.Tensor":
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`torch.Tensor`):
Image to resize.
size (`SizeDict`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
Returns:
`torch.Tensor`: The resized image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if size.shortest_edge and size.longest_edge:
# Resize the image so that the shortest edge or the longest edge is of the given size
# while maintaining the aspect ratio of the original image.
new_size = get_size_with_aspect_ratio(
image.size()[-2:],
size.shortest_edge,
size.longest_edge,
)
elif size.shortest_edge:
new_size = get_resize_output_image_size(
image,
size=size.shortest_edge,
default_to_square=False,
input_data_format=ChannelDimension.FIRST,
)
elif size.max_height and size.max_width:
new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
elif size.height and size.width:
new_size = (size.height, size.width)
else:
raise ValueError(
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
f" {size}."
)
return F.resize(image, new_size, interpolation=interpolation)
def rescale(
self,
image: "torch.Tensor",
scale: float,
**kwargs,
) -> "torch.Tensor":
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`torch.Tensor`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
Returns:
`torch.Tensor`: The rescaled image.
"""
return image * scale
def normalize(
self,
image: "torch.Tensor",
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
**kwargs,
) -> "torch.Tensor":
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`torch.Tensor`):
Image to normalize.
mean (`torch.Tensor`, `float` or `Iterable[float]`):
Image mean to use for normalization.
std (`torch.Tensor`, `float` or `Iterable[float]`):
Image standard deviation to use for normalization.
Returns:
`torch.Tensor`: The normalized image.
"""
return F.normalize(image, mean, std)
def rescale_and_normalize(
self,
images: "torch.Tensor",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
) -> "torch.Tensor":
"""
Rescale and normalize images.
"""
if do_rescale and do_normalize:
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
images = images * rescale_factor
elif do_normalize:
images = self.normalize(images, image_mean, image_std)
return images
def center_crop(
self,
image: "torch.Tensor",
size: Dict[str, int],
**kwargs,
) -> "torch.Tensor":
"""
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
any edge, the image is padded with 0's and then center cropped.
Args:
image (`"torch.Tensor"`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
Returns:
`torch.Tensor`: The center cropped image.
"""
if size.height is None or size.width is None:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return F.center_crop(image, (size["height"], size["width"]))
def convert_to_rgb(
self,
image: ImageInput,
) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
Args:
image (ImageInput):
The image to convert.
Returns:
ImageInput: The converted image.
"""
return convert_to_rgb(image)
def _prepare_images_structure(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _process_image(
self,
image: ImageInput,
do_convert_rgb: Optional[bool] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
) -> "torch.Tensor":
image_type = get_image_type(image)
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
if do_convert_rgb:
image = self.convert_to_rgb(image)
if image_type == ImageType.PIL:
image = F.pil_to_tensor(image)
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
image = torch.from_numpy(image).contiguous()
# Infer the channel dimension format if not provided
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
# We force the channel dimension to be first for torch tensors as this is what torchvision expects.
image = image.permute(2, 0, 1).contiguous()
# Now that we have torch tensors, we can move them to the right device
if device is not None:
image = image.to(device)
return image
def _prepare_input_images(
self,
images: ImageInput,
do_convert_rgb: bool = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
) -> List["torch.Tensor"]:
"""
Prepare the input images for processing.
"""
images = self._prepare_images_structure(images)
process_image_fn = partial(
self._process_image,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
with ThreadPoolExecutor() as executor:
processed_images = list(executor.map(process_image_fn, images))
return processed_images
@lru_cache(maxsize=10)
def _prepare_process_arguments(
self,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
device: Optional["torch.device"] = None,
) -> tuple:
"""
Prepare the arguments for the process method.
"""
validate_fast_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
resample=resample,
return_tensors=return_tensors,
data_format=data_format,
)
if do_rescale and do_normalize:
# Fused rescale and normalize
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
interpolation = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
return image_mean, image_std, interpolation
@add_start_docstrings(BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS)
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[DefaultFastImageProcessorPreprocessKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_preprocess_kwargs.__annotations__.keys()
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_preprocess_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Pop kwargs that need further processing or won't be used in _preprocess
default_to_square = kwargs.pop("default_to_square")
size = kwargs.pop("size")
crop_size = kwargs.pop("crop_size")
image_mean = kwargs.pop("image_mean")
image_std = kwargs.pop("image_std")
data_format = kwargs.pop("data_format")
resample = kwargs.pop("resample")
# Make hashable for cache
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) if size is not None else None
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) if crop_size is not None else None
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
image_mean, image_std, interpolation = self._prepare_process_arguments(
size=size,
crop_size=crop_size,
resample=resample,
image_mean=image_mean,
image_std=image_std,
data_format=data_format if data_format is not None else ChannelDimension.FIRST,
device=images[0].device,
do_resize=kwargs.get("do_resize"),
do_center_crop=kwargs.get("do_center_crop"),
do_rescale=kwargs.get("do_rescale"),
rescale_factor=kwargs.get("rescale_factor"),
do_normalize=kwargs.get("do_normalize"),
return_tensors=kwargs.get("return_tensors"),
)
return self._preprocess(
images=images,
size=size,
crop_size=crop_size,
interpolation=interpolation,
image_mean=image_mean,
image_std=image_std,
**kwargs,
)
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
return encoder_dict
class SemanticSegmentationMixin:
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`MobileNetV2ForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
# if is_torch_tensor(target_sizes):
# target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -15,7 +15,7 @@
import warnings
from math import ceil
from typing import Iterable, List, Optional, Sequence, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
@ -31,8 +31,6 @@ from .utils.import_utils import (
is_flax_available,
is_tf_available,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
requires_backends,
)
@ -52,11 +50,6 @@ if is_tf_available():
if is_flax_available():
import jax.numpy as jnp
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
elif is_torchvision_available():
from torchvision.transforms import functional as F
def to_channel_dimension_format(
image: np.ndarray,
@ -216,6 +209,45 @@ def to_pil_image(
return PIL.Image.fromarray(image, mode=image_mode)
def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]:
"""
Computes the output image size given the input image size and the desired output size.
Args:
image_size (`Tuple[int, int]`):
The input image size.
size (`int`):
The desired output size.
max_size (`int`, *optional*):
The maximum allowed output size.
"""
height, width = image_size
raw_size = None
if max_size is not None:
min_original_size = float(min((height, width)))
max_original_size = float(max((height, width)))
if max_original_size / min_original_size * size > max_size:
raw_size = max_size * min_original_size / max_original_size
size = int(round(raw_size))
if (height <= width and height == size) or (width <= height and width == size):
oh, ow = height, width
elif width < height:
ow = size
if max_size is not None and raw_size is not None:
oh = int(raw_size * height / width)
else:
oh = int(size * height / width)
else:
oh = size
if max_size is not None and raw_size is not None:
ow = int(raw_size * width / height)
else:
ow = int(size * width / height)
return (oh, ow)
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
def get_resize_output_image_size(
input_image: np.ndarray,
@ -821,32 +853,37 @@ def _cast_tensor_to_float(x):
return x.float()
class FusedRescaleNormalize:
def group_images_by_shape(
images: List["torch.Tensor"],
) -> Tuple[Dict[Tuple[int, int], List["torch.Tensor"]], Dict[int, Tuple[Tuple[int, int], int]]]:
"""
Rescale and normalize the input image in one step.
Groups images by shape.
Returns a dictionary with the shape as key and a list of images with that shape as value,
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
"""
def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
self.mean = torch.tensor(mean) * (1.0 / rescale_factor)
self.std = torch.tensor(std) * (1.0 / rescale_factor)
self.inplace = inplace
def __call__(self, image: "torch.Tensor"):
image = _cast_tensor_to_float(image)
return F.normalize(image, self.mean, self.std, inplace=self.inplace)
grouped_images = {}
grouped_images_index = {}
for i, image in enumerate(images):
shape = image.shape[1:]
if shape not in grouped_images:
grouped_images[shape] = []
grouped_images[shape].append(image)
grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)
# stack images with the same shape
grouped_images = {shape: torch.stack(images, dim=0) for shape, images in grouped_images.items()}
return grouped_images, grouped_images_index
class Rescale:
def reorder_images(
processed_images: Dict[Tuple[int, int], "torch.Tensor"], grouped_images_index: Dict[int, Tuple[int, int]]
) -> List["torch.Tensor"]:
"""
Rescale the input image by rescale factor: image *= rescale_factor.
Reconstructs a list of images in the original order.
"""
def __init__(self, rescale_factor: float = 1.0):
self.rescale_factor = rescale_factor
def __call__(self, image: "torch.Tensor"):
image = image * self.rescale_factor
return image
return [
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
for i in range(len(grouped_images_index))
]
class NumpyToTensor:

View File

@ -16,6 +16,7 @@
import base64
import os
from contextlib import redirect_stdout
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
@ -426,6 +427,37 @@ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> T
raise ValueError(f"Unsupported data format: {channel_dim}")
def get_image_size_for_max_height_width(
image_size: Tuple[int, int],
max_height: int,
max_width: int,
) -> Tuple[int, int]:
"""
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
Important, even if image_height < max_height and image_width < max_width, the image will be resized
to at least one of the edges be equal to max_height or max_width.
For example:
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
Args:
image_size (`Tuple[int, int]`):
The image to resize.
max_height (`int`):
The maximum allowed height.
max_width (`int`):
The maximum allowed width.
"""
height, width = image_size
height_scale = max_height / height
width_scale = max_width / width
min_scale = min(height_scale, width_scale)
new_height = int(height * min_scale)
new_width = int(width * min_scale)
return new_height, new_width
def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
if (
isinstance(annotation, dict)
@ -795,12 +827,16 @@ def validate_fast_preprocess_arguments(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# Extra checks for ImageProcessorFast
if return_tensors != "pt":
if return_tensors is not None and return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
@ -1190,3 +1226,22 @@ def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str])
unused_key_str = ", ".join(unused_keys)
# TODO raise a warning here instead of simply logging?
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
@dataclass(frozen=True)
class SizeDict:
"""
Hashable dictionary to store image size information.
"""
height: int = None
width: int = None
longest_edge: int = None
shortest_edge: int = None
max_height: int = None
max_width: int = None
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
raise KeyError(f"Key {key} not found in SizeDict.")

View File

@ -59,20 +59,20 @@ else:
("aria", ("AriaImageProcessor")),
("beit", ("BeitImageProcessor",)),
("bit", ("BitImageProcessor",)),
("blip", ("BlipImageProcessor",)),
("blip-2", ("BlipImageProcessor",)),
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
("bridgetower", ("BridgeTowerImageProcessor",)),
("chameleon", ("ChameleonImageProcessor",)),
("chinese_clip", ("ChineseCLIPImageProcessor",)),
("clip", ("CLIPImageProcessor",)),
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
("conditional_detr", ("ConditionalDetrImageProcessor",)),
("convnext", ("ConvNextImageProcessor",)),
("convnextv2", ("ConvNextImageProcessor",)),
("cvt", ("ConvNextImageProcessor",)),
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("data2vec-vision", ("BeitImageProcessor",)),
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
("deit", ("DeiTImageProcessor",)),
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
("depth_anything", ("DPTImageProcessor",)),
("deta", ("DetaImageProcessor",)),
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
@ -85,27 +85,27 @@ else:
("flava", ("FlavaImageProcessor",)),
("focalnet", ("BitImageProcessor",)),
("fuyu", ("FuyuImageProcessor",)),
("git", ("CLIPImageProcessor",)),
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glpn", ("GLPNImageProcessor",)),
("got_ocr2", ("GotOcr2ImageProcessor",)),
("grounding-dino", ("GroundingDinoImageProcessor",)),
("groupvit", ("CLIPImageProcessor",)),
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("hiera", ("BitImageProcessor",)),
("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)),
("idefics3", ("Idefics3ImageProcessor",)),
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor",)),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
("kosmos-2", ("CLIPImageProcessor",)),
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
("levit", ("LevitImageProcessor",)),
("llava", ("LlavaImageProcessor",)),
("llava_next", ("LlavaNextImageProcessor",)),
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
("llava_onevision", ("LlavaOnevisionImageProcessor",)),
("llava_onevision", ("LlavaOnevisionImageProcessor", "LlavaOnevisionImageProcessorFast")),
("mask2former", ("Mask2FormerImageProcessor",)),
("maskformer", ("MaskFormerImageProcessor",)),
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
@ -119,7 +119,7 @@ else:
("oneformer", ("OneFormerImageProcessor",)),
("owlv2", ("Owlv2ImageProcessor",)),
("owlvit", ("OwlViTImageProcessor",)),
("paligemma", ("SiglipImageProcessor",)),
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
@ -127,13 +127,13 @@ else:
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("regnet", ("ConvNextImageProcessor",)),
("resnet", ("ConvNextImageProcessor",)),
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
("sam", ("SamImageProcessor",)),
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),
("siglip", ("SiglipImageProcessor",)),
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("superglue", "SuperGlueImageProcessor"),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
@ -146,16 +146,16 @@ else:
("tvp", ("TvpImageProcessor",)),
("udop", ("LayoutLMv3ImageProcessor",)),
("upernet", ("SegformerImageProcessor",)),
("van", ("ConvNextImageProcessor",)),
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("videomae", ("VideoMAEImageProcessor",)),
("vilt", ("ViltImageProcessor",)),
("vipllava", ("CLIPImageProcessor",)),
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vit_hybrid", ("ViTHybridImageProcessor",)),
("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vitmatte", ("VitMatteImageProcessor",)),
("xclip", ("CLIPImageProcessor",)),
("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("yolos", ("YolosImageProcessor",)),
("zoedepth", ("ZoeDepthImageProcessor",)),
]

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_blip import *
from .image_processing_blip import *
from .image_processing_blip_fast import *
from .modeling_blip import *
from .modeling_tf_blip import *
from .processing_blip import *

View File

@ -0,0 +1,39 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for BLIP."""
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from ...utils import add_start_docstrings
@add_start_docstrings(
"Constructs a fast BLIP image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class BlipImageProcessorFast(BaseImageProcessorFast):
# To be checked against the slow image processor
# None values left after checking can be removed
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 384, "width": 384}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
__all__ = ["BlipImageProcessorFast"]

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_clip import *
from .feature_extraction_clip import *
from .image_processing_clip import *
from .image_processing_clip_fast import *
from .modeling_clip import *
from .modeling_flax_clip import *
from .modeling_tf_clip import *

View File

@ -0,0 +1,42 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for CLIP."""
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from ...utils import add_start_docstrings
@add_start_docstrings(
"Constructs a fast CLIP image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class CLIPImageProcessorFast(BaseImageProcessorFast):
# To be checked against the slow image processor
# None values left after checking can be removed
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"shortest_edge": 224}
default_to_square = False
crop_size = {"height": 224, "width": 224}
do_resize = True
do_center_crop = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
__all__ = ["CLIPImageProcessorFast"]

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_convnext import *
from .feature_extraction_convnext import *
from .image_processing_convnext import *
from .image_processing_convnext_fast import *
from .modeling_convnext import *
from .modeling_tf_convnext import *
else:

View File

@ -0,0 +1,207 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for ConvNeXT."""
from typing import Dict, List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_transforms import get_resize_output_image_size
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class ConvNextFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
crop_pct: Optional[float]
class ConvNextFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
crop_pct: Optional[float]
@add_start_docstrings(
r"Constructs a fast ConvNeXT image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
crop_pct (`float`, *optional*):
Percentage of the image to crop. Only has an effect if size < 384. Can be
overridden by `crop_pct` in the`preprocess` method.
""",
)
class ConvNextImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"shortest_edge": 384}
default_to_square = False
do_resize = True
do_rescale = True
do_normalize = True
crop_pct = 224 / 256
valid_init_kwargs = ConvNextFastImageProcessorInitKwargs
valid_preprocess_kwargs = ConvNextFastImageProcessorPreprocessKwargs
def __init__(self, **kwargs: Unpack[ConvNextFastImageProcessorInitKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
crop_pct (`float`, *optional*):
Percentage of the image to crop. Only has an effect if size < 384. Can be
overridden by `crop_pct` in the`preprocess` method.
""",
)
def preprocess(
self, images: ImageInput, **kwargs: Unpack[ConvNextFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def resize(
self,
image: "torch.Tensor",
size: Dict[str, int],
crop_pct: float,
interpolation: PILImageResampling = PILImageResampling.BICUBIC,
**kwargs,
) -> "torch.Tensor":
"""
Resize an image.
Args:
image (`torch.Tensor`):
Image to resize.
size (`Dict[str, int]`):
Dictionary of the form `{"shortest_edge": int}`, specifying the size of the output image. If
`size["shortest_edge"]` >= 384 image is resized to `(size["shortest_edge"], size["shortest_edge"])`.
Otherwise, the smaller edge of the image will be matched to `int(size["shortest_edge"] / crop_pct)`,
after which the image is cropped to `(size["shortest_edge"], size["shortest_edge"])`.
crop_pct (`float`):
Percentage of the image to crop. Only has an effect if size < 384.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resizing the image.
Returns:
`torch.Tensor`: Resized image.
"""
if not size.shortest_edge:
raise ValueError(f"Size dictionary must contain 'shortest_edge' key. Got {size.keys()}")
shortest_edge = size["shortest_edge"]
if shortest_edge < 384:
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
resize_shortest_edge = int(shortest_edge / crop_pct)
resize_size = get_resize_output_image_size(
image, size=resize_shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST
)
image = F.resize(
image,
resize_size,
interpolation=interpolation,
**kwargs,
)
# then crop to (shortest_edge, shortest_edge)
return F.center_crop(
image,
(shortest_edge, shortest_edge),
**kwargs,
)
else:
# warping (no cropping) when evaluated at 384 or larger
return F.resize(
image,
(shortest_edge, shortest_edge),
interpolation=interpolation,
**kwargs,
)
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: Dict[str, int],
crop_pct: float,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: int,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(
image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation
)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
__all__ = ["ConvNextImageProcessorFast"]

View File

@ -4,13 +4,16 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_deformable_detr.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import functools
import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
SizeDict,
get_image_size_for_max_height_width,
get_max_height_width,
@ -24,21 +27,17 @@ from ...image_utils import (
AnnotationType,
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
get_image_size,
get_image_type,
infer_channel_dimension_format,
make_list_of_images,
validate_annotations,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
logging,
)
from .image_processing_deformable_detr import get_size_with_aspect_ratio
@ -47,9 +46,6 @@ from .image_processing_deformable_detr import get_size_with_aspect_ratio
if is_torch_available():
import torch
if is_vision_available():
from ...image_utils import pil_torch_interpolation_mapping
if is_torchvision_v2_available():
from torchvision.io import read_image
@ -61,6 +57,24 @@ elif is_torchvision_available():
logger = logging.get_logger(__name__)
class DeformableDetrFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
format: Optional[Union[str, AnnotationFormat]]
do_convert_annotations: Optional[bool]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
class DeformableDetrFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
format: Optional[AnnotationFormat]
annotations: Optional[Dict]
do_convert_annotations: Optional[bool]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
@ -261,44 +275,12 @@ def prepare_coco_panoptic_annotation(
return new_target
class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a fast DeformableDetr image processor.
Args:
@add_start_docstrings(
"Constructs a fast DeformableDetr image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
do_resize (`bool`, *optional*, defaults to `True`):
Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be
overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
in the `preprocess` method. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
`preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_annotations (`bool`, *optional*, defaults to `True`):
Controls whether to convert the annotations to the format expected by the DEFORMABLE_DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
@ -312,29 +294,28 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
"""
""",
)
class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
format = AnnotationFormat.COCO_DETECTION
do_resize = True
do_rescale = True
do_normalize = True
do_pad = True
size = {"shortest_edge": 800, "longest_edge": 1333}
default_to_square = False
model_input_names = ["pixel_values", "pixel_mask"]
valid_init_kwargs = DeformableDetrFastImageProcessorInitKwargs
valid_preprocess_kwargs = DeformableDetrFastImageProcessorPreprocessKwargs
def __init__(
self,
format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
do_convert_annotations: Optional[bool] = None,
do_pad: bool = True,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
) -> None:
def __init__(self, **kwargs: Unpack[DeformableDetrFastImageProcessorInitKwargs]) -> None:
if "pad_and_return_pixel_mask" in kwargs:
do_pad = kwargs.pop("pad_and_return_pixel_mask")
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
size = kwargs.pop("size", None)
if "max_size" in kwargs:
logger.warning_once(
"The `max_size` parameter is deprecated and will be removed in v4.26. "
@ -345,46 +326,15 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
max_size = None if size is None else 1333
size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
size = get_size_dict(size, max_size=max_size, default_to_square=False)
self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
# Backwards compatibility
if do_convert_annotations is None:
do_convert_annotations = do_normalize
do_convert_annotations = kwargs.get("do_convert_annotations", None)
do_normalize = kwargs.get("do_normalize", None)
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
super().__init__(**kwargs)
self.format = format
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_convert_annotations = do_convert_annotations
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
self.pad_size = pad_size
self._valid_processor_keys = [
"images",
"annotations",
"return_segmentation_masks",
"masks_path",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"do_convert_annotations",
"image_mean",
"image_std",
"do_pad",
"pad_size",
"format",
"return_tensors",
"data_format",
"input_data_format",
]
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
@ -619,187 +569,85 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
return image, pixel_mask, annotation
@functools.lru_cache(maxsize=1)
def _validate_input_arguments(
self,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
do_resize: bool,
size: Dict[str, int],
resample: "PILImageResampling",
data_format: Union[str, ChannelDimension],
return_tensors: Union[TensorType, str],
):
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
if do_resize and None in (size, resample):
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and None in (image_mean, image_std):
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
do_convert_annotations (`bool`, *optional*, defaults to `True`):
Controls whether to convert the annotations to the format expected by the DEFORMABLE_DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
If `pad_size` is provided, the image will be padded to the specified dimensions.
Otherwise, the image will be padded to the maximum height and width of the batch.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
""",
)
def preprocess(
self,
images: ImageInput,
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
return_segmentation_masks: bool = None,
masks_path: Optional[Union[str, pathlib.Path]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[Union[int, float]] = None,
do_normalize: Optional[bool] = None,
do_convert_annotations: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
format: Optional[Union[str, AnnotationFormat]] = None,
return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
self, images: ImageInput, **kwargs: Unpack[DeformableDetrFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.
Args:
images (`ImageInput`):
Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
do_resize (`bool`, *optional*, defaults to self.do_resize):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to self.size):
Size of the image's `(height, width)` dimensions after resizing. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample):
Resampling filter to use when resizing the image.
do_rescale (`bool`, *optional*, defaults to self.do_rescale):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
Rescale factor to use when rescaling the image.
do_normalize (`bool`, *optional*, defaults to self.do_normalize):
Whether to normalize the image.
do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
Whether to convert the annotations to the format expected by the model. Converts the bounding
boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
and in relative coordinates.
image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
Mean to use when normalizing the image.
image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
Standard deviation to use when normalizing the image.
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True`, padding will be applied to the bottom and right of
the image with zeros. If `pad_size` is provided, the image will be padded to the specified
dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
Format of the annotations.
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
Type of tensors to return. If `None`, will return the list of images.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
"""
if "pad_and_return_pixel_mask" in kwargs:
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
logger.warning_once(
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
"use `do_pad` instead."
)
do_pad = kwargs.pop("pad_and_return_pixel_mask")
if "max_size" in kwargs:
logger.warning_once(
"The `max_size` argument is deprecated and will be removed in a future version, use"
" `size['longest_edge']` instead."
)
size = kwargs.pop("max_size")
do_resize = self.do_resize if do_resize is None else do_resize
size = self.size if size is None else size
size = get_size_dict(size=size, default_to_square=False)
resample = self.resample if resample is None else resample
do_rescale = self.do_rescale if do_rescale is None else do_rescale
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
do_normalize = self.do_normalize if do_normalize is None else do_normalize
image_mean = self.image_mean if image_mean is None else image_mean
image_std = self.image_std if image_std is None else image_std
do_convert_annotations = (
self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
)
do_pad = self.do_pad if do_pad is None else do_pad
pad_size = self.pad_size if pad_size is None else pad_size
format = self.format if format is None else format
device = kwargs.pop("device", None)
kwargs["size"] = kwargs.pop("max_size")
# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
images = make_list_of_images(images)
image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
self._validate_input_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
return_tensors=return_tensors,
data_format=data_format,
)
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
do_convert_annotations: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.
"""
if annotations is not None and isinstance(annotations, dict):
annotations = [annotations]
@ -823,26 +671,6 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
)
data = {}
if image_type == ImageType.PIL:
images = [F.pil_to_tensor(image) for image in images]
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
images = [torch.from_numpy(image).contiguous() for image in images]
if device is not None:
images = [image.to(device) for image in images]
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.LAST:
images = [image.permute(2, 0, 1).contiguous() for image in images]
input_data_format = ChannelDimension.FIRST
if do_rescale and do_normalize:
# fused rescale and normalize
new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor)
new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor)
processed_images = []
processed_annotations = []
@ -856,15 +684,10 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
format,
return_segmentation_masks=return_segmentation_masks,
masks_path=masks_path,
input_data_format=input_data_format,
input_data_format=ChannelDimension.FIRST,
)
if do_resize:
interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
resized_image = self.resize(image, size=size, interpolation=interpolation)
if annotations is not None:
annotation = self.resize_annotation(
@ -876,14 +699,14 @@ class DeformableDetrImageProcessorFast(BaseImageProcessorFast):
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, input_data_format))
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
processed_images.append(image)
processed_annotations.append(annotation)

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_deit import *
from .feature_extraction_deit import *
from .image_processing_deit import *
from .image_processing_deit_fast import *
from .modeling_deit import *
from .modeling_tf_deit import *
else:

View File

@ -0,0 +1,44 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for DeiT."""
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
PILImageResampling,
)
from ...utils import add_start_docstrings
@add_start_docstrings(
"Constructs a fast DeiT image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class DeiTImageProcessorFast(BaseImageProcessorFast):
# To be checked against the slow image processor
# None values left after checking can be removed
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 256, "width": 256}
crop_size = {"height": 224, "width": 224}
do_resize = True
do_center_crop = True
do_rescale = True
do_normalize = True
__all__ = ["DeiTImageProcessorFast"]

View File

@ -14,7 +14,6 @@
# limitations under the License.
"""Fast Image processor class for DETR."""
import functools
import io
import pathlib
from collections import defaultdict
@ -22,7 +21,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
SizeDict,
get_image_size_for_max_height_width,
get_max_height_width,
@ -40,17 +43,14 @@ from ...image_utils import (
AnnotationType,
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
get_image_size,
get_image_type,
infer_channel_dimension_format,
make_list_of_images,
validate_annotations,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
@ -72,8 +72,6 @@ if is_torch_available():
if is_vision_available():
import PIL
from ...image_utils import pil_torch_interpolation_mapping
if is_torchvision_v2_available():
from torchvision.io import read_image
@ -285,44 +283,29 @@ def prepare_coco_panoptic_annotation(
return new_target
class DetrImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a fast Detr image processor.
class DetrFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
format: Optional[Union[str, AnnotationFormat]]
do_convert_annotations: Optional[bool]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
Args:
class DetrFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
format: Optional[AnnotationFormat]
annotations: Optional[Dict]
do_convert_annotations: Optional[bool]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
@add_start_docstrings(
"Constructs a fast Detr image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
do_resize (`bool`, *optional*, defaults to `True`):
Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be
overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
in the `preprocess` method. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
`preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_annotations (`bool`, *optional*, defaults to `True`):
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
@ -336,29 +319,28 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
"""
""",
)
class DetrImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
format = AnnotationFormat.COCO_DETECTION
do_resize = True
do_rescale = True
do_normalize = True
do_pad = True
size = {"shortest_edge": 800, "longest_edge": 1333}
default_to_square = False
model_input_names = ["pixel_values", "pixel_mask"]
valid_init_kwargs = DetrFastImageProcessorInitKwargs
valid_preprocess_kwargs = DetrFastImageProcessorPreprocessKwargs
def __init__(
self,
format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
do_convert_annotations: Optional[bool] = None,
do_pad: bool = True,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
) -> None:
def __init__(self, **kwargs: Unpack[DetrFastImageProcessorInitKwargs]) -> None:
if "pad_and_return_pixel_mask" in kwargs:
do_pad = kwargs.pop("pad_and_return_pixel_mask")
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
size = kwargs.pop("size", None)
if "max_size" in kwargs:
logger.warning_once(
"The `max_size` parameter is deprecated and will be removed in v4.26. "
@ -369,46 +351,15 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
max_size = None if size is None else 1333
size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
size = get_size_dict(size, max_size=max_size, default_to_square=False)
self.size = get_size_dict(size, max_size=max_size, default_to_square=False)
# Backwards compatibility
if do_convert_annotations is None:
do_convert_annotations = do_normalize
do_convert_annotations = kwargs.get("do_convert_annotations", None)
do_normalize = kwargs.get("do_normalize", None)
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
super().__init__(**kwargs)
self.format = format
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_convert_annotations = do_convert_annotations
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
self.pad_size = pad_size
self._valid_processor_keys = [
"images",
"annotations",
"return_segmentation_masks",
"masks_path",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"do_convert_annotations",
"image_mean",
"image_std",
"do_pad",
"pad_size",
"format",
"return_tensors",
"data_format",
"input_data_format",
]
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
@ -643,187 +594,83 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
return image, pixel_mask, annotation
@functools.lru_cache(maxsize=1)
def _validate_input_arguments(
self,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
do_resize: bool,
size: Dict[str, int],
resample: "PILImageResampling",
data_format: Union[str, ChannelDimension],
return_tensors: Union[TensorType, str],
):
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
if do_resize and None in (size, resample):
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and None in (image_mean, image_std):
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
def preprocess(
self,
images: ImageInput,
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
return_segmentation_masks: bool = None,
masks_path: Optional[Union[str, pathlib.Path]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[Union[int, float]] = None,
do_normalize: Optional[bool] = None,
do_convert_annotations: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
format: Optional[Union[str, AnnotationFormat]] = None,
return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.
Args:
images (`ImageInput`):
Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
do_resize (`bool`, *optional*, defaults to self.do_resize):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to self.size):
Size of the image's `(height, width)` dimensions after resizing. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample):
Resampling filter to use when resizing the image.
do_rescale (`bool`, *optional*, defaults to self.do_rescale):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
Rescale factor to use when rescaling the image.
do_normalize (`bool`, *optional*, defaults to self.do_normalize):
Whether to normalize the image.
do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
Whether to convert the annotations to the format expected by the model. Converts the bounding
boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
and in relative coordinates.
image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
Mean to use when normalizing the image.
image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
Standard deviation to use when normalizing the image.
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True`, padding will be applied to the bottom and right of
the image with zeros. If `pad_size` is provided, the image will be padded to the specified
dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
Format of the annotations.
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
Type of tensors to return. If `None`, will return the list of images.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
do_convert_annotations (`bool`, *optional*, defaults to `True`):
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
If `pad_size` is provided, the image will be padded to the specified dimensions.
Otherwise, the image will be padded to the maximum height and width of the batch.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[DetrFastImageProcessorPreprocessKwargs]) -> BatchFeature:
if "pad_and_return_pixel_mask" in kwargs:
kwargs["do_pad"] = kwargs.pop("pad_and_return_pixel_mask")
logger.warning_once(
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
"use `do_pad` instead."
)
do_pad = kwargs.pop("pad_and_return_pixel_mask")
if "max_size" in kwargs:
logger.warning_once(
"The `max_size` argument is deprecated and will be removed in a future version, use"
" `size['longest_edge']` instead."
)
size = kwargs.pop("max_size")
do_resize = self.do_resize if do_resize is None else do_resize
size = self.size if size is None else size
size = get_size_dict(size=size, default_to_square=False)
resample = self.resample if resample is None else resample
do_rescale = self.do_rescale if do_rescale is None else do_rescale
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
do_normalize = self.do_normalize if do_normalize is None else do_normalize
image_mean = self.image_mean if image_mean is None else image_mean
image_std = self.image_std if image_std is None else image_std
do_convert_annotations = (
self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
)
do_pad = self.do_pad if do_pad is None else do_pad
pad_size = self.pad_size if pad_size is None else pad_size
format = self.format if format is None else format
device = kwargs.pop("device", None)
kwargs["size"] = kwargs.pop("max_size")
# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
images = make_list_of_images(images)
image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
self._validate_input_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
return_tensors=return_tensors,
data_format=data_format,
)
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
do_convert_annotations: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.
"""
if annotations is not None and isinstance(annotations, dict):
annotations = [annotations]
@ -847,26 +694,6 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
)
data = {}
if image_type == ImageType.PIL:
images = [F.pil_to_tensor(image) for image in images]
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
images = [torch.from_numpy(image).contiguous() for image in images]
if device is not None:
images = [image.to(device) for image in images]
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.LAST:
images = [image.permute(2, 0, 1).contiguous() for image in images]
input_data_format = ChannelDimension.FIRST
if do_rescale and do_normalize:
# fused rescale and normalize
new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor)
new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor)
processed_images = []
processed_annotations = []
@ -880,15 +707,10 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
format,
return_segmentation_masks=return_segmentation_masks,
masks_path=masks_path,
input_data_format=input_data_format,
input_data_format=ChannelDimension.FIRST,
)
if do_resize:
interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
resized_image = self.resize(image, size=size, interpolation=interpolation)
if annotations is not None:
annotation = self.resize_annotation(
@ -900,14 +722,14 @@ class DetrImageProcessorFast(BaseImageProcessorFast):
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, input_data_format))
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
processed_images.append(image)
processed_annotations.append(annotation)

View File

@ -19,6 +19,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_llava import *
from .image_processing_llava_fast import *
from .modeling_llava import *
from .processing_llava import *
else:

View File

@ -420,7 +420,7 @@ class LlavaImageProcessor(BaseImageProcessor):
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale:
images = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(

View File

@ -0,0 +1,209 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for LLaVa."""
from typing import List, Optional, Tuple, Union
from ...image_processing_utils import (
BatchFeature,
)
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
)
if is_vision_available():
from ...image_utils import PILImageResampling
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class LlavaFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
do_pad: Optional[bool]
class LlavaFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
do_pad: Optional[bool]
@add_start_docstrings(
"Constructs a fast Llava image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter
""",
)
class LlavaImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"shortest_edge": 224}
default_to_square = False
crop_size = {"height": 224, "width": 224}
do_pad = False
do_resize = True
do_center_crop = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
valid_init_kwargs = LlavaFastImageProcessorInitKwargs
valid_preprocess_kwargs = LlavaFastImageProcessorPreprocessKwargs
def __init__(self, **kwargs: Unpack[LlavaFastImageProcessorInitKwargs]) -> None:
super().__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter
""",
)
def preprocess(
self, images: ImageInput, **kwargs: Unpack[LlavaFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def pad_to_square(
self,
images: "torch.Tensor",
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to a square based on the longest edge.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
height, width = get_image_size(images, ChannelDimension.FIRST)
if height == width:
return images
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
max_dim = max(height, width)
paste_x_left = (max_dim - width) // 2
paste_y_left = (max_dim - height) // 2
paste_x_right = max_dim - width - paste_x_left
paste_y_right = max_dim - height - paste_y_left
padded_images = F.pad(
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
)
return padded_images
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_pad: bool,
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_pad:
stacked_images = self.pad_to_square(
images=stacked_images, background_color=tuple(int(x * 255) for x in self.image_mean)
)
resized_images_grouped[shape] = stacked_images
padded_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for batched resizing
# Needed in case do_pad is False, or padding returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(padded_images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
__all__ = ["LlavaImageProcessorFast"]

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_llava_next import *
from .image_processing_llava_next import *
from .image_processing_llava_next_fast import *
from .modeling_llava_next import *
from .processing_llava_next import *
else:

View File

@ -0,0 +1,323 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for LLaVa-NeXT."""
from typing import List, Optional, Union
from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
divide_to_patches,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
make_flat_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class LlavaNextFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
image_grid_pinpoints: Optional[List[List[int]]]
do_pad: Optional[bool]
class LlavaNextFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
image_grid_pinpoints: Optional[List[List[int]]]
do_pad: Optional[bool]
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class LlavaNextImageProcessorFast(BaseImageProcessorFast):
# To be checked against the slow image processor
# None values left after checking can be removed
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"shortest_edge": 224}
default_to_square = False
crop_size = {"height": 224, "width": 224}
do_resize = True
do_center_crop = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
valid_init_kwargs = LlavaNextFastImageProcessorInitKwargs
valid_preprocess_kwargs = LlavaNextFastImageProcessorPreprocessKwargs
def __init__(self, **kwargs: Unpack[LlavaNextFastImageProcessorInitKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
image_grid_pinpoints (`List`, *optional*):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
def preprocess(
self, images: ImageInput, **kwargs: Unpack[LlavaNextFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _resize_for_patching(
self,
image: "torch.Tensor",
target_resolution: tuple,
interpolation: "F.InterpolationMode",
input_data_format: ChannelDimension,
) -> "torch.Tensor":
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image ("torch.Tensor"):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
"torch.Tensor": The resized and padded image.
"""
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
return resized_image
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
return padded_image
def _get_image_patches(
self,
image: "torch.Tensor",
grid_pinpoints,
size: tuple,
patch_size: int,
interpolation: "F.InterpolationMode",
) -> List["torch.Tensor"]:
"""
Process an image with variable resolutions by dividing it into patches.
Args:
image ("torch.Tensor"):
The input image to be processed.
grid_pinpoints (List):
A string representation of a list of possible resolutions.
size (`tuple`):
Size to resize the original image to.
patch_size (`int`):
Size of the patches to divide the image into.
interpolation (`"InterpolationMode"`):
Resampling filter to use if resizing the image.
Returns:
List["torch.Tensor"]: A list of NumPy arrays containing the processed image patches.
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints must be a list of possible resolutions.")
possible_resolutions = grid_pinpoints
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
best_resolution = select_best_resolution(image_size, possible_resolutions)
resized_image = self._resize_for_patching(
image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST
)
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST)
patches = divide_to_patches(padded_image, patch_size=patch_size)
resized_original_image = F.resize(image, size=size, interpolation=interpolation)
image_patches = [resized_original_image] + patches
return image_patches
def _pad_for_batching(
self,
pixel_values: List["torch.Tensor"],
) -> List["torch.Tensor"]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
for image in pixel_values
]
return pixel_values
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
image_grid_pinpoints: List[List[int]],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
image_sizes = []
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
patch_size = crop_size.height
elif size and size.height:
patch_size = size.height
else:
patch_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
)
__all__ = ["LlavaNextImageProcessorFast"]

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_llava_onevision import *
from .image_processing_llava_onevision import *
from .image_processing_llava_onevision_fast import *
from .modeling_llava_onevision import *
from .processing_llava_onevision import *
from .video_processing_llava_onevision import *

View File

@ -119,7 +119,7 @@ def _get_patch_output_size(image, target_resolution, input_data_format):
class LlavaOnevisionImageProcessor(BaseImageProcessor):
r"""
Constructs a LLaVa-Onevisino-Video video processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
Constructs a LLaVa-Onevision image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
Args:
do_resize (`bool`, *optional*, defaults to `True`):

View File

@ -0,0 +1,305 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/llava_onevision/modular_llava_onevision.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_llava_onevision.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import List, Optional, Union
from ...image_processing_utils import BatchFeature, get_patch_output_size, select_best_resolution
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
divide_to_patches,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
make_flat_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import TensorType, add_start_docstrings, is_torch_available, is_torchvision_v2_available
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class LlavaOnevisionFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
image_grid_pinpoints: Optional[List[List[int]]]
do_pad: Optional[bool]
class LlavaOnevisionFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
image_grid_pinpoints: Optional[List[List[int]]]
do_pad: Optional[bool]
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method. Not used for processing videos.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 384, "width": 384}
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip
valid_init_kwargs = LlavaOnevisionFastImageProcessorInitKwargs
valid_preprocess_kwargs = LlavaOnevisionFastImageProcessorPreprocessKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[LlavaOnevisionFastImageProcessorInitKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
image_grid_pinpoints (`List`, *optional*):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
def preprocess(
self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _resize_for_patching(
self,
image: "torch.Tensor",
target_resolution: tuple,
interpolation: "F.InterpolationMode",
input_data_format: ChannelDimension,
) -> "torch.Tensor":
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image ("torch.Tensor"):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
"torch.Tensor": The resized and padded image.
"""
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
return resized_image
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
return padded_image
def _get_image_patches(
self,
image: "torch.Tensor",
grid_pinpoints,
size: tuple,
patch_size: int,
interpolation: "F.InterpolationMode",
) -> List["torch.Tensor"]:
"""
Process an image with variable resolutions by dividing it into patches.
Args:
image ("torch.Tensor"):
The input image to be processed.
grid_pinpoints (List):
A string representation of a list of possible resolutions.
size (`tuple`):
Size to resize the original image to.
patch_size (`int`):
Size of the patches to divide the image into.
interpolation (`"InterpolationMode"`):
Resampling filter to use if resizing the image.
Returns:
List["torch.Tensor"]: A list of NumPy arrays containing the processed image patches.
"""
if not isinstance(grid_pinpoints, list):
raise TypeError("grid_pinpoints must be a list of possible resolutions.")
possible_resolutions = grid_pinpoints
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
best_resolution = select_best_resolution(image_size, possible_resolutions)
resized_image = self._resize_for_patching(
image, best_resolution, interpolation=interpolation, input_data_format=ChannelDimension.FIRST
)
padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=ChannelDimension.FIRST)
patches = divide_to_patches(padded_image, patch_size=patch_size)
resized_original_image = F.resize(image, size=size, interpolation=interpolation)
image_patches = [resized_original_image] + patches
return image_patches
def _pad_for_batching(
self,
pixel_values: List["torch.Tensor"],
) -> List["torch.Tensor"]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
for image in pixel_values
]
return pixel_values
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
image_grid_pinpoints: List[List[int]],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
image_sizes = []
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
patch_size = crop_size.height
elif size and size.height:
patch_size = size.height
else:
patch_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
)
__all__ = ["LlavaOnevisionImageProcessorFast"]

View File

@ -0,0 +1,45 @@
from transformers.models.llava_next.image_processing_llava_next_fast import LlavaNextImageProcessorFast
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
PILImageResampling,
)
from ...utils import add_start_docstrings, logging
logger = logging.get_logger(__name__)
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method. Not used for processing videos.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"height": 384, "width": 384}
crop_size = None
default_to_square = False
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip
model_input_names = ["pixel_values_videos"]
__all__ = ["LlavaOnevisionImageProcessorFast"]

View File

@ -17,21 +17,24 @@
from typing import Dict, List, Optional, Union
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast
from ...image_utils import (
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
get_image_size,
get_image_type,
infer_channel_dimension_format,
make_list_of_images,
validate_fast_preprocess_arguments,
validate_kwargs,
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
ImageInput,
PILImageResampling,
SizeDict,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
@ -39,7 +42,6 @@ from ...utils import (
logging,
)
from .image_processing_pixtral import (
convert_to_rgb,
get_resize_output_image_size,
)
@ -51,7 +53,7 @@ if is_torch_available():
if is_torchvision_available():
if is_vision_available():
from ...image_utils import pil_torch_interpolation_mapping
pass
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
@ -59,93 +61,56 @@ if is_torchvision_available():
from torchvision.transforms import functional as F
class PixtralImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a fast Pixtral image processor that leverages torchvision.
class PixtralFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
patch_size: Optional[Dict[str, int]]
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
Size of the maximum dimension of either the height or width dimension of the image. Used to control how
images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
class PixtralFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
patch_size: Optional[Dict[str, int]]
@add_start_docstrings(
r"Constructs a fast ConvNeXT image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
""",
)
class PixtralImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]
patch_size = {"height": 16, "width": 16}
size = {"longest_edge": 1024}
default_to_square = True
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
valid_init_kwargs = PixtralFastImageProcessorInitKwargs
valid_preprocess_kwargs = PixtralFastImageProcessorPreprocessKwargs
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
patch_size: Dict[str, int] = None,
resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
def __init__(self, **kwargs: Unpack[PixtralFastImageProcessorInitKwargs]):
super().__init__(**kwargs)
size = size if size is not None else {"longest_edge": 1024}
patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
patch_size = get_size_dict(patch_size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.patch_size = patch_size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [
"images",
"do_resize",
"size",
"patch_size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format",
]
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
""",
)
def preprocess(
self, images: ImageInput, **kwargs: Unpack[PixtralFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def resize(
self,
image: torch.Tensor,
size: Dict[str, int],
patch_size: Dict[str, int],
size: SizeDict,
patch_size: SizeDict,
interpolation: "F.InterpolationMode" = None,
**kwargs,
) -> torch.Tensor:
@ -156,37 +121,28 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
Args:
image (`torch.Tensor`):
Image to resize.
size (`Dict[str, int]`):
size (`SizeDict`):
Dict containing the longest possible edge of the image.
patch_size (`Dict[str, int]`):
patch_size (`SizeDict`):
Patch size used to calculate the size of the output image.
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
Resampling filter to use when resiizing the image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if "longest_edge" in size:
size = (size["longest_edge"], size["longest_edge"])
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
if size.longest_edge:
size = (size.longest_edge, size.longest_edge)
elif size.height and size.width:
size = (size.height, size.width)
else:
raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
if "height" in patch_size and "width" in patch_size:
patch_size = (patch_size["height"], patch_size["width"])
if patch_size.height and patch_size.width:
patch_size = (patch_size.height, patch_size.width)
else:
raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size(
image,
size=size,
patch_size=patch_size,
)
return F.resize(
image,
size=output_size,
interpolation=interpolation,
**kwargs,
)
output_size = get_resize_output_image_size(image, size=size, patch_size=patch_size)
return F.resize(image, size=output_size, interpolation=interpolation, **kwargs)
# Adapted from transformers.models.pixtral.image_processing_pixtral.PixtralImageProcessor._pad_for_batching
def _pad_for_batching(
@ -205,177 +161,64 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
List[`torch.Tensor`]: The padded images.
"""
max_shape = (
max([size[0] for size in image_sizes]),
max([size[1] for size in image_sizes]),
)
max_shape = (max([size[0] for size in image_sizes]), max([size[1] for size in image_sizes]))
pixel_values = [
torch.nn.functional.pad(
image,
pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0]),
)
torch.nn.functional.pad(image, pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0]))
for image, size in zip(pixel_values, image_sizes)
]
return torch.stack(pixel_values)
def preprocess(
def _preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
patch_size: Dict[str, int] = None,
resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
patch_size: Dict[str, int],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: Dict[str, int],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Describes the maximum input dimensions to the model.
patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
Patch size in the model. Used to calculate the image after resizing.
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
patch_size = patch_size if patch_size is not None else self.patch_size
patch_size = get_size_dict(patch_size, default_to_square=True)
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
device = kwargs.pop("device", None)
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
images = make_list_of_images(images)
image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
validate_fast_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
return_tensors=return_tensors,
data_format=data_format,
)
if do_rescale and do_normalize:
# fused rescale and normalize
new_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
new_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
batch_images = []
batch_image_sizes = []
for image in images:
if do_convert_rgb:
image = convert_to_rgb(image)
if image_type == ImageType.PIL:
image = F.pil_to_tensor(image)
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
image = torch.from_numpy(image).contiguous()
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image.permute(2, 0, 1).contiguous()
image = image.to(device)
patch_size = SizeDict(**patch_size)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
image = self.resize(
image=image,
size=size,
patch_size=patch_size,
interpolation=interpolation,
stacked_images = self.resize(
image=stacked_images, size=size, patch_size=patch_size, interpolation=interpolation
)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
batch_image_sizes = [grouped_images_index[i][0] for i in range(len(grouped_images_index))]
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
batch_images.append(image)
batch_image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
pixel_values = self._pad_for_batching(
pixel_values=batch_images,
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
padded_images = self._pad_for_batching(
pixel_values=processed_images,
image_sizes=batch_image_sizes,
)
return BatchFeature(
data={"pixel_values": pixel_values, "image_sizes": batch_image_sizes}, tensor_type=return_tensors
data={"pixel_values": padded_images, "image_sizes": batch_image_sizes}, tensor_type=return_tensors
)

View File

@ -156,7 +156,7 @@ class Qwen2_5_VLImageProcessor(BaseImageProcessor):
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
self.do_convert_rgb = do_convert_rgb
def _preprocess(

View File

@ -149,7 +149,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
self.do_convert_rgb = do_convert_rgb
def _preprocess(

View File

@ -23,30 +23,29 @@ from typing import Dict, List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BaseImageProcessorFast,
)
from ...image_transforms import (
convert_to_rgb,
DefaultFastImageProcessorInitKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
SizeDict,
VideoInput,
get_image_size,
get_image_type,
infer_channel_dimension_format,
make_batched_videos,
make_flat_list_of_images,
make_list_of_images,
valid_images,
validate_preprocess_arguments,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
@ -60,8 +59,7 @@ if is_torch_available():
import torch
if is_vision_available():
from ...image_utils import pil_torch_interpolation_mapping
pass
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
@ -71,27 +69,18 @@ elif is_torchvision_available():
logger = logging.get_logger(__name__)
class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a fast Qwen2-VL image processor that dynamically resizes images based on the original images.
class Qwen2VLFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
min_pixels: Optional[int]
max_pixels: Optional[int]
patch_size: Optional[int]
temporal_patch_size: Optional[int]
merge_size: Optional[int]
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use when resizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
@add_start_docstrings(
"Constructs a fast Qwen2-VL image processor that dynamically resizes images based on the original images.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
min_pixels (`int`, *optional*, defaults to `56 * 56`):
The min pixels of the image to resize the image.
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
@ -102,57 +91,42 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to 2):
The merge size of the vision encoder to llm encoder.
"""
""",
)
class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
do_resize = True
resample = PILImageResampling.BICUBIC
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
do_rescale = True
do_normalize = True
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
do_convert_rgb = True
patch_size = 14
temporal_patch_size = 2
merge_size = 2
min_pixels = 56 * 56
max_pixels = 28 * 28 * 1280
valid_init_kwargs = Qwen2VLFastImageProcessorInitKwargs
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
def __init__(
self,
do_resize: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
min_pixels: int = 56 * 56,
max_pixels: int = 28 * 28 * 1280,
patch_size: int = 14,
temporal_patch_size: int = 2,
merge_size: int = 2,
**kwargs,
) -> None:
def __init__(self, **kwargs: Unpack[Qwen2VLFastImageProcessorInitKwargs]):
super().__init__(**kwargs)
self.do_resize = do_resize
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
self.do_convert_rgb = do_convert_rgb
def _preprocess(
self,
images: Union[ImageInput, VideoInput],
do_resize: bool = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional[Union[str, torch.device]] = None,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_convert_rgb: bool,
input_data_format: Optional[Union[str, ChannelDimension]],
device: Optional[Union[str, torch.device]],
):
"""
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
@ -164,8 +138,8 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
Optional list of dictionaries containing additional information about vision inputs.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
@ -178,50 +152,28 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
device (`torch.device`, *optional*):
The device to process the images on. If unset, the device is inferred from the input images.
"""
images = make_list_of_images(images)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
image_type = get_image_type(images[0])
if image_type == ImageType.PIL:
images = [F.pil_to_tensor(image) for image in images]
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
images = [torch.from_numpy(image).contiguous() for image in images]
if device is not None:
images = [image.to(device) for image in images]
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.LAST:
images = [image.permute(2, 0, 1).contiguous() for image in images]
input_data_format = ChannelDimension.FIRST
if do_rescale and do_normalize:
# fused rescale and normalize
image_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor)
image_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor)
height, width = get_image_size(images[0], channel_dim=input_data_format)
interpolation = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
images = self._prepare_input_images(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
height, width = get_image_size(images[0], channel_dim=ChannelDimension.FIRST)
resized_height, resized_width = height, width
processed_images = []
for image in images:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
resized_height, resized_width = smart_resize(
height,
@ -230,19 +182,25 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
image = F.resize(image, size=(resized_height, resized_width), interpolation=interpolation)
stacked_images = F.resize(
stacked_images, size=(resized_height, resized_width), interpolation=interpolation
)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images.append(image)
patches = torch.stack(processed_images)
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
patches = torch.stack(processed_images, dim=0)
if patches.shape[0] % self.temporal_patch_size != 0:
repeats = patches[-1].unsqueeze(0).repeat(self.temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=0)
@ -275,7 +233,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
videos: VideoInput = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
@ -285,6 +243,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
**kwargs,
):
"""
@ -334,7 +293,8 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
device (`torch.device`, *optional*):
The device to process the images on. If unset, the device is inferred from the input images.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
@ -345,12 +305,25 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
device = kwargs.pop("device", None)
# Make hashable for cache
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
size = SizeDict(**size) if size is not None else None
image_mean = tuple(image_mean) if image_mean is not None else None
image_std = tuple(image_std) if image_std is not None else None
image_mean, image_std, interpolation = self._prepare_process_arguments(
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
device=device,
)
if images is not None:
images = make_flat_list_of_images(images)
if videos is not None:
@ -362,29 +335,19 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if images is not None:
pixel_values, vision_grid_thws = [], []
for image in images:
patches, image_grid_thw = self._preprocess(
image,
do_resize=do_resize,
resample=resample,
size=size,
interpolation=interpolation,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
@ -401,13 +364,13 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
patches, video_grid_thw = self._preprocess(
images,
do_resize=do_resize,
resample=resample,
size=size,
interpolation=interpolation,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,

View File

@ -4,14 +4,18 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_rt_detr.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import functools
import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorInitKwargs,
DefaultFastImageProcessorPreprocessKwargs,
SizeDict,
add_start_docstrings,
get_image_size_for_max_height_width,
get_max_height_width,
safe_squeeze,
@ -24,21 +28,16 @@ from ...image_utils import (
AnnotationType,
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
get_image_size,
get_image_type,
infer_channel_dimension_format,
make_list_of_images,
validate_annotations,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
filter_out_non_signature_kwargs,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
requires_backends,
)
from .image_processing_rt_detr import get_size_with_aspect_ratio
@ -47,15 +46,30 @@ from .image_processing_rt_detr import get_size_with_aspect_ratio
if is_torch_available():
import torch
if is_vision_available():
from ...image_utils import pil_torch_interpolation_mapping
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
elif is_torchvision_available():
from torchvision.transforms import functional as F
class RTDetrFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
format: Optional[Union[str, AnnotationFormat]]
do_convert_annotations: Optional[bool]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
class RTDetrFastImageProcessorPreprocessKwargs(DefaultFastImageProcessorPreprocessKwargs):
format: Optional[AnnotationFormat]
annotations: Optional[Dict]
do_convert_annotations: Optional[bool]
do_pad: Optional[bool]
pad_size: Optional[Dict[str, int]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
@ -118,49 +132,17 @@ def prepare_coco_detection_annotation(
return new_target
class RTDetrImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a fast RTDetr image processor.
Args:
@add_start_docstrings(
"Constructs a fast RTDetr image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
do_resize (`bool`, *optional*, defaults to `True`):
Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be
overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
in the `preprocess` method. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `False`):
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
`preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_annotations (`bool`, *optional*, defaults to `True`):
Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `False`):
do_pad (`bool`, *optional*, defaults to `True`):
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
If `pad_size` is provided, the image will be padded to the specified dimensions.
@ -169,45 +151,32 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
"""
""",
)
class RTDetrImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
format = AnnotationFormat.COCO_DETECTION
do_resize = True
do_rescale = True
do_normalize = False
do_pad = False
size = {"height": 640, "width": 640}
default_to_square = False
model_input_names = ["pixel_values", "pixel_mask"]
valid_init_kwargs = RTDetrFastImageProcessorInitKwargs
valid_preprocess_kwargs = RTDetrFastImageProcessorPreprocessKwargs
do_convert_annotations = True
def __init__(
self,
format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = False,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
do_convert_annotations: bool = True,
do_pad: bool = False,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
) -> None:
size = size if size is not None else {"height": 640, "width": 640}
size = get_size_dict(size, default_to_square=False)
if do_convert_annotations is None:
do_convert_annotations = do_normalize
def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorInitKwargs]) -> None:
# Backwards compatibility
do_convert_annotations = kwargs.get("do_convert_annotations", None)
do_normalize = kwargs.get("do_normalize", None)
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
super().__init__(**kwargs)
self.format = format
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_convert_annotations = do_convert_annotations
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
self.pad_size = pad_size
def prepare_annotation(
self,
@ -419,174 +388,71 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
return image, pixel_mask, annotation
@functools.lru_cache(maxsize=1)
def _validate_input_arguments(
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
do_convert_annotations (`bool`, *optional*, defaults to `True`):
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
If `pad_size` is provided, the image will be padded to the specified dimensions.
Otherwise, the image will be padded to the maximum height and width of the batch.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
""",
)
def preprocess(
self, images: ImageInput, **kwargs: Unpack[RTDetrFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
do_resize: bool,
size: Dict[str, int],
resample: "PILImageResampling",
data_format: Union[str, ChannelDimension],
return_tensors: Union[TensorType, str],
):
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
if do_resize and None in (size, resample):
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and None in (image_mean, image_std):
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
@filter_out_non_signature_kwargs(extra=["device"])
def preprocess(
self,
images: ImageInput,
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
return_segmentation_masks: bool = None,
masks_path: Optional[Union[str, pathlib.Path]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[Union[int, float]] = None,
do_normalize: Optional[bool] = None,
do_convert_annotations: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
format: Optional[Union[str, AnnotationFormat]] = None,
return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
do_convert_annotations: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.
Args:
images (`ImageInput`):
Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
do_resize (`bool`, *optional*, defaults to self.do_resize):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to self.size):
Size of the image's `(height, width)` dimensions after resizing. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample):
Resampling filter to use when resizing the image.
do_rescale (`bool`, *optional*, defaults to self.do_rescale):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
Rescale factor to use when rescaling the image.
do_normalize (`bool`, *optional*, defaults to self.do_normalize):
Whether to normalize the image.
do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
Whether to convert the annotations to the format expected by the model. Converts the bounding
boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
and in relative coordinates.
image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
Mean to use when normalizing the image.
image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
Standard deviation to use when normalizing the image.
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True`, padding will be applied to the bottom and right of
the image with zeros. If `pad_size` is provided, the image will be padded to the specified
dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
Format of the annotations.
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
Type of tensors to return. If `None`, will return the list of images.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
"""
do_resize = self.do_resize if do_resize is None else do_resize
size = self.size if size is None else size
size = get_size_dict(size=size, default_to_square=True)
resample = self.resample if resample is None else resample
do_rescale = self.do_rescale if do_rescale is None else do_rescale
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
do_normalize = self.do_normalize if do_normalize is None else do_normalize
image_mean = self.image_mean if image_mean is None else image_mean
image_std = self.image_std if image_std is None else image_std
do_convert_annotations = (
self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
)
do_pad = self.do_pad if do_pad is None else do_pad
pad_size = self.pad_size if pad_size is None else pad_size
format = self.format if format is None else format
return_tensors = "pt" if return_tensors is None else return_tensors
device = kwargs.pop("device", None)
# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
images = make_list_of_images(images)
image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
self._validate_input_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
return_tensors=return_tensors,
data_format=data_format,
)
if annotations is not None and isinstance(annotations, dict):
annotations = [annotations]
@ -601,27 +467,6 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
data = {}
if image_type == ImageType.PIL:
images = [F.pil_to_tensor(image) for image in images]
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
images = [torch.from_numpy(image).contiguous() for image in images]
if device is not None:
images = [image.to(device) for image in images]
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.LAST:
images = [image.permute(2, 0, 1).contiguous() for image in images]
input_data_format = ChannelDimension.FIRST
if do_rescale and do_normalize:
# fused rescale and normalize
new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor)
new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor)
processed_images = []
processed_annotations = []
pixel_masks = [] # Initialize pixel_masks here
@ -634,15 +479,10 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
format,
return_segmentation_masks=return_segmentation_masks,
masks_path=masks_path,
input_data_format=input_data_format,
input_data_format=ChannelDimension.FIRST,
)
if do_resize:
interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
resized_image = self.resize(image, size=size, interpolation=interpolation)
if annotations is not None:
annotation = self.resize_annotation(
@ -654,14 +494,14 @@ class RTDetrImageProcessorFast(BaseImageProcessorFast):
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, input_data_format))
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
processed_images.append(image)
processed_annotations.append(annotation)

View File

@ -1,12 +1,18 @@
import pathlib
from typing import Dict, List, Optional, Tuple, Union
from transformers.models.detr.image_processing_detr_fast import DetrImageProcessorFast
from transformers.models.detr.image_processing_detr_fast import (
DetrFastImageProcessorInitKwargs,
DetrFastImageProcessorPreprocessKwargs,
DetrImageProcessorFast,
)
from ...image_processing_utils import BatchFeature, get_size_dict
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
SizeDict,
add_start_docstrings,
get_max_height_width,
)
from ...image_transforms import center_to_corners_format
@ -17,21 +23,16 @@ from ...image_utils import (
AnnotationType,
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
get_image_size,
get_image_type,
infer_channel_dimension_format,
make_list_of_images,
validate_annotations,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
filter_out_non_signature_kwargs,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
logging,
requires_backends,
)
@ -40,9 +41,6 @@ from ...utils import (
if is_torch_available():
import torch
if is_vision_available():
from ...image_utils import pil_torch_interpolation_mapping
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
@ -114,49 +112,60 @@ def prepare_coco_detection_annotation(
return new_target
class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
r"""
Constructs a fast RTDetr image processor.
class RTDetrFastImageProcessorInitKwargs(DetrFastImageProcessorInitKwargs):
pass
Args:
class RTDetrFastImageProcessorPreprocessKwargs(DetrFastImageProcessorPreprocessKwargs):
pass
class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
format = AnnotationFormat.COCO_DETECTION
do_convert_annotations = True
do_resize = True
do_rescale = True
do_normalize = False
do_pad = False
size = {"height": 640, "width": 640}
default_to_square = False
model_input_names = ["pixel_values", "pixel_mask"]
valid_init_kwargs = RTDetrFastImageProcessorInitKwargs
valid_preprocess_kwargs = RTDetrFastImageProcessorPreprocessKwargs
def __init__(self, **kwargs: Unpack[RTDetrFastImageProcessorInitKwargs]) -> None:
# Backwards compatibility
do_convert_annotations = kwargs.get("do_convert_annotations", None)
do_normalize = kwargs.get("do_normalize", None)
if do_convert_annotations is None and getattr(self, "do_convert_annotations", None) is None:
self.do_convert_annotations = do_normalize if do_normalize is not None else self.do_normalize
BaseImageProcessorFast.__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
do_resize (`bool`, *optional*, defaults to `True`):
Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be
overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
in the `preprocess` method. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `False`):
Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
`preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_annotations (`bool`, *optional*, defaults to `True`):
Controls whether to convert the annotations to the format expected by the RT_DETR model. Converts the
Controls whether to convert the annotations to the format expected by the DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `False`):
do_pad (`bool`, *optional*, defaults to `True`):
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
If `pad_size` is provided, the image will be padded to the specified dimensions.
@ -165,43 +174,16 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
"""
def __init__(
self,
format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: Union[PILImageResampling, "F.InterpolationMode"] = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = False,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
do_convert_annotations: bool = True,
do_pad: bool = False,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
) -> None:
size = size if size is not None else {"height": 640, "width": 640}
size = get_size_dict(size, default_to_square=False)
if do_convert_annotations is None:
do_convert_annotations = do_normalize
BaseImageProcessorFast.__init__(**kwargs)
self.format = format
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_convert_annotations = do_convert_annotations
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
self.pad_size = pad_size
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
""",
)
def preprocess(
self, images: ImageInput, **kwargs: Unpack[RTDetrFastImageProcessorPreprocessKwargs]
) -> BatchFeature:
return BaseImageProcessorFast().preprocess(images, **kwargs)
def prepare_annotation(
self,
@ -223,145 +205,31 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
raise ValueError(f"Format {format} is not supported.")
return target
@filter_out_non_signature_kwargs(extra=["device"])
def preprocess(
def _preprocess(
self,
images: ImageInput,
annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
return_segmentation_masks: bool = None,
masks_path: Optional[Union[str, pathlib.Path]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[Union[PILImageResampling, "F.InterpolationMode"]] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[Union[int, float]] = None,
do_normalize: Optional[bool] = None,
do_convert_annotations: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
format: Optional[Union[str, AnnotationFormat]] = None,
return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
pad_size: Optional[Dict[str, int]] = None,
**kwargs,
images: List["torch.Tensor"],
annotations: Optional[Union[AnnotationType, List[AnnotationType]]],
return_segmentation_masks: bool,
masks_path: Optional[Union[str, pathlib.Path]],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
do_convert_annotations: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
pad_size: Optional[Dict[str, int]],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
"""
Preprocess an image or a batch of images so that it can be used by the model.
Args:
images (`ImageInput`):
Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
- "file_name" (`str`): The file name of the image.
return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
Whether to return segmentation masks.
masks_path (`str` or `pathlib.Path`, *optional*):
Path to the directory containing the segmentation masks.
do_resize (`bool`, *optional*, defaults to self.do_resize):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to self.size):
Size of the image's `(height, width)` dimensions after resizing. Available options are:
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
Do NOT keep the aspect ratio.
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
less or equal to `longest_edge`.
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
`max_width`.
resample (`PILImageResampling` or `InterpolationMode`, *optional*, defaults to self.resample):
Resampling filter to use when resizing the image.
do_rescale (`bool`, *optional*, defaults to self.do_rescale):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
Rescale factor to use when rescaling the image.
do_normalize (`bool`, *optional*, defaults to self.do_normalize):
Whether to normalize the image.
do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
Whether to convert the annotations to the format expected by the model. Converts the bounding
boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
and in relative coordinates.
image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
Mean to use when normalizing the image.
image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
Standard deviation to use when normalizing the image.
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True`, padding will be applied to the bottom and right of
the image with zeros. If `pad_size` is provided, the image will be padded to the specified
dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
Format of the annotations.
return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
Type of tensors to return. If `None`, will return the list of images.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
pad_size (`Dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
"""
do_resize = self.do_resize if do_resize is None else do_resize
size = self.size if size is None else size
size = get_size_dict(size=size, default_to_square=True)
resample = self.resample if resample is None else resample
do_rescale = self.do_rescale if do_rescale is None else do_rescale
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
do_normalize = self.do_normalize if do_normalize is None else do_normalize
image_mean = self.image_mean if image_mean is None else image_mean
image_std = self.image_std if image_std is None else image_std
do_convert_annotations = (
self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
)
do_pad = self.do_pad if do_pad is None else do_pad
pad_size = self.pad_size if pad_size is None else pad_size
format = self.format if format is None else format
return_tensors = "pt" if return_tensors is None else return_tensors
device = kwargs.pop("device", None)
# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
images = make_list_of_images(images)
image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
self._validate_input_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
return_tensors=return_tensors,
data_format=data_format,
)
if annotations is not None and isinstance(annotations, dict):
annotations = [annotations]
@ -376,27 +244,6 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
data = {}
if image_type == ImageType.PIL:
images = [F.pil_to_tensor(image) for image in images]
elif image_type == ImageType.NUMPY:
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
images = [torch.from_numpy(image).contiguous() for image in images]
if device is not None:
images = [image.to(device) for image in images]
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.LAST:
images = [image.permute(2, 0, 1).contiguous() for image in images]
input_data_format = ChannelDimension.FIRST
if do_rescale and do_normalize:
# fused rescale and normalize
new_mean = torch.tensor(image_mean, device=images[0].device) * (1.0 / rescale_factor)
new_std = torch.tensor(image_std, device=images[0].device) * (1.0 / rescale_factor)
processed_images = []
processed_annotations = []
pixel_masks = [] # Initialize pixel_masks here
@ -409,15 +256,10 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
format,
return_segmentation_masks=return_segmentation_masks,
masks_path=masks_path,
input_data_format=input_data_format,
input_data_format=ChannelDimension.FIRST,
)
if do_resize:
interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
resized_image = self.resize(image, size=size, interpolation=interpolation)
if annotations is not None:
annotation = self.resize_annotation(
@ -429,14 +271,14 @@ class RTDetrImageProcessorFast(DetrImageProcessorFast, BaseImageProcessorFast):
if do_rescale and do_normalize:
# fused rescale and normalize
image = F.normalize(image.to(dtype=torch.float32), new_mean, new_std)
image = F.normalize(image.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
image = image * rescale_factor
elif do_normalize:
image = F.normalize(image, image_mean, image_std)
if do_convert_annotations and annotations is not None:
annotation = self.normalize_annotation(annotation, get_image_size(image, input_data_format))
annotation = self.normalize_annotation(annotation, get_image_size(image, ChannelDimension.FIRST))
processed_images.append(image)
processed_annotations.append(annotation)

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_siglip import *
from .image_processing_siglip import *
from .image_processing_siglip_fast import *
from .modeling_siglip import *
from .processing_siglip import *
from .tokenization_siglip import *

View File

@ -0,0 +1,41 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for SigLIP."""
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
PILImageResampling,
)
from ...utils import add_start_docstrings
@add_start_docstrings(
"Constructs a fast SigLIP image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class SiglipImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 224, "width": 224}
default_to_square = False
do_resize = True
do_rescale = True
do_normalize = True
__all__ = ["SiglipImageProcessorFast"]

View File

@ -14,290 +14,32 @@
# limitations under the License.
"""Fast Image processor class for ViT."""
import functools
from typing import Dict, List, Optional, Union
from ...image_processing_base import BatchFeature
from ...image_processing_utils import get_size_dict
from ...image_processing_utils_fast import BaseImageProcessorFast, SizeDict
from ...image_transforms import FusedRescaleNormalize, NumpyToTensor, Rescale, convert_to_rgb
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BaseImageProcessorFast,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
ImageType,
PILImageResampling,
get_image_type,
make_list_of_images,
pil_torch_interpolation_mapping,
)
from ...utils import TensorType, logging
from ...utils.import_utils import is_torch_available, is_torchvision_available
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
if is_torchvision_available():
from torchvision.transforms import Compose, Normalize, PILToTensor, Resize
from ...utils import (
add_start_docstrings,
)
@add_start_docstrings(
"Constructs a fast ViT image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class ViTImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a ViT image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
_transform_params = [
"do_resize",
"do_rescale",
"do_normalize",
"size",
"resample",
"rescale_factor",
"image_mean",
"image_std",
"image_type",
]
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_convert_rgb = do_convert_rgb
def _build_transforms(
self,
do_resize: bool,
size: Dict[str, int],
resample: PILImageResampling,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
image_type: ImageType,
) -> "Compose":
"""
Given the input settings build the image transforms using `torchvision.transforms.Compose`.
"""
transforms = []
# All PIL and numpy values need to be converted to a torch tensor
# to keep cross compatibility with slow image processors
if image_type == ImageType.PIL:
transforms.append(PILToTensor())
elif image_type == ImageType.NUMPY:
transforms.append(NumpyToTensor())
if do_resize:
transforms.append(
Resize((size["height"], size["width"]), interpolation=pil_torch_interpolation_mapping[resample])
)
# We can combine rescale and normalize into a single operation for speed
if do_rescale and do_normalize:
transforms.append(FusedRescaleNormalize(image_mean, image_std, rescale_factor=rescale_factor))
elif do_rescale:
transforms.append(Rescale(rescale_factor=rescale_factor))
elif do_normalize:
transforms.append(Normalize(image_mean, image_std))
return Compose(transforms)
@functools.lru_cache(maxsize=1)
def _validate_input_arguments(
self,
return_tensors: Union[str, TensorType],
do_resize: bool,
size: Dict[str, int],
resample: PILImageResampling,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
data_format: Union[str, ChannelDimension],
image_type: ImageType,
):
if return_tensors != "pt":
raise ValueError("Only returning PyTorch tensors is currently supported.")
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")
if do_resize and None in (size, resample):
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and None in (image_mean, image_std):
raise ValueError("Image mean and standard deviation must be specified if do_normalize is True.")
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = "pt",
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: Optional[bool] = None,
**kwargs,
):
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
resizing.
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Only "pt" is supported
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. The following formats are currently supported:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*):
Whether to convert the image to RGB.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
return_tensors = "pt" if return_tensors is None else return_tensors
# Make hashable for cache
size = SizeDict(**size)
image_mean = tuple(image_mean) if isinstance(image_mean, list) else image_mean
image_std = tuple(image_std) if isinstance(image_std, list) else image_std
images = make_list_of_images(images)
image_type = get_image_type(images[0])
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
raise ValueError(f"Unsupported input image type {image_type}")
self._validate_input_arguments(
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
image_type=image_type,
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
transforms = self.get_transforms(
do_resize=do_resize,
do_rescale=do_rescale,
do_normalize=do_normalize,
size=size,
resample=resample,
rescale_factor=rescale_factor,
image_mean=image_mean,
image_std=image_std,
image_type=image_type,
)
transformed_images = [transforms(image) for image in images]
data = {"pixel_values": torch.stack(transformed_images, dim=0)}
return BatchFeature(data, tensor_type=return_tensors)
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 224, "width": 224}
do_resize = True
do_rescale = True
do_normalize = True
__all__ = ["ViTImageProcessorFast"]

View File

@ -9,6 +9,27 @@ class BaseImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class BlipImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class CLIPImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class ConvNextImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class DeformableDetrImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
@ -16,6 +37,13 @@ class DeformableDetrImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class DeiTImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class DetrImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
@ -23,6 +51,27 @@ class DetrImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class LlavaImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class LlavaNextImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class LlavaOnevisionImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class PixtralImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
@ -44,6 +93,13 @@ class RTDetrImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class SiglipImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class ViTImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

View File

@ -17,7 +17,7 @@
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import BlipImageProcessor
if is_torchvision_available():
from transformers import BlipImageProcessorFast
class BlipImageProcessingTester:
def __init__(
@ -88,6 +91,7 @@ class BlipImageProcessingTester:
@require_vision
class BlipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = BlipImageProcessor if is_vision_available() else None
fast_image_processing_class = BlipImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -98,50 +102,36 @@ class BlipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processor = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
@require_torch
@require_vision
class BlipImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = BlipImageProcessor if is_vision_available() else None
fast_image_processing_class = BlipImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = BlipImageProcessingTester(self, num_channels=4)
self.expected_encoded_image_num_channels = 3
self.image_processor_tester = BlipImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processor = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
@unittest.skip(reason="BlipImageProcessor does not support 4 channels yet") # FIXME Amy
def test_call_numpy(self):
return super().test_call_numpy()
@unittest.skip(reason="BlipImageProcessor does not support 4 channels yet") # FIXME Amy
def test_call_pytorch(self):
return super().test_call_torch()
@unittest.skip(reason="BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_pil(self):
pass
@unittest.skip(reason="BLIP doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy
def test_call_numpy_4_channels(self):
pass
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))

View File

@ -17,7 +17,7 @@
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import CLIPImageProcessor
if is_torchvision_available():
from transformers import CLIPImageProcessorFast
class CLIPImageProcessingTester:
def __init__(
@ -44,6 +47,7 @@ class CLIPImageProcessingTester:
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
super().__init__()
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
@ -92,6 +96,7 @@ class CLIPImageProcessingTester:
@require_vision
class CLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = CLIPImageProcessor if is_vision_available() else None
fast_image_processing_class = CLIPImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -102,21 +107,23 @@ class CLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})

View File

@ -17,7 +17,7 @@
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import ConvNextImageProcessor
if is_torchvision_available():
from transformers import ConvNextImageProcessorFast
class ConvNextImageProcessingTester:
def __init__(
@ -85,6 +88,7 @@ class ConvNextImageProcessingTester:
@require_vision
class ConvNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ConvNextImageProcessor if is_vision_available() else None
fast_image_processing_class = ConvNextImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -95,17 +99,25 @@ class ConvNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "crop_pct"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "crop_pct"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
@unittest.skip(
"Skipping as ConvNextImageProcessor uses center_crop and center_crop functions are not equivalent for fast and slow processors"
)
def test_slow_fast_equivalence_batched(self):
pass

View File

@ -17,7 +17,7 @@
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import DeiTImageProcessor
if is_torchvision_available():
from transformers import DeiTImageProcessorFast
class DeiTImageProcessingTester:
def __init__(
@ -90,6 +93,7 @@ class DeiTImageProcessingTester:
@require_vision
class DeiTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = DeiTImageProcessor if is_vision_available() else None
fast_image_processing_class = DeiTImageProcessorFast if is_torchvision_available() else None
test_cast_dtype = True
def setUp(self):
@ -101,20 +105,22 @@ class DeiTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})

View File

@ -20,7 +20,7 @@ from typing import Tuple, Union
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -30,6 +30,11 @@ if is_vision_available():
from transformers import LlavaImageProcessor
if is_torchvision_available():
from torchvision.transforms import functional as F
from transformers import LlavaImageProcessorFast
class LlavaImageProcessingTester:
def __init__(
@ -50,6 +55,7 @@ class LlavaImageProcessingTester:
image_std=[0.26862954, 0.26130258, 0.27577711],
do_convert_rgb=True,
):
super().__init__()
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
@ -103,6 +109,7 @@ class LlavaImageProcessingTester:
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Llava
class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = LlavaImageProcessor if is_vision_available() else None
fast_image_processing_class = LlavaImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -114,25 +121,27 @@ class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Ignore copy
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
# Ignore copy
def test_padding(self):
@ -157,45 +166,72 @@ class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
result.paste(image, ((height - width) // 2, 0))
return result
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for i, image_processing_class in enumerate(self.image_processor_list):
image_processor = image_processing_class.from_dict(self.image_processor_dict)
numpify = i == 0
torchify = i == 1
image_inputs = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, numpify=numpify, torchify=torchify
)
# test with images in channel-last and channel-first format
for image in image_inputs:
padded_image = image_processor.pad_to_square(image)
padded_image_original = pad_to_square_original(Image.fromarray(image))
padded_image_original = np.array(padded_image_original)
# test with images in channel-last and channel-first format (only channel-first for torch)
for image in image_inputs:
padded_image = image_processor.pad_to_square(image)
if i == 0:
padded_image_original = pad_to_square_original(Image.fromarray(image))
padded_image_original = np.array(padded_image_original)
np.testing.assert_allclose(padded_image, padded_image_original)
np.testing.assert_allclose(padded_image, padded_image_original)
padded_image = image_processor.pad_to_square(image.transpose(2, 0, 1), input_data_format="channels_first")
padded_image = padded_image.transpose(1, 2, 0)
padded_image = image_processor.pad_to_square(
image.transpose(2, 0, 1), input_data_format="channels_first"
)
padded_image = padded_image.transpose(1, 2, 0)
np.testing.assert_allclose(padded_image, padded_image_original)
np.testing.assert_allclose(padded_image, padded_image_original)
else:
padded_image_original = pad_to_square_original(F.to_pil_image(image))
padded_image = padded_image.permute(1, 2, 0)
np.testing.assert_allclose(padded_image, padded_image_original)
# test background color
background_color = (122, 116, 104)
for image in image_inputs:
padded_image = image_processor.pad_to_square(image, background_color=background_color)
padded_image_original = pad_to_square_original(Image.fromarray(image), background_color=background_color)
padded_image_original = np.array(padded_image_original)
# test background color
background_color = (122, 116, 104)
for image in image_inputs:
padded_image = image_processor.pad_to_square(image, background_color=background_color)
if i == 0:
padded_image_original = pad_to_square_original(
Image.fromarray(image), background_color=background_color
)
else:
padded_image_original = pad_to_square_original(
F.to_pil_image(image), background_color=background_color
)
padded_image = padded_image.permute(1, 2, 0)
padded_image_original = np.array(padded_image_original)
np.testing.assert_allclose(padded_image, padded_image_original)
np.testing.assert_allclose(padded_image, padded_image_original)
background_color = 122
for image in image_inputs:
padded_image = image_processor.pad_to_square(image, background_color=background_color)
padded_image_original = pad_to_square_original(Image.fromarray(image), background_color=background_color)
padded_image_original = np.array(padded_image_original)
background_color = 122
for image in image_inputs:
padded_image = image_processor.pad_to_square(image, background_color=background_color)
if i == 0:
padded_image_original = pad_to_square_original(
Image.fromarray(image), background_color=background_color
)
else:
padded_image_original = pad_to_square_original(
F.to_pil_image(image), background_color=background_color
)
padded_image = padded_image.permute(1, 2, 0)
padded_image_original = np.array(padded_image_original)
np.testing.assert_allclose(padded_image, padded_image_original)
np.testing.assert_allclose(padded_image, padded_image_original)
# background color length should match channel length
with self.assertRaises(ValueError):
padded_image = image_processor.pad_to_square(image_inputs[0], background_color=(122, 104))
# background color length should match channel length
with self.assertRaises(ValueError):
padded_image = image_processor.pad_to_square(image_inputs[0], background_color=(122, 104))
with self.assertRaises(ValueError):
padded_image = image_processor.pad_to_square(image_inputs[0], background_color=(122, 104, 0, 0))
with self.assertRaises(ValueError):
padded_image = image_processor.pad_to_square(image_inputs[0], background_color=(122, 104, 0, 0))
@unittest.skip(reason="LLaVa does not support 4 channel images yet")
# Ignore copy

View File

@ -20,7 +20,7 @@ import numpy as np
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from transformers.models.llava_next.image_processing_llava_next import select_best_resolution
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -33,6 +33,9 @@ if is_vision_available():
from transformers import LlavaNextImageProcessor
if is_torchvision_available():
from transformers import LlavaNextImageProcessorFast
class LlavaNextImageProcessingTester:
def __init__(
@ -52,6 +55,7 @@ class LlavaNextImageProcessingTester:
image_std=OPENAI_CLIP_STD,
do_convert_rgb=True,
):
super().__init__()
size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
@ -102,6 +106,7 @@ class LlavaNextImageProcessingTester:
@require_vision
class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = LlavaNextImageProcessor if is_vision_available() else None
fast_image_processing_class = LlavaNextImageProcessorFast if is_torchvision_available() else None
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->LlavaNext
def setUp(self):
@ -114,26 +119,28 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "image_grid_pinpoints"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "image_grid_pinpoints"))
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
def test_select_best_resolution(self):
possible_resolutions = [[672, 336], [336, 672], [672, 672], [336, 1008], [1008, 336]]
@ -143,59 +150,62 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertEqual(best_resolution, (672, 336))
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@unittest.skip(
reason="LlavaNextImageProcessor doesn't treat 4 channel PIL and numpy consistently yet"
@ -204,19 +214,20 @@ class LlavaNextImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
pass
def test_nested_input(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
# Test batched as a list of images
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched as a list of images
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
# Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1445, 3, 18, 18)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
# Image processor should return same pixel values, independently of ipnut format
self.assertTrue((encoded_images_nested == encoded_images).all())
# Image processor should return same pixel values, independently of ipnut format
self.assertTrue((encoded_images_nested == encoded_images).all())

View File

@ -151,13 +151,14 @@ class LlavaNextVideoProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
def test_call_pil(self):
# Initialize image_processing

View File

@ -19,7 +19,7 @@ import numpy as np
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -30,7 +30,10 @@ if is_torch_available():
if is_vision_available():
from PIL import Image
from transformers import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor
from transformers import LlavaOnevisionImageProcessor
if is_torchvision_available():
from transformers import LlavaOnevisionImageProcessorFast, LlavaOnevisionVideoProcessor
class LlavaOnevisionImageProcessingTester:
@ -49,6 +52,7 @@ class LlavaOnevisionImageProcessingTester:
image_std=OPENAI_CLIP_STD,
do_convert_rgb=True,
):
super().__init__()
size = size if size is not None else {"height": 20, "width": 20}
self.parent = parent
self.batch_size = batch_size
@ -121,6 +125,7 @@ class LlavaOnevisionImageProcessingTester:
@require_vision
class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = LlavaOnevisionImageProcessor if is_vision_available() else None
fast_image_processing_class = LlavaOnevisionImageProcessorFast if is_torchvision_available() else None
video_processing_class = LlavaOnevisionVideoProcessor if is_vision_available() else None
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->LlavaOnevision
@ -134,14 +139,15 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "image_grid_pinpoints"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "image_grid_pinpoints"))
def test_video_processor_properties(self):
image_processing = self.video_processing_class(**self.image_processor_dict)
@ -153,66 +159,70 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 20, "width": 20})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_numpy(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@unittest.skip(
reason="LlavaOnevisionImageProcessor doesn't treat 4 channel PIL and numpy consistently yet"
@ -221,22 +231,23 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
pass
def test_nested_input(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
# Test batched as a list of images
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched as a list of images
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
# Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 1522, 3, 20, 20)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
# Image processor should return same pixel values, independently of input format
self.assertTrue((encoded_images_nested == encoded_images).all())
# Image processor should return same pixel values, independently of input format
self.assertTrue((encoded_images_nested == encoded_images).all())
def test_call_pil_video(self):
# Initialize image_processing
@ -289,3 +300,9 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos
expected_output_video_shape = (7, 8, 3, 20, 20)
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
@unittest.skip(
reason="LlavaOnevisionImageProcessorFast doesn't compile (infinitely) when using class transforms"
) # FIXME yoni
def test_can_compile_fast_image_processor(self):
pass

View File

@ -262,11 +262,43 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
torch.testing.assert_close(
encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], rtol=1e-2, atol=1e-2
encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], rtol=100, atol=1e-1
)
@require_vision
@require_torch
def test_slow_fast_equivalence_batched(self):
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
for i in range(len(encoding_slow.pixel_values)):
self.assertTrue(
torch.allclose(encoding_slow.pixel_values[i][0], encoding_fast.pixel_values[i][0], atol=1e-1)
)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values[i][0] - encoding_fast.pixel_values[i][0])).item(), 1e-3
)
torch.testing.assert_close(
encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], rtol=100, atol=1e-1
)
@slow
@require_torch_gpu
@require_vision

View File

@ -17,7 +17,7 @@
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -25,6 +25,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import SiglipImageProcessor
if is_torchvision_available():
from transformers import SiglipImageProcessorFast
class SiglipImageProcessingTester:
def __init__(
@ -89,6 +92,7 @@ class SiglipImageProcessingTester:
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Siglip
class SiglipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SiglipImageProcessor if is_vision_available() else None
fast_image_processing_class = SiglipImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -100,25 +104,27 @@ class SiglipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Ignore copy
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
# Ignore copy
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 84, "width": 84}
)
self.assertEqual(image_processor.size, {"height": 84, "width": 84})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 84, "width": 84}
)
self.assertEqual(image_processor.size, {"height": 84, "width": 84})
@unittest.skip(reason="not supported")
# Ignore copy

View File

@ -152,13 +152,14 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
def test_call_pil(self):
# Initialize image_processing

View File

@ -25,8 +25,8 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import ViTImageProcessor
if is_torchvision_available():
from transformers import ViTImageProcessorFast
if is_torchvision_available():
from transformers import ViTImageProcessorFast
class ViTImageProcessingTester:

View File

@ -165,23 +165,50 @@ class ImageProcessingTestMixin:
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, rtol=1e-1, atol=1e-2)
@require_vision
@require_torch
def test_slow_fast_equivalence_batched(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
@require_vision
@require_torch
@ -194,7 +221,8 @@ class ImageProcessingTestMixin:
def measure_time(image_processor, image):
# Warmup
_ = image_processor(image, return_tensors="pt")
for _ in range(5):
_ = image_processor(image, return_tensors="pt")
start = time.time()
_ = image_processor(image, return_tensors="pt")
return time.time() - start
@ -270,8 +298,31 @@ class ImageProcessingTestMixin:
image_processor_fast_1.save_pretrained(tmpdirname)
image_processor_slow_1 = self.image_processing_class.from_pretrained(tmpdirname)
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
dict_slow_0 = image_processor_slow_0.to_dict()
dict_slow_1 = image_processor_slow_1.to_dict()
difference = {
key: dict_slow_0.get(key) if key in dict_slow_0 else dict_slow_1.get(key)
for key in set(dict_slow_0) ^ set(dict_slow_1)
}
dict_slow_0 = {key: dict_slow_0[key] for key in set(dict_slow_0) & set(dict_slow_1)}
dict_slow_1 = {key: dict_slow_1[key] for key in set(dict_slow_0) & set(dict_slow_1)}
# check that all additional keys are None, except for `default_to_square` which is only set in fast processors
self.assertTrue(all(value is None for key, value in difference.items() if key not in ["default_to_square"]))
# check that the remaining keys are the same
self.assertEqual(dict_slow_0, dict_slow_1)
dict_fast_0 = image_processor_fast_0.to_dict()
dict_fast_1 = image_processor_fast_1.to_dict()
difference = {
key: dict_fast_0.get(key) if key in dict_fast_0 else dict_fast_1.get(key)
for key in set(dict_fast_0) ^ set(dict_fast_1)
}
dict_fast_0 = {key: dict_fast_0[key] for key in set(dict_fast_0) & set(dict_fast_1)}
dict_fast_1 = {key: dict_fast_1[key] for key in set(dict_fast_0) & set(dict_fast_1)}
# check that all additional keys are None, except for `default_to_square` which is only set in fast processors
self.assertTrue(all(value is None for key, value in difference.items() if key not in ["default_to_square"]))
# check that the remaining keys are the same
self.assertEqual(dict_fast_0, dict_fast_1)
def test_save_load_fast_slow_auto(self):
"Test that we can load a fast image processor from a slow one and vice-versa using AutoImageProcessor."
@ -293,8 +344,31 @@ class ImageProcessingTestMixin:
image_processor_fast_1.save_pretrained(tmpdirname)
image_processor_slow_1 = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=False)
self.assertEqual(image_processor_slow_0.to_dict(), image_processor_slow_1.to_dict())
self.assertEqual(image_processor_fast_0.to_dict(), image_processor_fast_1.to_dict())
dict_slow_0 = image_processor_slow_0.to_dict()
dict_slow_1 = image_processor_slow_1.to_dict()
difference = {
key: dict_slow_0.get(key) if key in dict_slow_0 else dict_slow_1.get(key)
for key in set(dict_slow_0) ^ set(dict_slow_1)
}
dict_slow_0 = {key: dict_slow_0[key] for key in set(dict_slow_0) & set(dict_slow_1)}
dict_slow_1 = {key: dict_slow_1[key] for key in set(dict_slow_0) & set(dict_slow_1)}
# check that all additional keys are None, except for `default_to_square` which is only set in fast processors
self.assertTrue(all(value is None for key, value in difference.items() if key not in ["default_to_square"]))
# check that the remaining keys are the same
self.assertEqual(dict_slow_0, dict_slow_1)
dict_fast_0 = image_processor_fast_0.to_dict()
dict_fast_1 = image_processor_fast_1.to_dict()
difference = {
key: dict_fast_0.get(key) if key in dict_fast_0 else dict_fast_1.get(key)
for key in set(dict_fast_0) ^ set(dict_fast_1)
}
dict_fast_0 = {key: dict_fast_0[key] for key in set(dict_fast_0) & set(dict_fast_1)}
dict_fast_1 = {key: dict_fast_1[key] for key in set(dict_fast_0) & set(dict_fast_1)}
# check that all additional keys are None, except for `default_to_square` which is only set in fast processors
self.assertTrue(all(value is None for key, value in difference.items() if key not in ["default_to_square"]))
# check that the remaining keys are the same
self.assertEqual(dict_fast_0, dict_fast_1)
def test_init_without_params(self):
for image_processing_class in self.image_processor_list:

View File

@ -833,6 +833,10 @@ def match_docstring_with_signature(obj: Any) -> Optional[Tuple[str, str]]:
# Nothing to do, no parameters are documented.
return
if "kwargs" in signature and signature["kwargs"].annotation != inspect._empty:
# Inspecting signature with typed kwargs is not supported yet.
return
indent = find_indent(obj_doc_lines[idx])
arguments = {}
current_arg = None

View File

@ -1066,6 +1066,8 @@ TYPE_TO_FILE_TYPE = {
"Processor": "processing",
"ImageProcessor": "image_processing",
"ImageProcessorFast": "image_processing*_fast", # "*" indicates where to insert the model name before the "_fast" suffix
"FastImageProcessorInitKwargs": "image_processing*_fast",
"FastImageProcessorPreprocessKwargs": "image_processing*_fast",
"FeatureExtractor": "feature_extractor",
"ProcessorKwargs": "processing",
"ImagesKwargs": "processing",