Uniformize kwargs for Idefics/2 processors (#32568)

* Add uniformize idefics processor kwargs and tests

* Uniformize idefics2 processor kwargs

* add image_processor tests idefics

* add BC args order change idefics2 processor and update doc

* Add support for multiple images per prompt in image-text-to-text mode idefics

* Fix processor input args in idefics tests

* improve test processing common, remove unnecessary tests, update process uniformization

* fix doctrings idefics

* fix tests processors idefics/2
This commit is contained in:
Yoni Gozlan 2024-10-03 18:08:24 +02:00 committed by GitHub
parent b0c5660e88
commit 074aa3b3fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 409 additions and 160 deletions

View File

@ -16,13 +16,21 @@
Processor class for IDEFICS.
"""
from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
from urllib.parse import urlparse
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from ...processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available
from ...utils.deprecation import deprecate_kwarg
if is_torch_available():
@ -34,6 +42,32 @@ if is_tf_available():
IMAGE_TOKEN = "<image>"
class IdeficsImagesKwargs(ImagesKwargs, total=False):
transform: Optional[Callable]
image_size: Optional[Dict[str, int]]
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]
class IdeficsTextKwargs(TextKwargs, total=False):
add_eos_token: Optional[bool]
add_end_of_utterance_token: Optional[bool]
class IdeficsProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: IdeficsTextKwargs
images_kwargs: IdeficsImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": False,
"padding": "longest",
"add_eos_token": False,
},
"images_kwargs": {},
"common_kwargs": {"return_tensors": "pt"},
}
# copied from m4.training.packing
def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
# Set elements >= num_classes to -1
@ -199,52 +233,32 @@ class IdeficsProcessor(ProcessorMixin):
else False
)
@deprecate_kwarg(old_name="prompts", version="5.0.0", new_name="text", raise_if_both_names=True)
def __call__(
self,
prompts: Union[List[TextInput], List[List[TextInput]]],
padding: Union[bool, str, PaddingStrategy] = "longest",
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
transform: Callable = None,
add_eos_token=False,
add_end_of_utterance_token=None,
debug=False,
return_tensors="pt",
) -> BatchEncoding:
images=None,
text: Union[
TextInput,
PreTokenizedInput,
List[TextInput],
List[PreTokenizedInput],
List[List[TextInput]],
List[List[PreTokenizedInput]],
] = None,
audio=None,
videos=None,
**kwargs: Unpack[IdeficsProcessorKwargs],
) -> BatchFeature:
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that
the model was trained on and prepares the image pixel values for the model to process.
Args:
prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
images (`Union[PIL.Image, str, List[PIL.Image], List[str]]`):
either a single image or a batched list of images - can be passed in when text contains only text prompts,
in order to use the image-text-to-text behavior.
text (`Union[List[TextInput], [List[List[TextInput]]]]`):
either a single prompt or a batched list of prompts - see the detailed description immediately after
the end of the arguments doc section.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `"longest"`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'`: No padding. This will raise an error if the input sequences are of different
lengths.
Note: Unlike most processors, which set padding=`False` by default, `IdeficsProcessor` sets `padding="longest"`
by default. See https://github.com/huggingface/transformers/pull/29449#pullrequestreview-1925576061 for why.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
transform (`Callable`, *optional*):
A custom transform function that accepts a single image can be passed for training. For example,
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
set of transforms will be applied to the images
add_eos_token (`bool`, *optional*, defaults to `False`):
Adds `eos_token` at the end of the final prompt if True`
add_end_of_utterance_token (`bool`, *optional*)
Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an
image). If `None` the tokenizer will be checked instead and if this token is found in
`additional_special_tokens` then the value will be `True`.
debug (`bool`, *optional*, defaults to `False`):
`True` value will help debug prompt generation by dumping useful information
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
The type of tensors to return. Can be one of:
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
@ -255,7 +269,7 @@ class IdeficsProcessor(ProcessorMixin):
Detailed explanation:
Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
Each entry in `text` is either a text to be passed as is or an image that will be processed.
An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
@ -279,7 +293,7 @@ class IdeficsProcessor(ProcessorMixin):
"Describe this image.\nAssistant:",
]
inputs = processor(prompts, return_tensors="pt")
inputs = processor(text=prompts, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=100)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
```
@ -311,18 +325,55 @@ class IdeficsProcessor(ProcessorMixin):
transforms.Normalize(mean=self.image_mean, std=self.image_std),
]
)
inputs = processor(prompts, transform=image_transform, return_tensors="pt")
inputs = processor(text=prompts, transform=image_transform, return_tensors="pt")
```
In order to help debug prompt generation enable `debug=True` which will show you what's happening.
"""
if images is None and text is None:
raise ValueError("You need to specify either `text` or `images` and `text`.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
if images is None:
# assuming the user wants to use the old behavior with prompts as the only argument
prompts = text
elif text is not None:
# Assuming image-text-to-text behavior:
# Check if batched images are provided
if not isinstance(images, (list, tuple)):
images = [images]
if isinstance(text, str):
text = [text]
# Check if batched images and text are in the correct format
if isinstance(text, (list, tuple)) and len(text) != len(images):
raise ValueError(
"When providing both images and text arguments, the number of text prompts should be the same as the number of images."
"If you want to have several images per prompt, images should be nested as such: images=[[img1, img2], [img3, img4], ...] for text=[prompt1, prompt2, ...]."
)
# Check that only text is present in the prompts
if not all(isinstance(i, str) for i in text):
raise ValueError("When using the image-text-to-text behavior, the prompts should only contain text.")
if isinstance(images[0], (list, tuple)):
# if nested images, nest text as well
text = [[i] for i in text]
prompts = list(zip(images, text))
output_kwargs = self._merge_kwargs(
IdeficsProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
add_end_of_utterance_token = output_kwargs["text_kwargs"].pop("add_end_of_utterance_token", None)
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if add_end_of_utterance_token is None:
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
# turn non-batched prompts into batched
if not any(isinstance(i, list) for i in prompts):
if not any(isinstance(i, (list, tuple)) for i in prompts):
prompts = [prompts]
fake_token = "<fake_token_around_image>"
@ -371,21 +422,14 @@ class IdeficsProcessor(ProcessorMixin):
if add_eos_token:
full_text += self.tokenizer.eos_token
if debug is True:
print(f"{full_text=}")
image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
image_objects = self.image_processor(image_objects, **output_kwargs["images_kwargs"])
all_prompts.append(full_text)
all_images.append(image_objects)
text_encoding = self.tokenizer(
text=all_prompts,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
)
# For BC
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
text_encoding = self.tokenizer(all_prompts, **output_kwargs["text_kwargs"])
all_texts = text_encoding["input_ids"]
all_attention_masks = text_encoding["attention_mask"]
@ -398,12 +442,12 @@ class IdeficsProcessor(ProcessorMixin):
output_images = []
output_attention_masks = []
for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text
for text_single, attention_mask, extracted_images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text_single
image_count = padded_input_ids.count(self.image_token_id)
local_max_num_images = min(image_count, max_num_images)
current_images = images[:local_max_num_images]
current_images = extracted_images[:local_max_num_images]
if len(current_images) > 0:
if return_tensors == "pt":

View File

@ -1584,7 +1584,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
... "In which city is that bridge located?<image>",
... ]
>>> images = [[image1, image2], [image3]]
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to("cuda")
>>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")
>>> # Generate
>>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)

View File

@ -20,9 +20,15 @@ from typing import TYPE_CHECKING, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import AddedToken, BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from ...utils import TensorType, logging
from ...processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import AddedToken, TextInput
from ...utils import logging
if TYPE_CHECKING:
@ -40,6 +46,23 @@ def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)
class Idefics2ImagesKwargs(ImagesKwargs, total=False):
image_seq_len: Optional[int]
class Idefics2ProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Idefics2ImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"is_split_into_words": False,
},
"images_kwargs": {},
}
class Idefics2Processor(ProcessorMixin):
r"""
Constructs a IDEFICS2 processor which wraps a LLama tokenizer and IDEFICS2 image processor into a single processor.
@ -97,16 +120,12 @@ class Idefics2Processor(ProcessorMixin):
def __call__(
self,
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
image_seq_len: Optional[int] = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
is_split_into_words: bool = False,
add_special_tokens: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchEncoding:
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
audio=None,
videos=None,
**kwargs: Unpack[Idefics2ProcessorKwargs],
) -> BatchFeature:
"""
Processes the input prompts and returns a BatchEncoding.
@ -130,7 +149,7 @@ class Idefics2Processor(ProcessorMixin):
... "<image>In this image, we see",
... "bla bla bla<image>",
... ]
>>> outputs = processor(text=text, images=images, return_tensors="pt", padding=True)
>>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True)
>>> input_ids = outputs.input_ids
>>> input_tokens = processor.tokenizer.batch_decode(input_ids)
>>> print(input_tokens)
@ -138,6 +157,9 @@ class Idefics2Processor(ProcessorMixin):
```
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
text (`Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
@ -145,27 +167,22 @@ class Idefics2Processor(ProcessorMixin):
Wherever an image token, `<image>` is encountered it is expanded to
`<fake_token_around_image>` + `<image>` * `image_seq_len` * <fake_token_around_image>`.
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
image_seq_len (`int`, *optional*):
The length of the image sequence. If not provided, the default value is used.
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `False`):
Padding strategy applied to the input ids. See [`PreTrainedTokenizerFast.pad`] for more information.
truncation (`Union[bool, str, TruncationStrategy]`, *optional*):
Truncation strategy applied to the input ids. See [`PreTrainedTokenizerFast.truncate`] for more information.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding/truncation length. See
[`PreTrainedTokenizerFast.__call__`] for more information.
is_split_into_words (`bool`, *optional*, defaults to `False`):
Whether the input text is split into words or not. If set to `True`, the tokenizer will skip the
tokenization process and assume the input is already tokenized.
add_special_tokens (`bool`, *optional*, defaults to `True`):
Whether to add special tokens or not. See [`PreTrainedTokenizerFast.__call__`] for more information.
return_tensors (`Union[str, TensorType]`, *optional*):
If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
information.
"""
if text is None and images is None:
raise ValueError("You must provide either `text` or `images`.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
Idefics2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
image_seq_len = output_kwargs["images_kwargs"].pop("image_seq_len", None)
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
n_images_in_text = []
@ -194,15 +211,7 @@ class Idefics2Processor(ProcessorMixin):
sample = sample.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
prompt_strings.append(sample)
text_inputs = self.tokenizer(
text=prompt_strings,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
is_split_into_words=is_split_into_words,
return_tensors=return_tensors,
)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)
if images is not None:
@ -227,7 +236,7 @@ class Idefics2Processor(ProcessorMixin):
# Load images if they are URLs
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, return_tensors=return_tensors)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
inputs.update(image_inputs)
return inputs

View File

@ -662,7 +662,7 @@ class IdeficsModelIntegrationTest(TestCasePlus):
"HuggingFaceM4/idefics-9b", quantization_config=quantization_config, device_map="auto"
)
processor = self.default_processor
inputs = processor(prompts, return_tensors="pt", padding="longest").to(torch_device)
inputs = processor(text=prompts, return_tensors="pt", padding="longest").to(torch_device)
generated_ids = model.generate(**inputs, max_length=100)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

View File

@ -12,11 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
from transformers.testing_utils import TestCasePlus, require_torch, require_vision
from transformers import (
AutoProcessor,
IdeficsImageProcessor,
IdeficsProcessor,
LlamaTokenizerFast,
PreTrainedTokenizerFast,
)
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_torch_available():
import torch
@ -24,37 +37,32 @@ if is_torch_available():
if is_vision_available():
from PIL import Image
from transformers import (
AutoProcessor,
IdeficsImageProcessor,
IdeficsProcessor,
LlamaTokenizerFast,
PreTrainedTokenizerFast,
)
@require_torch
@require_vision
class IdeficsProcessorTest(TestCasePlus):
def setUp(self):
super().setUp()
class IdeficsProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = IdeficsProcessor
self.checkpoint_path = self.get_auto_remove_tmp_dir()
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = IdeficsImageProcessor(return_tensors="pt")
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/tiny-random-idefics")
processor = IdeficsProcessor(image_processor, tokenizer)
processor.save_pretrained(self.checkpoint_path)
processor.save_pretrained(self.tmpdirname)
self.input_keys = ["pixel_values", "input_ids", "attention_mask", "image_attention_mask"]
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.checkpoint_path, **kwargs).tokenizer
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.checkpoint_path, **kwargs).image_processor
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_prompts(self):
"""This function prepares a list of PIL images"""
@ -100,13 +108,13 @@ class IdeficsProcessorTest(TestCasePlus):
def test_save_load_pretrained_additional_features(self):
processor = IdeficsProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
processor.save_pretrained(self.checkpoint_path)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = IdeficsProcessor.from_pretrained(
self.checkpoint_path, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
@ -124,7 +132,7 @@ class IdeficsProcessorTest(TestCasePlus):
prompts = self.prepare_prompts()
# test that all prompts succeeded
input_processor = processor(prompts, return_tensors="pt", padding="longest")
input_processor = processor(text=prompts, return_tensors="pt", padding="longest")
for key in self.input_keys:
assert torch.is_tensor(input_processor[key])
@ -157,8 +165,8 @@ class IdeficsProcessorTest(TestCasePlus):
]
prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
longest = processor(prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt")
max_length = processor(text=prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
longest = processor(text=prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt")
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
@ -185,8 +193,8 @@ class IdeficsProcessorTest(TestCasePlus):
([0] * 10) + ([1] * 10),
]
prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20)
longest = processor(prompts, padding="longest", truncation=True, max_length=30)
max_length = processor(text=prompts, padding="max_length", truncation=True, max_length=20)
longest = processor(text=prompts, padding="longest", truncation=True, max_length=30)
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
@ -204,7 +212,143 @@ class IdeficsProcessorTest(TestCasePlus):
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
prompts = self.prepare_prompts()
inputs = processor(prompts, padding="longest", return_tensors="pt")
inputs = processor(text=prompts, padding="longest", return_tensors="pt")
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
self.assertSetEqual(set(inputs.keys()), set(self.input_keys))
# Override the following tests as Idefics image processor does not accept do_rescale and rescale_factor
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", image_size=234)
tokenizer = self.get_component("tokenizer", max_length=117)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 234)
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", image_size=234)
tokenizer = self.get_component("tokenizer", max_length=117)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, image_size=224)
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 224)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
image_size=214,
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[3], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs(batch_size=2)
image_input = self.prepare_image_inputs(batch_size=2)
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
image_size=214,
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[3], 214)
self.assertEqual(len(inputs["input_ids"][0]), 8)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"image_size": 214},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[3], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"image_size": 214},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[3], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)

View File

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from io import BytesIO
from typing import Optional
import requests
@ -22,16 +25,30 @@ from transformers import Idefics2Processor
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from PIL import Image
from transformers import (
AutoProcessor,
Idefics2Processor,
)
@require_torch
@require_vision
class Idefics2ProcessorTest(unittest.TestCase):
class Idefics2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Idefics2Processor
def setUp(self):
self.processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b", image_seq_len=2)
self.tmpdirname = tempfile.mkdtemp()
processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b", image_seq_len=2)
processor.save_pretrained(self.tmpdirname)
self.image1 = Image.open(
BytesIO(
requests.get(
@ -49,22 +66,35 @@ class Idefics2ProcessorTest(unittest.TestCase):
).content
)
)
self.bos_token = self.processor.tokenizer.bos_token
self.image_token = self.processor.image_token.content
self.fake_image_token = self.processor.fake_image_token.content
self.bos_token = processor.tokenizer.bos_token
self.image_token = processor.image_token.content
self.fake_image_token = processor.fake_image_token.content
self.bos_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.bos_token)
self.image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.image_token)
self.fake_image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.fake_image_token)
self.image_seq_len = self.processor.image_seq_len
self.bos_token_id = processor.tokenizer.convert_tokens_to_ids(self.bos_token)
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(self.image_token)
self.fake_image_token_id = processor.tokenizer.convert_tokens_to_ids(self.fake_image_token)
self.image_seq_len = processor.image_seq_len
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def get_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_process_interleaved_images_prompts_no_image_splitting(self):
old_image_splitting = self.processor.image_processor.do_image_splitting
tokenizer = self.get_tokenizer()
processor = self.get_processor()
self.processor.image_processor.do_image_splitting = False
processor.image_processor.do_image_splitting = False
# Test that a single image is processed correctly
inputs = self.processor(images=self.image1)
inputs = processor(images=self.image1)
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 3, 653, 980))
self.assertEqual(inputs["pixel_attention_mask"].shape, (1, 1, 653, 980))
# fmt: on
@ -73,10 +103,10 @@ class Idefics2ProcessorTest(unittest.TestCase):
image_str = "<image>"
text_str = "In this image, we see"
text = image_str + text_str
inputs = self.processor(text=text, images=self.image1)
inputs = processor(text=text, images=self.image1)
# fmt: off
tokenized_sentence = self.processor.tokenizer(text_str, add_special_tokens=False)
tokenized_sentence = tokenizer(text_str, add_special_tokens=False)
expected_input_ids = [[self.bos_token_id] + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
@ -95,11 +125,11 @@ class Idefics2ProcessorTest(unittest.TestCase):
]
images = [[self.image1], [self.image2, self.image3]]
inputs = self.processor(text=text, images=images, padding=True)
inputs = processor(text=text, images=images, padding=True)
# fmt: off
tokenized_sentence_1 = self.processor.tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = self.processor.tokenizer(text_str_2, add_special_tokens=False)
tokenized_sentence_1 = tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = tokenizer(text_str_2, add_special_tokens=False)
expected_input_ids_1 = [self.bos_token_id] + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence_1["input_ids"]
expected_input_ids_2 = [self.bos_token_id] + tokenized_sentence_2["input_ids"] + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id]
# Pad the first input to match the second input
@ -117,15 +147,13 @@ class Idefics2ProcessorTest(unittest.TestCase):
self.assertEqual(inputs['pixel_attention_mask'].shape, (2, 2, 767, 980))
# fmt: on
self.processor.image_processor.do_image_splitting = old_image_splitting
def test_process_interleaved_images_prompts_image_splitting(self):
old_image_splitting = self.processor.image_processor.do_image_splitting
self.processor.image_processor.do_image_splitting = True
processor = self.get_processor()
tokenizer = self.get_tokenizer()
processor.image_processor.do_image_splitting = True
# Test that a single image is processed correctly
inputs = self.processor(images=self.image1)
inputs = processor(images=self.image1)
self.assertEqual(inputs["pixel_values"].shape, (1, 5, 3, 653, 980))
self.assertEqual(inputs["pixel_attention_mask"].shape, (1, 5, 653, 980))
# fmt: on
@ -134,10 +162,10 @@ class Idefics2ProcessorTest(unittest.TestCase):
image_str = "<image>"
text_str = "In this image, we see"
text = image_str + text_str
inputs = self.processor(text=text, images=self.image1)
inputs = processor(text=text, images=self.image1)
# fmt: off
tokenized_sentence = self.processor.tokenizer(text_str, add_special_tokens=False)
tokenized_sentence = tokenizer(text_str, add_special_tokens=False)
expected_input_ids = [[self.bos_token_id] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
self.assertEqual(inputs["input_ids"], expected_input_ids)
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
@ -156,11 +184,11 @@ class Idefics2ProcessorTest(unittest.TestCase):
]
images = [[self.image1], [self.image2, self.image3]]
inputs = self.processor(text=text, images=images, padding=True)
inputs = processor(text=text, images=images, padding=True)
# fmt: off
tokenized_sentence_1 = self.processor.tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = self.processor.tokenizer(text_str_2, add_special_tokens=False)
tokenized_sentence_1 = tokenizer(text_str_1, add_special_tokens=False)
tokenized_sentence_2 = tokenizer(text_str_2, add_special_tokens=False)
expected_input_ids_1 = [self.bos_token_id] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id] + tokenized_sentence_1["input_ids"]
expected_input_ids_2 = [self.bos_token_id] + tokenized_sentence_2["input_ids"] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id]
# Pad the first input to match the second input
@ -178,22 +206,22 @@ class Idefics2ProcessorTest(unittest.TestCase):
self.assertEqual(inputs['pixel_attention_mask'].shape, (2, 10, 767, 980))
# fmt: on
self.processor.image_processor.do_image_splitting = old_image_splitting
def test_add_special_tokens_processor(self):
processor = self.get_processor()
tokenizer = self.get_tokenizer()
image_str = "<image>"
text_str = "In this image, we see"
text = text_str + image_str
n_image_repeat = 5 if self.processor.image_processor.do_image_splitting else 1
n_image_repeat = 5 if processor.image_processor.do_image_splitting else 1
# fmt: off
inputs = self.processor(text=text, images=self.image1, add_special_tokens=False)
tokenized_sentence = self.processor.tokenizer(text_str, add_special_tokens=False)
inputs = processor(text=text, images=self.image1, add_special_tokens=False)
tokenized_sentence = tokenizer(text_str, add_special_tokens=False)
expected_input_ids = [tokenized_sentence["input_ids"] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * n_image_repeat + [self.fake_image_token_id]]
self.assertEqual(inputs["input_ids"], expected_input_ids)
inputs = self.processor(text=text, images=self.image1)
inputs = processor(text=text, images=self.image1)
expected_input_ids = [[self.bos_token_id] + tokenized_sentence["input_ids"] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * n_image_repeat + [self.fake_image_token_id]]
self.assertEqual(inputs["input_ids"], expected_input_ids)
# fmt: on
@ -222,7 +250,7 @@ class Idefics2ProcessorTest(unittest.TestCase):
{"role": "user", "content": [{"type": "text", "text": "And who is that?"}]},
]
processor = self.processor
processor = self.get_processor()
# Make short sequence length to test that the fake tokens are added correctly
rendered = processor.apply_chat_template(messages, add_generation_prompt=True)
@ -233,3 +261,27 @@ class Idefics2ProcessorTest(unittest.TestCase):
"Assistant:"
)
self.assertEqual(rendered, expected_rendered)
# Override as Idefics2Processor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <image>"
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return ["lower newer <image>"]
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
batch_size - 2
)
# Override as PixtralProcessor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
if batch_size is None:
return super().prepare_image_inputs()
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size