Uniformize kwargs for Pixtral processor (#33521)

* add uniformized pixtral and kwargs

* update doc

* fix _validate_images_text_input_order

* nit
This commit is contained in:
Yoni Gozlan 2024-09-17 14:44:27 -04:00 committed by GitHub
parent c29a8694b0
commit d8500cd229
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 255 additions and 62 deletions

View File

@ -51,7 +51,7 @@ IMG_URLS = [
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"
inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
inputs = processor(images=IMG_URLS, text=PROMPT, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

View File

@ -43,7 +43,8 @@ else:
if TYPE_CHECKING:
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig
from .configuration_pixtral import PixtralVisionConfig
from .processing_pixtral import PixtralProcessor
try:
if not is_torch_available():

View File

@ -16,18 +16,36 @@
Processor class for Pixtral.
"""
from typing import List, Optional, Union
import sys
from typing import List, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
logger = logging.get_logger(__name__)
class PixtralProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"common_kwargs": {
"return_tensors": "pt",
},
}
# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
@ -143,12 +161,11 @@ class PixtralProcessor(ProcessorMixin):
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[PixtralProcessorKwargs],
) -> BatchMixFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@ -158,26 +175,13 @@ class PixtralProcessor(ProcessorMixin):
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
@ -195,6 +199,15 @@ class PixtralProcessor(ProcessorMixin):
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
PixtralProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
if is_image_or_image_url(images):
images = [[images]]
@ -209,7 +222,7 @@ class PixtralProcessor(ProcessorMixin):
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
)
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors)
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
else:
image_inputs = {}
@ -246,16 +259,9 @@ class PixtralProcessor(ProcessorMixin):
while "<placeholder>" in sample:
replace_str = replace_strings.pop(0)
sample = sample.replace("<placeholder>", replace_str, 1)
prompt_strings.append(sample)
text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return BatchMixFeature(data={**text_inputs, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama

View File

@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
import numpy as np
from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_vision_available, valid_images
from .image_utils import ChannelDimension, is_valid_image, is_vision_available
if is_vision_available():
@ -1003,6 +1003,20 @@ def _validate_images_text_input_order(images, text):
in the processor's `__call__` method before calling this method.
"""
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
def _is_valid_images_input_for_processor(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not _is_valid_images_input_for_processor(img):
return False
# If not a list or tuple, we have been given a single image or batched tensor of images
elif not (is_valid_image(imgs) or is_url(imgs)):
return False
return True
def _is_valid_text_input_for_processor(t):
if isinstance(t, str):
# Strings are fine
@ -1019,11 +1033,11 @@ def _validate_images_text_input_order(images, text):
def _is_valid(input, validator):
return validator(input) or input is None
images_is_valid = _is_valid(images, valid_images)
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False
images_is_valid = _is_valid(images, _is_valid_images_input_for_processor)
images_is_text = _is_valid_text_input_for_processor(images)
text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
text_is_images = valid_images(text) if not text_is_valid else False
text_is_images = _is_valid_images_input_for_processor(text)
# Handle cases where both inputs are valid
if images_is_valid and text_is_valid:
return images, text

View File

@ -11,14 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import requests
import torch
from transformers.testing_utils import require_vision
from transformers.testing_utils import (
require_torch,
require_vision,
)
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from PIL import Image
@ -27,7 +34,7 @@ if is_vision_available():
@require_vision
class PixtralProcessorTest(unittest.TestCase):
class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PixtralProcessor
@classmethod
@ -40,15 +47,20 @@ class PixtralProcessorTest(unittest.TestCase):
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
def setUp(self):
super().setUp()
self.tmpdirname = tempfile.mkdtemp()
# FIXME - just load the processor directly from the checkpoint
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b")
image_processor = PixtralImageProcessor()
self.processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor.save_pretrained(self.tmpdirname)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
@unittest.skip("No chat template was set for this model (yet)")
def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:"
messages = [
@ -60,11 +72,12 @@ class PixtralProcessorTest(unittest.TestCase):
],
},
]
formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
@unittest.skip("No chat template was set for this model (yet)")
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
# Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526
@ -79,8 +92,8 @@ class PixtralProcessorTest(unittest.TestCase):
],
},
]
inputs = self.processor(
text=[self.processor.apply_chat_template(messages)],
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
@ -88,14 +101,15 @@ class PixtralProcessorTest(unittest.TestCase):
self.assertEqual(expected_image_tokens, image_tokens)
def test_processor_with_single_image(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:"
# Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
processor.image_processor.size = {"longest_edge": 30}
processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image
inputs_image = self.processor(text=prompt_string, images=self.image_0, return_tensors="pt")
inputs_image = processor(text=prompt_string, images=self.image_0, return_tensors="pt")
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
@ -115,7 +129,7 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on
# Test passing in a url
inputs_url = self.processor(text=prompt_string, images=self.url_0, return_tensors="pt")
inputs_url = processor(text=prompt_string, images=self.url_0, return_tensors="pt")
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
@ -135,14 +149,15 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on
def test_processor_with_multiple_images_single_list(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
# Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
processor.image_processor.size = {"longest_edge": 30}
processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image
inputs_image = self.processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
@ -162,7 +177,7 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on
# Test passing in a url
inputs_url = self.processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
inputs_url = processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt")
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 1)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
@ -181,19 +196,20 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on
def test_processor_with_multiple_images_multiple_lists(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = [
"USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:",
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
]
self.processor.tokenizer.pad_token = "</s>"
processor.tokenizer.pad_token = "</s>"
image_inputs = [[self.image_0, self.image_1], [self.image_2]]
# Make small for checking image token expansion
self.processor.image_processor.size = {"longest_edge": 30}
self.processor.image_processor.patch_size = {"height": 2, "width": 2}
processor.image_processor.size = {"longest_edge": 30}
processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image
inputs_image = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 2)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
@ -213,7 +229,7 @@ class PixtralProcessorTest(unittest.TestCase):
# fmt: on
# Test passing in a url
inputs_url = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
inputs_url = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_url)
self.assertTrue(len(inputs_url["input_ids"]) == 2)
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
@ -231,3 +247,145 @@ class PixtralProcessorTest(unittest.TestCase):
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
)
# fmt: on
# Override all tests requiring shape as returning tensor batches is not supported by PixtralProcessor
@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size={"height": 240, "width": 240})
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
# Added dimension by pixtral image processor
self.assertEqual(len(inputs["pixel_values"][0][0][0][0]), 240)
@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size={"height": 400, "width": 400})
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, size={"height": 240, "width": 240})
self.assertEqual(len(inputs["pixel_values"][0][0][0][0]), 240)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 240, "width": 240}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 240, "width": 240}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 240, "width": 240},
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
# images needs to be nested to detect multiple prompts
image_input = [self.prepare_image_inputs()] * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 240, "width": 240},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240)
self.assertEqual(len(inputs["input_ids"][0]), 4)

View File

@ -64,6 +64,8 @@ class ProcessorTesterMixin:
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
if attribute == "tokenizer" and not component.pad_token:
component.pad_token = "[TEST_PAD]"
if component.pad_token_id is None:
component.pad_token_id = 0
return component

View File

@ -80,6 +80,18 @@ class ProcessingUtilTester(unittest.TestCase):
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# list of strings and list of url images inputs
images = ["https://url1", "https://url2"]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# list of strings and nested list of numpy images inputs
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
text = ["text1", "text2"]