mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Image pipelines spec compliance (#33899)
* Update many similar visual pipelines * Add input tests * Add ImageToText as well * Add output tests * Add output tests * Add output tests * OutputElement -> Output * Correctly test elements * make fixup * fix typo in the task list * Fix VQA testing * Add copyright to image_classification.py * Revert changes to VQA pipeline because outputs have differences - will move to another PR * make fixup * Remove deprecation warnings
This commit is contained in:
parent
e2001c3413
commit
3b44d2f042
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@ -50,12 +51,12 @@ class DepthEstimationPipeline(Pipeline):
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||
"""
|
||||
Predict the depth(s) of the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
@ -65,9 +66,10 @@ class DepthEstimationPipeline(Pipeline):
|
||||
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
||||
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||
images.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
parameters (`Dict`, *optional*):
|
||||
A dictionary of argument names to parameter values, to control pipeline behaviour.
|
||||
The only parameter available right now is `timeout`, which is the length of time, in seconds,
|
||||
that the pipeline should wait before giving up on trying to download an image.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
||||
@ -79,12 +81,22 @@ class DepthEstimationPipeline(Pipeline):
|
||||
- **predicted_depth** (`torch.Tensor`) -- The predicted depth by the model as a `torch.Tensor`.
|
||||
- **depth** (`PIL.Image`) -- The predicted depth by the model as a `PIL.Image`.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the depth-estimation pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, timeout=None, **kwargs):
|
||||
def _sanitize_parameters(self, timeout=None, parameters=None, **kwargs):
|
||||
preprocess_params = {}
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = timeout
|
||||
if isinstance(parameters, dict) and "timeout" in parameters:
|
||||
preprocess_params["timeout"] = parameters["timeout"]
|
||||
return preprocess_params, {}, {}
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
|
@ -1,3 +1,17 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
@ -99,6 +113,9 @@ class ImageClassificationPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, top_k=None, function_to_apply=None, timeout=None):
|
||||
preprocess_params = {}
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = timeout
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
@ -109,12 +126,12 @@ class ImageClassificationPipeline(Pipeline):
|
||||
postprocess_params["function_to_apply"] = function_to_apply
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
@ -142,9 +159,6 @@ class ImageClassificationPipeline(Pipeline):
|
||||
top_k (`int`, *optional*, defaults to 5):
|
||||
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
||||
the number of labels available in the model configuration, it will default to the number of labels.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing result. If the input is a single image, will return a
|
||||
@ -156,7 +170,12 @@ class ImageClassificationPipeline(Pipeline):
|
||||
- **label** (`str`) -- The label identified by the model.
|
||||
- **score** (`int`) -- The score attributed by the model for that label.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
image = load_image(image, timeout=timeout)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
@ -90,16 +91,19 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
if "overlap_mask_area_threshold" in kwargs:
|
||||
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
||||
if "timeout" in kwargs:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_kwargs["timeout"] = kwargs["timeout"]
|
||||
|
||||
return preprocess_kwargs, {}, postprocess_kwargs
|
||||
|
||||
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
"""
|
||||
Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an HTTP(S) link pointing to an image
|
||||
@ -118,9 +122,6 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Mask overlap threshold to eliminate small, disconnected segments.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
|
||||
@ -136,7 +137,12 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
|
||||
"object" described by the label and the mask.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, image, subtask=None, timeout=None):
|
||||
image = load_image(image, timeout=timeout)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import (
|
||||
@ -80,6 +81,9 @@ class ImageToTextPipeline(Pipeline):
|
||||
if prompt is not None:
|
||||
preprocess_params["prompt"] = prompt
|
||||
if timeout is not None:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = timeout
|
||||
|
||||
if max_new_tokens is not None:
|
||||
@ -94,12 +98,12 @@ class ImageToTextPipeline(Pipeline):
|
||||
|
||||
return preprocess_params, forward_params, {}
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a HTTP(s) link pointing to an image
|
||||
@ -113,16 +117,18 @@ class ImageToTextPipeline(Pipeline):
|
||||
|
||||
generate_kwargs (`Dict`, *optional*):
|
||||
Pass it to send all of these arguments directly to `generate` allowing full control of this function.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
|
||||
|
||||
- **generated_text** (`str`) -- The generated text.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs:
|
||||
inputs = kwargs.pop("images")
|
||||
if inputs is None:
|
||||
raise ValueError("Cannot call the image-to-text pipeline without an inputs argument!")
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, image, prompt=None, timeout=None):
|
||||
image = load_image(image, timeout=timeout)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
|
||||
@ -63,6 +64,9 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
if "timeout" in kwargs:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = kwargs["timeout"]
|
||||
postprocess_kwargs = {}
|
||||
if "threshold" in kwargs:
|
||||
@ -74,7 +78,7 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an HTTP(S) link pointing to an image
|
||||
@ -85,9 +89,6 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
same format: all as HTTP(S) links, all as local paths, or all as PIL images.
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
The probability necessary to make a prediction.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
Return:
|
||||
A list of dictionaries or a list of list of dictionaries containing the result. If the input is a single
|
||||
@ -100,7 +101,9 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
- **score** (`float`) -- The score attributed by the model for that label.
|
||||
- **box** (`List[Dict[str, int]]`) -- The bounding box of detected object in image's original size.
|
||||
"""
|
||||
|
||||
# After deprecation of this is completed, remove the default `None` value for `images`
|
||||
if "images" in kwargs and "inputs" not in kwargs:
|
||||
kwargs["inputs"] = kwargs.pop("images")
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def preprocess(self, image, timeout=None):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from collections import UserDict
|
||||
from typing import List, Union
|
||||
|
||||
@ -73,12 +74,12 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
else MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
||||
def __call__(self, image: Union[str, List[str], "Image", List["Image"]] = None, **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
image (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
@ -93,13 +94,6 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
replacing the placeholder with the candidate_labels. Pass "{}" if *candidate_labels* are
|
||||
already formatted.
|
||||
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
||||
the call may block forever.
|
||||
|
||||
tokenizer_kwargs (`dict`, *optional*):
|
||||
Additional dictionary of keyword arguments passed along to the tokenizer.
|
||||
|
||||
Return:
|
||||
A list of dictionaries containing one entry per proposed label. Each dictionary contains the
|
||||
following keys:
|
||||
@ -107,17 +101,29 @@ class ZeroShotImageClassificationPipeline(Pipeline):
|
||||
- **score** (`float`) -- The score attributed by the model to that label. It is a value between
|
||||
0 and 1, computed as the `softmax` of `logits_per_image`.
|
||||
"""
|
||||
return super().__call__(images, **kwargs)
|
||||
# After deprecation of this is completed, remove the default `None` value for `image`
|
||||
if "images" in kwargs:
|
||||
image = kwargs.pop("images")
|
||||
if image is None:
|
||||
raise ValueError("Cannot call the zero-shot-image-classification pipeline without an images argument!")
|
||||
return super().__call__(image, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, tokenizer_kwargs=None, **kwargs):
|
||||
preprocess_params = {}
|
||||
if "candidate_labels" in kwargs:
|
||||
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
|
||||
if "timeout" in kwargs:
|
||||
warnings.warn(
|
||||
"The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
|
||||
)
|
||||
preprocess_params["timeout"] = kwargs["timeout"]
|
||||
if "hypothesis_template" in kwargs:
|
||||
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
|
||||
if tokenizer_kwargs is not None:
|
||||
warnings.warn(
|
||||
"The `tokenizer_kwargs` argument is deprecated and will be removed in version 5 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
|
||||
|
||||
return preprocess_params, {}, {}
|
||||
|
@ -14,11 +14,13 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import DepthEstimationOutput
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
|
||||
from transformers import MODEL_FOR_DEPTH_ESTIMATION_MAPPING, is_torch_available, is_vision_available
|
||||
from transformers.pipelines import DepthEstimationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
@ -94,6 +96,9 @@ class DepthEstimationPipelineTests(unittest.TestCase):
|
||||
outputs,
|
||||
)
|
||||
|
||||
for single_output in outputs:
|
||||
compare_pipeline_output_to_hub_spec(single_output, DepthEstimationOutput)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip(reason="Depth estimation is not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import ImageClassificationOutputElement
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
@ -23,6 +25,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.pipelines import ImageClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
@ -121,6 +124,10 @@ class ImageClassificationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
for single_output in outputs:
|
||||
for output_element in single_output:
|
||||
compare_pipeline_output_to_hub_spec(output_element, ImageClassificationOutputElement)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
small_model = "hf-internal-testing/tiny-random-vit"
|
||||
|
@ -20,6 +20,7 @@ import datasets
|
||||
import numpy as np
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ImageSegmentationOutputElement
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
|
||||
from transformers import (
|
||||
@ -36,6 +37,7 @@ from transformers import (
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
@ -168,6 +170,10 @@ class ImageSegmentationPipelineTests(unittest.TestCase):
|
||||
f"Expected [{n}, {n}, {n}, {n}, {n}], got {[len(item) for item in outputs]}",
|
||||
)
|
||||
|
||||
for single_output in outputs:
|
||||
for output_element in single_output:
|
||||
compare_pipeline_output_to_hub_spec(output_element, ImageSegmentationOutputElement)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip(reason="Image segmentation not implemented in TF")
|
||||
def test_small_model_tf(self):
|
||||
|
@ -15,10 +15,12 @@
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from huggingface_hub import ImageToTextOutput
|
||||
|
||||
from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available
|
||||
from transformers.pipelines import ImageToTextPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
require_tf,
|
||||
require_torch,
|
||||
@ -103,6 +105,9 @@ class ImageToTextPipelineTests(unittest.TestCase):
|
||||
[{"generated_text": "growth"}],
|
||||
)
|
||||
|
||||
for single_output in outputs:
|
||||
compare_pipeline_output_to_hub_spec(single_output, ImageToTextOutput)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2")
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import ObjectDetectionOutputElement
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
@ -22,7 +24,8 @@ from transformers import (
|
||||
is_vision_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
from transformers.testing_utils import ( #
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_pytesseract,
|
||||
@ -101,6 +104,7 @@ class ObjectDetectionPipelineTests(unittest.TestCase):
|
||||
"box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)},
|
||||
},
|
||||
)
|
||||
compare_pipeline_output_to_hub_spec(detected_object, ObjectDetectionOutputElement)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip(reason="Object detection not implemented in TF")
|
||||
|
@ -14,9 +14,12 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import ZeroShotImageClassificationOutputElement
|
||||
|
||||
from transformers import is_vision_available
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
@ -127,6 +130,9 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
for single_output in output:
|
||||
compare_pipeline_output_to_hub_spec(single_output, ZeroShotImageClassificationOutputElement)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_fp16(self):
|
||||
self.test_small_model_pt(torch_dtype="float16")
|
||||
|
@ -25,9 +25,27 @@ from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import get_args
|
||||
|
||||
from huggingface_hub import AudioClassificationInput, AutomaticSpeechRecognitionInput
|
||||
from huggingface_hub import (
|
||||
AudioClassificationInput,
|
||||
AutomaticSpeechRecognitionInput,
|
||||
DepthEstimationInput,
|
||||
ImageClassificationInput,
|
||||
ImageSegmentationInput,
|
||||
ImageToTextInput,
|
||||
ObjectDetectionInput,
|
||||
ZeroShotImageClassificationInput,
|
||||
)
|
||||
|
||||
from transformers.pipelines import AudioClassificationPipeline, AutomaticSpeechRecognitionPipeline
|
||||
from transformers.pipelines import (
|
||||
AudioClassificationPipeline,
|
||||
AutomaticSpeechRecognitionPipeline,
|
||||
DepthEstimationPipeline,
|
||||
ImageClassificationPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
ImageToTextPipeline,
|
||||
ObjectDetectionPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_decord,
|
||||
@ -105,6 +123,12 @@ task_to_pipeline_and_spec_mapping = {
|
||||
# task spec in the HF Hub
|
||||
"audio-classification": (AudioClassificationPipeline, AudioClassificationInput),
|
||||
"automatic-speech-recognition": (AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionInput),
|
||||
"depth-estimation": (DepthEstimationPipeline, DepthEstimationInput),
|
||||
"image-classification": (ImageClassificationPipeline, ImageClassificationInput),
|
||||
"image-segmentation": (ImageSegmentationPipeline, ImageSegmentationInput),
|
||||
"image-to-text": (ImageToTextPipeline, ImageToTextInput),
|
||||
"object-detection": (ObjectDetectionPipeline, ObjectDetectionInput),
|
||||
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
|
||||
}
|
||||
|
||||
for task, task_info in pipeline_test_mapping.items():
|
||||
|
Loading…
Reference in New Issue
Block a user