Image Feature Extraction pipeline (#28216)

* Draft pipeline

* Fixup

* Fix docstrings

* Update doctest

* Update pipeline_model_mapping

* Update docstring

* Update tests

* Update src/transformers/pipelines/image_feature_extraction.py

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>

* Fix docstrings - review comments

* Remove pipeline mapping for composite vision models

* Add to pipeline tests

* Remove for flava (multimodal)

* safe pil import

* Add requirements for pipeline run

* Account for super slow efficientnet

* Review comments

* Fix tests

* Swap order of kwargs

* Use build_pipeline_init_args

* Add back FE pipeline for Vilt

* Include image_processor_kwargs in docstring

* Mark test as flaky

* Update TODO

* Update tests/pipelines/test_pipelines_image_feature_extraction.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add license header

---------

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
amyeroberts 2024-02-05 14:50:07 +00:00 committed by GitHub
parent 7addc9346c
commit ba3264b4e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 387 additions and 53 deletions

View File

@ -469,6 +469,12 @@ Pipelines available for multimodal tasks include the following.
- __call__
- all
### ImageFeatureExtractionPipeline
[[autodoc]] ImageFeatureExtractionPipeline
- __call__
- all
### ImageToTextPipeline
[[autodoc]] ImageToTextPipeline

View File

@ -25,7 +25,7 @@ Recognition、Masked Language Modeling、Sentiment Analysis、Feature Extraction
パイプラインの抽象化には2つのカテゴリーがある
- [`pipeline`] は、他のすべてのパイプラインをカプセル化する最も強力なオブジェクトです。
- タスク固有のパイプラインは、[オーディオ](#audio)、[コンピューター ビジョン](#computer-vision)、[自然言語処理](#natural-language-processing)、および [マルチモーダル](#multimodal) タスクで使用できます。
- タスク固有のパイプラインは、[オーディオ](#audio)、[コンピューター ビジョン](#computer-vision)、[自然言語処理](#natural-language-processing)、および [マルチモーダル](#multimodal) タスクで使用できます。
## The pipeline abstraction
@ -477,6 +477,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
- __call__
- all
### ImageFeatureExtractionPipeline
[[autodoc]] ImageFeatureExtractionPipeline
- __call__
- all
### ImageToTextPipeline
[[autodoc]] ImageToTextPipeline

View File

@ -435,7 +435,7 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all
## 多模态
## 多模态
可用于多模态任务的pipeline包括以下几种。
@ -451,6 +451,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all
### ImageFeatureExtractionPipeline
[[autodoc]] ImageFeatureExtractionPipeline
- __call__
- all
### ImageToTextPipeline
[[autodoc]] ImageToTextPipeline

View File

@ -973,6 +973,7 @@ _import_structure = {
"FeatureExtractionPipeline",
"FillMaskPipeline",
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageSegmentationPipeline",
"ImageToImagePipeline",
"ImageToTextPipeline",
@ -5709,6 +5710,7 @@ if TYPE_CHECKING:
FeatureExtractionPipeline,
FillMaskPipeline,
ImageClassificationPipeline,
ImageFeatureExtractionPipeline,
ImageSegmentationPipeline,
ImageToImagePipeline,
ImageToTextPipeline,

View File

@ -66,6 +66,7 @@ from .document_question_answering import DocumentQuestionAnsweringPipeline
from .feature_extraction import FeatureExtractionPipeline
from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .image_feature_extraction import ImageFeatureExtractionPipeline
from .image_segmentation import ImageSegmentationPipeline
from .image_to_image import ImageToImagePipeline
from .image_to_text import ImageToTextPipeline
@ -362,6 +363,18 @@ SUPPORTED_TASKS = {
},
"type": "image",
},
"image-feature-extraction": {
"impl": ImageFeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("google/vit-base-patch16-224", "29e7a1e183"),
"tf": ("google/vit-base-patch16-224", "29e7a1e183"),
}
},
"type": "image",
},
"image-segmentation": {
"impl": ImageSegmentationPipeline,
"tf": (),
@ -500,6 +513,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]:
- `"feature-extraction"`
- `"fill-mask"`
- `"image-classification"`
- `"image-feature-extraction"`
- `"image-segmentation"`
- `"image-to-text"`
- `"image-to-image"`
@ -586,6 +600,7 @@ def pipeline(
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
- `"image-to-text"`: will return a [`ImageToTextPipeline`].

View File

@ -14,7 +14,7 @@ from .base import GenericTensor, Pipeline, build_pipeline_init_args
)
class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
transformer, which can be used as features in downstream tasks.
Example:

View File

@ -0,0 +1,92 @@
from typing import Dict
from ..utils import add_end_docstrings, is_vision_available
from .base import GenericTensor, Pipeline, build_pipeline_init_args
if is_vision_available():
from ..image_utils import load_image
@add_end_docstrings(
build_pipeline_init_args(has_image_processor=True),
"""
image_processor_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the image processor e.g.
{"size": {"height": 100, "width": 100}}
""",
)
class ImageFeatureExtractionPipeline(Pipeline):
"""
Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
transformer, which can be used as features in downstream tasks.
Example:
```python
>>> from transformers import pipeline
>>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
>>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
>>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
torch.Size([1, 197, 768])
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
`"image-feature-extraction"`.
All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
[huggingface.co/models](https://huggingface.co/models).
"""
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, **kwargs):
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
postprocess_params = {"return_tensors": return_tensors} if return_tensors is not None else {}
if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]
return preprocess_params, {}, postprocess_params
def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
return model_inputs
def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs
def postprocess(self, model_outputs, return_tensors=False):
# [0] is the first available tensor, logits or last_hidden_state.
if return_tensors:
return model_outputs[0]
if self.framework == "pt":
return model_outputs[0].tolist()
elif self.framework == "tf":
return model_outputs[0].numpy().tolist()
def __call__(self, *args, **kwargs):
"""
Extract the features of the input(s).
Args:
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing a http link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
the call may block forever.
Return:
A nested list of `float`: The features computed by the model.
"""
return super().__call__(*args, **kwargs)

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from ..utils import (

View File

@ -242,7 +242,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
)
pipeline_model_mapping = (
{
"feature-extraction": BeitModel,
"image-feature-extraction": BeitModel,
"image-classification": BeitForImageClassification,
"image-segmentation": BeitForSemanticSegmentation,
}

View File

@ -162,7 +162,7 @@ class BitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (BitModel, BitForImageClassification, BitBackbone) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": BitModel, "image-classification": BitForImageClassification}
{"image-feature-extraction": BitModel, "image-classification": BitForImageClassification}
if is_torch_available()
else {}
)

