Make pipeline able to load processor (#32514)

* Refactor get_test_pipeline

* Fixup

* Fixing tests

* Add processor loading in tests

* Restructure processors loading

* Add processor to the pipeline

* Move model loading on tom of the test

* Update `get_test_pipeline`

* Fixup

* Add class-based flags for loading processors

* Change `is_pipeline_test_to_skip` signature

* Skip t5 failing test for slow tokenizer

* Fixup

* Fix copies for T5

* Fix typo

* Add try/except for tokenizer loading (kosmos-2 case)

* Fixup

* Llama not fails for long generation

* Revert processor pass in text-generation test

* Fix docs

* Switch back to json file for image processors and feature extractors

* Add processor type check

* Remove except for tokenizers

* Fix docstring

* Fix empty lists for tests

* Fixup

* Fix load check

* Ensure we have non-empty test cases

* Update src/transformers/pipelines/__init__.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update src/transformers/pipelines/base.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Rework comment

* Better docs, add note about pipeline components

* Change warning to error raise

* Fixup

* Refine pipeline docs

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Pavel Iakubovskii 2024-10-09 16:46:11 +01:00 committed by GitHub
parent 4fb28703ad
commit 48461c0fe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
91 changed files with 1312 additions and 241 deletions

View File

@ -28,7 +28,9 @@ from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage
from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
@ -556,6 +558,7 @@ def pipeline(
tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
image_processor: Optional[Union[str, BaseImageProcessor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None,
processor: Optional[Union[str, ProcessorMixin]] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
use_fast: bool = True, use_fast: bool = True,
@ -571,11 +574,19 @@ def pipeline(
""" """
Utility factory method to build a [`Pipeline`]. Utility factory method to build a [`Pipeline`].
Pipelines are made of: A pipeline consists of:
- A [tokenizer](tokenizer) in charge of mapping raw textual input to token. - One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer),
- A [model](model) to make predictions from the inputs. [image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors).
- Some (optional) post processing for enhancing model's output. - A [model](model) that generates predictions from the inputs.
- Optional post-processing steps to refine the model's output, which can also be handled by processors.
<Tip>
While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`,
they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load
required ones automatically. In case you want to provide these components explicitly, please refer to a
specific pipeline in order to get more details regarding what components are required.
</Tip>
Args: Args:
task (`str`): task (`str`):
@ -644,6 +655,25 @@ def pipeline(
`model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it `model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it
is a string). However, if `config` is also not given or not a string, then the default feature extractor is a string). However, if `config` is also not given or not a string, then the default feature extractor
for the given `task` will be loaded. for the given `task` will be loaded.
image_processor (`str` or [`BaseImageProcessor`], *optional*):
The image processor that will be used by the pipeline to preprocess images for the model. This can be a
model identifier or an actual image processor inheriting from [`BaseImageProcessor`].
Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal
models will also require a tokenizer to be passed.
If not provided, the default image processor for the given `model` will be loaded (if it is a string). If
`model` is not specified or not a string, then the default image processor for `config` is loaded (if it is
a string).
processor (`str` or [`ProcessorMixin`], *optional*):
The processor that will be used by the pipeline to preprocess data for the model. This can be a model
identifier or an actual processor inheriting from [`ProcessorMixin`].
Processors are used for multi-modal models that require multi-modal inputs, for example, a model that
requires both text and image inputs.
If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model`
is not specified or not a string, then the default processor for `config` is loaded (if it is a string).
framework (`str`, *optional*): framework (`str`, *optional*):
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
installed. installed.
@ -905,13 +935,17 @@ def pipeline(
model_config = model.config model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash hub_kwargs["_commit_hash"] = model.config._commit_hash
load_tokenizer = (
type(model_config) in TOKENIZER_MAPPING load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
or model_config.tokenizer_class is not None
or isinstance(tokenizer, str)
)
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None
# Check that pipeline class required loading
load_tokenizer = load_tokenizer and pipeline_class._load_tokenizer
load_feature_extractor = load_feature_extractor and pipeline_class._load_feature_extractor
load_image_processor = load_image_processor and pipeline_class._load_image_processor
load_processor = load_processor and pipeline_class._load_processor
# If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while # If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while
# `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some # `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some
@ -1074,6 +1108,31 @@ def pipeline(
if not is_pyctcdecode_available(): if not is_pyctcdecode_available():
logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode") logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode")
if load_processor:
# Try to infer processor from model or config name (if provided as str)
if processor is None:
if isinstance(model_name, str):
processor = model_name
elif isinstance(config, str):
processor = config
else:
# Impossible to guess what is the right processor here
raise Exception(
"Impossible to guess which processor to use. "
"Please provide a processor instance or a path/identifier "
"to a processor."
)
# Instantiate processor if needed
if isinstance(processor, (str, tuple)):
processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs)
if not isinstance(processor, ProcessorMixin):
raise TypeError(
"Processor was loaded, but it is not an instance of `ProcessorMixin`. "
f"Got type `{type(processor)}` instead. Please check that you specified "
"correct pipeline task for the model and model has processor implemented and saved."
)
if task == "translation" and model.config.task_specific_params: if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params: for key in model.config.task_specific_params:
if key.startswith("translation"): if key.startswith("translation"):
@ -1099,4 +1158,7 @@ def pipeline(
if device is not None: if device is not None:
kwargs["device"] = device kwargs["device"] = device
if processor is not None:
kwargs["processor"] = processor
return pipeline_class(model=model, framework=framework, task=task, **kwargs) return pipeline_class(model=model, framework=framework, task=task, **kwargs)

View File

@ -34,6 +34,7 @@ from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..image_processing_utils import BaseImageProcessor from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig from ..models.auto.configuration_auto import AutoConfig
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..utils import ( from ..utils import (
ModelOutput, ModelOutput,
@ -716,6 +717,7 @@ def build_pipeline_init_args(
has_tokenizer: bool = False, has_tokenizer: bool = False,
has_feature_extractor: bool = False, has_feature_extractor: bool = False,
has_image_processor: bool = False, has_image_processor: bool = False,
has_processor: bool = False,
supports_binary_output: bool = True, supports_binary_output: bool = True,
) -> str: ) -> str:
docstring = r""" docstring = r"""
@ -738,6 +740,12 @@ def build_pipeline_init_args(
image_processor ([`BaseImageProcessor`]): image_processor ([`BaseImageProcessor`]):
The image processor that will be used by the pipeline to encode data for the model. This object inherits from The image processor that will be used by the pipeline to encode data for the model. This object inherits from
[`BaseImageProcessor`].""" [`BaseImageProcessor`]."""
if has_processor:
docstring += r"""
processor ([`ProcessorMixin`]):
The processor that will be used by the pipeline to encode data for the model. This object inherits from
[`ProcessorMixin`]. Processor is a composite object that might contain `tokenizer`, `feature_extractor`, and
`image_processor`."""
docstring += r""" docstring += r"""
modelcard (`str` or [`ModelCard`], *optional*): modelcard (`str` or [`ModelCard`], *optional*):
Model card attributed to the model for this pipeline. Model card attributed to the model for this pipeline.
@ -774,7 +782,11 @@ def build_pipeline_init_args(
PIPELINE_INIT_ARGS = build_pipeline_init_args( PIPELINE_INIT_ARGS = build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True has_tokenizer=True,
has_feature_extractor=True,
has_image_processor=True,
has_processor=True,
supports_binary_output=True,
) )
@ -787,7 +799,11 @@ if is_torch_available():
) )
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_feature_extractor=True, has_image_processor=True)) @add_end_docstrings(
build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True
)
)
class Pipeline(_ScikitCompat, PushToHubMixin): class Pipeline(_ScikitCompat, PushToHubMixin):
""" """
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
@ -805,6 +821,22 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
constructor argument. If set to `True`, the output will be stored in the pickle format. constructor argument. If set to `True`, the output will be stored in the pickle format.
""" """
# Historically we have pipelines working with `tokenizer`, `feature_extractor`, and `image_processor`
# as separate processing components. While we have `processor` class that combines them, some pipelines
# might still operate with these components separately.
# With the addition of `processor` to `pipeline`, we want to avoid:
# - loading `processor` for pipelines that still work with `image_processor` and `tokenizer` separately;
# - loading `image_processor`/`tokenizer` as a separate component while we operate only with `processor`,
# because `processor` will load required sub-components by itself.
# Below flags allow granular control over loading components and set to be backward compatible with current
# pipelines logic. You may override these flags when creating your pipeline. For example, for
# `zero-shot-object-detection` pipeline which operates with `processor` you should set `_load_processor=True`
# and all the rest flags to `False` to avoid unnecessary loading of the components.
_load_processor = False
_load_image_processor = True
_load_feature_extractor = True
_load_tokenizer = True
default_input_names = None default_input_names = None
def __init__( def __init__(
@ -813,6 +845,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
tokenizer: Optional[PreTrainedTokenizer] = None, tokenizer: Optional[PreTrainedTokenizer] = None,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None, feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
image_processor: Optional[BaseImageProcessor] = None, image_processor: Optional[BaseImageProcessor] = None,
processor: Optional[ProcessorMixin] = None,
modelcard: Optional[ModelCard] = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
task: str = "", task: str = "",
@ -830,6 +863,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.feature_extractor = feature_extractor self.feature_extractor = feature_extractor
self.image_processor = image_processor self.image_processor = image_processor
self.processor = processor
self.modelcard = modelcard self.modelcard = modelcard
self.framework = framework self.framework = framework

View File

@ -436,9 +436,16 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "FeatureExtractionPipelineTests": if pipeline_test_case_name == "FeatureExtractionPipelineTests":
return True return True
return False return False

View File

@ -167,9 +167,16 @@ class ASTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "AudioClassificationPipelineTests": if pipeline_test_case_name == "AudioClassificationPipelineTests":
return True return True
return False return False

View File

@ -276,9 +276,16 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -236,9 +236,16 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return pipeline_test_casse_name == "TextGenerationPipelineTests" return pipeline_test_case_name == "TextGenerationPipelineTests"
def setUp(self): def setUp(self):
self.model_tester = BlenderbotSmallModelTester(self) self.model_tester = BlenderbotSmallModelTester(self)

View File

@ -321,9 +321,16 @@ class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, Flax
all_generative_model_classes = (FlaxBlenderbotSmallForConditionalGeneration,) if is_flax_available() else () all_generative_model_classes = (FlaxBlenderbotSmallForConditionalGeneration,) if is_flax_available() else ()
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return pipeline_test_casse_name == "TextGenerationPipelineTests" return pipeline_test_case_name == "TextGenerationPipelineTests"
def setUp(self): def setUp(self):
self.model_tester = FlaxBlenderbotSmallModelTester(self) self.model_tester = FlaxBlenderbotSmallModelTester(self)

View File

@ -198,9 +198,16 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
test_onnx = False test_onnx = False
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return pipeline_test_casse_name == "TextGenerationPipelineTests" return pipeline_test_case_name == "TextGenerationPipelineTests"
def setUp(self): def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self) self.model_tester = TFBlenderbotSmallModelTester(self)

View File

@ -295,7 +295,14 @@ class BrosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# BROS requires `bbox` in the inputs which doesn't fit into the above 2 pipelines' input formats. # BROS requires `bbox` in the inputs which doesn't fit into the above 2 pipelines' input formats.
# see https://github.com/huggingface/transformers/pull/26294 # see https://github.com/huggingface/transformers/pull/26294
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -22,7 +22,14 @@ from transformers.testing_utils import custom_tokenizers
class CpmTokenizationTest(unittest.TestCase): class CpmTokenizationTest(unittest.TestCase):
# There is no `CpmModel` # There is no `CpmModel`
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -211,9 +211,16 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests": if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers. # Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny # `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny
# config could not be created. # config could not be created.

View File

@ -189,9 +189,16 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests": if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers. # Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny # `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny
# config could not be created. # config could not be created.

View File

@ -312,7 +312,14 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -392,10 +392,17 @@ class FlaubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -309,10 +309,17 @@ class TFFlaubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -298,9 +298,16 @@ class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -326,7 +326,14 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -382,10 +382,17 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -322,10 +322,17 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -260,9 +260,16 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: `image-to-text` pipeline for this model needs Processor. # TODO: `image-to-text` pipeline for this model needs Processor.
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return pipeline_test_casse_name == "ImageToTextPipelineTests" return pipeline_test_case_name == "ImageToTextPipelineTests"
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)

View File

@ -292,7 +292,14 @@ class LayoutLMv3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
# `DocumentQuestionAnsweringPipeline` is expected to work with this model, but it combines the text and visual # `DocumentQuestionAnsweringPipeline` is expected to work with this model, but it combines the text and visual
# embedding along the sequence dimension (dim 1), which causes an error during post-processing as `p_mask` has # embedding along the sequence dimension (dim 1), which causes an error during post-processing as `p_mask` has

View File

@ -288,7 +288,14 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -302,9 +302,16 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -245,7 +245,14 @@ class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -331,10 +331,17 @@ class LongformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -303,10 +303,17 @@ class TFLongformerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -621,9 +621,16 @@ class LukeModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name in ["QAPipelineTests", "ZeroShotClassificationPipelineTests"]: if pipeline_test_case_name in ["QAPipelineTests", "ZeroShotClassificationPipelineTests"]:
return True return True
return False return False

View File

@ -258,9 +258,16 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "TranslationPipelineTests": if pipeline_test_case_name == "TranslationPipelineTests":
# Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`. # Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
# `M2M100Config` was never used in pipeline tests: cannot create a simple tokenizer. # `M2M100Config` was never used in pipeline tests: cannot create a simple tokenizer.
return True return True

View File

@ -301,7 +301,14 @@ class MarkupLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
# ValueError: Nodes must be of type `List[str]` (single pretokenized example), or `List[List[str]]` # ValueError: Nodes must be of type `List[str]` (single pretokenized example), or `List[List[str]]`
# (batch of pretokenized examples). # (batch of pretokenized examples).

View File

@ -252,9 +252,16 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -175,9 +175,16 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name != "FeatureExtractionPipelineTests": if pipeline_test_case_name != "FeatureExtractionPipelineTests":
# Exception encountered when calling layer '...' # Exception encountered when calling layer '...'
return True return True

View File

@ -313,7 +313,14 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -266,7 +266,14 @@ class TFMistralModelTest(TFModelTesterMixin, TFGenerationIntegrationTests, Pipel
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -313,7 +313,14 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -582,7 +582,14 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file # `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model) # `src/transformers/data/processors/squad.py` (where this test fails for this model)
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if tokenizer_name is None: if tokenizer_name is None:
return True return True
@ -1055,6 +1062,26 @@ class MT5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
# `MT5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece

View File

@ -441,10 +441,17 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -266,7 +266,14 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
# Saving the slow tokenizer after saving the fast tokenizer causes the loading of the later hanging forever. # Saving the slow tokenizer after saving the fast tokenizer causes the loading of the later hanging forever.
return True return True

View File

@ -247,9 +247,16 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "FeatureExtractionPipelineTests": if pipeline_test_case_name == "FeatureExtractionPipelineTests":
return True return True
return False return False

View File

@ -211,9 +211,16 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests": if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers. # Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a # `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a
# tiny config could not be created. # tiny config could not be created.

View File

@ -217,9 +217,16 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Tes
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests": if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers. # Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a # `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a
# tiny config could not be created. # tiny config could not be created.

View File

@ -221,10 +221,17 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -304,7 +304,14 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -351,7 +351,14 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -245,9 +245,16 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "TranslationPipelineTests": if pipeline_test_case_name == "TranslationPipelineTests":
# Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`. # Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
# `PLBartConfig` was never used in pipeline tests: cannot create a simple tokenizer. # `PLBartConfig` was never used in pipeline tests: cannot create a simple tokenizer.
return True return True

View File

@ -906,9 +906,16 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "TextGenerationPipelineTests": if pipeline_test_case_name == "TextGenerationPipelineTests":
# Get `ValueError: AttributeError: 'NoneType' object has no attribute 'new_ones'` or `AssertionError`. # Get `ValueError: AttributeError: 'NoneType' object has no attribute 'new_ones'` or `AssertionError`.
# `ProphetNetConfig` was never used in pipeline tests: cannot create a simple # `ProphetNetConfig` was never used in pipeline tests: cannot create a simple
# tokenizer. # tokenizer.

View File

@ -322,7 +322,14 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -349,7 +349,14 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -305,7 +305,14 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -727,10 +727,17 @@ class ReformerLSHAttnModelTest(
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -587,9 +587,16 @@ class RoCBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name in [ if pipeline_test_case_name in [
"FillMaskPipelineTests", "FillMaskPipelineTests",
"FeatureExtractionPipelineTests", "FeatureExtractionPipelineTests",
"TextClassificationPipelineTests", "TextClassificationPipelineTests",

View File

@ -275,9 +275,16 @@ class TFRoFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
# TODO: add `prepare_inputs_for_generation` for `TFRoFormerForCausalLM` # TODO: add `prepare_inputs_for_generation` for `TFRoFormerForCausalLM`
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "TextGenerationPipelineTests": if pipeline_test_case_name == "TextGenerationPipelineTests":
return True return True
return False return False

View File

@ -299,7 +299,14 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -290,7 +290,14 @@ class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -225,11 +225,18 @@ class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests when this model gets more usage # TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests": if pipeline_test_case_name == "QAPipelineTests":
return True return True
elif pipeline_test_casse_name == "FeatureExtractionPipelineTests" and tokenizer_name.endswith("Fast"): elif pipeline_test_case_name == "FeatureExtractionPipelineTests" and tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -311,7 +311,14 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -585,7 +585,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file # `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model) # `src/transformers/data/processors/squad.py` (where this test fails for this model)
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if tokenizer_name is None: if tokenizer_name is None:
return True return True
@ -1056,6 +1063,26 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
# `T5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
def use_task_specific_params(model, task): def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task]) model.config.update(model.config.task_specific_params[task])

View File

@ -492,7 +492,14 @@ class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -447,7 +447,14 @@ class TFTapasModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
return True return True

View File

@ -322,9 +322,16 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file # `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model) # `src/transformers/data/processors/squad.py` (where this test fails for this model)
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False
@ -723,6 +730,26 @@ class UMT5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
# `UMT5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece

View File

@ -397,9 +397,16 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name in [ if pipeline_test_case_name in [
"AutomaticSpeechRecognitionPipelineTests", "AutomaticSpeechRecognitionPipelineTests",
"AudioClassificationPipelineTests", "AudioClassificationPipelineTests",
]: ]:

View File

@ -312,10 +312,17 @@ class TFXLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -393,10 +393,17 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):

View File

@ -392,9 +392,16 @@ class XLMRobertaXLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -372,7 +372,14 @@ class TFXLNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
# Exception encountered when calling layer '...' # Exception encountered when calling layer '...'
return True return True

View File

@ -543,9 +543,16 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -386,9 +386,16 @@ class XmodModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests # TODO: Fix the failed tests
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True return True
return False return False

View File

@ -37,9 +37,22 @@ class AudioClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
audio_classifier = AudioClassificationPipeline( audio_classifier = AudioClassificationPipeline(
model=model, feature_extractor=processor, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
) )
# test with a raw waveform # test with a raw waveform

View File

@ -67,14 +67,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else []) + (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
) )
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if tokenizer is None: if tokenizer is None:
# Side effect of no Fast Tokenizer class for these model, so skipping # Side effect of no Fast Tokenizer class for these model, so skipping
# But the slow tokenizer test should still run as they're quite small # But the slow tokenizer test should still run as they're quite small
self.skipTest(reason="No tokenizer available") self.skipTest(reason="No tokenizer available")
speech_recognizer = AutomaticSpeechRecognitionPipeline( speech_recognizer = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
) )
# test with a raw waveform # test with a raw waveform

