mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add Visual Question Answering (VQA) pipeline (#17286)
* wip * rebase * all tests pass * rebase * ready for PR * address comments * fix styles * add require_torch to pipeline test * remove remote image to improve CI consistency * address comments; fix tf/flax tests * address comments; fix tf/flax tests * fix tests; add alias * repo consistency tests * Update src/transformers/pipelines/visual_question_answering.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * address comments * Update src/transformers/pipelines/visual_question_answering.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * merge * Update src/transformers/models/auto/modeling_auto.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * merge Co-authored-by: Sijun He <sijunhe@Sijuns-MacBook-Pro.local> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
a5282ab4bc
commit
66336dc183
@ -38,6 +38,7 @@ There are two categories of pipeline abstractions to be aware about:
|
||||
- [`Text2TextGenerationPipeline`]
|
||||
- [`TokenClassificationPipeline`]
|
||||
- [`TranslationPipeline`]
|
||||
- [`VisualQuestionAnsweringPipeline`]
|
||||
- [`ZeroShotClassificationPipeline`]
|
||||
- [`ZeroShotImageClassificationPipeline`]
|
||||
|
||||
@ -423,6 +424,12 @@ See [`TokenClassificationPipeline`] for all details.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### VisualQuestionAnsweringPipeline
|
||||
|
||||
[[autodoc]] VisualQuestionAnsweringPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ZeroShotClassificationPipeline
|
||||
|
||||
[[autodoc]] ZeroShotClassificationPipeline
|
||||
|
@ -122,6 +122,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
||||
|
||||
[[autodoc]] AutoModelForVision2Seq
|
||||
|
||||
## AutoModelForVisualQuestionAnswering
|
||||
|
||||
[[autodoc]] AutoModelForVisualQuestionAnswering
|
||||
|
||||
## AutoModelForAudioClassification
|
||||
|
||||
[[autodoc]] AutoModelForAudioClassification
|
||||
|
@ -377,6 +377,7 @@ _import_structure = {
|
||||
"TextGenerationPipeline",
|
||||
"TokenClassificationPipeline",
|
||||
"TranslationPipeline",
|
||||
"VisualQuestionAnsweringPipeline",
|
||||
"ZeroShotClassificationPipeline",
|
||||
"ZeroShotImageClassificationPipeline",
|
||||
"pipeline",
|
||||
@ -758,6 +759,7 @@ else:
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"AutoModel",
|
||||
@ -783,6 +785,7 @@ else:
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
"AutoModelForTokenClassification",
|
||||
"AutoModelForVision2Seq",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
"AutoModelWithLMHead",
|
||||
]
|
||||
)
|
||||
@ -2961,6 +2964,7 @@ if TYPE_CHECKING:
|
||||
TextGenerationPipeline,
|
||||
TokenClassificationPipeline,
|
||||
TranslationPipeline,
|
||||
VisualQuestionAnsweringPipeline,
|
||||
ZeroShotClassificationPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
pipeline,
|
||||
@ -3291,6 +3295,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
@ -3316,6 +3321,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
from .models.bart import (
|
||||
|
@ -64,6 +64,7 @@ else:
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"AutoModel",
|
||||
@ -89,6 +90,7 @@ else:
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
"AutoModelForTokenClassification",
|
||||
"AutoModelForVision2Seq",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
"AutoModelWithLMHead",
|
||||
]
|
||||
|
||||
@ -202,6 +204,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
@ -227,6 +230,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
|
||||
|
@ -64,6 +64,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
("swin", "ViTFeatureExtractor"),
|
||||
("van", "ConvNextFeatureExtractor"),
|
||||
("vilt", "ViltFeatureExtractor"),
|
||||
("vit", "ViTFeatureExtractor"),
|
||||
("vit_mae", "ViTFeatureExtractor"),
|
||||
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
||||
|
@ -548,6 +548,12 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("vilt", "ViltForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
@ -706,6 +712,9 @@ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
||||
@ -813,6 +822,17 @@ AutoModelForTableQuestionAnswering = auto_class_update(
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
|
||||
AutoModelForVisualQuestionAnswering = auto_class_update(
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
head_doc="visual question answering",
|
||||
checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForTokenClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
|
||||
|
@ -229,6 +229,7 @@ else:
|
||||
("tapas", ("TapasTokenizer", None)),
|
||||
("tapex", ("TapexTokenizer", None)),
|
||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
|
||||
|
@ -61,6 +61,7 @@ from .token_classification import (
|
||||
TokenClassificationArgumentHandler,
|
||||
TokenClassificationPipeline,
|
||||
)
|
||||
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
||||
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
||||
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
||||
|
||||
@ -94,6 +95,7 @@ if is_torch_available():
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForCausalLM,
|
||||
@ -109,6 +111,7 @@ if is_torch_available():
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_tf_utils import TFPreTrainedModel
|
||||
@ -121,6 +124,7 @@ logger = logging.get_logger(__name__)
|
||||
TASK_ALIASES = {
|
||||
"sentiment-analysis": "text-classification",
|
||||
"ner": "token-classification",
|
||||
"vqa": "visual-question-answering",
|
||||
}
|
||||
SUPPORTED_TASKS = {
|
||||
"audio-classification": {
|
||||
@ -190,6 +194,19 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"visual-question-answering": {
|
||||
"impl": VisualQuestionAnsweringPipeline,
|
||||
"pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),
|
||||
"tf": (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": "dandelin/vilt-b32-finetuned-vqa",
|
||||
"tokenizer": "dandelin/vilt-b32-finetuned-vqa",
|
||||
"feature_extractor": "dandelin/vilt-b32-finetuned-vqa",
|
||||
},
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"fill-mask": {
|
||||
"impl": FillMaskPipeline,
|
||||
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
||||
|
115
src/transformers/pipelines/visual_question_answering.py
Normal file
115
src/transformers/pipelines/visual_question_answering.py
Normal file
@ -0,0 +1,115 @@
|
||||
from typing import Union
|
||||
|
||||
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class VisualQuestionAnsweringPipeline(Pipeline):
|
||||
"""
|
||||
Visual Question Answering pipeline using a `AutoModelForVisualQuestionAnswering`. This pipeline is currently only
|
||||
available in PyTorch.
|
||||
|
||||
This visual question answering pipeline can currently be loaded from [`pipeline`] using the following task
|
||||
identifiers: `"visual-question-answering", "vqa"`.
|
||||
|
||||
The models that this pipeline can use are models that have been fine-tuned on a visual question answering task. See
|
||||
the up-to-date list of available models on
|
||||
[huggingface.co/models](https://huggingface.co/models?filter=visual-question-answering).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, **kwargs):
|
||||
preprocess_params, postprocess_params = {}, {}
|
||||
if padding is not None:
|
||||
preprocess_params["padding"] = padding
|
||||
if truncation is not None:
|
||||
preprocess_params["truncation"] = truncation
|
||||
if top_k is not None:
|
||||
postprocess_params["top_k"] = top_k
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, image: Union["Image.Image", str], question: str = None, **kwargs):
|
||||
r"""
|
||||
Answers open-ended questions about images. The pipeline accepts several types of inputs which are detailed
|
||||
below:
|
||||
|
||||
- `pipeline(image=image, question=question)`
|
||||
- `pipeline({"image": image, "question": question})`
|
||||
- `pipeline([{"image": image, "question": question}])`
|
||||
- `pipeline([{"image": image, "question": question}, {"image": image, "question": question}])`
|
||||
|
||||
Args:
|
||||
image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images. If given a single image, it can be
|
||||
broadcasted to multiple questions.
|
||||
question (`str`, `List[str]`):
|
||||
The question(s) asked. If given a single question, it can be broadcasted to multiple images.
|
||||
top_k (`int`, *optional*, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
|
||||
|
||||
- **label** (`str`) -- The label identified by the model.
|
||||
- **score** (`int`) -- The score attributed by the model for that label.
|
||||
"""
|
||||
if isinstance(image, (Image.Image, str)) and isinstance(question, str):
|
||||
inputs = {"image": image, "question": question}
|
||||
else:
|
||||
"""
|
||||
Supports the following format
|
||||
- {"image": image, "question": question}
|
||||
- [{"image": image, "question": question}]
|
||||
- Generator and datasets
|
||||
"""
|
||||
inputs = image
|
||||
results = super().__call__(inputs, **kwargs)
|
||||
return results
|
||||
|
||||
def preprocess(self, inputs, padding=False, truncation=False):
|
||||
image = load_image(inputs["image"])
|
||||
model_inputs = self.tokenizer(
|
||||
inputs["question"], return_tensors=self.framework, padding=padding, truncation=truncation
|
||||
)
|
||||
image_features = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||
model_inputs.update(image_features)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, top_k=5):
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
|
||||
if self.framework == "pt":
|
||||
probs = model_outputs.logits.sigmoid()[0]
|
||||
scores, ids = probs.topk(top_k)
|
||||
else:
|
||||
raise ValueError(f"Unsupported framework: {self.framework}")
|
||||
|
||||
scores = scores.tolist()
|
||||
ids = ids.tolist()
|
||||
return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|
@ -409,6 +409,9 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
MODEL_MAPPING = None
|
||||
|
||||
|
||||
@ -576,6 +579,13 @@ class AutoModelForVision2Seq(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelWithLMHead(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
115
tests/pipelines/test_pipelines_visual_question_answering.py
Normal file
115
tests/pipelines/test_pipelines_visual_question_answering.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, is_vision_available
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_torch
|
||||
@require_vision
|
||||
class VisualQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
|
||||
examples = [
|
||||
{
|
||||
"image": Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"question": "How many cats are there?",
|
||||
},
|
||||
{
|
||||
"image": "./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"question": "How many cats are there?",
|
||||
},
|
||||
]
|
||||
return vqa_pipeline, examples
|
||||
|
||||
def run_pipeline_test(self, vqa_pipeline, examples):
|
||||
outputs = vqa_pipeline(examples, top_k=1)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[{"score": ANY(float), "answer": ANY(str)}],
|
||||
[{"score": ANY(float), "answer": ANY(str)}],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
question = "How many cats are there?"
|
||||
|
||||
outputs = vqa_pipeline(image=image, question="How many cats are there?", top_k=2)
|
||||
self.assertEqual(
|
||||
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
|
||||
)
|
||||
|
||||
outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
|
||||
self.assertEqual(
|
||||
outputs, [{"score": ANY(float), "answer": ANY(str)}, {"score": ANY(float), "answer": ANY(str)}]
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
vqa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
question = "How many cats are there?"
|
||||
|
||||
outputs = vqa_pipeline(image=image, question=question, top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
|
||||
)
|
||||
|
||||
outputs = vqa_pipeline({"image": image, "question": question}, top_k=2)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]
|
||||
)
|
||||
|
||||
outputs = vqa_pipeline(
|
||||
[{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[[{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}]] * 2,
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Visual question answering not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user