View File

@ -429,7 +429,10 @@ class BlipModelTester:
class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (BlipModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": BlipModel, "image-to-text": BlipForConditionalGeneration}
{
"feature-extraction": BlipModel,
"image-to-text": BlipForConditionalGeneration,
}
if is_torch_available()
else {}
)

View File

@ -477,7 +477,9 @@ class CLIPModelTester:
@require_torch
class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": CLIPModel} if is_torch_available() else {}
pipeline_model_mapping = (
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
)
fx_compatible = True
test_head_masking = False
test_pruning = False

View File

@ -185,7 +185,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
{"image-feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
if is_torch_available()
else {}
)

View File

@ -172,7 +172,7 @@ class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
{"image-feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
if is_torch_available()
else {}
)

View File

@ -180,7 +180,7 @@ class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
{"image-feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
if is_torch_available()
else {}
)

View File

@ -151,7 +151,7 @@ class CvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
{"image-feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
if is_torch_available()
else {}
)

View File

@ -178,7 +178,7 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
)
pipeline_model_mapping = (
{
"feature-extraction": Data2VecVisionModel,
"image-feature-extraction": Data2VecVisionModel,
"image-classification": Data2VecVisionForImageClassification,
"image-segmentation": Data2VecVisionForSemanticSegmentation,
}

View File

@ -191,7 +191,7 @@ class DeformableDetrModelTester:
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
{"image-feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
if is_torch_available()
else {}
)

View File

@ -206,7 +206,7 @@ class DeiTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
)
pipeline_model_mapping = (
{
"feature-extraction": DeiTModel,
"image-feature-extraction": DeiTModel,
"image-classification": (DeiTForImageClassification, DeiTForImageClassificationWithTeacher),
}
if is_torch_available()

View File

@ -217,7 +217,7 @@ class DetaModelTester:
class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DetaModel, DetaForObjectDetection) if is_torchvision_available() else ()
pipeline_model_mapping = (
{"feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
{"image-feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
if is_torchvision_available()
else {}
)

View File

@ -182,7 +182,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
)
pipeline_model_mapping = (
{
"feature-extraction": DetrModel,
"image-feature-extraction": DetrModel,
"image-segmentation": DetrForSegmentation,
"object-detection": DetrForObjectDetection,
}

View File

@ -207,7 +207,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
{"image-feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
if is_torch_available()
else {}
)

View File

@ -217,7 +217,7 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
{"image-feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
if is_torch_available()
else {}
)

View File

@ -145,7 +145,7 @@ class DonutSwinModelTester:
@require_torch
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": DonutSwinModel} if is_torch_available() else {}
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
fx_compatible = True
test_pruning = False

View File

@ -163,7 +163,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_model_mapping = (
{
"depth-estimation": DPTForDepthEstimation,
"feature-extraction": DPTModel,
"image-feature-extraction": DPTModel,
"image-segmentation": DPTForSemanticSegmentation,
}
if is_torch_available()

View File

@ -190,7 +190,7 @@ class EfficientFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
)
pipeline_model_mapping = (
{
"feature-extraction": EfficientFormerModel,
"image-feature-extraction": EfficientFormerModel,
"image-classification": (
EfficientFormerForImageClassification,
EfficientFormerForImageClassificationWithTeacher,

View File

@ -130,7 +130,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
{"image-feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
if is_torch_available()
else {}
)
@ -216,6 +216,12 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
model = EfficientNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@is_pipeline_test
@require_vision
@slow
def test_pipeline_image_feature_extraction(self):
super().test_pipeline_image_feature_extraction()
@is_pipeline_test
@require_vision
@slow

View File

@ -238,7 +238,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
else ()
)
pipeline_model_mapping = (
{"feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
{"image-feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
if is_torch_available()
else {}
)

View File

@ -146,7 +146,9 @@ class GLPNModelTester:
class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else ()
pipeline_model_mapping = (
{"depth-estimation": GLPNForDepthEstimation, "feature-extraction": GLPNModel} if is_torch_available() else {}
{"depth-estimation": GLPNForDepthEstimation, "image-feature-extraction": GLPNModel}
if is_torch_available()
else {}
)
test_head_masking = False

View File

@ -271,7 +271,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
)
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
{"image-feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
if is_torch_available()
else {}
)

View File

@ -176,7 +176,7 @@ class LevitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
)
pipeline_model_mapping = (
{
"feature-extraction": LevitModel,
"image-feature-extraction": LevitModel,
"image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher),
}
if is_torch_available()

View File

@ -197,7 +197,7 @@ class Mask2FormerModelTester:
@require_torch
class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": Mask2FormerModel} if is_torch_available() else {}
pipeline_model_mapping = {"image-feature-extraction": Mask2FormerModel} if is_torch_available() else {}
is_encoder_decoder = False
test_pruning = False

View File

@ -197,7 +197,7 @@ class MaskFormerModelTester:
class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
{"image-feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
if is_torch_available()
else {}
)

View File

@ -31,7 +31,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import MgpstrForSceneTextRecognition
from transformers import MgpstrForSceneTextRecognition, MgpstrModel
if is_vision_available():
@ -118,7 +118,11 @@ class MgpstrModelTester:
@require_torch
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {}
pipeline_model_mapping = (
{"feature-extraction": MgpstrForSceneTextRecognition, "image-feature-extraction": MgpstrModel}
if is_torch_available()
else {}
)
fx_compatible = False
test_pruning = False

View File

@ -147,7 +147,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
all_model_classes = (MobileNetV1Model, MobileNetV1ForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
{"image-feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
if is_torch_available()
else {}
)

View File

@ -195,7 +195,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
)
pipeline_model_mapping = (
{
"feature-extraction": MobileNetV2Model,
"image-feature-extraction": MobileNetV2Model,
"image-classification": MobileNetV2ForImageClassification,
"image-segmentation": MobileNetV2ForSemanticSegmentation,
}

View File

@ -188,7 +188,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
)
pipeline_model_mapping = (
{
"feature-extraction": MobileViTModel,
"image-feature-extraction": MobileViTModel,
"image-classification": MobileViTForImageClassification,
"image-segmentation": MobileViTForSemanticSegmentation,
}

View File

@ -190,7 +190,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
pipeline_model_mapping = (
{
"feature-extraction": MobileViTV2Model,
"image-feature-extraction": MobileViTV2Model,
"image-classification": MobileViTV2ForImageClassification,
"image-segmentation": MobileViTV2ForSemanticSegmentation,
}

View File

@ -204,7 +204,7 @@ class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": NatModel, "image-classification": NatForImageClassification}
{"image-feature-extraction": NatModel, "image-classification": NatForImageClassification}
if is_torch_available()
else {}
)

View File

@ -433,7 +433,10 @@ class Owlv2ModelTester:
class Owlv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Owlv2Model,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": Owlv2Model, "zero-shot-object-detection": Owlv2ForObjectDetection}
{
"feature-extraction": Owlv2Model,
"zero-shot-object-detection": Owlv2ForObjectDetection,
}
if is_torch_available()
else {}
)

View File

@ -428,7 +428,10 @@ class OwlViTModelTester:
class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (OwlViTModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": OwlViTModel, "zero-shot-object-detection": OwlViTForObjectDetection}
{
"feature-extraction": OwlViTModel,
"zero-shot-object-detection": OwlViTForObjectDetection,
}
if is_torch_available()
else {}
)

View File

@ -124,7 +124,7 @@ class PoolFormerModelTester:
class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (PoolFormerModel, PoolFormerForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": PoolFormerModel, "image-classification": PoolFormerForImageClassification}
{"image-feature-extraction": PoolFormerModel, "image-classification": PoolFormerForImageClassification}
if is_torch_available()
else {}
)

View File

@ -158,7 +158,7 @@ def prepare_img():
class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (PvtModel, PvtForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": PvtModel, "image-classification": PvtForImageClassification}
{"image-feature-extraction": PvtModel, "image-classification": PvtForImageClassification}
if is_torch_available()
else {}
)

View File

@ -126,7 +126,7 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
{"image-feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
if is_torch_available()
else {}
)

View File

@ -170,7 +170,7 @@ class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
{"image-feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
if is_torch_available()
else {}
)

View File

@ -171,7 +171,7 @@ class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
)
pipeline_model_mapping = (
{
"feature-extraction": SegformerModel,
"image-feature-extraction": SegformerModel,
"image-classification": SegformerForImageClassification,
"image-segmentation": SegformerForSemanticSegmentation,
}

View File

@ -139,7 +139,7 @@ class SwiftFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
all_model_classes = (SwiftFormerModel, SwiftFormerForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": SwiftFormerModel, "image-classification": SwiftFormerForImageClassification}
{"image-feature-extraction": SwiftFormerModel, "image-classification": SwiftFormerForImageClassification}
if is_torch_available()
else {}
)

View File

@ -232,7 +232,7 @@ class SwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": SwinModel, "image-classification": SwinForImageClassification}
{"image-feature-extraction": SwinModel, "image-classification": SwinForImageClassification}
if is_torch_available()
else {}
)

View File

@ -162,7 +162,7 @@ class Swin2SRModelTester:
class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Swin2SRModel, Swin2SRForImageSuperResolution) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": Swin2SRModel, "image-to-image": Swin2SRForImageSuperResolution}
{"image-feature-extraction": Swin2SRModel, "image-to-image": Swin2SRForImageSuperResolution}
if is_torch_available()
else {}
)

View File

@ -217,7 +217,7 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
{"image-feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
if is_torch_available()
else {}
)

View File

@ -200,7 +200,7 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
else ()
)
pipeline_model_mapping = (
{"feature-extraction": TableTransformerModel, "object-detection": TableTransformerForObjectDetection}
{"image-feature-extraction": TableTransformerModel, "object-detection": TableTransformerForObjectDetection}
if is_torch_available()
else {}
)

View File

@ -228,7 +228,7 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ViltModel, "visual-question-answering": ViltForQuestionAnswering}
{"image-feature-extraction": ViltModel, "visual-question-answering": ViltForQuestionAnswering}
if is_torch_available()
else {}
)