View File

@ -58,8 +58,23 @@ def hashimage(image: Image) -> str:
class DepthEstimationPipelineTests(unittest.TestCase): class DepthEstimationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
depth_estimator = DepthEstimationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
depth_estimator = DepthEstimationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return depth_estimator, [ return depth_estimator, [
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",

View File

@ -15,7 +15,7 @@
import unittest import unittest
from transformers import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, AutoTokenizer, is_vision_available from transformers import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, AutoTokenizer, is_vision_available
from transformers.pipelines import pipeline from transformers.pipelines import DocumentQuestionAnsweringPipeline, pipeline
from transformers.pipelines.document_question_answering import apply_tesseract from transformers.pipelines.document_question_answering import apply_tesseract
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
@ -61,12 +61,21 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase):
@require_pytesseract @require_pytesseract
@require_vision @require_vision
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
dqa_pipeline = pipeline( self,
"document-question-answering", model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
dqa_pipeline = DocumentQuestionAnsweringPipeline(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
image_processor=processor, feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )

View File

@ -174,7 +174,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase):
raise TypeError("We expect lists of floats, nothing else") raise TypeError("We expect lists of floats, nothing else")
return shape return shape
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if tokenizer is None: if tokenizer is None:
self.skipTest(reason="No tokenizer") self.skipTest(reason="No tokenizer")
elif ( elif (
@ -193,10 +201,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase):
For now ignore those. For now ignore those.
""" """
) )
feature_extractor = FeatureExtractionPipeline( feature_extractor_pipeline = FeatureExtractionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
) )
return feature_extractor, ["This is a test", "This is another test"] return feature_extractor_pipeline, ["This is a test", "This is another test"]
def run_pipeline_test(self, feature_extractor, examples): def run_pipeline_test(self, feature_extractor, examples):
outputs = feature_extractor("This is a test") outputs = feature_extractor("This is a test")

