mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add depth estimation pipeline (#18618)
* Add initial files for depth estimation pipelines * Add test file for depth estimation pipeline * Update model mapping names * Add updates for depth estimation output * Add generic test * Hopefully fixing the tests. * Check if test passes * Add make fixup and make fix-copies changes after rebase with main * Rebase with main * Fixing up depth pipeline. * This is not used anymore. * Fixing the test. `Image` is a module `Image.Image` is the type. * Update docs/source/en/main_classes/pipelines.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
4ed0fa3676
commit
e94384e4d8
@ -25,6 +25,7 @@ There are two categories of pipeline abstractions to be aware about:
|
||||
- [`AudioClassificationPipeline`]
|
||||
- [`AutomaticSpeechRecognitionPipeline`]
|
||||
- [`ConversationalPipeline`]
|
||||
- [`DepthEstimationPipeline`]
|
||||
- [`DocumentQuestionAnsweringPipeline`]
|
||||
- [`FeatureExtractionPipeline`]
|
||||
- [`FillMaskPipeline`]
|
||||
@ -344,12 +345,16 @@ That should enable you to do all the custom code you want.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### DepthEstimationPipeline
|
||||
[[autodoc]] DepthEstimationPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### DocumentQuestionAnsweringPipeline
|
||||
|
||||
[[autodoc]] DocumentQuestionAnsweringPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### FeatureExtractionPipeline
|
||||
|
||||
[[autodoc]] FeatureExtractionPipeline
|
||||
|
@ -82,6 +82,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
||||
|
||||
[[autodoc]] AutoModelForCausalLM
|
||||
|
||||
## AutoModelForDepthEstimation
|
||||
|
||||
[[autodoc]] AutoModelForDepthEstimation
|
||||
|
||||
## AutoModelForMaskedLM
|
||||
|
||||
[[autodoc]] AutoModelForMaskedLM
|
||||
|
@ -420,6 +420,7 @@ _import_structure = {
|
||||
"Conversation",
|
||||
"ConversationalPipeline",
|
||||
"CsvPipelineDataFormat",
|
||||
"DepthEstimationPipeline",
|
||||
"DocumentQuestionAnsweringPipeline",
|
||||
"FeatureExtractionPipeline",
|
||||
"FillMaskPipeline",
|
||||
@ -859,6 +860,7 @@ else:
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_CTC_MAPPING",
|
||||
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||
@ -888,6 +890,7 @@ else:
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
"AutoModelForDepthEstimation",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForImageSegmentation",
|
||||
"AutoModelForInstanceSegmentation",
|
||||
@ -3419,6 +3422,7 @@ if TYPE_CHECKING:
|
||||
Conversation,
|
||||
ConversationalPipeline,
|
||||
CsvPipelineDataFormat,
|
||||
DepthEstimationPipeline,
|
||||
DocumentQuestionAnsweringPipeline,
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
@ -3788,6 +3792,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
|
||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
@ -3817,6 +3822,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForAudioXVector,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForDepthEstimation,
|
||||
AutoModelForDocumentQuestionAnswering,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
|
@ -48,6 +48,7 @@ else:
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_CTC_MAPPING",
|
||||
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||
@ -76,6 +77,7 @@ else:
|
||||
"AutoModelForAudioXVector",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForDepthEstimation",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForImageSegmentation",
|
||||
"AutoModelForInstanceSegmentation",
|
||||
@ -197,6 +199,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
|
||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||
@ -226,6 +229,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForAudioXVector,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForDepthEstimation,
|
||||
AutoModelForDocumentQuestionAnswering,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
|
@ -480,6 +480,13 @@ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for depth estimation mapping
|
||||
("dpt", "DPTForDepthEstimation"),
|
||||
("glpn", "GLPNForDepthEstimation"),
|
||||
]
|
||||
)
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
@ -845,6 +852,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODE
|
||||
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||
)
|
||||
@ -1040,6 +1048,13 @@ AutoModelForZeroShotObjectDetection = auto_class_update(
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForDepthEstimation(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
|
||||
|
||||
|
||||
class AutoModelForVideoClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||
|
||||
|
@ -32,6 +32,7 @@ from ..dynamic_module_utils import get_class_from_dynamic_module
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||
from ..models.auto.modeling_auto import AutoModelForDepthEstimation
|
||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
@ -51,6 +52,7 @@ from .base import (
|
||||
infer_framework_load_model,
|
||||
)
|
||||
from .conversational import Conversation, ConversationalPipeline
|
||||
from .depth_estimation import DepthEstimationPipeline
|
||||
from .document_question_answering import DocumentQuestionAnsweringPipeline
|
||||
from .feature_extraction import FeatureExtractionPipeline
|
||||
from .fill_mask import FillMaskPipeline
|
||||
@ -344,6 +346,13 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"depth-estimation": {
|
||||
"impl": DepthEstimationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForDepthEstimation,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}},
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
|
93
src/transformers/pipelines/depth_estimation.py
Normal file
93
src/transformers/pipelines/depth_estimation.py
Normal file
@ -0,0 +1,93 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class DepthEstimationPipeline(Pipeline):
|
||||
"""
|
||||
Depth estimation pipeline using any `AutoModelForDepthEstimation`. This pipeline predicts the depth of an image.
|
||||
|
||||
This depth estimation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"depth-estimation"`.
|
||||
|
||||
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=depth-estimation).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING)
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
||||
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||
images.
|
||||
top_k (`int`, *optional*, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
||||
dictionary, if the input is a list of several images, will return a list of dictionaries corresponding to
|
||||
the images.
|
||||
|
||||
The dictionaries contain the following keys:
|
||||
|
||||
- **label** (`str`) -- The label identified by the model.
|
||||
- **score** (`int`) -- The score attributed by the model for that label.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
return {}, {}, {}
|
||||
|
||||
def preprocess(self, image):
|
||||
image = load_image(image)
|
||||
self.image_size = image.size
|
||||
model_inputs = self.feature_extractor(images=image, return_tensors=self.framework)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
predicted_depth = model_outputs.predicted_depth
|
||||
prediction = torch.nn.functional.interpolate(
|
||||
predicted_depth.unsqueeze(1), size=self.image_size[::-1], mode="bicubic", align_corners=False
|
||||
)
|
||||
output = prediction.squeeze().cpu().numpy()
|
||||
formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
depth = Image.fromarray(formatted)
|
||||
output_dict = {}
|
||||
output_dict["predicted_depth"] = predicted_depth
|
||||
output_dict["depth"] = depth
|
||||
return output_dict
|
@ -358,6 +358,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = None
|
||||
MODEL_FOR_CTC_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
@ -469,6 +472,13 @@ class AutoModelForCTC(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForDepthEstimation(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForDocumentQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
107
tests/pipelines/test_pipelines_depth_estimation.py
Normal file
107
tests/pipelines/test_pipelines_depth_estimation.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_DEPTH_ESTIMATION_MAPPING, is_torch_available, is_vision_available
|
||||
from transformers.pipelines import DepthEstimationPipeline, pipeline
|
||||
from transformers.testing_utils import nested_simplify, require_tf, require_timm, require_torch, require_vision, slow
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
else:
|
||||
|
||||
class Image:
|
||||
@staticmethod
|
||||
def open(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def hashimage(image: Image) -> str:
|
||||
m = hashlib.md5(image.tobytes())
|
||||
return m.hexdigest()
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_timm
|
||||
@require_torch
|
||||
class DepthEstimationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
|
||||
model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||
depth_estimator = DepthEstimationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
return depth_estimator, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
]
|
||||
|
||||
def run_pipeline_test(self, depth_estimator, examples):
|
||||
outputs = depth_estimator("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
self.assertEqual({"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)}, outputs)
|
||||
import datasets
|
||||
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
outputs = depth_estimator(
|
||||
[
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
# RGBA
|
||||
dataset[0]["file"],
|
||||
# LA
|
||||
dataset[1]["file"],
|
||||
# L
|
||||
dataset[2]["file"],
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
|
||||
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
|
||||
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
|
||||
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
|
||||
{"predicted_depth": ANY(torch.Tensor), "depth": ANY(Image.Image)},
|
||||
],
|
||||
outputs,
|
||||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Depth estimation is not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
model_id = "Intel/dpt-large"
|
||||
depth_estimator = pipeline("depth-estimation", model=model_id)
|
||||
outputs = depth_estimator("http://images.cocodataset.org/val2017/000000039769.jpg")
|
||||
outputs["depth"] = hashimage(outputs["depth"])
|
||||
|
||||
# This seems flaky.
|
||||
# self.assertEqual(outputs["depth"], "1a39394e282e9f3b0741a90b9f108977")
|
||||
self.assertEqual(nested_simplify(outputs["predicted_depth"].max().item()), 29.304)
|
||||
self.assertEqual(nested_simplify(outputs["predicted_depth"].min().item()), 2.662)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
# This is highly irregular to have no small tests.
|
||||
self.skipTest("There is not hf-internal-testing tiny model for either GLPN nor DPT")
|
@ -101,6 +101,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
"_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModel",
|
||||
),
|
||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user