Uniformize kwargs for image-text-to-text processors (#32544)

* uniformize FUYU processor kwargs

* Uniformize instructblip processor kwargs

* Fix processor kwargs and tests Fuyu, InstructBlip, Kosmos2

* Uniformize llava_next processor

* Fix save_load test for processor with chat_template only as extra init args

* Fix import Unpack

* Fix Fuyu Processor import

* Fix FuyuProcessor import

* Fix FuyuProcessor

* Add defaults for specific kwargs kosmos2

* Fix Udop to return BatchFeature instead of BatchEncoding and uniformize kwargs

* Add tests processor Udop

* remove Copied from in processing Udop as change of input orders caused by BatchEncoding -> BatchFeature

* Fix overwrite tests kwargs processors

* Add warnings and BC for changes in processor inputs order, change docs, add BC for text_pair as arg for Udop

* Fix processing test fuyu

* remove unnecessary pad_token check in instructblip ProcessorTest

* Fix BC tests and cleanup

* FIx imports fuyu

* Uniformize Pix2Struct

* Fix wrong name for FuyuProcessorKwargs

* Fix slow tests reversed inputs align fuyu llava-next, change udop warning

* Fix wrong logging import udop

* Add check images text input order

* Fix copies

* change text pair handling when positional arg

* rebase on main, fix imports in test_processing_common

* remove optional args and udop uniformization from this PR

* fix failing tests

* remove unnecessary test, fix processing utils and test processing common

* cleanup Unpack

* cleanup

* fix conflict grounding dino
This commit is contained in:
Yoni Gozlan 2024-09-24 21:28:19 -04:00 committed by GitHub
parent fa0bb0fe76
commit 5f0c181f4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 763 additions and 852 deletions

View File

@ -46,7 +46,7 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) image = Image.open(requests.get(url, stream=True).raw)
candidate_labels = ["an image of a cat", "an image of a dog"] candidate_labels = ["an image of a cat", "an image of a dog"]
inputs = processor(text=candidate_labels, images=image, return_tensors="pt") inputs = processor(images=image ,text=candidate_labels, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)

View File

