From e94384e4d879505dd0a0617fd8e199282d309a08 Mon Sep 17 00:00:00 2001 From: Ritik Nandwal <48522685+nandwalritik@users.noreply.github.com> Date: Wed, 12 Oct 2022 18:24:20 +0530 Subject: [PATCH] Add depth estimation pipeline (#18618) * Add initial files for depth estimation pipelines * Add test file for depth estimation pipeline * Update model mapping names * Add updates for depth estimation output * Add generic test * Hopefully fixing the tests. * Check if test passes * Add make fixup and make fix-copies changes after rebase with main * Rebase with main * Fixing up depth pipeline. * This is not used anymore. * Fixing the test. `Image` is a module `Image.Image` is the type. * Update docs/source/en/main_classes/pipelines.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicolas Patry Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/main_classes/pipelines.mdx | 7 +- docs/source/en/model_doc/auto.mdx | 4 + src/transformers/__init__.py | 6 + src/transformers/models/auto/__init__.py | 4 + src/transformers/models/auto/modeling_auto.py | 15 +++ src/transformers/pipelines/__init__.py | 9 ++ .../pipelines/depth_estimation.py | 93 +++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 10 ++ .../test_pipelines_depth_estimation.py | 107 ++++++++++++++++++ utils/update_metadata.py | 1 + 10 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 src/transformers/pipelines/depth_estimation.py create mode 100644 tests/pipelines/test_pipelines_depth_estimation.py diff --git a/docs/source/en/main_classes/pipelines.mdx b/docs/source/en/main_classes/pipelines.mdx index 5374f1a4003..ef6adc48107 100644 --- a/docs/source/en/main_classes/pipelines.mdx +++ b/docs/source/en/main_classes/pipelines.mdx @@ -25,6 +25,7 @@ There are two categories of pipeline abstractions to be aware about: - [`AudioClassificationPipeline`] - [`AutomaticSpeechRecognitionPipeline`] - [`ConversationalPipeline`] + - [`DepthEstimationPipeline`] - [`DocumentQuestionAnsweringPipeline`] - [`FeatureExtractionPipeline`] - [`FillMaskPipeline`] @@ -344,12 +345,16 @@ That should enable you to do all the custom code you want. - __call__ - all +### DepthEstimationPipeline +[[autodoc]] DepthEstimationPipeline + - __call__ + - all + ### DocumentQuestionAnsweringPipeline [[autodoc]] DocumentQuestionAnsweringPipeline - __call__ - all - ### FeatureExtractionPipeline [[autodoc]] FeatureExtractionPipeline diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index 01db8c4b1f7..a6426eb3c2c 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -82,6 +82,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] AutoModelForCausalLM +## AutoModelForDepthEstimation + +[[autodoc]] AutoModelForDepthEstimation + ## AutoModelForMaskedLM [[autodoc]] AutoModelForMaskedLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b2fc16fdd43..0e69839f0ec 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -420,6 +420,7 @@ _import_structure = { "Conversation", "ConversationalPipeline", "CsvPipelineDataFormat", + "DepthEstimationPipeline", "DocumentQuestionAnsweringPipeline", "FeatureExtractionPipeline", "FillMaskPipeline", @@ -859,6 +860,7 @@ else: "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", @@ -888,6 +890,7 @@ else: "AutoModelForCausalLM", "AutoModelForCTC", "AutoModelForDocumentQuestionAnswering", + "AutoModelForDepthEstimation", "AutoModelForImageClassification", "AutoModelForImageSegmentation", "AutoModelForInstanceSegmentation", @@ -3419,6 +3422,7 @@ if TYPE_CHECKING: Conversation, ConversationalPipeline, CsvPipelineDataFormat, + DepthEstimationPipeline, DocumentQuestionAnsweringPipeline, FeatureExtractionPipeline, FillMaskPipeline, @@ -3788,6 +3792,7 @@ if TYPE_CHECKING: MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CTC_MAPPING, + MODEL_FOR_DEPTH_ESTIMATION_MAPPING, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, @@ -3817,6 +3822,7 @@ if TYPE_CHECKING: AutoModelForAudioXVector, AutoModelForCausalLM, AutoModelForCTC, + AutoModelForDepthEstimation, AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 1964c73938f..acb0fa8b0f1 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -48,6 +48,7 @@ else: "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", @@ -76,6 +77,7 @@ else: "AutoModelForAudioXVector", "AutoModelForCausalLM", "AutoModelForCTC", + "AutoModelForDepthEstimation", "AutoModelForImageClassification", "AutoModelForImageSegmentation", "AutoModelForInstanceSegmentation", @@ -197,6 +199,7 @@ if TYPE_CHECKING: MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CTC_MAPPING, + MODEL_FOR_DEPTH_ESTIMATION_MAPPING, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, @@ -226,6 +229,7 @@ if TYPE_CHECKING: AutoModelForAudioXVector, AutoModelForCausalLM, AutoModelForCTC, + AutoModelForDepthEstimation, AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 9b83741aa96..6ef3b812c6a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -480,6 +480,13 @@ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict( + [ + # Model for depth estimation mapping + ("dpt", "DPTForDepthEstimation"), + ("glpn", "GLPNForDepthEstimation"), + ] +) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping @@ -845,6 +852,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODE MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES ) +MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES ) @@ -1040,6 +1048,13 @@ AutoModelForZeroShotObjectDetection = auto_class_update( ) +class AutoModelForDepthEstimation(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING + + +AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") + + class AutoModelForVideoClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 0a878728185..4e8faa58d2e 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -32,6 +32,7 @@ from ..dynamic_module_utils import get_class_from_dynamic_module from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..models.auto.configuration_auto import AutoConfig from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor +from ..models.auto.modeling_auto import AutoModelForDepthEstimation from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils_fast import PreTrainedTokenizerFast @@ -51,6 +52,7 @@ from .base import ( infer_framework_load_model, ) from .conversational import Conversation, ConversationalPipeline +from .depth_estimation import DepthEstimationPipeline from .document_question_answering import DocumentQuestionAnsweringPipeline from .feature_extraction import FeatureExtractionPipeline from .fill_mask import FillMaskPipeline @@ -344,6 +346,13 @@ SUPPORTED_TASKS = { "default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}}, "type": "multimodal", }, + "depth-estimation": { + "impl": DepthEstimationPipeline, + "tf": (), + "pt": (AutoModelForDepthEstimation,) if is_torch_available() else (), + "default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}}, + "type": "image", + }, } NO_FEATURE_EXTRACTOR_TASKS = set() diff --git a/src/transformers/pipelines/depth_estimation.py b/src/transformers/pipelines/depth_estimation.py new file mode 100644 index 00000000000..e826013a42f --- /dev/null +++ b/src/transformers/pipelines/depth_estimation.py @@ -0,0 +1,93 @@ +from typing import List, Union + +import numpy as np + +from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends +from .base import PIPELINE_INIT_ARGS, Pipeline + + +if is_vision_available(): + from PIL import Image + + from ..image_utils import load_image + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING + +logger = logging.get_logger(__name__) + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class DepthEstimationPipeline(Pipeline): + """ + Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image. + + This depth estimation pipeline can currently be loaded from [`pipeline`] using the following task identifier: + `"depth-estimation"`. + + See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=depth-estimation). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + requires_backends(self, "vision") + self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING) + + def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): + """ + Assign labels to the image(s) passed as inputs. + + Args: + images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`): + The pipeline handles three types of images: + + - A string containing a http link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + The pipeline accepts either a single image or a batch of images, which must then be passed as a string. + Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL + images. + top_k (`int`, *optional*, defaults to 5): + The number of top labels that will be returned by the pipeline. If the provided number is higher than + the number of labels available in the model configuration, it will default to the number of labels. + + Return: + A dictionary or a list of dictionaries containing result. If the input is a single image, will return a + dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to + the images. + + The dictionaries contain the following keys: + + - **label** (`str`) -- The label identified by the model. + - **score** (`int`) -- The score attributed by the model for that label. + """ + return super().__call__(images, **kwargs) + + def _sanitize_parameters(self, **kwargs): + return {}, {}, {} + + def preprocess(self, image): + image = load_image(image) + self.image_size = image.size + model_inputs = self.feature_extractor(images=image, return_tensors=self.framework) + return model_inputs + + def _forward(self, model_inputs): + model_outputs = self.model(**model_inputs) + return model_outputs + + def postprocess(self, model_outputs): + predicted_depth = model_outputs.predicted_depth + prediction = torch.nn.functional.interpolate( + predicted_depth.unsqueeze(1), size=self.image_size[::-1], mode="bicubic", align_corners=False + ) + output = prediction.squeeze().cpu().numpy() + formatted = (output * 255 / np.max(output)).astype("uint8") + depth = Image.fromarray(formatted) + output_dict = {} + output_dict["predicted_depth"] = predicted_depth + output_dict["depth"] = depth + return output_dict diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 7e7917a783c..5c9cf9cb43f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -358,6 +358,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = None MODEL_FOR_CTC_MAPPING = None +MODEL_FOR_DEPTH_ESTIMATION_MAPPING = None + + MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = None @@ -469,6 +472,13 @@ class AutoModelForCTC(metaclass=DummyObject): requires_backends(self, ["torch"]) +class AutoModelForDepthEstimation(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelForDocumentQuestionAnswering(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/test_pipelines_depth_estimation.py b/tests/pipelines/test_pipelines_depth_estimation.py new file mode 100644 index 00000000000..d42ba2a067c --- /dev/null +++ b/tests/pipelines/test_pipelines_depth_estimation.py @@ -0,0 +1,107 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import unittest + +from transformers import MODEL_FOR_DEPTH_ESTIMATION_MAPPING, is_torch_available, is_vision_available +from transformers.pipelines import DepthEstimationPipeline, pipeline +from transformers.testing_utils import nested_simplify, require_tf, require_timm, require_torch, require_vision, slow + +from .test_pipelines_common import ANY, PipelineTestCaseMeta + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + +def hashimage(image: Image) -> str: + m = hashlib.md5(image.tobytes()) + return m.hexdigest() + + +@require_vision +@require_timm +@require_torch +class DepthEstimationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + + model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING + + def get_test_pipeline(self, model, tokenizer, feature_extractor): + depth_estimator = DepthEstimationPipeline(model=model, feature_extractor=feature_extractor) + return depth_estimator, [ + "./tests/fixtures/tests_samples/COCO/000000039769.png", + "./tests/fixtures/tests_samples/COCO/000000039769.png", + ] + + def run_pipeline_test(self, depth_estimator, examples): + outputs = depth_estimator("./tests/fixtures/tests_samples/COCO/000000039769.png") + self.assertEqual({"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)}, outputs) + import datasets + + dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test") + outputs = depth_estimator( + [ + Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), + "http://images.cocodataset.org/val2017/000000039769.jpg", + # RGBA + dataset[0]["file"], + # LA + dataset[1]["file"], + # L + dataset[2]["file"], + ] + ) + self.assertEqual( + [ + {"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)}, + {"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)}, + {"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)}, + {"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)}, + {"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)}, + ], + outputs, + ) + + @require_tf + @unittest.skip("Depth estimation is not implemented in TF") + def test_small_model_tf(self): + pass + + @slow + @require_torch + def test_large_model_pt(self): + model_id = "Intel/dpt-large" + depth_estimator = pipeline("depth-estimation", model=model_id) + outputs = depth_estimator("http://images.cocodataset.org/val2017/000000039769.jpg") + outputs["depth"] = hashimage(outputs["depth"]) + + # This seems flaky. + # self.assertEqual(outputs["depth"], "1a39394e282e9f3b0741a90b9f108977") + self.assertEqual(nested_simplify(outputs["predicted_depth"].max().item()), 29.304) + self.assertEqual(nested_simplify(outputs["predicted_depth"].min().item()), 2.662) + + @require_torch + def test_small_model_pt(self): + # This is highly irregular to have no small tests. + self.skipTest("There is not hf-internal-testing tiny model for either GLPN nor DPT") diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 8bb3b71672d..5e7169c2558 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -101,6 +101,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ "_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModel", ), + ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ]