[WIP] Add ZeroShotObjectDetectionPipeline (#18445) (#18930)

* Add ZeroShotObjectDetectionPipeline (#18445)

* Add AutoModelForZeroShotObjectDetection task

This commit also adds the following

- Add explicit _processor method for ZeroShotObjectDetectionPipeline.
  This is necessary as pipelines don't auto infer processors yet and
  `OwlVitProcessor` wraps tokenizer and feature_extractor together, to
  process multiple images at once

- Add auto tests and other tests for ZeroShotObjectDetectionPipeline

* Add AutoModelForZeroShotObjectDetection task

This commit also adds the following

- Add explicit _processor method for ZeroShotObjectDetectionPipeline.
  This is necessary as pipelines don't auto infer processors yet and
  `OwlVitProcessor` wraps tokenizer and feature_extractor together, to
  process multiple images at once

- Add auto tests and other tests for ZeroShotObjectDetectionPipeline

* Add batching for ZeroShotObjectDetectionPipeline

* Fix doc-string ZeroShotObjectDetectionPipeline

* Fix output format: ZeroShotObjectDetectionPipeline
This commit is contained in:
Amrit Sahu 2022-10-07 19:30:19 +05:30 committed by GitHub
parent 331ea019d7
commit e9a49babee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 605 additions and 0 deletions

View File

@ -43,6 +43,7 @@ There are two categories of pipeline abstractions to be aware about:
- [`VisualQuestionAnsweringPipeline`]
- [`ZeroShotClassificationPipeline`]
- [`ZeroShotImageClassificationPipeline`]
- [`ZeroShotObjectDetectionPipeline`]
## The pipeline abstraction
@ -456,6 +457,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all
### ZeroShotObjectDetectionPipeline
[[autodoc]] ZeroShotObjectDetectionPipeline
- __call__
- all
## Parent class: `Pipeline`
[[autodoc]] Pipeline

View File

@ -174,6 +174,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
[[autodoc]] AutoModelForInstanceSegmentation
## AutoModelForZeroShotObjectDetection
[[autodoc]] AutoModelForZeroShotObjectDetection
## TFAutoModel
[[autodoc]] TFAutoModel

View File

@ -442,6 +442,7 @@ _import_structure = {
"VisualQuestionAnsweringPipeline",
"ZeroShotClassificationPipeline",
"ZeroShotImageClassificationPipeline",
"ZeroShotObjectDetectionPipeline",
"pipeline",
],
"processing_utils": ["ProcessorMixin"],
@ -878,6 +879,7 @@ else:
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel",
"AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
@ -905,6 +907,7 @@ else:
"AutoModelForVision2Seq",
"AutoModelForVisualQuestionAnswering",
"AutoModelWithLMHead",
"AutoModelForZeroShotObjectDetection",
]
)
_import_structure["models.bart"].extend(
@ -3407,6 +3410,7 @@ if TYPE_CHECKING:
VisualQuestionAnsweringPipeline,
ZeroShotClassificationPipeline,
ZeroShotImageClassificationPipeline,
ZeroShotObjectDetectionPipeline,
pipeline,
)
from .processing_utils import ProcessorMixin
@ -3772,6 +3776,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
@ -3800,6 +3805,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead,
)
from .models.bart import (

View File

@ -69,6 +69,7 @@ else:
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel",
"AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
@ -96,6 +97,7 @@ else:
"AutoModelForVisualQuestionAnswering",
"AutoModelForDocumentQuestionAnswering",
"AutoModelWithLMHead",
"AutoModelForZeroShotObjectDetection",
]
try:
@ -215,6 +217,7 @@ if TYPE_CHECKING:
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModel,
@ -243,6 +246,7 @@ if TYPE_CHECKING:
AutoModelForVideoClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead,
)

View File

@ -472,6 +472,13 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Object Detection mapping
("owlvit", "OwlViTForObjectDetection")
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
@ -830,6 +837,9 @@ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
@ -1016,6 +1026,15 @@ class AutoModelForObjectDetection(_BaseAutoModelClass):
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
AutoModelForZeroShotObjectDetection = auto_class_update(
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
)
class AutoModelForVideoClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING

View File

@ -72,6 +72,7 @@ from .token_classification import (
from .visual_question_answering import VisualQuestionAnsweringPipeline
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
if is_tf_available():
@ -124,6 +125,7 @@ if is_torch_available():
AutoModelForTokenClassification,
AutoModelForVision2Seq,
AutoModelForVisualQuestionAnswering,
AutoModelForZeroShotObjectDetection,
)
if TYPE_CHECKING:
from ..modeling_tf_utils import TFPreTrainedModel
@ -335,6 +337,13 @@ SUPPORTED_TASKS = {
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
"type": "image",
},
"zero-shot-object-detection": {
"impl": ZeroShotObjectDetectionPipeline,
"tf": (),
"pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),
"default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},
"type": "multimodal",
},
}
NO_FEATURE_EXTRACTOR_TASKS = set()