@ -18,16 +18,16 @@ rendered properly in your Markdown viewer.
## Overview ## Overview
The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar.
The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs. The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs.
By treating image tokens like text tokens and using a special image-newline character, the model knows when an image line ends. Image positional embeddings are removed. This avoids the need for different training phases for various image resolutions. With 8 billion parameters and licensed under CC-BY-NC, Fuyu-8B is notable for its ability to handle both text and images, its impressive context size of 16K, and its overall performance. By treating image tokens like text tokens and using a special image-newline character, the model knows when an image line ends. Image positional embeddings are removed. This avoids the need for different training phases for various image resolutions. With 8 billion parameters and licensed under CC-BY-NC, Fuyu-8B is notable for its ability to handle both text and images, its impressive context size of 16K, and its overall performance.
<Tip warning={true}> <Tip warning={true}>
The `Fuyu` models were trained using `bfloat16`, but the original inference uses `float16` The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be The `Fuyu` models were trained using `bfloat16`, but the original inference uses `float16` The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`. used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be cast to the default `dtype` of `torch` (becomes `torch.float32`). Users should specify the `torch_dtype` they want, and if they don't it will be `torch.float32`. The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be cast to the default `dtype` of `torch` (becomes `torch.float32`). Users should specify the `torch_dtype` they want, and if they don't it will be `torch.float32`.
@ -56,7 +56,7 @@ tar -xvf 8b_base_model_release.tar
``` ```
Then, model can be loaded via: Then, model can be loaded via:
```py ```py
from transformers import FuyuConfig, FuyuForCausalLM from transformers import FuyuConfig, FuyuForCausalLM
model_config = FuyuConfig() model_config = FuyuConfig()
model = FuyuForCausalLM(model_config).from_pretrained('/output/path') model = FuyuForCausalLM(model_config).from_pretrained('/output/path')
@ -81,7 +81,7 @@ text_prompt = "Generate a coco-style caption.\\n"
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content)) bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
inputs_to_model = processor(text=text_prompt, images=bus_image_pil) inputs_to_model = processor(images=bus_image_pil, text=text_prompt)
``` ```
@ -90,7 +90,7 @@ This model was contributed by [Molbap](https://huggingface.co/Molbap).
The original code can be found [here](https://github.com/persimmon-ai-labs/adept-inference). The original code can be found [here](https://github.com/persimmon-ai-labs/adept-inference).
- Fuyu uses a `sentencepiece` based tokenizer, with a `Unigram` model. It supports bytefallback, which is only available in `tokenizers==0.14.0` for the fast tokenizer. - Fuyu uses a `sentencepiece` based tokenizer, with a `Unigram` model. It supports bytefallback, which is only available in `tokenizers==0.14.0` for the fast tokenizer.
The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece. The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.
- The authors suggest to use the following prompt for image captioning: `f"Generate a coco-style caption.\\n"` - The authors suggest to use the following prompt for image captioning: `f"Generate a coco-style caption.\\n"`

View File

@ -133,7 +133,7 @@ import requests
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda:0") model.to("cuda:0")
# prepare image and text prompt, using the appropriate prompt template # prepare image and text prompt, using the appropriate prompt template
@ -150,7 +150,7 @@ conversation = [
}, },
] ]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0") inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")
# autoregressively complete prompt # autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100) output = model.generate(**inputs, max_new_tokens=100)
@ -222,7 +222,7 @@ prompts = [prompt_1, prompt_2]
# We can simply feed images in the order they have to be used in the text prompt # We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens # Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device)
# Generate # Generate
generate_ids = model.generate(**inputs, max_new_tokens=30) generate_ids = model.generate(**inputs, max_new_tokens=30)
@ -266,8 +266,8 @@ First make sure to install flash-attn. Refer to the [original repository of Flas
from transformers import LlavaNextForConditionalGeneration from transformers import LlavaNextForConditionalGeneration
model = LlavaNextForConditionalGeneration.from_pretrained( model = LlavaNextForConditionalGeneration.from_pretrained(
model_id, model_id,
torch_dtype=torch.float16, torch_dtype=torch.float16,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
use_flash_attention_2=True use_flash_attention_2=True
).to(0) ).to(0)

View File

@ -1575,7 +1575,7 @@ class AlignModel(AlignPreTrainedModel):
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor( >>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True ... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True
... ) ... )
>>> outputs = model(**inputs) >>> outputs = model(**inputs)

View File

@ -19,11 +19,7 @@ Image/Text processor class for ALIGN
from typing import List, Union from typing import List, Union
from ...image_utils import ImageInput from ...image_utils import ImageInput
from ...processing_utils import ( from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
@ -76,8 +72,8 @@ class AlignProcessor(ProcessorMixin):
def __call__( def __call__(
self, self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None, images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None, audio=None,
videos=None, videos=None,
**kwargs: Unpack[AlignProcessorKwargs], **kwargs: Unpack[AlignProcessorKwargs],
@ -90,13 +86,13 @@ class AlignProcessor(ProcessorMixin):
to the doctsring of the above two methods for more information. to the doctsring of the above two methods for more information.
Args: Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`): text (`str`, `List[str]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*): return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are: If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects. - `'tf'`: Return TensorFlow `tf.constant` objects.
@ -114,6 +110,9 @@ class AlignProcessor(ProcessorMixin):
""" """
if text is None and images is None: if text is None and images is None:
raise ValueError("You must specify either text or images.") raise ValueError("You must specify either text or images.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs( output_kwargs = self._merge_kwargs(
AlignProcessorKwargs, AlignProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs,

View File

@ -265,7 +265,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n" >>> prompt = "Generate a coco-style caption.\n"
>>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> generated_ids = model.generate(**inputs, max_new_tokens=7) >>> generated_ids = model.generate(**inputs, max_new_tokens=7)

View File

@ -21,9 +21,10 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ...processing_utils import ProcessorMixin from ...image_utils import ImageInput
from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...utils import TensorType, is_torch_available, logging, requires_backends from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_available, logging, requires_backends
if is_torch_available(): if is_torch_available():
@ -49,6 +50,24 @@ TOKEN_POINT_CLOSE_STRING = "<0x03>" # </point>
BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa> BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa>
class FuyuProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_attention_mask": True,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}
def full_unpacked_stream_to_tensor( def full_unpacked_stream_to_tensor(
all_bi_tokens_to_place: List[int], all_bi_tokens_to_place: List[int],
full_unpacked_stream: List["torch.Tensor"], full_unpacked_stream: List["torch.Tensor"],
@ -452,23 +471,11 @@ class FuyuProcessor(ProcessorMixin):
def __call__( def __call__(
self, self,
text=None, images: ImageInput = None,
images=None, text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
add_special_tokens: bool = True, audio=None,
return_attention_mask: bool = True, videos=None,
padding: Union[bool, str, PaddingStrategy] = False, **kwargs: Unpack[FuyuProcessorKwargs],
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> "FuyuBatchFeature": ) -> "FuyuBatchFeature":
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@ -478,13 +485,13 @@ class FuyuProcessor(ProcessorMixin):
of the above two methods for more information. of the above two methods for more information.
Args: Args:
images (`PIL.Image.Image`, `List[PIL.Image.Image]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`): text (`str`, `List[str]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `List[PIL.Image.Image]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
Returns: Returns:
[`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields: [`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
@ -498,31 +505,24 @@ class FuyuProcessor(ProcessorMixin):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
# --- Check input validity --- # --- Check input validity ---
if not return_attention_mask:
raise ValueError("`return_attention_mask=False` is not supported for this model.")
if text is None and images is None: if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be None.") raise ValueError("You have to specify either text or images. Both cannot be None.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
FuyuProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
raise ValueError("`return_attention_mask=False` is not supported for this model.")
if text is not None and images is None: if text is not None and images is None:
logger.warning("You are processing a text with no associated image. Make sure it is intended.") logger.warning("You are processing a text with no associated image. Make sure it is intended.")
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
text_encoding = self.tokenizer( text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding return text_encoding
if text is None and images is not None: if text is None and images is not None:
@ -537,7 +537,8 @@ class FuyuProcessor(ProcessorMixin):
# --- Preprocess images using self.image_processor --- # --- Preprocess images using self.image_processor ---
# FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors # FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
image_encoding = self.image_processor.preprocess(images, return_tensors="pt") output_kwargs["images_kwargs"]["return_tensors"] = "pt"
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
batch_images = image_encoding["images"] batch_images = image_encoding["images"]
image_unpadded_heights = image_encoding["image_unpadded_heights"] image_unpadded_heights = image_encoding["image_unpadded_heights"]
image_unpadded_widths = image_encoding["image_unpadded_widths"] image_unpadded_widths = image_encoding["image_unpadded_widths"]
@ -568,7 +569,7 @@ class FuyuProcessor(ProcessorMixin):
) )
all_encodings.append(sample_encoding) all_encodings.append(sample_encoding)
batch_encoding = self._left_pad_inputs_with_attention_mask( batch_encoding = self._left_pad_inputs_with_attention_mask(
model_inputs=all_encodings, return_attention_mask=return_attention_mask model_inputs=all_encodings, return_attention_mask=True
) )
return FuyuBatchFeature(data=batch_encoding) return FuyuBatchFeature(data=batch_encoding)

View File

@ -17,26 +17,41 @@ Processor class for InstructBLIP. Largely copy of Blip2Processor with addition o
""" """
import os import os
from typing import List, Optional, Union from typing import List, Union
from ...image_processing_utils import BatchFeature from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import ( from ...tokenization_utils_base import (
AddedToken, AddedToken,
BatchEncoding, BatchEncoding,
PaddingStrategy,
PreTokenizedInput, PreTokenizedInput,
TextInput, TextInput,
TruncationStrategy,
) )
from ...utils import TensorType, logging from ...utils import logging
from ..auto import AutoTokenizer from ..auto import AutoTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class InstructBlipProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}
class InstructBlipProcessor(ProcessorMixin): class InstructBlipProcessor(ProcessorMixin):
r""" r"""
Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
@ -72,31 +87,33 @@ class InstructBlipProcessor(ProcessorMixin):
self, self,
images: ImageInput = None, images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True, audio=None,
padding: Union[bool, str, PaddingStrategy] = False, videos=None,
truncation: Union[bool, str, TruncationStrategy] = None, **kwargs: Unpack[InstructBlipProcessorKwargs],
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model. [`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information. Please refer to the docstring of the above two methods for more information.
Args:
images (`ImageInput`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
""" """
if images is None and text is None: if images is None and text is None:
raise ValueError("You have to specify at least images or text.") raise ValueError("You have to specify at least images or text.")
output_kwargs = self._merge_kwargs(
InstructBlipProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
encoding = BatchFeature() encoding = BatchFeature()
if text is not None: if text is not None:
@ -105,24 +122,7 @@ class InstructBlipProcessor(ProcessorMixin):
elif not isinstance(text, list) and not isinstance(text[0], str): elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings") raise ValueError("Invalid input text. Please provide a string, or a list of strings")
_text_encoding = self.tokenizer( _text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=None, # needed to concatenate below
**kwargs,
)
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation # if we know how many query tokens, expand text inside processor. We need this hacky manipulation
# because BLIP expects image tokens to be at the beginning even before BOS token # because BLIP expects image tokens to be at the beginning even before BOS token
@ -145,31 +145,17 @@ class InstructBlipProcessor(ProcessorMixin):
) )
# cast to desired return tensors type after concatenating # cast to desired return tensors type after concatenating
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors) text_encoding = BatchEncoding(
encoding.update(text_encoding) text_encoding, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")
qformer_text_encoding = self.qformer_tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
) )
encoding.update(text_encoding)
qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids") encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask") encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
if images is not None: if images is not None:
image_encoding = self.image_processor(images, return_tensors=return_tensors) image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
encoding.update(image_encoding) encoding.update(image_encoding)
return encoding return encoding

