Add image text to text pipeline (#34170)

* Standardize image-text-to-text-models-output

add post_process_image_text_to_text to chameleon and cleanup

Fix legacy kwarg behavior and deprecation warning

add post_process_image_text_to_text to qwen2_vl and llava_onevision

Add post_process_image_text_to_text to idefics3, mllama, pixtral processor

* nit var name post_process_image_text_to_text udop

* nit fix deprecation warnings

* Add image-text-to-text pipeline

* add support for image url in chat template for pipeline

* Reformat to be fully compatible with chat templates

* Add tests chat template

* Fix imports and tests

* Add pipeline tag

* change logic handling of single prompt ans multiple images

* add pipeline mapping to models

* fix batched inference

* fix tests

* Add manual batching for preprocessing

* Fix outputs with nested images

* Add support for all common processing kwargs

* Add default padding when multiple text inputs (batch size>1)

* nit change version deprecation warning

* Add support for text only inference

* add chat_template warnings

* Add pipeline tests and add copied from post process function

* Fix batched pipeline tests

* nit

* Fix pipeline tests blip2

* remove unnecessary max_new_tokens

* revert processing kosmos2 and remove unnecessary max_new_tokens

* fix pipeline tests idefics

* Force try loading processor if pipeline supports it

* revert load_processor change

* hardcode loading only processor

* remove unnecessary try except

* skip imagetexttotext tests for kosmos2 as tiny model causes problems

* Make code clearer

* Address review comments

* remove preprocessing logic from pipeline

* fix fuyu

* add BC resize fuyu

* Move post_process_image_text_to_text to ProcessorMixin

* add guard in post_process

* fix zero shot object detection pipeline

* add support for generator input in pipeline

* nit

* change default image-text-to-text model to llava onevision

* fix owlv2 size dict

* Change legacy deprecation warning to only show when True
This commit is contained in:
Yoni Gozlan 2024-10-31 15:48:11 -04:00 committed by GitHub
parent c443d8d536
commit 203e27059b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 988 additions and 33 deletions

View File

@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following.
- __call__
- all
### ImageTextToTextPipeline
[[autodoc]] ImageTextToTextPipeline
- __call__
- all
### MaskGenerationPipeline
[[autodoc]] MaskGenerationPipeline

View File

@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
- __call__
- all
### ImageTextToTextPipeline
[[autodoc]] ImageTextToTextPipeline
- __call__
- all
### VisualQuestionAnsweringPipeline
[[autodoc]] VisualQuestionAnsweringPipeline

View File

@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all
### ImageTextToTextPipeline
[[autodoc]] ImageTextToTextPipeline
- __call__
- all
### MaskGenerationPipeline
[[autodoc]] MaskGenerationPipeline

View File

@ -868,6 +868,7 @@ _import_structure = {
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageSegmentationPipeline",
"ImageTextToTextPipeline",
"ImageToImagePipeline",
"ImageToTextPipeline",
"JsonPipelineDataFormat",
@ -5794,6 +5795,7 @@ if TYPE_CHECKING:
ImageClassificationPipeline,
ImageFeatureExtractionPipeline,
ImageSegmentationPipeline,
ImageTextToTextPipeline,
ImageToImagePipeline,
ImageToTextPipeline,
JsonPipelineDataFormat,

View File

@ -385,6 +385,27 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image
def load_images(
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
"""Loads images, handling different levels of nesting.
Args:
images: A single image, a list of images, or a list of lists of images to load.
timeout: Timeout for loading images.
Returns:
A single image, a list of images, a list of lists of images.
"""
if isinstance(images, (list, tuple)):
if len(images) and isinstance(images[0], (list, tuple)):
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
else:
return [load_image(image, timeout=timeout) for image in images]
else:
return load_image(images, timeout=timeout)
def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,

View File

@ -114,6 +114,7 @@ else:
("oneformer", ("OneFormerImageProcessor",)),
("owlv2", ("Owlv2ImageProcessor",)),
("owlvit", ("OwlViTImageProcessor",)),
("paligemma", ("SiglipImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),

View File

@ -24,12 +24,16 @@ from typing import List, Optional, Union
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
class DonutProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}
logger = logging.get_logger(__name__)
class DonutProcessor(ProcessorMixin):
r"""
Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single
@ -85,6 +89,16 @@ class DonutProcessor(ProcessorMixin):
[`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
# For backward compatibility
legacy = kwargs.pop("legacy", True)
if legacy:
# With `add_special_tokens=True`, the performance of donut are degraded when working with both images and text.
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, if both images and text are provided, the default value of `add_special_tokens` "
"will be changed to `False` when calling the tokenizer if `add_special_tokens` is unset. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)
if self._in_target_context_manager:
return self.current_processor(images, text, **kwargs)
@ -100,6 +114,8 @@ class DonutProcessor(ProcessorMixin):
if images is not None:
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None:
if not legacy and images is not None:
output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
if text is None:

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
pad,
resize,
@ -475,6 +475,7 @@ class FuyuImageProcessor(BaseImageProcessor):
input_data_format = infer_channel_dimension_format(batch_images[0][0])
original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
size = get_size_dict(size) # for BC
if do_resize:
batch_images = [

View File

@ -264,10 +264,10 @@ def _tokenize_prompts_with_image_and_batch(
bos_token = tokenizer.vocab["|ENDOFTEXT|"]
prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
if add_beginning_of_answer_token:
boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
# Only add bbox open token to the last subsequence since that is what will be completed
for token_seq in prompts_tokens:
token_seq[-1].append(boa)
token_seq[-1].append(beginning_of_answer)
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
@ -682,6 +682,32 @@ class FuyuProcessor(ProcessorMixin):
return results
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
containing the token ids of the generated sequences.
Returns:
`List[str]`: The decoded text output.
"""
beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
# get boa index for each outputted sequence tensor
# start all generated sequences from the beginning of the answer token, pad to have consistent length
unpadded_output_sequences = [
seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
]
max_len = max(len(seq) for seq in unpadded_output_sequences)
# convert to torch and pad sequences
padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)
return self.batch_decode(padded_output_sequences, skip_special_tokens=True)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -22,12 +22,16 @@ from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
class GitProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}
logger = logging.get_logger(__name__)
class GitProcessor(ProcessorMixin):
r"""
Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
@ -91,6 +95,15 @@ class GitProcessor(ProcessorMixin):
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, if both images and text are provided, the last token (EOS token) "
"of the input_ids and attention_mask tensors will be removed. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")
@ -110,6 +123,10 @@ class GitProcessor(ProcessorMixin):
if images is not None:
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)
if not legacy:
data["input_ids"] = data["input_ids"][:, :-1]
data["attention_mask"] = data["attention_mask"][:, :-1]
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))
def batch_decode(self, *args, **kwargs):

View File

@ -428,6 +428,21 @@ class Kosmos2Processor(ProcessorMixin):
return clean_text_and_extract_entities_with_bboxes(caption)
return caption
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]
@property
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
def model_input_names(self):

View File

@ -342,6 +342,22 @@ class MllamaProcessor(ProcessorMixin):
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_to_corners_format,
pad,
@ -399,6 +399,7 @@ class Owlv2ImageProcessor(BaseImageProcessor):
image_std = image_std if image_std is not None else self.image_std
size = size if size is not None else self.size
size = get_size_dict(size) # for BC
images = make_list_of_images(images)

View File

@ -21,6 +21,7 @@ from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import logging
class Pix2StructImagesKwargs(ImagesKwargs, total=False):
@ -48,6 +49,9 @@ class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
}
logger = logging.get_logger(__name__)
class Pix2StructProcessor(ProcessorMixin):
r"""
Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single
@ -85,6 +89,15 @@ class Pix2StructProcessor(ProcessorMixin):
Please refer to the docstring of the above two methods for more information.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning_once(
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
"In the new behavior, If both images and text are provided, image_processor is not a VQA processor, and `add_special_tokens` is unset, "
"the default value of `add_special_tokens` will be changed to `False` when calling the tokenizer. "
"To test the new behavior, set `legacy=False`as a processor call argument."
)
if images is None and text is None:
raise ValueError("You have to specify either images or text.")
@ -93,8 +106,12 @@ class Pix2StructProcessor(ProcessorMixin):
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None)
# Get only text
if images is None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else True
)
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
return text_encoding
@ -108,6 +125,9 @@ class Pix2StructProcessor(ProcessorMixin):
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else legacy
)
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
if "attention_mask" in text_encoding:

View File

@ -168,6 +168,22 @@ class Qwen2VLProcessor(ProcessorMixin):
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names

View File

@ -208,20 +208,6 @@ class UdopProcessor(ProcessorMixin):
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
@property
def model_input_names(self):
return ["pixel_values", "input_ids", "bbox", "attention_mask"]

View File

@ -67,6 +67,7 @@ from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .image_feature_extraction import ImageFeatureExtractionPipeline
from .image_segmentation import ImageSegmentationPipeline
from .image_text_to_text import ImageTextToTextPipeline
from .image_to_image import ImageToImagePipeline
from .image_to_text import ImageToTextPipeline
from .mask_generation import MaskGenerationPipeline
@ -119,6 +120,7 @@ if is_torch_available():
AutoModelForDocumentQuestionAnswering,
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForImageTextToText,
AutoModelForMaskedLM,
AutoModelForMaskGeneration,
AutoModelForObjectDetection,
@ -384,6 +386,17 @@ SUPPORTED_TASKS = {
},
"type": "multimodal",
},
"image-text-to-text": {
"impl": ImageTextToTextPipeline,
"tf": (),
"pt": (AutoModelForImageTextToText,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "2c9ba3b"),
}
},
"type": "multimodal",
},
"object-detection": {
"impl": ObjectDetectionPipeline,
"tf": (),
@ -601,6 +614,7 @@ def pipeline(
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
- `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`].
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].

