mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Add ZeroShotObjectDetectionPipeline (#18445) * Add AutoModelForZeroShotObjectDetection task This commit also adds the following - Add explicit _processor method for ZeroShotObjectDetectionPipeline. This is necessary as pipelines don't auto infer processors yet and `OwlVitProcessor` wraps tokenizer and feature_extractor together, to process multiple images at once - Add auto tests and other tests for ZeroShotObjectDetectionPipeline * Add AutoModelForZeroShotObjectDetection task This commit also adds the following - Add explicit _processor method for ZeroShotObjectDetectionPipeline. This is necessary as pipelines don't auto infer processors yet and `OwlVitProcessor` wraps tokenizer and feature_extractor together, to process multiple images at once - Add auto tests and other tests for ZeroShotObjectDetectionPipeline * Add batching for ZeroShotObjectDetectionPipeline * Fix doc-string ZeroShotObjectDetectionPipeline * Fix output format: ZeroShotObjectDetectionPipeline
This commit is contained in:
parent
331ea019d7
commit
e9a49babee
@ -43,6 +43,7 @@ There are two categories of pipeline abstractions to be aware about:
|
||||
- [`VisualQuestionAnsweringPipeline`]
|
||||
- [`ZeroShotClassificationPipeline`]
|
||||
- [`ZeroShotImageClassificationPipeline`]
|
||||
- [`ZeroShotObjectDetectionPipeline`]
|
||||
|
||||
## The pipeline abstraction
|
||||
|
||||
@ -456,6 +457,12 @@ See [`TokenClassificationPipeline`] for all details.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ZeroShotObjectDetectionPipeline
|
||||
|
||||
[[autodoc]] ZeroShotObjectDetectionPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
## Parent class: `Pipeline`
|
||||
|
||||
[[autodoc]] Pipeline
|
||||
|
@ -174,6 +174,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
||||
|
||||
[[autodoc]] AutoModelForInstanceSegmentation
|
||||
|
||||
## AutoModelForZeroShotObjectDetection
|
||||
|
||||
[[autodoc]] AutoModelForZeroShotObjectDetection
|
||||
|
||||
## TFAutoModel
|
||||
|
||||
[[autodoc]] TFAutoModel
|
||||
|
@ -442,6 +442,7 @@ _import_structure = {
|
||||
"VisualQuestionAnsweringPipeline",
|
||||
"ZeroShotClassificationPipeline",
|
||||
"ZeroShotImageClassificationPipeline",
|
||||
"ZeroShotObjectDetectionPipeline",
|
||||
"pipeline",
|
||||
],
|
||||
"processing_utils": ["ProcessorMixin"],
|
||||
@ -878,6 +879,7 @@ else:
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
@ -905,6 +907,7 @@ else:
|
||||
"AutoModelForVision2Seq",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
"AutoModelWithLMHead",
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
]
|
||||
)
|
||||
_import_structure["models.bart"].extend(
|
||||
@ -3407,6 +3410,7 @@ if TYPE_CHECKING:
|
||||
VisualQuestionAnsweringPipeline,
|
||||
ZeroShotClassificationPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
ZeroShotObjectDetectionPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from .processing_utils import ProcessorMixin
|
||||
@ -3772,6 +3776,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
@ -3800,6 +3805,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
from .models.bart import (
|
||||
|
@ -69,6 +69,7 @@ else:
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_MAPPING",
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
@ -96,6 +97,7 @@ else:
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
"AutoModelWithLMHead",
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
]
|
||||
|
||||
try:
|
||||
@ -215,6 +217,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
@ -243,6 +246,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
|
||||
|
@ -472,6 +472,13 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Zero Shot Object Detection mapping
|
||||
("owlvit", "OwlViTForObjectDetection")
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
@ -830,6 +837,9 @@ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||
)
|
||||
@ -1016,6 +1026,15 @@ class AutoModelForObjectDetection(_BaseAutoModelClass):
|
||||
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
|
||||
|
||||
|
||||
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||
|
||||
|
||||
AutoModelForZeroShotObjectDetection = auto_class_update(
|
||||
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForVideoClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||
|
||||
|
@ -72,6 +72,7 @@ from .token_classification import (
|
||||
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
||||
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
||||
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
||||
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@ -124,6 +125,7 @@ if is_torch_available():
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVision2Seq,
|
||||
AutoModelForVisualQuestionAnswering,
|
||||
AutoModelForZeroShotObjectDetection,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_tf_utils import TFPreTrainedModel
|
||||
@ -335,6 +337,13 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
|
||||
"type": "image",
|
||||
},
|
||||
"zero-shot-object-detection": {
|
||||
"impl": ZeroShotObjectDetectionPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
}
|
||||
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
|
278
src/transformers/pipelines/zero_shot_object_detection.py
Normal file
278
src/transformers/pipelines/zero_shot_object_detection.py
Normal file
@ -0,0 +1,278 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..tokenization_utils_base import BatchEncoding
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ZeroShotObjectDetectionPipeline(Pipeline):
|
||||
"""
|
||||
Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of
|
||||
objects when you provide an image and a set of `candidate_labels`.
|
||||
|
||||
This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"zero-shot-object-detection"`.
|
||||
|
||||
See the list of available models on
|
||||
[huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.framework == "tf":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Union[str, List[str], "Image.Image", List["Image.Image"]],
|
||||
text_queries: Union[str, List[str], List[List[str]]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an http url pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
text_queries (`str` or `List[str]` or `List[List[str]]`): Text queries to query the target image with.
|
||||
If given multiple images, `text_queries` should be provided as a list of lists, where each nested list
|
||||
contains the text queries for the corresponding image.
|
||||
|
||||
threshold (`float`, *optional*, defaults to 0.1):
|
||||
The probability necessary to make a prediction.
|
||||
|
||||
top_k (`int`, *optional*, defaults to None):
|
||||
The number of top predictions that will be returned by the pipeline. If the provided number is `None`
|
||||
or higher than the number of predictions available, it will default to the number of predictions.
|
||||
|
||||
|
||||
Return:
|
||||
A list of lists containing prediction results, one list per input image. Each list contains dictionaries
|
||||
with the following keys:
|
||||
|
||||
- **label** (`str`) -- Text query corresponding to the found object.
|
||||
- **score** (`float`) -- Score corresponding to the object (between 0 and 1).
|
||||
- **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a
|
||||
dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys.
|
||||
"""
|
||||
if isinstance(text_queries, str) or (isinstance(text_queries, List) and not isinstance(text_queries[0], List)):
|
||||
if isinstance(images, (str, Image.Image)):
|
||||
inputs = {"images": images, "text_queries": text_queries}
|
||||
elif isinstance(images, List):
|
||||
assert len(images) == 1, "Input text_queries and images must have correspondance"
|
||||
inputs = {"images": images[0], "text_queries": text_queries}
|
||||
else:
|
||||
raise TypeError(f"Innapropriate type of images: {type(images)}")
|
||||
|
||||
elif isinstance(text_queries, str) or (isinstance(text_queries, List) and isinstance(text_queries[0], List)):
|
||||
if isinstance(images, (Image.Image, str)):
|
||||
images = [images]
|
||||
assert len(images) == len(text_queries), "Input text_queries and images must have correspondance"
|
||||
inputs = {"images": images, "text_queries": text_queries}
|
||||
else:
|
||||
"""
|
||||
Supports the following format
|
||||
- {"images": images, "text_queries": text_queries}
|
||||
"""
|
||||
inputs = images
|
||||
results = super().__call__(inputs, **kwargs)
|
||||
return results
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
postprocess_params = {}
|
||||
if "threshold" in kwargs:
|
||||
postprocess_params["threshold"] = kwargs["threshold"]
|
||||
if "top_k" in kwargs:
|
||||
postprocess_params["top_k"] = kwargs["top_k"]
|
||||
return {}, {}, postprocess_params
|
||||
|
||||
def preprocess(self, inputs):
|
||||
if not isinstance(inputs["images"], List):
|
||||
inputs["images"] = [inputs["images"]]
|
||||
images = [load_image(img) for img in inputs["images"]]
|
||||
text_queries = inputs["text_queries"]
|
||||
if isinstance(text_queries, str) or isinstance(text_queries[0], str):
|
||||
text_queries = [text_queries]
|
||||
|
||||
target_sizes = [torch.IntTensor([[img.height, img.width]]) for img in images]
|
||||
target_sizes = torch.cat(target_sizes)
|
||||
inputs = self._processor(text=inputs["text_queries"], images=images, return_tensors="pt")
|
||||
return {"target_sizes": target_sizes, "text_queries": text_queries, **inputs}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
target_sizes = model_inputs.pop("target_sizes")
|
||||
text_queries = model_inputs.pop("text_queries")
|
||||
outputs = self.model(**model_inputs)
|
||||
|
||||
model_outputs = outputs.__class__({"target_sizes": target_sizes, "text_queries": text_queries, **outputs})
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, threshold=0.1, top_k=None):
|
||||
texts = model_outputs["text_queries"]
|
||||
|
||||
outputs = self.feature_extractor.post_process(
|
||||
outputs=model_outputs, target_sizes=model_outputs["target_sizes"]
|
||||
)
|
||||
|
||||
results = []
|
||||
for i in range(len(outputs)):
|
||||
keep = outputs[i]["scores"] >= threshold
|
||||
labels = outputs[i]["labels"][keep].tolist()
|
||||
scores = outputs[i]["scores"][keep].tolist()
|
||||
boxes = [self._get_bounding_box(box) for box in outputs[i]["boxes"][keep]]
|
||||
|
||||
result = [
|
||||
{"score": score, "label": texts[i][label], "box": box}
|
||||
for score, label, box in zip(scores, labels, boxes)
|
||||
]
|
||||
|
||||
result = sorted(result, key=lambda x: x["score"], reverse=True)
|
||||
if top_k:
|
||||
result = result[:top_k]
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
|
||||
"""
|
||||
Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
|
||||
|
||||
Args:
|
||||
box (`torch.Tensor`): Tensor containing the coordinates in corners format.
|
||||
|
||||
Returns:
|
||||
bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
|
||||
"""
|
||||
if self.framework != "pt":
|
||||
raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.")
|
||||
xmin, ymin, xmax, ymax = box.int().tolist()
|
||||
bbox = {
|
||||
"xmin": xmin,
|
||||
"ymin": ymin,
|
||||
"xmax": xmax,
|
||||
"ymax": ymax,
|
||||
}
|
||||
return bbox
|
||||
|
||||
# Replication of OwlViTProcessor __call__ method, since pipelines don't auto infer processor's yet!
|
||||
def _processor(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs):
|
||||
"""
|
||||
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
|
||||
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the
|
||||
doctsring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
||||
`List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify at least one text or image. Both cannot be none.")
|
||||
|
||||
if text is not None:
|
||||
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
|
||||
encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
|
||||
|
||||
elif isinstance(text, List) and isinstance(text[0], List):
|
||||
encodings = []
|
||||
|
||||
# Maximum number of queries across batch
|
||||
max_num_queries = max([len(t) for t in text])
|
||||
|
||||
# Pad all batch samples to max number of text queries
|
||||
for t in text:
|
||||
if len(t) != max_num_queries:
|
||||
t = t + [" "] * (max_num_queries - len(t))
|
||||
|
||||
encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
|
||||
encodings.append(encoding)
|
||||
else:
|
||||
raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
|
||||
|
||||
if return_tensors == "np":
|
||||
input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
|
||||
attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
|
||||
|
||||
elif return_tensors == "pt" and is_torch_available():
|
||||
import torch
|
||||
|
||||
input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0)
|
||||
attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0)
|
||||
|
||||
elif return_tensors == "tf" and is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0)
|
||||
attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0)
|
||||
|
||||
else:
|
||||
raise ValueError("Target return tensor type could not be returned")
|
||||
|
||||
encoding = BatchEncoding()
|
||||
encoding["input_ids"] = input_ids
|
||||
encoding["attention_mask"] = attention_mask
|
||||
|
||||
if images is not None:
|
||||
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
if text is not None and images is not None:
|
||||
encoding["pixel_values"] = image_features.pixel_values
|
||||
return encoding
|
||||
elif text is not None:
|
||||
return encoding
|
||||
else:
|
||||
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
@ -418,6 +418,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None
|
||||
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_MAPPING = None
|
||||
|
||||
|
||||
@ -606,6 +609,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForZeroShotObjectDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelWithLMHead(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
263
tests/pipelines/test_pipelines_zero_shot_object_detection.py
Normal file
263
tests/pipelines/test_pipelines_zero_shot_object_detection.py
Normal file
@ -0,0 +1,263 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, is_vision_available, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@is_pipeline_test
|
||||
class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
|
||||
model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
object_detector = pipeline(
|
||||
"zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection"
|
||||
)
|
||||
|
||||
examples = [
|
||||
{
|
||||
"images": "./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"text_queries": ["cat", "remote", "couch"],
|
||||
}
|
||||
]
|
||||
return object_detector, examples
|
||||
|
||||
def run_pipeline_test(self, object_detector, examples):
|
||||
batch_outputs = object_detector(examples, threshold=0.0)
|
||||
|
||||
self.assertEqual(len(examples), len(batch_outputs))
|
||||
for outputs in batch_outputs:
|
||||
for output_per_image in outputs:
|
||||
self.assertGreater(len(output_per_image), 0)
|
||||
for detected_object in output_per_image:
|
||||
self.assertEqual(
|
||||
detected_object,
|
||||
{
|
||||
"score": ANY(float),
|
||||
"label": ANY(str),
|
||||
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
|
||||
},
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Zero Shot Object Detection not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
object_detector = pipeline(
|
||||
"zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection"
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
text_queries=["cat", "remote", "couch"],
|
||||
threshold=0.64,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
["./tests/fixtures/tests_samples/COCO/000000039769.png"],
|
||||
text_queries=["cat", "remote", "couch"],
|
||||
threshold=0.64,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
text_queries=[["cat", "remote", "couch"]],
|
||||
threshold=0.64,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
[
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]],
|
||||
threshold=0.64,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||
],
|
||||
[
|
||||
{"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}},
|
||||
{"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}},
|
||||
{"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}},
|
||||
{"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_large_model_pt(self):
|
||||
object_detector = pipeline("zero-shot-object-detection")
|
||||
|
||||
outputs = object_detector(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg", text_queries=["cat", "remote", "couch"]
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
|
||||
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]],
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
|
||||
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
|
||||
],
|
||||
[
|
||||
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||
{"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}},
|
||||
{"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Zero Shot Object Detection not implemented in TF")
|
||||
def test_large_model_tf(self):
|
||||
pass
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_threshold(self):
|
||||
threshold = 0.2
|
||||
object_detector = pipeline("zero-shot-object-detection")
|
||||
|
||||
outputs = object_detector(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
text_queries=["cat", "remote", "couch"],
|
||||
threshold=threshold,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||
{"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_top_k(self):
|
||||
top_k = 2
|
||||
object_detector = pipeline("zero-shot-object-detection")
|
||||
|
||||
outputs = object_detector(
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
text_queries=["cat", "remote", "couch"],
|
||||
top_k=top_k,
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}},
|
||||
{"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}},
|
||||
]
|
||||
],
|
||||
)
|
@ -58,6 +58,11 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
|
||||
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
|
||||
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
|
||||
(
|
||||
"zero-shot-object-detection",
|
||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES",
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
),
|
||||
("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
|
||||
("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
|
||||
("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),
|
||||
|
Loading…
Reference in New Issue
Block a user