mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Object detection pipeline (#12886)
* Implement object-detection pipeline * Define threshold const * Add `threshold` argument * Refactor * Uncomment test inputs * `rm Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Fix typo Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Fix typo Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Chore better doc Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Rm unnecessary lines Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Chore better naming Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Fix typo * Add `detr-tiny` for tests * Add `ObjectDetectionPipeline` to `trnsfrmrs/init` * Implement new bbox format * Update detr post_process * Update `load_img` method obj det pipeline * make style * Implement new testing format for obj det pipeln * Add guard pytorch specific code in pipeline * Add doc * Make pipeline_obj_tet tests deterministic * Revert some changes to `post_process` COCO api * Chore * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/pipelines/object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Rm timm requirement * make fixup * Add timm requirement to test * Make fixup * Guard torch.Tensor * Chore * Delete unnecessary comment Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
parent
707105290b
commit
2a15e8ccfb
@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about:
|
||||
- :class:`~transformers.FeatureExtractionPipeline`
|
||||
- :class:`~transformers.FillMaskPipeline`
|
||||
- :class:`~transformers.ImageClassificationPipeline`
|
||||
- :class:`~transformers.ObjectDetectionPipeline`
|
||||
- :class:`~transformers.QuestionAnsweringPipeline`
|
||||
- :class:`~transformers.SummarizationPipeline`
|
||||
- :class:`~transformers.TableQuestionAnsweringPipeline`
|
||||
@ -102,6 +103,13 @@ NerPipeline
|
||||
|
||||
See :class:`~transformers.TokenClassificationPipeline` for all details.
|
||||
|
||||
ObjectDetectionPipeline
|
||||
=======================================================================================================================
|
||||
|
||||
.. autoclass:: transformers.ObjectDetectionPipeline
|
||||
:special-members: __call__
|
||||
:members:
|
||||
|
||||
QuestionAnsweringPipeline
|
||||
=======================================================================================================================
|
||||
|
||||
|
@ -142,6 +142,13 @@ AutoModelForAudioClassification
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForObjectDetection
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForObjectDetection
|
||||
:members:
|
||||
|
||||
|
||||
TFAutoModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -294,6 +294,7 @@ _import_structure = {
|
||||
"ImageClassificationPipeline",
|
||||
"JsonPipelineDataFormat",
|
||||
"NerPipeline",
|
||||
"ObjectDetectionPipeline",
|
||||
"PipedPipelineDataFormat",
|
||||
"Pipeline",
|
||||
"PipelineDataFormat",
|
||||
@ -558,6 +559,7 @@ if is_torch_available():
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
"AutoModelForObjectDetection",
|
||||
"AutoModelForPreTraining",
|
||||
"AutoModelForQuestionAnswering",
|
||||
"AutoModelForSeq2SeqLM",
|
||||
@ -2074,6 +2076,7 @@ if TYPE_CHECKING:
|
||||
ImageClassificationPipeline,
|
||||
JsonPipelineDataFormat,
|
||||
NerPipeline,
|
||||
ObjectDetectionPipeline,
|
||||
PipedPipelineDataFormat,
|
||||
Pipeline,
|
||||
PipelineDataFormat,
|
||||
@ -2295,6 +2298,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForNextSentencePrediction,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
@ -52,6 +52,7 @@ if is_torch_available():
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
"AutoModelForObjectDetection",
|
||||
"AutoModelForPreTraining",
|
||||
"AutoModelForQuestionAnswering",
|
||||
"AutoModelForSeq2SeqLM",
|
||||
@ -143,6 +144,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForNextSentencePrediction,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
|
@ -34,6 +34,7 @@ from .configuration_auto import (
|
||||
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("beit", "BeitFeatureExtractor"),
|
||||
("detr", "DetrFeatureExtractor"),
|
||||
("deit", "DeiTFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
|
@ -588,6 +588,13 @@ class AutoModelForImageClassification(_BaseAutoModelClass):
|
||||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||
|
||||
|
||||
class AutoModelForObjectDetection(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||
|
||||
|
||||
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
|
||||
|
||||
|
||||
class AutoModelForAudioClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
|
||||
|
@ -44,6 +44,7 @@ from .conversational import Conversation, ConversationalPipeline
|
||||
from .feature_extraction import FeatureExtractionPipeline
|
||||
from .fill_mask import FillMaskPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .object_detection import ObjectDetectionPipeline
|
||||
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
||||
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
||||
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
|
||||
@ -91,6 +92,7 @@ if is_torch_available():
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
@ -229,6 +231,12 @@ SUPPORTED_TASKS = {
|
||||
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||
},
|
||||
"object-detection": {
|
||||
"impl": ObjectDetectionPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/detr-resnet-50"}},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
176
src/transformers/pipelines/object_detection.py
Normal file
176
src/transformers/pipelines/object_detection.py
Normal file
@ -0,0 +1,176 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
Prediction = Dict[str, Any]
|
||||
Predictions = List[Prediction]
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ObjectDetectionPipeline(Pipeline):
|
||||
"""
|
||||
Object detection pipeline using any :obj:`AutoModelForObjectDetection`. This pipeline predicts bounding boxes of
|
||||
objects and their classes.
|
||||
|
||||
This object detection pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
|
||||
identifier: :obj:`"object-detection"`.
|
||||
|
||||
See the list of available models on `huggingface.co/models
|
||||
<https://huggingface.co/models?filter=object-detection>`__.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
feature_extractor: PreTrainedFeatureExtractor,
|
||||
framework: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs)
|
||||
|
||||
if self.framework == "tf":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
|
||||
requires_backends(self, "vision")
|
||||
|
||||
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
|
||||
|
||||
self.feature_extractor = feature_extractor
|
||||
|
||||
@staticmethod
|
||||
def load_image(image: Union[str, "Image.Image"]):
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
# like http_huggingface_co.png
|
||||
image = Image.open(requests.get(image, stream=True).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
elif isinstance(image, Image.Image):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
||||
)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Union[str, List[str], "Image", List["Image"]],
|
||||
threshold: Optional[float] = 0.9,
|
||||
) -> Union[Predictions, List[Prediction]]:
|
||||
"""
|
||||
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an HTTP(S) link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
|
||||
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
||||
threshold (:obj:`float`, `optional`, defaults to 0.9):
|
||||
The probability necessary to make a prediction.
|
||||
|
||||
Return:
|
||||
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
|
||||
image, will return a list of dictionaries, if the input is a list of several images, will return a list of
|
||||
list of dictionaries corresponding to each image.
|
||||
|
||||
The dictionaries contain the following keys:
|
||||
|
||||
- **label** (:obj:`str`) -- The class label identified by the model.
|
||||
- **score** (:obj:`float`) -- The score attributed by the model for that label.
|
||||
- **box** (:obj:`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size.
|
||||
"""
|
||||
is_batched = isinstance(images, list)
|
||||
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
|
||||
images = [self.load_image(image) for image in images]
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.feature_extractor(images=images, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
if self.framework == "pt":
|
||||
target_sizes = torch.IntTensor([[im.height, im.width] for im in images])
|
||||
else:
|
||||
raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
|
||||
|
||||
raw_annotations = self.feature_extractor.post_process(outputs, target_sizes)
|
||||
annotations = []
|
||||
for annotation in raw_annotations:
|
||||
keep = annotation["scores"] > threshold
|
||||
scores = annotation["scores"][keep]
|
||||
labels = annotation["labels"][keep]
|
||||
boxes = annotation["boxes"][keep]
|
||||
|
||||
annotation["scores"] = scores.tolist()
|
||||
annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels]
|
||||
annotation["boxes"] = [self._get_bounding_box(box) for box in boxes]
|
||||
|
||||
# {"scores": [...], ...} --> [{"score":x, ...}, ...]
|
||||
keys = ["score", "label", "box"]
|
||||
annotation = [
|
||||
dict(zip(keys, vals))
|
||||
for vals in zip(annotation["scores"], annotation["labels"], annotation["boxes"])
|
||||
]
|
||||
|
||||
annotations.append(annotation)
|
||||
|
||||
if not is_batched:
|
||||
return annotations[0]
|
||||
|
||||
return annotations
|
||||
|
||||
def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
|
||||
"""
|
||||
Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
|
||||
|
||||
Args:
|
||||
box (torch.Tensor): Tensor containing the coordinates in corners format.
|
||||
|
||||
Returns:
|
||||
bbox (Dict[str, int]): Dict containing the coordinates in corners format.
|
||||
"""
|
||||
if self.framework != "pt":
|
||||
raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.")
|
||||
xmin, ymin, xmax, ymax = box.int().tolist()
|
||||
bbox = {
|
||||
"xmin": xmin,
|
||||
"ymin": ymin,
|
||||
"xmax": xmax,
|
||||
"ymax": ymax,
|
||||
}
|
||||
return bbox
|
@ -415,6 +415,15 @@ class AutoModelForNextSentencePrediction:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForObjectDetection:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
253
tests/test_pipelines_object_detection.py
Normal file
253
tests/test_pipelines_object_detection.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForObjectDetection,
|
||||
ObjectDetectionPipeline,
|
||||
is_vision_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_datasets,
|
||||
require_tf,
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_timm
|
||||
@require_torch
|
||||
@is_pipeline_test
|
||||
class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||
|
||||
@require_datasets
|
||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||
outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||
|
||||
self.assertGreater(len(outputs), 0)
|
||||
for detected_object in outputs:
|
||||
self.assertEqual(
|
||||
detected_object,
|
||||
{
|
||||
"score": ANY(float),
|
||||
"label": ANY(str),
|
||||
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
|
||||
},
|
||||
)
|
||||
|
||||
import datasets
|
||||
|
||||
dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test")
|
||||
|
||||
batch = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
# RGBA
|
||||
dataset[0]["file"],
|
||||
# LA
|
||||
dataset[1]["file"],
|
||||
# L
|
||||
dataset[2]["file"],
|
||||
]
|
||||
batch_outputs = object_detector(batch, threshold=0.0)
|
||||
|
||||
self.assertEqual(len(batch), len(batch_outputs))
|
||||
for outputs in batch_outputs:
|
||||
self.assertGreater(len(outputs), 0)
|
||||
for detected_object in outputs:
|
||||
self.assertEqual(
|
||||
detected_object,
|
||||
{
|
||||
"score": ANY(float),
|
||||
"label": ANY(str),
|
||||
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
|
||||
},
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Object detection not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
model_id = "mishig/tiny-detr-mobilenetsv3"
|
||||
|
||||
model = AutoModelForObjectDetection.from_pretrained(model_id)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
],
|
||||
threshold=0.0,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||
],
|
||||
[
|
||||
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_large_model_pt(self):
|
||||
model_id = "facebook/detr-resnet-50"
|
||||
|
||||
model = AutoModelForObjectDetection.from_pretrained(model_id)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||
|
||||
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||
],
|
||||
[
|
||||
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_object_detection(self):
|
||||
model_id = "facebook/detr-resnet-50"
|
||||
|
||||
object_detector = pipeline("object-detection", model=model_id)
|
||||
|
||||
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||
],
|
||||
)
|
||||
|
||||
outputs = object_detector(
|
||||
[
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||
],
|
||||
[
|
||||
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}},
|
||||
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}},
|
||||
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}},
|
||||
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_threshold(self):
|
||||
threshold = 0.9985
|
||||
model_id = "facebook/detr-resnet-50"
|
||||
|
||||
object_detector = pipeline("object-detection", model=model_id)
|
||||
|
||||
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}},
|
||||
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}},
|
||||
],
|
||||
)
|
Loading…
Reference in New Issue
Block a user