mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Uniformize kwargs for Idefics/2 processors (#32568)
* Add uniformize idefics processor kwargs and tests * Uniformize idefics2 processor kwargs * add image_processor tests idefics * add BC args order change idefics2 processor and update doc * Add support for multiple images per prompt in image-text-to-text mode idefics * Fix processor input args in idefics tests * improve test processing common, remove unnecessary tests, update process uniformization * fix doctrings idefics * fix tests processors idefics/2
This commit is contained in:
parent
b0c5660e88
commit
074aa3b3fd
@ -16,13 +16,21 @@
|
||||
Processor class for IDEFICS.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
|
||||
from ...processing_utils import (
|
||||
ImagesKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
TextKwargs,
|
||||
Unpack,
|
||||
_validate_images_text_input_order,
|
||||
)
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_tf_available, is_torch_available
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -34,6 +42,32 @@ if is_tf_available():
|
||||
IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
class IdeficsImagesKwargs(ImagesKwargs, total=False):
|
||||
transform: Optional[Callable]
|
||||
image_size: Optional[Dict[str, int]]
|
||||
image_mean: Optional[Union[float, List[float]]]
|
||||
image_std: Optional[Union[float, List[float]]]
|
||||
|
||||
|
||||
class IdeficsTextKwargs(TextKwargs, total=False):
|
||||
add_eos_token: Optional[bool]
|
||||
add_end_of_utterance_token: Optional[bool]
|
||||
|
||||
|
||||
class IdeficsProcessorKwargs(ProcessingKwargs, total=False):
|
||||
text_kwargs: IdeficsTextKwargs
|
||||
images_kwargs: IdeficsImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": False,
|
||||
"padding": "longest",
|
||||
"add_eos_token": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
}
|
||||
|
||||
|
||||
# copied from m4.training.packing
|
||||
def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
|
||||
# Set elements >= num_classes to -1
|
||||
@ -199,52 +233,32 @@ class IdeficsProcessor(ProcessorMixin):
|
||||
else False
|
||||
)
|
||||
|
||||
@deprecate_kwarg(old_name="prompts", version="5.0.0", new_name="text", raise_if_both_names=True)
|
||||
def __call__(
|
||||
self,
|
||||
prompts: Union[List[TextInput], List[List[TextInput]]],
|
||||
padding: Union[bool, str, PaddingStrategy] = "longest",
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
transform: Callable = None,
|
||||
add_eos_token=False,
|
||||
add_end_of_utterance_token=None,
|
||||
debug=False,
|
||||
return_tensors="pt",
|
||||
) -> BatchEncoding:
|
||||
images=None,
|
||||
text: Union[
|
||||
TextInput,
|
||||
PreTokenizedInput,
|
||||
List[TextInput],
|
||||
List[PreTokenizedInput],
|
||||
List[List[TextInput]],
|
||||
List[List[PreTokenizedInput]],
|
||||
] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[IdeficsProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that
|
||||
the model was trained on and prepares the image pixel values for the model to process.
|
||||
|
||||
Args:
|
||||
prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
|
||||
images (`Union[PIL.Image, str, List[PIL.Image], List[str]]`):
|
||||
either a single image or a batched list of images - can be passed in when text contains only text prompts,
|
||||
in order to use the image-text-to-text behavior.
|
||||
text (`Union[List[TextInput], [List[List[TextInput]]]]`):
|
||||
either a single prompt or a batched list of prompts - see the detailed description immediately after
|
||||
the end of the arguments doc section.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `"longest"`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'`: No padding. This will raise an error if the input sequences are of different
|
||||
lengths.
|
||||
Note: Unlike most processors, which set padding=`False` by default, `IdeficsProcessor` sets `padding="longest"`
|
||||
by default. See https://github.com/huggingface/transformers/pull/29449#pullrequestreview-1925576061 for why.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (`bool`, *optional*):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
transform (`Callable`, *optional*):
|
||||
A custom transform function that accepts a single image can be passed for training. For example,
|
||||
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
|
||||
set of transforms will be applied to the images
|
||||
add_eos_token (`bool`, *optional*, defaults to `False`):
|
||||
Adds `eos_token` at the end of the final prompt if True`
|
||||
add_end_of_utterance_token (`bool`, *optional*)
|
||||
Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an
|
||||
image). If `None` the tokenizer will be checked instead and if this token is found in
|
||||
`additional_special_tokens` then the value will be `True`.
|
||||
debug (`bool`, *optional*, defaults to `False`):
|
||||
`True` value will help debug prompt generation by dumping useful information
|
||||
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
|
||||
The type of tensors to return. Can be one of:
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
@ -255,7 +269,7 @@ class IdeficsProcessor(ProcessorMixin):
|
||||
|
||||
Detailed explanation:
|
||||
|
||||
Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
|
||||
Each entry in `text` is either a text to be passed as is or an image that will be processed.
|
||||
|
||||
An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
|
||||
|
||||
@ -279,7 +293,7 @@ class IdeficsProcessor(ProcessorMixin):
|
||||
"Describe this image.\nAssistant:",
|
||||
]
|
||||
|
||||
inputs = processor(prompts, return_tensors="pt")
|
||||
inputs = processor(text=prompts, return_tensors="pt")
|
||||
generated_ids = model.generate(**inputs, max_length=100)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
```
|
||||
@ -311,18 +325,55 @@ class IdeficsProcessor(ProcessorMixin):
|
||||
transforms.Normalize(mean=self.image_mean, std=self.image_std),
|
||||
]
|
||||
)
|
||||
inputs = processor(prompts, transform=image_transform, return_tensors="pt")
|
||||
inputs = processor(text=prompts, transform=image_transform, return_tensors="pt")
|
||||
```
|
||||
|
||||
In order to help debug prompt generation enable `debug=True` which will show you what's happening.
|
||||
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You need to specify either `text` or `images` and `text`.")
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
if images is None:
|
||||
# assuming the user wants to use the old behavior with prompts as the only argument
|
||||
prompts = text
|
||||
elif text is not None:
|
||||
# Assuming image-text-to-text behavior:
|
||||
# Check if batched images are provided
|
||||
if not isinstance(images, (list, tuple)):
|
||||
images = [images]
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
# Check if batched images and text are in the correct format
|
||||
if isinstance(text, (list, tuple)) and len(text) != len(images):
|
||||
raise ValueError(
|
||||
"When providing both images and text arguments, the number of text prompts should be the same as the number of images."
|
||||
"If you want to have several images per prompt, images should be nested as such: images=[[img1, img2], [img3, img4], ...] for text=[prompt1, prompt2, ...]."
|
||||
)
|
||||
# Check that only text is present in the prompts
|
||||
if not all(isinstance(i, str) for i in text):
|
||||
raise ValueError("When using the image-text-to-text behavior, the prompts should only contain text.")
|
||||
if isinstance(images[0], (list, tuple)):
|
||||
# if nested images, nest text as well
|
||||
text = [[i] for i in text]
|
||||
prompts = list(zip(images, text))
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
IdeficsProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
|
||||
add_end_of_utterance_token = output_kwargs["text_kwargs"].pop("add_end_of_utterance_token", None)
|
||||
|
||||
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
|
||||
if add_end_of_utterance_token is None:
|
||||
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
|
||||
# turn non-batched prompts into batched
|
||||
if not any(isinstance(i, list) for i in prompts):
|
||||
if not any(isinstance(i, (list, tuple)) for i in prompts):
|
||||
prompts = [prompts]
|
||||
|
||||
fake_token = "<fake_token_around_image>"
|
||||
@ -371,21 +422,14 @@ class IdeficsProcessor(ProcessorMixin):
|
||||
if add_eos_token:
|
||||
full_text += self.tokenizer.eos_token
|
||||
|
||||
if debug is True:
|
||||
print(f"{full_text=}")
|
||||
|
||||
image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
|
||||
image_objects = self.image_processor(image_objects, **output_kwargs["images_kwargs"])
|
||||
|
||||
all_prompts.append(full_text)
|
||||
all_images.append(image_objects)
|
||||
|
||||
text_encoding = self.tokenizer(
|
||||
text=all_prompts,
|
||||
add_special_tokens=False,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
# For BC
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
|
||||
text_encoding = self.tokenizer(all_prompts, **output_kwargs["text_kwargs"])
|
||||
all_texts = text_encoding["input_ids"]
|
||||
all_attention_masks = text_encoding["attention_mask"]
|
||||
|
||||
@ -398,12 +442,12 @@ class IdeficsProcessor(ProcessorMixin):
|
||||
output_images = []
|
||||
output_attention_masks = []
|
||||
|
||||
for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
|
||||
padded_input_ids = text
|
||||
for text_single, attention_mask, extracted_images in zip(all_texts, all_attention_masks, all_images):
|
||||
padded_input_ids = text_single
|
||||
image_count = padded_input_ids.count(self.image_token_id)
|
||||
local_max_num_images = min(image_count, max_num_images)
|
||||
|
||||
current_images = images[:local_max_num_images]
|
||||
current_images = extracted_images[:local_max_num_images]
|
||||
|
||||
if len(current_images) > 0:
|
||||
if return_tensors == "pt":
|
||||
|
@ -1584,7 +1584,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
... "In which city is that bridge located?<image>",
|
||||
... ]
|
||||
>>> images = [[image1, image2], [image3]]
|
||||
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to("cuda")
|
||||
>>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")
|
||||
|
||||
>>> # Generate
|
||||
>>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)
|
||||
|
@ -20,9 +20,15 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import AddedToken, BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType, logging
|
||||
from ...processing_utils import (
|
||||
ImagesKwargs,
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
_validate_images_text_input_order,
|
||||
)
|
||||
from ...tokenization_utils_base import AddedToken, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -40,6 +46,23 @@ def is_image_or_image_url(elem):
|
||||
return is_url(elem) or is_valid_image(elem)
|
||||
|
||||
|
||||
class Idefics2ImagesKwargs(ImagesKwargs, total=False):
|
||||
image_seq_len: Optional[int]
|
||||
|
||||
|
||||
class Idefics2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Idefics2ImagesKwargs
|
||||
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"is_split_into_words": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class Idefics2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a IDEFICS2 processor which wraps a LLama tokenizer and IDEFICS2 image processor into a single processor.
|
||||
@ -97,16 +120,12 @@ class Idefics2Processor(ProcessorMixin):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
|
||||
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
|
||||
image_seq_len: Optional[int] = None,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
is_split_into_words: bool = False,
|
||||
add_special_tokens: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchEncoding:
|
||||
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Idefics2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Processes the input prompts and returns a BatchEncoding.
|
||||
|
||||
@ -130,7 +149,7 @@ class Idefics2Processor(ProcessorMixin):
|
||||
... "<image>In this image, we see",
|
||||
... "bla bla bla<image>",
|
||||
... ]
|
||||
>>> outputs = processor(text=text, images=images, return_tensors="pt", padding=True)
|
||||
>>> outputs = processor(images=images, text=text, return_tensors="pt", padding=True)
|
||||
>>> input_ids = outputs.input_ids
|
||||
>>> input_tokens = processor.tokenizer.batch_decode(input_ids)
|
||||
>>> print(input_tokens)
|
||||
@ -138,6 +157,9 @@ class Idefics2Processor(ProcessorMixin):
|
||||
```
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
|
||||
text (`Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]`, *optional*):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
@ -145,27 +167,22 @@ class Idefics2Processor(ProcessorMixin):
|
||||
|
||||
Wherever an image token, `<image>` is encountered it is expanded to
|
||||
`<fake_token_around_image>` + `<image>` * `image_seq_len` * <fake_token_around_image>`.
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. If is of type `List[ImageInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
|
||||
image_seq_len (`int`, *optional*):
|
||||
The length of the image sequence. If not provided, the default value is used.
|
||||
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `False`):
|
||||
Padding strategy applied to the input ids. See [`PreTrainedTokenizerFast.pad`] for more information.
|
||||
truncation (`Union[bool, str, TruncationStrategy]`, *optional*):
|
||||
Truncation strategy applied to the input ids. See [`PreTrainedTokenizerFast.truncate`] for more information.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding/truncation length. See
|
||||
[`PreTrainedTokenizerFast.__call__`] for more information.
|
||||
is_split_into_words (`bool`, *optional*, defaults to `False`):
|
||||
Whether the input text is split into words or not. If set to `True`, the tokenizer will skip the
|
||||
tokenization process and assume the input is already tokenized.
|
||||
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add special tokens or not. See [`PreTrainedTokenizerFast.__call__`] for more information.
|
||||
return_tensors (`Union[str, TensorType]`, *optional*):
|
||||
If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
|
||||
information.
|
||||
|
||||
"""
|
||||
if text is None and images is None:
|
||||
raise ValueError("You must provide either `text` or `images`.")
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Idefics2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
image_seq_len = output_kwargs["images_kwargs"].pop("image_seq_len", None)
|
||||
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
|
||||
|
||||
n_images_in_text = []
|
||||
@ -194,15 +211,7 @@ class Idefics2Processor(ProcessorMixin):
|
||||
sample = sample.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
|
||||
prompt_strings.append(sample)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=prompt_strings,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
is_split_into_words=is_split_into_words,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
||||
inputs.update(text_inputs)
|
||||
|
||||
if images is not None:
|
||||
@ -227,7 +236,7 @@ class Idefics2Processor(ProcessorMixin):
|
||||
|
||||
# Load images if they are URLs
|
||||
images = [[load_image(im) for im in sample] for sample in images]
|
||||
image_inputs = self.image_processor(images, return_tensors=return_tensors)
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
inputs.update(image_inputs)
|
||||
|
||||
return inputs
|
||||
|
@ -662,7 +662,7 @@ class IdeficsModelIntegrationTest(TestCasePlus):
|
||||
"HuggingFaceM4/idefics-9b", quantization_config=quantization_config, device_map="auto"
|
||||
)
|
||||
processor = self.default_processor
|
||||
inputs = processor(prompts, return_tensors="pt", padding="longest").to(torch_device)
|
||||
inputs = processor(text=prompts, return_tensors="pt", padding="longest").to(torch_device)
|
||||
generated_ids = model.generate(**inputs, max_length=100)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
|
@ -12,11 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, require_vision
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
IdeficsImageProcessor,
|
||||
IdeficsProcessor,
|
||||
LlamaTokenizerFast,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -24,37 +37,32 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
IdeficsImageProcessor,
|
||||
IdeficsProcessor,
|
||||
LlamaTokenizerFast,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class IdeficsProcessorTest(TestCasePlus):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
class IdeficsProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = IdeficsProcessor
|
||||
|
||||
self.checkpoint_path = self.get_auto_remove_tmp_dir()
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = IdeficsImageProcessor(return_tensors="pt")
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceM4/tiny-random-idefics")
|
||||
|
||||
processor = IdeficsProcessor(image_processor, tokenizer)
|
||||
|
||||
processor.save_pretrained(self.checkpoint_path)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
self.input_keys = ["pixel_values", "input_ids", "attention_mask", "image_attention_mask"]
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.checkpoint_path, **kwargs).tokenizer
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.checkpoint_path, **kwargs).image_processor
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_prompts(self):
|
||||
"""This function prepares a list of PIL images"""
|
||||
@ -100,13 +108,13 @@ class IdeficsProcessorTest(TestCasePlus):
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = IdeficsProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
|
||||
processor.save_pretrained(self.checkpoint_path)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = IdeficsProcessor.from_pretrained(
|
||||
self.checkpoint_path, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
@ -124,7 +132,7 @@ class IdeficsProcessorTest(TestCasePlus):
|
||||
prompts = self.prepare_prompts()
|
||||
|
||||
# test that all prompts succeeded
|
||||
input_processor = processor(prompts, return_tensors="pt", padding="longest")
|
||||
input_processor = processor(text=prompts, return_tensors="pt", padding="longest")
|
||||
for key in self.input_keys:
|
||||
assert torch.is_tensor(input_processor[key])
|
||||
|
||||
@ -157,8 +165,8 @@ class IdeficsProcessorTest(TestCasePlus):
|
||||
]
|
||||
prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
|
||||
|
||||
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
|
||||
longest = processor(prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt")
|
||||
max_length = processor(text=prompts, padding="max_length", truncation=True, max_length=20, return_tensors="pt")
|
||||
longest = processor(text=prompts, padding="longest", truncation=True, max_length=30, return_tensors="pt")
|
||||
|
||||
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
|
||||
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
|
||||
@ -185,8 +193,8 @@ class IdeficsProcessorTest(TestCasePlus):
|
||||
([0] * 10) + ([1] * 10),
|
||||
]
|
||||
prompts = [[prompt] for prompt in self.prepare_prompts()[2]]
|
||||
max_length = processor(prompts, padding="max_length", truncation=True, max_length=20)
|
||||
longest = processor(prompts, padding="longest", truncation=True, max_length=30)
|
||||
max_length = processor(text=prompts, padding="max_length", truncation=True, max_length=20)
|
||||
longest = processor(text=prompts, padding="longest", truncation=True, max_length=30)
|
||||
|
||||
decoded_max_length = processor.tokenizer.decode(max_length["input_ids"][-1])
|
||||
decoded_longest = processor.tokenizer.decode(longest["input_ids"][-1])
|
||||
@ -204,7 +212,143 @@ class IdeficsProcessorTest(TestCasePlus):
|
||||
processor = IdeficsProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
prompts = self.prepare_prompts()
|
||||
|
||||
inputs = processor(prompts, padding="longest", return_tensors="pt")
|
||||
inputs = processor(text=prompts, padding="longest", return_tensors="pt")
|
||||
|
||||
# For now the processor supports only ['pixel_values', 'input_ids', 'attention_mask']
|
||||
self.assertSetEqual(set(inputs.keys()), set(self.input_keys))
|
||||
|
||||
# Override the following tests as Idefics image processor does not accept do_rescale and rescale_factor
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", image_size=234)
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs()
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 234)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", image_size=234)
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs()
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, image_size=224)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0][0]), 224)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs()
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
image_size=214,
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=2)
|
||||
image_input = self.prepare_image_inputs(batch_size=2)
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
image_size=214,
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 8)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs()
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"image_size": 214},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = self.prepare_text_inputs()
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"image_size": 214},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[3], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
@ -13,8 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
@ -22,16 +25,30 @@ from transformers import Idefics2Processor
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Idefics2Processor,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Idefics2ProcessorTest(unittest.TestCase):
|
||||
class Idefics2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Idefics2Processor
|
||||
|
||||
def setUp(self):
|
||||
self.processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b", image_seq_len=2)
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b", image_seq_len=2)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
self.image1 = Image.open(
|
||||
BytesIO(
|
||||
requests.get(
|
||||
@ -49,22 +66,35 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
).content
|
||||
)
|
||||
)
|
||||
self.bos_token = self.processor.tokenizer.bos_token
|
||||
self.image_token = self.processor.image_token.content
|
||||
self.fake_image_token = self.processor.fake_image_token.content
|
||||
self.bos_token = processor.tokenizer.bos_token
|
||||
self.image_token = processor.image_token.content
|
||||
self.fake_image_token = processor.fake_image_token.content
|
||||
|
||||
self.bos_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.bos_token)
|
||||
self.image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
self.fake_image_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.fake_image_token)
|
||||
self.image_seq_len = self.processor.image_seq_len
|
||||
self.bos_token_id = processor.tokenizer.convert_tokens_to_ids(self.bos_token)
|
||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
self.fake_image_token_id = processor.tokenizer.convert_tokens_to_ids(self.fake_image_token)
|
||||
self.image_seq_len = processor.image_seq_len
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_process_interleaved_images_prompts_no_image_splitting(self):
|
||||
old_image_splitting = self.processor.image_processor.do_image_splitting
|
||||
tokenizer = self.get_tokenizer()
|
||||
processor = self.get_processor()
|
||||
|
||||
self.processor.image_processor.do_image_splitting = False
|
||||
processor.image_processor.do_image_splitting = False
|
||||
|
||||
# Test that a single image is processed correctly
|
||||
inputs = self.processor(images=self.image1)
|
||||
inputs = processor(images=self.image1)
|
||||
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 3, 653, 980))
|
||||
self.assertEqual(inputs["pixel_attention_mask"].shape, (1, 1, 653, 980))
|
||||
# fmt: on
|
||||
@ -73,10 +103,10 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
image_str = "<image>"
|
||||
text_str = "In this image, we see"
|
||||
text = image_str + text_str
|
||||
inputs = self.processor(text=text, images=self.image1)
|
||||
inputs = processor(text=text, images=self.image1)
|
||||
|
||||
# fmt: off
|
||||
tokenized_sentence = self.processor.tokenizer(text_str, add_special_tokens=False)
|
||||
tokenized_sentence = tokenizer(text_str, add_special_tokens=False)
|
||||
expected_input_ids = [[self.bos_token_id] + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
|
||||
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
|
||||
@ -95,11 +125,11 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
]
|
||||
images = [[self.image1], [self.image2, self.image3]]
|
||||
|
||||
inputs = self.processor(text=text, images=images, padding=True)
|
||||
inputs = processor(text=text, images=images, padding=True)
|
||||
|
||||
# fmt: off
|
||||
tokenized_sentence_1 = self.processor.tokenizer(text_str_1, add_special_tokens=False)
|
||||
tokenized_sentence_2 = self.processor.tokenizer(text_str_2, add_special_tokens=False)
|
||||
tokenized_sentence_1 = tokenizer(text_str_1, add_special_tokens=False)
|
||||
tokenized_sentence_2 = tokenizer(text_str_2, add_special_tokens=False)
|
||||
expected_input_ids_1 = [self.bos_token_id] + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + tokenized_sentence_1["input_ids"]
|
||||
expected_input_ids_2 = [self.bos_token_id] + tokenized_sentence_2["input_ids"] + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len + [self.fake_image_token_id]
|
||||
# Pad the first input to match the second input
|
||||
@ -117,15 +147,13 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
self.assertEqual(inputs['pixel_attention_mask'].shape, (2, 2, 767, 980))
|
||||
# fmt: on
|
||||
|
||||
self.processor.image_processor.do_image_splitting = old_image_splitting
|
||||
|
||||
def test_process_interleaved_images_prompts_image_splitting(self):
|
||||
old_image_splitting = self.processor.image_processor.do_image_splitting
|
||||
|
||||
self.processor.image_processor.do_image_splitting = True
|
||||
processor = self.get_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
processor.image_processor.do_image_splitting = True
|
||||
|
||||
# Test that a single image is processed correctly
|
||||
inputs = self.processor(images=self.image1)
|
||||
inputs = processor(images=self.image1)
|
||||
self.assertEqual(inputs["pixel_values"].shape, (1, 5, 3, 653, 980))
|
||||
self.assertEqual(inputs["pixel_attention_mask"].shape, (1, 5, 653, 980))
|
||||
# fmt: on
|
||||
@ -134,10 +162,10 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
image_str = "<image>"
|
||||
text_str = "In this image, we see"
|
||||
text = image_str + text_str
|
||||
inputs = self.processor(text=text, images=self.image1)
|
||||
inputs = processor(text=text, images=self.image1)
|
||||
|
||||
# fmt: off
|
||||
tokenized_sentence = self.processor.tokenizer(text_str, add_special_tokens=False)
|
||||
tokenized_sentence = tokenizer(text_str, add_special_tokens=False)
|
||||
expected_input_ids = [[self.bos_token_id] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id] + tokenized_sentence["input_ids"]]
|
||||
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||
self.assertEqual(inputs["attention_mask"], [[1] * len(expected_input_ids[0])])
|
||||
@ -156,11 +184,11 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
]
|
||||
images = [[self.image1], [self.image2, self.image3]]
|
||||
|
||||
inputs = self.processor(text=text, images=images, padding=True)
|
||||
inputs = processor(text=text, images=images, padding=True)
|
||||
|
||||
# fmt: off
|
||||
tokenized_sentence_1 = self.processor.tokenizer(text_str_1, add_special_tokens=False)
|
||||
tokenized_sentence_2 = self.processor.tokenizer(text_str_2, add_special_tokens=False)
|
||||
tokenized_sentence_1 = tokenizer(text_str_1, add_special_tokens=False)
|
||||
tokenized_sentence_2 = tokenizer(text_str_2, add_special_tokens=False)
|
||||
expected_input_ids_1 = [self.bos_token_id] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id] + tokenized_sentence_1["input_ids"]
|
||||
expected_input_ids_2 = [self.bos_token_id] + tokenized_sentence_2["input_ids"] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * 5 + [self.fake_image_token_id]
|
||||
# Pad the first input to match the second input
|
||||
@ -178,22 +206,22 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
self.assertEqual(inputs['pixel_attention_mask'].shape, (2, 10, 767, 980))
|
||||
# fmt: on
|
||||
|
||||
self.processor.image_processor.do_image_splitting = old_image_splitting
|
||||
|
||||
def test_add_special_tokens_processor(self):
|
||||
processor = self.get_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_str = "<image>"
|
||||
text_str = "In this image, we see"
|
||||
text = text_str + image_str
|
||||
|
||||
n_image_repeat = 5 if self.processor.image_processor.do_image_splitting else 1
|
||||
n_image_repeat = 5 if processor.image_processor.do_image_splitting else 1
|
||||
|
||||
# fmt: off
|
||||
inputs = self.processor(text=text, images=self.image1, add_special_tokens=False)
|
||||
tokenized_sentence = self.processor.tokenizer(text_str, add_special_tokens=False)
|
||||
inputs = processor(text=text, images=self.image1, add_special_tokens=False)
|
||||
tokenized_sentence = tokenizer(text_str, add_special_tokens=False)
|
||||
expected_input_ids = [tokenized_sentence["input_ids"] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * n_image_repeat + [self.fake_image_token_id]]
|
||||
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||
|
||||
inputs = self.processor(text=text, images=self.image1)
|
||||
inputs = processor(text=text, images=self.image1)
|
||||
expected_input_ids = [[self.bos_token_id] + tokenized_sentence["input_ids"] + ([self.fake_image_token_id] + [self.image_token_id] * self.image_seq_len) * n_image_repeat + [self.fake_image_token_id]]
|
||||
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||
# fmt: on
|
||||
@ -222,7 +250,7 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
{"role": "user", "content": [{"type": "text", "text": "And who is that?"}]},
|
||||
]
|
||||
|
||||
processor = self.processor
|
||||
processor = self.get_processor()
|
||||
# Make short sequence length to test that the fake tokens are added correctly
|
||||
rendered = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
@ -233,3 +261,27 @@ class Idefics2ProcessorTest(unittest.TestCase):
|
||||
"Assistant:"
|
||||
)
|
||||
self.assertEqual(rendered, expected_rendered)
|
||||
|
||||
# Override as Idefics2Processor needs image tokens in prompts
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||
if batch_size is None:
|
||||
return "lower newer <image>"
|
||||
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
|
||||
if batch_size == 1:
|
||||
return ["lower newer <image>"]
|
||||
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
|
||||
batch_size - 2
|
||||
)
|
||||
|
||||
# Override as PixtralProcessor needs nested images to work properly with batched inputs
|
||||
@require_vision
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
if batch_size is None:
|
||||
return super().prepare_image_inputs()
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
return [[super().prepare_image_inputs()]] * batch_size
|
||||
|
Loading…
Reference in New Issue
Block a user