Make pipeline able to load processor (#32514)

* Refactor get_test_pipeline

* Fixup

* Fixing tests

* Add processor loading in tests

* Restructure processors loading

* Add processor to the pipeline

* Move model loading on tom of the test

* Update `get_test_pipeline`

* Fixup

* Add class-based flags for loading processors

* Change `is_pipeline_test_to_skip` signature

* Skip t5 failing test for slow tokenizer

* Fixup

* Fix copies for T5

* Fix typo

* Add try/except for tokenizer loading (kosmos-2 case)

* Fixup

* Llama not fails for long generation

* Revert processor pass in text-generation test

* Fix docs

* Switch back to json file for image processors and feature extractors

* Add processor type check

* Remove except for tokenizers

* Fix docstring

* Fix empty lists for tests

* Fixup

* Fix load check

* Ensure we have non-empty test cases

* Update src/transformers/pipelines/__init__.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update src/transformers/pipelines/base.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Rework comment

* Better docs, add note about pipeline components

* Change warning to error raise

* Fixup

* Refine pipeline docs

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Pavel Iakubovskii 2024-10-09 16:46:11 +01:00 committed by GitHub
parent 4fb28703ad
commit 48461c0fe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
91 changed files with 1312 additions and 241 deletions

View File

@ -28,7 +28,9 @@ from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage
from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
CONFIG_NAME,
@ -556,6 +558,7 @@ def pipeline(
tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
image_processor: Optional[Union[str, BaseImageProcessor]] = None,
processor: Optional[Union[str, ProcessorMixin]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = True,
@ -571,11 +574,19 @@ def pipeline(
"""
Utility factory method to build a [`Pipeline`].
Pipelines are made of:
A pipeline consists of:
- A [tokenizer](tokenizer) in charge of mapping raw textual input to token.
- A [model](model) to make predictions from the inputs.
- Some (optional) post processing for enhancing model's output.
- One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer),
[image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors).
- A [model](model) that generates predictions from the inputs.
- Optional post-processing steps to refine the model's output, which can also be handled by processors.
<Tip>
While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`,
they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load
required ones automatically. In case you want to provide these components explicitly, please refer to a
specific pipeline in order to get more details regarding what components are required.
</Tip>
Args:
task (`str`):
@ -644,6 +655,25 @@ def pipeline(
`model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it
is a string). However, if `config` is also not given or not a string, then the default feature extractor
for the given `task` will be loaded.
image_processor (`str` or [`BaseImageProcessor`], *optional*):
The image processor that will be used by the pipeline to preprocess images for the model. This can be a
model identifier or an actual image processor inheriting from [`BaseImageProcessor`].
Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal
models will also require a tokenizer to be passed.
If not provided, the default image processor for the given `model` will be loaded (if it is a string). If
`model` is not specified or not a string, then the default image processor for `config` is loaded (if it is
a string).
processor (`str` or [`ProcessorMixin`], *optional*):
The processor that will be used by the pipeline to preprocess data for the model. This can be a model
identifier or an actual processor inheriting from [`ProcessorMixin`].
Processors are used for multi-modal models that require multi-modal inputs, for example, a model that
requires both text and image inputs.
If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model`
is not specified or not a string, then the default processor for `config` is loaded (if it is a string).
framework (`str`, *optional*):
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
installed.
@ -905,13 +935,17 @@ def pipeline(
model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash
load_tokenizer = (
type(model_config) in TOKENIZER_MAPPING
or model_config.tokenizer_class is not None
or isinstance(tokenizer, str)
)
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None
# Check that pipeline class required loading
load_tokenizer = load_tokenizer and pipeline_class._load_tokenizer
load_feature_extractor = load_feature_extractor and pipeline_class._load_feature_extractor
load_image_processor = load_image_processor and pipeline_class._load_image_processor
load_processor = load_processor and pipeline_class._load_processor
# If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while
# `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some
@ -1074,6 +1108,31 @@ def pipeline(
if not is_pyctcdecode_available():
logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode")
if load_processor:
# Try to infer processor from model or config name (if provided as str)
if processor is None:
if isinstance(model_name, str):
processor = model_name
elif isinstance(config, str):
processor = config
else:
# Impossible to guess what is the right processor here
raise Exception(
"Impossible to guess which processor to use. "
"Please provide a processor instance or a path/identifier "
"to a processor."
)
# Instantiate processor if needed
if isinstance(processor, (str, tuple)):
processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs)
if not isinstance(processor, ProcessorMixin):
raise TypeError(
"Processor was loaded, but it is not an instance of `ProcessorMixin`. "
f"Got type `{type(processor)}` instead. Please check that you specified "
"correct pipeline task for the model and model has processor implemented and saved."
)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
@ -1099,4 +1158,7 @@ def pipeline(
if device is not None:
kwargs["device"] = device
if processor is not None:
kwargs["processor"] = processor
return pipeline_class(model=model, framework=framework, task=task, **kwargs)

View File

@ -34,6 +34,7 @@ from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..processing_utils import ProcessorMixin
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
ModelOutput,
@ -716,6 +717,7 @@ def build_pipeline_init_args(
has_tokenizer: bool = False,
has_feature_extractor: bool = False,
has_image_processor: bool = False,
has_processor: bool = False,
supports_binary_output: bool = True,
) -> str:
docstring = r"""
@ -738,6 +740,12 @@ def build_pipeline_init_args(
image_processor ([`BaseImageProcessor`]):
The image processor that will be used by the pipeline to encode data for the model. This object inherits from
[`BaseImageProcessor`]."""
if has_processor:
docstring += r"""
processor ([`ProcessorMixin`]):
The processor that will be used by the pipeline to encode data for the model. This object inherits from
[`ProcessorMixin`]. Processor is a composite object that might contain `tokenizer`, `feature_extractor`, and
`image_processor`."""
docstring += r"""
modelcard (`str` or [`ModelCard`], *optional*):
Model card attributed to the model for this pipeline.
@ -774,7 +782,11 @@ def build_pipeline_init_args(
PIPELINE_INIT_ARGS = build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True
has_tokenizer=True,
has_feature_extractor=True,
has_image_processor=True,
has_processor=True,
supports_binary_output=True,
)
@ -787,7 +799,11 @@ if is_torch_available():
)
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_feature_extractor=True, has_image_processor=True))
@add_end_docstrings(
build_pipeline_init_args(
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True
)
)
class Pipeline(_ScikitCompat, PushToHubMixin):
"""
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
@ -805,6 +821,22 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
constructor argument. If set to `True`, the output will be stored in the pickle format.
"""
# Historically we have pipelines working with `tokenizer`, `feature_extractor`, and `image_processor`
# as separate processing components. While we have `processor` class that combines them, some pipelines
# might still operate with these components separately.
# With the addition of `processor` to `pipeline`, we want to avoid:
# - loading `processor` for pipelines that still work with `image_processor` and `tokenizer` separately;
# - loading `image_processor`/`tokenizer` as a separate component while we operate only with `processor`,
# because `processor` will load required sub-components by itself.
# Below flags allow granular control over loading components and set to be backward compatible with current
# pipelines logic. You may override these flags when creating your pipeline. For example, for
# `zero-shot-object-detection` pipeline which operates with `processor` you should set `_load_processor=True`
# and all the rest flags to `False` to avoid unnecessary loading of the components.
_load_processor = False
_load_image_processor = True
_load_feature_extractor = True
_load_tokenizer = True
default_input_names = None
def __init__(
@ -813,6 +845,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
tokenizer: Optional[PreTrainedTokenizer] = None,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
image_processor: Optional[BaseImageProcessor] = None,
processor: Optional[ProcessorMixin] = None,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
@ -830,6 +863,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
self.tokenizer = tokenizer
self.feature_extractor = feature_extractor
self.image_processor = image_processor
self.processor = processor
self.modelcard = modelcard
self.framework = framework

View File

@ -436,9 +436,16 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "FeatureExtractionPipelineTests":
if pipeline_test_case_name == "FeatureExtractionPipelineTests":
return True
return False

View File

@ -167,9 +167,16 @@ class ASTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "AudioClassificationPipelineTests":
if pipeline_test_case_name == "AudioClassificationPipelineTests":
return True
return False

View File

@ -276,9 +276,16 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -236,9 +236,16 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return pipeline_test_casse_name == "TextGenerationPipelineTests"
return pipeline_test_case_name == "TextGenerationPipelineTests"
def setUp(self):
self.model_tester = BlenderbotSmallModelTester(self)

View File

@ -321,9 +321,16 @@ class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, Flax
all_generative_model_classes = (FlaxBlenderbotSmallForConditionalGeneration,) if is_flax_available() else ()
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return pipeline_test_casse_name == "TextGenerationPipelineTests"
return pipeline_test_case_name == "TextGenerationPipelineTests"
def setUp(self):
self.model_tester = FlaxBlenderbotSmallModelTester(self)

View File

@ -198,9 +198,16 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
test_onnx = False
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return pipeline_test_casse_name == "TextGenerationPipelineTests"
return pipeline_test_case_name == "TextGenerationPipelineTests"
def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)

View File

@ -295,7 +295,14 @@ class BrosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# BROS requires `bbox` in the inputs which doesn't fit into the above 2 pipelines' input formats.
# see https://github.com/huggingface/transformers/pull/26294
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -22,7 +22,14 @@ from transformers.testing_utils import custom_tokenizers
class CpmTokenizationTest(unittest.TestCase):
# There is no `CpmModel`
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -211,9 +211,16 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny
# config could not be created.

View File

@ -189,9 +189,16 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny
# config could not be created.

View File

@ -312,7 +312,14 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -392,10 +392,17 @@ class FlaubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -309,10 +309,17 @@ class TFFlaubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -298,9 +298,16 @@ class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -326,7 +326,14 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -382,10 +382,17 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -322,10 +322,17 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -260,9 +260,16 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: `image-to-text` pipeline for this model needs Processor.
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return pipeline_test_casse_name == "ImageToTextPipelineTests"
return pipeline_test_case_name == "ImageToTextPipelineTests"
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)

View File

@ -292,7 +292,14 @@ class LayoutLMv3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
# `DocumentQuestionAnsweringPipeline` is expected to work with this model, but it combines the text and visual
# embedding along the sequence dimension (dim 1), which causes an error during post-processing as `p_mask` has

View File

@ -288,7 +288,14 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -302,9 +302,16 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -245,7 +245,14 @@ class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -331,10 +331,17 @@ class LongformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -303,10 +303,17 @@ class TFLongformerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -621,9 +621,16 @@ class LukeModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name in ["QAPipelineTests", "ZeroShotClassificationPipelineTests"]:
if pipeline_test_case_name in ["QAPipelineTests", "ZeroShotClassificationPipelineTests"]:
return True
return False

View File

@ -258,9 +258,16 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "TranslationPipelineTests":
if pipeline_test_case_name == "TranslationPipelineTests":
# Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
# `M2M100Config` was never used in pipeline tests: cannot create a simple tokenizer.
return True

View File

@ -301,7 +301,14 @@ class MarkupLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
# ValueError: Nodes must be of type `List[str]` (single pretokenized example), or `List[List[str]]`
# (batch of pretokenized examples).

View File

@ -252,9 +252,16 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -175,9 +175,16 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name != "FeatureExtractionPipelineTests":
if pipeline_test_case_name != "FeatureExtractionPipelineTests":
# Exception encountered when calling layer '...'
return True

View File

@ -313,7 +313,14 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -266,7 +266,14 @@ class TFMistralModelTest(TFModelTesterMixin, TFGenerationIntegrationTests, Pipel
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -313,7 +313,14 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -582,7 +582,14 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
def is_pipeline_test_to_skip(
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
@ -1055,6 +1062,26 @@ class MT5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
# `MT5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
@require_torch
@require_sentencepiece

View File

@ -441,10 +441,17 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -266,7 +266,14 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
# Saving the slow tokenizer after saving the fast tokenizer causes the loading of the later hanging forever.
return True

View File

@ -247,9 +247,16 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "FeatureExtractionPipelineTests":
if pipeline_test_case_name == "FeatureExtractionPipelineTests":
return True
return False

View File

@ -211,9 +211,16 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a
# tiny config could not be created.

View File

@ -217,9 +217,16 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Tes
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
# `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a
# tiny config could not be created.

View File

@ -221,10 +221,17 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -304,7 +304,14 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -351,7 +351,14 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -245,9 +245,16 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "TranslationPipelineTests":
if pipeline_test_case_name == "TranslationPipelineTests":
# Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
# `PLBartConfig` was never used in pipeline tests: cannot create a simple tokenizer.
return True

View File

@ -906,9 +906,16 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "TextGenerationPipelineTests":
if pipeline_test_case_name == "TextGenerationPipelineTests":
# Get `ValueError: AttributeError: 'NoneType' object has no attribute 'new_ones'` or `AssertionError`.
# `ProphetNetConfig` was never used in pipeline tests: cannot create a simple
# tokenizer.

View File

@ -322,7 +322,14 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -349,7 +349,14 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -305,7 +305,14 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -727,10 +727,17 @@ class ReformerLSHAttnModelTest(
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -587,9 +587,16 @@ class RoCBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name in [
if pipeline_test_case_name in [
"FillMaskPipelineTests",
"FeatureExtractionPipelineTests",
"TextClassificationPipelineTests",

View File

@ -275,9 +275,16 @@ class TFRoFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
# TODO: add `prepare_inputs_for_generation` for `TFRoFormerForCausalLM`
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "TextGenerationPipelineTests":
if pipeline_test_case_name == "TextGenerationPipelineTests":
return True
return False

View File

@ -299,7 +299,14 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -290,7 +290,14 @@ class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -225,11 +225,18 @@ class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# TODO: Fix the failed tests when this model gets more usage
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests":
if pipeline_test_case_name == "QAPipelineTests":
return True
elif pipeline_test_casse_name == "FeatureExtractionPipelineTests" and tokenizer_name.endswith("Fast"):
elif pipeline_test_case_name == "FeatureExtractionPipelineTests" and tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -311,7 +311,14 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -585,7 +585,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
def is_pipeline_test_to_skip(
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
@ -1056,6 +1063,26 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
# `T5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])

View File

@ -492,7 +492,14 @@ class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -447,7 +447,14 @@ class TFTapasModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True

View File

@ -322,9 +322,16 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
@ -723,6 +730,26 @@ class UMT5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if tokenizer_name is None:
return True
# `UMT5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False
@require_torch
@require_sentencepiece

View File

@ -397,9 +397,16 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name in [
if pipeline_test_case_name in [
"AutomaticSpeechRecognitionPipelineTests",
"AudioClassificationPipelineTests",
]:

View File

@ -312,10 +312,17 @@ class TFXLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -393,10 +393,17 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):

View File

@ -392,9 +392,16 @@ class XLMRobertaXLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -372,7 +372,14 @@ class TFXLNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
# Exception encountered when calling layer '...'
return True

View File

@ -543,9 +543,16 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -386,9 +386,16 @@ class XmodModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
return True
return False

View File

@ -37,9 +37,22 @@ class AudioClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
audio_classifier = AudioClassificationPipeline(
model=model, feature_extractor=processor, torch_dtype=torch_dtype
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
# test with a raw waveform

View File

@ -67,14 +67,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if tokenizer is None:
# Side effect of no Fast Tokenizer class for these model, so skipping
# But the slow tokenizer test should still run as they're quite small
self.skipTest(reason="No tokenizer available")
speech_recognizer = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
# test with a raw waveform

View File

@ -58,8 +58,23 @@ def hashimage(image: Image) -> str:
class DepthEstimationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
depth_estimator = DepthEstimationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
depth_estimator = DepthEstimationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return depth_estimator, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",

View File

@ -15,7 +15,7 @@
import unittest
from transformers import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, AutoTokenizer, is_vision_available
from transformers.pipelines import pipeline
from transformers.pipelines import DocumentQuestionAnsweringPipeline, pipeline
from transformers.pipelines.document_question_answering import apply_tesseract
from transformers.testing_utils import (
is_pipeline_test,
@ -61,12 +61,21 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase):
@require_pytesseract
@require_vision
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
dqa_pipeline = pipeline(
"document-question-answering",
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
dqa_pipeline = DocumentQuestionAnsweringPipeline(
model=model,
tokenizer=tokenizer,
image_processor=processor,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)

View File

@ -174,7 +174,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase):
raise TypeError("We expect lists of floats, nothing else")
return shape
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if tokenizer is None:
self.skipTest(reason="No tokenizer")
elif (
@ -193,10 +201,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase):
For now ignore those.
"""
)
feature_extractor = FeatureExtractionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype
feature_extractor_pipeline = FeatureExtractionPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return feature_extractor, ["This is a test", "This is another test"]
return feature_extractor_pipeline, ["This is a test", "This is another test"]
def run_pipeline_test(self, feature_extractor, examples):
outputs = feature_extractor("This is a test")

View File

@ -251,11 +251,26 @@ class FillMaskPipelineTests(unittest.TestCase):
unmasker.tokenizer.pad_token = None
self.run_pipeline_test(unmasker, [])
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if tokenizer is None or tokenizer.mask_token_id is None:
self.skipTest(reason="The provided tokenizer has no mask token, (probably reformer or wav2vec2)")
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
fill_masker = FillMaskPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
examples = [
f"This is another {tokenizer.mask_token} test",
]

View File

@ -58,9 +58,23 @@ class ImageClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
image_classifier = ImageClassificationPipeline(
model=model, image_processor=processor, top_k=2, torch_dtype=torch_dtype
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
top_k=2,
)
examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),

View File

@ -157,8 +157,16 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
outputs = feature_extractor(img, return_tensors=True)
self.assertTrue(tf.is_tensor(outputs))
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
if processor is None:
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if image_processor is None:
self.skipTest(reason="No image processor")
elif type(model.config) in TOKENIZER_MAPPING:
@ -175,11 +183,16 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
"""
)
feature_extractor = ImageFeatureExtractionPipeline(
model=model, image_processor=processor, torch_dtype=torch_dtype
feature_extractor_pipeline = ImageFeatureExtractionPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
img = prepare_img()
return feature_extractor, [img, img]
return feature_extractor_pipeline, [img, img]
def run_pipeline_test(self, feature_extractor, examples):
imgs = examples

View File

@ -89,8 +89,23 @@ class ImageSegmentationPipelineTests(unittest.TestCase):
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
image_segmenter = ImageSegmentationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
image_segmenter = ImageSegmentationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",

View File

@ -47,9 +47,22 @@ class ImageToTextPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
pipe = ImageToTextPipeline(
model=model, tokenizer=tokenizer, image_processor=processor, torch_dtype=torch_dtype
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),

View File

@ -67,8 +67,23 @@ class MaskGenerationPipelineTests(unittest.TestCase):
(list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else [])
)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
image_segmenter = MaskGenerationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",

View File

@ -56,8 +56,23 @@ else:
class ObjectDetectionPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
object_detector = ObjectDetectionPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
object_detector = ObjectDetectionPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
def run_pipeline_test(self, object_detector, examples):

View File

@ -50,12 +50,27 @@ class QAPipelineTests(unittest.TestCase):
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
}
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if isinstance(model.config, LxmertConfig):
# This is an bimodal model, we need to find a more consistent way
# to switch on those models.
return None, None
question_answerer = QuestionAnsweringPipeline(model, tokenizer, torch_dtype=torch_dtype)
question_answerer = QuestionAnsweringPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
examples = [
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},

View File

@ -32,8 +32,23 @@ class SummarizationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
summarizer = SummarizationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
summarizer = SummarizationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return summarizer, ["(CNN)The Palestinian Authority officially became", "Some other text"]
def run_pipeline_test(self, summarizer, _):

View File

@ -35,8 +35,23 @@ class Text2TextGenerationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
generator = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
generator = Text2TextGenerationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return generator, ["Something to write", "Something else"]
def run_pipeline_test(self, generator, _):

View File

@ -179,8 +179,23 @@ class TextClassificationPipelineTests(unittest.TestCase):
outputs = text_classifier("Birds are a type of animal")
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
text_classifier = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return text_classifier, ["HuggingFace is in", "This is another test"]
def run_pipeline_test(self, text_classifier, _):

View File

@ -377,8 +377,23 @@ class TextGenerationPipelineTests(unittest.TestCase):
],
)
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
text_generator = TextGenerationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return text_generator, ["This is a test", "Another test"]
def test_stop_sequence_stopping_criteria(self):
@ -471,6 +486,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
"GPTNeoXForCausalLM",
"GPTNeoXJapaneseForCausalLM",
"FuyuForCausalLM",
"LlamaForCausalLM",
]
if (
tokenizer.model_max_length < 10000

View File

@ -250,8 +250,23 @@ class TextToAudioPipelineTests(unittest.TestCase):
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
speech_generator = TextToAudioPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return speech_generator, ["This is a test", "Another test"]
def run_pipeline_test(self, speech_generator, _):

View File

@ -61,8 +61,23 @@ class TokenClassificationPipelineTests(unittest.TestCase):
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
}
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
token_classifier = TokenClassificationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return token_classifier, ["A simple string", "A simple string that is quite a bit longer"]
def run_pipeline_test(self, token_classifier, _):

View File

@ -35,14 +35,36 @@ class TranslationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
if isinstance(model.config, MBartConfig):
src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2]
translator = TranslationPipeline(
model=model, tokenizer=tokenizer, src_lang=src_lang, tgt_lang=tgt_lang, torch_dtype=torch_dtype
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
src_lang=src_lang,
tgt_lang=tgt_lang,
)
else:
translator = TranslationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
translator = TranslationPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
)
return translator, ["Some string", "Some other text"]
def run_pipeline_test(self, translator, _):

View File

@ -38,12 +38,26 @@ from .test_pipelines_common import ANY
class VideoClassificationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
example_video_filepath = hf_hub_download(
repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset"
)
video_classifier = VideoClassificationPipeline(
model=model, image_processor=processor, top_k=2, torch_dtype=torch_dtype
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
top_k=2,
)
examples = [
example_video_filepath,

View File

@ -55,9 +55,19 @@ else:
class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
vqa_pipeline = pipeline(
"visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa", torch_dtype=torch_dtype
"visual-question-answering",
model="hf-internal-testing/tiny-vilt-random-vqa",
torch_dtype=torch_dtype,
)
examples = [
{

View File

@ -53,9 +53,23 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
}
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
classifier = ZeroShotClassificationPipeline(
model=model, tokenizer=tokenizer, candidate_labels=["polics", "health"], torch_dtype=torch_dtype
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
image_processor=image_processor,
processor=processor,
torch_dtype=torch_dtype,
candidate_labels=["polics", "health"],
)
return classifier, ["Who are you voting for in 2020?", "My stomach hurts."]

View File

@ -43,7 +43,15 @@ else:
class ZeroShotObjectDetectionPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
def get_test_pipeline(
self,
model,
tokenizer=None,
image_processor=None,
feature_extractor=None,
processor=None,
torch_dtype="float32",
):
object_detector = pipeline(
"zero-shot-object-detection",
model="hf-internal-testing/tiny-random-owlvit-object-detection",

View File

@ -36,6 +36,7 @@ from huggingface_hub import (
ZeroShotImageClassificationInput,
)
from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
from transformers.pipelines import (
AudioClassificationPipeline,
AutomaticSpeechRecognitionPipeline,
@ -183,11 +184,14 @@ class PipelineTesterMixin:
model_architectures = self.pipeline_model_mapping[task]
if not isinstance(model_architectures, tuple):
model_architectures = (model_architectures,)
if not isinstance(model_architectures, tuple):
raise TypeError(f"`model_architectures` must be a tuple. Got {type(model_architectures)} instead.")
# We are going to run tests for multiple model architectures, some of them might be skipped
# with this flag we are control if at least one model were tested or all were skipped
at_least_one_model_is_tested = False
for model_architecture in model_architectures:
model_arch_name = model_architecture.__name__
model_type = model_architecture.config_class.model_type
# Get the canonical name
for _prefix in ["Flax", "TF"]:
@ -195,31 +199,72 @@ class PipelineTesterMixin:
model_arch_name = model_arch_name[len(_prefix) :]
break
tokenizer_names = []
processor_names = []
commit = None
if model_arch_name in tiny_model_summary:
if model_arch_name not in tiny_model_summary:
continue
tokenizer_names = tiny_model_summary[model_arch_name]["tokenizer_classes"]
processor_names = tiny_model_summary[model_arch_name]["processor_classes"]
if "sha" in tiny_model_summary[model_arch_name]:
# Sort image processors and feature extractors from tiny-models json file
image_processor_names = []
feature_extractor_names = []
processor_classes = tiny_model_summary[model_arch_name]["processor_classes"]
for cls_name in processor_classes:
if "ImageProcessor" in cls_name:
image_processor_names.append(cls_name)
elif "FeatureExtractor" in cls_name:
feature_extractor_names.append(cls_name)
else:
raise ValueError(f"Unknown processor class: {cls_name}")
# Processor classes are not in tiny models JSON file, so extract them from the mapping
# processors are mapped to instance, e.g. "XxxProcessor"
processor_names = PROCESSOR_MAPPING_NAMES.get(model_type, None)
if not isinstance(processor_names, (list, tuple)):
processor_names = [processor_names]
commit = None
if model_arch_name in tiny_model_summary and "sha" in tiny_model_summary[model_arch_name]:
commit = tiny_model_summary[model_arch_name]["sha"]
# Adding `None` (if empty) so we can generate tests
tokenizer_names = [None] if len(tokenizer_names) == 0 else tokenizer_names
processor_names = [None] if len(processor_names) == 0 else processor_names
repo_name = f"tiny-random-{model_arch_name}"
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
repo_name = model_arch_name
self.run_model_pipeline_tests(
task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype
task,
repo_name,
model_architecture,
tokenizer_names=tokenizer_names,
image_processor_names=image_processor_names,
feature_extractor_names=feature_extractor_names,
processor_names=processor_names,
commit=commit,
torch_dtype=torch_dtype,
)
if task in task_to_pipeline_and_spec_mapping:
pipeline, hub_spec = task_to_pipeline_and_spec_mapping[task]
compare_pipeline_args_to_hub_spec(pipeline, hub_spec)
at_least_one_model_is_tested = True
if not at_least_one_model_is_tested:
self.skipTest(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find any "
f"model architecture in the tiny models JSON file for `{task}`."
)
def run_model_pipeline_tests(
self, task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype="float32"
self,
task,
repo_name,
model_architecture,
tokenizer_names,
image_processor_names,
feature_extractor_names,
processor_names,
commit,
torch_dtype="float32",
):
"""Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class names
@ -232,8 +277,12 @@ class PipelineTesterMixin:
A subclass of `PretrainedModel` or `PretrainedModel`.
tokenizer_names (`List[str]`):
A list of names of a subclasses of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`.
image_processor_names (`List[str]`):
A list of names of subclasses of `BaseImageProcessor`.
feature_extractor_names (`List[str]`):
A list of names of subclasses of `FeatureExtractionMixin`.
processor_names (`List[str]`):
A list of names of subclasses of `BaseImageProcessor` or `FeatureExtractionMixin`.
A list of names of subclasses of `ProcessorMixin`.
commit (`str`):
The commit hash of the model repository on the Hub.
torch_dtype (`str`, `optional`, defaults to `'float32'`):
@ -243,27 +292,73 @@ class PipelineTesterMixin:
# `run_pipeline_test`.
pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__
for tokenizer_name in tokenizer_names:
for processor_name in processor_names:
if self.is_pipeline_test_to_skip(
# If no image processor or feature extractor is found, we still need to test the pipeline with None
# otherwise for any empty list we might skip all the tests
tokenizer_names = tokenizer_names or [None]
image_processor_names = image_processor_names or [None]
feature_extractor_names = feature_extractor_names or [None]
processor_names = processor_names or [None]
test_cases = [
{
"tokenizer_name": tokenizer_name,
"image_processor_name": image_processor_name,
"feature_extractor_name": feature_extractor_name,
"processor_name": processor_name,
}
for tokenizer_name in tokenizer_names
for image_processor_name in image_processor_names
for feature_extractor_name in feature_extractor_names
for processor_name in processor_names
]
for test_case in test_cases:
tokenizer_name = test_case["tokenizer_name"]
image_processor_name = test_case["image_processor_name"]
feature_extractor_name = test_case["feature_extractor_name"]
processor_name = test_case["processor_name"]
do_skip_test_case = self.is_pipeline_test_to_skip(
pipeline_test_class_name,
model_architecture.config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
)
if do_skip_test_case:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is "
f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer "
f"`{tokenizer_name}` | processor `{processor_name}`."
f"`{tokenizer_name}` | image processor `{image_processor_name}` | feature extractor {feature_extractor_name}."
)
continue
self.run_pipeline_test(
task, repo_name, model_architecture, tokenizer_name, processor_name, commit, torch_dtype
task,
repo_name,
model_architecture,
tokenizer_name=tokenizer_name,
image_processor_name=image_processor_name,
feature_extractor_name=feature_extractor_name,
processor_name=processor_name,
commit=commit,
torch_dtype=torch_dtype,
)
def run_pipeline_test(
self, task, repo_name, model_architecture, tokenizer_name, processor_name, commit, torch_dtype="float32"
self,
task,
repo_name,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
commit,
torch_dtype="float32",
):
"""Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class name
@ -278,43 +373,24 @@ class PipelineTesterMixin:
A subclass of `PretrainedModel` or `PretrainedModel`.
tokenizer_name (`str`):
The name of a subclass of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`.
image_processor_name (`str`):
The name of a subclass of `BaseImageProcessor`.
feature_extractor_name (`str`):
The name of a subclass of `FeatureExtractionMixin`.
processor_name (`str`):
The name of a subclass of `BaseImageProcessor` or `FeatureExtractionMixin`.
The name of a subclass of `ProcessorMixin`.
commit (`str`):
The commit hash of the model repository on the Hub.
torch_dtype (`str`, `optional`, defaults to `'float32'`):
The torch dtype to use for the model. Can be used for FP16/other precision inference.
"""
repo_id = f"{TRANSFORMERS_TINY_MODEL_PATH}/{repo_name}"
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
model_type = model_architecture.config_class.model_type
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
repo_id = os.path.join(TRANSFORMERS_TINY_MODEL_PATH, model_type, repo_name)
tokenizer = None
if tokenizer_name is not None:
tokenizer_class = getattr(transformers_module, tokenizer_name)
tokenizer = tokenizer_class.from_pretrained(repo_id, revision=commit)
processor = None
if processor_name is not None:
processor_class = getattr(transformers_module, processor_name)
# If the required packages (like `Pillow` or `torchaudio`) are not installed, this will fail.
try:
processor = processor_class.from_pretrained(repo_id, revision=commit)
except Exception:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not load the "
f"processor from `{repo_id}` with `{processor_name}`."
)
self.skipTest(f"Could not load the processor from {repo_id} with {processor_name}.")
# TODO: Maybe not upload such problematic tiny models to Hub.
if tokenizer is None and processor is None:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find or load "
f"any tokenizer / processor from `{repo_id}`."
)
self.skipTest(f"Could not find or load any tokenizer / processor from {repo_id}.")
# -------------------- Load model --------------------
# TODO: We should check if a model file is on the Hub repo. instead.
try:
@ -326,19 +402,57 @@ class PipelineTesterMixin:
)
self.skipTest(f"Could not find or load the model from {repo_id} with {model_architecture}.")
# -------------------- Load tokenizer --------------------
tokenizer = None
if tokenizer_name is not None:
tokenizer_class = getattr(transformers_module, tokenizer_name)
tokenizer = tokenizer_class.from_pretrained(repo_id, revision=commit)
# -------------------- Load processors --------------------
processors = {}
for key, name in zip(
["image_processor", "feature_extractor", "processor"],
[image_processor_name, feature_extractor_name, processor_name],
):
if name is not None:
try:
# Can fail if some extra dependencies are not installed
processor_class = getattr(transformers_module, name)
processor = processor_class.from_pretrained(repo_id, revision=commit)
processors[key] = processor
except Exception:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: "
f"Could not load the {key} from `{repo_id}` with `{name}`."
)
self.skipTest(f"Could not load the {key} from {repo_id} with {name}.")
# ---------------------------------------------------------
# TODO: Maybe not upload such problematic tiny models to Hub.
if tokenizer is None and "image_processor" not in processors and "feature_extractor" not in processors:
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find or load "
f"any tokenizer / image processor / feature extractor from `{repo_id}`."
)
self.skipTest(f"Could not find or load any tokenizer / processor from {repo_id}.")
pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__
if self.is_pipeline_test_to_skip_more(pipeline_test_class_name, model.config, model, tokenizer, processor):
if self.is_pipeline_test_to_skip_more(pipeline_test_class_name, model.config, model, tokenizer, **processors):
logger.warning(
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is "
f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer "
f"`{tokenizer_name}` | processor `{processor_name}`."
f"`{tokenizer_name}` | image processor `{image_processor_name}` | feature extractor `{feature_extractor_name}`."
)
self.skipTest(
f"Test is known to fail for: model `{model_architecture.__name__}` | tokenizer `{tokenizer_name}` | processor `{processor_name}`."
f"Test is known to fail for: model `{model_architecture.__name__}` | tokenizer `{tokenizer_name}` "
f"| image processor `{image_processor_name}` | feature extractor `{feature_extractor_name}`."
)
# validate
validate_test_components(self, task, model, tokenizer, processor)
validate_test_components(model, tokenizer)
if hasattr(model, "eval"):
model = model.eval()
@ -347,7 +461,7 @@ class PipelineTesterMixin:
# `run_pipeline_test`.
task_test = pipeline_test_mapping[task]["test"]()
pipeline, examples = task_test.get_test_pipeline(model, tokenizer, processor, torch_dtype=torch_dtype)
pipeline, examples = task_test.get_test_pipeline(model, tokenizer, **processors, torch_dtype=torch_dtype)
if pipeline is None:
# The test can disable itself, but it should be very marginal
# Concerns: Wav2Vec2ForCTC without tokenizer test (FastTokenizer don't exist)
@ -674,7 +788,14 @@ class PipelineTesterMixin:
# This contains the test cases to be skipped without model architecture being involved.
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
"""Skip some tests based on the classes or their names without the instantiated objects.
@ -682,7 +803,7 @@ class PipelineTesterMixin:
"""
# No fix is required for this case.
if (
pipeline_test_casse_name == "DocumentQuestionAnsweringPipelineTests"
pipeline_test_case_name == "DocumentQuestionAnsweringPipelineTests"
and tokenizer_name is not None
and not tokenizer_name.endswith("Fast")
):
@ -691,11 +812,20 @@ class PipelineTesterMixin:
return False
def is_pipeline_test_to_skip_more(self, pipeline_test_casse_name, config, model, tokenizer, processor): # noqa
def is_pipeline_test_to_skip_more(
self,
pipeline_test_case_name,
config,
model,
tokenizer,
image_processor=None,
feature_extractor=None,
processor=None,
): # noqa
"""Skip some more tests based on the information from the instantiated objects."""
# No fix is required for this case.
if (
pipeline_test_casse_name == "QAPipelineTests"
pipeline_test_case_name == "QAPipelineTests"
and tokenizer is not None
and getattr(tokenizer, "pad_token", None) is None
and not tokenizer.__class__.__name__.endswith("Fast")
@ -706,7 +836,7 @@ class PipelineTesterMixin:
return False
def validate_test_components(test_case, task, model, tokenizer, processor):
def validate_test_components(model, tokenizer):
# TODO: Move this to tiny model creation script
# head-specific (within a model type) necessary changes to the config
# 1. for `BlenderbotForCausalLM`