mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add image text to text pipeline (#34170)
* Standardize image-text-to-text-models-output add post_process_image_text_to_text to chameleon and cleanup Fix legacy kwarg behavior and deprecation warning add post_process_image_text_to_text to qwen2_vl and llava_onevision Add post_process_image_text_to_text to idefics3, mllama, pixtral processor * nit var name post_process_image_text_to_text udop * nit fix deprecation warnings * Add image-text-to-text pipeline * add support for image url in chat template for pipeline * Reformat to be fully compatible with chat templates * Add tests chat template * Fix imports and tests * Add pipeline tag * change logic handling of single prompt ans multiple images * add pipeline mapping to models * fix batched inference * fix tests * Add manual batching for preprocessing * Fix outputs with nested images * Add support for all common processing kwargs * Add default padding when multiple text inputs (batch size>1) * nit change version deprecation warning * Add support for text only inference * add chat_template warnings * Add pipeline tests and add copied from post process function * Fix batched pipeline tests * nit * Fix pipeline tests blip2 * remove unnecessary max_new_tokens * revert processing kosmos2 and remove unnecessary max_new_tokens * fix pipeline tests idefics * Force try loading processor if pipeline supports it * revert load_processor change * hardcode loading only processor * remove unnecessary try except * skip imagetexttotext tests for kosmos2 as tiny model causes problems * Make code clearer * Address review comments * remove preprocessing logic from pipeline * fix fuyu * add BC resize fuyu * Move post_process_image_text_to_text to ProcessorMixin * add guard in post_process * fix zero shot object detection pipeline * add support for generator input in pipeline * nit * change default image-text-to-text model to llava onevision * fix owlv2 size dict * Change legacy deprecation warning to only show when True
This commit is contained in:
parent
c443d8d536
commit
203e27059b
@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageTextToTextPipeline
|
||||
|
||||
[[autodoc]] ImageTextToTextPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### MaskGenerationPipeline
|
||||
|
||||
[[autodoc]] MaskGenerationPipeline
|
||||
|
@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageTextToTextPipeline
|
||||
|
||||
[[autodoc]] ImageTextToTextPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### VisualQuestionAnsweringPipeline
|
||||
|
||||
[[autodoc]] VisualQuestionAnsweringPipeline
|
||||
|
@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageTextToTextPipeline
|
||||
|
||||
[[autodoc]] ImageTextToTextPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### MaskGenerationPipeline
|
||||
|
||||
[[autodoc]] MaskGenerationPipeline
|
||||
|
@ -868,6 +868,7 @@ _import_structure = {
|
||||
"ImageClassificationPipeline",
|
||||
"ImageFeatureExtractionPipeline",
|
||||
"ImageSegmentationPipeline",
|
||||
"ImageTextToTextPipeline",
|
||||
"ImageToImagePipeline",
|
||||
"ImageToTextPipeline",
|
||||
"JsonPipelineDataFormat",
|
||||
@ -5794,6 +5795,7 @@ if TYPE_CHECKING:
|
||||
ImageClassificationPipeline,
|
||||
ImageFeatureExtractionPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
ImageTextToTextPipeline,
|
||||
ImageToImagePipeline,
|
||||
ImageToTextPipeline,
|
||||
JsonPipelineDataFormat,
|
||||
|
@ -385,6 +385,27 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
|
||||
return image
|
||||
|
||||
|
||||
def load_images(
|
||||
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
|
||||
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
|
||||
"""Loads images, handling different levels of nesting.
|
||||
|
||||
Args:
|
||||
images: A single image, a list of images, or a list of lists of images to load.
|
||||
timeout: Timeout for loading images.
|
||||
|
||||
Returns:
|
||||
A single image, a list of images, a list of lists of images.
|
||||
"""
|
||||
if isinstance(images, (list, tuple)):
|
||||
if len(images) and isinstance(images[0], (list, tuple)):
|
||||
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
|
||||
else:
|
||||
return [load_image(image, timeout=timeout) for image in images]
|
||||
else:
|
||||
return load_image(images, timeout=timeout)
|
||||
|
||||
|
||||
def validate_preprocess_arguments(
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
|
@ -114,6 +114,7 @@ else:
|
||||
("oneformer", ("OneFormerImageProcessor",)),
|
||||
("owlv2", ("Owlv2ImageProcessor",)),
|
||||
("owlvit", ("OwlViTImageProcessor",)),
|
||||
("paligemma", ("SiglipImageProcessor",)),
|
||||
("perceiver", ("PerceiverImageProcessor",)),
|
||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||
("pixtral", ("PixtralImageProcessor",)),
|
||||
|
@ -24,12 +24,16 @@ from typing import List, Optional, Union
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
class DonutProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {}
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DonutProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single
|
||||
@ -85,6 +89,16 @@ class DonutProcessor(ProcessorMixin):
|
||||
[`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
legacy = kwargs.pop("legacy", True)
|
||||
if legacy:
|
||||
# With `add_special_tokens=True`, the performance of donut are degraded when working with both images and text.
|
||||
logger.warning_once(
|
||||
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
|
||||
"In the new behavior, if both images and text are provided, the default value of `add_special_tokens` "
|
||||
"will be changed to `False` when calling the tokenizer if `add_special_tokens` is unset. "
|
||||
"To test the new behavior, set `legacy=False`as a processor call argument."
|
||||
)
|
||||
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(images, text, **kwargs)
|
||||
|
||||
@ -100,6 +114,8 @@ class DonutProcessor(ProcessorMixin):
|
||||
if images is not None:
|
||||
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
if text is not None:
|
||||
if not legacy and images is not None:
|
||||
output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
|
||||
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if text is None:
|
||||
|
@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
pad,
|
||||
resize,
|
||||
@ -475,6 +475,7 @@ class FuyuImageProcessor(BaseImageProcessor):
|
||||
input_data_format = infer_channel_dimension_format(batch_images[0][0])
|
||||
|
||||
original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
|
||||
size = get_size_dict(size) # for BC
|
||||
|
||||
if do_resize:
|
||||
batch_images = [
|
||||
|
@ -264,10 +264,10 @@ def _tokenize_prompts_with_image_and_batch(
|
||||
bos_token = tokenizer.vocab["|ENDOFTEXT|"]
|
||||
prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
|
||||
if add_beginning_of_answer_token:
|
||||
boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
|
||||
beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
|
||||
# Only add bbox open token to the last subsequence since that is what will be completed
|
||||
for token_seq in prompts_tokens:
|
||||
token_seq[-1].append(boa)
|
||||
token_seq[-1].append(beginning_of_answer)
|
||||
|
||||
# Now we have a list of list of tokens which each list has a different
|
||||
# size. We want to extend this list to:
|
||||
@ -682,6 +682,32 @@ class FuyuProcessor(ProcessorMixin):
|
||||
|
||||
return results
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
containing the token ids of the generated sequences.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text output.
|
||||
"""
|
||||
beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
|
||||
# get boa index for each outputted sequence tensor
|
||||
# start all generated sequences from the beginning of the answer token, pad to have consistent length
|
||||
unpadded_output_sequences = [
|
||||
seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
|
||||
]
|
||||
max_len = max(len(seq) for seq in unpadded_output_sequences)
|
||||
# convert to torch and pad sequences
|
||||
padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
|
||||
for i, seq in enumerate(unpadded_output_sequences):
|
||||
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)
|
||||
|
||||
return self.batch_decode(padded_output_sequences, skip_special_tokens=True)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
|
@ -22,12 +22,16 @@ from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
class GitProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {}
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GitProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
|
||||
@ -91,6 +95,15 @@ class GitProcessor(ProcessorMixin):
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
legacy = kwargs.pop("legacy", True)
|
||||
if legacy:
|
||||
logger.warning_once(
|
||||
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
|
||||
"In the new behavior, if both images and text are provided, the last token (EOS token) "
|
||||
"of the input_ids and attention_mask tensors will be removed. "
|
||||
"To test the new behavior, set `legacy=False`as a processor call argument."
|
||||
)
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
||||
|
||||
@ -110,6 +123,10 @@ class GitProcessor(ProcessorMixin):
|
||||
if images is not None:
|
||||
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
data.update(image_features)
|
||||
if not legacy:
|
||||
data["input_ids"] = data["input_ids"][:, :-1]
|
||||
data["attention_mask"] = data["attention_mask"][:, :-1]
|
||||
|
||||
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
|
@ -428,6 +428,21 @@ class Kosmos2Processor(ProcessorMixin):
|
||||
return clean_text_and_extract_entities_with_bboxes(caption)
|
||||
return caption
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
|
||||
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
|
@ -342,6 +342,22 @@ class MllamaProcessor(ProcessorMixin):
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
|
@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
center_to_corners_format,
|
||||
pad,
|
||||
@ -399,6 +399,7 @@ class Owlv2ImageProcessor(BaseImageProcessor):
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size) # for BC
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from typing import List, Optional, Union
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
class Pix2StructImagesKwargs(ImagesKwargs, total=False):
|
||||
@ -48,6 +49,9 @@ class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
|
||||
}
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Pix2StructProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single
|
||||
@ -85,6 +89,15 @@ class Pix2StructProcessor(ProcessorMixin):
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
legacy = kwargs.pop("legacy", True)
|
||||
if legacy:
|
||||
logger.warning_once(
|
||||
"Legacy behavior is being used. The current behavior will be deprecated in version 5.0.0. "
|
||||
"In the new behavior, If both images and text are provided, image_processor is not a VQA processor, and `add_special_tokens` is unset, "
|
||||
"the default value of `add_special_tokens` will be changed to `False` when calling the tokenizer. "
|
||||
"To test the new behavior, set `legacy=False`as a processor call argument."
|
||||
)
|
||||
|
||||
if images is None and text is None:
|
||||
raise ValueError("You have to specify either images or text.")
|
||||
|
||||
@ -93,8 +106,12 @@ class Pix2StructProcessor(ProcessorMixin):
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None)
|
||||
# Get only text
|
||||
if images is None and not self.image_processor.is_vqa:
|
||||
output_kwargs["text_kwargs"]["add_special_tokens"] = (
|
||||
add_special_tokens if add_special_tokens is not None else True
|
||||
)
|
||||
self.current_processor = self.tokenizer
|
||||
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
return text_encoding
|
||||
@ -108,6 +125,9 @@ class Pix2StructProcessor(ProcessorMixin):
|
||||
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
if text is not None and not self.image_processor.is_vqa:
|
||||
output_kwargs["text_kwargs"]["add_special_tokens"] = (
|
||||
add_special_tokens if add_special_tokens is not None else legacy
|
||||
)
|
||||
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if "attention_mask" in text_encoding:
|
||||
|
@ -168,6 +168,22 @@ class Qwen2VLProcessor(ProcessorMixin):
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
|
@ -208,20 +208,6 @@ class UdopProcessor(ProcessorMixin):
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
return ["pixel_values", "input_ids", "bbox", "attention_mask"]
|
||||
|
@ -67,6 +67,7 @@ from .fill_mask import FillMaskPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .image_feature_extraction import ImageFeatureExtractionPipeline
|
||||
from .image_segmentation import ImageSegmentationPipeline
|
||||
from .image_text_to_text import ImageTextToTextPipeline
|
||||
from .image_to_image import ImageToImagePipeline
|
||||
from .image_to_text import ImageToTextPipeline
|
||||
from .mask_generation import MaskGenerationPipeline
|
||||
@ -119,6 +120,7 @@ if is_torch_available():
|
||||
AutoModelForDocumentQuestionAnswering,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMaskGeneration,
|
||||
AutoModelForObjectDetection,
|
||||
@ -384,6 +386,17 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"image-text-to-text": {
|
||||
"impl": ImageTextToTextPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageTextToText,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", "2c9ba3b"),
|
||||
}
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"object-detection": {
|
||||
"impl": ObjectDetectionPipeline,
|
||||
"tf": (),
|
||||
@ -601,6 +614,7 @@ def pipeline(
|
||||
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
||||
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
|
||||
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
||||
- `"image-text-to-text"`: will return a [`ImageTextToTextPipeline`].
|
||||
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
|
||||
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
||||
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
|
||||
|
@ -951,6 +951,14 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
self._num_workers = kwargs.pop("num_workers", None)
|
||||
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||
|
||||
# In processor only mode, we can get the modality processors from the processor
|
||||
if self.processor is not None and all(
|
||||
[self.tokenizer is None, self.feature_extractor is None, self.image_processor is None]
|
||||
):
|
||||
self.tokenizer = getattr(self.processor, "tokenizer", None)
|
||||
self.feature_extractor = getattr(self.processor, "feature_extractor", None)
|
||||
self.image_processor = getattr(self.processor, "image_processor", None)
|
||||
|
||||
if self.image_processor is None and self.feature_extractor is not None:
|
||||
if isinstance(self.feature_extractor, BaseImageProcessor):
|
||||
# Backward compatible change, if users called
|
||||
|
416
src/transformers/pipelines/image_text_to_text.py
Normal file
416
src/transformers/pipelines/image_text_to_text.py
Normal file
@ -0,0 +1,416 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ..processing_utils import ProcessingKwargs, Unpack
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import Pipeline, build_pipeline_init_args
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_images, valid_images
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
||||
from .pt_utils import KeyDataset
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
IMAGE_TOKEN = "<image>"
|
||||
|
||||
|
||||
class ReturnType(enum.Enum):
|
||||
TENSORS = 0
|
||||
NEW_TEXT = 1
|
||||
FULL_TEXT = 2
|
||||
|
||||
|
||||
class Chat:
|
||||
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
|
||||
to this format because the rest of the pipeline code tends to assume that lists of messages are
|
||||
actually a batch of samples rather than messages in the same conversation."""
|
||||
|
||||
def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image", List["Image.Image"]]):
|
||||
for message in messages:
|
||||
if not ("role" in message and "content" in message):
|
||||
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
|
||||
images = retrieve_images_in_messages(messages, images)
|
||||
|
||||
self.messages = messages
|
||||
self.images = images
|
||||
|
||||
|
||||
def retrieve_images_in_messages(
|
||||
messages: dict, images: Optional[Union[str, List[str], "Image.Image", List["Image.Image"]]]
|
||||
):
|
||||
"""
|
||||
Retrieve and combine images from the chat and the images passed as input.
|
||||
"""
|
||||
if images is None:
|
||||
images = []
|
||||
idx_images = 0
|
||||
retrieved_images = []
|
||||
for message in messages:
|
||||
for content in message["content"]:
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if "image" in content:
|
||||
retrieved_images.append(content["image"])
|
||||
elif idx_images < len(images):
|
||||
retrieved_images.append(images[idx_images])
|
||||
idx_images += 1
|
||||
else:
|
||||
raise ValueError(
|
||||
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
||||
)
|
||||
|
||||
# The number of images passed should be consistent with the number of images in the chat without an image key
|
||||
if idx_images != len(images):
|
||||
raise ValueError(
|
||||
"The number of images in the chat messages should be the same as the number of images passed to the pipeline."
|
||||
)
|
||||
|
||||
return retrieved_images
|
||||
|
||||
|
||||
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
|
||||
class ImageTextToTextPipeline(Pipeline):
|
||||
"""
|
||||
Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text.
|
||||
When the underlying model is a conversational model, it can also accept one or more chats,
|
||||
in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
|
||||
Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> pipe = pipeline(task="image-text-to-text", model="Salesforce/blip-image-captioning-base")
|
||||
>>> pipe("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", text="A photo of")
|
||||
[{'generated_text': 'a photo of two birds'}]
|
||||
```
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||
>>> messages = [
|
||||
>>> {
|
||||
>>> "role": "user",
|
||||
>>> "content": [
|
||||
>>> {
|
||||
>>> "type": "image",
|
||||
>>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
>>> },
|
||||
>>> {"type": "text", "text": "Describe this image."},
|
||||
>>> ],
|
||||
>>> },
|
||||
>>> {
|
||||
>>> "role": "assistant",
|
||||
>>> "content": [
|
||||
>>> {"type": "text", "text": "There is a dog and"},
|
||||
>>> ],
|
||||
>>> },
|
||||
>>> ]
|
||||
>>> pipe(text=messages, max_new_tokens=20, return_full_text=False)
|
||||
[{'input_text': [{'role': 'user',
|
||||
'content': [{'type': 'image',
|
||||
'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
|
||||
{'type': 'text', 'text': 'Describe this image.'}]},
|
||||
{'role': 'assistant',
|
||||
'content': [{'type': 'text', 'text': 'There is a dog and'}]}],
|
||||
'generated_text': ' a person in the image. The dog is sitting on the sand, and the person is sitting on'}]
|
||||
```
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
This image-text to text pipeline can currently be loaded from pipeline() using the following task identifier:
|
||||
"image-text-to-text".
|
||||
|
||||
See the list of available models on
|
||||
[huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-text-to-text).
|
||||
"""
|
||||
|
||||
_load_processor = True
|
||||
_load_image_processor = False
|
||||
_load_feature_extractor = False
|
||||
_load_tokenizer = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
|
||||
|
||||
def _sanitize_parameters(
|
||||
self,
|
||||
max_new_tokens=None,
|
||||
generate_kwargs=None,
|
||||
timeout=None,
|
||||
return_full_text=None,
|
||||
return_tensors=None,
|
||||
return_type=None,
|
||||
continue_final_message=None,
|
||||
**kwargs: Unpack[ProcessingKwargs],
|
||||
):
|
||||
forward_kwargs = {}
|
||||
preprocess_params = {}
|
||||
postprocess_params = {}
|
||||
|
||||
preprocess_params["processing_kwargs"] = kwargs
|
||||
|
||||
if timeout is not None:
|
||||
preprocess_params["timeout"] = timeout
|
||||
|
||||
if continue_final_message is not None:
|
||||
preprocess_params["continue_final_message"] = continue_final_message
|
||||
|
||||
if generate_kwargs is not None:
|
||||
forward_kwargs["generate_kwargs"] = generate_kwargs
|
||||
|
||||
if max_new_tokens is not None:
|
||||
if "generate_kwargs" not in forward_kwargs:
|
||||
forward_kwargs["generate_kwargs"] = {}
|
||||
if "max_new_tokens" in forward_kwargs["generate_kwargs"]:
|
||||
raise ValueError(
|
||||
"'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter,"
|
||||
" please use only one"
|
||||
)
|
||||
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
||||
|
||||
if return_full_text is not None and return_type is None:
|
||||
if return_tensors is not None:
|
||||
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
|
||||
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
|
||||
if return_tensors is not None and return_type is None:
|
||||
return_type = ReturnType.TENSORS
|
||||
if return_type is not None:
|
||||
postprocess_params["return_type"] = return_type
|
||||
if continue_final_message is not None:
|
||||
postprocess_params["continue_final_message"] = continue_final_message
|
||||
|
||||
return preprocess_params, forward_kwargs, postprocess_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Optional[
|
||||
Union[str, List[str], List[List[str]], "Image.Image", List["Image.Image"], List[List["Image.Image"]]]
|
||||
] = None,
|
||||
text: Optional[Union[str, List[str], List[dict]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generate a text given text and the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a HTTP(s) link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images.
|
||||
text (str, List[str], `List[Dict[str, Union[str, PIL.Image]]]`):
|
||||
The text to be used for generation. If a list of strings is passed, the length of the list should be the
|
||||
same as the number of images. Text can also follow the chat format: a list of dictionaries where each
|
||||
dictionary represents a message in a conversation. Each dictionary should have two keys: 'role' and
|
||||
'content'. 'role' should be one of 'user', 'system' or 'assistant'. 'content' should be a list of dictionary
|
||||
containing the text of the message and the type of the message. The type of the message can be either
|
||||
'text' or 'image'. If the type is 'image', no text is needed.
|
||||
return_tensors (`bool`, *optional*, defaults to `False`):
|
||||
Returns the tensors of predictions (as token indices) in the outputs. If set to
|
||||
`True`, the decoded text is not returned.
|
||||
return_text (`bool`, *optional*):
|
||||
Returns the decoded texts in the outputs.
|
||||
return_full_text (`bool`, *optional*, defaults to `True`):
|
||||
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
|
||||
specified at the same time as `return_text`.
|
||||
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
|
||||
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
|
||||
By default this is `True` when the final message in the input chat has the `assistant` role and
|
||||
`False` otherwise, but you can manually override that behaviour by setting this flag.
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Each result comes as a dictionary with the following key (cannot return a combination
|
||||
of both `generated_text` and `generated_token_ids`):
|
||||
|
||||
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
|
||||
- **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token
|
||||
ids of the generated text.
|
||||
- **input_text** (`str`) -- The input text.
|
||||
"""
|
||||
if images is None and text is None:
|
||||
raise ValueError("You must at least provide either text or images.")
|
||||
if images is not None and text is None and not valid_images(images):
|
||||
"""
|
||||
Supports the following format
|
||||
- {"image": image, "text": text}
|
||||
- [{"image": image, "text": text}]
|
||||
- Generator and datasets
|
||||
This is a common pattern in other multimodal pipelines, so we support it here as well.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
if isinstance(text, (list, tuple, KeyDataset)) and isinstance(text[0], (list, tuple, dict)):
|
||||
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
||||
if isinstance(text[0], dict):
|
||||
return super().__call__(Chat(text, images), **kwargs)
|
||||
else:
|
||||
if images is None:
|
||||
images = [None] * len(text)
|
||||
chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈
|
||||
return super().__call__(chats, **kwargs)
|
||||
|
||||
# encourage the user to use the chat format if supported
|
||||
if getattr(self.processor, "chat_template", None) is not None:
|
||||
logger.warning_once(
|
||||
"The input data was not formatted as a chat with dicts containing 'role' and 'content' keys, even though this model supports chat. "
|
||||
"Consider using the chat format for better results. For more information, see https://huggingface.co/docs/transformers/en/chat_templating"
|
||||
)
|
||||
|
||||
# support text only generation
|
||||
if images is None:
|
||||
return super().__call__(text, **kwargs)
|
||||
if text is None:
|
||||
raise ValueError("You must provide text for this pipeline.")
|
||||
|
||||
return super().__call__({"images": images, "text": text}, **kwargs)
|
||||
|
||||
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
|
||||
# In case we only have text inputs
|
||||
if isinstance(inputs, (list, tuple, str)):
|
||||
images = None
|
||||
text = inputs
|
||||
inputs_text = inputs
|
||||
else:
|
||||
if isinstance(inputs, Chat):
|
||||
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
|
||||
# because very few models support multiple separate, consecutive assistant messages
|
||||
if continue_final_message is None:
|
||||
continue_final_message = inputs.messages[-1]["role"] == "assistant"
|
||||
text = self.processor.apply_chat_template(
|
||||
inputs.messages,
|
||||
add_generation_prompt=not continue_final_message,
|
||||
continue_final_message=continue_final_message,
|
||||
return_tensors=self.framework,
|
||||
)
|
||||
inputs_text = inputs
|
||||
images = inputs.images
|
||||
else:
|
||||
text = inputs["text"]
|
||||
inputs_text = inputs["text"]
|
||||
images = inputs["images"]
|
||||
|
||||
images = load_images(images)
|
||||
|
||||
# if batched text inputs, we set padding to True unless specified otherwise
|
||||
if isinstance(text, (list, tuple)) and len(text) > 1:
|
||||
processing_kwargs.setdefault("padding", True)
|
||||
model_inputs = self.processor(
|
||||
images=images, text=text, return_tensors=self.framework, legacy=False, **processing_kwargs
|
||||
).to(dtype=self.torch_dtype)
|
||||
|
||||
model_inputs["text"] = inputs_text
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs, generate_kwargs=None):
|
||||
generate_kwargs = {} if generate_kwargs is None else generate_kwargs
|
||||
prompt_text = model_inputs.pop("text")
|
||||
input_ids = (
|
||||
model_inputs["input_ids"] if "input_ids" in model_inputs else model_inputs["decoder_input_ids"]
|
||||
) # for decoder-only models
|
||||
generated_sequence = self.model.generate(**model_inputs, **generate_kwargs)
|
||||
|
||||
return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}
|
||||
|
||||
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
|
||||
input_texts = model_outputs["prompt_text"]
|
||||
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
|
||||
generated_sequence = model_outputs["generated_sequence"]
|
||||
input_ids = model_outputs["input_ids"]
|
||||
if return_type == ReturnType.TENSORS:
|
||||
return [
|
||||
{"input_text": input_texts[i], "generated_token_ids": generated_sequence[i]}
|
||||
for i in range(len(input_texts))
|
||||
]
|
||||
|
||||
# Decode inputs and outputs the same way to remove input text from generated text if present
|
||||
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
|
||||
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)
|
||||
|
||||
# Force consistent behavior for including the input text in the output
|
||||
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
||||
# Remove the input text from the generated text if the generated text starts with the input text
|
||||
# (accounting for the possibility of a space between the input and generated text)
|
||||
new_generated_texts = []
|
||||
for text_generated, decoded_input in zip(generated_texts, decoded_inputs):
|
||||
# There can be added characters before the input text, so we need to find the beginning of the input text in the generated text
|
||||
index_input_text = text_generated.find(decoded_input)
|
||||
# Limit the search to 2 residual characters, like spaces or new lines, to avoid removing a large part of the answer
|
||||
if 0 <= index_input_text <= 2:
|
||||
# If the input text is found, we remove it
|
||||
new_generated_texts.append(text_generated[index_input_text + len(decoded_input) :])
|
||||
else:
|
||||
new_generated_texts.append(text_generated)
|
||||
generated_texts = new_generated_texts
|
||||
if return_type == ReturnType.FULL_TEXT:
|
||||
full_texts = []
|
||||
for prompt_text, generated_text in zip(input_texts, generated_texts):
|
||||
if isinstance(prompt_text, str):
|
||||
generated_text = prompt_text + generated_text
|
||||
elif isinstance(prompt_text, Chat):
|
||||
if continue_final_message is None:
|
||||
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
|
||||
# default because very few models support multiple separate, consecutive assistant messages
|
||||
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
|
||||
if continue_final_message:
|
||||
# With assistant prefill, concat onto the end of the last message
|
||||
new_text = dict(prompt_text.messages[-1]["content"][-1].items())
|
||||
new_text["text"] += generated_text
|
||||
generated_text = list(prompt_text.messages)[:-1] + [
|
||||
{
|
||||
"role": prompt_text.messages[-1]["role"],
|
||||
"content": prompt_text.messages[-1]["content"][:-1] + [new_text],
|
||||
}
|
||||
]
|
||||
else:
|
||||
# When we're not starting from a prefill, the output is a new assistant message
|
||||
generated_text = list(prompt_text.messages) + [
|
||||
{"role": "assistant", "content": generated_text}
|
||||
]
|
||||
full_texts.append(generated_text)
|
||||
generated_texts = full_texts
|
||||
|
||||
records = [
|
||||
{
|
||||
"input_text": input_text.messages if isinstance(input_text, Chat) else input_text,
|
||||
"generated_text": generated_text,
|
||||
}
|
||||
for input_text, generated_text in zip(input_texts, generated_texts)
|
||||
]
|
||||
|
||||
return records
|
@ -134,6 +134,10 @@ class ImageToTextPipeline(Pipeline):
|
||||
image = load_image(image, timeout=timeout)
|
||||
|
||||
if prompt is not None:
|
||||
logger.warning_once(
|
||||
"Passing `prompt` to the `image-to-text` pipeline is deprecated and will be removed in version 4.48"
|
||||
" of 🤗 Transformers. Use the `image-text-to-text` pipeline instead",
|
||||
)
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError(
|
||||
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
|
||||
|
@ -7,7 +7,7 @@ from .base import ChunkPipeline, build_pipeline_init_args
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
from ..image_utils import load_image, valid_images
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -130,8 +130,23 @@ class ZeroShotObjectDetectionPipeline(ChunkPipeline):
|
||||
|
||||
if isinstance(image, (str, Image.Image)):
|
||||
inputs = {"image": image, "candidate_labels": candidate_labels}
|
||||
elif isinstance(image, (list, tuple)) and valid_images(image):
|
||||
return list(
|
||||
super().__call__(
|
||||
({"image": img, "candidate_labels": labels} for img, labels in zip(image, candidate_labels)),
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
else:
|
||||
"""
|
||||
Supports the following format
|
||||
- {"image": image, "candidate_labels": candidate_labels}
|
||||
- [{"image": image, "candidate_labels": candidate_labels}]
|
||||
- Generator and datasets
|
||||
This is a common pattern in other multimodal pipelines, so we support it here as well.
|
||||
"""
|
||||
inputs = image
|
||||
|
||||
results = super().__call__(inputs, **kwargs)
|
||||
return results
|
||||
|
||||
|
@ -1107,6 +1107,20 @@ class ProcessorMixin(PushToHubMixin):
|
||||
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
|
||||
)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of a vlm to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
|
||||
|
||||
|
||||
def _validate_images_text_input_order(images, text):
|
||||
"""
|
||||
|
@ -436,6 +436,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
"feature-extraction": BlipModel,
|
||||
"image-to-text": BlipForConditionalGeneration,
|
||||
"visual-question-answering": BlipForQuestionAnswering,
|
||||
"image-text-to-text": BlipForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
|
@ -767,6 +767,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
||||
"feature-extraction": Blip2Model,
|
||||
"image-to-text": Blip2ForConditionalGeneration,
|
||||
"visual-question-answering": Blip2ForConditionalGeneration,
|
||||
"image-text-to-text": Blip2ForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
|
@ -276,6 +276,7 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
{
|
||||
"feature-extraction": ChameleonModel,
|
||||
"text-generation": ChameleonForConditionalGeneration,
|
||||
"image-text-to-text": ChameleonForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
|
@ -265,7 +265,9 @@ class FuyuModelTester:
|
||||
@require_torch
|
||||
class FuyuModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FuyuForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-generation": FuyuForCausalLM} if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
{"text-generation": FuyuForCausalLM, "image-text-to-text": FuyuForCausalLM} if is_torch_available() else {}
|
||||
)
|
||||
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
@ -401,7 +401,12 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": GitModel, "image-to-text": GitForCausalLM, "text-generation": GitForCausalLM}
|
||||
{
|
||||
"feature-extraction": GitModel,
|
||||
"image-to-text": GitForCausalLM,
|
||||
"text-generation": GitForCausalLM,
|
||||
"image-text-to-text": GitForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -332,7 +332,11 @@ class IdeficsModelTester:
|
||||
@require_torch
|
||||
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": IdeficsModel} if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": IdeficsModel, "image-text-to-text": IdeficsForVisionText2Text}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
|
@ -375,6 +375,7 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
|
||||
all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": Idefics2ForConditionalGeneration} if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
|
@ -317,6 +317,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
|
||||
all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": Idefics3ForConditionalGeneration} if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
|
@ -455,6 +455,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester:
|
||||
@require_torch
|
||||
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": InstructBlipForConditionalGeneration}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
@ -257,7 +257,11 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Kosmos2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": Kosmos2Model, "image-to-text": Kosmos2ForConditionalGeneration}
|
||||
{
|
||||
"feature-extraction": Kosmos2Model,
|
||||
"image-to-text": Kosmos2ForConditionalGeneration,
|
||||
"image-text-to-text": Kosmos2ForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
@ -269,6 +273,7 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
_is_composite = True
|
||||
|
||||
# TODO: `image-to-text` pipeline for this model needs Processor.
|
||||
# TODO: Tiny model needs fixing for `image-text-to-text` (latent_query_num=3 not compatible with num_image_tokens=64).
|
||||
def is_pipeline_test_to_skip(
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
@ -279,7 +284,10 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return pipeline_test_case_name == "ImageToTextPipelineTests"
|
||||
return (
|
||||
pipeline_test_case_name == "ImageToTextPipelineTests"
|
||||
or pipeline_test_case_name == "ImageTextToTextPipelineTests"
|
||||
)
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
@ -183,7 +183,11 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
|
||||
all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
{"image-to-text": LlavaForConditionalGeneration, "image-text-to-text": LlavaForConditionalGeneration}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
@ -216,6 +216,7 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
|
||||
all_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": LlavaNextForConditionalGeneration} if is_torch_available() else {}
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
@ -217,6 +217,9 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
|
||||
all_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"image-text-to-text": LlavaOnevisionForConditionalGeneration} if is_torch_available() else {}
|
||||
)
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
@ -264,6 +264,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
|
||||
all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": MllamaForConditionalGeneration} if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
|
@ -183,6 +183,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
|
||||
all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": PaliGemmaForConditionalGeneration}
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
@ -420,7 +420,11 @@ class Pix2StructModelTester:
|
||||
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
|
||||
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
{"image-to-text": Pix2StructForConditionalGeneration, "image-text-to-text": Pix2StructForConditionalGeneration}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
@ -224,6 +224,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
|
||||
all_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Qwen2VLForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
|
@ -275,7 +275,11 @@ class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (UdopForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": UdopModel} if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": UdopModel, "image-text-to-text": UdopForConditionalGeneration}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
@ -170,6 +170,7 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
|
||||
all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-text-to-text": VipLlavaForConditionalGeneration} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = True
|
||||
|
260
tests/pipelines/test_pipelines_image_text_to_text.py
Normal file
260
tests/pipelines/test_pipelines_image_text_to_text.py
Normal file
@ -0,0 +1,260 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available
|
||||
from transformers.pipelines import ImageTextToTextPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
class ImageTextToTextPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, image_processor, torch_dtype="float32"):
|
||||
pipe = ImageTextToTextPipeline(model=model, processor=processor, torch_dtype=torch_dtype)
|
||||
image_token = getattr(processor.tokenizer, "image_token", "")
|
||||
examples = [
|
||||
{
|
||||
"images": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"text": f"{image_token}This is a ",
|
||||
},
|
||||
{
|
||||
"images": "./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"text": f"{image_token}Here I see a ",
|
||||
},
|
||||
]
|
||||
return pipe, examples
|
||||
|
||||
def run_pipeline_test(self, pipe, examples):
|
||||
outputs = pipe(examples[0].get("images"), text=examples[0].get("text"))
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"input_text": ANY(str), "generated_text": ANY(str)},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_token(self):
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
text = "<image> What this is? Assistant: This is"
|
||||
|
||||
outputs = pipe(image, text=text)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
outputs = pipe([image, image], text=[text, text])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
||||
},
|
||||
{
|
||||
"input_text": "<image> What this is? Assistant: This is",
|
||||
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_consistent_batching_behaviour(self):
|
||||
pipe = pipeline("image-text-to-text", model="microsoft/kosmos-2-patch14-224")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
prompt = "a photo of"
|
||||
|
||||
outputs = pipe([image, image], text=[prompt, prompt])
|
||||
outputs_batched = pipe([image, image], text=[prompt, prompt], batch_size=2)
|
||||
self.assertEqual(outputs, outputs_batched)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_model_pt_chat_template(self):
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||
image_ny = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||
image_chicago = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What’s the difference between these two images?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
],
|
||||
}
|
||||
]
|
||||
outputs = pipe([image_ny, image_chicago], text=messages)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"input_text": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What’s the difference between these two images?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"generated_text": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What’s the difference between these two images?"},
|
||||
{"type": "image"},
|
||||
{"type": "image"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_model_pt_chat_template_continue_final_message(self):
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "There is a dog and"},
|
||||
],
|
||||
},
|
||||
]
|
||||
outputs = pipe(text=messages)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"input_text": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "There is a dog and"}]},
|
||||
],
|
||||
"generated_text": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "There is a dog and a person in the image. The dog is sitting on the sand, and the person is sitting on",
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_model_pt_chat_template_new_text(self):
|
||||
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
outputs = pipe(text=messages, return_full_text=False)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{
|
||||
"input_text": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
],
|
||||
"generated_text": "In the image, a woman is sitting on the sandy beach, her legs crossed in a relaxed manner",
|
||||
}
|
||||
],
|
||||
)
|
@ -14,7 +14,12 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, is_vision_available, pipeline
|
||||
from transformers import (
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||
ZeroShotObjectDetectionPipeline,
|
||||
is_vision_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
@ -52,9 +57,11 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase):
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
object_detector = pipeline(
|
||||
"zero-shot-object-detection",
|
||||
model="hf-internal-testing/tiny-random-owlvit-object-detection",
|
||||
object_detector = ZeroShotObjectDetectionPipeline(
|
||||
model=model,
|
||||
processor=processor,
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
@ -67,7 +74,7 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase):
|
||||
return object_detector, examples
|
||||
|
||||
def run_pipeline_test(self, object_detector, examples):
|
||||
outputs = object_detector(examples[0], threshold=0.0)
|
||||
outputs = object_detector(examples[0].get("image"), examples[0].get("candidate_labels"), threshold=0.0)
|
||||
|
||||
n = len(outputs)
|
||||
self.assertGreater(n, 0)
|
||||
|
@ -71,6 +71,7 @@ from .pipelines.test_pipelines_fill_mask import FillMaskPipelineTests
|
||||
from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests
|
||||
from .pipelines.test_pipelines_image_feature_extraction import ImageFeatureExtractionPipelineTests
|
||||
from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests
|
||||
from .pipelines.test_pipelines_image_text_to_text import ImageTextToTextPipelineTests
|
||||
from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests
|
||||
from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests
|
||||
from .pipelines.test_pipelines_mask_generation import MaskGenerationPipelineTests
|
||||
@ -102,6 +103,7 @@ pipeline_test_mapping = {
|
||||
"image-classification": {"test": ImageClassificationPipelineTests},
|
||||
"image-feature-extraction": {"test": ImageFeatureExtractionPipelineTests},
|
||||
"image-segmentation": {"test": ImageSegmentationPipelineTests},
|
||||
"image-text-to-text": {"test": ImageTextToTextPipelineTests},
|
||||
"image-to-image": {"test": ImageToImagePipelineTests},
|
||||
"image-to-text": {"test": ImageToTextPipelineTests},
|
||||
"mask-generation": {"test": MaskGenerationPipelineTests},
|
||||
@ -586,6 +588,18 @@ class PipelineTesterMixin:
|
||||
def test_pipeline_image_segmentation_fp16(self):
|
||||
self.run_task_tests(task="image-segmentation", torch_dtype="float16")
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_pipeline_image_text_to_text(self):
|
||||
self.run_task_tests(task="image-text-to-text")
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_pipeline_image_text_to_text_fp16(self):
|
||||
self.run_task_tests(task="image-text-to-text", torch_dtype="float16")
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
def test_pipeline_image_to_text(self):
|
||||
|
@ -2896,7 +2896,7 @@
|
||||
"model_classes": [
|
||||
"IdeficsForVisionText2Text"
|
||||
],
|
||||
"sha": "2c2f2e2cd6b02a77d0cdd8c3767ba9a6267dbd20"
|
||||
"sha": "a6be81294ff7a3d44f3aef0ed18e42b97c426831"
|
||||
},
|
||||
"IdeficsModel": {
|
||||
"tokenizer_classes": [
|
||||
|
@ -335,6 +335,7 @@ OBJECTS_TO_IGNORE = [
|
||||
"ImageFeatureExtractionPipeline",
|
||||
"ImageGPTConfig",
|
||||
"ImageSegmentationPipeline",
|
||||
"ImageTextToTextPipeline",
|
||||
"ImageToImagePipeline",
|
||||
"ImageToTextPipeline",
|
||||
"InformerConfig",
|
||||
|
@ -69,6 +69,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"),
|
||||
("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"),
|
||||
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
|
||||
("image-text-to-text", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"),
|
||||
("image-to-image", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES", "AutoModelForImageToImage"),
|
||||
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
|
||||
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
|
||||
|
Loading…
Reference in New Issue
Block a user