View File

@ -193,7 +193,7 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ViTModel, "image-classification": ViTForImageClassification}
{"image-feature-extraction": ViTModel, "image-classification": ViTForImageClassification}
if is_torch_available()
else {}
)

View File

@ -156,7 +156,7 @@ class ViTHybridModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
all_model_classes = (ViTHybridModel, ViTHybridForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": ViTHybridModel, "image-classification": ViTHybridForImageClassification}
{"image-feature-extraction": ViTHybridModel, "image-classification": ViTHybridForImageClassification}
if is_torch_available()
else {}
)

View File

@ -164,7 +164,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
all_model_classes = (ViTMAEModel, ViTMAEForPreTraining) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": ViTMAEModel} if is_torch_available() else {}
pipeline_model_mapping = {"image-feature-extraction": ViTMAEModel} if is_torch_available() else {}
test_pruning = False
test_torchscript = False

View File

@ -152,7 +152,7 @@ class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (ViTMSNModel, ViTMSNForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": ViTMSNModel, "image-classification": ViTMSNForImageClassification}
{"image-feature-extraction": ViTMSNModel, "image-classification": ViTMSNForImageClassification}
if is_torch_available()
else {}
)

View File

@ -168,7 +168,9 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (YolosModel, YolosForObjectDetection) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": YolosModel, "object-detection": YolosForObjectDetection} if is_torch_available() else {}
{"image-feature-extraction": YolosModel, "object-detection": YolosForObjectDetection}
if is_torch_available()
else {}
)
test_pruning = False

