Enable TruncationStrategy override for pipelines (#9432)

* Enable TruncationStrategy override for pipelines

* Update isort.

* Fixing test

* Fixing text_generation pipeline.

* Using same DummyTok as other PR  for easier merge later.

* Some more import guards.

* Remove bogus file.

* Do not pass `generate_kwargs` to `_parse_and_tokenize`.
@patrickvonplaten

* Removed DummyTok.

* Doc quality.
This commit is contained in:
Nicolas Patry 2021-01-11 15:23:28 +01:00 committed by GitHub
parent 8d25df2c7a
commit d20e9c7299
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 94 additions and 25 deletions

View File

@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..utils import logging from ..utils import logging
@ -577,7 +577,9 @@ class Pipeline(_ScikitCompat):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}", f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
) )
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs): def _parse_and_tokenize(
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
""" """
Parse arguments and tokenize Parse arguments and tokenize
""" """
@ -587,6 +589,7 @@ class Pipeline(_ScikitCompat):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
return_tensors=self.framework, return_tensors=self.framework,
padding=padding, padding=padding,
truncation=truncation,
) )
return inputs return inputs

View File

@ -2,6 +2,7 @@ import uuid
from typing import List, Optional, Union from typing import List, Optional, Union
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
@ -317,12 +318,14 @@ class ConversationalPipeline(Pipeline):
else: else:
return output return output
def _parse_and_tokenize(self, inputs, **kwargs): def _parse_and_tokenize(
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
""" """
Parse arguments and tokenize, adding an EOS token at the end of the user input Parse arguments and tokenize, adding an EOS token at the end of the user input
""" """
# Parse arguments # Parse arguments
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", []) inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
for input in inputs: for input in inputs:
input.append(self.tokenizer.eos_token_id) input.append(self.tokenizer.eos_token_id)
return inputs return inputs

View File

@ -1,4 +1,5 @@
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
@ -50,7 +51,13 @@ class Text2TextGenerationPipeline(Pipeline):
return True return True
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
**generate_kwargs
): ):
r""" r"""
Generate the output text(s) using text(s) given as inputs. Generate the output text(s) using text(s) given as inputs.
@ -64,6 +71,10 @@ class Text2TextGenerationPipeline(Pipeline):
Whether or not to include the decoded texts in the outputs. Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`):
The truncation strategy for the tokenization within the pipeline.
:obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable
to truncate the input to fit the model's max_length instead of throwing an error down the line.
generate_kwargs: generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__). corresponding to your framework `here <./model.html#generative-models>`__).
@ -96,7 +107,7 @@ class Text2TextGenerationPipeline(Pipeline):
) )
with self.device_placement(): with self.device_placement():
inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs) inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)
if self.framework == "pt": if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs) inputs = self.ensure_tensor_on_device(**inputs)
@ -108,9 +119,6 @@ class Text2TextGenerationPipeline(Pipeline):
max_length = generate_kwargs.get("max_length", self.model.config.max_length) max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length) self.check_inputs(input_length, min_length, max_length)
# truncation should be used by _parse_and_tokenize
generate_kwargs.pop("truncation", None)
generations = self.model.generate( generations = self.model.generate(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],

View File

@ -50,25 +50,15 @@ class TextGenerationPipeline(Pipeline):
self.check_model_type(self.ALLOWED_MODELS) self.check_model_type(self.ALLOWED_MODELS)
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, **kwargs):
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
""" """
Parse arguments and tokenize Parse arguments and tokenize
""" """
# Parse arguments # Parse arguments
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True} kwargs.update({"add_space_before_punct_symbol": True})
else:
tokenizer_kwargs = {}
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
**tokenizer_kwargs,
)
return inputs return super()._parse_and_tokenize(*args, **kwargs)
def __call__( def __call__(
self, self,

View File

@ -3,6 +3,7 @@ from typing import List, Union
import numpy as np import numpy as np
from ..file_utils import add_end_docstrings from ..file_utils import add_end_docstrings
from ..tokenization_utils import TruncationStrategy
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
@ -78,7 +79,14 @@ class ZeroShotClassificationPipeline(Pipeline):
return -1 return -1
def _parse_and_tokenize( def _parse_and_tokenize(
self, sequences, candidate_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs self,
sequences,
candidate_labels,
hypothesis_template,
padding=True,
add_special_tokens=True,
truncation=TruncationStrategy.ONLY_FIRST,
**kwargs
): ):
""" """
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
@ -89,7 +97,7 @@ class ZeroShotClassificationPipeline(Pipeline):
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
return_tensors=self.framework, return_tensors=self.framework,
padding=padding, padding=padding,
truncation="only_first", truncation=truncation,
) )
return inputs return inputs

View File

@ -14,15 +14,72 @@
import unittest import unittest
from transformers import pipeline from transformers import AutoTokenizer, is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from transformers.tokenization_utils import TruncationStrategy
from .test_pipelines_common import MonoInputPipelineCommonMixin from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
import torch
from transformers.models.bart import BartConfig, BartForConditionalGeneration
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
class SimpleSummarizationPipelineTests(unittest.TestCase):
@require_torch
def test_input_too_long(self):
torch.manual_seed(0)
config = BartConfig(
vocab_size=257,
d_model=32,
encoder_layers=1,
decoder_layers=1,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
# So any text > 4 should raise an exception
max_position_embeddings=4,
encoder_attention_heads=1,
decoder_attention_heads=1,
max_length=4,
min_length=1,
)
model = BartForConditionalGeneration(config)
# Bias output towards L
V, C = model.lm_head.weight.shape
bias = torch.zeros(V, requires_grad=True)
bias[76] = 10
model.lm_head.bias = torch.nn.Parameter(bias)
# # Generated with:
# import tempfile
# from tokenizers import Tokenizer, models
# from transformers import PreTrainedTokenizerFast
# model_max_length = 4
# vocab = [(chr(i), i) for i in range(256)]
# tokenizer = Tokenizer(models.Unigram(vocab))
# with tempfile.NamedTemporaryFile() as f:
# tokenizer.save(f.name)
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length)
# real_tokenizer._tokenizer.save("tokenizer.json")
# # + add missing config.json with albert as model_type
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test")
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
with self.assertLogs("transformers", level="WARNING"):
with self.assertRaises(IndexError):
_ = nlp("This is a test")
output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST)
# 2 is default BOS from Bart.
self.assertEqual(output, [{"summary_text": "\x02 L L L"}])
class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "summarization" pipeline_task = "summarization"
pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5} pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}