transformers/tests/pipelines/test_pipelines_object_detection.py
NielsRogge 59407bbeb3
Add Deformable DETR (#17281)
* First draft

* More improvements

* Improve model, add custom CUDA code

* Import torch before

* Add script that imports custom layer

* Add everything in new ops directory

* Import custom layer in modeling file

* Fix ARCHIVE_MAP typo

* Creating the custom kernel on the fly.

* Import custom layer in modeling file

* More improvements

* Fix CUDA loading

* More improvements

* Improve conversion script

* Improve conversion script

* Make it work until encoder_outputs

* Make forward pass work

* More improvements

* Make logits match original implementation

* Make implementation also support single_scale model

* Add support for single_scale and dilation checkpoint

* Add support for with_box_refine model

* Support also two stage model

* Improve tests

* Fix more tests

* Make more tests pass

* Upload all models to the hub

* Clean up some code

* Improve decoder outputs

* Rename intermediate hidden states and reference points

* Improve model outputs

* Move tests to dedicated folder

* Improve model outputs

* Fix retain_grad test

* Improve docs

* Clean up and make test_initialization pass

* Improve variable names

* Add copied from statements

* Improve docs

* Fix style

* Improve docs

* Improve docs, move tests to model folder

* Fix rebase

* Remove DetrForSegmentation from auto mapping

* Apply suggestions from code review

* Improve variable names and docstrings

* Apply some more suggestions from code review

* Apply suggestion from code review

* better docs and variables names

* hint to num_queries and two_stage confusion

* remove asserts and code refactor

* add exception if two_stage is True and with_box_refine is False

* use f-strings

* Improve docs and variable names

* Fix code quality

* Fix rebase

* Add require_torch_gpu decorator

* Add pip install ninja to CI jobs

* Apply suggestion of @sgugger

* Remove DeformableDetrForObjectDetection from auto mapping

* Remove DeformableDetrModel from auto mapping

* Add model to toctree

* Add model back to mappings, skip model in pipeline tests

* Apply @sgugger's suggestion

* Fix imports in the init

* Fix copies

* Add CPU implementation

* Comment out GPU function

* Undo previous change

* Apply more suggestions

* Remove require_torch_gpu annotator

* Fix quality

* Add logger.info

* Fix logger

* Fix variable names

* Fix initializaztion

* Add missing initialization

* Update checkpoint name

* Add model to doc tests

* Add CPU/GPU equivalence test

* Add Deformable DETR to pipeline tests

* Skip model for object detection pipeline

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
2022-09-14 11:45:21 +02:00

261 lines
11 KiB
Python

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import (
MODEL_FOR_OBJECT_DETECTION_MAPPING,
AutoFeatureExtractor,
AutoModelForObjectDetection,
ObjectDetectionPipeline,
is_vision_available,
pipeline,
)
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_timm,
require_torch,
require_vision,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
if is_vision_available():
from PIL import Image
else:
class Image:
@staticmethod
def open(*args, **kwargs):
pass
@require_vision
@require_timm
@require_torch
@is_pipeline_test
class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if model.__class__.__name__ == "DeformableDetrForObjectDetection":
self.skipTest(
"""Deformable DETR requires a custom CUDA kernel.
"""
)
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
def run_pipeline_test(self, object_detector, examples):
outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
self.assertGreater(len(outputs), 0)
for detected_object in outputs:
self.assertEqual(
detected_object,
{
"score": ANY(float),
"label": ANY(str),
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
},
)
import datasets
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
batch = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"http://images.cocodataset.org/val2017/000000039769.jpg",
# RGBA
dataset[0]["file"],
# LA
dataset[1]["file"],
# L
dataset[2]["file"],
]
batch_outputs = object_detector(batch, threshold=0.0)
self.assertEqual(len(batch), len(batch_outputs))
for outputs in batch_outputs:
self.assertGreater(len(outputs), 0)
for detected_object in outputs:
self.assertEqual(
detected_object,
{
"score": ANY(float),
"label": ANY(str),
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
},
)
@require_tf
@unittest.skip("Object detection not implemented in TF")
def test_small_model_tf(self):
pass
@require_torch
def test_small_model_pt(self):
model_id = "hf-internal-testing/tiny-detr-mobilenetsv3"
model = AutoModelForObjectDetection.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
)
outputs = object_detector(
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
],
threshold=0.0,
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
[
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3376, "label": "LABEL_0", "box": {"xmin": 159, "ymin": 120, "xmax": 480, "ymax": 359}},
],
],
)
@require_torch
@slow
def test_large_model_pt(self):
model_id = "facebook/detr-resnet-50"
model = AutoModelForObjectDetection.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
)
outputs = object_detector(
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
]
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
],
)
@require_torch
@slow
def test_integration_torch_object_detection(self):
model_id = "facebook/detr-resnet-50"
object_detector = pipeline("object-detection", model=model_id)
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg")
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
)
outputs = object_detector(
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
]
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
[
{"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
],
)
@require_torch
@slow
def test_threshold(self):
threshold = 0.9985
model_id = "facebook/detr-resnet-50"
object_detector = pipeline("object-detection", model=model_id)
outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
],
)