Add Idefics 3! (#32473)

* Add Idefics 3!

* fixes to make both pipelines identical

* fix for quantized models

* First pass at the review

* remove vocab size from the main config (it's still in the text_config)

* hot fix for merve

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* re-add model_type for text_config

* remove support for old_cache

* remove hidden_size from main config

* rename idefics3 HF repo

* few changes suggested in the PR

* fix to input_data_format computation

* remove overwrite of _autoset_attn_implementation following @zucchini-nlp suggestion

* improve example

* few improvements from amy's review

* big change to enable processing input images as numpy arrays

* Changes to the code to uniformize processor kwargs

* image processing tests

* image processing tests fixes and some bugs they discovered

* addressed review comments from Yoni

* fix modeling tests

* remove special tokens that are not special

* fixes tests

* skip failing tests - they also fail for idefics2

* added paper and readded the tests with multi gpu, who knows

* Update docs/source/en/model_doc/idefics3.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* review amy until image_processing_idefics3

* last comments from Amy

* review amy

* Update src/transformers/models/idefics3/image_processing_idefics3.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/idefics3/modeling_idefics3.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/idefics3.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* doc improvement - amy review

* fix runtime error during fine-tuning

* amy's review

* Update src/transformers/models/idefics3/image_processing_idefics3.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/idefics3/image_processing_idefics3.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/idefics3/modeling_idefics3.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* ruff

* amy's comment on the order

* ruff ruff

* fix copies

* square images when they are not splitted

* ruff :(

* Update src/transformers/models/idefics3/image_processing_idefics3.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/idefics3/test_processing_idefics3.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix small bug introduced in refactor

* amy's image processing changes

* fixes peft tests and ruff

* modify to_pil_image from transformers. and review from emanuele.

* add modified to_pil_image

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Andrés Marafioti 2024-09-25 21:28:49 +02:00 committed by GitHub
parent f0eabf6c7d
commit f2c388e3f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 4482 additions and 2 deletions

View File

@ -830,6 +830,8 @@
title: IDEFICS
- local: model_doc/idefics2
title: Idefics2
- local: model_doc/idefics3
title: Idefics3
- local: model_doc/instructblip
title: InstructBLIP
- local: model_doc/instructblipvideo

View File

@ -169,6 +169,7 @@ Flax), PyTorch, and/or TensorFlow.
| [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ |
| [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ |
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
| [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,73 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Idefics3
## Overview
The Idefics3 model was proposed in [Building and better understanding vision-language models: insights and future directions](https://huggingface.co/papers/2408.12637) by Hugo Laurençon, Andrés Marafioti, Victor Sanh, and Léo Tronchon.
Idefics3 is an adaptation of the Idefics2 model with three main differences:
- It uses Llama3 for the text model.
- It uses an updated processing logic for the images.
- It removes the perceiver.
The abstract from the paper is the following:
*The field of vision-language models (VLMs), which take images and texts as inputs and output texts, is rapidly evolving and has yet to reach consensus on several key aspects of the development pipeline, including data, architecture, and training methods. This paper can be seen as a tutorial for building a VLM. We begin by providing a comprehensive overview of the current state-of-the-art approaches, highlighting the strengths and weaknesses of each, addressing the major challenges in the field, and suggesting promising research directions for underexplored areas. We then walk through the practical steps to build Idefics3-8B, a powerful VLM that significantly outperforms its predecessor Idefics2-8B, while being trained efficiently, exclusively on open datasets, and using a straightforward pipeline. These steps include the creation of Docmatix, a dataset for improving document understanding capabilities, which is 240 times larger than previously available datasets. We release the model along with the datasets created for its training.*
## Usage tips
Input images are processed either by upsampling (if resizing is enabled) or at their original resolution. The resizing behavior depends on two parameters: do_resize and size.
If `do_resize` is set to `True`, the model resizes images so that the longest edge is 4*364 pixels by default.
The default resizing behavior can be customized by passing a dictionary to the `size` parameter. For example, `{"longest_edge": 4 * 364}` is the default, but you can change it to a different value if needed.
Heres how to control resizing and set a custom size:
```python
image_processor = Idefics3ImageProcessor(do_resize=True, size={"longest_edge": 2 * 364}, max_image_size=364)
```
Additionally, the `max_image_size` parameter, which controls the size of each square patch the image is decomposed into, is set to 364 by default but can be adjusted as needed. After resizing (if applicable), the image processor decomposes the images into square patches based on the `max_image_size` parameter.
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [andimarafioti](https://huggingface.co/andito).
## Idefics3Config
[[autodoc]] Idefics3Config
## Idefics3Model
[[autodoc]] Idefics3Model
- forward
## Idefics3ForConditionalGeneration
[[autodoc]] Idefics3ForConditionalGeneration
- forward
## Idefics3ImageProcessor
[[autodoc]] Idefics3ImageProcessor
- preprocess
## Idefics3Processor
[[autodoc]] Idefics3Processor
- __call__

View File

@ -54,6 +54,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)

View File

@ -481,6 +481,7 @@ _import_structure = {
"models.ibert": ["IBertConfig"],
"models.idefics": ["IdeficsConfig"],
"models.idefics2": ["Idefics2Config"],
"models.idefics3": ["Idefics3Config"],
"models.imagegpt": ["ImageGPTConfig"],
"models.informer": ["InformerConfig"],
"models.instructblip": [
@ -1191,6 +1192,7 @@ else:
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
_import_structure["models.idefics"].extend(["IdeficsImageProcessor"])
_import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
_import_structure["models.idefics3"].extend(["Idefics3ImageProcessor"])
_import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"])
_import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"])
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
@ -2420,6 +2422,14 @@ else:
"Idefics2Processor",
]
)
_import_structure["models.idefics3"].extend(
[
"Idefics3ForConditionalGeneration",
"Idefics3Model",
"Idefics3PreTrainedModel",
"Idefics3Processor",
]
)
_import_structure["models.imagegpt"].extend(
[
"ImageGPTForCausalImageModeling",
@ -5289,6 +5299,7 @@ if TYPE_CHECKING:
IdeficsConfig,
)
from .models.idefics2 import Idefics2Config
from .models.idefics3 import Idefics3Config
from .models.imagegpt import ImageGPTConfig
from .models.informer import InformerConfig
from .models.instructblip import (
@ -6037,6 +6048,7 @@ if TYPE_CHECKING:
from .models.grounding_dino import GroundingDinoImageProcessor
from .models.idefics import IdeficsImageProcessor
from .models.idefics2 import Idefics2ImageProcessor
from .models.idefics3 import Idefics3ImageProcessor
from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor
from .models.instructblipvideo import InstructBlipVideoImageProcessor
from .models.layoutlmv2 import (
@ -7071,6 +7083,12 @@ if TYPE_CHECKING:
Idefics2PreTrainedModel,
Idefics2Processor,
)
from .models.idefics3 import (
Idefics3ForConditionalGeneration,
Idefics3Model,
Idefics3PreTrainedModel,
Idefics3Processor,
)
from .models.imagegpt import (
ImageGPTForCausalImageModeling,
ImageGPTForImageClassification,

View File

@ -162,6 +162,7 @@ def _rescale_for_pil_conversion(image):
def to_pil_image(
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None,
image_mode: Optional[str] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> "PIL.Image.Image":
"""
@ -175,6 +176,8 @@ def to_pil_image(
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
and `False` otherwise.
image_mode (`str`, *optional*):
The mode to use for the PIL image. If unset, will use the default mode for the input image type.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
@ -207,7 +210,7 @@ def to_pil_image(
image = rescale(image, 255)
image = image.astype(np.uint8)
return PIL.Image.fromarray(image)
return PIL.Image.fromarray(image, mode=image_mode)
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366

View File

@ -115,6 +115,7 @@ from . import (
ibert,
idefics,
idefics2,
idefics3,
imagegpt,
informer,
instructblip,

View File

@ -133,6 +133,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("ibert", "IBertConfig"),
("idefics", "IdeficsConfig"),
("idefics2", "Idefics2Config"),
("idefics3", "Idefics3Config"),
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
@ -432,6 +433,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("ibert", "I-BERT"),
("idefics", "IDEFICS"),
("idefics2", "Idefics2"),
("idefics3", "Idefics3"),
("imagegpt", "ImageGPT"),
("informer", "Informer"),
("instructblip", "InstructBLIP"),

View File

@ -89,6 +89,7 @@ else:
("hiera", ("BitImageProcessor",)),
("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)),
("idefics3", ("Idefics3ImageProcessor",)),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor",)),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),

View File

@ -130,6 +130,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("ibert", "IBertModel"),
("idefics", "IdeficsModel"),
("idefics2", "Idefics2Model"),
("idefics3", "Idefics3Model"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jamba", "JambaModel"),
@ -315,6 +316,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMaskedLM"),
("idefics", "IdeficsForVisionText2Text"),
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
("layoutlm", "LayoutLMForMaskedLM"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
@ -733,6 +735,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
("chameleon", "ChameleonForConditionalGeneration"),
("git", "GitForCausalLM"),
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),

View File

@ -65,6 +65,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("hubert", "Wav2Vec2Processor"),
("idefics", "IdeficsProcessor"),
("idefics2", "Idefics2Processor"),
("idefics3", "Idefics3Processor"),
("instructblip", "InstructBlipProcessor"),
("instructblipvideo", "InstructBlipVideoProcessor"),
("kosmos-2", "Kosmos2Processor"),

View File

@ -219,6 +219,7 @@ else:
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("idefics3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
(

View File

@ -1098,7 +1098,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
std = (
self.config.text_config.initializer_range
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)

View File

@ -0,0 +1,72 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_idefics3": ["Idefics3Config"]}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_idefics3"] = ["Idefics3ImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_idefics3"] = [
"Idefics3ForConditionalGeneration",
"Idefics3PreTrainedModel",
"Idefics3Model",
]
_import_structure["processing_idefics3"] = ["Idefics3Processor"]
if TYPE_CHECKING:
from .configuration_idefics3 import Idefics3Config
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_idefics3 import Idefics3ImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_idefics3 import (
Idefics3ForConditionalGeneration,
Idefics3Model,
Idefics3PreTrainedModel,
)
from .processing_idefics3 import Idefics3Processor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@ -0,0 +1,207 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Idefics3 model configuration"""
import os
from typing import Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class Idefics3VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Idefics3VisionModel`]. It is used to instantiate a
Idefics3 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) used in the Idefics3 model
[HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1152):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 32):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
intializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation for initializing all weight matrices in the model.
Example:
```python
>>> from transformers.models.idefics3.modeling_idefics3 import Idefics3VisionTransformer
>>> from transformers.models.idefics3.configuration_idefics3 import Idefics3VisionConfig
>>> # Initializing a Idefics3VisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = Idefics3VisionConfig()
>>> # Initializing a Idefics3VisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = Idefics3VisionTransformer(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "idefics3"
def __init__(
self,
hidden_size=1152,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=16,
num_channels=3,
image_size=224,
patch_size=32,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from Idefics3Config
if config_dict.get("model_type") == "idefics3":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class Idefics3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Idefics3Model`]. It is used to instantiate a
Idefics3 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the model of the Idefics3
[HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should cache the key/value pairs of the attention mechanism. Only
relevant if `config.is_decoder=True`.
image_token_id (`int`, *optional*, defaults to 128257):
The id of the "image" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether or not to tie the word embeddings with the token embeddings.
vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
Custom vision config or dict for the vision tower
text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
Custom text config or dict for the text model
scale_factor (`int`, *optional*, defaults to 2):
The scale factor for the image encoder.
pad_token_id (`int`, *optional*, defaults to 128002):
The id of the padding token.
Example:
```python
>>> from transformers import Idefics3Model, Idefics3Config
>>> # Initializing configuration
>>> configuration = Idefics3Config()
>>> # Initializing a model from the configuration
>>> model = Idefics3Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "idefics3"
is_composition = True
def __init__(
self,
use_cache=True,
image_token_id=128257,
tie_word_embeddings=False,
vision_config=None,
text_config=None,
scale_factor=2,
pad_token_id=128_002,
**kwargs,
):
self.image_token_id = image_token_id
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
if vision_config is None:
self.vision_config = Idefics3VisionConfig()
logger.info("vision_config is None, using default vision config")
elif isinstance(vision_config, dict):
self.vision_config = Idefics3VisionConfig(**vision_config)
elif isinstance(vision_config, Idefics3VisionConfig):
self.vision_config = vision_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("text_config is None, using default text config")
text_config = CONFIG_MAPPING["llama"](
rms_norm_eps=1e-5,
pad_token_id=pad_token_id,
tie_word_embeddings=False,
)
self.text_config = text_config
self.scale_factor = scale_factor
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)

View File

@ -0,0 +1,214 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Idefics3Config,
Idefics3ForConditionalGeneration,
Idefics3ImageProcessor,
Idefics3Processor,
LlamaConfig,
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/idefics3/convert_idefics3_weights_to_hf.py --original_model_id HuggingFaceM4/Idefics3-8B-Llama3 --output_hub_path org/idefics3
"""
KEYS_TO_MODIFY_MAPPING = {
"lm_head.weight": "lm_head.linear.weight",
"model.layers": "model.text_model.layers",
"model.norm": "model.text_model.norm",
"model.modality_projection": "model.connector.modality_projection",
}
WEIGHTS_TO_MERGE_MAPPING = (
# (weights to merge in merging order), (new weight name)
(
("model.embed_tokens.weight", "model.embed_tokens.additional_embedding.weight"),
"model.text_model.embed_tokens.weight",
),
(("lm_head.linear.weight", "additional_fc.weight"), "lm_head.weight"),
)
WEIGHTS_TO_DROP = (
# The original model had a vision head, but this is never used
"model.vision_model.head",
)
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
old_state_dict_keys = set(state_dict.keys())
# Flattened list of weights to merge. We keep these in the original state dict to merge them later
original_weights_to_merge = [w for weights in WEIGHTS_TO_MERGE_MAPPING for w in weights[0]]
# for key, value in state_dict.items():
for old_key in old_state_dict_keys:
if old_key.endswith(".inv_freq") or any(w in old_key for w in WEIGHTS_TO_DROP):
state_dict.pop(old_key)
continue
key = old_key
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
weight = state_dict.pop(old_key)
if key in original_weights_to_merge:
new_state_dict[key] = weight
# Bit of a hack - we need to keep the original weights to merge them later
state_dict[key] = weight
else:
new_state_dict[key] = weight
return new_state_dict
def merge_weights(state_dict, new_state_dict):
old_weight_names = set(state_dict.keys())
# Merge the weights
for weights_to_merge, new_weight_name in WEIGHTS_TO_MERGE_MAPPING:
for weight_to_merge in weights_to_merge:
print(weight_to_merge)
assert weight_to_merge in state_dict, f"Weight {weight_to_merge} is missing in the state dict"
weight = state_dict.pop(weight_to_merge)
if new_weight_name not in new_state_dict:
new_state_dict[new_weight_name] = [weight]
else:
new_state_dict[new_weight_name].append(weight)
old_weight_names.remove(weight_to_merge)
new_state_dict[new_weight_name] = torch.cat(new_state_dict[new_weight_name], dim=0)
# Remove the weights that were merged
for weights_to_merge, new_weight_name in WEIGHTS_TO_MERGE_MAPPING:
for weight in weights_to_merge:
if weight in new_state_dict and weight != new_weight_name:
new_state_dict.pop(weight)
return new_state_dict
def get_config(checkpoint):
# We load the config then recreate to use the text_config
# download the config file
filepath = hf_hub_download(repo_id=checkpoint, filename="config.json")
with open(filepath, "r") as f:
config_json = json.load(f)
# Setup the vision config
vision_config = config_json.pop("vision_config")
vision_config.pop("vision_model_name", None)
if "embed_dim" in vision_config:
vision_config["hidden_size"] = vision_config.pop("embed_dim")
config_json["vocab_size"] = config_json.pop("vocab_size") + config_json.pop("additional_vocab_size")
image_token_id = config_json.pop("image_token_id", config_json["vocab_size"] - 2)
use_cache = config_json.pop("use_cache", True)
tie_word_embeddings = config_json.pop("tie_word_embeddings", True)
scale_factor = config_json.pop("scale_factor", 2)
vocab_size = config_json.pop("vocab_size", 100000)
# Remove "freeze" params from the config
config_json = {k: v for k, v in config_json.items() if not k.startswith("freeze_")}
text_config = LlamaConfig(**config_json)
config = Idefics3Config(
text_config=text_config,
vision_config=vision_config,
use_cache=use_cache,
image_token_id=image_token_id,
tie_word_embeddings=tie_word_embeddings,
scale_factor=scale_factor,
vocab_size=vocab_size,
)
return config
def convert_idefics3_hub_to_hf(original_model_id, output_hub_path, push_to_hub):
# The original model maps to AutoModelForCausalLM, converted we map to Idefics3ForConditionalGeneration
original_model = AutoModelForCausalLM.from_pretrained(
original_model_id, trust_remote_code=True, torch_dtype=torch.bfloat16
)
# The original model doesn't use the Idefics3 processing objects
image_processor = Idefics3ImageProcessor()
tokenizer = AutoTokenizer.from_pretrained(original_model_id)
processor = Idefics3Processor(
image_processor=image_processor,
tokenizer=tokenizer,
)
state_dict = original_model.state_dict()
new_state_dict = convert_state_dict_to_hf(state_dict)
# Merge weights
new_state_dict = merge_weights(state_dict, new_state_dict)
del state_dict
config = get_config(original_model_id)
print(config)
with init_empty_weights():
model = Idefics3ForConditionalGeneration(config)
model.load_state_dict(new_state_dict, strict=True, assign=True)
model.save_pretrained(output_hub_path)
processor.save_pretrained(output_hub_path)
if push_to_hub:
model.push_to_hub(output_hub_path, private=True)
processor.push_to_hub(output_hub_path, private=True)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--original_model_id",
help="Hub location of the text model",
)
parser.add_argument(
"--output_hub_path",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="If set, the model will be pushed to the hub after conversion.",
)
args = parser.parse_args()
convert_idefics3_hub_to_hf(args.original_model_id, args.output_hub_path, args.push_to_hub)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,890 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import PaddingMode, pad, to_channel_dimension_format, to_pil_image
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
from PIL import Image
def _resize_output_size_rescale_to_max_len(
height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
) -> Tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
min_len (`int`, *optional*, defaults to 1):
Minimum size of the output image.
max_len (`int`, *optional*, defaults to the maximum size of the image):
Maximum size of the output image.
Returns:
The output size of the image after resizing.
"""
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
if width >= height:
width = max_len
height = int(width / aspect_ratio)
if height % 2 != 0:
height += 1
elif height > width:
height = max_len
width = int(height * aspect_ratio)
if width % 2 != 0:
width += 1
# Avoid resizing to a size smaller than min_len
height = max(height, min_len)
width = max(width, min_len)
return height, width
def _resize_output_size_scale_below_upper_bound(
height: int, width: int, max_len: Optional[Dict[str, int]] = None
) -> Tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
Defines the maximum dimensions of the image.
Returns:
The output size of the image after resizing.
"""
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
if width >= height and width > max_len:
width = max_len
height = int(width / aspect_ratio)
elif height > width and height > max_len:
height = max_len
width = int(height * aspect_ratio)
# Avoid resizing to a size smaller than 1
height = max(height, 1)
width = max(width, 1)
return height, width
def get_resize_output_image_size(
image,
resolution_max_side: int,
max_image_size: int = 1820,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
image (`np.ndarray`):
Image to resize.
resolution_max_side (`int`):
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
input aspect ratio, with a lower bound of `min_image_size`.
max_image_size (`int`, *optional*, defaults to 1820):
Maximum image resolution. If the image is larger than this size, the longest edge will be resized to this
value, with the shortest edge resized to keep the input aspect ratio, with a lower bound of `min_image_size`.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
The output size of the image after resizing.
"""
if resolution_max_side > max_image_size:
raise ValueError("`resolution_max_side` cannot be larger than `max_image_size`")
height, width = get_image_size(image, channel_dim=input_data_format)
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
# Find the output size when scaling the image to be below the max_image_size
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=max_image_size)
return height, width
# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
"""
Convert a single image or a list of images to a list of numpy arrays.
Args:
images (`ImageInput`):
A single image or a list of images.
Returns:
A list of numpy arrays.
"""
# If it's a single image, convert it to a list of lists
if is_valid_image(images):
images = [[images]]
# If it's a list of images, it's a single batch, so convert it to a list of lists
elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
images = [images]
# If it's a list of batches, it's already in the right format
elif (
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and is_valid_image(images[0][0])
):
pass
else:
raise ValueError(
"Invalid input type. Must be a single image, a list of images, or a list of batches of images."
)
return images
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
"""
return [max(values_i) for values_i in zip(*values)]
def get_max_height_width(
images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
max_height = max_width = float("-inf")
for images in images_list:
for image in images:
height, width = get_image_size(image, channel_dim=input_data_format)
max_height = max(height, max_height)
max_width = max(width, max_width)
return (max_height, max_width)
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(
image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
def convert_to_rgb(
image: np.ndarray,
palette: Optional[PIL.ImagePalette.ImagePalette] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> ImageInput:
"""
Converts an image to RGB format.
Args:
image (`np.ndarray`):
The image to convert.
palette (List[int], *optional*):
The palette to use if given.
data_format (ChannelDimension or str, *optional*):
The channel dimension format for the output image. If not provided, it will be the same as the input image.
input_data_format (ChannelDimension or str, *optional*):
The channel dimension format of the input image.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
# The resized image from PIL will always have channels last, so find the input format first.
data_format = input_data_format if data_format is None else data_format
mode = "P" if palette is not None else None
image = to_pil_image(image, image_mode=mode)
if image.mode == "P" and palette is not None:
image.putpalette(palette)
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
output_array = np.array(alpha_composite)
# The image is always in channels last format after converting from a PIL image
output_array = to_channel_dimension_format(output_array, data_format, input_channel_dim=ChannelDimension.LAST)
return output_array
# FIXME Amy: make a more general crop function that isn't just centre crop
def _crop(
image: np.ndarray,
w1: int,
h1: int,
w2: int,
h2: int,
data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
if data_format is None:
data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
if data_format == ChannelDimension.FIRST:
image = image[:, h1:h2, w1:w2]
elif data_format == ChannelDimension.LAST:
image = image[h1:h2, w1:w2, :]
else:
raise ValueError("Invalid channel dimension format.")
return image
class Idefics3ImageProcessor(BaseImageProcessor):
r"""
Constructs a Idefics3 image processor.
Args:
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
Only has an effect if the input image is in the PIL format.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
shortest edge resized to keep the input aspect ratio.
size (`Dict`, *optional*, defaults to `{"longest_edge": 4 * 364}`):
Controls the size of the output image. This is a dictionary containing the key "longest_edge".
The image will be resized such that the longest edge is <= `size["longest_edge"]` and the shortest edge is resized
to keep the input aspect ratio.
resample (`Resampling`, *optional*, defaults to `Resampling.LANCZOS`):
Resampling filter to use when resizing the image.
do_image_splitting (`bool`, *optional*, defaults to `True`):
Whether to split the image into sub-images concatenated with the original image. They are split into patches
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
rescale_factor (`float`, *optional*, defaults to `1/255`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
a standard deviation of `image_std`.
image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether or not to pad the images to the largest height and width in the batch and number of images per
sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_convert_rgb: bool = True,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.LANCZOS,
do_image_splitting: bool = True,
max_image_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.do_convert_rgb = do_convert_rgb
self.do_resize = do_resize
self.size = size if size is not None else {"longest_edge": 4 * 364}
self.resample = resample
self.do_image_splitting = do_image_splitting
self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 364}
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.LANCZOS,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image. The longest edge of the image is resized to size["longest_edge"], with the shortest edge
resized to keep the input aspect ratio. Can also be used with size["height"] and size["width"].
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
Resampling filter to use when resizing the image.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the output image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
# The resized image from PIL will always have channels last, so find the input format first.
data_format = input_data_format if data_format is None else data_format
if "longest_edge" in size:
size = get_resize_output_image_size(
image, resolution_max_side=size["longest_edge"], input_data_format=input_data_format
)
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
image_mode = None
if image.ndim == 2 or image.shape[-1] == 1:
image_mode = "P"
image = to_pil_image(image, image_mode=image_mode)
resized_image = image.resize((size[1], size[0]), resample=resample)
resized_image = np.array(resized_image)
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
# so we need to add it back if necessary.
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
# The image is always in channels last format after converting from a PIL image
resized_image = to_channel_dimension_format(
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
)
return resized_image
def split_image(
self,
image,
max_image_size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.LANCZOS,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Split an image into squares of side max_image_size and the original image resized to max_image_size.
That means that a single image becomes a sequence of images.
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
sub-images of the same size each (image_size, image_size). Typically, 364x364.
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
Args:
image (`np.ndarray`):
Images to split.
max_image_size (`Dict[str, int]`):
Maximum size of the output image. If the image is larger than this size, it will be split into
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
Resampling filter to use when resizing the image.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the output image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
height, width = get_image_size(image, channel_dim=input_data_format)
max_height = max_width = max_image_size["longest_edge"]
frames = []
if height > max_height or width > max_width:
# Calculate the number of splits
num_splits_h = math.ceil(height / max_height)
num_splits_w = math.ceil(width / max_width)
# Calculate the optimal width and height for the sub-images
optimal_height = math.ceil(height / num_splits_h)
optimal_width = math.ceil(width / num_splits_w)
# Iterate through each row and column
for r in range(num_splits_h):
for c in range(num_splits_w):
# Calculate the starting point of the crop
start_x = c * optimal_width
start_y = r * optimal_height
# Calculate the ending point of the crop
end_x = min(start_x + optimal_width, width)
end_y = min(start_y + optimal_height, height)
# Crop the image
cropped_image = _crop(
image,
start_x,
start_y,
end_x,
end_y,
data_format=data_format,
)
frames.append(cropped_image)
# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
global_image_height, global_image_width = max_height, max_width
if height != global_image_height or width != global_image_width:
image = self.resize(
image,
{"height": global_image_height, "width": global_image_width},
resample=resample,
input_data_format=data_format,
)
else:
num_splits_h, num_splits_w = 0, 0
frames.append(image)
return frames, num_splits_h, num_splits_w
def resize_for_vision_encoder(
self,
image: np.ndarray,
vision_encoder_max_size: int,
resample: PILImageResampling = PILImageResampling.LANCZOS,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
Args:
image (`np.ndarray`):
Images to resize.
vision_encoder_max_size (`int`):
Maximum size of the output image. If the image is larger than this size, it will be split into
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
Resampling filter to use when resizing the image.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the output image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred
"""
height, width = get_image_size(image, channel_dim=input_data_format)
aspect_ratio = width / height
if width >= height:
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
height = int(width / aspect_ratio)
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
elif height > width:
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
width = int(height * aspect_ratio)
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
new_size = {"height": height, "width": width}
return self.resize(
image, size=new_size, resample=resample, input_data_format=input_data_format, data_format=data_format
)
def _pad_image(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Pad an image with zeros to the given size.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
return padded_image
def pad(
self,
images: List[np.ndarray],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
"""
For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
Args:
images (`List[np.ndarray]`):
List of list of images to pad. Pads to the largest height and width in the batch.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
pad_size = get_max_height_width(images, input_data_format=input_data_format)
batch_size = len(images)
max_num_images = max(len(images_) for images_ in images)
input_data_format = (
infer_channel_dimension_format(images[0][0], num_channels=(1, 3, 4))
if input_data_format is None
else input_data_format
)
data_format = input_data_format if data_format is None else data_format
if input_data_format == ChannelDimension.FIRST:
n_channels = images[0][0].shape[0]
elif input_data_format == ChannelDimension.LAST:
n_channels = images[0][0].shape[-1]
else:
raise ValueError("Invalid channel dimension format.")
def empty_image(size, input_data_format):
if input_data_format == ChannelDimension.FIRST:
return np.zeros((n_channels, *size), dtype=np.uint8)
elif input_data_format == ChannelDimension.LAST:
return np.zeros((*size, n_channels), dtype=np.uint8)
padded_images_list = [
[empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
]
padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
for batch_idx in range(batch_size):
for sample_idx, image in enumerate(images[batch_idx]):
padded_images_list[batch_idx][sample_idx] = self._pad_image(
image,
pad_size,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
padded_masks[batch_idx][sample_idx] = make_pixel_mask(
image, output_size=pad_size, input_data_format=input_data_format
)
padded_masks = padded_masks if return_pixel_mask else None
return padded_images_list, padded_masks
def preprocess(
self,
images: ImageInput,
do_convert_rgb: Optional[bool] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_image_splitting: Optional[bool] = None,
do_rescale: Optional[bool] = None,
max_image_size: Optional[Dict[str, int]] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_row_col_info: bool = False,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Preprocess a batch of images.
Args:
images (`ImageInput`):
A list of images to preprocess.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. With the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
Whether to split the image into sub-images concatenated with the original image. They are split into patches
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
max_image_size (`Dict`, *optional*, defaults to `self.max_image_size`):
Maximum resolution of the images. If the image is larger than this size, the image is split into patches.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether or not to pad the images to the largest height and width in the batch.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
return_row_col_info (`bool`, *optional*, default to `False`):
Whether to return the number of rows and columns of the split images. This is used for the
`Idefics3Processor` to generate prompt strings based on the number of rows and columns.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
max_image_size = max_image_size if max_image_size is not None else self.max_image_size
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
do_pad = do_pad if do_pad is not None else self.do_pad
images_list = make_list_of_images(images)
if not valid_images(images_list[0]):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# save the palettes for conversion to RGB
palettes_list = [
[im.getpalette() if isinstance(im, Image.Image) and im.mode == "P" else None for im in images]
for images in images_list
]
# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]
if is_scaled_image(images_list[0][0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
# We assume that all images have the same channel dimension format.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
# Extra channel dimension for grayscale images
if input_data_format == ChannelDimension.LAST:
images_list = [
[np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
]
elif input_data_format == ChannelDimension.FIRST:
images_list = [
[np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
]
else:
raise ValueError(f"Invalid channel dimension format {input_data_format}.")
if do_resize:
images_list = [
[
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
for images in images_list
]
if do_image_splitting:
# We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
# for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
# for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
images_list = [
[
self.resize_for_vision_encoder(
image, max_image_size["longest_edge"], resample=resample, input_data_format=input_data_format
)
for image in images
]
for images in images_list
]
images_list_split_arrays = []
palettes_list_split_arrays = []
images_list_rows = []
images_list_cols = []
for images, palettes in zip(images_list, palettes_list):
split_image_arrays = []
split_palettes_arrays = []
image_rows = []
image_cols = []
for image, palette in zip(images, palettes):
split_image_array, rows, cols = self.split_image(
image,
max_image_size=max_image_size,
input_data_format=input_data_format,
)
split_image_arrays.extend(split_image_array)
split_palettes_arrays.extend([palette] * len(split_image_array))
image_rows.append(rows)
image_cols.append(cols)
images_list_split_arrays.append(split_image_arrays)
palettes_list_split_arrays.append(split_palettes_arrays)
images_list_rows.append(image_rows)
images_list_cols.append(image_cols)
images_list = images_list_split_arrays
palettes_list = palettes_list_split_arrays
else:
# We square the images to max_image_size
images_list = [
[
self.resize(
image=image,
size={"height": max_image_size["longest_edge"], "width": max_image_size["longest_edge"]},
resample=resample,
input_data_format=input_data_format,
)
for image in images
]
for images in images_list
]
images_list_rows = [[0] * len(images) for images in images_list]
images_list_cols = [[0] * len(images) for images in images_list]
if do_convert_rgb:
images_list = [
[convert_to_rgb(img, palette) for img, palette in zip(images, palettes)]
for images, palettes in zip(images_list, palettes_list)
]
if do_rescale:
images_list = [
[self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
for images in images_list
]
if do_normalize:
images_list = [
[
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
for images in images_list
]
pixel_attention_mask = None
if do_pad:
images_list, pixel_attention_mask = self.pad(
images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
)
if data_format is not None:
images_list = [
[
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
for image in images
]
for images in images_list
]
# Faster tensor conversion
data = {"pixel_values": np.array(images_list) if do_pad and return_tensors is not None else images_list}
if pixel_attention_mask is not None:
data["pixel_attention_mask"] = (
np.array(pixel_attention_mask) if do_pad and return_tensors is not None else pixel_attention_mask
)
encoding = BatchFeature(data=data, tensor_type=return_tensors)
# This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
if return_row_col_info:
encoding["rows"] = images_list_rows
encoding["cols"] = images_list_cols
return encoding

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,344 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Idefics3.
"""
import re
import sys
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
from ...utils import logging
if TYPE_CHECKING:
from ...tokenization_utils_base import PreTokenizedInput
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
logger = logging.get_logger(__name__)
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)
def _prompt_split_image(image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}" + f"<row_{n_h + 1}_col_{n_w + 1}>" + f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def _prompt_single_image(image_seq_len, fake_token_around_image, image_token, global_img_token):
"""Prompt with expanded image tokens for a single image."""
return (
f"{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
def get_image_prompt_string(
image_rows, image_cols, image_seq_len, fake_token_around_image, image_token, global_img_token
):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(
image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_img_token
)
class Idefics3ImagesKwargs(ImagesKwargs, total=False):
return_row_col_info: Optional[bool]
max_image_size: Optional[Dict[str, int]]
class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Idefics3ImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"is_split_into_words": False,
},
"images_kwargs": {
"return_row_col_info": True,
},
}
Idefics3ProcessorKwargs.__annotations__["images_kwargs"] = Idefics3ImagesKwargs # python 3.8 compatibility
class Idefics3Processor(ProcessorMixin):
r"""
Constructs a Idefics3 processor which wraps a LLama tokenizer and Idefics3 image processor into a single processor.
[`Idefics3Processor`] offers all the functionalities of [`Idefics3ImageProcessor`] and [`Idefics3TokenizerFast`]. See
the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
Args:
image_processor (`Idefics3ImageProcessor`):
An instance of [`Idefics3ImageProcessor`]. The image processor is a required input.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
image_seq_len (`int`, *optional*, defaults to 169):
The length of the image sequence i.e. the number of <image> tokens per image in the input.
This parameter is used to build the string from the input prompt and image tokens and should match the
value the model used. It is computed as: image_seq_len = int(((image_size // patch_size) ** 2) / (scale_factor**2))
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "Idefics3ImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 169, chat_template: str = None, **kwargs):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True)
self.image_token = AddedToken("<image>", normalized=False, special=True)
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True)
self.global_image_tag = "<global-img>" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341
self.image_seq_len = image_seq_len
# This regex matches one or more occurrences of <global-img> tags (optionally surrounded by newline characters)
# or <row_x_col_y> tags (where x and y are digits, also optionally surrounded by newline characters).
self._regex_to_remove_extra_special_tokens = re.compile(r"(\n?<global-img>\n?|<row_\d+_col_\d+>\n?)+")
tokens_to_add = {
"additional_special_tokens": [
self.fake_image_token,
self.image_token,
self.end_of_utterance_token,
]
}
tokenizer.add_special_tokens(tokens_to_add)
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
def _extract_images_from_prompts(self, prompts):
prompt_images = []
for prompt in prompts:
images = []
for elem in prompt:
if is_valid_image(elem):
images.append(elem)
elif is_url(elem):
images.append(load_image(elem))
prompt_images.append(images)
return prompt_images
def __call__(
self,
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
audio=None,
videos=None,
image_seq_len: Optional[int] = None,
**kwargs: Unpack[Idefics3ProcessorKwargs],
) -> BatchEncoding:
"""
Processes the input prompts and returns a BatchEncoding.
Example:
```python
>>> import requests
>>> from transformers import Idefics3Processor
>>> from transformers.image_utils import load_image
>>> processor = Idefics3Processor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
>>> processor.image_processor.do_image_splitting = False # Force as False to simplify the example
>>> url1 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
>>> url2 = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
>>> image1, image2 = load_image(url1), load_image(url2)
>>> images = [[image1], [image2]]
>>> text = [
... "<image>In this image, we see",
... "bla bla bla<image>",
... ]
>>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True)
>>> input_ids = outputs.input_ids
>>> input_tokens = processor.tokenizer.batch_decode(input_ids)
>>> print(input_tokens)
['<|begin_of_text|><fake_token_around_image><global-img>((<image>)*169)<fake_token_around_image> In this image, we see', '<|reserved_special_token_0|><|reserved_special_token_0|><|reserved_special_token_0|><|begin_of_text|>bla bla bla<fake_token_around_image><global-img>((<image>)*169)<fake_token_around_image>']
```
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
text (`Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
Wherever an image token, `<image>` is encountered it is expanded to
`<fake_token_around_image>` + `<row_x_col_y>` + `<image>` * `image_seq_len` * <fake_token_around_image>`.
image_seq_len (`int`, *optional*):
The length of the image sequence. If not provided, the default value of self.image_seq_len is used.
image_seq_len should be equal to int(((image_size // patch_size) ** 2) / (scale_factor**2))
return_tensors (`Union[str, TensorType]`, *optional*):
If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
information.
"""
if text is None and images is None:
raise ValueError("You must provide either `text` or `images`.")
output_kwargs = self._merge_kwargs(
Idefics3ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Temporary fix for "padding_side" in init_kwargs
output_kwargs["text_kwargs"].pop("padding_side", None)
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
n_images_in_text = []
n_images_in_images = []
inputs = BatchFeature()
if images is not None:
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
and not is_image_or_image_url(images[0][0])
):
raise ValueError(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
)
n_images_in_images = [len(sample) for sample in images]
# Load images if they are URLs
images = [[load_image(im) if is_url(im) else im for im in sample] for sample in images]
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
inputs.update(image_inputs)
if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])
fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
global_img_token = self.global_image_tag
prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
n_images_in_text.append(sample.count(image_token))
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")
# Place in the image prompt strings where the image tokens are
sample = split_sample[0]
for i, image_prompt_string in enumerate(image_prompt_strings):
sample += image_prompt_string + split_sample[i + 1]
prompt_strings.append(sample)
text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)
if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)
return inputs
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Idefics3TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
batched_decode_output = self.tokenizer.batch_decode(*args, **kwargs)
return [self._regex_to_remove_extra_special_tokens.sub("<image>", s) for s in batched_decode_output]
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Idefics3TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
decode_output = self.tokenizer.decode(*args, **kwargs)
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

View File

@ -307,6 +307,17 @@ class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, Comm
}
```
For Python 3.8 compatibility, when inheriting from this class and overriding one of the kwargs,
you need to manually update the __annotations__ dictionary. This can be done as follows:
```python
class CustomProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: CustomImagesKwargs
CustomProcessorKwargs.__annotations__["images_kwargs"] = CustomImagesKwargs # python 3.8 compatibility
```python
"""
common_kwargs: CommonKwargs = {

View File

@ -4870,6 +4870,34 @@ class Idefics2Processor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Idefics3ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Idefics3Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Idefics3PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Idefics3Processor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ImageGPTForCausalImageModeling(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -296,6 +296,13 @@ class Idefics2ImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class Idefics3ImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class ImageGPTFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]

View File

View File

@ -0,0 +1,285 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.image_utils import PILImageResampling
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin
if is_vision_available():
from PIL import Image
from transformers import Idefics3ImageProcessor
if is_torch_available():
import torch
class Idefics3ImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
num_images=1,
image_size=18,
min_resolution=30,
max_resolution=40,
do_resize=True,
size=None,
max_image_size=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_convert_rgb=True,
do_pad=True,
do_image_splitting=True,
resample=PILImageResampling.LANCZOS,
):
super().__init__()
self.size = size if size is not None else {"longest_edge": max_resolution}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.num_images = num_images
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.resample = resample
self.do_image_splitting = do_image_splitting
self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 20}
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.do_pad = do_pad
def prepare_image_processor_dict(self):
return {
"do_convert_rgb": self.do_convert_rgb,
"do_resize": self.do_resize,
"size": self.size,
"max_image_size": self.max_image_size,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_pad": self.do_pad,
"do_image_splitting": self.do_image_splitting,
}
def get_expected_values(self, image_inputs, batched=False):
"""
This function computes the expected height and width when providing images to Idefics3ImageProcessor,
assuming do_resize is set to True. The expected size in that case the max image size.
"""
return self.max_image_size["longest_edge"], self.max_image_size["longest_edge"]
def expected_output_image_shape(self, images):
height, width = self.get_expected_values(images, batched=True)
effective_nb_images = (
self.num_images * 5 if self.do_image_splitting else 1
) # 5 is a squared image divided into 4 + global image resized
return effective_nb_images, self.num_channels, height, width
def prepare_image_inputs(
self,
batch_size=None,
min_resolution=None,
max_resolution=None,
num_channels=None,
num_images=None,
size_divisor=None,
equal_resolution=False,
numpify=False,
torchify=False,
):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
One can specify whether the images are of the same resolution or not.
"""
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
batch_size = batch_size if batch_size is not None else self.batch_size
min_resolution = min_resolution if min_resolution is not None else self.min_resolution
max_resolution = max_resolution if max_resolution is not None else self.max_resolution
num_channels = num_channels if num_channels is not None else self.num_channels
num_images = num_images if num_images is not None else self.num_images
images_list = []
for i in range(batch_size):
images = []
for j in range(num_images):
if equal_resolution:
width = height = max_resolution
else:
# To avoid getting image width/height 0
if size_divisor is not None:
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
min_resolution = max(size_divisor, min_resolution)
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
images_list.append(images)
if not numpify and not torchify:
# PIL expects the channel dimension as last dimension
images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list]
if torchify:
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
if numpify:
# Numpy images are typically in channels last format
images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list]
return images_list
@require_torch
@require_vision
class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Idefics3ImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = Idefics3ImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
self.assertTrue(hasattr(image_processing, "max_image_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
for image in sample_images:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_numpy_4_channels(self):
# Idefics3 always processes images as RGB, so it always returns images with 3 channels
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processor_dict = self.image_processor_dict
image_processing = self.image_processing_class(**image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
for image in sample_images:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for images in image_inputs:
for image in images:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
)
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for images in image_inputs:
for image in images:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
# Test batched
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)

View File

@ -0,0 +1,529 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Idefics3 model."""
import copy
import gc
import unittest
from io import BytesIO
import requests
from transformers import (
AutoProcessor,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import (
Idefics3Config,
Idefics3ForConditionalGeneration,
Idefics3Model,
)
else:
is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
class Idefics3VisionText2TextModelTester:
def __init__(
self,
parent,
is_training=True,
batch_size=2,
scale_factor=2,
num_images=2,
vision_config={
"image_size": 16,
"patch_size": 4,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 32,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
text_config={
"vocab_size": 100,
"hidden_size": 64,
"intermediate_size": 56,
"num_hidden_layers": 3,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"hidden_act": "silu",
"max_position_embeddings": 256,
"initializer_range": 0.02,
"rms_norm_eps": 1e-6,
"pad_token_id": 2,
"bos_token_id": 0,
"eos_token_id": 1,
"image_token_id": 57,
"tie_word_embeddings": False,
"rope_theta": 10000.0,
"sliding_window": 32,
"attention_dropout": 0.0,
},
use_cache=False,
tie_word_embeddings=False,
image_token_id=57,
):
self.parent = parent
self.is_training = is_training
self.batch_size = batch_size
self.num_images = num_images
self.scale_factor = scale_factor
self.seq_length = (
int(((vision_config["image_size"] // vision_config["patch_size"]) ** 2) / (self.scale_factor**2))
* self.num_images
)
self.use_cache = use_cache
self.image_token_id = image_token_id
self.tie_word_embeddings = tie_word_embeddings
# Hack - add properties here so use common tests
self.vocab_size = text_config["vocab_size"]
self.num_hidden_layers = text_config["num_hidden_layers"]
self.num_attention_heads = text_config["num_attention_heads"]
self.hidden_size = text_config["hidden_size"]
self.vision_config = vision_config
self.text_config = text_config
def get_config(self):
return Idefics3Config(
use_cache=self.use_cache,
image_token_id=self.image_token_id,
tie_word_embeddings=self.tie_word_embeddings,
vision_config=self.vision_config,
text_config=self.text_config,
vocab_size=self.vocab_size,
scale_factor=self.scale_factor,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.num_images,
3, # Idefics3ImageProcessor always generates RGB pixel values
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 1
# For simplicity just set the last n tokens to the image token
n_image_tokens_per_batch = self.seq_length
input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id
attention_mask = input_ids.ne(1).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class Idefics3ModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `Idefics3`.
"""
all_model_classes = (Idefics3Model,) if is_torch_available() else ()
fx_compatible = False
test_torchscript = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
def setUp(self):
self.model_tester = Idefics3VisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Idefics3Config, has_text_modality=False)
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
def test_inputs_embeds():
pass
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_generate_padding_right(self):
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self):
pass
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.text_config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# Ignore copy
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
n_images = self.model_tester.num_images * self.model_tester.seq_length
model.image_token_id = model_vocab_size - 15 - 1
inputs_dict["input_ids"][:, -n_images:] = model.image_token_id
# make sure that decoder_input_ids are resized as well
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.text_config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
self.assertTrue(model.config.text_config.vocab_size + 10, model_vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
self.assertTrue(model_embed.weight.shape[0], model.config.text_config.vocab_size)
self.assertTrue(model.config.text_config.vocab_size, model.vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
target_dimension = 128
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0], target_dimension)
with self.assertRaisesRegex(
ValueError,
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
):
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
# We need to override as we need to prepare such that the image token is the last token
def test_resize_embeddings_untied(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
original_config.tie_word_embeddings = False
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
# if no output embeddings -> leave test
if model.get_output_embeddings() is None:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
n_images = self.model_tester.num_images * self.model_tester.seq_length
model.image_token_id = model_vocab_size - 15 - 1
inputs_dict["input_ids"][:, -n_images:] = model.image_token_id
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@require_torch
class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
"""
Model tester for `Idefics3ForConditionalGeneration`.
"""
all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
test_torchscript = False
def setUp(self):
self.model_tester = Idefics3VisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Idefics3Config, has_text_modality=False)
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
def test_inputs_embeds():
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_generate_padding_right(self):
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self):
pass
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.text_config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
n_images = self.model_tester.num_images * self.model_tester.seq_length
model.model.image_token_id = model_vocab_size - 15 - 1
inputs_dict["input_ids"][:, -n_images:] = model.model.image_token_id
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.text_config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
self.assertTrue(model.config.text_config.vocab_size + 10, model_vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
self.assertTrue(model_embed.weight.shape[0], model.config.text_config.vocab_size)
self.assertTrue(model.config.text_config.vocab_size, model.vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
target_dimension = 128
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0], target_dimension)
with self.assertRaisesRegex(
ValueError,
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
):
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
# We need to override as we need to prepare such that the image token is the last token
def test_resize_embeddings_untied(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
original_config.tie_word_embeddings = False
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary - 1 and the image token should be the last token
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 2)
n_images = self.model_tester.num_images * self.model_tester.seq_length
model.model.image_token_id = model_vocab_size - 15 - 1
inputs_dict["input_ids"][:, -n_images:] = model.model.image_token_id
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
@require_torch
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
self.image1 = Image.open(
BytesIO(
requests.get(
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
).content
)
)
self.image2 = Image.open(
BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content)
)
self.image3 = Image.open(
BytesIO(
requests.get(
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
).content
)
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
@unittest.skip("multi-gpu tests are disabled for now")
def test_integration_test(self):
model = Idefics3ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/Idefics3-8B-Llama3",
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Create inputs
text = "<image>In this image, we see"
images = self.image1
inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
inputs.to(torch_device)
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
expected_generated_text = "<image>In this image, we see the Statue of Liberty, which is located on Liberty"
self.assertEqual(generated_texts[0], expected_generated_text)
@slow
@require_bitsandbytes
@unittest.skip("multi-gpu tests are disabled for now")
def test_integration_test_4bit(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Idefics3ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/Idefics3-8B-Llama3",
load_in_4bit=True,
device_map="auto",
)
# Create pixel inputs
text = ["<image>In this image, we see", "bla, bla <image><image>"]
images = [[self.image1], [self.image2, self.image3]]
inputs = self.processor(text=text, images=images, padding=True, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
expected_generated_text = "<image>In this image, we see the Statue of Liberty, trees, buildings, water"
self.assertEqual(generated_texts[0], expected_generated_text)

View File

@ -0,0 +1,462 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from io import BytesIO
import numpy as np
import requests
from transformers import Idefics3Processor
from transformers.models.auto.processing_auto import AutoProcessor
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from PIL import Image
@require_torch
@require_vision
class Idefics3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Idefics3Processor
@classmethod
def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
processor = Idefics3Processor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", image_seq_len=2)
processor.save_pretrained(cls.tmpdirname)
cls.image1 = Image.open(
BytesIO(
requests.get(
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
).content
)
)
cls.image2 = Image.open(
BytesIO(requests.get("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg").content)
)
cls.image3 = Image.open(
BytesIO(
requests.get(
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"
).content
)
)
cls.bos_token = processor.tokenizer.bos_token
cls.image_token = processor.image_token.content
cls.fake_image_token = processor.fake_image_token.content
cls.global_img_token = processor.global_image_tag
cls.bos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.bos_token)
cls.image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.image_token)
cls.fake_image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.fake_image_token)
cls.global_img_tokens_id = processor.tokenizer(cls.global_img_token, add_special_tokens=False)["input_ids"]
cls.padding_token_id = processor.tokenizer.pad_token_id
cls.image_seq_len = processor.image_seq_len
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def get_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
def get_split_image_expected_tokens(self, processor, image_rows, image_cols):
text_split_images = []
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
[self.fake_image_token_id]
+ processor.tokenizer(f"<row_{n_h + 1}_col_{n_w + 1}>", add_special_tokens=False)["input_ids"]
+ [self.image_token_id] * self.image_seq_len
)
text_split_images += processor.tokenizer("\n", add_special_tokens=False)["input_ids"]
text_split_images = text_split_images[:-1] # remove last newline
# add double newline, as it gets its own token
text_split_images += processor.tokenizer("\n\n", add_special_tokens=False)["input_ids"]
text_split_images += (
[self.fake_image_token_id]
+ self.global_img_tokens_id
+ [self.image_token_id] * self.image_seq_len
+ [self.fake_image_token_id]
)
return text_split_images
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname)
def test_process_interleaved_images_prompts_no_image_splitting(self):
processor = self.get_processor()
processor.image_processor.do_image_splitting = False
# Test that a single image is processed correctly
inputs = processor(images=self.image1)
image1_expected_size = (364, 364)
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 1, 3, *image1_expected_size))
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 1, *image1_expected_size))
# fmt: on
# Test a single sample with image and text
image_str = "<image>"
text_str = "In this image, we see"
text = image_str + text_str
inputs = processor(text=text, images=self.image1)
# fmt: off
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
expected_input_ids = [[self.bos_token_id] + [self.fake_image_token_id] + self.global_img_tokens_id + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 1, 3, *image1_expected_size))
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 1, *image1_expected_size))
# fmt: on
# Test that batch is correctly processed
image_str = "<image>"
text_str_1 = "In this image, we see"
text_str_2 = "In this image, we see"
text = [
image_str + text_str_1,
image_str + image_str + text_str_2,
]
images = [[self.image1], [self.image2, self.image3]]
inputs = processor(text=text, images=images, padding=True)
# fmt: off
tokenized_sentence_1 = processor.tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = processor.tokenizer(text_str_2, add_special_tokens=False)
image_tokens = [self.fake_image_token_id] + self.global_img_tokens_id + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id]
expected_input_ids_1 = [self.bos_token_id] + image_tokens + tokenized_sentence_1["input_ids"]
expected_input_ids_2 = [self.bos_token_id] + 2 * image_tokens + tokenized_sentence_2["input_ids"]
# Pad the first input to match the second input
pad_len = len(expected_input_ids_2) - len(expected_input_ids_1)
padded_expected_input_ids_1 = [self.padding_token_id] * pad_len + expected_input_ids_1
self.assertEqual(
inputs["input_ids"], [padded_expected_input_ids_1, expected_input_ids_2]
)
self.assertEqual(
inputs["attention_mask"],
[[0] * pad_len + [1] * len(expected_input_ids_1), [1] * len(expected_input_ids_2)]
)
self.assertEqual(np.array(inputs['pixel_values']).shape, (2, 2, 3, 364, 364))
self.assertEqual(np.array(inputs['pixel_attention_mask']).shape, (2, 2, 364, 364))
# fmt: on
def test_process_interleaved_images_prompts_image_splitting(self):
processor = self.get_processor()
processor.image_processor.do_image_splitting = True
# Test that a single image is processed correctly
inputs = processor(images=self.image1)
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 13, 3, 364, 364))
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 13, 364, 364))
# fmt: on
self.maxDiff = None
# Test a single sample with image and text
image_str = "<image>"
text_str = "In this image, we see"
text = image_str + text_str
inputs = processor(text=text, images=self.image1)
# fmt: off
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
split_image1_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
expected_input_ids_1 = [[self.bos_token_id] + split_image1_tokens + tokenized_sentence["input_ids"]]
self.assertEqual(inputs["input_ids"], expected_input_ids_1)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids_1[0])])
self.assertEqual(np.array(inputs["pixel_values"]).shape, (1, 13, 3, 364, 364))
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (1, 13, 364, 364))
# fmt: on
# Test that batch is correctly processed
image_str = "<image>"
text_str_1 = "In this image, we see"
text_str_2 = "bla, bla"
text = [
image_str + text_str_1,
text_str_2 + image_str + image_str,
]
images = [[self.image1], [self.image2, self.image3]]
inputs = processor(text=text, images=images, padding=True)
# fmt: off
tokenized_sentence_1 = processor.tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = processor.tokenizer(text_str_2, add_special_tokens=False)
split_image1_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
split_image2_tokens = self.get_split_image_expected_tokens(processor, 4, 4)
split_image3_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
expected_input_ids_1 = [self.bos_token_id] + split_image1_tokens + tokenized_sentence_1["input_ids"]
expected_input_ids_2 = [self.bos_token_id] + tokenized_sentence_2["input_ids"] + split_image2_tokens + split_image3_tokens
# Pad the first input to match the second input
pad_len = len(expected_input_ids_2) - len(expected_input_ids_1)
padded_expected_input_ids_1 = [self.padding_token_id] * pad_len + expected_input_ids_1
self.assertEqual(
inputs["input_ids"], [padded_expected_input_ids_1, expected_input_ids_2]
)
self.assertEqual(
inputs["attention_mask"],
[[0] * pad_len + [1] * len(expected_input_ids_1), [1] * len(expected_input_ids_2)]
)
self.assertEqual(np.array(inputs['pixel_values']).shape, (2, 30, 3, 364, 364))
self.assertEqual(np.array(inputs['pixel_attention_mask']).shape, (2, 30, 364, 364))
# fmt: on
def test_add_special_tokens_processor(self):
processor = self.get_processor()
image_str = "<image>"
text_str = "In this image, we see"
text = text_str + image_str
# fmt: off
inputs = processor(text=text, images=self.image1, add_special_tokens=False)
tokenized_sentence = processor.tokenizer(text_str, add_special_tokens=False)
split_image1_tokens = self.get_split_image_expected_tokens(processor, 3, 4)
expected_input_ids = [tokenized_sentence["input_ids"] + split_image1_tokens]
self.assertEqual(inputs["input_ids"], expected_input_ids)
inputs = processor(text=text, images=self.image1)
expected_input_ids = [[self.bos_token_id] + tokenized_sentence["input_ids"] + split_image1_tokens]
self.assertEqual(inputs["input_ids"], expected_input_ids)
# fmt: on
def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What do these images show?"},
{"type": "image"},
{"type": "image"},
"What do these images show?",
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.",
}
],
},
{"role": "user", "content": [{"type": "text", "text": "And who is that?"}]},
]
processor = self.get_processor()
# Make short sequence length to test that the fake tokens are added correctly
rendered = processor.apply_chat_template(messages, add_generation_prompt=True)
expected_rendered = (
"<|begin_of_text|>User: What do these images show?<image><image><end_of_utterance>\n"
"Assistant: The first image shows the statue of Liberty in New York. The second image picture depicts Idefix, the dog of Obelix in Asterix and Obelix.<end_of_utterance>\n"
"User: And who is that?<end_of_utterance>\n"
"Assistant:"
)
self.assertEqual(rendered, expected_rendered)
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer <image>"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 3)
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 364) # crop size doesn't affect our image processor
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component(
"image_processor", max_image_size={"longest_edge": 32}, size={"longest_edge": 32}
)
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor, image_seq_len=2)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer <image>"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 3)
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 32)
self.assertEqual(len(inputs["input_ids"][0]), 117)
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=30)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer<image>"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt", max_length=30)
self.assertEqual(len(inputs["input_ids"][0]), 30)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer<image>"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
inputs = processor(
text=input_str,
images=image_input,
common_kwargs={"return_tensors": "pt"},
images_kwargs={"max_image_size": {"longest_edge": 32}},
text_kwargs={"padding": "max_length", "max_length": 120, "truncation": "longest_first"},
)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[3], 32)
self.assertEqual(len(inputs["input_ids"][0]), 120)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer<image>"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"max_image_size": {"longest_edge": 32}},
"text_kwargs": {"padding": "max_length", "max_length": 120, "truncation": "longest_first"},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[3], 32)
self.assertEqual(len(inputs["input_ids"][0]), 120)
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=30)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer<image>"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 30)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["<image>lower newer", "<image>upper older longer string"]
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=[image_input, image_input],
return_tensors="pt",
padding="longest",
max_length=76,
truncation=True,
max_image_size={"longest_edge": 30},
)
self.assertEqual(inputs["pixel_values"].shape[2], 3)
self.assertEqual(inputs["pixel_values"].shape[3], 30)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer<image>"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
max_image_size={"longest_edge": 32},
padding="max_length",
max_length=120,
truncation="longest_first",
)
self.assertEqual(inputs["pixel_values"].shape[3], 32)
self.assertEqual(len(inputs["input_ids"][0]), 120)

View File

@ -82,6 +82,7 @@ PRIVATE_MODELS = [
"SeamlessM4Tv2TextToUnitModel",
"SeamlessM4Tv2CodeHifiGan",
"SeamlessM4Tv2TextToUnitForConditionalGeneration",
"Idefics3VisionTransformer",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be.