View File

@ -251,11 +251,26 @@ class FillMaskPipelineTests(unittest.TestCase):
unmasker.tokenizer.pad_token = None unmasker.tokenizer.pad_token = None
self.run_pipeline_test(unmasker, []) self.run_pipeline_test(unmasker, [])
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if tokenizer is None or tokenizer.mask_token_id is None: if tokenizer is None or tokenizer.mask_token_id is None:
self.skipTest(reason="The provided tokenizer has no mask token, (probably reformer or wav2vec2)") self.skipTest(reason="The provided tokenizer has no mask token, (probably reformer or wav2vec2)")
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) fill_masker = FillMaskPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
examples = [ examples = [
f"This is another {tokenizer.mask_token} test", f"This is another {tokenizer.mask_token} test",
] ]

View File

@ -58,9 +58,23 @@ class ImageClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
image_classifier = ImageClassificationPipeline( image_classifier = ImageClassificationPipeline(
model=model, image_processor=processor, top_k=2, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
top_k=2,
) )
examples = [ examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),

View File

@ -157,8 +157,16 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
outputs = feature_extractor(img, return_tensors=True) outputs = feature_extractor(img, return_tensors=True)
self.assertTrue(tf.is_tensor(outputs)) self.assertTrue(tf.is_tensor(outputs))
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
if processor is None: self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if image_processor is None:
self.skipTest(reason="No image processor") self.skipTest(reason="No image processor")
elif type(model.config) in TOKENIZER_MAPPING: elif type(model.config) in TOKENIZER_MAPPING:
@ -175,11 +183,16 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
""" """
) )
feature_extractor = ImageFeatureExtractionPipeline( feature_extractor_pipeline = ImageFeatureExtractionPipeline(
model=model, image_processor=processor, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
) )
img = prepare_img() img = prepare_img()
return feature_extractor, [img, img] return feature_extractor_pipeline, [img, img]
def run_pipeline_test(self, feature_extractor, examples): def run_pipeline_test(self, feature_extractor, examples):
imgs = examples imgs = examples

