diff --git a/docs/source/main_classes/pipelines.rst b/docs/source/main_classes/pipelines.rst index 857612afb1c..c10e2f08978 100644 --- a/docs/source/main_classes/pipelines.rst +++ b/docs/source/main_classes/pipelines.rst @@ -29,6 +29,7 @@ There are two categories of pipeline abstractions to be aware about: - :class:`~transformers.FeatureExtractionPipeline` - :class:`~transformers.FillMaskPipeline` - :class:`~transformers.ImageClassificationPipeline` + - :class:`~transformers.ObjectDetectionPipeline` - :class:`~transformers.QuestionAnsweringPipeline` - :class:`~transformers.SummarizationPipeline` - :class:`~transformers.TableQuestionAnsweringPipeline` @@ -102,6 +103,13 @@ NerPipeline See :class:`~transformers.TokenClassificationPipeline` for all details. +ObjectDetectionPipeline +======================================================================================================================= + +.. autoclass:: transformers.ObjectDetectionPipeline + :special-members: __call__ + :members: + QuestionAnsweringPipeline ======================================================================================================================= diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 2a013ed9b97..928d5184614 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -142,6 +142,13 @@ AutoModelForAudioClassification :members: +AutoModelForObjectDetection +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.AutoModelForObjectDetection + :members: + + TFAutoModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d2bbfb7a3cc..00eb78b4289 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -294,6 +294,7 @@ _import_structure = { "ImageClassificationPipeline", "JsonPipelineDataFormat", "NerPipeline", + "ObjectDetectionPipeline", "PipedPipelineDataFormat", "Pipeline", "PipelineDataFormat", @@ -558,6 +559,7 @@ if is_torch_available(): "AutoModelForMaskedLM", "AutoModelForMultipleChoice", "AutoModelForNextSentencePrediction", + "AutoModelForObjectDetection", "AutoModelForPreTraining", "AutoModelForQuestionAnswering", "AutoModelForSeq2SeqLM", @@ -2074,6 +2076,7 @@ if TYPE_CHECKING: ImageClassificationPipeline, JsonPipelineDataFormat, NerPipeline, + ObjectDetectionPipeline, PipedPipelineDataFormat, Pipeline, PipelineDataFormat, @@ -2295,6 +2298,7 @@ if TYPE_CHECKING: AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForNextSentencePrediction, + AutoModelForObjectDetection, AutoModelForPreTraining, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index b1db8f95e8a..6b99d6a91ba 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -52,6 +52,7 @@ if is_torch_available(): "AutoModelForMaskedLM", "AutoModelForMultipleChoice", "AutoModelForNextSentencePrediction", + "AutoModelForObjectDetection", "AutoModelForPreTraining", "AutoModelForQuestionAnswering", "AutoModelForSeq2SeqLM", @@ -143,6 +144,7 @@ if TYPE_CHECKING: AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForNextSentencePrediction, + AutoModelForObjectDetection, AutoModelForPreTraining, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 169bd5c7ce2..7fcd0dd5564 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -34,6 +34,7 @@ from .configuration_auto import ( FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( [ ("beit", "BeitFeatureExtractor"), + ("detr", "DetrFeatureExtractor"), ("deit", "DeiTFeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"), ("speech_to_text", "Speech2TextFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a3569d90ae6..2f8f099e3d8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -588,6 +588,13 @@ class AutoModelForImageClassification(_BaseAutoModelClass): AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") +class AutoModelForObjectDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING + + +AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") + + class AutoModelForAudioClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 35a870609cb..3bdf4e4d6f7 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -44,6 +44,7 @@ from .conversational import Conversation, ConversationalPipeline from .feature_extraction import FeatureExtractionPipeline from .fill_mask import FillMaskPipeline from .image_classification import ImageClassificationPipeline +from .object_detection import ObjectDetectionPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline @@ -91,6 +92,7 @@ if is_torch_available(): AutoModelForCausalLM, AutoModelForImageClassification, AutoModelForMaskedLM, + AutoModelForObjectDetection, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, @@ -229,6 +231,12 @@ SUPPORTED_TASKS = { "pt": (AutoModelForImageClassification,) if is_torch_available() else (), "default": {"model": {"pt": "google/vit-base-patch16-224"}}, }, + "object-detection": { + "impl": ObjectDetectionPipeline, + "tf": (), + "pt": (AutoModelForObjectDetection,) if is_torch_available() else (), + "default": {"model": {"pt": "facebook/detr-resnet-50"}}, + }, } diff --git a/src/transformers/pipelines/object_detection.py b/src/transformers/pipelines/object_detection.py new file mode 100644 index 00000000000..cfed91f776d --- /dev/null +++ b/src/transformers/pipelines/object_detection.py @@ -0,0 +1,176 @@ +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import requests + +from ..feature_extraction_utils import PreTrainedFeatureExtractor +from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends +from ..utils import logging +from .base import PIPELINE_INIT_ARGS, Pipeline + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +if is_vision_available(): + from PIL import Image + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING + +logger = logging.get_logger(__name__) + + +Prediction = Dict[str, Any] +Predictions = List[Prediction] + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class ObjectDetectionPipeline(Pipeline): + """ + Object detection pipeline using any :obj:`AutoModelForObjectDetection`. This pipeline predicts bounding boxes of + objects and their classes. + + This object detection pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task + identifier: :obj:`"object-detection"`. + + See the list of available models on `huggingface.co/models + `__. + """ + + def __init__( + self, + model: "PreTrainedModel", + feature_extractor: PreTrainedFeatureExtractor, + framework: Optional[str] = None, + **kwargs + ): + super().__init__(model, feature_extractor=feature_extractor, framework=framework, **kwargs) + + if self.framework == "tf": + raise ValueError(f"The {self.__class__} is only available in PyTorch.") + + requires_backends(self, "vision") + + self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING) + + self.feature_extractor = feature_extractor + + @staticmethod + def load_image(image: Union[str, "Image.Image"]): + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + # We need to actually check for a real protocol, otherwise it's impossible to use a local file + # like http_huggingface_co.png + image = Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, Image.Image): + pass + else: + raise ValueError( + "Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image." + ) + image = image.convert("RGB") + return image + + def __call__( + self, + images: Union[str, List[str], "Image", List["Image"]], + threshold: Optional[float] = 0.9, + ) -> Union[Predictions, List[Prediction]]: + """ + Detect objects (bounding boxes & classes) in the image(s) passed as inputs. + + Args: + images (:obj:`str`, :obj:`List[str]`, :obj:`PIL.Image` or :obj:`List[PIL.Image]`): + The pipeline handles three types of images: + + - A string containing an HTTP(S) link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the + same format: all as HTTP(S) links, all as local paths, or all as PIL images. + threshold (:obj:`float`, `optional`, defaults to 0.9): + The probability necessary to make a prediction. + + Return: + A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single + image, will return a list of dictionaries, if the input is a list of several images, will return a list of + list of dictionaries corresponding to each image. + + The dictionaries contain the following keys: + + - **label** (:obj:`str`) -- The class label identified by the model. + - **score** (:obj:`float`) -- The score attributed by the model for that label. + - **box** (:obj:`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size. + """ + is_batched = isinstance(images, list) + + if not is_batched: + images = [images] + + images = [self.load_image(image) for image in images] + + with torch.no_grad(): + inputs = self.feature_extractor(images=images, return_tensors="pt") + outputs = self.model(**inputs) + + if self.framework == "pt": + target_sizes = torch.IntTensor([[im.height, im.width] for im in images]) + else: + raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.") + + raw_annotations = self.feature_extractor.post_process(outputs, target_sizes) + annotations = [] + for annotation in raw_annotations: + keep = annotation["scores"] > threshold + scores = annotation["scores"][keep] + labels = annotation["labels"][keep] + boxes = annotation["boxes"][keep] + + annotation["scores"] = scores.tolist() + annotation["labels"] = [self.model.config.id2label[label.item()] for label in labels] + annotation["boxes"] = [self._get_bounding_box(box) for box in boxes] + + # {"scores": [...], ...} --> [{"score":x, ...}, ...] + keys = ["score", "label", "box"] + annotation = [ + dict(zip(keys, vals)) + for vals in zip(annotation["scores"], annotation["labels"], annotation["boxes"]) + ] + + annotations.append(annotation) + + if not is_batched: + return annotations[0] + + return annotations + + def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]: + """ + Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... } + + Args: + box (torch.Tensor): Tensor containing the coordinates in corners format. + + Returns: + bbox (Dict[str, int]): Dict containing the coordinates in corners format. + """ + if self.framework != "pt": + raise ValueError("The ObjectDetectionPipeline is only available in PyTorch.") + xmin, ymin, xmax, ymax = box.int().tolist() + bbox = { + "xmin": xmin, + "ymin": ymin, + "xmax": xmax, + "ymax": ymax, + } + return bbox diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 6d8f46f20e9..4eb80cec3f0 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -415,6 +415,15 @@ class AutoModelForNextSentencePrediction: requires_backends(cls, ["torch"]) +class AutoModelForObjectDetection: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoModelForPreTraining: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_pipelines_object_detection.py b/tests/test_pipelines_object_detection.py new file mode 100644 index 00000000000..1fc53ca3e9f --- /dev/null +++ b/tests/test_pipelines_object_detection.py @@ -0,0 +1,253 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import ( + MODEL_FOR_OBJECT_DETECTION_MAPPING, + AutoFeatureExtractor, + AutoModelForObjectDetection, + ObjectDetectionPipeline, + is_vision_available, + pipeline, +) +from transformers.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_datasets, + require_tf, + require_timm, + require_torch, + require_vision, + slow, +) + +from .test_pipelines_common import ANY, PipelineTestCaseMeta + + +if is_vision_available(): + from PIL import Image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + +@require_vision +@require_timm +@require_torch +@is_pipeline_test +class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING + + @require_datasets + def run_pipeline_test(self, model, tokenizer, feature_extractor): + object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor) + outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0) + + self.assertGreater(len(outputs), 0) + for detected_object in outputs: + self.assertEqual( + detected_object, + { + "score": ANY(float), + "label": ANY(str), + "box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)}, + }, + ) + + import datasets + + dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test") + + batch = [ + Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), + "http://images.cocodataset.org/val2017/000000039769.jpg", + # RGBA + dataset[0]["file"], + # LA + dataset[1]["file"], + # L + dataset[2]["file"], + ] + batch_outputs = object_detector(batch, threshold=0.0) + + self.assertEqual(len(batch), len(batch_outputs)) + for outputs in batch_outputs: + self.assertGreater(len(outputs), 0) + for detected_object in outputs: + self.assertEqual( + detected_object, + { + "score": ANY(float), + "label": ANY(str), + "box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)}, + }, + ) + + @require_tf + @unittest.skip("Object detection not implemented in TF") + def test_small_model_tf(self): + pass + + @require_torch + def test_small_model_pt(self): + model_id = "mishig/tiny-detr-mobilenetsv3" + + model = AutoModelForObjectDetection.from_pretrained(model_id) + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor) + + outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=0.0) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, + {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, + ], + ) + + outputs = object_detector( + [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ], + threshold=0.0, + ) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, + {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, + ], + [ + {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, + {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, + ], + ], + ) + + @require_torch + @slow + def test_large_model_pt(self): + model_id = "facebook/detr-resnet-50" + + model = AutoModelForObjectDetection.from_pretrained(model_id) + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor) + + outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg") + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, + {"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, + {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, + {"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, + {"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, + ], + ) + + outputs = object_detector( + [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ] + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, + {"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, + {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, + {"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, + {"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, + ], + [ + {"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, + {"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, + {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, + {"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, + {"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, + ], + ], + ) + + @require_torch + @slow + def test_integration_torch_object_detection(self): + model_id = "facebook/detr-resnet-50" + + object_detector = pipeline("object-detection", model=model_id) + + outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg") + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, + {"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, + {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, + {"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, + {"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, + ], + ) + + outputs = object_detector( + [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ] + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, + {"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, + {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, + {"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, + {"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, + ], + [ + {"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, + {"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, + {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, + {"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, + {"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, + ], + ], + ) + + @require_torch + @slow + def test_threshold(self): + threshold = 0.9985 + model_id = "facebook/detr-resnet-50" + + object_detector = pipeline("object-detection", model=model_id) + + outputs = object_detector("http://images.cocodataset.org/val2017/000000039769.jpg", threshold=threshold) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, + {"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, + ], + )