mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[image-to-text pipeline] Add conditional text support + GIT (#23362)
* First draft * Remove print statements * Add conditional generation * Add more tests * Remove scripts * Remove BLIP specific linkes * Add support for pix2struct * Add fast test * Address comment * Fix style
This commit is contained in:
parent
e69feab8a1
commit
2f424d7979
@ -529,6 +529,8 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("blip", "BlipForConditionalGeneration"),
|
||||
("blip-2", "Blip2ForConditionalGeneration"),
|
||||
("git", "GitForCausalLM"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
|
||||
]
|
||||
)
|
||||
|
@ -20,6 +20,8 @@ if is_tf_available():
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -56,8 +58,13 @@ class ImageToTextPipeline(Pipeline):
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None):
|
||||
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
|
||||
forward_kwargs = {}
|
||||
preprocess_params = {}
|
||||
|
||||
if prompt is not None:
|
||||
preprocess_params["prompt"] = prompt
|
||||
|
||||
if generate_kwargs is not None:
|
||||
forward_kwargs["generate_kwargs"] = generate_kwargs
|
||||
if max_new_tokens is not None:
|
||||
@ -69,7 +76,7 @@ class ImageToTextPipeline(Pipeline):
|
||||
" please use only one"
|
||||
)
|
||||
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
||||
return {}, forward_kwargs, {}
|
||||
return preprocess_params, forward_kwargs, {}
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
"""
|
||||
@ -98,9 +105,43 @@ class ImageToTextPipeline(Pipeline):
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
def preprocess(self, image):
|
||||
def preprocess(self, image, prompt=None):
|
||||
image = load_image(image)
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
|
||||
if prompt is not None:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError(
|
||||
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
|
||||
"Note also that one single text can be provided for conditional image to text generation."
|
||||
)
|
||||
|
||||
model_type = self.model.config.model_type
|
||||
|
||||
if model_type == "git":
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
|
||||
input_ids = [self.tokenizer.cls_token_id] + input_ids
|
||||
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
||||
model_inputs.update({"input_ids": input_ids})
|
||||
|
||||
elif model_type == "pix2struct":
|
||||
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
|
||||
|
||||
elif model_type != "vision-encoder-decoder":
|
||||
# vision-encoder-decoder does not support conditional generation
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
|
||||
model_inputs.update(text_inputs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Model type {model_type} does not support conditional text generation")
|
||||
|
||||
else:
|
||||
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
|
||||
|
||||
if self.model.config.model_type == "git" and prompt is None:
|
||||
model_inputs["input_ids"] = None
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs, generate_kwargs=None):
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_vision, slow
|
||||
@ -125,6 +127,15 @@ class ImageToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_conditional(self):
|
||||
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-BlipForConditionalGeneration")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
prompt = "a photo of"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertTrue(outputs[0]["generated_text"].startswith(prompt))
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
@ -143,6 +154,71 @@ class ImageToTextPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_generation_pt_blip(self):
|
||||
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
||||
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(outputs, [{"generated_text": "a pink pokemon pokemon with a blue shirt and a blue shirt"}])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_generation_pt_git(self):
|
||||
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
|
||||
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
outputs = pipe(image)
|
||||
self.assertEqual(outputs, [{"generated_text": "a cartoon of a purple character."}])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_conditional_generation_pt_blip(self):
|
||||
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "a photography of"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertEqual(outputs, [{"generated_text": "a photography of a volcano"}])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = pipe([image, image], prompt=[prompt, prompt])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_conditional_generation_pt_git(self):
|
||||
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "a photo of a"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertEqual(outputs, [{"generated_text": "a photo of a tent with a tent and a tent in the background."}])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = pipe([image, image], prompt=[prompt, prompt])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_conditional_generation_pt_pix2struct(self):
|
||||
pipe = pipeline("image-to-text", model="google/pix2struct-ai2d-base")
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
prompt = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
||||
|
||||
outputs = pipe(image, prompt=prompt)
|
||||
self.assertEqual(outputs, [{"generated_text": "ash cloud"}])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = pipe([image, image], prompt=[prompt, prompt])
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_large_model_tf(self):
|
||||
|
Loading…
Reference in New Issue
Block a user