View File

@ -89,8 +89,23 @@ class ImageSegmentationPipelineTests(unittest.TestCase):
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else []) + (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
) )
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
image_segmenter = ImageSegmentationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
image_segmenter = ImageSegmentationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return image_segmenter, [ return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",

View File

@ -47,9 +47,22 @@ class ImageToTextPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
pipe = ImageToTextPipeline( pipe = ImageToTextPipeline(
model=model, tokenizer=tokenizer, image_processor=processor, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
) )
examples = [ examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),

View File

@ -67,8 +67,23 @@ class MaskGenerationPipelineTests(unittest.TestCase):
(list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else []) (list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else [])
) )
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
image_segmenter = MaskGenerationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return image_segmenter, [ return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png", "./tests/fixtures/tests_samples/COCO/000000039769.png",

View File

@ -56,8 +56,23 @@ else:
class ObjectDetectionPipelineTests(unittest.TestCase): class ObjectDetectionPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
object_detector = ObjectDetectionPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
object_detector = ObjectDetectionPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"] return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
def run_pipeline_test(self, object_detector, examples): def run_pipeline_test(self, object_detector, examples):

View File

@ -50,12 +50,27 @@ class QAPipelineTests(unittest.TestCase):
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
} }
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if isinstance(model.config, LxmertConfig): if isinstance(model.config, LxmertConfig):
# This is an bimodal model, we need to find a more consistent way # This is an bimodal model, we need to find a more consistent way
# to switch on those models. # to switch on those models.
return None, None return None, None
question_answerer = QuestionAnsweringPipeline(model, tokenizer, torch_dtype=torch_dtype) question_answerer = QuestionAnsweringPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
examples = [ examples = [
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."}, {"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},

View File

@ -32,8 +32,23 @@ class SummarizationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
summarizer = SummarizationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
summarizer = SummarizationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return summarizer, ["(CNN)The Palestinian Authority officially became", "Some other text"] return summarizer, ["(CNN)The Palestinian Authority officially became", "Some other text"]
def run_pipeline_test(self, summarizer, _): def run_pipeline_test(self, summarizer, _):

View File

@ -35,8 +35,23 @@ class Text2TextGenerationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
generator = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
generator = Text2TextGenerationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return generator, ["Something to write", "Something else"] return generator, ["Something to write", "Something else"]
def run_pipeline_test(self, generator, _): def run_pipeline_test(self, generator, _):

View File

@ -179,8 +179,23 @@ class TextClassificationPipelineTests(unittest.TestCase):
outputs = text_classifier("Birds are a type of animal") outputs = text_classifier("Birds are a type of animal")
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}]) self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
text_classifier = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return text_classifier, ["HuggingFace is in", "This is another test"] return text_classifier, ["HuggingFace is in", "This is another test"]
def run_pipeline_test(self, text_classifier, _): def run_pipeline_test(self, text_classifier, _):

