mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Image Feature Extraction pipeline (#28216)
* Draft pipeline * Fixup * Fix docstrings * Update doctest * Update pipeline_model_mapping * Update docstring * Update tests * Update src/transformers/pipelines/image_feature_extraction.py Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> * Fix docstrings - review comments * Remove pipeline mapping for composite vision models * Add to pipeline tests * Remove for flava (multimodal) * safe pil import * Add requirements for pipeline run * Account for super slow efficientnet * Review comments * Fix tests * Swap order of kwargs * Use build_pipeline_init_args * Add back FE pipeline for Vilt * Include image_processor_kwargs in docstring * Mark test as flaky * Update TODO * Update tests/pipelines/test_pipelines_image_feature_extraction.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add license header --------- Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
7addc9346c
commit
ba3264b4e8
@ -469,6 +469,12 @@ Pipelines available for multimodal tasks include the following.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageFeatureExtractionPipeline
|
||||
|
||||
[[autodoc]] ImageFeatureExtractionPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageToTextPipeline
|
||||
|
||||
[[autodoc]] ImageToTextPipeline
|
||||
|
@ -25,7 +25,7 @@ Recognition、Masked Language Modeling、Sentiment Analysis、Feature Extraction
|
||||
パイプラインの抽象化には2つのカテゴリーがある:
|
||||
|
||||
- [`pipeline`] は、他のすべてのパイプラインをカプセル化する最も強力なオブジェクトです。
|
||||
- タスク固有のパイプラインは、[オーディオ](#audio)、[コンピューター ビジョン](#computer-vision)、[自然言語処理](#natural-language-processing)、および [マルチモーダル](#multimodal) タスクで使用できます。
|
||||
- タスク固有のパイプラインは、[オーディオ](#audio)、[コンピューター ビジョン](#computer-vision)、[自然言語処理](#natural-language-processing)、および [マルチモーダル](#multimodal) タスクで使用できます。
|
||||
|
||||
## The pipeline abstraction
|
||||
|
||||
@ -477,6 +477,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageFeatureExtractionPipeline
|
||||
|
||||
[[autodoc]] ImageFeatureExtractionPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageToTextPipeline
|
||||
|
||||
[[autodoc]] ImageToTextPipeline
|
||||
|
@ -435,7 +435,7 @@ See [`TokenClassificationPipeline`] for all details.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
## 多模态
|
||||
## 多模态
|
||||
|
||||
可用于多模态任务的pipeline包括以下几种。
|
||||
|
||||
@ -451,6 +451,12 @@ See [`TokenClassificationPipeline`] for all details.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageFeatureExtractionPipeline
|
||||
|
||||
[[autodoc]] ImageFeatureExtractionPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ImageToTextPipeline
|
||||
|
||||
[[autodoc]] ImageToTextPipeline
|
||||
|
@ -973,6 +973,7 @@ _import_structure = {
|
||||
"FeatureExtractionPipeline",
|
||||
"FillMaskPipeline",
|
||||
"ImageClassificationPipeline",
|
||||
"ImageFeatureExtractionPipeline",
|
||||
"ImageSegmentationPipeline",
|
||||
"ImageToImagePipeline",
|
||||
"ImageToTextPipeline",
|
||||
@ -5709,6 +5710,7 @@ if TYPE_CHECKING:
|
||||
FeatureExtractionPipeline,
|
||||
FillMaskPipeline,
|
||||
ImageClassificationPipeline,
|
||||
ImageFeatureExtractionPipeline,
|
||||
ImageSegmentationPipeline,
|
||||
ImageToImagePipeline,
|
||||
ImageToTextPipeline,
|
||||
|
@ -66,6 +66,7 @@ from .document_question_answering import DocumentQuestionAnsweringPipeline
|
||||
from .feature_extraction import FeatureExtractionPipeline
|
||||
from .fill_mask import FillMaskPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .image_feature_extraction import ImageFeatureExtractionPipeline
|
||||
from .image_segmentation import ImageSegmentationPipeline
|
||||
from .image_to_image import ImageToImagePipeline
|
||||
from .image_to_text import ImageToTextPipeline
|
||||
@ -362,6 +363,18 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"type": "image",
|
||||
},
|
||||
"image-feature-extraction": {
|
||||
"impl": ImageFeatureExtractionPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": ("google/vit-base-patch16-224", "29e7a1e183"),
|
||||
"tf": ("google/vit-base-patch16-224", "29e7a1e183"),
|
||||
}
|
||||
},
|
||||
"type": "image",
|
||||
},
|
||||
"image-segmentation": {
|
||||
"impl": ImageSegmentationPipeline,
|
||||
"tf": (),
|
||||
@ -500,6 +513,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]:
|
||||
- `"feature-extraction"`
|
||||
- `"fill-mask"`
|
||||
- `"image-classification"`
|
||||
- `"image-feature-extraction"`
|
||||
- `"image-segmentation"`
|
||||
- `"image-to-text"`
|
||||
- `"image-to-image"`
|
||||
@ -586,6 +600,7 @@ def pipeline(
|
||||
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
|
||||
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
|
||||
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
||||
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
|
||||
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
||||
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
|
||||
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
||||
|
@ -14,7 +14,7 @@ from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
||||
)
|
||||
class FeatureExtractionPipeline(Pipeline):
|
||||
"""
|
||||
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
|
||||
Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
||||
transformer, which can be used as features in downstream tasks.
|
||||
|
||||
Example:
|
||||
|
92
src/transformers/pipelines/image_feature_extraction.py
Normal file
92
src/transformers/pipelines/image_feature_extraction.py
Normal file
@ -0,0 +1,92 @@
|
||||
from typing import Dict
|
||||
|
||||
from ..utils import add_end_docstrings, is_vision_available
|
||||
from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ..image_utils import load_image
|
||||
|
||||
|
||||
@add_end_docstrings(
|
||||
build_pipeline_init_args(has_image_processor=True),
|
||||
"""
|
||||
image_processor_kwargs (`dict`, *optional*):
|
||||
Additional dictionary of keyword arguments passed along to the image processor e.g.
|
||||
{"size": {"height": 100, "width": 100}}
|
||||
""",
|
||||
)
|
||||
class ImageFeatureExtractionPipeline(Pipeline):
|
||||
"""
|
||||
Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
||||
transformer, which can be used as features in downstream tasks.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
|
||||
>>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
|
||||
>>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
|
||||
torch.Size([1, 197, 768])
|
||||
```
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
|
||||
`"image-feature-extraction"`.
|
||||
|
||||
All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
|
||||
[huggingface.co/models](https://huggingface.co/models).
|
||||
"""
|
||||
|
||||
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, **kwargs):
|
||||
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
|
||||
postprocess_params = {"return_tensors": return_tensors} if return_tensors is not None else {}
|
||||
|
||||
if "timeout" in kwargs:
|
||||
preprocess_params["timeout"] = kwargs["timeout"]
|
||||
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
|
||||
image = load_image(image, timeout=timeout)
|
||||
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
model_outputs = self.model(**model_inputs)
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs, return_tensors=False):
|
||||
# [0] is the first available tensor, logits or last_hidden_state.
|
||||
if return_tensors:
|
||||
return model_outputs[0]
|
||||
if self.framework == "pt":
|
||||
return model_outputs[0].tolist()
|
||||
elif self.framework == "tf":
|
||||
return model_outputs[0].numpy().tolist()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Extract the features of the input(s).
|
||||
|
||||
Args:
|
||||
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing a http link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL directly
|
||||
|
||||
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
||||
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||
images.
|
||||
timeout (`float`, *optional*, defaults to None):
|
||||
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
|
||||
the call may block forever.
|
||||
Return:
|
||||
A nested list of `float`: The features computed by the model.
|
||||
"""
|
||||
return super().__call__(*args, **kwargs)
|
@ -1,3 +1,18 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import (
|
||||
|
@ -242,7 +242,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": BeitModel,
|
||||
"image-feature-extraction": BeitModel,
|
||||
"image-classification": BeitForImageClassification,
|
||||
"image-segmentation": BeitForSemanticSegmentation,
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ class BitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BitModel, BitForImageClassification, BitBackbone) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": BitModel, "image-classification": BitForImageClassification}
|
||||
{"image-feature-extraction": BitModel, "image-classification": BitForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -429,7 +429,10 @@ class BlipModelTester:
|
||||
class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (BlipModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": BlipModel, "image-to-text": BlipForConditionalGeneration}
|
||||
{
|
||||
"feature-extraction": BlipModel,
|
||||
"image-to-text": BlipForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -477,7 +477,9 @@ class CLIPModelTester:
|
||||
@require_torch
|
||||
class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (CLIPModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": CLIPModel} if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
@ -185,7 +185,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
|
||||
{"image-feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -172,7 +172,7 @@ class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
|
||||
{"image-feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -180,7 +180,7 @@ class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
|
||||
{"image-feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -151,7 +151,7 @@ class CvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
|
||||
{"image-feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -178,7 +178,7 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Data2VecVisionModel,
|
||||
"image-feature-extraction": Data2VecVisionModel,
|
||||
"image-classification": Data2VecVisionForImageClassification,
|
||||
"image-segmentation": Data2VecVisionForSemanticSegmentation,
|
||||
}
|
||||
|
@ -191,7 +191,7 @@ class DeformableDetrModelTester:
|
||||
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
|
||||
{"image-feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -206,7 +206,7 @@ class DeiTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": DeiTModel,
|
||||
"image-feature-extraction": DeiTModel,
|
||||
"image-classification": (DeiTForImageClassification, DeiTForImageClassificationWithTeacher),
|
||||
}
|
||||
if is_torch_available()
|
||||
|
@ -217,7 +217,7 @@ class DetaModelTester:
|
||||
class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DetaModel, DetaForObjectDetection) if is_torchvision_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
|
||||
{"image-feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
|
||||
if is_torchvision_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -182,7 +182,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": DetrModel,
|
||||
"image-feature-extraction": DetrModel,
|
||||
"image-segmentation": DetrForSegmentation,
|
||||
"object-detection": DetrForObjectDetection,
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
|
||||
{"image-feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -217,7 +217,7 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
|
||||
{"image-feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -145,7 +145,7 @@ class DonutSwinModelTester:
|
||||
@require_torch
|
||||
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": DonutSwinModel} if is_torch_available() else {}
|
||||
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
|
||||
fx_compatible = True
|
||||
|
||||
test_pruning = False
|
||||
|
@ -163,7 +163,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"depth-estimation": DPTForDepthEstimation,
|
||||
"feature-extraction": DPTModel,
|
||||
"image-feature-extraction": DPTModel,
|
||||
"image-segmentation": DPTForSemanticSegmentation,
|
||||
}
|
||||
if is_torch_available()
|
||||
|
@ -190,7 +190,7 @@ class EfficientFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": EfficientFormerModel,
|
||||
"image-feature-extraction": EfficientFormerModel,
|
||||
"image-classification": (
|
||||
EfficientFormerForImageClassification,
|
||||
EfficientFormerForImageClassificationWithTeacher,
|
||||
|
@ -130,7 +130,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
|
||||
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
|
||||
{"image-feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
@ -216,6 +216,12 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
model = EfficientNetModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@slow
|
||||
def test_pipeline_image_feature_extraction(self):
|
||||
super().test_pipeline_image_feature_extraction()
|
||||
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
@slow
|
||||
|
@ -238,7 +238,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
|
||||
{"image-feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -146,7 +146,9 @@ class GLPNModelTester:
|
||||
class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"depth-estimation": GLPNForDepthEstimation, "feature-extraction": GLPNModel} if is_torch_available() else {}
|
||||
{"depth-estimation": GLPNForDepthEstimation, "image-feature-extraction": GLPNModel}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
test_head_masking = False
|
||||
|
@ -271,7 +271,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
)
|
||||
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
|
||||
{"image-feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -176,7 +176,7 @@ class LevitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": LevitModel,
|
||||
"image-feature-extraction": LevitModel,
|
||||
"image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher),
|
||||
}
|
||||
if is_torch_available()
|
||||
|
@ -197,7 +197,7 @@ class Mask2FormerModelTester:
|
||||
@require_torch
|
||||
class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": Mask2FormerModel} if is_torch_available() else {}
|
||||
pipeline_model_mapping = {"image-feature-extraction": Mask2FormerModel} if is_torch_available() else {}
|
||||
|
||||
is_encoder_decoder = False
|
||||
test_pruning = False
|
||||
|
@ -197,7 +197,7 @@ class MaskFormerModelTester:
|
||||
class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
|
||||
{"image-feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import MgpstrForSceneTextRecognition
|
||||
from transformers import MgpstrForSceneTextRecognition, MgpstrModel
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -118,7 +118,11 @@ class MgpstrModelTester:
|
||||
@require_torch
|
||||
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": MgpstrForSceneTextRecognition, "image-feature-extraction": MgpstrModel}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = False
|
||||
|
||||
test_pruning = False
|
||||
|
@ -147,7 +147,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
|
||||
all_model_classes = (MobileNetV1Model, MobileNetV1ForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
|
||||
{"image-feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -195,7 +195,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MobileNetV2Model,
|
||||
"image-feature-extraction": MobileNetV2Model,
|
||||
"image-classification": MobileNetV2ForImageClassification,
|
||||
"image-segmentation": MobileNetV2ForSemanticSegmentation,
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MobileViTModel,
|
||||
"image-feature-extraction": MobileViTModel,
|
||||
"image-classification": MobileViTForImageClassification,
|
||||
"image-segmentation": MobileViTForSemanticSegmentation,
|
||||
}
|
||||
|
@ -190,7 +190,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MobileViTV2Model,
|
||||
"image-feature-extraction": MobileViTV2Model,
|
||||
"image-classification": MobileViTV2ForImageClassification,
|
||||
"image-segmentation": MobileViTV2ForSemanticSegmentation,
|
||||
}
|
||||
|
@ -204,7 +204,7 @@ class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": NatModel, "image-classification": NatForImageClassification}
|
||||
{"image-feature-extraction": NatModel, "image-classification": NatForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -433,7 +433,10 @@ class Owlv2ModelTester:
|
||||
class Owlv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Owlv2Model,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": Owlv2Model, "zero-shot-object-detection": Owlv2ForObjectDetection}
|
||||
{
|
||||
"feature-extraction": Owlv2Model,
|
||||
"zero-shot-object-detection": Owlv2ForObjectDetection,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -428,7 +428,10 @@ class OwlViTModelTester:
|
||||
class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (OwlViTModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": OwlViTModel, "zero-shot-object-detection": OwlViTForObjectDetection}
|
||||
{
|
||||
"feature-extraction": OwlViTModel,
|
||||
"zero-shot-object-detection": OwlViTForObjectDetection,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -124,7 +124,7 @@ class PoolFormerModelTester:
|
||||
class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (PoolFormerModel, PoolFormerForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": PoolFormerModel, "image-classification": PoolFormerForImageClassification}
|
||||
{"image-feature-extraction": PoolFormerModel, "image-classification": PoolFormerForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -158,7 +158,7 @@ def prepare_img():
|
||||
class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (PvtModel, PvtForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": PvtModel, "image-classification": PvtForImageClassification}
|
||||
{"image-feature-extraction": PvtModel, "image-classification": PvtForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -126,7 +126,7 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
|
||||
{"image-feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -170,7 +170,7 @@ class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
|
||||
{"image-feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -171,7 +171,7 @@ class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": SegformerModel,
|
||||
"image-feature-extraction": SegformerModel,
|
||||
"image-classification": SegformerForImageClassification,
|
||||
"image-segmentation": SegformerForSemanticSegmentation,
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ class SwiftFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
|
||||
all_model_classes = (SwiftFormerModel, SwiftFormerForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": SwiftFormerModel, "image-classification": SwiftFormerForImageClassification}
|
||||
{"image-feature-extraction": SwiftFormerModel, "image-classification": SwiftFormerForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -232,7 +232,7 @@ class SwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": SwinModel, "image-classification": SwinForImageClassification}
|
||||
{"image-feature-extraction": SwinModel, "image-classification": SwinForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -162,7 +162,7 @@ class Swin2SRModelTester:
|
||||
class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Swin2SRModel, Swin2SRForImageSuperResolution) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": Swin2SRModel, "image-to-image": Swin2SRForImageSuperResolution}
|
||||
{"image-feature-extraction": Swin2SRModel, "image-to-image": Swin2SRForImageSuperResolution}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -217,7 +217,7 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
|
||||
{"image-feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -200,7 +200,7 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": TableTransformerModel, "object-detection": TableTransformerForObjectDetection}
|
||||
{"image-feature-extraction": TableTransformerModel, "object-detection": TableTransformerForObjectDetection}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -228,7 +228,7 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ViltModel, "visual-question-answering": ViltForQuestionAnswering}
|
||||
{"image-feature-extraction": ViltModel, "visual-question-answering": ViltForQuestionAnswering}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -193,7 +193,7 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ViTModel, "image-classification": ViTForImageClassification}
|
||||
{"image-feature-extraction": ViTModel, "image-classification": ViTForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -156,7 +156,7 @@ class ViTHybridModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
|
||||
all_model_classes = (ViTHybridModel, ViTHybridForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ViTHybridModel, "image-classification": ViTHybridForImageClassification}
|
||||
{"image-feature-extraction": ViTHybridModel, "image-classification": ViTHybridForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -164,7 +164,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
|
||||
all_model_classes = (ViTMAEModel, ViTMAEForPreTraining) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": ViTMAEModel} if is_torch_available() else {}
|
||||
pipeline_model_mapping = {"image-feature-extraction": ViTMAEModel} if is_torch_available() else {}
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
@ -152,7 +152,7 @@ class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (ViTMSNModel, ViTMSNForImageClassification) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": ViTMSNModel, "image-classification": ViTMSNForImageClassification}
|
||||
{"image-feature-extraction": ViTMSNModel, "image-classification": ViTMSNForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
@ -168,7 +168,9 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (YolosModel, YolosForObjectDetection) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": YolosModel, "object-detection": YolosForObjectDetection} if is_torch_available() else {}
|
||||
{"image-feature-extraction": YolosModel, "object-detection": YolosForObjectDetection}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
test_pruning = False
|
||||
|
157
tests/pipelines/test_pipelines_image_feature_extraction.py
Normal file
157
tests/pipelines/test_pipelines_image_feature_extraction.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
MODEL_MAPPING,
|
||||
TF_MODEL_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
ImageFeatureExtractionPipeline,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_MAPPING
|
||||
tf_model_mapping = TF_MODEL_MAPPING
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
|
||||
)
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs[0][0]),
|
||||
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
|
||||
)
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs[0][0]),
|
||||
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
|
||||
|
||||
@require_torch
|
||||
def test_image_processing_small_model_pt(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
|
||||
)
|
||||
|
||||
# test with image processor parameters
|
||||
image_processor_kwargs = {"size": {"height": 300, "width": 300}}
|
||||
img = prepare_img()
|
||||
with pytest.raises(ValueError):
|
||||
# Image doesn't match model input size
|
||||
feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||
|
||||
image_processor_kwargs = {"image_mean": [0, 0, 0], "image_std": [1, 1, 1]}
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
|
||||
|
||||
@require_tf
|
||||
def test_image_processing_small_model_tf(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
|
||||
)
|
||||
|
||||
# test with image processor parameters
|
||||
image_processor_kwargs = {"size": {"height": 300, "width": 300}}
|
||||
img = prepare_img()
|
||||
with pytest.raises(ValueError):
|
||||
# Image doesn't match model input size
|
||||
feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||
|
||||
image_processor_kwargs = {"image_mean": [0, 0, 0], "image_std": [1, 1, 1]}
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
|
||||
|
||||
@require_torch
|
||||
def test_return_tensors_pt(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
|
||||
)
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img, return_tensors=True)
|
||||
self.assertTrue(torch.is_tensor(outputs))
|
||||
|
||||
@require_tf
|
||||
def test_return_tensors_tf(self):
|
||||
feature_extractor = pipeline(
|
||||
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
|
||||
)
|
||||
img = prepare_img()
|
||||
outputs = feature_extractor(img, return_tensors=True)
|
||||
self.assertTrue(tf.is_tensor(outputs))
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
if processor is None:
|
||||
self.skipTest("No image processor")
|
||||
|
||||
elif type(model.config) in TOKENIZER_MAPPING:
|
||||
self.skipTest("This is a bimodal model, we need to find a more consistent way to switch on those models.")
|
||||
|
||||
elif model.config.is_encoder_decoder:
|
||||
self.skipTest(
|
||||
"""encoder_decoder models are trickier for this pipeline.
|
||||
Do we want encoder + decoder inputs to get some featues?
|
||||
Do we want encoder only features ?
|
||||
For now ignore those.
|
||||
"""
|
||||
)
|
||||
|
||||
feature_extractor = ImageFeatureExtractionPipeline(model=model, image_processor=processor)
|
||||
img = prepare_img()
|
||||
return feature_extractor, [img, img]
|
||||
|
||||
def run_pipeline_test(self, feature_extractor, examples):
|
||||
imgs = examples
|
||||
outputs = feature_extractor(imgs[0])
|
||||
|
||||
self.assertEqual(len(outputs), 1)
|
||||
|
||||
outputs = feature_extractor(imgs)
|
||||
self.assertEqual(len(outputs), 2)
|
@ -39,6 +39,7 @@ from .pipelines.test_pipelines_document_question_answering import DocumentQuesti
|
||||
from .pipelines.test_pipelines_feature_extraction import FeatureExtractionPipelineTests
|
||||
from .pipelines.test_pipelines_fill_mask import FillMaskPipelineTests
|
||||
from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests
|
||||
from .pipelines.test_pipelines_image_feature_extraction import ImageFeatureExtractionPipelineTests
|
||||
from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests
|
||||
from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests
|
||||
from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests
|
||||
@ -70,6 +71,7 @@ pipeline_test_mapping = {
|
||||
"feature-extraction": {"test": FeatureExtractionPipelineTests},
|
||||
"fill-mask": {"test": FillMaskPipelineTests},
|
||||
"image-classification": {"test": ImageClassificationPipelineTests},
|
||||
"image-feature-extraction": {"test": ImageFeatureExtractionPipelineTests},
|
||||
"image-segmentation": {"test": ImageSegmentationPipelineTests},
|
||||
"image-to-image": {"test": ImageToImagePipelineTests},
|
||||
"image-to-text": {"test": ImageToTextPipelineTests},
|
||||
@ -374,6 +376,13 @@ class PipelineTesterMixin:
|
||||
def test_pipeline_image_to_text(self):
|
||||
self.run_task_tests(task="image-to-text")
|
||||
|
||||
@is_pipeline_test
|
||||
@require_timm
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_pipeline_image_feature_extraction(self):
|
||||
self.run_task_tests(task="image-feature-extraction")
|
||||
|
||||
@unittest.skip(reason="`run_pipeline_test` is currently not implemented.")
|
||||
@is_pipeline_test
|
||||
@require_vision
|
||||
|
@ -324,6 +324,7 @@ OBJECTS_TO_IGNORE = [
|
||||
"IdeficsConfig",
|
||||
"IdeficsProcessor",
|
||||
"ImageClassificationPipeline",
|
||||
"ImageFeatureExtractionPipeline",
|
||||
"ImageGPTConfig",
|
||||
"ImageSegmentationPipeline",
|
||||
"ImageToImagePipeline",
|
||||
|
Loading…
Reference in New Issue
Block a user