mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Make pipeline
able to load processor
(#32514)
* Refactor get_test_pipeline * Fixup * Fixing tests * Add processor loading in tests * Restructure processors loading * Add processor to the pipeline * Move model loading on tom of the test * Update `get_test_pipeline` * Fixup * Add class-based flags for loading processors * Change `is_pipeline_test_to_skip` signature * Skip t5 failing test for slow tokenizer * Fixup * Fix copies for T5 * Fix typo * Add try/except for tokenizer loading (kosmos-2 case) * Fixup * Llama not fails for long generation * Revert processor pass in text-generation test * Fix docs * Switch back to json file for image processors and feature extractors * Add processor type check * Remove except for tokenizers * Fix docstring * Fix empty lists for tests * Fixup * Fix load check * Ensure we have non-empty test cases * Update src/transformers/pipelines/__init__.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Update src/transformers/pipelines/base.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Rework comment * Better docs, add note about pipeline components * Change warning to error raise * Fixup * Refine pipeline docs --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
4fb28703ad
commit
48461c0fe2
@ -28,7 +28,9 @@ from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor
|
||||
from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage
|
||||
from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor
|
||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from ..processing_utils import ProcessorMixin
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
@ -556,6 +558,7 @@ def pipeline(
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None,
|
||||
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
|
||||
image_processor: Optional[Union[str, BaseImageProcessor]] = None,
|
||||
processor: Optional[Union[str, ProcessorMixin]] = None,
|
||||
framework: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
use_fast: bool = True,
|
||||
@ -571,11 +574,19 @@ def pipeline(
|
||||
"""
|
||||
Utility factory method to build a [`Pipeline`].
|
||||
|
||||
Pipelines are made of:
|
||||
A pipeline consists of:
|
||||
|
||||
- A [tokenizer](tokenizer) in charge of mapping raw textual input to token.
|
||||
- A [model](model) to make predictions from the inputs.
|
||||
- Some (optional) post processing for enhancing model's output.
|
||||
- One or more components for pre-processing model inputs, such as a [tokenizer](tokenizer),
|
||||
[image_processor](image_processor), [feature_extractor](feature_extractor), or [processor](processors).
|
||||
- A [model](model) that generates predictions from the inputs.
|
||||
- Optional post-processing steps to refine the model's output, which can also be handled by processors.
|
||||
|
||||
<Tip>
|
||||
While there are such optional arguments as `tokenizer`, `feature_extractor`, `image_processor`, and `processor`,
|
||||
they shouldn't be specified all at once. If these components are not provided, `pipeline` will try to load
|
||||
required ones automatically. In case you want to provide these components explicitly, please refer to a
|
||||
specific pipeline in order to get more details regarding what components are required.
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
task (`str`):
|
||||
@ -644,6 +655,25 @@ def pipeline(
|
||||
`model` is not specified or not a string, then the default feature extractor for `config` is loaded (if it
|
||||
is a string). However, if `config` is also not given or not a string, then the default feature extractor
|
||||
for the given `task` will be loaded.
|
||||
image_processor (`str` or [`BaseImageProcessor`], *optional*):
|
||||
The image processor that will be used by the pipeline to preprocess images for the model. This can be a
|
||||
model identifier or an actual image processor inheriting from [`BaseImageProcessor`].
|
||||
|
||||
Image processors are used for Vision models and multi-modal models that require image inputs. Multi-modal
|
||||
models will also require a tokenizer to be passed.
|
||||
|
||||
If not provided, the default image processor for the given `model` will be loaded (if it is a string). If
|
||||
`model` is not specified or not a string, then the default image processor for `config` is loaded (if it is
|
||||
a string).
|
||||
processor (`str` or [`ProcessorMixin`], *optional*):
|
||||
The processor that will be used by the pipeline to preprocess data for the model. This can be a model
|
||||
identifier or an actual processor inheriting from [`ProcessorMixin`].
|
||||
|
||||
Processors are used for multi-modal models that require multi-modal inputs, for example, a model that
|
||||
requires both text and image inputs.
|
||||
|
||||
If not provided, the default processor for the given `model` will be loaded (if it is a string). If `model`
|
||||
is not specified or not a string, then the default processor for `config` is loaded (if it is a string).
|
||||
framework (`str`, *optional*):
|
||||
The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
@ -905,13 +935,17 @@ def pipeline(
|
||||
|
||||
model_config = model.config
|
||||
hub_kwargs["_commit_hash"] = model.config._commit_hash
|
||||
load_tokenizer = (
|
||||
type(model_config) in TOKENIZER_MAPPING
|
||||
or model_config.tokenizer_class is not None
|
||||
or isinstance(tokenizer, str)
|
||||
)
|
||||
|
||||
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
||||
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
||||
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
|
||||
load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None
|
||||
|
||||
# Check that pipeline class required loading
|
||||
load_tokenizer = load_tokenizer and pipeline_class._load_tokenizer
|
||||
load_feature_extractor = load_feature_extractor and pipeline_class._load_feature_extractor
|
||||
load_image_processor = load_image_processor and pipeline_class._load_image_processor
|
||||
load_processor = load_processor and pipeline_class._load_processor
|
||||
|
||||
# If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while
|
||||
# `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some
|
||||
@ -1074,6 +1108,31 @@ def pipeline(
|
||||
if not is_pyctcdecode_available():
|
||||
logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode")
|
||||
|
||||
if load_processor:
|
||||
# Try to infer processor from model or config name (if provided as str)
|
||||
if processor is None:
|
||||
if isinstance(model_name, str):
|
||||
processor = model_name
|
||||
elif isinstance(config, str):
|
||||
processor = config
|
||||
else:
|
||||
# Impossible to guess what is the right processor here
|
||||
raise Exception(
|
||||
"Impossible to guess which processor to use. "
|
||||
"Please provide a processor instance or a path/identifier "
|
||||
"to a processor."
|
||||
)
|
||||
|
||||
# Instantiate processor if needed
|
||||
if isinstance(processor, (str, tuple)):
|
||||
processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||
if not isinstance(processor, ProcessorMixin):
|
||||
raise TypeError(
|
||||
"Processor was loaded, but it is not an instance of `ProcessorMixin`. "
|
||||
f"Got type `{type(processor)}` instead. Please check that you specified "
|
||||
"correct pipeline task for the model and model has processor implemented and saved."
|
||||
)
|
||||
|
||||
if task == "translation" and model.config.task_specific_params:
|
||||
for key in model.config.task_specific_params:
|
||||
if key.startswith("translation"):
|
||||
@ -1099,4 +1158,7 @@ def pipeline(
|
||||
if device is not None:
|
||||
kwargs["device"] = device
|
||||
|
||||
if processor is not None:
|
||||
kwargs["processor"] = processor
|
||||
|
||||
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
|
||||
|
@ -34,6 +34,7 @@ from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
from ..image_processing_utils import BaseImageProcessor
|
||||
from ..modelcard import ModelCard
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..processing_utils import ProcessorMixin
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import (
|
||||
ModelOutput,
|
||||
@ -716,6 +717,7 @@ def build_pipeline_init_args(
|
||||
has_tokenizer: bool = False,
|
||||
has_feature_extractor: bool = False,
|
||||
has_image_processor: bool = False,
|
||||
has_processor: bool = False,
|
||||
supports_binary_output: bool = True,
|
||||
) -> str:
|
||||
docstring = r"""
|
||||
@ -738,6 +740,12 @@ def build_pipeline_init_args(
|
||||
image_processor ([`BaseImageProcessor`]):
|
||||
The image processor that will be used by the pipeline to encode data for the model. This object inherits from
|
||||
[`BaseImageProcessor`]."""
|
||||
if has_processor:
|
||||
docstring += r"""
|
||||
processor ([`ProcessorMixin`]):
|
||||
The processor that will be used by the pipeline to encode data for the model. This object inherits from
|
||||
[`ProcessorMixin`]. Processor is a composite object that might contain `tokenizer`, `feature_extractor`, and
|
||||
`image_processor`."""
|
||||
docstring += r"""
|
||||
modelcard (`str` or [`ModelCard`], *optional*):
|
||||
Model card attributed to the model for this pipeline.
|
||||
@ -774,7 +782,11 @@ def build_pipeline_init_args(
|
||||
|
||||
|
||||
PIPELINE_INIT_ARGS = build_pipeline_init_args(
|
||||
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True
|
||||
has_tokenizer=True,
|
||||
has_feature_extractor=True,
|
||||
has_image_processor=True,
|
||||
has_processor=True,
|
||||
supports_binary_output=True,
|
||||
)
|
||||
|
||||
|
||||
@ -787,7 +799,11 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
|
||||
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_feature_extractor=True, has_image_processor=True))
|
||||
@add_end_docstrings(
|
||||
build_pipeline_init_args(
|
||||
has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True
|
||||
)
|
||||
)
|
||||
class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
"""
|
||||
The Pipeline class is the class from which all pipelines inherit. Refer to this class for methods shared across
|
||||
@ -805,6 +821,22 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
constructor argument. If set to `True`, the output will be stored in the pickle format.
|
||||
"""
|
||||
|
||||
# Historically we have pipelines working with `tokenizer`, `feature_extractor`, and `image_processor`
|
||||
# as separate processing components. While we have `processor` class that combines them, some pipelines
|
||||
# might still operate with these components separately.
|
||||
# With the addition of `processor` to `pipeline`, we want to avoid:
|
||||
# - loading `processor` for pipelines that still work with `image_processor` and `tokenizer` separately;
|
||||
# - loading `image_processor`/`tokenizer` as a separate component while we operate only with `processor`,
|
||||
# because `processor` will load required sub-components by itself.
|
||||
# Below flags allow granular control over loading components and set to be backward compatible with current
|
||||
# pipelines logic. You may override these flags when creating your pipeline. For example, for
|
||||
# `zero-shot-object-detection` pipeline which operates with `processor` you should set `_load_processor=True`
|
||||
# and all the rest flags to `False` to avoid unnecessary loading of the components.
|
||||
_load_processor = False
|
||||
_load_image_processor = True
|
||||
_load_feature_extractor = True
|
||||
_load_tokenizer = True
|
||||
|
||||
default_input_names = None
|
||||
|
||||
def __init__(
|
||||
@ -813,6 +845,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
||||
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
|
||||
image_processor: Optional[BaseImageProcessor] = None,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
task: str = "",
|
||||
@ -830,6 +863,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
self.tokenizer = tokenizer
|
||||
self.feature_extractor = feature_extractor
|
||||
self.image_processor = image_processor
|
||||
self.processor = processor
|
||||
self.modelcard = modelcard
|
||||
self.framework = framework
|
||||
|
||||
|
@ -436,9 +436,16 @@ class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "FeatureExtractionPipelineTests":
|
||||
if pipeline_test_case_name == "FeatureExtractionPipelineTests":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -167,9 +167,16 @@ class ASTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "AudioClassificationPipelineTests":
|
||||
if pipeline_test_case_name == "AudioClassificationPipelineTests":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -276,9 +276,16 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -236,9 +236,16 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return pipeline_test_casse_name == "TextGenerationPipelineTests"
|
||||
return pipeline_test_case_name == "TextGenerationPipelineTests"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlenderbotSmallModelTester(self)
|
||||
|
@ -321,9 +321,16 @@ class FlaxBlenderbotSmallModelTest(FlaxModelTesterMixin, unittest.TestCase, Flax
|
||||
all_generative_model_classes = (FlaxBlenderbotSmallForConditionalGeneration,) if is_flax_available() else ()
|
||||
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return pipeline_test_casse_name == "TextGenerationPipelineTests"
|
||||
return pipeline_test_case_name == "TextGenerationPipelineTests"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBlenderbotSmallModelTester(self)
|
||||
|
@ -198,9 +198,16 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, PipelineTesterMixin, unitte
|
||||
test_onnx = False
|
||||
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return pipeline_test_casse_name == "TextGenerationPipelineTests"
|
||||
return pipeline_test_case_name == "TextGenerationPipelineTests"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||
|
@ -295,7 +295,14 @@ class BrosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
# BROS requires `bbox` in the inputs which doesn't fit into the above 2 pipelines' input formats.
|
||||
# see https://github.com/huggingface/transformers/pull/26294
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -22,7 +22,14 @@ from transformers.testing_utils import custom_tokenizers
|
||||
class CpmTokenizationTest(unittest.TestCase):
|
||||
# There is no `CpmModel`
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -211,9 +211,16 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
|
||||
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
|
||||
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
|
||||
# `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny
|
||||
# config could not be created.
|
||||
|
@ -189,9 +189,16 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
|
||||
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
|
||||
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
|
||||
# `CTRLConfig` was never used in pipeline tests, either because of a missing checkpoint or because a tiny
|
||||
# config could not be created.
|
||||
|
@ -312,7 +312,14 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -392,10 +392,17 @@ class FlaubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -309,10 +309,17 @@ class TFFlaubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -298,9 +298,16 @@ class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -326,7 +326,14 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -382,10 +382,17 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -322,10 +322,17 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -260,9 +260,16 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
# TODO: `image-to-text` pipeline for this model needs Processor.
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return pipeline_test_casse_name == "ImageToTextPipelineTests"
|
||||
return pipeline_test_case_name == "ImageToTextPipelineTests"
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
@ -292,7 +292,14 @@ class LayoutLMv3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
# `DocumentQuestionAnsweringPipeline` is expected to work with this model, but it combines the text and visual
|
||||
# embedding along the sequence dimension (dim 1), which causes an error during post-processing as `p_mask` has
|
||||
|
@ -288,7 +288,14 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -302,9 +302,16 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -245,7 +245,14 @@ class LiltModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -331,10 +331,17 @@ class LongformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -303,10 +303,17 @@ class TFLongformerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -621,9 +621,16 @@ class LukeModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name in ["QAPipelineTests", "ZeroShotClassificationPipelineTests"]:
|
||||
if pipeline_test_case_name in ["QAPipelineTests", "ZeroShotClassificationPipelineTests"]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -258,9 +258,16 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "TranslationPipelineTests":
|
||||
if pipeline_test_case_name == "TranslationPipelineTests":
|
||||
# Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
|
||||
# `M2M100Config` was never used in pipeline tests: cannot create a simple tokenizer.
|
||||
return True
|
||||
|
@ -301,7 +301,14 @@ class MarkupLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
# ValueError: Nodes must be of type `List[str]` (single pretokenized example), or `List[List[str]]`
|
||||
# (batch of pretokenized examples).
|
||||
|
@ -252,9 +252,16 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -175,9 +175,16 @@ class TFMBartModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name != "FeatureExtractionPipelineTests":
|
||||
if pipeline_test_case_name != "FeatureExtractionPipelineTests":
|
||||
# Exception encountered when calling layer '...'
|
||||
return True
|
||||
|
||||
|
@ -313,7 +313,14 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -266,7 +266,14 @@ class TFMistralModelTest(TFModelTesterMixin, TFGenerationIntegrationTests, Pipel
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -313,7 +313,14 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -582,7 +582,14 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
|
||||
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if tokenizer_name is None:
|
||||
return True
|
||||
@ -1055,6 +1062,26 @@ class MT5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
|
||||
|
||||
def is_pipeline_test_to_skip(
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if tokenizer_name is None:
|
||||
return True
|
||||
|
||||
# `MT5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
|
||||
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
|
||||
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
@ -441,10 +441,17 @@ class MvpModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -266,7 +266,14 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
# Saving the slow tokenizer after saving the fast tokenizer causes the loading of the later hanging forever.
|
||||
return True
|
||||
|
@ -247,9 +247,16 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "FeatureExtractionPipelineTests":
|
||||
if pipeline_test_case_name == "FeatureExtractionPipelineTests":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -211,9 +211,16 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
|
||||
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
|
||||
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
|
||||
# `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a
|
||||
# tiny config could not be created.
|
||||
|
@ -217,9 +217,16 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "ZeroShotClassificationPipelineTests":
|
||||
if pipeline_test_case_name == "ZeroShotClassificationPipelineTests":
|
||||
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
|
||||
# `OpenAIGPTConfig` was never used in pipeline tests, either because of a missing checkpoint or because a
|
||||
# tiny config could not be created.
|
||||
|
@ -221,10 +221,17 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -304,7 +304,14 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -351,7 +351,14 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -245,9 +245,16 @@ class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "TranslationPipelineTests":
|
||||
if pipeline_test_case_name == "TranslationPipelineTests":
|
||||
# Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
|
||||
# `PLBartConfig` was never used in pipeline tests: cannot create a simple tokenizer.
|
||||
return True
|
||||
|
@ -906,9 +906,16 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "TextGenerationPipelineTests":
|
||||
if pipeline_test_case_name == "TextGenerationPipelineTests":
|
||||
# Get `ValueError: AttributeError: 'NoneType' object has no attribute 'new_ones'` or `AssertionError`.
|
||||
# `ProphetNetConfig` was never used in pipeline tests: cannot create a simple
|
||||
# tokenizer.
|
||||
|
@ -322,7 +322,14 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -349,7 +349,14 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -305,7 +305,14 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -727,10 +727,17 @@ class ReformerLSHAttnModelTest(
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -587,9 +587,16 @@ class RoCBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name in [
|
||||
if pipeline_test_case_name in [
|
||||
"FillMaskPipelineTests",
|
||||
"FeatureExtractionPipelineTests",
|
||||
"TextClassificationPipelineTests",
|
||||
|
@ -275,9 +275,16 @@ class TFRoFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
|
||||
# TODO: add `prepare_inputs_for_generation` for `TFRoFormerForCausalLM`
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "TextGenerationPipelineTests":
|
||||
if pipeline_test_case_name == "TextGenerationPipelineTests":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -299,7 +299,14 @@ class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -290,7 +290,14 @@ class TFSamModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -225,11 +225,18 @@ class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
# TODO: Fix the failed tests when this model gets more usage
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests":
|
||||
if pipeline_test_case_name == "QAPipelineTests":
|
||||
return True
|
||||
elif pipeline_test_casse_name == "FeatureExtractionPipelineTests" and tokenizer_name.endswith("Fast"):
|
||||
elif pipeline_test_case_name == "FeatureExtractionPipelineTests" and tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -311,7 +311,14 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
|
||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -585,7 +585,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
|
||||
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if tokenizer_name is None:
|
||||
return True
|
||||
@ -1056,6 +1063,26 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
|
||||
|
||||
def is_pipeline_test_to_skip(
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if tokenizer_name is None:
|
||||
return True
|
||||
|
||||
# `T5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
|
||||
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
|
||||
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
model.config.update(model.config.task_specific_params[task])
|
||||
|
@ -492,7 +492,14 @@ class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -447,7 +447,14 @@ class TFTapasModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -322,9 +322,16 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
|
||||
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -723,6 +730,26 @@ class UMT5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
|
||||
|
||||
def is_pipeline_test_to_skip(
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if tokenizer_name is None:
|
||||
return True
|
||||
|
||||
# `UMT5EncoderOnlyModelTest` is not working well with slow tokenizers (for some models) and we don't want to touch the file
|
||||
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
|
||||
if pipeline_test_case_name == "TokenClassificationPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
@ -397,9 +397,16 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name in [
|
||||
if pipeline_test_case_name in [
|
||||
"AutomaticSpeechRecognitionPipelineTests",
|
||||
"AudioClassificationPipelineTests",
|
||||
]:
|
||||
|
@ -312,10 +312,17 @@ class TFXLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -393,10 +393,17 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
|
@ -392,9 +392,16 @@ class XLMRobertaXLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTes
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -372,7 +372,14 @@ class TFXLNetModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
# Exception encountered when calling layer '...'
|
||||
return True
|
||||
|
@ -543,9 +543,16 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -386,9 +386,16 @@ class XmodModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -37,9 +37,22 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
audio_classifier = AudioClassificationPipeline(
|
||||
model=model, feature_extractor=processor, torch_dtype=torch_dtype
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
# test with a raw waveform
|
||||
|
@ -67,14 +67,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
if tokenizer is None:
|
||||
# Side effect of no Fast Tokenizer class for these model, so skipping
|
||||
# But the slow tokenizer test should still run as they're quite small
|
||||
self.skipTest(reason="No tokenizer available")
|
||||
|
||||
speech_recognizer = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
# test with a raw waveform
|
||||
|
@ -58,8 +58,23 @@ def hashimage(image: Image) -> str:
|
||||
class DepthEstimationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
depth_estimator = DepthEstimationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
depth_estimator = DepthEstimationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return depth_estimator, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
|
@ -15,7 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, AutoTokenizer, is_vision_available
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.pipelines import DocumentQuestionAnsweringPipeline, pipeline
|
||||
from transformers.pipelines.document_question_answering import apply_tesseract
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
@ -61,12 +61,21 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase):
|
||||
|
||||
@require_pytesseract
|
||||
@require_vision
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
dqa_pipeline = pipeline(
|
||||
"document-question-answering",
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
dqa_pipeline = DocumentQuestionAnsweringPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
image_processor=processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
|
@ -174,7 +174,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase):
|
||||
raise TypeError("We expect lists of floats, nothing else")
|
||||
return shape
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
if tokenizer is None:
|
||||
self.skipTest(reason="No tokenizer")
|
||||
elif (
|
||||
@ -193,10 +201,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase):
|
||||
For now ignore those.
|
||||
"""
|
||||
)
|
||||
feature_extractor = FeatureExtractionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=processor, torch_dtype=torch_dtype
|
||||
feature_extractor_pipeline = FeatureExtractionPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return feature_extractor, ["This is a test", "This is another test"]
|
||||
return feature_extractor_pipeline, ["This is a test", "This is another test"]
|
||||
|
||||
def run_pipeline_test(self, feature_extractor, examples):
|
||||
outputs = feature_extractor("This is a test")
|
||||
|
@ -251,11 +251,26 @@ class FillMaskPipelineTests(unittest.TestCase):
|
||||
unmasker.tokenizer.pad_token = None
|
||||
self.run_pipeline_test(unmasker, [])
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
if tokenizer is None or tokenizer.mask_token_id is None:
|
||||
self.skipTest(reason="The provided tokenizer has no mask token, (probably reformer or wav2vec2)")
|
||||
|
||||
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
fill_masker = FillMaskPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
examples = [
|
||||
f"This is another {tokenizer.mask_token} test",
|
||||
]
|
||||
|
@ -58,9 +58,23 @@ class ImageClassificationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
image_classifier = ImageClassificationPipeline(
|
||||
model=model, image_processor=processor, top_k=2, torch_dtype=torch_dtype
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
top_k=2,
|
||||
)
|
||||
examples = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
|
@ -157,8 +157,16 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||
outputs = feature_extractor(img, return_tensors=True)
|
||||
self.assertTrue(tf.is_tensor(outputs))
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
if processor is None:
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
if image_processor is None:
|
||||
self.skipTest(reason="No image processor")
|
||||
|
||||
elif type(model.config) in TOKENIZER_MAPPING:
|
||||
@ -175,11 +183,16 @@ class ImageFeatureExtractionPipelineTests(unittest.TestCase):
|
||||
"""
|
||||
)
|
||||
|
||||
feature_extractor = ImageFeatureExtractionPipeline(
|
||||
model=model, image_processor=processor, torch_dtype=torch_dtype
|
||||
feature_extractor_pipeline = ImageFeatureExtractionPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
img = prepare_img()
|
||||
return feature_extractor, [img, img]
|
||||
return feature_extractor_pipeline, [img, img]
|
||||
|
||||
def run_pipeline_test(self, feature_extractor, examples):
|
||||
imgs = examples
|
||||
|
@ -89,8 +89,23 @@ class ImageSegmentationPipelineTests(unittest.TestCase):
|
||||
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
image_segmenter = ImageSegmentationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
image_segmenter = ImageSegmentationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return image_segmenter, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
|
@ -47,9 +47,22 @@ class ImageToTextPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
pipe = ImageToTextPipeline(
|
||||
model=model, tokenizer=tokenizer, image_processor=processor, torch_dtype=torch_dtype
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
examples = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
|
@ -67,8 +67,23 @@ class MaskGenerationPipelineTests(unittest.TestCase):
|
||||
(list(TF_MODEL_FOR_MASK_GENERATION_MAPPING.items()) if TF_MODEL_FOR_MASK_GENERATION_MAPPING else [])
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
image_segmenter = MaskGenerationPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
image_segmenter = MaskGenerationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return image_segmenter, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
|
@ -56,8 +56,23 @@ else:
|
||||
class ObjectDetectionPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
object_detector = ObjectDetectionPipeline(model=model, image_processor=processor, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
object_detector = ObjectDetectionPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
|
||||
|
||||
def run_pipeline_test(self, object_detector, examples):
|
||||
|
@ -50,12 +50,27 @@ class QAPipelineTests(unittest.TestCase):
|
||||
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
|
||||
}
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
if isinstance(model.config, LxmertConfig):
|
||||
# This is an bimodal model, we need to find a more consistent way
|
||||
# to switch on those models.
|
||||
return None, None
|
||||
question_answerer = QuestionAnsweringPipeline(model, tokenizer, torch_dtype=torch_dtype)
|
||||
question_answerer = QuestionAnsweringPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
examples = [
|
||||
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
|
||||
|
@ -32,8 +32,23 @@ class SummarizationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
summarizer = SummarizationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
summarizer = SummarizationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return summarizer, ["(CNN)The Palestinian Authority officially became", "Some other text"]
|
||||
|
||||
def run_pipeline_test(self, summarizer, _):
|
||||
|
@ -35,8 +35,23 @@ class Text2TextGenerationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
generator = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
generator = Text2TextGenerationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return generator, ["Something to write", "Something else"]
|
||||
|
||||
def run_pipeline_test(self, generator, _):
|
||||
|
@ -179,8 +179,23 @@ class TextClassificationPipelineTests(unittest.TestCase):
|
||||
outputs = text_classifier("Birds are a type of animal")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
text_classifier = TextClassificationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return text_classifier, ["HuggingFace is in", "This is another test"]
|
||||
|
||||
def run_pipeline_test(self, text_classifier, _):
|
||||
|
@ -377,8 +377,23 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
text_generator = TextGenerationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return text_generator, ["This is a test", "Another test"]
|
||||
|
||||
def test_stop_sequence_stopping_criteria(self):
|
||||
@ -471,6 +486,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
"GPTNeoXForCausalLM",
|
||||
"GPTNeoXJapaneseForCausalLM",
|
||||
"FuyuForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
]
|
||||
if (
|
||||
tokenizer.model_max_length < 10000
|
||||
|
@ -250,8 +250,23 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
|
||||
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
speech_generator = TextToAudioPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return speech_generator, ["This is a test", "Another test"]
|
||||
|
||||
def run_pipeline_test(self, speech_generator, _):
|
||||
|
@ -61,8 +61,23 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
||||
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
|
||||
}
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
token_classifier = TokenClassificationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return token_classifier, ["A simple string", "A simple string that is quite a bit longer"]
|
||||
|
||||
def run_pipeline_test(self, token_classifier, _):
|
||||
|
@ -35,14 +35,36 @@ class TranslationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
if isinstance(model.config, MBartConfig):
|
||||
src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2]
|
||||
translator = TranslationPipeline(
|
||||
model=model, tokenizer=tokenizer, src_lang=src_lang, tgt_lang=tgt_lang, torch_dtype=torch_dtype
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
)
|
||||
else:
|
||||
translator = TranslationPipeline(model=model, tokenizer=tokenizer, torch_dtype=torch_dtype)
|
||||
translator = TranslationPipeline(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
return translator, ["Some string", "Some other text"]
|
||||
|
||||
def run_pipeline_test(self, translator, _):
|
||||
|
@ -38,12 +38,26 @@ from .test_pipelines_common import ANY
|
||||
class VideoClassificationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
example_video_filepath = hf_hub_download(
|
||||
repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset"
|
||||
)
|
||||
video_classifier = VideoClassificationPipeline(
|
||||
model=model, image_processor=processor, top_k=2, torch_dtype=torch_dtype
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
top_k=2,
|
||||
)
|
||||
examples = [
|
||||
example_video_filepath,
|
||||
|
@ -55,9 +55,19 @@ else:
|
||||
class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
vqa_pipeline = pipeline(
|
||||
"visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa", torch_dtype=torch_dtype
|
||||
"visual-question-answering",
|
||||
model="hf-internal-testing/tiny-vilt-random-vqa",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
examples = [
|
||||
{
|
||||
|
@ -53,9 +53,23 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase):
|
||||
config: model for config, model in tf_model_mapping.items() if config.__name__ not in _TO_SKIP
|
||||
}
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
classifier = ZeroShotClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, candidate_labels=["polics", "health"], torch_dtype=torch_dtype
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
processor=processor,
|
||||
torch_dtype=torch_dtype,
|
||||
candidate_labels=["polics", "health"],
|
||||
)
|
||||
return classifier, ["Who are you voting for in 2020?", "My stomach hurts."]
|
||||
|
||||
|
@ -43,7 +43,15 @@ else:
|
||||
class ZeroShotObjectDetectionPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"):
|
||||
def get_test_pipeline(
|
||||
self,
|
||||
model,
|
||||
tokenizer=None,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
object_detector = pipeline(
|
||||
"zero-shot-object-detection",
|
||||
model="hf-internal-testing/tiny-random-owlvit-object-detection",
|
||||
|
@ -36,6 +36,7 @@ from huggingface_hub import (
|
||||
ZeroShotImageClassificationInput,
|
||||
)
|
||||
|
||||
from transformers.models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
|
||||
from transformers.pipelines import (
|
||||
AudioClassificationPipeline,
|
||||
AutomaticSpeechRecognitionPipeline,
|
||||
@ -183,11 +184,14 @@ class PipelineTesterMixin:
|
||||
model_architectures = self.pipeline_model_mapping[task]
|
||||
if not isinstance(model_architectures, tuple):
|
||||
model_architectures = (model_architectures,)
|
||||
if not isinstance(model_architectures, tuple):
|
||||
raise TypeError(f"`model_architectures` must be a tuple. Got {type(model_architectures)} instead.")
|
||||
|
||||
# We are going to run tests for multiple model architectures, some of them might be skipped
|
||||
# with this flag we are control if at least one model were tested or all were skipped
|
||||
at_least_one_model_is_tested = False
|
||||
|
||||
for model_architecture in model_architectures:
|
||||
model_arch_name = model_architecture.__name__
|
||||
model_type = model_architecture.config_class.model_type
|
||||
|
||||
# Get the canonical name
|
||||
for _prefix in ["Flax", "TF"]:
|
||||
@ -195,31 +199,72 @@ class PipelineTesterMixin:
|
||||
model_arch_name = model_arch_name[len(_prefix) :]
|
||||
break
|
||||
|
||||
tokenizer_names = []
|
||||
processor_names = []
|
||||
if model_arch_name not in tiny_model_summary:
|
||||
continue
|
||||
|
||||
tokenizer_names = tiny_model_summary[model_arch_name]["tokenizer_classes"]
|
||||
|
||||
# Sort image processors and feature extractors from tiny-models json file
|
||||
image_processor_names = []
|
||||
feature_extractor_names = []
|
||||
|
||||
processor_classes = tiny_model_summary[model_arch_name]["processor_classes"]
|
||||
for cls_name in processor_classes:
|
||||
if "ImageProcessor" in cls_name:
|
||||
image_processor_names.append(cls_name)
|
||||
elif "FeatureExtractor" in cls_name:
|
||||
feature_extractor_names.append(cls_name)
|
||||
else:
|
||||
raise ValueError(f"Unknown processor class: {cls_name}")
|
||||
|
||||
# Processor classes are not in tiny models JSON file, so extract them from the mapping
|
||||
# processors are mapped to instance, e.g. "XxxProcessor"
|
||||
processor_names = PROCESSOR_MAPPING_NAMES.get(model_type, None)
|
||||
if not isinstance(processor_names, (list, tuple)):
|
||||
processor_names = [processor_names]
|
||||
|
||||
commit = None
|
||||
if model_arch_name in tiny_model_summary:
|
||||
tokenizer_names = tiny_model_summary[model_arch_name]["tokenizer_classes"]
|
||||
processor_names = tiny_model_summary[model_arch_name]["processor_classes"]
|
||||
if "sha" in tiny_model_summary[model_arch_name]:
|
||||
commit = tiny_model_summary[model_arch_name]["sha"]
|
||||
# Adding `None` (if empty) so we can generate tests
|
||||
tokenizer_names = [None] if len(tokenizer_names) == 0 else tokenizer_names
|
||||
processor_names = [None] if len(processor_names) == 0 else processor_names
|
||||
if model_arch_name in tiny_model_summary and "sha" in tiny_model_summary[model_arch_name]:
|
||||
commit = tiny_model_summary[model_arch_name]["sha"]
|
||||
|
||||
repo_name = f"tiny-random-{model_arch_name}"
|
||||
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
|
||||
repo_name = model_arch_name
|
||||
|
||||
self.run_model_pipeline_tests(
|
||||
task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype
|
||||
task,
|
||||
repo_name,
|
||||
model_architecture,
|
||||
tokenizer_names=tokenizer_names,
|
||||
image_processor_names=image_processor_names,
|
||||
feature_extractor_names=feature_extractor_names,
|
||||
processor_names=processor_names,
|
||||
commit=commit,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
if task in task_to_pipeline_and_spec_mapping:
|
||||
pipeline, hub_spec = task_to_pipeline_and_spec_mapping[task]
|
||||
compare_pipeline_args_to_hub_spec(pipeline, hub_spec)
|
||||
|
||||
at_least_one_model_is_tested = True
|
||||
|
||||
if not at_least_one_model_is_tested:
|
||||
self.skipTest(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find any "
|
||||
f"model architecture in the tiny models JSON file for `{task}`."
|
||||
)
|
||||
|
||||
def run_model_pipeline_tests(
|
||||
self, task, repo_name, model_architecture, tokenizer_names, processor_names, commit, torch_dtype="float32"
|
||||
self,
|
||||
task,
|
||||
repo_name,
|
||||
model_architecture,
|
||||
tokenizer_names,
|
||||
image_processor_names,
|
||||
feature_extractor_names,
|
||||
processor_names,
|
||||
commit,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
"""Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class names
|
||||
|
||||
@ -232,8 +277,12 @@ class PipelineTesterMixin:
|
||||
A subclass of `PretrainedModel` or `PretrainedModel`.
|
||||
tokenizer_names (`List[str]`):
|
||||
A list of names of a subclasses of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`.
|
||||
image_processor_names (`List[str]`):
|
||||
A list of names of subclasses of `BaseImageProcessor`.
|
||||
feature_extractor_names (`List[str]`):
|
||||
A list of names of subclasses of `FeatureExtractionMixin`.
|
||||
processor_names (`List[str]`):
|
||||
A list of names of subclasses of `BaseImageProcessor` or `FeatureExtractionMixin`.
|
||||
A list of names of subclasses of `ProcessorMixin`.
|
||||
commit (`str`):
|
||||
The commit hash of the model repository on the Hub.
|
||||
torch_dtype (`str`, `optional`, defaults to `'float32'`):
|
||||
@ -243,27 +292,73 @@ class PipelineTesterMixin:
|
||||
# `run_pipeline_test`.
|
||||
pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__
|
||||
|
||||
for tokenizer_name in tokenizer_names:
|
||||
for processor_name in processor_names:
|
||||
if self.is_pipeline_test_to_skip(
|
||||
pipeline_test_class_name,
|
||||
model_architecture.config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
processor_name,
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is "
|
||||
f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer "
|
||||
f"`{tokenizer_name}` | processor `{processor_name}`."
|
||||
)
|
||||
continue
|
||||
self.run_pipeline_test(
|
||||
task, repo_name, model_architecture, tokenizer_name, processor_name, commit, torch_dtype
|
||||
# If no image processor or feature extractor is found, we still need to test the pipeline with None
|
||||
# otherwise for any empty list we might skip all the tests
|
||||
tokenizer_names = tokenizer_names or [None]
|
||||
image_processor_names = image_processor_names or [None]
|
||||
feature_extractor_names = feature_extractor_names or [None]
|
||||
processor_names = processor_names or [None]
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"tokenizer_name": tokenizer_name,
|
||||
"image_processor_name": image_processor_name,
|
||||
"feature_extractor_name": feature_extractor_name,
|
||||
"processor_name": processor_name,
|
||||
}
|
||||
for tokenizer_name in tokenizer_names
|
||||
for image_processor_name in image_processor_names
|
||||
for feature_extractor_name in feature_extractor_names
|
||||
for processor_name in processor_names
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
tokenizer_name = test_case["tokenizer_name"]
|
||||
image_processor_name = test_case["image_processor_name"]
|
||||
feature_extractor_name = test_case["feature_extractor_name"]
|
||||
processor_name = test_case["processor_name"]
|
||||
|
||||
do_skip_test_case = self.is_pipeline_test_to_skip(
|
||||
pipeline_test_class_name,
|
||||
model_architecture.config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
)
|
||||
|
||||
if do_skip_test_case:
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is "
|
||||
f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer "
|
||||
f"`{tokenizer_name}` | image processor `{image_processor_name}` | feature extractor {feature_extractor_name}."
|
||||
)
|
||||
continue
|
||||
|
||||
self.run_pipeline_test(
|
||||
task,
|
||||
repo_name,
|
||||
model_architecture,
|
||||
tokenizer_name=tokenizer_name,
|
||||
image_processor_name=image_processor_name,
|
||||
feature_extractor_name=feature_extractor_name,
|
||||
processor_name=processor_name,
|
||||
commit=commit,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
def run_pipeline_test(
|
||||
self, task, repo_name, model_architecture, tokenizer_name, processor_name, commit, torch_dtype="float32"
|
||||
self,
|
||||
task,
|
||||
repo_name,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
commit,
|
||||
torch_dtype="float32",
|
||||
):
|
||||
"""Run pipeline tests for a specific `task` with the give model class and tokenizer/processor class name
|
||||
|
||||
@ -278,43 +373,24 @@ class PipelineTesterMixin:
|
||||
A subclass of `PretrainedModel` or `PretrainedModel`.
|
||||
tokenizer_name (`str`):
|
||||
The name of a subclass of `PreTrainedTokenizerFast` or `PreTrainedTokenizer`.
|
||||
image_processor_name (`str`):
|
||||
The name of a subclass of `BaseImageProcessor`.
|
||||
feature_extractor_name (`str`):
|
||||
The name of a subclass of `FeatureExtractionMixin`.
|
||||
processor_name (`str`):
|
||||
The name of a subclass of `BaseImageProcessor` or `FeatureExtractionMixin`.
|
||||
The name of a subclass of `ProcessorMixin`.
|
||||
commit (`str`):
|
||||
The commit hash of the model repository on the Hub.
|
||||
torch_dtype (`str`, `optional`, defaults to `'float32'`):
|
||||
The torch dtype to use for the model. Can be used for FP16/other precision inference.
|
||||
"""
|
||||
repo_id = f"{TRANSFORMERS_TINY_MODEL_PATH}/{repo_name}"
|
||||
model_type = model_architecture.config_class.model_type
|
||||
|
||||
if TRANSFORMERS_TINY_MODEL_PATH != "hf-internal-testing":
|
||||
model_type = model_architecture.config_class.model_type
|
||||
repo_id = os.path.join(TRANSFORMERS_TINY_MODEL_PATH, model_type, repo_name)
|
||||
|
||||
tokenizer = None
|
||||
if tokenizer_name is not None:
|
||||
tokenizer_class = getattr(transformers_module, tokenizer_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(repo_id, revision=commit)
|
||||
|
||||
processor = None
|
||||
if processor_name is not None:
|
||||
processor_class = getattr(transformers_module, processor_name)
|
||||
# If the required packages (like `Pillow` or `torchaudio`) are not installed, this will fail.
|
||||
try:
|
||||
processor = processor_class.from_pretrained(repo_id, revision=commit)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not load the "
|
||||
f"processor from `{repo_id}` with `{processor_name}`."
|
||||
)
|
||||
self.skipTest(f"Could not load the processor from {repo_id} with {processor_name}.")
|
||||
|
||||
# TODO: Maybe not upload such problematic tiny models to Hub.
|
||||
if tokenizer is None and processor is None:
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find or load "
|
||||
f"any tokenizer / processor from `{repo_id}`."
|
||||
)
|
||||
self.skipTest(f"Could not find or load any tokenizer / processor from {repo_id}.")
|
||||
# -------------------- Load model --------------------
|
||||
|
||||
# TODO: We should check if a model file is on the Hub repo. instead.
|
||||
try:
|
||||
@ -326,19 +402,57 @@ class PipelineTesterMixin:
|
||||
)
|
||||
self.skipTest(f"Could not find or load the model from {repo_id} with {model_architecture}.")
|
||||
|
||||
# -------------------- Load tokenizer --------------------
|
||||
|
||||
tokenizer = None
|
||||
if tokenizer_name is not None:
|
||||
tokenizer_class = getattr(transformers_module, tokenizer_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(repo_id, revision=commit)
|
||||
|
||||
# -------------------- Load processors --------------------
|
||||
|
||||
processors = {}
|
||||
for key, name in zip(
|
||||
["image_processor", "feature_extractor", "processor"],
|
||||
[image_processor_name, feature_extractor_name, processor_name],
|
||||
):
|
||||
if name is not None:
|
||||
try:
|
||||
# Can fail if some extra dependencies are not installed
|
||||
processor_class = getattr(transformers_module, name)
|
||||
processor = processor_class.from_pretrained(repo_id, revision=commit)
|
||||
processors[key] = processor
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: "
|
||||
f"Could not load the {key} from `{repo_id}` with `{name}`."
|
||||
)
|
||||
self.skipTest(f"Could not load the {key} from {repo_id} with {name}.")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
|
||||
# TODO: Maybe not upload such problematic tiny models to Hub.
|
||||
if tokenizer is None and "image_processor" not in processors and "feature_extractor" not in processors:
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: Could not find or load "
|
||||
f"any tokenizer / image processor / feature extractor from `{repo_id}`."
|
||||
)
|
||||
self.skipTest(f"Could not find or load any tokenizer / processor from {repo_id}.")
|
||||
|
||||
pipeline_test_class_name = pipeline_test_mapping[task]["test"].__name__
|
||||
if self.is_pipeline_test_to_skip_more(pipeline_test_class_name, model.config, model, tokenizer, processor):
|
||||
if self.is_pipeline_test_to_skip_more(pipeline_test_class_name, model.config, model, tokenizer, **processors):
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}::test_pipeline_{task.replace('-', '_')}_{torch_dtype} is skipped: test is "
|
||||
f"currently known to fail for: model `{model_architecture.__name__}` | tokenizer "
|
||||
f"`{tokenizer_name}` | processor `{processor_name}`."
|
||||
f"`{tokenizer_name}` | image processor `{image_processor_name}` | feature extractor `{feature_extractor_name}`."
|
||||
)
|
||||
self.skipTest(
|
||||
f"Test is known to fail for: model `{model_architecture.__name__}` | tokenizer `{tokenizer_name}` | processor `{processor_name}`."
|
||||
f"Test is known to fail for: model `{model_architecture.__name__}` | tokenizer `{tokenizer_name}` "
|
||||
f"| image processor `{image_processor_name}` | feature extractor `{feature_extractor_name}`."
|
||||
)
|
||||
|
||||
# validate
|
||||
validate_test_components(self, task, model, tokenizer, processor)
|
||||
validate_test_components(model, tokenizer)
|
||||
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
@ -347,7 +461,7 @@ class PipelineTesterMixin:
|
||||
# `run_pipeline_test`.
|
||||
task_test = pipeline_test_mapping[task]["test"]()
|
||||
|
||||
pipeline, examples = task_test.get_test_pipeline(model, tokenizer, processor, torch_dtype=torch_dtype)
|
||||
pipeline, examples = task_test.get_test_pipeline(model, tokenizer, **processors, torch_dtype=torch_dtype)
|
||||
if pipeline is None:
|
||||
# The test can disable itself, but it should be very marginal
|
||||
# Concerns: Wav2Vec2ForCTC without tokenizer test (FastTokenizer don't exist)
|
||||
@ -674,7 +788,14 @@ class PipelineTesterMixin:
|
||||
|
||||
# This contains the test cases to be skipped without model architecture being involved.
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config_class,
|
||||
model_architecture,
|
||||
tokenizer_name,
|
||||
image_processor_name,
|
||||
feature_extractor_name,
|
||||
processor_name,
|
||||
):
|
||||
"""Skip some tests based on the classes or their names without the instantiated objects.
|
||||
|
||||
@ -682,7 +803,7 @@ class PipelineTesterMixin:
|
||||
"""
|
||||
# No fix is required for this case.
|
||||
if (
|
||||
pipeline_test_casse_name == "DocumentQuestionAnsweringPipelineTests"
|
||||
pipeline_test_case_name == "DocumentQuestionAnsweringPipelineTests"
|
||||
and tokenizer_name is not None
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
):
|
||||
@ -691,11 +812,20 @@ class PipelineTesterMixin:
|
||||
|
||||
return False
|
||||
|
||||
def is_pipeline_test_to_skip_more(self, pipeline_test_casse_name, config, model, tokenizer, processor): # noqa
|
||||
def is_pipeline_test_to_skip_more(
|
||||
self,
|
||||
pipeline_test_case_name,
|
||||
config,
|
||||
model,
|
||||
tokenizer,
|
||||
image_processor=None,
|
||||
feature_extractor=None,
|
||||
processor=None,
|
||||
): # noqa
|
||||
"""Skip some more tests based on the information from the instantiated objects."""
|
||||
# No fix is required for this case.
|
||||
if (
|
||||
pipeline_test_casse_name == "QAPipelineTests"
|
||||
pipeline_test_case_name == "QAPipelineTests"
|
||||
and tokenizer is not None
|
||||
and getattr(tokenizer, "pad_token", None) is None
|
||||
and not tokenizer.__class__.__name__.endswith("Fast")
|
||||
@ -706,7 +836,7 @@ class PipelineTesterMixin:
|
||||
return False
|
||||
|
||||
|
||||
def validate_test_components(test_case, task, model, tokenizer, processor):
|
||||
def validate_test_components(model, tokenizer):
|
||||
# TODO: Move this to tiny model creation script
|
||||
# head-specific (within a model type) necessary changes to the config
|
||||
# 1. for `BlenderbotForCausalLM`
|
||||
|
Loading…
Reference in New Issue
Block a user