View File

@ -21,10 +21,9 @@ from typing import List, Optional, Tuple, Union
from ...image_processing_utils import BatchFeature from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput, is_batched from ...image_utils import ImageInput, is_batched
from ...processing_utils import ProcessorMixin from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from ...tokenization_utils import AddedToken from ...tokenization_utils import AddedToken
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy from ...tokenization_utils_base import BatchEncoding, TextInput
from ...utils import TensorType
BboxInput = Union[ BboxInput = Union[
@ -35,6 +34,37 @@ BboxInput = Union[
] ]
class Kosmos2ImagesKwargs(ImagesKwargs, total=False):
bboxes: Optional[List[float]]
num_image_tokens: Optional[int]
first_image_token_id: Optional[int]
class Kosmos2TextKwargs(TextKwargs, total=False):
add_eos_token: Optional[bool]
class Kosmos2ProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: Kosmos2TextKwargs
images_kwargs: Kosmos2ImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"verbose": True,
"add_eos_token": False,
},
"images_kwargs": {
"num_image_tokens": 64,
},
}
class Kosmos2Processor(ProcessorMixin): class Kosmos2Processor(ProcessorMixin):
r""" r"""
Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single
@ -56,7 +86,7 @@ class Kosmos2Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"] attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["num_patch_index_tokens"] valid_kwargs = ["num_patch_index_tokens"]
image_processor_class = "CLIPImageProcessor" image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast") tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs): def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs):
tokenizer.return_token_type_ids = False tokenizer.return_token_type_ids = False
@ -107,20 +137,9 @@ class Kosmos2Processor(ProcessorMixin):
self, self,
images: ImageInput = None, images: ImageInput = None,
text: Union[TextInput, List[TextInput]] = None, text: Union[TextInput, List[TextInput]] = None,
bboxes: BboxInput = None, audio=None,
num_image_tokens: Optional[int] = 64, videos=None,
first_image_token_id: Optional[int] = None, **kwargs: Unpack[Kosmos2ProcessorKwargs],
add_special_tokens: bool = True,
add_eos_token: bool = False,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and
@ -145,10 +164,25 @@ class Kosmos2Processor(ProcessorMixin):
if images is None and text is None: if images is None and text is None:
raise ValueError("You have to specify either images or text.") raise ValueError("You have to specify either images or text.")
output_kwargs = self._merge_kwargs(
Kosmos2ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
bboxes = output_kwargs["images_kwargs"].pop("bboxes", None)
num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens", 64)
first_image_token_id = output_kwargs["images_kwargs"].pop("first_image_token_id", None)
add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
add_special_tokens = output_kwargs["text_kwargs"]["add_special_tokens"]
padding = output_kwargs["text_kwargs"]["padding"]
return_tensors = output_kwargs["text_kwargs"].setdefault("return_tensors", None)
encoding = BatchFeature() encoding = BatchFeature()
if images is not None: if images is not None:
image_encoding = self.image_processor(images, return_tensors=return_tensors) image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
encoding.update(image_encoding) encoding.update(image_encoding)
if text is not None: if text is not None:
@ -159,21 +193,18 @@ class Kosmos2Processor(ProcessorMixin):
text = f"{self.tokenizer.bos_token}{text}" text = f"{self.tokenizer.bos_token}{text}"
elif isinstance(text, list): elif isinstance(text, list):
text = [f"{self.tokenizer.bos_token}{s}" for s in text] text = [f"{self.tokenizer.bos_token}{s}" for s in text]
output_kwargs["text_kwargs"]["add_special_tokens"] = (
text_encoding = self.tokenizer( output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token
text=text,
add_special_tokens=(add_special_tokens and add_eos_token),
padding=padding and images is None,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of if images is None else pad_to_multiple_of,
return_attention_mask=return_attention_mask,
verbose=verbose,
return_tensors=return_tensors if images is None else None,
**kwargs,
) )
output_kwargs["text_kwargs"]["padding"] = padding if images is None else False
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors if images is None else None
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
encoding.update(text_encoding) encoding.update(text_encoding)
output_kwargs["text_kwargs"]["add_special_tokens"] = add_special_tokens
output_kwargs["text_kwargs"]["padding"] = padding
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
if text is not None and images is not None: if text is not None and images is not None:
# Use the id of the first token after <unk> # Use the id of the first token after <unk>
if first_image_token_id is None: if first_image_token_id is None:
@ -218,18 +249,12 @@ class Kosmos2Processor(ProcessorMixin):
) )
_, min_len_not_padded = sorted_length[0] _, min_len_not_padded = sorted_length[0]
idx, _ = sorted_length[-1] idx, _ = sorted_length[-1]
output_kwargs["text_kwargs"]["add_special_tokens"] = (
text_encoding = self.tokenizer( output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token
text=[text[idx]],
add_special_tokens=(add_special_tokens and add_eos_token),
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
return_tensors=None,
**kwargs,
) )
output_kwargs["text_kwargs"]["return_tensors"] = None
text_encoding = self.tokenizer(text=[text[idx]], **output_kwargs["text_kwargs"])
max_len_padded = len(text_encoding.input_ids[0]) max_len_padded = len(text_encoding.input_ids[0])
if min_len_not_padded != max_len_padded: if min_len_not_padded != max_len_padded:

View File

@ -715,7 +715,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
image_patches = self.get_image_patches( image_patches = self.get_image_patches(
image, image,
image_grid_pinpoints, image_grid_pinpoints,
size=(size["shortest_edge"], size["shortest_edge"]), size=(size["shortest_edge"], size["shortest_edge"])
if "shortest_edge" in size
else (min(size["height"], size["width"]), min(size["height"], size["width"])),
patch_size=crop_size["height"], patch_size=crop_size["height"],
resample=resample, resample=resample,
data_format=input_data_format, data_format=input_data_format,

View File

@ -763,7 +763,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate >>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30) >>> generate_ids = model.generate(**inputs, max_length=30)

View File

@ -16,19 +16,30 @@
Processor class for LLaVa-NeXT. Processor class for LLaVa-NeXT.
""" """
from typing import List, Optional, Union from typing import List, Union
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, get_image_size, to_numpy_array from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import TensorType, logging from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
"do_pad": True,
},
}
class LlavaNextProcessor(ProcessorMixin): class LlavaNextProcessor(ProcessorMixin):
r""" r"""
Constructs a LLaVa-NeXT processor which wraps a LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor. Constructs a LLaVa-NeXT processor which wraps a LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.
@ -74,13 +85,11 @@ class LlavaNextProcessor(ProcessorMixin):
def __call__( def __call__(
self, self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: ImageInput = None, images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
truncation: Union[bool, str, TruncationStrategy] = None, audio=None,
max_length: Optional[int] = None, videos=None,
do_pad: Optional[bool] = True, **kwargs: Unpack[LlavaNextProcessorKwargs],
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature: ) -> BatchFeature:
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@ -90,36 +99,13 @@ class LlavaNextProcessor(ProcessorMixin):
of the above two methods for more information. of the above two methods for more information.
Args: Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`): text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
do_pad (`bool`, *optional*, defaults to self.do_pad):
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns: Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields: [`BatchFeature`]: A [`BatchFeature`] with the following fields:
@ -130,8 +116,18 @@ class LlavaNextProcessor(ProcessorMixin):
`None`). `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
if images is None and text is None:
raise ValueError("You have to specify at least images or text.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
LlavaNextProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None: if images is not None:
image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors) image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
else: else:
image_inputs = {} image_inputs = {}
@ -164,13 +160,7 @@ class LlavaNextProcessor(ProcessorMixin):
prompt_strings.append(sample) prompt_strings.append(sample)
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings] prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
text_inputs = self.tokenizer( text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
return BatchFeature(data={**text_inputs, **image_inputs}) return BatchFeature(data={**text_inputs, **image_inputs})

View File

@ -18,9 +18,34 @@ Processor class for Pix2Struct.
from typing import List, Optional, Union from typing import List, Optional, Union
from ...processing_utils import ProcessorMixin from ...feature_extraction_utils import BatchFeature
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...utils import TensorType from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
class Pix2StructImagesKwargs(ImagesKwargs, total=False):
max_patches: Optional[int]
header_text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Pix2StructImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {
"max_patches": 2048,
},
}
class Pix2StructProcessor(ProcessorMixin): class Pix2StructProcessor(ProcessorMixin):
@ -50,23 +75,10 @@ class Pix2StructProcessor(ProcessorMixin):
self, self,
images=None, images=None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True, audio=None,
padding: Union[bool, str, PaddingStrategy] = False, videos=None,
truncation: Union[bool, str, TruncationStrategy] = None, **kwargs: Unpack[Pix2StructProcessorKwargs],
max_length: Optional[int] = None, ) -> Union[BatchEncoding, BatchFeature]:
max_patches: Optional[int] = 2048,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchEncoding:
""" """
This method uses [`Pix2StructImageProcessor.preprocess`] method to prepare image(s) for the model, and This method uses [`Pix2StructImageProcessor.preprocess`] method to prepare image(s) for the model, and
[`T5TokenizerFast.__call__`] to prepare text for the model. [`T5TokenizerFast.__call__`] to prepare text for the model.
@ -76,59 +88,27 @@ class Pix2StructProcessor(ProcessorMixin):
if images is None and text is None: if images is None and text is None:
raise ValueError("You have to specify either images or text.") raise ValueError("You have to specify either images or text.")
output_kwargs = self._merge_kwargs(
Pix2StructProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Get only text # Get only text
if images is None and not self.image_processor.is_vqa: if images is None and not self.image_processor.is_vqa:
self.current_processor = self.tokenizer self.current_processor = self.tokenizer
text_encoding = self.tokenizer( text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
return text_encoding return text_encoding
if not self.image_processor.is_vqa: if not self.image_processor.is_vqa:
# add pixel_values # add pixel_values
encoding_image_processor = self.image_processor( encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
images, return_tensors=return_tensors, max_patches=max_patches, **kwargs
)
else: else:
# add pixel_values and bbox # add pixel_values and bbox
encoding_image_processor = self.image_processor( output_kwargs["images_kwargs"].setdefault("header_text", text)
images, return_tensors=return_tensors, max_patches=max_patches, header_text=text, **kwargs encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
)
if text is not None and not self.image_processor.is_vqa: if text is not None and not self.image_processor.is_vqa:
text_encoding = self.tokenizer( text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
if "attention_mask" in text_encoding: if "attention_mask" in text_encoding:
text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask") text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask")

View File

@ -626,7 +626,7 @@ class AlignModelIntegrationTest(unittest.TestCase):
image = prepare_img() image = prepare_img()
texts = ["a photo of a cat", "a photo of a dog"] texts = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=texts, images=image, return_tensors="pt").to(torch_device) inputs = processor(images=image, text=texts, return_tensors="pt").to(torch_device)
# forward pass # forward pass
with torch.no_grad(): with torch.no_grad():

View File

@ -330,7 +330,7 @@ class FuyuModelIntegrationTest(unittest.TestCase):
text_prompt_coco_captioning = "Generate a coco-style caption.\n" text_prompt_coco_captioning = "Generate a coco-style caption.\n"
inputs = processor(text=text_prompt_coco_captioning, images=image, return_tensors="pt") inputs = processor(images=image, text=text_prompt_coco_captioning, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=10) generated_ids = model.generate(**inputs, max_new_tokens=10)
# take the last 8 tokens (in order to skip special \n\x04 characters) and decode them # take the last 8 tokens (in order to skip special \n\x04 characters) and decode them

View File

@ -1,17 +1,25 @@
import io import io
import tempfile
import unittest import unittest
import requests import requests
from transformers import AutoTokenizer, is_torch_available, is_vision_available from transformers import (
from transformers.testing_utils import require_torch, require_torch_gpu, slow AutoProcessor,
AutoTokenizer,
FuyuImageProcessor,
FuyuProcessor,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_torch, require_vision
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
if is_vision_available() and is_torch_available():
from transformers import FuyuImageProcessor, FuyuProcessor
if is_torch_available(): if is_torch_available():
import torch import torch
@ -20,21 +28,36 @@ if is_torch_available():
@require_torch @require_torch
@require_torch_gpu @require_vision
@slow class FuyuProcessingTest(ProcessorTesterMixin, unittest.TestCase):
class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here? processor_class = FuyuProcessor
""" """
def setUp(self): def setUp(self):
pretrained_model_name = "adept/fuyu-8b" self.tmpdirname = tempfile.mkdtemp()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
self.image_processor = FuyuImageProcessor() image_processor = FuyuImageProcessor()
tokenizer = AutoTokenizer.from_pretrained("adept/fuyu-8b")
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)
self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer)
self.text_prompt = "Generate a coco-style caption.\\n" self.text_prompt = "Generate a coco-style caption.\\n"
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png" bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content)) self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
def get_processor(self):
image_processor = FuyuImageProcessor()
tokenizer = AutoTokenizer.from_pretrained("adept/fuyu-8b")
processor = FuyuProcessor(image_processor, tokenizer, **self.prepare_processor_dict())
return processor
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def test_fuyu_processing(self): def test_fuyu_processing(self):
""" """
Test to ensure that the standard processing on a gold example matches adept's code. Test to ensure that the standard processing on a gold example matches adept's code.
@ -43,7 +66,7 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64) EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64) EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
one_image_bus_model_inputs = self.processor(text=self.text_prompt, images=self.bus_image_pil) one_image_bus_model_inputs = self.get_processor()(text=self.text_prompt, images=self.bus_image_pil)
# fmt: on # fmt: on
torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS) torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS)
@ -53,8 +76,8 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
""" """
Test to check processor works with just text input Test to check processor works with just text input
""" """
processor_outputs = self.processor(text=self.text_prompt) processor_outputs = self.get_processor()(text=self.text_prompt)
tokenizer_outputs = self.tokenizer(self.text_prompt) tokenizer_outputs = self.get_tokenizer()(self.text_prompt)
self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"]) self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"])
def test_fuyu_processing_no_text(self): def test_fuyu_processing_no_text(self):
@ -90,7 +113,7 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
]).to(torch.int64) ]).to(torch.int64)
# fmt: on # fmt: on
processor_outputs = self.processor(images=self.bus_image_pil) processor_outputs = self.get_processor()(images=self.bus_image_pil)
self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all()) self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all())
def test_fuyu_processing_multiple_image_sample(self): def test_fuyu_processing_multiple_image_sample(self):
@ -107,7 +130,7 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
# Batch of two images - equally sized # Batch of two images - equally sized
images = [self.bus_image_pil, self.bus_image_pil] images = [self.bus_image_pil, self.bus_image_pil]
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images) processor_outputs = self.get_processor()(text=[self.text_prompt, self.text_prompt], images=images)
self.assertTrue( self.assertTrue(
( (
@ -124,18 +147,18 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
# Processes single images with different sizes as expected # Processes single images with different sizes as expected
images = [self.bus_image_pil] images = [self.bus_image_pil]
processor_outputs = self.processor(text=self.text_prompt, images=images) processor_outputs = self.get_processor()(text=self.text_prompt, images=images)
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all()) self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all())
self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all()) self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all())
images = [self.bus_image_pil.resize((64, 300))] images = [self.bus_image_pil.resize((64, 300))]
processor_outputs = self.processor(text=self.text_prompt, images=images) processor_outputs = self.get_processor()(text=self.text_prompt, images=images)
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all()) self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all())
self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all()) self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all())
# Batch of two images - different sizes. Left-pads the smaller image inputs # Batch of two images - different sizes. Left-pads the smaller image inputs
images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))] images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))]
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images) processor_outputs = self.get_processor()(text=[self.text_prompt, self.text_prompt], images=images)
padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1] padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1]
padded_single_resized_image_patch = torch.cat( padded_single_resized_image_patch = torch.cat(
@ -156,6 +179,155 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all()) self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all())
self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all()) self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all())
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# Fuyu uses tokenizer kwargs only when image is None.
image_input = None
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
@unittest.skip("Fuyu processor does not support image_processor kwargs")
def test_image_processor_defaults_preserved_by_image_kwargs(self):
pass
@unittest.skip("Fuyu processor does not support image_processor kwargs")
def test_kwargs_overrides_default_image_processor_kwargs(self):
pass
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# Fuyu uses tokenizer kwargs only when image is None.
image_input = None
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
# Rewrite as Fuyu image processor does not return pixel values
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# Fuyu uses tokenizer kwargs only when image is None.
image_input = None
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Rewrite as Fuyu image processor does not return pixel values
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# Fuyu uses tokenizer kwargs only when image is None.
image_input = None
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# Fuyu uses tokenizer kwargs only when image is None.
image_input = None
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="max_length",
max_length=76,
)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
# Fuyu uses tokenizer kwargs only when image is None.
image_input = None
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="longest",
max_length=76,
)
self.assertEqual(len(inputs["input_ids"][0]), 6)
@require_torch @require_torch
class TestImageTextProcessingUtils(unittest.TestCase): class TestImageTextProcessingUtils(unittest.TestCase):

