mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
Image Feature Extraction pipeline (#28216)
* Draft pipeline * Fixup * Fix docstrings * Update doctest * Update pipeline_model_mapping * Update docstring * Update tests * Update src/transformers/pipelines/image_feature_extraction.py Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> * Fix docstrings - review comments * Remove pipeline mapping for composite vision models * Add to pipeline tests * Remove for flava (multimodal) * safe pil import * Add requirements for pipeline run * Account for super slow efficientnet * Review comments * Fix tests * Swap order of kwargs * Use build_pipeline_init_args * Add back FE pipeline for Vilt * Include image_processor_kwargs in docstring * Mark test as flaky * Update TODO * Update tests/pipelines/test_pipelines_image_feature_extraction.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add license header --------- Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
7addc9346c
commit
ba3264b4e8
@ -469,6 +469,12 @@ Pipelines available for multimodal tasks include the following.
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### ImageFeatureExtractionPipeline
|
||||||
|
|
||||||
|
[[autodoc]] ImageFeatureExtractionPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
### ImageToTextPipeline
|
### ImageToTextPipeline
|
||||||
|
|
||||||
[[autodoc]] ImageToTextPipeline
|
[[autodoc]] ImageToTextPipeline
|
||||||
|
@ -477,6 +477,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### ImageFeatureExtractionPipeline
|
||||||
|
|
||||||
|
[[autodoc]] ImageFeatureExtractionPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
### ImageToTextPipeline
|
### ImageToTextPipeline
|
||||||
|
|
||||||
[[autodoc]] ImageToTextPipeline
|
[[autodoc]] ImageToTextPipeline
|
||||||
|
@ -451,6 +451,12 @@ See [`TokenClassificationPipeline`] for all details.
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### ImageFeatureExtractionPipeline
|
||||||
|
|
||||||
|
[[autodoc]] ImageFeatureExtractionPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
### ImageToTextPipeline
|
### ImageToTextPipeline
|
||||||
|
|
||||||
[[autodoc]] ImageToTextPipeline
|
[[autodoc]] ImageToTextPipeline
|
||||||
|
@ -973,6 +973,7 @@ _import_structure = {
|
|||||||
"FeatureExtractionPipeline",
|
"FeatureExtractionPipeline",
|
||||||
"FillMaskPipeline",
|
"FillMaskPipeline",
|
||||||
"ImageClassificationPipeline",
|
"ImageClassificationPipeline",
|
||||||
|
"ImageFeatureExtractionPipeline",
|
||||||
"ImageSegmentationPipeline",
|
"ImageSegmentationPipeline",
|
||||||
"ImageToImagePipeline",
|
"ImageToImagePipeline",
|
||||||
"ImageToTextPipeline",
|
"ImageToTextPipeline",
|
||||||
@ -5709,6 +5710,7 @@ if TYPE_CHECKING:
|
|||||||
FeatureExtractionPipeline,
|
FeatureExtractionPipeline,
|
||||||
FillMaskPipeline,
|
FillMaskPipeline,
|
||||||
ImageClassificationPipeline,
|
ImageClassificationPipeline,
|
||||||
|
ImageFeatureExtractionPipeline,
|
||||||
ImageSegmentationPipeline,
|
ImageSegmentationPipeline,
|
||||||
ImageToImagePipeline,
|
ImageToImagePipeline,
|
||||||
ImageToTextPipeline,
|
ImageToTextPipeline,
|
||||||
|
@ -66,6 +66,7 @@ from .document_question_answering import DocumentQuestionAnsweringPipeline
|
|||||||
from .feature_extraction import FeatureExtractionPipeline
|
from .feature_extraction import FeatureExtractionPipeline
|
||||||
from .fill_mask import FillMaskPipeline
|
from .fill_mask import FillMaskPipeline
|
||||||
from .image_classification import ImageClassificationPipeline
|
from .image_classification import ImageClassificationPipeline
|
||||||
|
from .image_feature_extraction import ImageFeatureExtractionPipeline
|
||||||
from .image_segmentation import ImageSegmentationPipeline
|
from .image_segmentation import ImageSegmentationPipeline
|
||||||
from .image_to_image import ImageToImagePipeline
|
from .image_to_image import ImageToImagePipeline
|
||||||
from .image_to_text import ImageToTextPipeline
|
from .image_to_text import ImageToTextPipeline
|
||||||
@ -362,6 +363,18 @@ SUPPORTED_TASKS = {
|
|||||||
},
|
},
|
||||||
"type": "image",
|
"type": "image",
|
||||||
},
|
},
|
||||||
|
"image-feature-extraction": {
|
||||||
|
"impl": ImageFeatureExtractionPipeline,
|
||||||
|
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||||
|
"pt": (AutoModel,) if is_torch_available() else (),
|
||||||
|
"default": {
|
||||||
|
"model": {
|
||||||
|
"pt": ("google/vit-base-patch16-224", "29e7a1e183"),
|
||||||
|
"tf": ("google/vit-base-patch16-224", "29e7a1e183"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "image",
|
||||||
|
},
|
||||||
"image-segmentation": {
|
"image-segmentation": {
|
||||||
"impl": ImageSegmentationPipeline,
|
"impl": ImageSegmentationPipeline,
|
||||||
"tf": (),
|
"tf": (),
|
||||||
@ -500,6 +513,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]:
|
|||||||
- `"feature-extraction"`
|
- `"feature-extraction"`
|
||||||
- `"fill-mask"`
|
- `"fill-mask"`
|
||||||
- `"image-classification"`
|
- `"image-classification"`
|
||||||
|
- `"image-feature-extraction"`
|
||||||
- `"image-segmentation"`
|
- `"image-segmentation"`
|
||||||
- `"image-to-text"`
|
- `"image-to-text"`
|
||||||
- `"image-to-image"`
|
- `"image-to-image"`
|
||||||
@ -586,6 +600,7 @@ def pipeline(
|
|||||||
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
|
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
|
||||||
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
|
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
|
||||||
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
||||||
|
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
|
||||||
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
||||||
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
|
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
|
||||||
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
||||||
|
@ -14,7 +14,7 @@ from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
|||||||
)
|
)
|
||||||
class FeatureExtractionPipeline(Pipeline):
|
class FeatureExtractionPipeline(Pipeline):
|
||||||
"""
|
"""
|
||||||
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
|
Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
||||||
transformer, which can be used as features in downstream tasks.
|
transformer, which can be used as features in downstream tasks.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
92
src/transformers/pipelines/image_feature_extraction.py
Normal file
92
src/transformers/pipelines/image_feature_extraction.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from ..utils import add_end_docstrings, is_vision_available
|
||||||
|
from .base import GenericTensor, Pipeline, build_pipeline_init_args
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from ..image_utils import load_image
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(
|
||||||
|
build_pipeline_init_args(has_image_processor=True),
|
||||||
|
"""
|
||||||
|
image_processor_kwargs (`dict`, *optional*):
|
||||||
|
Additional dictionary of keyword arguments passed along to the image processor e.g.
|
||||||
|
{"size": {"height": 100, "width": 100}}
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
class ImageFeatureExtractionPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
|
||||||
|
transformer, which can be used as features in downstream tasks.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import pipeline
|
||||||
|
|
||||||
|
>>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
|
||||||
|
>>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
|
||||||
|
>>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
|
||||||
|
torch.Size([1, 197, 768])
|
||||||
|
```
|
||||||
|
|
||||||
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||||
|
|
||||||
|
This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
|
||||||
|
`"image-feature-extraction"`.
|
||||||
|
|
||||||
|
All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
|
||||||
|
[huggingface.co/models](https://huggingface.co/models).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, **kwargs):
|
||||||
|
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
|
||||||
|
postprocess_params = {"return_tensors": return_tensors} if return_tensors is not None else {}
|
||||||
|
|
||||||
|
if "timeout" in kwargs:
|
||||||
|
preprocess_params["timeout"] = kwargs["timeout"]
|
||||||
|
|
||||||
|
return preprocess_params, {}, postprocess_params
|
||||||
|
|
||||||
|
def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
|
||||||
|
image = load_image(image, timeout=timeout)
|
||||||
|
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
model_outputs = self.model(**model_inputs)
|
||||||
|
return model_outputs
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs, return_tensors=False):
|
||||||
|
# [0] is the first available tensor, logits or last_hidden_state.
|
||||||
|
if return_tensors:
|
||||||
|
return model_outputs[0]
|
||||||
|
if self.framework == "pt":
|
||||||
|
return model_outputs[0].tolist()
|
||||||
|
elif self.framework == "tf":
|
||||||
|
return model_outputs[0].numpy().tolist()
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Extract the features of the input(s).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
|
||||||
|
The pipeline handles three types of images:
|
||||||
|
|
||||||
|
- A string containing a http link pointing to an image
|
||||||
|
- A string containing a local path to an image
|
||||||
|
- An image loaded in PIL directly
|
||||||
|
|
||||||
|
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
|
||||||
|
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
|
||||||
|
images.
|
||||||
|
timeout (`float`, *optional*, defaults to None):
|
||||||
|
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
|
||||||
|
the call may block forever.
|
||||||
|
Return:
|
||||||
|
A nested list of `float`: The features computed by the model.
|
||||||
|
"""
|
||||||
|
return super().__call__(*args, **kwargs)
|
@ -1,3 +1,18 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
@ -242,7 +242,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": BeitModel,
|
"image-feature-extraction": BeitModel,
|
||||||
"image-classification": BeitForImageClassification,
|
"image-classification": BeitForImageClassification,
|
||||||
"image-segmentation": BeitForSemanticSegmentation,
|
"image-segmentation": BeitForSemanticSegmentation,
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ class BitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (BitModel, BitForImageClassification, BitBackbone) if is_torch_available() else ()
|
all_model_classes = (BitModel, BitForImageClassification, BitBackbone) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": BitModel, "image-classification": BitForImageClassification}
|
{"image-feature-extraction": BitModel, "image-classification": BitForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -429,7 +429,10 @@ class BlipModelTester:
|
|||||||
class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (BlipModel,) if is_torch_available() else ()
|
all_model_classes = (BlipModel,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": BlipModel, "image-to-text": BlipForConditionalGeneration}
|
{
|
||||||
|
"feature-extraction": BlipModel,
|
||||||
|
"image-to-text": BlipForConditionalGeneration,
|
||||||
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -477,7 +477,9 @@ class CLIPModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (CLIPModel,) if is_torch_available() else ()
|
all_model_classes = (CLIPModel,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"feature-extraction": CLIPModel} if is_torch_available() else {}
|
pipeline_model_mapping = (
|
||||||
|
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
|
||||||
|
)
|
||||||
fx_compatible = True
|
fx_compatible = True
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
@ -185,7 +185,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
|
{"image-feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -172,7 +172,7 @@ class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
|
{"image-feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -180,7 +180,7 @@ class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
|
{"image-feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -151,7 +151,7 @@ class CvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
|
all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
|
{"image-feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -178,7 +178,7 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Data2VecVisionModel,
|
"image-feature-extraction": Data2VecVisionModel,
|
||||||
"image-classification": Data2VecVisionForImageClassification,
|
"image-classification": Data2VecVisionForImageClassification,
|
||||||
"image-segmentation": Data2VecVisionForSemanticSegmentation,
|
"image-segmentation": Data2VecVisionForSemanticSegmentation,
|
||||||
}
|
}
|
||||||
|
@ -191,7 +191,7 @@ class DeformableDetrModelTester:
|
|||||||
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_torch_available() else ()
|
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
|
{"image-feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -206,7 +206,7 @@ class DeiTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": DeiTModel,
|
"image-feature-extraction": DeiTModel,
|
||||||
"image-classification": (DeiTForImageClassification, DeiTForImageClassificationWithTeacher),
|
"image-classification": (DeiTForImageClassification, DeiTForImageClassificationWithTeacher),
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
@ -217,7 +217,7 @@ class DetaModelTester:
|
|||||||
class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (DetaModel, DetaForObjectDetection) if is_torchvision_available() else ()
|
all_model_classes = (DetaModel, DetaForObjectDetection) if is_torchvision_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
|
{"image-feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
|
||||||
if is_torchvision_available()
|
if is_torchvision_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -182,7 +182,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": DetrModel,
|
"image-feature-extraction": DetrModel,
|
||||||
"image-segmentation": DetrForSegmentation,
|
"image-segmentation": DetrForSegmentation,
|
||||||
"object-detection": DetrForObjectDetection,
|
"object-detection": DetrForObjectDetection,
|
||||||
}
|
}
|
||||||
|
@ -207,7 +207,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
|
{"image-feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -217,7 +217,7 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
|
{"image-feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -145,7 +145,7 @@ class DonutSwinModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
|
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"feature-extraction": DonutSwinModel} if is_torch_available() else {}
|
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
|
||||||
fx_compatible = True
|
fx_compatible = True
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
@ -163,7 +163,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"depth-estimation": DPTForDepthEstimation,
|
"depth-estimation": DPTForDepthEstimation,
|
||||||
"feature-extraction": DPTModel,
|
"image-feature-extraction": DPTModel,
|
||||||
"image-segmentation": DPTForSemanticSegmentation,
|
"image-segmentation": DPTForSemanticSegmentation,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
@ -190,7 +190,7 @@ class EfficientFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": EfficientFormerModel,
|
"image-feature-extraction": EfficientFormerModel,
|
||||||
"image-classification": (
|
"image-classification": (
|
||||||
EfficientFormerForImageClassification,
|
EfficientFormerForImageClassification,
|
||||||
EfficientFormerForImageClassificationWithTeacher,
|
EfficientFormerForImageClassificationWithTeacher,
|
||||||
|
@ -130,7 +130,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
|||||||
|
|
||||||
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
|
all_model_classes = (EfficientNetModel, EfficientNetForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
|
{"image-feature-extraction": EfficientNetModel, "image-classification": EfficientNetForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
@ -216,6 +216,12 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
|||||||
model = EfficientNetModel.from_pretrained(model_name)
|
model = EfficientNetModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_vision
|
||||||
|
@slow
|
||||||
|
def test_pipeline_image_feature_extraction(self):
|
||||||
|
super().test_pipeline_image_feature_extraction()
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
@require_vision
|
@require_vision
|
||||||
@slow
|
@slow
|
||||||
|
@ -238,7 +238,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
|
{"image-feature-extraction": FocalNetModel, "image-classification": FocalNetForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -146,7 +146,9 @@ class GLPNModelTester:
|
|||||||
class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else ()
|
all_model_classes = (GLPNModel, GLPNForDepthEstimation) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"depth-estimation": GLPNForDepthEstimation, "feature-extraction": GLPNModel} if is_torch_available() else {}
|
{"depth-estimation": GLPNForDepthEstimation, "image-feature-extraction": GLPNModel}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
@ -271,7 +271,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
|
all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
|
{"image-feature-extraction": ImageGPTModel, "image-classification": ImageGPTForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -176,7 +176,7 @@ class LevitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": LevitModel,
|
"image-feature-extraction": LevitModel,
|
||||||
"image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher),
|
"image-classification": (LevitForImageClassification, LevitForImageClassificationWithTeacher),
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
|
@ -197,7 +197,7 @@ class Mask2FormerModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else ()
|
all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"feature-extraction": Mask2FormerModel} if is_torch_available() else {}
|
pipeline_model_mapping = {"image-feature-extraction": Mask2FormerModel} if is_torch_available() else {}
|
||||||
|
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
@ -197,7 +197,7 @@ class MaskFormerModelTester:
|
|||||||
class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
|
{"image-feature-extraction": MaskFormerModel, "image-segmentation": MaskFormerForInstanceSegmentation}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -31,7 +31,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import MgpstrForSceneTextRecognition
|
from transformers import MgpstrForSceneTextRecognition, MgpstrModel
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@ -118,7 +118,11 @@ class MgpstrModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
|
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"feature-extraction": MgpstrForSceneTextRecognition} if is_torch_available() else {}
|
pipeline_model_mapping = (
|
||||||
|
{"feature-extraction": MgpstrForSceneTextRecognition, "image-feature-extraction": MgpstrModel}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
@ -147,7 +147,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
|
|
||||||
all_model_classes = (MobileNetV1Model, MobileNetV1ForImageClassification) if is_torch_available() else ()
|
all_model_classes = (MobileNetV1Model, MobileNetV1ForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
|
{"image-feature-extraction": MobileNetV1Model, "image-classification": MobileNetV1ForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -195,7 +195,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": MobileNetV2Model,
|
"image-feature-extraction": MobileNetV2Model,
|
||||||
"image-classification": MobileNetV2ForImageClassification,
|
"image-classification": MobileNetV2ForImageClassification,
|
||||||
"image-segmentation": MobileNetV2ForSemanticSegmentation,
|
"image-segmentation": MobileNetV2ForSemanticSegmentation,
|
||||||
}
|
}
|
||||||
|
@ -188,7 +188,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": MobileViTModel,
|
"image-feature-extraction": MobileViTModel,
|
||||||
"image-classification": MobileViTForImageClassification,
|
"image-classification": MobileViTForImageClassification,
|
||||||
"image-segmentation": MobileViTForSemanticSegmentation,
|
"image-segmentation": MobileViTForSemanticSegmentation,
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
|
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": MobileViTV2Model,
|
"image-feature-extraction": MobileViTV2Model,
|
||||||
"image-classification": MobileViTV2ForImageClassification,
|
"image-classification": MobileViTV2ForImageClassification,
|
||||||
"image-segmentation": MobileViTV2ForSemanticSegmentation,
|
"image-segmentation": MobileViTV2ForSemanticSegmentation,
|
||||||
}
|
}
|
||||||
|
@ -204,7 +204,7 @@ class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": NatModel, "image-classification": NatForImageClassification}
|
{"image-feature-extraction": NatModel, "image-classification": NatForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -433,7 +433,10 @@ class Owlv2ModelTester:
|
|||||||
class Owlv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Owlv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Owlv2Model,) if is_torch_available() else ()
|
all_model_classes = (Owlv2Model,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": Owlv2Model, "zero-shot-object-detection": Owlv2ForObjectDetection}
|
{
|
||||||
|
"feature-extraction": Owlv2Model,
|
||||||
|
"zero-shot-object-detection": Owlv2ForObjectDetection,
|
||||||
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -428,7 +428,10 @@ class OwlViTModelTester:
|
|||||||
class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class OwlViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (OwlViTModel,) if is_torch_available() else ()
|
all_model_classes = (OwlViTModel,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": OwlViTModel, "zero-shot-object-detection": OwlViTForObjectDetection}
|
{
|
||||||
|
"feature-extraction": OwlViTModel,
|
||||||
|
"zero-shot-object-detection": OwlViTForObjectDetection,
|
||||||
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -124,7 +124,7 @@ class PoolFormerModelTester:
|
|||||||
class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (PoolFormerModel, PoolFormerForImageClassification) if is_torch_available() else ()
|
all_model_classes = (PoolFormerModel, PoolFormerForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": PoolFormerModel, "image-classification": PoolFormerForImageClassification}
|
{"image-feature-extraction": PoolFormerModel, "image-classification": PoolFormerForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -158,7 +158,7 @@ def prepare_img():
|
|||||||
class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (PvtModel, PvtForImageClassification) if is_torch_available() else ()
|
all_model_classes = (PvtModel, PvtForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": PvtModel, "image-classification": PvtForImageClassification}
|
{"image-feature-extraction": PvtModel, "image-classification": PvtForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -126,7 +126,7 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
|
all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
|
{"image-feature-extraction": RegNetModel, "image-classification": RegNetForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -170,7 +170,7 @@ class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
|
{"image-feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -171,7 +171,7 @@ class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": SegformerModel,
|
"image-feature-extraction": SegformerModel,
|
||||||
"image-classification": SegformerForImageClassification,
|
"image-classification": SegformerForImageClassification,
|
||||||
"image-segmentation": SegformerForSemanticSegmentation,
|
"image-segmentation": SegformerForSemanticSegmentation,
|
||||||
}
|
}
|
||||||
|
@ -139,7 +139,7 @@ class SwiftFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
|
|
||||||
all_model_classes = (SwiftFormerModel, SwiftFormerForImageClassification) if is_torch_available() else ()
|
all_model_classes = (SwiftFormerModel, SwiftFormerForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": SwiftFormerModel, "image-classification": SwiftFormerForImageClassification}
|
{"image-feature-extraction": SwiftFormerModel, "image-classification": SwiftFormerForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -232,7 +232,7 @@ class SwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": SwinModel, "image-classification": SwinForImageClassification}
|
{"image-feature-extraction": SwinModel, "image-classification": SwinForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -162,7 +162,7 @@ class Swin2SRModelTester:
|
|||||||
class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Swin2SRModel, Swin2SRForImageSuperResolution) if is_torch_available() else ()
|
all_model_classes = (Swin2SRModel, Swin2SRForImageSuperResolution) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": Swin2SRModel, "image-to-image": Swin2SRForImageSuperResolution}
|
{"image-feature-extraction": Swin2SRModel, "image-to-image": Swin2SRForImageSuperResolution}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -217,7 +217,7 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
|
{"image-feature-extraction": Swinv2Model, "image-classification": Swinv2ForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -200,7 +200,7 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": TableTransformerModel, "object-detection": TableTransformerForObjectDetection}
|
{"image-feature-extraction": TableTransformerModel, "object-detection": TableTransformerForObjectDetection}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -228,7 +228,7 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ViltModel, "visual-question-answering": ViltForQuestionAnswering}
|
{"image-feature-extraction": ViltModel, "visual-question-answering": ViltForQuestionAnswering}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -193,7 +193,7 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ViTModel, "image-classification": ViTForImageClassification}
|
{"image-feature-extraction": ViTModel, "image-classification": ViTForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -156,7 +156,7 @@ class ViTHybridModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||||||
|
|
||||||
all_model_classes = (ViTHybridModel, ViTHybridForImageClassification) if is_torch_available() else ()
|
all_model_classes = (ViTHybridModel, ViTHybridForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ViTHybridModel, "image-classification": ViTHybridForImageClassification}
|
{"image-feature-extraction": ViTHybridModel, "image-classification": ViTHybridForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -164,7 +164,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
all_model_classes = (ViTMAEModel, ViTMAEForPreTraining) if is_torch_available() else ()
|
all_model_classes = (ViTMAEModel, ViTMAEForPreTraining) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"feature-extraction": ViTMAEModel} if is_torch_available() else {}
|
pipeline_model_mapping = {"image-feature-extraction": ViTMAEModel} if is_torch_available() else {}
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
@ -152,7 +152,7 @@ class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (ViTMSNModel, ViTMSNForImageClassification) if is_torch_available() else ()
|
all_model_classes = (ViTMSNModel, ViTMSNForImageClassification) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": ViTMSNModel, "image-classification": ViTMSNForImageClassification}
|
{"image-feature-extraction": ViTMSNModel, "image-classification": ViTMSNForImageClassification}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
@ -168,7 +168,9 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (YolosModel, YolosForObjectDetection) if is_torch_available() else ()
|
all_model_classes = (YolosModel, YolosForObjectDetection) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": YolosModel, "object-detection": YolosForObjectDetection} if is_torch_available() else {}
|
{"image-feature-extraction": YolosModel, "object-detection": YolosForObjectDetection}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
157
tests/pipelines/test_pipelines_image_feature_extraction.py
Normal file
157
tests/pipelines/test_pipelines_image_feature_extraction.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
MODEL_MAPPING,
|
||||||
|
TF_MODEL_MAPPING,
|
||||||
|
TOKENIZER_MAPPING,
|
||||||
|
ImageFeatureExtractionPipeline,
|
||||||
|
is_tf_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_vision_available,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# We will verify our results on an image of cute cats
|
||||||
|
def prepare_img():
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||||
|
model_mapping = MODEL_MAPPING
|
||||||
|
tf_model_mapping = TF_MODEL_MAPPING
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
feature_extractor = pipeline(
|
||||||
|
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
|
||||||
|
)
|
||||||
|
img = prepare_img()
|
||||||
|
outputs = feature_extractor(img)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs[0][0]),
|
||||||
|
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
feature_extractor = pipeline(
|
||||||
|
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
|
||||||
|
)
|
||||||
|
img = prepare_img()
|
||||||
|
outputs = feature_extractor(img)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs[0][0]),
|
||||||
|
[-1.417, -0.392, -1.264, -1.196, 1.648, 0.885, 0.56, -0.606, -1.175, 0.823, 1.912, 0.081, -0.053, 1.119, -0.062, -1.757, -0.571, 0.075, 0.959, 0.118, 1.201, -0.672, -0.498, 0.364, 0.937, -1.623, 0.228, 0.19, 1.697, -1.115, 0.583, -0.981]) # fmt: skip
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_image_processing_small_model_pt(self):
|
||||||
|
feature_extractor = pipeline(
|
||||||
|
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# test with image processor parameters
|
||||||
|
image_processor_kwargs = {"size": {"height": 300, "width": 300}}
|
||||||
|
img = prepare_img()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# Image doesn't match model input size
|
||||||
|
feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||||
|
|
||||||
|
image_processor_kwargs = {"image_mean": [0, 0, 0], "image_std": [1, 1, 1]}
|
||||||
|
img = prepare_img()
|
||||||
|
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||||
|
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_image_processing_small_model_tf(self):
|
||||||
|
feature_extractor = pipeline(
|
||||||
|
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
|
||||||
|
)
|
||||||
|
|
||||||
|
# test with image processor parameters
|
||||||
|
image_processor_kwargs = {"size": {"height": 300, "width": 300}}
|
||||||
|
img = prepare_img()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# Image doesn't match model input size
|
||||||
|
feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||||
|
|
||||||
|
image_processor_kwargs = {"image_mean": [0, 0, 0], "image_std": [1, 1, 1]}
|
||||||
|
img = prepare_img()
|
||||||
|
outputs = feature_extractor(img, image_processor_kwargs=image_processor_kwargs)
|
||||||
|
self.assertEqual(np.squeeze(outputs).shape, (226, 32))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_return_tensors_pt(self):
|
||||||
|
feature_extractor = pipeline(
|
||||||
|
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="pt"
|
||||||
|
)
|
||||||
|
img = prepare_img()
|
||||||
|
outputs = feature_extractor(img, return_tensors=True)
|
||||||
|
self.assertTrue(torch.is_tensor(outputs))
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_return_tensors_tf(self):
|
||||||
|
feature_extractor = pipeline(
|
||||||
|
task="image-feature-extraction", model="hf-internal-testing/tiny-random-vit", framework="tf"
|
||||||
|
)
|
||||||
|
img = prepare_img()
|
||||||
|
outputs = feature_extractor(img, return_tensors=True)
|
||||||
|
self.assertTrue(tf.is_tensor(outputs))
|
||||||
|
|
||||||
|
def get_test_pipeline(self, model, tokenizer, processor):
|
||||||
|
if processor is None:
|
||||||
|
self.skipTest("No image processor")
|
||||||
|
|
||||||
|
elif type(model.config) in TOKENIZER_MAPPING:
|
||||||
|
self.skipTest("This is a bimodal model, we need to find a more consistent way to switch on those models.")
|
||||||
|
|
||||||
|
elif model.config.is_encoder_decoder:
|
||||||
|
self.skipTest(
|
||||||
|
"""encoder_decoder models are trickier for this pipeline.
|
||||||
|
Do we want encoder + decoder inputs to get some featues?
|
||||||
|
Do we want encoder only features ?
|
||||||
|
For now ignore those.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_extractor = ImageFeatureExtractionPipeline(model=model, image_processor=processor)
|
||||||
|
img = prepare_img()
|
||||||
|
return feature_extractor, [img, img]
|
||||||
|
|
||||||
|
def run_pipeline_test(self, feature_extractor, examples):
|
||||||
|
imgs = examples
|
||||||
|
outputs = feature_extractor(imgs[0])
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), 1)
|
||||||
|
|
||||||
|
outputs = feature_extractor(imgs)
|
||||||
|
self.assertEqual(len(outputs), 2)
|
@ -39,6 +39,7 @@ from .pipelines.test_pipelines_document_question_answering import DocumentQuesti
|
|||||||
from .pipelines.test_pipelines_feature_extraction import FeatureExtractionPipelineTests
|
from .pipelines.test_pipelines_feature_extraction import FeatureExtractionPipelineTests
|
||||||
from .pipelines.test_pipelines_fill_mask import FillMaskPipelineTests
|
from .pipelines.test_pipelines_fill_mask import FillMaskPipelineTests
|
||||||
from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests
|
from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests
|
||||||
|
from .pipelines.test_pipelines_image_feature_extraction import ImageFeatureExtractionPipelineTests
|
||||||
from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests
|
from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests
|
||||||
from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests
|
from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests
|
||||||
from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests
|
from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests
|
||||||
@ -70,6 +71,7 @@ pipeline_test_mapping = {
|
|||||||
"feature-extraction": {"test": FeatureExtractionPipelineTests},
|
"feature-extraction": {"test": FeatureExtractionPipelineTests},
|
||||||
"fill-mask": {"test": FillMaskPipelineTests},
|
"fill-mask": {"test": FillMaskPipelineTests},
|
||||||
"image-classification": {"test": ImageClassificationPipelineTests},
|
"image-classification": {"test": ImageClassificationPipelineTests},
|
||||||
|
"image-feature-extraction": {"test": ImageFeatureExtractionPipelineTests},
|
||||||
"image-segmentation": {"test": ImageSegmentationPipelineTests},
|
"image-segmentation": {"test": ImageSegmentationPipelineTests},
|
||||||
"image-to-image": {"test": ImageToImagePipelineTests},
|
"image-to-image": {"test": ImageToImagePipelineTests},
|
||||||
"image-to-text": {"test": ImageToTextPipelineTests},
|
"image-to-text": {"test": ImageToTextPipelineTests},
|
||||||
@ -374,6 +376,13 @@ class PipelineTesterMixin:
|
|||||||
def test_pipeline_image_to_text(self):
|
def test_pipeline_image_to_text(self):
|
||||||
self.run_task_tests(task="image-to-text")
|
self.run_task_tests(task="image-to-text")
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_timm
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_image_feature_extraction(self):
|
||||||
|
self.run_task_tests(task="image-feature-extraction")
|
||||||
|
|
||||||
@unittest.skip(reason="`run_pipeline_test` is currently not implemented.")
|
@unittest.skip(reason="`run_pipeline_test` is currently not implemented.")
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
@require_vision
|
@require_vision
|
||||||
|
@ -324,6 +324,7 @@ OBJECTS_TO_IGNORE = [
|
|||||||
"IdeficsConfig",
|
"IdeficsConfig",
|
||||||
"IdeficsProcessor",
|
"IdeficsProcessor",
|
||||||
"ImageClassificationPipeline",
|
"ImageClassificationPipeline",
|
||||||
|
"ImageFeatureExtractionPipeline",
|
||||||
"ImageGPTConfig",
|
"ImageGPTConfig",
|
||||||
"ImageSegmentationPipeline",
|
"ImageSegmentationPipeline",
|
||||||
"ImageToImagePipeline",
|
"ImageToImagePipeline",
|
||||||
|
Loading…
Reference in New Issue
Block a user