View File

@ -0,0 +1,157 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import pytest
from transformers import (
MODEL_MAPPING,
TF_MODEL_MAPPING,
TOKENIZER_MAPPING,
ImageFeatureExtractionPipeline,
is_tf_available,
is_torch_available,
is_vision_available,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_vision_available():
from PIL import Image
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@is_pipeline_test
class ImageFeatureExtractionPipelineTests(unittest.TestCase):
model_mapping = MODEL_MAPPING
tf_model_mapping = TF_MODEL_MAPPING
@require_torch
def test_small_model_pt(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
)
img = prepare_img()
outputs = feature_extractor(img)
self.assertEqual(
nested_simplify(outputs[0][0]),
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
@require_tf
def test_small_model_tf(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
)
img = prepare_img()
outputs = feature_extractor(img)
self.assertEqual(
nested_simplify(outputs[0][0]),
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
@require_torch
def test_image_processing_small_model_pt(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
)
# test with image processor parameters
image_processor_kwargs = {"size": {"height": 300, "width": 300}}
img = prepare_img()
with pytest.raises(ValueError):
# Image doesn't match model input size
feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
image_processor_kwargs = {"image_mean": [0, 0, 0], "image_std": [1, 1, 1]}
img = prepare_img()
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
@require_tf
def test_image_processing_small_model_tf(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
)
# test with image processor parameters
image_processor_kwargs = {"size": {"height": 300, "width": 300}}
img = prepare_img()
with pytest.raises(ValueError):
# Image doesn't match model input size
feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
image_processor_kwargs = {"image_mean": [0, 0, 0], "image_std": [1, 1, 1]}
img = prepare_img()
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
@require_torch
def test_return_tensors_pt(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
)
img = prepare_img()
outputs = feature_extractor(img, return_tensors=True)
self.assertTrue(torch.is_tensor(outputs))
@require_tf
def test_return_tensors_tf(self):
feature_extractor = pipeline(
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
)
img = prepare_img()
outputs = feature_extractor(img, return_tensors=True)
self.assertTrue(tf.is_tensor(outputs))
def get_test_pipeline(self, model, tokenizer, processor):
if processor is None:
self.skipTest("No image processor")
elif type(model.config) in TOKENIZER_MAPPING:
self.skipTest("This is a bimodal model, we need to find a more consistent way to switch on those models.")
elif model.config.is_encoder_decoder:
self.skipTest(
"""encoder_decoder models are trickier for this pipeline.
Do we want encoder + decoder inputs to get some featues?
Do we want encoder only features ?
For now ignore those.
"""
)
feature_extractor = ImageFeatureExtractionPipeline(model=model, image_processor=processor)
img = prepare_img()
return feature_extractor, [img, img]
def run_pipeline_test(self, feature_extractor, examples):
imgs = examples
outputs = feature_extractor(imgs[0])
self.assertEqual(len(outputs), 1)
outputs = feature_extractor(imgs)
self.assertEqual(len(outputs), 2)

View File

@ -39,6 +39,7 @@ from .pipelines.test_pipelines_document_question_answering import DocumentQuesti
from .pipelines.test_pipelines_feature_extraction import FeatureExtractionPipelineTests
from .pipelines.test_pipelines_fill_mask import FillMaskPipelineTests
from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests
from .pipelines.test_pipelines_image_feature_extraction import ImageFeatureExtractionPipelineTests
from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests
from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests
from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests
@ -70,6 +71,7 @@ pipeline_test_mapping = {
"feature-extraction": {"test": FeatureExtractionPipelineTests},
"fill-mask": {"test": FillMaskPipelineTests},
"image-classification": {"test": ImageClassificationPipelineTests},
"image-feature-extraction": {"test": ImageFeatureExtractionPipelineTests},
"image-segmentation": {"test": ImageSegmentationPipelineTests},
"image-to-image": {"test": ImageToImagePipelineTests},
"image-to-text": {"test": ImageToTextPipelineTests},
@ -374,6 +376,13 @@ class PipelineTesterMixin:
def test_pipeline_image_to_text(self):
self.run_task_tests(task="image-to-text")
@is_pipeline_test
@require_timm
@require_vision
@require_torch
def test_pipeline_image_feature_extraction(self):
self.run_task_tests(task="image-feature-extraction")
@unittest.skip(reason="`run_pipeline_test` is currently not implemented.")
@is_pipeline_test
@require_vision

View File

@ -324,6 +324,7 @@ OBJECTS_TO_IGNORE = [
"IdeficsConfig",
"IdeficsProcessor",
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageGPTConfig",
"ImageSegmentationPipeline",
"ImageToImagePipeline",