View File

@ -17,7 +17,7 @@ import unittest
import pytest import pytest
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin from ...test_processing_common import ProcessorTesterMixin
@ -179,261 +179,3 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
list(inputs.keys()), list(inputs.keys()),
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"], ["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
) )
# Override as InstructBlipProcessor has qformer_tokenizer
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", padding="longest")
qformer_tokenizer = self.get_component("qformer_tokenizer", padding="longest")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 6)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_doubly_passed_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
image_input = self.prepare_image_inputs()
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
images_kwargs={"size": {"height": 222, "width": 222}},
size={"height": 214, "width": 214},
)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Override as InstructBlipProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
def test_overlapping_text_kwargs_handling(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_kwargs = {}
processor_kwargs["image_processor"] = self.get_component("image_processor")
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
if "video_processor" in self.processor_class.attributes:
processor_kwargs["video_processor"] = self.get_component("video_processor")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="max_length",
text_kwargs={"padding": "do_not_pad"},
)

View File

@ -15,18 +15,15 @@ import shutil
import tempfile import tempfile
import unittest import unittest
import numpy as np
import pytest import pytest
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin from ...test_processing_common import ProcessorTesterMixin
if is_vision_available(): if is_vision_available():
from PIL import Image
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
BertTokenizerFast, BertTokenizerFast,
@ -65,16 +62,6 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
# Ignore copy
def prepare_image_inputs(self):
"""This function prepares a list of list of PIL images"""
video_inputs = [
[Image.fromarray(np.random.randint(255, size=(30, 400, 3), dtype=np.uint8)) for _ in range(5)]
for _ in range(2)
]
return video_inputs
def test_save_load_pretrained_additional_features(self): def test_save_load_pretrained_additional_features(self):
processor = InstructBlipVideoProcessor( processor = InstructBlipVideoProcessor(
tokenizer=self.get_tokenizer(), tokenizer=self.get_tokenizer(),
@ -193,261 +180,3 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
list(inputs.keys()), list(inputs.keys()),
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"], ["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
) )
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", padding="longest")
qformer_tokenizer = self.get_component("qformer_tokenizer", padding="longest")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 6)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_doubly_passed_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
image_input = self.prepare_image_inputs()
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
images_kwargs={"size": {"height": 222, "width": 222}},
size={"height": 214, "width": 214},
)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Override as InstructBlipVideoProcessor has qformer_tokenizer
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
def test_overlapping_text_kwargs_handling(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_kwargs = {}
processor_kwargs["image_processor"] = self.get_component("image_processor")
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
if "video_processor" in self.processor_class.attributes:
processor_kwargs["video_processor"] = self.get_component("video_processor")
qformer_tokenizer = self.get_component("qformer_tokenizer")
processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="max_length",
text_kwargs={"padding": "do_not_pad"},
)