View File

@ -0,0 +1,278 @@
from typing import Dict, List, Union
import numpy as np
from ..tokenization_utils_base import BatchEncoding
from ..utils import (
add_end_docstrings,
is_tf_available,
is_torch_available,
is_vision_available,
logging,
requires_backends,
)
from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available():
from PIL import Image
from ..image_utils import load_image
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ZeroShotObjectDetectionPipeline(Pipeline):
"""
Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of
objects when you provide an image and a set of `candidate_labels`.
This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"zero-shot-object-detection"`.
See the list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection).
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.framework == "tf":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING)
def __call__(
self,
images: Union[str, List[str], "Image.Image", List["Image.Image"]],
text_queries: Union[str, List[str], List[List[str]]] = None,
**kwargs
):
"""
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
Args:
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing an http url pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly
text_queries (`str` or `List[str]` or `List[List[str]]`): Text queries to query the target image with.
If given multiple images, `text_queries` should be provided as a list of lists, where each nested list
contains the text queries for the corresponding image.
threshold (`float`, *optional*, defaults to 0.1):
The probability necessary to make a prediction.
top_k (`int`, *optional*, defaults to None):
The number of top predictions that will be returned by the pipeline. If the provided number is `None`
or higher than the number of predictions available, it will default to the number of predictions.
Return:
A list of lists containing prediction results, one list per input image. Each list contains dictionaries
with the following keys:
- **label** (`str`) -- Text query corresponding to the found object.
- **score** (`float`) -- Score corresponding to the object (between 0 and 1).
- **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a
dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys.
"""
if isinstance(text_queries, str) or (isinstance(text_queries, List) and not isinstance(text_queries[0], List)):
if isinstance(images, (str, Image.Image)):
inputs = {"images": images, "text_queries": text_queries}
elif isinstance(images, List):
assert len(images) == 1, "Input text_queries and images must have correspondance"
inputs = {"images": images[0], "text_queries": text_queries}
else:
raise TypeError(f"Innapropriate type of images: {type(images)}")
elif isinstance(text_queries, str) or (isinstance(text_queries, List) and isinstance(text_queries[0], List)):
if isinstance(images, (Image.Image, str)):
images = [images]
assert len(images) == len(text_queries), "Input text_queries and images must have correspondance"
inputs = {"images": images, "text_queries": text_queries}
else:
"""
Supports the following format
- {"images": images, "text_queries": text_queries}
"""
inputs = images
results = super().__call__(inputs, **kwargs)
return results
def _sanitize_parameters(self, **kwargs):
postprocess_params = {}
if "threshold" in kwargs:
postprocess_params["threshold"] = kwargs["threshold"]
if "top_k" in kwargs:
postprocess_params["top_k"] = kwargs["top_k"]
return {}, {}, postprocess_params
def preprocess(self, inputs):
if not isinstance(inputs["images"], List):
inputs["images"] = [inputs["images"]]
images = [load_image(img) for img in inputs["images"]]
text_queries = inputs["text_queries"]
if isinstance(text_queries, str) or isinstance(text_queries[0], str):
text_queries = [text_queries]
target_sizes = [torch.IntTensor([[img.height, img.width]]) for img in images]
target_sizes = torch.cat(target_sizes)
inputs = self._processor(text=inputs["text_queries"], images=images, return_tensors="pt")
return {"target_sizes": target_sizes, "text_queries": text_queries, **inputs}
def _forward(self, model_inputs):
target_sizes = model_inputs.pop("target_sizes")
text_queries = model_inputs.pop("text_queries")
outputs = self.model(**model_inputs)
model_outputs = outputs.__class__({"target_sizes": target_sizes, "text_queries": text_queries, **outputs})
return model_outputs
def postprocess(self, model_outputs, threshold=0.1, top_k=None):
texts = model_outputs["text_queries"]
outputs = self.feature_extractor.post_process(
outputs=model_outputs, target_sizes=model_outputs["target_sizes"]
)
results = []
for i in range(len(outputs)):
keep = outputs[i]["scores"] >= threshold
labels = outputs[i]["labels"][keep].tolist()
scores = outputs[i]["scores"][keep].tolist()
boxes = [self._get_bounding_box(box) for box in outputs[i]["boxes"][keep]]
result = [
{"score": score, "label": texts[i][label], "box": box}
for score, label, box in zip(scores, labels, boxes)
]
result = sorted(result, key=lambda x: x["score"], reverse=True)
if top_k:
result = result[:top_k]
results.append(result)
return results
def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
"""
Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
Args:
box (`torch.Tensor`): Tensor containing the coordinates in corners format.
Returns:
bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
"""
if self.framework != "pt":
raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.")
xmin, ymin, xmax, ymax = box.int().tolist()
bbox = {
"xmin": xmin,
"ymin": ymin,
"xmax": xmax,
"ymax": ymax,
}
return bbox
# Replication of OwlViTProcessor __call__ method, since pipelines don't auto infer processor's yet!
def _processor(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs):
"""
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the
doctsring of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
`List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if text is None and images is None:
raise ValueError("You have to specify at least one text or image. Both cannot be none.")
if text is not None:
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
elif isinstance(text, List) and isinstance(text[0], List):
encodings = []
# Maximum number of queries across batch
max_num_queries = max([len(t) for t in text])
# Pad all batch samples to max number of text queries
for t in text:
if len(t) != max_num_queries:
t = t + [" "] * (max_num_queries - len(t))
encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
encodings.append(encoding)
else:
raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
if return_tensors == "np":
input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
elif return_tensors == "pt" and is_torch_available():
import torch
input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0)
attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0)
elif return_tensors == "tf" and is_tf_available():
import tensorflow as tf
input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0)
attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0)
else:
raise ValueError("Target return tensor type could not be returned")
encoding = BatchEncoding()
encoding["input_ids"] = input_ids
encoding["attention_mask"] = attention_mask
if images is not None:
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)
if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)

View File

@ -418,6 +418,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
MODEL_MAPPING = None
@ -606,6 +609,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelWithLMHead(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -0,0 +1,263 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, is_vision_available, pipeline
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_vision,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
if is_vision_available():
from PIL import Image
else:
class Image:
@staticmethod
def open(*args, **kwargs):
pass
@require_vision
@require_torch
@is_pipeline_test
class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor):
object_detector = pipeline(
"zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection"
)
examples = [
{
"images": "./tests/fixtures/tests_samples/COCO/000000039769.png",
"text_queries": ["cat", "remote", "couch"],
}
]
return object_detector, examples
def run_pipeline_test(self, object_detector, examples):
batch_outputs = object_detector(examples, threshold=0.0)
self.assertEqual(len(examples), len(batch_outputs))
for outputs in batch_outputs:
for output_per_image in outputs:
self.assertGreater(len(output_per_image), 0)
for detected_object in output_per_image:
self.assertEqual(
detected_object,
{
"score": ANY(float),
"label": ANY(str),
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
},
)
@require_tf
@unittest.skip("Zero Shot Object Detection not implemented in TF")
def test_small_model_tf(self):
pass
@require_torch
def test_small_model_pt(self):
object_detector = pipeline(
"zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection"
)
outputs = object_detector(
"./tests/fixtures/tests_samples/COCO/000000039769.png",
text_queries=["cat", "remote", "couch"],
threshold=0.64,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
]
],
)
outputs = object_detector(
["./tests/fixtures/tests_samples/COCO/000000039769.png"],
text_queries=["cat", "remote", "couch"],
threshold=0.64,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
]
],
)
outputs = object_detector(
"./tests/fixtures/tests_samples/COCO/000000039769.png",
text_queries=[["cat", "remote", "couch"]],
threshold=0.64,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
]
],
)
outputs = object_detector(
[
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"http://images.cocodataset.org/val2017/000000039769.jpg",
],
text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]],
threshold=0.64,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
],
[
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
],
],
)
@require_torch
@slow
def test_large_model_pt(self):
object_detector = pipeline("zero-shot-object-detection")
outputs = object_detector(
"http://images.cocodataset.org/val2017/000000039769.jpg", text_queries=["cat", "remote", "couch"]
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
]
],
)
outputs = object_detector(
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
],
text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]],
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
],
[
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
],
],
)
@require_tf
@unittest.skip("Zero Shot Object Detection not implemented in TF")
def test_large_model_tf(self):
pass
@require_torch
@slow
def test_threshold(self):
threshold = 0.2
object_detector = pipeline("zero-shot-object-detection")
outputs = object_detector(
"http://images.cocodataset.org/val2017/000000039769.jpg",
text_queries=["cat", "remote", "couch"],
threshold=threshold,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
]
],
)
@require_torch
@slow
def test_top_k(self):
top_k = 2
object_detector = pipeline("zero-shot-object-detection")
outputs = object_detector(
"http://images.cocodataset.org/val2017/000000039769.jpg",
text_queries=["cat", "remote", "couch"],
top_k=top_k,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
]
],
)

View File

@ -58,6 +58,11 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
(
"zero-shot-object-detection",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES",
"AutoModelForZeroShotObjectDetection",
),
("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),