Adding support for truncation parameter on feature-extraction pipeline. (#14193)

* Adding support for `truncation` parameter on `feature-extraction`
pipeline.

Fixes #14183

* Fixing tests on ibert, longformer, and roberta.

* Rebase fix.
This commit is contained in:
Nicolas Patry 2021-11-03 15:48:00 +01:00 committed by GitHub
parent 27b1516d32
commit dec759e7e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 6 deletions

View File

@ -41,12 +41,19 @@ class FeatureExtractionPipeline(Pipeline):
the associated CUDA device id.
"""
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def _sanitize_parameters(self, truncation=None, **kwargs):
preprocess_params = {}
if truncation is not None:
preprocess_params["truncation"] = truncation
return preprocess_params, {}, {}
def preprocess(self, inputs) -> Dict[str, GenericTensor]:
def preprocess(self, inputs, truncation=None) -> Dict[str, GenericTensor]:
return_tensors = self.framework
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors)
if truncation is None:
kwargs = {}
else:
kwargs = {"truncation": truncation}
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **kwargs)
return model_inputs
def _forward(self, model_inputs):

View File

@ -22,7 +22,15 @@ from abc import abstractmethod
from functools import lru_cache
from unittest import skipIf
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, pipeline
from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
IBertConfig,
RobertaConfig,
pipeline,
)
from transformers.pipelines.base import _pad
from transformers.testing_utils import is_pipeline_test, require_torch
@ -143,7 +151,7 @@ class PipelineTestCaseMeta(type):
try:
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
# XLNet actually defines it as -1.
if model.config.__class__.__name__ == "RobertaConfig":
if isinstance(model.config, (RobertaConfig, IBertConfig)):
tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif (
hasattr(model.config, "max_position_embeddings")

View File

@ -105,3 +105,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = feature_extractor(["This is a test", "Another longer test"])
shape = self.get_shape(outputs)
self.assertEqual(shape[0], 2)
outputs = feature_extractor("This is a test" * 100, truncation=True)
shape = self.get_shape(outputs)
self.assertEqual(shape[0], 1)