View File

@ -61,7 +61,7 @@ class Kosmos2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
image_processor = CLIPImageProcessor() image_processor = CLIPImageProcessor(do_center_crop=False)
# We have a SentencePiece fixture for testing # We have a SentencePiece fixture for testing
slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB) slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB)
@ -487,3 +487,147 @@ class Kosmos2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertListEqual(outputs.input_ids.numpy().tolist()[-1], EXPECTED_IDS_BATCH[-1]) self.assertListEqual(outputs.input_ids.numpy().tolist()[-1], EXPECTED_IDS_BATCH[-1])
self.assertListEqual(outputs.attention_mask.numpy().tolist()[-1], EXPECTED_MASK_BATCH[-1]) self.assertListEqual(outputs.attention_mask.numpy().tolist()[-1], EXPECTED_MASK_BATCH[-1])
self.assertListEqual(outputs.image_embeds_position_mask.numpy().tolist()[-1], EXPECTED_IMG_POS_MASK_BATCH[-1]) self.assertListEqual(outputs.image_embeds_position_mask.numpy().tolist()[-1], EXPECTED_IMG_POS_MASK_BATCH[-1])
# Rewrite as Kosmos-2 supports custom padding only when image is None.
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# set image input to None
image_input = None
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
max_length=112,
padding="max_length",
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
# Rewrite to test only image_processor kwargs
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
# Rewrite to test only image_processor kwargs
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
# Rewrite as Kosmos-2 supports custom padding only when image is None.
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# set image input to None
image_input = None
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
# Rewrite as Kosmos-2 supports custom padding only when image is None.
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
# set image input to None
image_input = None
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
padding="max_length",
max_length=76,
)
self.assertEqual(len(inputs["input_ids"][0]), 76)
# Rewrite as Kosmos-2 supports custom padding only when image is None.
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
# set image input to None
image_input = None
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(len(inputs["input_ids"][0]), 10)