View File

@ -377,8 +377,23 @@ class TextGenerationPipelineTests(unittest.TestCase):
], ],
) )
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
text_generator = TextGenerationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return text_generator, ["This is a test", "Another test"] return text_generator, ["This is a test", "Another test"]
def test_stop_sequence_stopping_criteria(self): def test_stop_sequence_stopping_criteria(self):
@ -471,6 +486,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXJapaneseForCausalLM", "GPTNeoXJapaneseForCausalLM",
"FuyuForCausalLM", "FuyuForCausalLM",
"LlamaForCausalLM",
] ]
if ( if (
tokenizer.model_max_length < 10000 tokenizer.model_max_length < 10000

View File

@ -250,8 +250,23 @@ class TextToAudioPipelineTests(unittest.TestCase):
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs) outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
self.assertListEqual(outputs["audio"].tolist(), audio.tolist()) self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
speech_generator = TextToAudioPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return speech_generator, ["This is a test", "Another test"] return speech_generator, ["This is a test", "Another test"]
def run_pipeline_test(self, speech_generator, _): def run_pipeline_test(self, speech_generator, _):

View File

@ -61,8 +61,23 @@ class TokenClassificationPipelineTests(unittest.TestCase):
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
} }
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
token_classifier = TokenClassificationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return token_classifier, ["A simple string", "A simple string that is quite a bit longer"] return token_classifier, ["A simple string", "A simple string that is quite a bit longer"]
def run_pipeline_test(self, token_classifier, _): def run_pipeline_test(self, token_classifier, _):

View File

@ -35,14 +35,36 @@ class TranslationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if isinstance(model.config, MBartConfig): if isinstance(model.config, MBartConfig):
src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2] src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2]
translator = TranslationPipeline( translator = TranslationPipeline(
model=model, tokenizer=tokenizer, src_lang=src_lang, tgt_lang=tgt_lang, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
src_lang=src_lang,
tgt_lang=tgt_lang,
) )
else: else:
translator = TranslationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype) translator = TranslationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return translator, ["Some string", "Some other text"] return translator, ["Some string", "Some other text"]
def run_pipeline_test(self, translator, _): def run_pipeline_test(self, translator, _):

View File

