Uniformize kwargs for Pixtral processor (#33521)

* add uniformized pixtral and kwargs

* update doc

* fix _validate_images_text_input_order

* nit
This commit is contained in:
Yoni Gozlan 2024-09-17 14:44:27 -04:00 committed by GitHub
parent c29a8694b0
commit d8500cd229
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 255 additions and 62 deletions

View File

@ -51,7 +51,7 @@ IMG_URLS = [
] ]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]" PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda") inputs = processor(images=IMG_URLS, text=PROMPT, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500) generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

View File

@ -43,7 +43,8 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig from .configuration_pixtral import PixtralVisionConfig
from .processing_pixtral import PixtralProcessor
try: try:
if not is_torch_available(): if not is_torch_available():

View File

@ -16,18 +16,36 @@
Processor class for Pixtral. Processor class for Pixtral.
""" """
from typing import List, Optional, Union import sys
from typing import List, Union
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class PixtralProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"common_kwargs": {
"return_tensors": "pt",
},
}
# Copied from transformers.models.idefics2.processing_idefics2.is_url # Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool: def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http") return isinstance(val, str) and val.startswith("http")
@ -143,12 +161,11 @@ class PixtralProcessor(ProcessorMixin):
def __call__( def __call__(
self, self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None, images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
truncation: Union[bool, str, TruncationStrategy] = None, audio=None,
max_length=None, videos=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, **kwargs: Unpack[PixtralProcessorKwargs],
) -> BatchMixFeature: ) -> BatchMixFeature:
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@ -158,26 +175,13 @@ class PixtralProcessor(ProcessorMixin):
of the above two methods for more information. of the above two methods for more information.
Args: Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`): text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*): return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are: If set, will return tensors of a particular framework. Acceptable values are:
@ -195,6 +199,15 @@ class PixtralProcessor(ProcessorMixin):
`None`). `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
PixtralProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None: if images is not None:
if is_image_or_image_url(images): if is_image_or_image_url(images):
images = [[images]] images = [[images]]
@ -209,7 +222,7 @@ class PixtralProcessor(ProcessorMixin):
"Invalid input images. Please provide a single image or a list of images or a list of list of images." "Invalid input images. Please provide a single image or a list of images or a list of list of images."
) )
images = [[load_image(im) for im in sample] for sample in images] images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors) image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
else: else:
image_inputs = {} image_inputs = {}
@ -246,16 +259,9 @@ class PixtralProcessor(ProcessorMixin):
while "<placeholder>" in sample: while "<placeholder>" in sample:
replace_str = replace_strings.pop(0) replace_str = replace_strings.pop(0)
sample = sample.replace("<placeholder>", replace_str, 1) sample = sample.replace("<placeholder>", replace_str, 1)
prompt_strings.append(sample) prompt_strings.append(sample)
text_inputs = self.tokenizer( text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
return BatchMixFeature(data={**text_inputs, **image_inputs}) return BatchMixFeature(data={**text_inputs, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama

View File

@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
import numpy as np import numpy as np
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_vision_available, valid_images from .image_utils import ChannelDimension, is_valid_image, is_vision_available
if is_vision_available(): if is_vision_available():
@ -1003,6 +1003,20 @@ def _validate_images_text_input_order(images, text):
in the processor's `__call__` method before calling this method. in the processor's `__call__` method before calling this method.
""" """
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
def _is_valid_images_input_for_processor(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not _is_valid_images_input_for_processor(img):
return False
# If not a list or tuple, we have been given a single image or batched tensor of images
elif not (is_valid_image(imgs) or is_url(imgs)):
return False
return True
def _is_valid_text_input_for_processor(t): def _is_valid_text_input_for_processor(t):
if isinstance(t, str): if isinstance(t, str):
# Strings are fine # Strings are fine
@ -1019,11 +1033,11 @@ def _validate_images_text_input_order(images, text):
def _is_valid(input, validator): def _is_valid(input, validator):
return validator(input) or input is None return validator(input) or input is None
images_is_valid = _is_valid(images, valid_images) images_is_valid = _is_valid(images, _is_valid_images_input_for_processor)
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False images_is_text = _is_valid_text_input_for_processor(images)
text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
text_is_images = valid_images(text) if not text_is_valid else False text_is_images = _is_valid_images_input_for_processor(text)
# Handle cases where both inputs are valid # Handle cases where both inputs are valid
if images_is_valid and text_is_valid: if images_is_valid and text_is_valid:
return images, text return images, text

View File

@ -11,14 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import shutil
import tempfile
import unittest import unittest
import requests import requests
import torch import torch
from transformers.testing_utils import require_vision from transformers.testing_utils import (
require_torch,
require_vision,
)
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
@ -27,7 +34,7 @@ if is_vision_available():
@require_vision @require_vision
class PixtralProcessorTest(unittest.TestCase): class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PixtralProcessor processor_class = PixtralProcessor
@classmethod @classmethod
@ -40,15 +47,20 @@ class PixtralProcessorTest(unittest.TestCase):
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw) cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
def setUp(self): def setUp(self):
super().setUp() self.tmpdirname = tempfile.mkdtemp()
# FIXME - just load the processor directly from the checkpoint # FIXME - just load the processor directly from the checkpoint
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b")
image_processor = PixtralImageProcessor() image_processor = PixtralImageProcessor()
self.processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor) processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor.save_pretrained(self.tmpdirname)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
@unittest.skip("No chat template was set for this model (yet)") @unittest.skip("No chat template was set for this model (yet)")
def test_chat_template(self): def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:" expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:"
messages = [ messages = [
@ -60,11 +72,12 @@ class PixtralProcessorTest(unittest.TestCase):
], ],
}, },
] ]
formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt) self.assertEqual(expected_prompt, formatted_prompt)
@unittest.skip("No chat template was set for this model (yet)") @unittest.skip("No chat template was set for this model (yet)")
def test_image_token_filling(self): def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
# Important to check with non square image # Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316)) image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526 expected_image_tokens = 1526
@ -79,8 +92,8 @@ class PixtralProcessorTest(unittest.TestCase):
], ],
}, },
] ]
inputs = self.processor( inputs = processor(
text=[self.processor.apply_chat_template(messages)], text=[processor.apply_chat_template(messages)],
images=[image], images=[image],
return_tensors="pt", return_tensors="pt",
) )
@ -88,14 +101,15 @@ class PixtralProcessorTest(unittest.TestCase):
self.assertEqual(expected_image_tokens, image_tokens) self.assertEqual(expected_image_tokens, image_tokens)
def test_processor_with_single_image(self): def test_processor_with_single_image(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:" prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
# Make small for checking image token expansion # Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30} processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2} processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image # Test passing in an image
inputs_image = self.processor(text=prompt_string, images=self.image_0, return_tensors="pt") inputs_image = processor(text=prompt_string, images=self.image_0, return_tensors="pt")
self.assertIn("input_ids", inputs_image) self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
@ -115,7 +129,7 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on # fmt: on
# Test passing in a url # Test passing in a url
inputs_url = self.processor(text=prompt_string, images=self.url_0, return_tensors="pt") inputs_url = processor(text=prompt_string, images=self.url_0, return_tensors="pt")
self.assertIn("input_ids", inputs_url) self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
@ -135,14 +149,15 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on # fmt: on
def test_processor_with_multiple_images_single_list(self): def test_processor_with_multiple_images_single_list(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:" prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
# Make small for checking image token expansion # Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30} processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2} processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image # Test passing in an image
inputs_image = self.processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt") inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
self.assertIn("input_ids", inputs_image) self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
@ -162,7 +177,7 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on # fmt: on
# Test passing in a url # Test passing in a url
inputs_url = self.processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt") inputs_url = processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
self.assertIn("input_ids", inputs_url) self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
@ -181,19 +196,20 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on # fmt: on
def test_processor_with_multiple_images_multiple_lists(self): def test_processor_with_multiple_images_multiple_lists(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = [ prompt_string = [
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:", "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:", "USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
] ]
self.processor.tokenizer.pad_token = "</s>" processor.tokenizer.pad_token = "</s>"
image_inputs = [[self.image_0, self.image_1], [self.image_2]] image_inputs = [[self.image_0, self.image_1], [self.image_2]]
# Make small for checking image token expansion # Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30} processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2} processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image # Test passing in an image
inputs_image = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image) self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 2) self.assertTrue(len(inputs_image["input_ids"]) == 2)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
@ -213,7 +229,7 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on # fmt: on
# Test passing in a url # Test passing in a url
inputs_url = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) inputs_url = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_url) self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 2) self.assertTrue(len(inputs_url["input_ids"]) == 2)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
@ -231,3 +247,145 @@ class PixtralProcessorTest(unittest.TestCase):
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
) )
# fmt: on # fmt: on
# Override all tests requiring shape as returning tensor batches is not supported by PixtralProcessor
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size={"height": 240, "width": 240})
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
# Added dimension by pixtral image processor
self.assertEqual(len(inputs["pixel_values"][0][0][0][0]), 240)
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size={"height": 400, "width": 400})
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, size={"height": 240, "width": 240})
self.assertEqual(len(inputs["pixel_values"][0][0][0][0]), 240)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 240, "width": 240}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 240, "width": 240}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 240, "width": 240},
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
# images needs to be nested to detect multiple prompts
image_input = [self.prepare_image_inputs()] * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 240, "width": 240},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 4)

View File

@ -64,6 +64,8 @@ class ProcessorTesterMixin:
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
if attribute == "tokenizer" and not component.pad_token: if attribute == "tokenizer" and not component.pad_token:
component.pad_token = "[TEST_PAD]" component.pad_token = "[TEST_PAD]"
if component.pad_token_id is None:
component.pad_token_id = 0
return component return component

View File

@ -80,6 +80,18 @@ class ProcessingUtilTester(unittest.TestCase):
self.assertTrue(np.array_equal(valid_images[0], images[0])) self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text) self.assertEqual(valid_text, text)
# list of strings and list of url images inputs
images = ["https://url1", "https://url2"]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# list of strings and nested list of numpy images inputs # list of strings and nested list of numpy images inputs
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
text = ["text1", "text2"] text = ["text1", "text2"]