View File

@ -338,7 +338,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
load_in_4bit=True, load_in_4bit=True,
) )
inputs = self.processor(self.prompt, self.image, return_tensors="pt") inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt")
# verify inputs against original implementation # verify inputs against original implementation
filepath = hf_hub_download( filepath = hf_hub_download(
@ -390,8 +390,8 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
cats_image = Image.open(requests.get(url, stream=True).raw) cats_image = Image.open(requests.get(url, stream=True).raw)
inputs = self.processor( inputs = self.processor(
[self.prompt, self.prompt],
images=[self.image, cats_image], images=[self.image, cats_image],
text=[self.prompt, self.prompt],
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
).to(torch_device) ).to(torch_device)
@ -415,7 +415,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
) )
prompt_with_unk = "[INST] <image>\nWhat is shown in this <unk> image? [/INST]" prompt_with_unk = "[INST] <image>\nWhat is shown in this <unk> image? [/INST]"
inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt") inputs = self.processor(images=self.image, text=prompt_with_unk, return_tensors="pt")
# verify single forward pass # verify single forward pass
inputs = inputs.to(torch_device) inputs = inputs.to(torch_device)
@ -445,7 +445,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs = self.processor( inputs = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True
).to(torch_device) ).to(torch_device)
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
@ -498,10 +498,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs_batched = self.processor( inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True
).to(torch_device) ).to(torch_device)
inputs_single = self.processor(self.prompt, images=lowres_img, return_tensors="pt", padding=True).to( inputs_single = self.processor(images=lowres_img, text=self.prompt, return_tensors="pt", padding=True).to(
torch_device torch_device
) )
@ -527,7 +527,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw) lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
inputs_batched = self.processor( inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True
).to(torch_device) ).to(torch_device)
# model is in eval mode by default so we should get pad on the left side # model is in eval mode by default so we should get pad on the left side
@ -607,13 +607,13 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
# check processing with expansion of inputs # check processing with expansion of inputs
processor.vision_feature_select_strategy = "default" processor.vision_feature_select_strategy = "default"
processor.patch_size = 14 processor.patch_size = 14
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2356) self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2356)
# check processing without expansion of inputs (legacy behavior) # check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None processor.vision_feature_select_strategy = None
processor.patch_size = None processor.patch_size = None
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs.input_ids.shape[-1] == 17) self.assertTrue(inputs.input_ids.shape[-1] == 17)
# generate exactly 20 tokens # generate exactly 20 tokens