@ -38,12 +38,26 @@ from .test_pipelines_common import ANY
class VideoClassificationPipelineTests(unittest.TestCase): class VideoClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
example_video_filepath = hf_hub_download( example_video_filepath = hf_hub_download(
repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset" repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset"
) )
video_classifier = VideoClassificationPipeline( video_classifier = VideoClassificationPipeline(
model=model, image_processor=processor, top_k=2, torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
top_k=2,
) )
examples = [ examples = [
example_video_filepath, example_video_filepath,

View File

@ -55,9 +55,19 @@ else:
class VisualQuestionAnsweringPipelineTests(unittest.TestCase): class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
vqa_pipeline = pipeline( vqa_pipeline = pipeline(
"visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa", torch_dtype=torch_dtype "visual-question-answering",
model="hf-internal-testing/tiny-vilt-random-vqa",
torch_dtype=torch_dtype,
) )
examples = [ examples = [
{ {

View File

@ -53,9 +53,23 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
} }
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
classifier = ZeroShotClassificationPipeline( classifier = ZeroShotClassificationPipeline(
model=model, tokenizer=tokenizer, candidate_labels=["polics", "health"], torch_dtype=torch_dtype model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
candidate_labels=["polics", "health"],
) )
return classifier, ["Who are you voting for in 2020?", "My stomach hurts."] return classifier, ["Who are you voting for in 2020?", "My stomach hurts."]

View File

@ -43,7 +43,15 @@ else:
class ZeroShotObjectDetectionPipelineTests(unittest.TestCase): class ZeroShotObjectDetectionPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
object_detector = pipeline( object_detector = pipeline(
"zero-shot-object-detection", "zero-shot-object-detection",
model="hf-internal-testing/tiny-random-owlvit-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection",

View File

@ -36,6 +36,7 @@ from huggingface_hub import (
ZeroShotImageClassificationInput, ZeroShotImageClassificationInput,
) )
from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
from transformers.pipelines import ( from transformers.pipelines import (
AudioClassificationPipeline, AudioClassificationPipeline,
AutomaticSpeechRecognitionPipeline, AutomaticSpeechRecognitionPipeline,
@ -183,11 +184,14 @@ class PipelineTesterMixin:
model_architectures = self.pipeline_model_mapping[task] model_architectures = self.pipeline_model_mapping[task]
if not isinstance(model_architectures, tuple): if not isinstance(model_architectures, tuple):
model_architectures = (model_architectures,) model_architectures = (model_architectures,)
if not isinstance(model_architectures, tuple):
raise TypeError(f"`model_architectures` must be a tuple. Got {type(model_architectures)} instead.") # We are going to run tests for multiple model architectures, some of them might be skipped
# with this flag we are control if at least one model were tested or all were skipped
at_least_one_model_is_tested = False
for model_architecture in model_architectures: for model_architecture in model_architectures:
model_arch_name = model_architecture.__name__ model_arch_name = model_architecture.__name__
model_type = model_architecture.config_class.model_type
# Get the canonical name # Get the canonical name
for _prefix in ["Flax", "TF"]: for _prefix in ["Flax", "TF"]:
@ -195,31 +199,72 @@ class PipelineTesterMixin:
model_arch_name = model_arch_name[len(_prefix) :] model_arch_name = model_arch_name[len(_prefix) :]
break break
tokenizer_names = [] if model_arch_name not in tiny_model_summary:
processor_names = [] continue
commit = None
if model_arch_name in tiny_model_summary:
tokenizer_names = tiny_model_summary[model_arch_name]["tokenizer_classes"] tokenizer_names = tiny_model_summary[model_arch_name]["tokenizer_classes"]
processor_names = tiny_model_summary[model_arch_name]["processor_classes"]
if "sha" in tiny_model_summary[model_arch_name]: # Sort image processors and feature extractors from tiny-models json file
image_processor_names = []
feature_extractor_names = []
processor_classes = tiny_model_summary[model_arch_name]["processor_classes"]
for cls_name in processor_classes:
if "ImageProcessor" in cls_name:
image_processor_names.append(cls_name)
elif "FeatureExtractor" in cls_name:
feature_extractor_names.append(cls_name)
else:
raise ValueError(f"Unknown processor class: {cls_name}")
# Processor classes are not in tiny models JSON file, so extract them from the mapping
# processors are mapped to instance, e.g. "XxxProcessor"
processor_names = PROCESSOR_MAPPING_NAMES.get(model_type, None)
if not isinstance(processor_names, (list, tuple)):
processor_names = [processor_names]
commit = None
if model_arch_name in tiny_model_summary and "sha" in tiny_model_summary[model_arch_name]:
commit = tiny_model_summary[model_arch_name]["sha"] commit = tiny_model_summary[model_arch_name]["sha"]
# Adding `None` (if empty) so we can generate tests
tokenizer_names = [None] if len(tokenizer_names) == 0 else tokenizer_names
processor_names = [None] if len(processor_names) == 0 else processor_names
repo_name = f"tiny-random-{model_arch_name}" repo_name = f"tiny-random-{model_arch_name}"
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing": if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
repo_name = model_arch_name repo_name = model_arch_name
self.run_model_pipeline_tests( self.run_model_pipeline_tests(
task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype task,
repo_name,
model_architecture,
tokenizer_names=tokenizer_names,
image_processor_names=image_processor_names,
feature_extractor_names=feature_extractor_names,
processor_names=processor_names,
commit=commit,
torch_dtype=torch_dtype,
) )
if task in task_to_pipeline_and_spec_mapping: if task in task_to_pipeline_and_spec_mapping:
pipeline, hub_spec = task_to_pipeline_and_spec_mapping[task] pipeline, hub_spec = task_to_pipeline_and_spec_mapping[task]
compare_pipeline_args_to_hub_spec(pipeline, hub_spec) compare_pipeline_args_to_hub_spec(pipeline, hub_spec)
at_least_one_model_is_tested = True
if not at_least_one_model_is_tested:
self.skipTest(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find any "
f"model architecture in the tiny models JSON file for `{task}`."
)
def run_model_pipeline_tests( def run_model_pipeline_tests(
self, task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype="float32" self,
task,
repo_name,
model_architecture,
tokenizer_names,
image_processor_names,
feature_extractor_names,
processor_names,
commit,
torch_dtype="float32",
): ):
"""Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class names """Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class names
@ -232,8 +277,12 @@ class PipelineTesterMixin:
A subclass of `PretrainedModel` or `PretrainedModel`. A subclass of `PretrainedModel` or `PretrainedModel`.
tokenizer_names (`List[str]`): tokenizer_names (`List[str]`):
A list of names of a subclasses of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`. A list of names of a subclasses of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`.
image_processor_names (`List[str]`):
A list of names of subclasses of `BaseImageProcessor`.
feature_extractor_names (`List[str]`):
A list of names of subclasses of `FeatureExtractionMixin`.
processor_names (`List[str]`): processor_names (`List[str]`):
A list of names of subclasses of `BaseImageProcessor` or `FeatureExtractionMixin`. A list of names of subclasses of `ProcessorMixin`.
commit (`str`): commit (`str`):
The commit hash of the model repository on the Hub. The commit hash of the model repository on the Hub.
torch_dtype (`str`, `optional`, defaults to `'float32'`): torch_dtype (`str`, `optional`, defaults to `'float32'`):
@ -243,27 +292,73 @@ class PipelineTesterMixin:
# `run_pipeline_test`. # `run_pipeline_test`.
pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__ pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__
for tokenizer_name in tokenizer_names: # If no image processor or feature extractor is found, we still need to test the pipeline with None
for processor_name in processor_names: # otherwise for any empty list we might skip all the tests
if self.is_pipeline_test_to_skip( tokenizer_names = tokenizer_names or [None]
image_processor_names = image_processor_names or [None]
feature_extractor_names = feature_extractor_names or [None]
processor_names = processor_names or [None]
test_cases = [
{
"tokenizer_name": tokenizer_name,
"image_processor_name": image_processor_name,
"feature_extractor_name": feature_extractor_name,
"processor_name": processor_name,
}
for tokenizer_name in tokenizer_names
for image_processor_name in image_processor_names
for feature_extractor_name in feature_extractor_names
for processor_name in processor_names
]
for test_case in test_cases:
tokenizer_name = test_case["tokenizer_name"]
image_processor_name = test_case["image_processor_name"]
feature_extractor_name = test_case["feature_extractor_name"]
processor_name = test_case["processor_name"]
do_skip_test_case = self.is_pipeline_test_to_skip(
pipeline_test_class_name, pipeline_test_class_name,
model_architecture.config_class, model_architecture.config_class,
model_architecture, model_architecture,
tokenizer_name, tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name, processor_name,
): )
if do_skip_test_case:
logger.warning( logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is " f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is "
f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer " f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer "
f"`{tokenizer_name}` | processor `{processor_name}`." f"`{tokenizer_name}` | image processor `{image_processor_name}` | feature extractor {feature_extractor_name}."
) )
continue continue
self.run_pipeline_test( self.run_pipeline_test(
task, repo_name, model_architecture, tokenizer_name, processor_name, commit, torch_dtype task,
repo_name,
model_architecture,
tokenizer_name=tokenizer_name,
image_processor_name=image_processor_name,
feature_extractor_name=feature_extractor_name,
processor_name=processor_name,
commit=commit,
torch_dtype=torch_dtype,
) )
def run_pipeline_test( def run_pipeline_test(
self, task, repo_name, model_architecture, tokenizer_name, processor_name, commit, torch_dtype="float32" self,
task,
repo_name,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
commit,
torch_dtype="float32",
): ):
"""Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class name """Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class name
@ -278,43 +373,24 @@ class PipelineTesterMixin:
A subclass of `PretrainedModel` or `PretrainedModel`. A subclass of `PretrainedModel` or `PretrainedModel`.
tokenizer_name (`str`): tokenizer_name (`str`):
The name of a subclass of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`. The name of a subclass of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`.
image_processor_name (`str`):
The name of a subclass of `BaseImageProcessor`.
feature_extractor_name (`str`):
The name of a subclass of `FeatureExtractionMixin`.
processor_name (`str`): processor_name (`str`):
The name of a subclass of `BaseImageProcessor` or `FeatureExtractionMixin`. The name of a subclass of `ProcessorMixin`.
commit (`str`): commit (`str`):
The commit hash of the model repository on the Hub. The commit hash of the model repository on the Hub.
torch_dtype (`str`, `optional`, defaults to `'float32'`): torch_dtype (`str`, `optional`, defaults to `'float32'`):
The torch dtype to use for the model. Can be used for FP16/other precision inference. The torch dtype to use for the model. Can be used for FP16/other precision inference.
""" """
repo_id = f"{TRANSFORMERS_TINY_MODEL_PATH}/{repo_name}" repo_id = f"{TRANSFORMERS_TINY_MODEL_PATH}/{repo_name}"
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
model_type = model_architecture.config_class.model_type model_type = model_architecture.config_class.model_type
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
repo_id = os.path.join(TRANSFORMERS_TINY_MODEL_PATH, model_type, repo_name) repo_id = os.path.join(TRANSFORMERS_TINY_MODEL_PATH, model_type, repo_name)
tokenizer = None # -------------------- Load model --------------------
if tokenizer_name is not None:
tokenizer_class = getattr(transformers_module, tokenizer_name)
tokenizer = tokenizer_class.from_pretrained(repo_id, revision=commit)
processor = None
if processor_name is not None:
processor_class = getattr(transformers_module, processor_name)
# If the required packages (like `Pillow` or `torchaudio`) are not installed, this will fail.
try:
processor = processor_class.from_pretrained(repo_id, revision=commit)
except Exception:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not load the "
f"processor from `{repo_id}` with `{processor_name}`."
)
self.skipTest(f"Could not load the processor from {repo_id} with {processor_name}.")
# TODO: Maybe not upload such problematic tiny models to Hub.
if tokenizer is None and processor is None:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find or load "
f"any tokenizer / processor from `{repo_id}`."
)
self.skipTest(f"Could not find or load any tokenizer / processor from {repo_id}.")
# TODO: We should check if a model file is on the Hub repo. instead. # TODO: We should check if a model file is on the Hub repo. instead.
try: try:
@ -326,19 +402,57 @@ class PipelineTesterMixin:
) )
self.skipTest(f"Could not find or load the model from {repo_id} with {model_architecture}.") self.skipTest(f"Could not find or load the model from {repo_id} with {model_architecture}.")
# -------------------- Load tokenizer --------------------
tokenizer = None
if tokenizer_name is not None:
tokenizer_class = getattr(transformers_module, tokenizer_name)
tokenizer = tokenizer_class.from_pretrained(repo_id, revision=commit)
# -------------------- Load processors --------------------
processors = {}
for key, name in zip(
["image_processor", "feature_extractor", "processor"],
[image_processor_name, feature_extractor_name, processor_name],
):
if name is not None:
try:
# Can fail if some extra dependencies are not installed
processor_class = getattr(transformers_module, name)
processor = processor_class.from_pretrained(repo_id, revision=commit)
processors[key] = processor
except Exception:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: "
f"Could not load the {key} from `{repo_id}` with `{name}`."
)
self.skipTest(f"Could not load the {key} from {repo_id} with {name}.")
# ---------------------------------------------------------
# TODO: Maybe not upload such problematic tiny models to Hub.
if tokenizer is None and "image_processor" not in processors and "feature_extractor" not in processors:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find or load "
f"any tokenizer / image processor / feature extractor from `{repo_id}`."
)
self.skipTest(f"Could not find or load any tokenizer / processor from {repo_id}.")
pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__ pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__
if self.is_pipeline_test_to_skip_more(pipeline_test_class_name, model.config, model, tokenizer, processor): if self.is_pipeline_test_to_skip_more(pipeline_test_class_name, model.config, model, tokenizer, **processors):
logger.warning( logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is " f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is "
f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer " f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer "
f"`{tokenizer_name}` | processor `{processor_name}`." f"`{tokenizer_name}` | image processor `{image_processor_name}` | feature extractor `{feature_extractor_name}`."
) )
self.skipTest( self.skipTest(
f"Test is known to fail for: model `{model_architecture.__name__}` | tokenizer `{tokenizer_name}` | processor `{processor_name}`." f"Test is known to fail for: model `{model_architecture.__name__}` | tokenizer `{tokenizer_name}` "
f"| image processor `{image_processor_name}` | feature extractor `{feature_extractor_name}`."
) )
# validate # validate
validate_test_components(self, task, model, tokenizer, processor) validate_test_components(model, tokenizer)
if hasattr(model, "eval"): if hasattr(model, "eval"):
model = model.eval() model = model.eval()
@ -347,7 +461,7 @@ class PipelineTesterMixin:
# `run_pipeline_test`. # `run_pipeline_test`.
task_test = pipeline_test_mapping[task]["test"]() task_test = pipeline_test_mapping[task]["test"]()
pipeline, examples = task_test.get_test_pipeline(model, tokenizer, processor, torch_dtype=torch_dtype) pipeline, examples = task_test.get_test_pipeline(model, tokenizer, **processors, torch_dtype=torch_dtype)
if pipeline is None: if pipeline is None:
# The test can disable itself, but it should be very marginal # The test can disable itself, but it should be very marginal
# Concerns: Wav2Vec2ForCTC without tokenizer test (FastTokenizer don't exist) # Concerns: Wav2Vec2ForCTC without tokenizer test (FastTokenizer don't exist)
@ -674,7 +788,14 @@ class PipelineTesterMixin:
# This contains the test cases to be skipped without model architecture being involved. # This contains the test cases to be skipped without model architecture being involved.
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
): ):
"""Skip some tests based on the classes or their names without the instantiated objects. """Skip some tests based on the classes or their names without the instantiated objects.
@ -682,7 +803,7 @@ class PipelineTesterMixin:
""" """
# No fix is required for this case. # No fix is required for this case.
if ( if (
pipeline_test_casse_name == "DocumentQuestionAnsweringPipelineTests" pipeline_test_case_name == "DocumentQuestionAnsweringPipelineTests"
and tokenizer_name is not None and tokenizer_name is not None
and not tokenizer_name.endswith("Fast") and not tokenizer_name.endswith("Fast")
): ):
@ -691,11 +812,20 @@ class PipelineTesterMixin:
return False return False
def is_pipeline_test_to_skip_more(self, pipeline_test_casse_name, config, model, tokenizer, processor): # noqa def is_pipeline_test_to_skip_more(
self,
pipeline_test_case_name,
config,
model,
tokenizer,
image_processor=None,
feature_extractor=None,
processor=None,
): # noqa
"""Skip some more tests based on the information from the instantiated objects.""" """Skip some more tests based on the information from the instantiated objects."""
# No fix is required for this case. # No fix is required for this case.
if ( if (
pipeline_test_casse_name == "QAPipelineTests" pipeline_test_case_name == "QAPipelineTests"
and tokenizer is not None and tokenizer is not None
and getattr(tokenizer, "pad_token", None) is None and getattr(tokenizer, "pad_token", None) is None
and not tokenizer.__class__.__name__.endswith("Fast") and not tokenizer.__class__.__name__.endswith("Fast")
@ -706,7 +836,7 @@ class PipelineTesterMixin:
return False return False
def validate_test_components(test_case, task, model, tokenizer, processor): def validate_test_components(model, tokenizer):
# TODO: Move this to tiny model creation script # TODO: Move this to tiny model creation script
# head-specific (within a model type) necessary changes to the config # head-specific (within a model type) necessary changes to the config
# 1. for `BlenderbotForCausalLM` # 1. for `BlenderbotForCausalLM`