mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adding support for truncation
parameter on feature-extraction
pipeline. (#14193)
* Adding support for `truncation` parameter on `feature-extraction` pipeline. Fixes #14183 * Fixing tests on ibert, longformer, and roberta. * Rebase fix.
This commit is contained in:
parent
27b1516d32
commit
dec759e7e8
@ -41,12 +41,19 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
the associated CUDA device id.
|
||||
"""
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
return {}, {}, {}
|
||||
def _sanitize_parameters(self, truncation=None, **kwargs):
|
||||
preprocess_params = {}
|
||||
if truncation is not None:
|
||||
preprocess_params["truncation"] = truncation
|
||||
return preprocess_params, {}, {}
|
||||
|
||||
def preprocess(self, inputs) -> Dict[str, GenericTensor]:
|
||||
def preprocess(self, inputs, truncation=None) -> Dict[str, GenericTensor]:
|
||||
return_tensors = self.framework
|
||||
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors)
|
||||
if truncation is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"truncation": truncation}
|
||||
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **kwargs)
|
||||
return model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
|
@ -22,7 +22,15 @@ from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from unittest import skipIf
|
||||
|
||||
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, pipeline
|
||||
from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
IBertConfig,
|
||||
RobertaConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines.base import _pad
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch
|
||||
|
||||
@ -143,7 +151,7 @@ class PipelineTestCaseMeta(type):
|
||||
try:
|
||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||
# XLNet actually defines it as -1.
|
||||
if model.config.__class__.__name__ == "RobertaConfig":
|
||||
if isinstance(model.config, (RobertaConfig, IBertConfig)):
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings - 2
|
||||
elif (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
|
@ -105,3 +105,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
outputs = feature_extractor(["This is a test", "Another longer test"])
|
||||
shape = self.get_shape(outputs)
|
||||
self.assertEqual(shape[0], 2)
|
||||
|
||||
outputs = feature_extractor("This is a test" * 100, truncation=True)
|
||||
shape = self.get_shape(outputs)
|
||||
self.assertEqual(shape[0], 1)
|
||||
|
Loading…
Reference in New Issue
Block a user