mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Paligemma support for multi-image (#33447)
* upadte * Update src/transformers/models/paligemma/processing_paligemma.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * update docs * better example in tests * support image tokens * read token * Update tests/models/paligemma/test_processing_paligemma.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * nit: naming * Update docs/source/en/model_doc/paligemma.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * conflicts after rebasing --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
parent
55b7a0404e
commit
3e039d3827
@ -29,7 +29,20 @@ This model was contributed by [Molbap](https://huggingface.co/Molbap).
|
||||
|
||||
## Usage tips
|
||||
|
||||
Inference with PaliGemma can be performed as follows:
|
||||
- PaliGemma is not meant for conversational use, and it works best when fine-tuning to a specific use case. Some downstream tasks on which PaliGemma can be fine-tuned include image captioning, visual question answering (VQA), object detection, referring expression segmentation and document understanding.
|
||||
- One can use `PaliGemmaProcessor` to prepare images, text and optional labels for the model. When fine-tuning a PaliGemma model, the `suffix` argument can be passed to the processor which creates the `labels` for the model:
|
||||
|
||||
```python
|
||||
prompt = "What is on the flower?"
|
||||
answer = "a bee"
|
||||
inputs = processor(images=raw_image, text=prompt, suffix=answer, return_tensors="pt")
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
The model can accept a single or multiple images. According to the [paper](https://arxiv.org/abs/2407.07726v1), the checkpoint PaliGemma can transfer to tasks which take multiple images as input. NLVR2 is one such task, which asks one question about two images, and requires looking at both to give the correct answer. Here's an example code for single and multi image inference.
|
||||
|
||||
### Single-image Inference
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
@ -44,16 +57,31 @@ raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
inputs = processor(raw_image, prompt, return_tensors="pt")
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
```
|
||||
|
||||
- PaliGemma is not meant for conversational use, and it works best when fine-tuning to a specific use case. Some downstream tasks on which PaliGemma can be fine-tuned include image captioning, visual question answering (VQA), object detection, referring expression segmentation and document understanding.
|
||||
- One can use `PaliGemmaProcessor` to prepare images, text and optional labels for the model. When fine-tuning a PaliGemma model, the `suffix` argument can be passed to the processor which creates the `labels` for the model:
|
||||
### Multi-image Inference
|
||||
|
||||
```python
|
||||
prompt = "What is on the flower?"
|
||||
answer = "a bee"
|
||||
inputs = processor(images=raw_image, text=prompt, suffix=answer, return_tensors="pt")
|
||||
model_id = "google/paligemma-3b-ft-nlvr2-448" # checkpoint tuned for multiple images
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
|
||||
processor = PaliGemmaProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "answer en Which of the two pictures shows a snowman, first or second?"
|
||||
stop_sign_image = Image.open(
|
||||
requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw
|
||||
)
|
||||
snow_image = Image.open(
|
||||
requests.get(
|
||||
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", stream=True
|
||||
).raw
|
||||
)
|
||||
|
||||
inputs = processor(images=[[snow_image, stop_sign_image]], text=prompt, return_tensors="pt")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
@ -77,7 +77,7 @@ def _is_str_or_image(elem):
|
||||
return isinstance(elem, (str)) or is_image_or_image_url(elem)
|
||||
|
||||
|
||||
def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
|
||||
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
|
||||
"""
|
||||
Builds a string from the input prompt and image tokens.
|
||||
For example, for the call:
|
||||
@ -94,8 +94,33 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
|
||||
bos_token (`str`): The beginning of sentence token.
|
||||
image_seq_len (`int`): The length of the image sequence.
|
||||
image_token (`str`): The image token.
|
||||
num_images (`int`): Number of images in the prompt.
|
||||
"""
|
||||
return f"{image_token * image_seq_len}{bos_token}{prompt}\n"
|
||||
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
|
||||
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images
|
||||
def make_batched_images(images) -> List[List[ImageInput]]:
|
||||
"""
|
||||
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
||||
|
||||
Args:
|
||||
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
||||
The input image.
|
||||
|
||||
Returns:
|
||||
list: A list of images.
|
||||
"""
|
||||
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
||||
return [img for img_list in images for img in img_list]
|
||||
|
||||
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
||||
return images
|
||||
|
||||
elif is_valid_image(images):
|
||||
return [images]
|
||||
|
||||
raise ValueError(f"Could not make batched video from {images}")
|
||||
|
||||
|
||||
class PaliGemmaProcessor(ProcessorMixin):
|
||||
@ -230,29 +255,53 @@ class PaliGemmaProcessor(ProcessorMixin):
|
||||
)
|
||||
text = ""
|
||||
|
||||
if isinstance(text, List) and isinstance(images, List):
|
||||
if len(images) < len(text):
|
||||
raise ValueError(
|
||||
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
|
||||
)
|
||||
if _is_str_or_image(text):
|
||||
text = [text]
|
||||
elif isinstance(text, list) and _is_str_or_image(text[0]):
|
||||
pass
|
||||
if suffix is not None and _is_str_or_image(suffix):
|
||||
suffix = [suffix]
|
||||
if suffix is not None:
|
||||
suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
|
||||
|
||||
input_strings = [
|
||||
build_string_from_input(
|
||||
prompt=prompt,
|
||||
bos_token=self.tokenizer.bos_token,
|
||||
image_seq_len=self.image_seq_length,
|
||||
image_token=IMAGE_TOKEN,
|
||||
)
|
||||
for prompt in text
|
||||
]
|
||||
if text is not None and images is not None:
|
||||
if not any(IMAGE_TOKEN in sample for sample in text):
|
||||
logger.warning(
|
||||
"You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special "
|
||||
"image tokens in the text, as many tokens as there are images per each text. It is recommended to "
|
||||
"add `<image>` tokens in the very beginning of your text and `<bos>` token after that. For this call, we will infer how many images "
|
||||
"each text has and add special tokens."
|
||||
)
|
||||
|
||||
if isinstance(text, List) and isinstance(images, List):
|
||||
if len(images) != len(text):
|
||||
raise ValueError(
|
||||
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
|
||||
)
|
||||
|
||||
# make a nested list of lists to be able to iterate over the images and text below
|
||||
if is_valid_image(images):
|
||||
images = [[images]]
|
||||
elif isinstance(images, list) and is_valid_image(images[0]):
|
||||
images = [[image] for image in images]
|
||||
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
|
||||
raise ValueError("images must be an image, list of images or list of list of images")
|
||||
|
||||
if suffix is not None and _is_str_or_image(suffix):
|
||||
suffix = [suffix]
|
||||
if suffix is not None:
|
||||
suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
|
||||
|
||||
input_strings = [
|
||||
build_string_from_input(
|
||||
prompt=prompt,
|
||||
bos_token=self.tokenizer.bos_token,
|
||||
image_seq_len=self.image_seq_length,
|
||||
image_token=IMAGE_TOKEN,
|
||||
num_images=len(image_list) if isinstance(image_list, list) else 1,
|
||||
)
|
||||
for prompt, image_list in zip(text, images)
|
||||
]
|
||||
images = make_batched_images(images)
|
||||
else:
|
||||
text = [sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length) for sample in text]
|
||||
input_strings = [f"{sample}\n" for sample in text]
|
||||
|
||||
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
|
||||
|
||||
|
@ -326,8 +326,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@ -349,8 +347,40 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_multiimage(self):
|
||||
model_id = "google/paligemma-3b-ft-nlvr2-448" # checkpoint tuned for multiple images
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
|
||||
processor = PaliGemmaProcessor.from_pretrained(model_id)
|
||||
prompt = "answer en There is no snowman in any of the images. Is this true or false?"
|
||||
stop_sign_image = Image.open(
|
||||
requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw
|
||||
)
|
||||
snow_image = Image.open(
|
||||
requests.get(
|
||||
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg", stream=True
|
||||
).raw
|
||||
)
|
||||
|
||||
inputs = processor(text=prompt, images=[[snow_image, snow_image]], return_tensors="pt")
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "answer en There is no snowman in any of the images. Is this true or false?\nFalse"
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
# try another prompt with two different image this time
|
||||
prompt = "answer en There is exactly one snowman. Is this true or false?"
|
||||
inputs = processor(text=prompt, images=[[snow_image, stop_sign_image]], return_tensors="pt")
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "answer en There is exactly one snowman. Is this true or false?\nTrue"
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
def test_small_model_integration_test_paligemma_VQA(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@ -370,8 +400,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_empty_prompt(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@ -392,8 +420,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_batched(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@ -420,9 +446,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_batched_bf16(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@ -452,9 +475,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_small_model_integration_test_paligemma_batched_f16(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
@ -485,9 +505,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_integration_detection_bug(self):
|
||||
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
|
||||
# impacted negatively segmentation generations.
|
||||
@ -511,8 +528,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
|
||||
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_paligemma_index_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
|
||||
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
|
||||
@ -536,9 +551,6 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_paligemma_finetuning_with_suffixes_bf16(self):
|
||||
# this is a supplementary test to ensure paligemma fine-tuning that relies on token_type_ids is robust to future changes
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
|
84
tests/models/paligemma/test_processing_paligemma.py
Normal file
84
tests/models/paligemma/test_processing_paligemma.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, GemmaTokenizerFast, PaliGemmaProcessor
|
||||
from transformers.testing_utils import require_read_token, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import SiglipImageProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_read_token
|
||||
class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = PaliGemmaProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = SiglipImageProcessor(do_center_crop=False)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/gemma-7b")
|
||||
image_processor.image_seq_length = 32
|
||||
|
||||
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_text_with_image_tokens(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
text_multi_images = "<image><image><bos>Dummy text!"
|
||||
text_single_image = "<image><bos>Dummy text!"
|
||||
text_no_image = "Dummy text!"
|
||||
|
||||
image = self.prepare_image_inputs()[0]
|
||||
|
||||
out_noimage = processor(text=text_no_image, images=image, return_tensors="np")
|
||||
out_singlimage = processor(text=text_single_image, images=image, return_tensors="np")
|
||||
for k in out_noimage:
|
||||
self.assertTrue(out_noimage[k].tolist() == out_singlimage[k].tolist())
|
||||
|
||||
out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
|
||||
out_noimage = processor(text=text_no_image, images=[[image, image]], return_tensors="np")
|
||||
|
||||
# We can't be sure what is users intention, whether user want "one text + two images" or user forgot to add the second text
|
||||
with self.assertRaises(ValueError):
|
||||
out_noimage = processor(text=text_no_image, images=[image, image], return_tensors="np")
|
||||
|
||||
for k in out_noimage:
|
||||
self.assertTrue(out_noimage[k].tolist() == out_multiimages[k].tolist())
|
||||
|
||||
text_batched = ["Dummy text!", "Dummy text!"]
|
||||
text_batched_with_image = ["<image><bos>Dummy text!", "<image><bos>Dummy text!"]
|
||||
out_images = processor(text=text_batched_with_image, images=[image, image], return_tensors="np")
|
||||
out_noimage_nested = processor(text=text_batched, images=[[image], [image]], return_tensors="np")
|
||||
out_noimage = processor(text=text_batched, images=[image, image], return_tensors="np")
|
||||
for k in out_noimage:
|
||||
self.assertTrue(out_noimage[k].tolist() == out_images[k].tolist() == out_noimage_nested[k].tolist())
|
Loading…
Reference in New Issue
Block a user