View File

@ -18,7 +18,9 @@ import unittest
import torch import torch
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor
from transformers.testing_utils import require_vision from transformers.testing_utils import (
require_vision,
)
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin from ...test_processing_common import ProcessorTesterMixin

View File

@ -37,6 +37,8 @@ if is_vision_available():
@require_torch @require_torch
class Pix2StructProcessorTest(ProcessorTesterMixin, unittest.TestCase): class Pix2StructProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Pix2StructProcessor processor_class = Pix2StructProcessor
text_input_name = "decoder_input_ids"
images_input_name = "flattened_patches"
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
@ -180,3 +182,148 @@ class Pix2StructProcessorTest(ProcessorTesterMixin, unittest.TestCase):
# For now the processor supports only ["flattened_patches", "input_ids", "attention_mask", "decoder_attention_mask"] # For now the processor supports only ["flattened_patches", "input_ids", "attention_mask", "decoder_attention_mask"]
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"]) self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
# Rewrite as pix2struct processor return "flattened_patches" and not "pixel_values"
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", max_patches=1024, patch_size={"height": 8, "width": 8})
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["flattened_patches"][0][0]), 194)
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
# Rewrite as pix2struct processor return "flattened_patches" and not "pixel_values"
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", max_patches=4096)
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, max_patches=1024)
self.assertEqual(len(inputs["flattened_patches"][0]), 1024)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
max_patches=1024,
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
self.assertEqual(len(inputs["decoder_input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
max_patches=1024,
padding="longest",
max_length=76,
)
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
self.assertEqual(len(inputs["decoder_input_ids"][0]), 5)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"max_patches": 1024},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
self.assertEqual(len(inputs["decoder_input_ids"][0]), 76)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"max_patches": 1024},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
self.assertEqual(len(inputs["decoder_input_ids"][0]), 76)

View File

@ -18,9 +18,7 @@ import unittest
import requests import requests
import torch import torch
from transformers.testing_utils import ( from transformers.testing_utils import require_vision
require_vision,
)
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin from ...test_processing_common import ProcessorTesterMixin

View File

@ -66,7 +66,7 @@ class ProcessorTesterMixin:
component_class = processor_class_from_name(component_class_name) component_class = processor_class_from_name(component_class_name)
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
if attribute == "tokenizer" and not component.pad_token: if "tokenizer" in attribute and not component.pad_token:
component.pad_token = "[TEST_PAD]" component.pad_token = "[TEST_PAD]"
if component.pad_token_id is None: if component.pad_token_id is None:
component.pad_token_id = 0 component.pad_token_id = 0
@ -322,14 +322,8 @@ class ProcessorTesterMixin:
def test_overlapping_text_kwargs_handling(self): def test_overlapping_text_kwargs_handling(self):
if "image_processor" not in self.processor_class.attributes: if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}") self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_kwargs = {} processor_components = self.prepare_components()
processor_kwargs["image_processor"] = self.get_component("image_processor") processor = self.processor_class(**processor_components)
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
if "video_processor" in self.processor_class.attributes:
processor_kwargs["video_processor"] = self.get_component("video_processor")
processor = self.processor_class(**processor_kwargs)
self.skip_processor_without_typed_kwargs(processor) self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer" input_str = "lower newer"