View File

@ -951,6 +951,14 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
self._num_workers = kwargs.pop("num_workers", None)
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
# In processor only mode, we can get the modality processors from the processor
if self.processor is not None and all(
[self.tokenizer is None, self.feature_extractor is None, self.image_processor is None]
):
self.tokenizer = getattr(self.processor, "tokenizer", None)
self.feature_extractor = getattr(self.processor, "feature_extractor", None)
self.image_processor = getattr(self.processor, "image_processor", None)
if self.image_processor is None and self.feature_extractor is not None:
if isinstance(self.feature_extractor, BaseImageProcessor):
# Backward compatible change, if users called

View File

@ -0,0 +1,416 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from typing import Dict, List, Optional, Union
from ..processing_utils import ProcessingKwargs, Unpack
from ..utils import (
add_end_docstrings,
is_torch_available,
is_vision_available,
logging,
requires_backends,
)
from .base import Pipeline, build_pipeline_init_args
if is_vision_available():
from PIL import Image
from ..image_utils import load_images, valid_images
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from .pt_utils import KeyDataset
logger = logging.get_logger(__name__)
IMAGE_TOKEN = "<image>"
class ReturnType(enum.Enum):
TENSORS = 0
NEW_TEXT = 1
FULL_TEXT = 2
class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""
def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image", List["Image.Image"]]):
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
images = retrieve_images_in_messages(messages, images)
self.messages = messages
self.images = images
def retrieve_images_in_messages(
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
):
"""
Retrieve and combine images from the chat and the images passed as input.
"""
if images is None:
images = []
idx_images = 0
retrieved_images = []
for message in messages:
for content in message["content"]:
if isinstance(content, dict) and content.get("type") == "image":
if "image" in content:
retrieved_images.append(content["image"])
elif idx_images < len(images):
retrieved_images.append(images[idx_images])
idx_images += 1
else:
raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
# The number of images passed should be consistent with the number of images in the chat without an image key
if idx_images != len(images):
raise ValueError(
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
)
return retrieved_images
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
class ImageTextToTextPipeline(Pipeline):
"""
Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text.
When the underlying model is a conversational model, it can also accept one or more chats,
in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
Example:
```python
>>> from transformers import pipeline
>>> pipe = pipeline(task="image-text-to-text", model="Salesforce/blip-image-captioning-base")
>>> pipe("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", text="A photo of")
[{'generated_text': 'a photo of two birds'}]
```
```python
>>> from transformers import pipeline
>>> pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
>>> messages = [
>>> {
>>> "role": "user",
>>> "content": [
>>> {
>>> "type": "image",
>>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
>>> },
>>> {"type": "text", "text": "Describe this image."},
>>> ],
>>> },
>>> {
>>> "role": "assistant",
>>> "content": [
>>> {"type": "text", "text": "There is a dog and"},
>>> ],
>>> },
>>> ]
>>> pipe(text=messages, max_new_tokens=20, return_full_text=False)
[{'input_text': [{'role': 'user',
'content': [{'type': 'image',
'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
{'type': 'text', 'text': 'Describe this image.'}]},
{'role': 'assistant',
'content': [{'type': 'text', 'text': 'There is a dog and'}]}],
'generated_text': ' a person in the image. The dog is sitting on the sand, and the person is sitting on'}]
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
This image-text to text pipeline can currently be loaded from pipeline() using the following task identifier:
"image-text-to-text".
See the list of available models on
[huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-text-to-text).
"""
_load_processor = True
_load_image_processor = False
_load_feature_extractor = False
_load_tokenizer = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
def _sanitize_parameters(
self,
max_new_tokens=None,
generate_kwargs=None,
timeout=None,
return_full_text=None,
return_tensors=None,
return_type=None,
continue_final_message=None,
**kwargs: Unpack[ProcessingKwargs],
):
forward_kwargs = {}
preprocess_params = {}
postprocess_params = {}
preprocess_params["processing_kwargs"] = kwargs
if timeout is not None:
preprocess_params["timeout"] = timeout
if continue_final_message is not None:
preprocess_params["continue_final_message"] = continue_final_message
if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if max_new_tokens is not None:
if "generate_kwargs" not in forward_kwargs:
forward_kwargs["generate_kwargs"] = {}
if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
raise ValueError(
"'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
" please use only one"
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
if return_full_text is not None and return_type is None:
if return_tensors is not None:
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
if return_tensors is not None and return_type is None:
return_type = ReturnType.TENSORS
if return_type is not None:
postprocess_params["return_type"] = return_type
if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message
return preprocess_params, forward_kwargs, postprocess_params
def __call__(
self,
images: Optional[
Union[str, List[str], List[List[str]], "Image.Image", List["Image.Image"], List[List["Image.Image"]]]
] = None,
text: Optional[Union[str, List[str], List[dict]]] = None,
**kwargs,
):
"""
Generate a text given text and the image(s) passed as inputs.
Args:
images (`str`, `List[str]`, `PIL.Image or `List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing a HTTP(s) link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly
The pipeline accepts either a single image or a batch of images.
text (str, List[str], `List[Dict[str, Union[str, PIL.Image]]]`):
The text to be used for generation. If a list of strings is passed, the length of the list should be the
same as the number of images. Text can also follow the chat format: a list of dictionaries where each
dictionary represents a message in a conversation. Each dictionary should have two keys: 'role' and
'content'. 'role' should be one of 'user', 'system' or 'assistant'. 'content' should be a list of dictionary
containing the text of the message and the type of the message. The type of the message can be either
'text' or 'image'. If the type is 'image', no text is needed.
return_tensors (`bool`, *optional*, defaults to `False`):
Returns the tensors of predictions (as token indices) in the outputs. If set to
`True`, the decoded text is not returned.
return_text (`bool`, *optional*):
Returns the decoded texts in the outputs.
return_full_text (`bool`, *optional*, defaults to `True`):
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
specified at the same time as `return_text`.
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
By default this is `True` when the final message in the input chat has the `assistant` role and
`False` otherwise, but you can manually override that behaviour by setting this flag.
Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following key (cannot return a combination
of both `generated_text` and `generated_token_ids`):
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
- **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token
ids of the generated text.
- **input_text** (`str`) -- The input text.
"""
if images is None and text is None:
raise ValueError("You must at least provide either text or images.")
if images is not None and text is None and not valid_images(images):
"""
Supports the following format
- {"image": image, "text": text}
- [{"image": image, "text": text}]
- Generator and datasets
This is a common pattern in other multimodal pipelines, so we support it here as well.
"""
return super().__call__(images, **kwargs)
if isinstance(text, (list, tuple, KeyDataset)) and isinstance(text[0], (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(text[0], dict):
return super().__call__(Chat(text, images), **kwargs)
else:
if images is None:
images = [None] * len(text)
chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)
# encourage the user to use the chat format if supported
if getattr(self.processor, "chat_template", None) is not None:
logger.warning_once(
"The input data was not formatted as a chat with dicts containing 'role' and 'content' keys, even though this model supports chat. "
"Consider using the chat format for better results. For more information, see https://huggingface.co/docs/transformers/en/chat_templating"
)
# support text only generation
if images is None:
return super().__call__(text, **kwargs)
if text is None:
raise ValueError("You must provide text for this pipeline.")
return super().__call__({"images": images, "text": text}, **kwargs)
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
# In case we only have text inputs
if isinstance(inputs, (list, tuple, str)):
images = None
text = inputs
inputs_text = inputs
else:
if isinstance(inputs, Chat):
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = inputs.messages[-1]["role"] == "assistant"
text = self.processor.apply_chat_template(
inputs.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
)
inputs_text = inputs
images = inputs.images
else:
text = inputs["text"]
inputs_text = inputs["text"]
images = inputs["images"]
images = load_images(images)
# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
processing_kwargs.setdefault("padding", True)
model_inputs = self.processor(
images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs
).to(dtype=self.torch_dtype)
model_inputs["text"] = inputs_text
return model_inputs
def _forward(self, model_inputs, generate_kwargs=None):
generate_kwargs = {} if generate_kwargs is None else generate_kwargs
prompt_text = model_inputs.pop("text")
input_ids = (
model_inputs["input_ids"] if "input_ids" in model_inputs else model_inputs["decoder_input_ids"]
) # for decoder-only models
generated_sequence = self.model.generate(**model_inputs, **generate_kwargs)
return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
input_texts = model_outputs["prompt_text"]
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
generated_sequence = model_outputs["generated_sequence"]
input_ids = model_outputs["input_ids"]
if return_type == ReturnType.TENSORS:
return [
{"input_text": input_texts[i], "generated_token_ids": generated_sequence[i]}
for i in range(len(input_texts))
]
# Decode inputs and outputs the same way to remove input text from generated text if present
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)
# Force consistent behavior for including the input text in the output
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Remove the input text from the generated text if the generated text starts with the input text
# (accounting for the possibility of a space between the input and generated text)
new_generated_texts = []
for text_generated, decoded_input in zip(generated_texts, decoded_inputs):
# There can be added characters before the input text, so we need to find the beginning of the input text in the generated text
index_input_text = text_generated.find(decoded_input)
# Limit the search to 2 residual characters, like spaces or new lines, to avoid removing a large part of the answer
if 0 <= index_input_text <= 2:
# If the input text is found, we remove it
new_generated_texts.append(text_generated[index_input_text + len(decoded_input) :])
else:
new_generated_texts.append(text_generated)
generated_texts = new_generated_texts
if return_type == ReturnType.FULL_TEXT:
full_texts = []
for prompt_text, generated_text in zip(input_texts, generated_texts):
if isinstance(prompt_text, str):
generated_text = prompt_text + generated_text
elif isinstance(prompt_text, Chat):
if continue_final_message is None:
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
# default because very few models support multiple separate, consecutive assistant messages
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
if continue_final_message:
# With assistant prefill, concat onto the end of the last message
new_text = dict(prompt_text.messages[-1]["content"][-1].items())
new_text["text"] += generated_text
generated_text = list(prompt_text.messages)[:-1] + [
{
"role": prompt_text.messages[-1]["role"],
"content": prompt_text.messages[-1]["content"][:-1] + [new_text],
}
]
else:
# When we're not starting from a prefill, the output is a new assistant message
generated_text = list(prompt_text.messages) + [
{"role": "assistant", "content": generated_text}
]
full_texts.append(generated_text)
generated_texts = full_texts
records = [
{
"input_text": input_text.messages if isinstance(input_text, Chat) else input_text,
"generated_text": generated_text,
}
for input_text, generated_text in zip(input_texts, generated_texts)
]
return records

View File

@ -134,6 +134,10 @@ class ImageToTextPipeline(Pipeline):
image = load_image(image, timeout=timeout)
if prompt is not None:
logger.warning_once(
"Passing `prompt` to the `image-to-text` pipeline is deprecated and will be removed in version 4.48"
" of 🤗 Transformers. Use the `image-text-to-text` pipeline instead",
)
if not isinstance(prompt, str):
raise ValueError(
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "

View File

@ -7,7 +7,7 @@ from .base import ChunkPipeline, build_pipeline_init_args
if is_vision_available():
from PIL import Image
from ..image_utils import load_image
from ..image_utils import load_image, valid_images
if is_torch_available():
import torch
@ -130,8 +130,23 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
if isinstance(image, (str, Image.Image)):
inputs = {"image": image, "candidate_labels": candidate_labels}
elif isinstance(image, (list, tuple)) and valid_images(image):
return list(
super().__call__(
({"image": img, "candidate_labels": labels} for img, labels in zip(image, candidate_labels)),
**kwargs,
)
)
else:
"""
Supports the following format
- {"image": image, "candidate_labels": candidate_labels}
- [{"image": image, "candidate_labels": candidate_labels}]
- Generator and datasets
This is a common pattern in other multimodal pipelines, so we support it here as well.
"""
inputs = image
results = super().__call__(inputs, **kwargs)
return results

View File

@ -1107,6 +1107,20 @@ class ProcessorMixin(PushToHubMixin):
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of a vlm to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
def _validate_images_text_input_order(images, text):
"""

View File

@ -436,6 +436,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"feature-extraction": BlipModel,
"image-to-text": BlipForConditionalGeneration,
"visual-question-answering": BlipForQuestionAnswering,
"image-text-to-text": BlipForConditionalGeneration,
}
if is_torch_available()
else {}

View File

@ -767,6 +767,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
"feature-extraction": Blip2Model,
"image-to-text": Blip2ForConditionalGeneration,
"visual-question-answering": Blip2ForConditionalGeneration,
"image-text-to-text": Blip2ForConditionalGeneration,
}
if is_torch_available()
else {}

View File

@ -276,6 +276,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
{
"feature-extraction": ChameleonModel,
"text-generation": ChameleonForConditionalGeneration,
"image-text-to-text": ChameleonForConditionalGeneration,
}
if is_torch_available()
else {}

View File

@ -265,7 +265,9 @@ class FuyuModelTester:
@require_torch
class FuyuModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = {"text-generation": FuyuForCausalLM} if is_torch_available() else {}
pipeline_model_mapping = (
{"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {}
)
test_head_masking = False
test_pruning = False

View File

@ -401,7 +401,12 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": GitModel, "image-to-text": GitForCausalLM, "text-generation": GitForCausalLM}
{
"feature-extraction": GitModel,
"image-to-text": GitForCausalLM,
"text-generation": GitForCausalLM,
"image-text-to-text": GitForCausalLM,
}
if is_torch_available()
else {}
)

View File

@ -332,7 +332,11 @@ class IdeficsModelTester:
@require_torch
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": IdeficsModel} if is_torch_available() else {}
pipeline_model_mapping = (
{"feature-extraction": IdeficsModel, "image-text-to-text": IdeficsForVisionText2Text}
if is_torch_available()
else {}
)
test_pruning = False
test_headmasking = False
test_torchscript = False

View File

@ -375,6 +375,7 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": Idefics2ForConditionalGeneration} if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True

View File

@ -317,6 +317,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": Idefics3ForConditionalGeneration} if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True

View File

@ -455,6 +455,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester:
@require_torch
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration}
fx_compatible = False
test_head_masking = False
test_pruning = False

View File

@ -257,7 +257,11 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Kosmos2ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": Kosmos2Model, "image-to-text": Kosmos2ForConditionalGeneration}
{
"feature-extraction": Kosmos2Model,
"image-to-text": Kosmos2ForConditionalGeneration,
"image-text-to-text": Kosmos2ForConditionalGeneration,
}
if is_torch_available()
else {}
)
@ -269,6 +273,7 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
_is_composite = True
# TODO: `image-to-text` pipeline for this model needs Processor.
# TODO: Tiny model needs fixing for `image-text-to-text` (latent_query_num=3 not compatible with num_image_tokens=64).
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
@ -279,7 +284,10 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
feature_extractor_name,
processor_name,
):
return pipeline_test_case_name == "ImageToTextPipelineTests"
return (
pipeline_test_case_name == "ImageToTextPipelineTests"
or pipeline_test_case_name == "ImageTextToTextPipelineTests"
)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)

View File

@ -183,7 +183,11 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {}
pipeline_model_mapping = (
{"image-to-text": LlavaForConditionalGeneration, "image-text-to-text": LlavaForConditionalGeneration}
if is_torch_available()
else {}
)
test_pruning = False
test_head_masking = False
_is_composite = True

View File

@ -216,6 +216,7 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": LlavaNextForConditionalGeneration} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
_is_composite = True

View File

@ -217,6 +217,9 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
all_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
{"image-text-to-text": LlavaOnevisionForConditionalGeneration} if is_torch_available() else {}
)
test_pruning = False
test_head_masking = False
_is_composite = True

View File

@ -264,6 +264,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": MllamaForConditionalGeneration} if is_torch_available() else ()
test_pruning = False
test_head_masking = False
test_torchscript = False

View File

@ -183,6 +183,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": PaliGemmaForConditionalGeneration}
fx_compatible = False
test_pruning = False
test_torchscript = False

View File

@ -420,7 +420,11 @@ class Pix2StructModelTester:
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
pipeline_model_mapping = (
{"image-to-text": Pix2StructForConditionalGeneration, "image-text-to-text": Pix2StructForConditionalGeneration}
if is_torch_available()
else {}
)
fx_compatible = False
test_head_masking = False
test_pruning = False

View File

@ -224,6 +224,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
test_pruning = False
test_head_masking = False

View File

@ -275,7 +275,11 @@ class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (UdopForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": UdopModel} if is_torch_available() else {}
pipeline_model_mapping = (
{"feature-extraction": UdopModel, "image-text-to-text": UdopForConditionalGeneration}
if is_torch_available()
else {}
)
fx_compatible = False
test_pruning = False
test_torchscript = False

View File

@ -170,6 +170,7 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-text-to-text": VipLlavaForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_pruning = False
test_resize_embeddings = True

View File

@ -0,0 +1,260 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
from transformers.pipelines import ImageTextToTextPipeline, pipeline
from transformers.testing_utils import (
is_pipeline_test,
require_torch,
require_vision,
slow,
)
from .test_pipelines_common import ANY
if is_vision_available():
from PIL import Image
else:
class Image:
@staticmethod
def open(*args, **kwargs):
pass
@is_pipeline_test
@require_vision
class ImageTextToTextPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, image_processor, torch_dtype="float32"):
pipe = ImageTextToTextPipeline(model=model, processor=processor, torch_dtype=torch_dtype)
image_token = getattr(processor.tokenizer, "image_token", "")
examples = [
{
"images": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"text": f"{image_token}This is a ",
},
{
"images": "./tests/fixtures/tests_samples/COCO/000000039769.png",
"text": f"{image_token}Here I see a ",
},
]
return pipe, examples
def run_pipeline_test(self, pipe, examples):
outputs = pipe(examples[0].get("images"), text=examples[0].get("text"))
self.assertEqual(
outputs,
[
{"input_text": ANY(str), "generated_text": ANY(str)},
],
)
@require_torch
def test_small_model_pt_token(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
text = "<image> What this is? Assistant: This is"
outputs = pipe(image, text=text)
self.assertEqual(
outputs,
[
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
}
],
)
outputs = pipe([image, image], text=[text, text])
self.assertEqual(
outputs,
[
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
},
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
},
],
)
@require_torch
def test_consistent_batching_behaviour(self):
pipe = pipeline("image-text-to-text", model="microsoft/kosmos-2-patch14-224")
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
prompt = "a photo of"
outputs = pipe([image, image], text=[prompt, prompt])
outputs_batched = pipe([image, image], text=[prompt, prompt], batch_size=2)
self.assertEqual(outputs, outputs_batched)
@slow
@require_torch
def test_model_pt_chat_template(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
image_ny = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
image_chicago = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between these two images?"},
{"type": "image"},
{"type": "image"},
],
}
]
outputs = pipe([image_ny, image_chicago], text=messages)
self.assertEqual(
outputs,
[
{
"input_text": [
{
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between these two images?"},
{"type": "image"},
{"type": "image"},
],
}
],
"generated_text": [
{
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between these two images?"},
{"type": "image"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows",
},
],
}
],
)
@slow
@require_torch
def test_model_pt_chat_template_continue_final_message(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "There is a dog and"},
],
},
]
outputs = pipe(text=messages)
self.assertEqual(
outputs,
[
{
"input_text": [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
},
{"role": "assistant", "content": [{"type": "text", "text": "There is a dog and"}]},
],
"generated_text": [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "There is a dog and a person in the image. The dog is sitting on the sand, and the person is sitting on",
}
],
},
],
}
],
)
@slow
@require_torch
def test_model_pt_chat_template_new_text(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
}
]
outputs = pipe(text=messages, return_full_text=False)
self.assertEqual(
outputs,
[
{
"input_text": [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
}
],
"generated_text": "In the image, a woman is sitting on the sandy beach, her legs crossed in a relaxed manner",
}
],
)

View File

@ -14,7 +14,12 @@
import unittest
from transformers import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, is_vision_available, pipeline
from transformers import (
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
ZeroShotObjectDetectionPipeline,
is_vision_available,
pipeline,
)
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
@ -52,9 +57,11 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase):
processor=None,
torch_dtype="float32",
):
object_detector = pipeline(
"zero-shot-object-detection",
model="hf-internal-testing/tiny-random-owlvit-object-detection",
object_detector = ZeroShotObjectDetectionPipeline(
model=model,
processor=processor,
tokenizer=tokenizer,
image_processor=image_processor,
torch_dtype=torch_dtype,
)
@ -67,7 +74,7 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase):
return object_detector, examples
def run_pipeline_test(self, object_detector, examples):
outputs = object_detector(examples[0], threshold=0.0)
outputs = object_detector(examples[0].get("image"), examples[0].get("candidate_labels"), threshold=0.0)
n = len(outputs)
self.assertGreater(n, 0)

View File

@ -71,6 +71,7 @@ from .pipelines.test_pipelines_fill_mask import FillMaskPipelineTests
from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests
from .pipelines.test_pipelines_image_feature_extraction import ImageFeatureExtractionPipelineTests
from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests
from .pipelines.test_pipelines_image_text_to_text import ImageTextToTextPipelineTests
from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests
from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests
from .pipelines.test_pipelines_mask_generation import MaskGenerationPipelineTests
@ -102,6 +103,7 @@ pipeline_test_mapping = {
"image-classification": {"test": ImageClassificationPipelineTests},
"image-feature-extraction": {"test": ImageFeatureExtractionPipelineTests},
"image-segmentation": {"test": ImageSegmentationPipelineTests},
"image-text-to-text": {"test": ImageTextToTextPipelineTests},
"image-to-image": {"test": ImageToImagePipelineTests},
"image-to-text": {"test": ImageToTextPipelineTests},
"mask-generation": {"test": MaskGenerationPipelineTests},
@ -586,6 +588,18 @@ class PipelineTesterMixin:
def test_pipeline_image_segmentation_fp16(self):
self.run_task_tests(task="image-segmentation", torch_dtype="float16")
@is_pipeline_test
@require_vision
@require_torch
def test_pipeline_image_text_to_text(self):
self.run_task_tests(task="image-text-to-text")
@is_pipeline_test
@require_vision
@require_torch
def test_pipeline_image_text_to_text_fp16(self):
self.run_task_tests(task="image-text-to-text", torch_dtype="float16")
@is_pipeline_test
@require_vision
def test_pipeline_image_to_text(self):

View File

@ -2896,7 +2896,7 @@
"model_classes": [
"IdeficsForVisionText2Text"
],
"sha": "2c2f2e2cd6b02a77d0cdd8c3767ba9a6267dbd20"
"sha": "a6be81294ff7a3d44f3aef0ed18e42b97c426831"
},
"IdeficsModel": {
"tokenizer_classes": [

View File

@ -335,6 +335,7 @@ OBJECTS_TO_IGNORE = [
"ImageFeatureExtractionPipeline",
"ImageGPTConfig",
"ImageSegmentationPipeline",
"ImageTextToTextPipeline",
"ImageToImagePipeline",
"ImageToTextPipeline",
"InformerConfig",

View File

@ -69,6 +69,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"),
("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"),
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
("image-text-to-text", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"),
("image-to-image", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES", "AutoModelForImageToImage"),
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),