mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Uniformize kwargs for image-text-to-text processors (#32544)
* uniformize FUYU processor kwargs * Uniformize instructblip processor kwargs * Fix processor kwargs and tests Fuyu, InstructBlip, Kosmos2 * Uniformize llava_next processor * Fix save_load test for processor with chat_template only as extra init args * Fix import Unpack * Fix Fuyu Processor import * Fix FuyuProcessor import * Fix FuyuProcessor * Add defaults for specific kwargs kosmos2 * Fix Udop to return BatchFeature instead of BatchEncoding and uniformize kwargs * Add tests processor Udop * remove Copied from in processing Udop as change of input orders caused by BatchEncoding -> BatchFeature * Fix overwrite tests kwargs processors * Add warnings and BC for changes in processor inputs order, change docs, add BC for text_pair as arg for Udop * Fix processing test fuyu * remove unnecessary pad_token check in instructblip ProcessorTest * Fix BC tests and cleanup * FIx imports fuyu * Uniformize Pix2Struct * Fix wrong name for FuyuProcessorKwargs * Fix slow tests reversed inputs align fuyu llava-next, change udop warning * Fix wrong logging import udop * Add check images text input order * Fix copies * change text pair handling when positional arg * rebase on main, fix imports in test_processing_common * remove optional args and udop uniformization from this PR * fix failing tests * remove unnecessary test, fix processing utils and test processing common * cleanup Unpack * cleanup * fix conflict grounding dino
This commit is contained in:
parent
fa0bb0fe76
commit
5f0c181f4e
@ -46,7 +46,7 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
candidate_labels = ["an image of a cat", "an image of a dog"]
|
||||
|
||||
inputs = processor(text=candidate_labels, images=image, return_tensors="pt")
|
||||
inputs = processor(images=image ,text=candidate_labels, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
@ -18,16 +18,16 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
## Overview
|
||||
|
||||
The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar.
|
||||
The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar.
|
||||
|
||||
The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs.
|
||||
The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs.
|
||||
|
||||
By treating image tokens like text tokens and using a special image-newline character, the model knows when an image line ends. Image positional embeddings are removed. This avoids the need for different training phases for various image resolutions. With 8 billion parameters and licensed under CC-BY-NC, Fuyu-8B is notable for its ability to handle both text and images, its impressive context size of 16K, and its overall performance.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The `Fuyu` models were trained using `bfloat16`, but the original inference uses `float16` The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be
|
||||
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
|
||||
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
|
||||
|
||||
The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be cast to the default `dtype` of `torch` (becomes `torch.float32`). Users should specify the `torch_dtype` they want, and if they don't it will be `torch.float32`.
|
||||
|
||||
@ -56,7 +56,7 @@ tar -xvf 8b_base_model_release.tar
|
||||
```
|
||||
Then, model can be loaded via:
|
||||
|
||||
```py
|
||||
```py
|
||||
from transformers import FuyuConfig, FuyuForCausalLM
|
||||
model_config = FuyuConfig()
|
||||
model = FuyuForCausalLM(model_config).from_pretrained('/output/path')
|
||||
@ -81,7 +81,7 @@ text_prompt = "Generate a coco-style caption.\\n"
|
||||
|
||||
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
|
||||
bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
|
||||
inputs_to_model = processor(text=text_prompt, images=bus_image_pil)
|
||||
inputs_to_model = processor(images=bus_image_pil, text=text_prompt)
|
||||
|
||||
|
||||
```
|
||||
@ -90,7 +90,7 @@ This model was contributed by [Molbap](https://huggingface.co/Molbap).
|
||||
The original code can be found [here](https://github.com/persimmon-ai-labs/adept-inference).
|
||||
|
||||
- Fuyu uses a `sentencepiece` based tokenizer, with a `Unigram` model. It supports bytefallback, which is only available in `tokenizers==0.14.0` for the fast tokenizer.
|
||||
The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.
|
||||
The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.
|
||||
|
||||
- The authors suggest to use the following prompt for image captioning: `f"Generate a coco-style caption.\\n"`
|
||||
|
||||
|
@ -133,7 +133,7 @@ import requests
|
||||
|
||||
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
model.to("cuda:0")
|
||||
|
||||
# prepare image and text prompt, using the appropriate prompt template
|
||||
@ -150,7 +150,7 @@ conversation = [
|
||||
},
|
||||
]
|
||||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||||
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
|
||||
inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")
|
||||
|
||||
# autoregressively complete prompt
|
||||
output = model.generate(**inputs, max_new_tokens=100)
|
||||
@ -222,7 +222,7 @@ prompts = [prompt_1, prompt_2]
|
||||
|
||||
# We can simply feed images in the order they have to be used in the text prompt
|
||||
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
|
||||
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
|
||||
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=30)
|
||||
@ -266,8 +266,8 @@ First make sure to install flash-attn. Refer to the [original repository of Flas
|
||||
from transformers import LlavaNextForConditionalGeneration
|
||||
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
use_flash_attention_2=True
|
||||
).to(0)
|
||||
|
@ -1575,7 +1575,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(
|
||||
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
||||
... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True
|
||||
... )
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
|
@ -19,11 +19,7 @@ Image/Text processor class for ALIGN
|
||||
from typing import List, Union
|
||||
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import (
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
)
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
@ -76,8 +72,8 @@ class AlignProcessor(ProcessorMixin):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[AlignProcessorKwargs],
|
||||
@ -90,13 +86,13 @@ class AlignProcessor(ProcessorMixin):
|
||||
to the doctsring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
@ -114,6 +110,9 @@ class AlignProcessor(ProcessorMixin):
|
||||
"""
|
||||
if text is None and images is None:
|
||||
raise ValueError("You must specify either text or images.")
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
AlignProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
|
@ -265,7 +265,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> prompt = "Generate a coco-style caption.\n"
|
||||
|
||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> generated_ids = model.generate(**inputs, max_new_tokens=7)
|
||||
|
@ -21,9 +21,10 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy
|
||||
from ...utils import TensorType, is_torch_available, logging, requires_backends
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_torch_available, logging, requires_backends
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -49,6 +50,24 @@ TOKEN_POINT_CLOSE_STRING = "<0x03>" # </point>
|
||||
BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa>
|
||||
|
||||
|
||||
class FuyuProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"stride": 0,
|
||||
"return_attention_mask": True,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_token_type_ids": False,
|
||||
"return_length": False,
|
||||
"verbose": True,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
def full_unpacked_stream_to_tensor(
|
||||
all_bi_tokens_to_place: List[int],
|
||||
full_unpacked_stream: List["torch.Tensor"],
|
||||
@ -452,23 +471,11 @@ class FuyuProcessor(ProcessorMixin):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text=None,
|
||||
images=None,
|
||||
add_special_tokens: bool = True,
|
||||
return_attention_mask: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_token_type_ids: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
images: ImageInput = None,
|
||||
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[FuyuProcessorKwargs],
|
||||
) -> "FuyuBatchFeature":
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
@ -478,13 +485,13 @@ class FuyuProcessor(ProcessorMixin):
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `List[PIL.Image.Image]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `List[PIL.Image.Image]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
|
||||
Returns:
|
||||
[`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
|
||||
@ -498,31 +505,24 @@ class FuyuProcessor(ProcessorMixin):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
# --- Check input validity ---
|
||||
if not return_attention_mask:
|
||||
raise ValueError("`return_attention_mask=False` is not supported for this model.")
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify either text or images. Both cannot be None.")
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
FuyuProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
|
||||
raise ValueError("`return_attention_mask=False` is not supported for this model.")
|
||||
|
||||
if text is not None and images is None:
|
||||
logger.warning("You are processing a text with no associated image. Make sure it is intended.")
|
||||
self.current_processor = self.tokenizer
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
return text_encoding
|
||||
|
||||
if text is None and images is not None:
|
||||
@ -537,7 +537,8 @@ class FuyuProcessor(ProcessorMixin):
|
||||
# --- Preprocess images using self.image_processor ---
|
||||
|
||||
# FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
|
||||
image_encoding = self.image_processor.preprocess(images, return_tensors="pt")
|
||||
output_kwargs["images_kwargs"]["return_tensors"] = "pt"
|
||||
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
|
||||
batch_images = image_encoding["images"]
|
||||
image_unpadded_heights = image_encoding["image_unpadded_heights"]
|
||||
image_unpadded_widths = image_encoding["image_unpadded_widths"]
|
||||
@ -568,7 +569,7 @@ class FuyuProcessor(ProcessorMixin):
|
||||
)
|
||||
all_encodings.append(sample_encoding)
|
||||
batch_encoding = self._left_pad_inputs_with_attention_mask(
|
||||
model_inputs=all_encodings, return_attention_mask=return_attention_mask
|
||||
model_inputs=all_encodings, return_attention_mask=True
|
||||
)
|
||||
return FuyuBatchFeature(data=batch_encoding)
|
||||
|
||||
|
@ -17,26 +17,41 @@ Processor class for InstructBLIP. Largely copy of Blip2Processor with addition o
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import (
|
||||
AddedToken,
|
||||
BatchEncoding,
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...utils import logging
|
||||
from ..auto import AutoTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class InstructBlipProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"stride": 0,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_token_type_ids": False,
|
||||
"return_length": False,
|
||||
"verbose": True,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class InstructBlipProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
|
||||
@ -72,31 +87,33 @@ class InstructBlipProcessor(ProcessorMixin):
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_token_type_ids: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[InstructBlipProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
|
||||
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify at least images or text.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
InstructBlipProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
encoding = BatchFeature()
|
||||
|
||||
if text is not None:
|
||||
@ -105,24 +122,7 @@ class InstructBlipProcessor(ProcessorMixin):
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
_text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=None, # needed to concatenate below
|
||||
**kwargs,
|
||||
)
|
||||
_text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
|
||||
# because BLIP expects image tokens to be at the beginning even before BOS token
|
||||
@ -145,31 +145,17 @@ class InstructBlipProcessor(ProcessorMixin):
|
||||
)
|
||||
|
||||
# cast to desired return tensors type after concatenating
|
||||
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
|
||||
encoding.update(text_encoding)
|
||||
qformer_text_encoding = self.qformer_tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
text_encoding = BatchEncoding(
|
||||
text_encoding, tensor_type=output_kwargs["common_kwargs"].get("return_tensors")
|
||||
)
|
||||
|
||||
encoding.update(text_encoding)
|
||||
qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
||||
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
||||
|
||||
if images is not None:
|
||||
image_encoding = self.image_processor(images, return_tensors=return_tensors)
|
||||
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
encoding.update(image_encoding)
|
||||
|
||||
return encoding
|
||||
|
@ -21,10 +21,9 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_batched
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
|
||||
from ...tokenization_utils import AddedToken
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...tokenization_utils_base import BatchEncoding, TextInput
|
||||
|
||||
|
||||
BboxInput = Union[
|
||||
@ -35,6 +34,37 @@ BboxInput = Union[
|
||||
]
|
||||
|
||||
|
||||
class Kosmos2ImagesKwargs(ImagesKwargs, total=False):
|
||||
bboxes: Optional[List[float]]
|
||||
num_image_tokens: Optional[int]
|
||||
first_image_token_id: Optional[int]
|
||||
|
||||
|
||||
class Kosmos2TextKwargs(TextKwargs, total=False):
|
||||
add_eos_token: Optional[bool]
|
||||
|
||||
|
||||
class Kosmos2ProcessorKwargs(ProcessingKwargs, total=False):
|
||||
text_kwargs: Kosmos2TextKwargs
|
||||
images_kwargs: Kosmos2ImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"stride": 0,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_token_type_ids": False,
|
||||
"verbose": True,
|
||||
"add_eos_token": False,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"num_image_tokens": 64,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Kosmos2Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single
|
||||
@ -56,7 +86,7 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["num_patch_index_tokens"]
|
||||
image_processor_class = "CLIPImageProcessor"
|
||||
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs):
|
||||
tokenizer.return_token_type_ids = False
|
||||
@ -107,20 +137,9 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, List[TextInput]] = None,
|
||||
bboxes: BboxInput = None,
|
||||
num_image_tokens: Optional[int] = 64,
|
||||
first_image_token_id: Optional[int] = None,
|
||||
add_special_tokens: bool = True,
|
||||
add_eos_token: bool = False,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Kosmos2ProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and
|
||||
@ -145,10 +164,25 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify either images or text.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Kosmos2ProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
bboxes = output_kwargs["images_kwargs"].pop("bboxes", None)
|
||||
num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens", 64)
|
||||
first_image_token_id = output_kwargs["images_kwargs"].pop("first_image_token_id", None)
|
||||
add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
|
||||
|
||||
add_special_tokens = output_kwargs["text_kwargs"]["add_special_tokens"]
|
||||
padding = output_kwargs["text_kwargs"]["padding"]
|
||||
return_tensors = output_kwargs["text_kwargs"].setdefault("return_tensors", None)
|
||||
|
||||
encoding = BatchFeature()
|
||||
|
||||
if images is not None:
|
||||
image_encoding = self.image_processor(images, return_tensors=return_tensors)
|
||||
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
encoding.update(image_encoding)
|
||||
|
||||
if text is not None:
|
||||
@ -159,21 +193,18 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
text = f"{self.tokenizer.bos_token}{text}"
|
||||
elif isinstance(text, list):
|
||||
text = [f"{self.tokenizer.bos_token}{s}" for s in text]
|
||||
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=(add_special_tokens and add_eos_token),
|
||||
padding=padding and images is None,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of if images is None else pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors if images is None else None,
|
||||
**kwargs,
|
||||
output_kwargs["text_kwargs"]["add_special_tokens"] = (
|
||||
output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token
|
||||
)
|
||||
output_kwargs["text_kwargs"]["padding"] = padding if images is None else False
|
||||
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors if images is None else None
|
||||
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
encoding.update(text_encoding)
|
||||
|
||||
output_kwargs["text_kwargs"]["add_special_tokens"] = add_special_tokens
|
||||
output_kwargs["text_kwargs"]["padding"] = padding
|
||||
output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
|
||||
|
||||
if text is not None and images is not None:
|
||||
# Use the id of the first token after <unk>
|
||||
if first_image_token_id is None:
|
||||
@ -218,18 +249,12 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
)
|
||||
_, min_len_not_padded = sorted_length[0]
|
||||
idx, _ = sorted_length[-1]
|
||||
|
||||
text_encoding = self.tokenizer(
|
||||
text=[text[idx]],
|
||||
add_special_tokens=(add_special_tokens and add_eos_token),
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
return_tensors=None,
|
||||
**kwargs,
|
||||
output_kwargs["text_kwargs"]["add_special_tokens"] = (
|
||||
output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token
|
||||
)
|
||||
output_kwargs["text_kwargs"]["return_tensors"] = None
|
||||
|
||||
text_encoding = self.tokenizer(text=[text[idx]], **output_kwargs["text_kwargs"])
|
||||
max_len_padded = len(text_encoding.input_ids[0])
|
||||
|
||||
if min_len_not_padded != max_len_padded:
|
||||
|
@ -715,7 +715,9 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
image_patches = self.get_image_patches(
|
||||
image,
|
||||
image_grid_pinpoints,
|
||||
size=(size["shortest_edge"], size["shortest_edge"]),
|
||||
size=(size["shortest_edge"], size["shortest_edge"])
|
||||
if "shortest_edge" in size
|
||||
else (min(size["height"], size["width"]), min(size["height"], size["width"])),
|
||||
patch_size=crop_size["height"],
|
||||
resample=resample,
|
||||
data_format=input_data_format,
|
||||
|
@ -763,7 +763,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
|
@ -16,19 +16,30 @@
|
||||
Processor class for LLaVa-NeXT.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType, logging
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"do_pad": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class LlavaNextProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a LLaVa-NeXT processor which wraps a LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.
|
||||
@ -74,13 +85,11 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
||||
images: ImageInput = None,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
do_pad: Optional[bool] = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[LlavaNextProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
@ -90,36 +99,13 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
do_pad (`bool`, *optional*, defaults to self.do_pad):
|
||||
Whether to pad the image. If `True` will pad the images in the batch to the largest image in the batch
|
||||
and create a pixel mask. Padding will be applied to the bottom and right of the image with zeros.
|
||||
truncation (`bool`, *optional*):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
@ -130,8 +116,18 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify at least images or text.")
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
LlavaNextProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, do_pad=do_pad, return_tensors=return_tensors)
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
else:
|
||||
image_inputs = {}
|
||||
|
||||
@ -164,13 +160,7 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
prompt_strings.append(sample)
|
||||
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt_strings,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||
|
||||
|
@ -18,9 +18,34 @@ Processor class for Pix2Struct.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class Pix2StructImagesKwargs(ImagesKwargs, total=False):
|
||||
max_patches: Optional[int]
|
||||
header_text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
|
||||
|
||||
|
||||
class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Pix2StructImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": False,
|
||||
"stride": 0,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_token_type_ids": False,
|
||||
"return_length": False,
|
||||
"verbose": True,
|
||||
},
|
||||
"images_kwargs": {
|
||||
"max_patches": 2048,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Pix2StructProcessor(ProcessorMixin):
|
||||
@ -50,23 +75,10 @@ class Pix2StructProcessor(ProcessorMixin):
|
||||
self,
|
||||
images=None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_patches: Optional[int] = 2048,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_token_type_ids: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[Pix2StructProcessorKwargs],
|
||||
) -> Union[BatchEncoding, BatchFeature]:
|
||||
"""
|
||||
This method uses [`Pix2StructImageProcessor.preprocess`] method to prepare image(s) for the model, and
|
||||
[`T5TokenizerFast.__call__`] to prepare text for the model.
|
||||
@ -76,59 +88,27 @@ class Pix2StructProcessor(ProcessorMixin):
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify either images or text.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Pix2StructProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
# Get only text
|
||||
if images is None and not self.image_processor.is_vqa:
|
||||
self.current_processor = self.tokenizer
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
return text_encoding
|
||||
|
||||
if not self.image_processor.is_vqa:
|
||||
# add pixel_values
|
||||
encoding_image_processor = self.image_processor(
|
||||
images, return_tensors=return_tensors, max_patches=max_patches, **kwargs
|
||||
)
|
||||
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
else:
|
||||
# add pixel_values and bbox
|
||||
encoding_image_processor = self.image_processor(
|
||||
images, return_tensors=return_tensors, max_patches=max_patches, header_text=text, **kwargs
|
||||
)
|
||||
output_kwargs["images_kwargs"].setdefault("header_text", text)
|
||||
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
if text is not None and not self.image_processor.is_vqa:
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if "attention_mask" in text_encoding:
|
||||
text_encoding["decoder_attention_mask"] = text_encoding.pop("attention_mask")
|
||||
|
@ -626,7 +626,7 @@ class AlignModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
image = prepare_img()
|
||||
texts = ["a photo of a cat", "a photo of a dog"]
|
||||
inputs = processor(text=texts, images=image, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=image, text=texts, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
|
@ -330,7 +330,7 @@ class FuyuModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
text_prompt_coco_captioning = "Generate a coco-style caption.\n"
|
||||
|
||||
inputs = processor(text=text_prompt_coco_captioning, images=image, return_tensors="pt")
|
||||
inputs = processor(images=image, text=text_prompt_coco_captioning, return_tensors="pt")
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=10)
|
||||
|
||||
# take the last 8 tokens (in order to skip special \n\x04 characters) and decode them
|
||||
|
@ -1,17 +1,25 @@
|
||||
import io
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import AutoTokenizer, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
FuyuImageProcessor,
|
||||
FuyuProcessor,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
if is_vision_available() and is_torch_available():
|
||||
from transformers import FuyuImageProcessor, FuyuProcessor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -20,21 +28,36 @@ if is_torch_available():
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here?
|
||||
""" """
|
||||
@require_vision
|
||||
class FuyuProcessingTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = FuyuProcessor
|
||||
|
||||
def setUp(self):
|
||||
pretrained_model_name = "adept/fuyu-8b"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
||||
self.image_processor = FuyuImageProcessor()
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = FuyuImageProcessor()
|
||||
tokenizer = AutoTokenizer.from_pretrained("adept/fuyu-8b")
|
||||
|
||||
processor = FuyuProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer)
|
||||
self.text_prompt = "Generate a coco-style caption.\\n"
|
||||
bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
|
||||
self.bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
|
||||
|
||||
def get_processor(self):
|
||||
image_processor = FuyuImageProcessor()
|
||||
tokenizer = AutoTokenizer.from_pretrained("adept/fuyu-8b")
|
||||
processor = FuyuProcessor(image_processor, tokenizer, **self.prepare_processor_dict())
|
||||
|
||||
return processor
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def test_fuyu_processing(self):
|
||||
"""
|
||||
Test to ensure that the standard processing on a gold example matches adept's code.
|
||||
@ -43,7 +66,7 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
EXPECTED_IMAGE_PATCH_INPUTS = torch.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, -1, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, -1, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, -1, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -1, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, -1, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, -1, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -1, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, -1, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, -1, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, -1, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, -1, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, -1, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, -1, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,]]).to(torch.int64)
|
||||
EXPECTED_PADDED_UNPACKED_TOKEN_INPUTS = torch.Tensor([[71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71011, 71019, 1, 128340, 71374, 71389, 120412, 71377, 71835, 71374, 73615, 71375, 71399, 71435, 71122,]]).to(torch.int64)
|
||||
|
||||
one_image_bus_model_inputs = self.processor(text=self.text_prompt, images=self.bus_image_pil)
|
||||
one_image_bus_model_inputs = self.get_processor()(text=self.text_prompt, images=self.bus_image_pil)
|
||||
|
||||
# fmt: on
|
||||
torch.testing.assert_close(one_image_bus_model_inputs["image_patches_indices"], EXPECTED_IMAGE_PATCH_INPUTS)
|
||||
@ -53,8 +76,8 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
"""
|
||||
Test to check processor works with just text input
|
||||
"""
|
||||
processor_outputs = self.processor(text=self.text_prompt)
|
||||
tokenizer_outputs = self.tokenizer(self.text_prompt)
|
||||
processor_outputs = self.get_processor()(text=self.text_prompt)
|
||||
tokenizer_outputs = self.get_tokenizer()(self.text_prompt)
|
||||
self.assertEqual(processor_outputs["input_ids"], tokenizer_outputs["input_ids"])
|
||||
|
||||
def test_fuyu_processing_no_text(self):
|
||||
@ -90,7 +113,7 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
]).to(torch.int64)
|
||||
# fmt: on
|
||||
|
||||
processor_outputs = self.processor(images=self.bus_image_pil)
|
||||
processor_outputs = self.get_processor()(images=self.bus_image_pil)
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == EXPECTED_IMAGE_PATCH_INPUTS).all())
|
||||
|
||||
def test_fuyu_processing_multiple_image_sample(self):
|
||||
@ -107,7 +130,7 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
|
||||
# Batch of two images - equally sized
|
||||
images = [self.bus_image_pil, self.bus_image_pil]
|
||||
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
|
||||
processor_outputs = self.get_processor()(text=[self.text_prompt, self.text_prompt], images=images)
|
||||
|
||||
self.assertTrue(
|
||||
(
|
||||
@ -124,18 +147,18 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
|
||||
# Processes single images with different sizes as expected
|
||||
images = [self.bus_image_pil]
|
||||
processor_outputs = self.processor(text=self.text_prompt, images=images)
|
||||
processor_outputs = self.get_processor()(text=self.text_prompt, images=images)
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_IMAGE_PATCH_INPUTS).all())
|
||||
self.assertTrue((processor_outputs["input_ids"] == SINGLE_PADDED_UNPACKED_TOKEN_INPUTS).all())
|
||||
|
||||
images = [self.bus_image_pil.resize((64, 300))]
|
||||
processor_outputs = self.processor(text=self.text_prompt, images=images)
|
||||
processor_outputs = self.get_processor()(text=self.text_prompt, images=images)
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == SINGLE_RESIZED_IMAGE_PATCH_INPUTS).all())
|
||||
self.assertTrue((processor_outputs["input_ids"] == SINGLE_RESIZED_PADDED_UNPACKED_TOKEN_INPUTS).all())
|
||||
|
||||
# Batch of two images - different sizes. Left-pads the smaller image inputs
|
||||
images = [self.bus_image_pil, self.bus_image_pil.resize((64, 300))]
|
||||
processor_outputs = self.processor(text=[self.text_prompt, self.text_prompt], images=images)
|
||||
processor_outputs = self.get_processor()(text=[self.text_prompt, self.text_prompt], images=images)
|
||||
|
||||
padding_len_patch = SINGLE_IMAGE_PATCH_INPUTS.shape[1] - SINGLE_RESIZED_IMAGE_PATCH_INPUTS.shape[1]
|
||||
padded_single_resized_image_patch = torch.cat(
|
||||
@ -156,6 +179,155 @@ class FuyuProcessingTest(unittest.TestCase): # TODO Which mixins do we add here
|
||||
self.assertTrue((processor_outputs["image_patches_indices"] == expected_image_patch_inputs).all())
|
||||
self.assertTrue((processor_outputs["input_ids"] == expected_padded_unpacked_token_inputs).all())
|
||||
|
||||
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
# Fuyu uses tokenizer kwargs only when image is None.
|
||||
image_input = None
|
||||
|
||||
inputs = processor(
|
||||
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
|
||||
)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
|
||||
@unittest.skip("Fuyu processor does not support image_processor kwargs")
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Fuyu processor does not support image_processor kwargs")
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
pass
|
||||
|
||||
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
# Fuyu uses tokenizer kwargs only when image is None.
|
||||
image_input = None
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
|
||||
# Rewrite as Fuyu image processor does not return pixel values
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
# Fuyu uses tokenizer kwargs only when image is None.
|
||||
image_input = None
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Rewrite as Fuyu image processor does not return pixel values
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
# Fuyu uses tokenizer kwargs only when image is None.
|
||||
image_input = None
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
# Fuyu uses tokenizer kwargs only when image is None.
|
||||
image_input = None
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Rewrite as Fuyu supports tokenizer kwargs only when image is None.
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
# Fuyu uses tokenizer kwargs only when image is None.
|
||||
image_input = None
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 6)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestImageTextProcessingUtils(unittest.TestCase):
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -179,261 +179,3 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
list(inputs.keys()),
|
||||
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
|
||||
)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", padding="longest")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", padding="longest")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(
|
||||
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
|
||||
)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, size=[224, 224])
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 6)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_doubly_passed_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
image_input = self.prepare_image_inputs()
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
images_kwargs={"size": {"height": 222, "width": 222}},
|
||||
size={"height": 214, "width": 214},
|
||||
)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Override as InstructBlipProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_kwargs = {}
|
||||
processor_kwargs["image_processor"] = self.get_component("image_processor")
|
||||
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "video_processor" in self.processor_class.attributes:
|
||||
processor_kwargs["video_processor"] = self.get_component("video_processor")
|
||||
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
text_kwargs={"padding": "do_not_pad"},
|
||||
)
|
||||
|
@ -15,18 +15,15 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
BertTokenizerFast,
|
||||
@ -65,16 +62,6 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# Ignore copy
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of list of PIL images"""
|
||||
|
||||
video_inputs = [
|
||||
[Image.fromarray(np.random.randint(255, size=(30, 400, 3), dtype=np.uint8)) for _ in range(5)]
|
||||
for _ in range(2)
|
||||
]
|
||||
return video_inputs
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = InstructBlipVideoProcessor(
|
||||
tokenizer=self.get_tokenizer(),
|
||||
@ -193,261 +180,3 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
list(inputs.keys()),
|
||||
["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],
|
||||
)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", padding="longest")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", padding="longest")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(
|
||||
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
|
||||
)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, size=[224, 224])
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 6)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_doubly_passed_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
image_input = self.prepare_image_inputs()
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
images_kwargs={"size": {"height": 222, "width": 222}},
|
||||
size={"height": 214, "width": 214},
|
||||
)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Override as InstructBlipVideoProcessor has qformer_tokenizer
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_kwargs = {}
|
||||
processor_kwargs["image_processor"] = self.get_component("image_processor")
|
||||
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "video_processor" in self.processor_class.attributes:
|
||||
processor_kwargs["video_processor"] = self.get_component("video_processor")
|
||||
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
text_kwargs={"padding": "do_not_pad"},
|
||||
)
|
||||
|
@ -61,7 +61,7 @@ class Kosmos2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = CLIPImageProcessor()
|
||||
image_processor = CLIPImageProcessor(do_center_crop=False)
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB)
|
||||
@ -487,3 +487,147 @@ class Kosmos2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(outputs.input_ids.numpy().tolist()[-1], EXPECTED_IDS_BATCH[-1])
|
||||
self.assertListEqual(outputs.attention_mask.numpy().tolist()[-1], EXPECTED_MASK_BATCH[-1])
|
||||
self.assertListEqual(outputs.image_embeds_position_mask.numpy().tolist()[-1], EXPECTED_IMG_POS_MASK_BATCH[-1])
|
||||
|
||||
# Rewrite as Kosmos-2 supports custom padding only when image is None.
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
# set image input to None
|
||||
image_input = None
|
||||
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
max_length=112,
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
|
||||
# Rewrite to test only image_processor kwargs
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
# Rewrite to test only image_processor kwargs
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
# Rewrite as Kosmos-2 supports custom padding only when image is None.
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
# set image input to None
|
||||
image_input = None
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
|
||||
# Rewrite as Kosmos-2 supports custom padding only when image is None.
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
# set image input to None
|
||||
image_input = None
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# Rewrite as Kosmos-2 supports custom padding only when image is None.
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
# set image input to None
|
||||
image_input = None
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 10)
|
||||
|
@ -338,7 +338,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
inputs = self.processor(self.prompt, self.image, return_tensors="pt")
|
||||
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt")
|
||||
|
||||
# verify inputs against original implementation
|
||||
filepath = hf_hub_download(
|
||||
@ -390,8 +390,8 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
cats_image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = self.processor(
|
||||
[self.prompt, self.prompt],
|
||||
images=[self.image, cats_image],
|
||||
text=[self.prompt, self.prompt],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
@ -415,7 +415,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
prompt_with_unk = "[INST] <image>\nWhat is shown in this <unk> image? [/INST]"
|
||||
inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt")
|
||||
inputs = self.processor(images=self.image, text=prompt_with_unk, return_tensors="pt")
|
||||
|
||||
# verify single forward pass
|
||||
inputs = inputs.to(torch_device)
|
||||
@ -445,7 +445,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
@ -498,10 +498,10 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
|
||||
inputs_single = self.processor(self.prompt, images=lowres_img, return_tensors="pt", padding=True).to(
|
||||
inputs_single = self.processor(images=lowres_img, text=self.prompt, return_tensors="pt", padding=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
@ -527,7 +527,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)
|
||||
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
images=[lowres_img, cats_image], text=[self.prompt, self.prompt], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
|
||||
# model is in eval mode by default so we should get pad on the left side
|
||||
@ -607,13 +607,13 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2356)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 17)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
|
@ -18,7 +18,9 @@ import unittest
|
||||
import torch
|
||||
|
||||
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.testing_utils import (
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
@ -37,6 +37,8 @@ if is_vision_available():
|
||||
@require_torch
|
||||
class Pix2StructProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Pix2StructProcessor
|
||||
text_input_name = "decoder_input_ids"
|
||||
images_input_name = "flattened_patches"
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
@ -180,3 +182,148 @@ class Pix2StructProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
# For now the processor supports only ["flattened_patches", "input_ids", "attention_mask", "decoder_attention_mask"]
|
||||
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
# Rewrite as pix2struct processor return "flattened_patches" and not "pixel_values"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", max_patches=1024, patch_size={"height": 8, "width": 8})
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["flattened_patches"][0][0]), 194)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
# Rewrite as pix2struct processor return "flattened_patches" and not "pixel_values"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", max_patches=4096)
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, max_patches=1024)
|
||||
self.assertEqual(len(inputs["flattened_patches"][0]), 1024)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
max_patches=1024,
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
|
||||
self.assertEqual(len(inputs["decoder_input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
max_patches=1024,
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
|
||||
|
||||
self.assertEqual(len(inputs["decoder_input_ids"][0]), 5)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"max_patches": 1024},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
|
||||
|
||||
self.assertEqual(len(inputs["decoder_input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
# Rewrite as pix2struct processor return "decoder_input_ids" and not "input_ids"
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"max_patches": 1024},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["flattened_patches"].shape[1], 1024)
|
||||
|
||||
self.assertEqual(len(inputs["decoder_input_ids"][0]), 76)
|
||||
|
@ -18,9 +18,7 @@ import unittest
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import (
|
||||
require_vision,
|
||||
)
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
@ -66,7 +66,7 @@ class ProcessorTesterMixin:
|
||||
|
||||
component_class = processor_class_from_name(component_class_name)
|
||||
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
|
||||
if attribute == "tokenizer" and not component.pad_token:
|
||||
if "tokenizer" in attribute and not component.pad_token:
|
||||
component.pad_token = "[TEST_PAD]"
|
||||
if component.pad_token_id is None:
|
||||
component.pad_token_id = 0
|
||||
@ -322,14 +322,8 @@ class ProcessorTesterMixin:
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_kwargs = {}
|
||||
processor_kwargs["image_processor"] = self.get_component("image_processor")
|
||||
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "video_processor" in self.processor_class.attributes:
|
||||
processor_kwargs["video_processor"] = self.get_component("video_processor")
|
||||
processor = self.processor_class(**processor_kwargs)
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
|
Loading…
Reference in New Issue
Block a user