mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Enable TruncationStrategy override for pipelines (#9432)
* Enable TruncationStrategy override for pipelines * Update isort. * Fixing test * Fixing text_generation pipeline. * Using same DummyTok as other PR for easier merge later. * Some more import guards. * Remove bogus file. * Do not pass `generate_kwargs` to `_parse_and_tokenize`. @patrickvonplaten * Removed DummyTok. * Doc quality.
This commit is contained in:
parent
8d25df2c7a
commit
d20e9c7299
@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
from ..modelcard import ModelCard
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
@ -577,7 +577,9 @@ class Pipeline(_ScikitCompat):
|
||||
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
|
||||
)
|
||||
|
||||
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
||||
def _parse_and_tokenize(
|
||||
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
|
||||
):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
"""
|
||||
@ -587,6 +589,7 @@ class Pipeline(_ScikitCompat):
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=self.framework,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
@ -2,6 +2,7 @@ import uuid
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
from ..tokenization_utils import TruncationStrategy
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
@ -317,12 +318,14 @@ class ConversationalPipeline(Pipeline):
|
||||
else:
|
||||
return output
|
||||
|
||||
def _parse_and_tokenize(self, inputs, **kwargs):
|
||||
def _parse_and_tokenize(
|
||||
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
|
||||
):
|
||||
"""
|
||||
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
||||
"""
|
||||
# Parse arguments
|
||||
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
|
||||
inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
|
||||
for input in inputs:
|
||||
input.append(self.tokenizer.eos_token_id)
|
||||
return inputs
|
||||
|
@ -1,4 +1,5 @@
|
||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
from ..tokenization_utils import TruncationStrategy
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
@ -50,7 +51,13 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
return True
|
||||
|
||||
def __call__(
|
||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
self,
|
||||
*args,
|
||||
return_tensors=False,
|
||||
return_text=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
**generate_kwargs
|
||||
):
|
||||
r"""
|
||||
Generate the output text(s) using text(s) given as inputs.
|
||||
@ -64,6 +71,10 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
Whether or not to include the decoded texts in the outputs.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`):
|
||||
The truncation strategy for the tokenization within the pipeline.
|
||||
:obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable
|
||||
to truncate the input to fit the model's max_length instead of throwing an error down the line.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
@ -96,7 +107,7 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
)
|
||||
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs)
|
||||
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)
|
||||
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
@ -108,9 +119,6 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
self.check_inputs(input_length, min_length, max_length)
|
||||
|
||||
# truncation should be used by _parse_and_tokenize
|
||||
generate_kwargs.pop("truncation", None)
|
||||
|
||||
generations = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
|
@ -50,25 +50,15 @@ class TextGenerationPipeline(Pipeline):
|
||||
self.check_model_type(self.ALLOWED_MODELS)
|
||||
|
||||
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
||||
|
||||
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
||||
def _parse_and_tokenize(self, *args, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
"""
|
||||
# Parse arguments
|
||||
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
|
||||
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
|
||||
else:
|
||||
tokenizer_kwargs = {}
|
||||
inputs = self.tokenizer(
|
||||
inputs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=self.framework,
|
||||
padding=padding,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
kwargs.update({"add_space_before_punct_symbol": True})
|
||||
|
||||
return inputs
|
||||
return super()._parse_and_tokenize(*args, **kwargs)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -3,6 +3,7 @@ from typing import List, Union
|
||||
import numpy as np
|
||||
|
||||
from ..file_utils import add_end_docstrings
|
||||
from ..tokenization_utils import TruncationStrategy
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
|
||||
|
||||
@ -78,7 +79,14 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
return -1
|
||||
|
||||
def _parse_and_tokenize(
|
||||
self, sequences, candidate_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
|
||||
self,
|
||||
sequences,
|
||||
candidate_labels,
|
||||
hypothesis_template,
|
||||
padding=True,
|
||||
add_special_tokens=True,
|
||||
truncation=TruncationStrategy.ONLY_FIRST,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
||||
@ -89,7 +97,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors=self.framework,
|
||||
padding=padding,
|
||||
truncation="only_first",
|
||||
truncation=truncation,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
@ -14,15 +14,72 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers import AutoTokenizer, is_torch_available, pipeline
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.tokenization_utils import TruncationStrategy
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.bart import BartConfig, BartForConditionalGeneration
|
||||
|
||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
|
||||
|
||||
class SimpleSummarizationPipelineTests(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_input_too_long(self):
|
||||
torch.manual_seed(0)
|
||||
config = BartConfig(
|
||||
vocab_size=257,
|
||||
d_model=32,
|
||||
encoder_layers=1,
|
||||
decoder_layers=1,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
# So any text > 4 should raise an exception
|
||||
max_position_embeddings=4,
|
||||
encoder_attention_heads=1,
|
||||
decoder_attention_heads=1,
|
||||
max_length=4,
|
||||
min_length=1,
|
||||
)
|
||||
model = BartForConditionalGeneration(config)
|
||||
# Bias output towards L
|
||||
V, C = model.lm_head.weight.shape
|
||||
|
||||
bias = torch.zeros(V, requires_grad=True)
|
||||
bias[76] = 10
|
||||
|
||||
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||
|
||||
# # Generated with:
|
||||
# import tempfile
|
||||
# from tokenizers import Tokenizer, models
|
||||
# from transformers import PreTrainedTokenizerFast
|
||||
# model_max_length = 4
|
||||
# vocab = [(chr(i), i) for i in range(256)]
|
||||
# tokenizer = Tokenizer(models.Unigram(vocab))
|
||||
# with tempfile.NamedTemporaryFile() as f:
|
||||
# tokenizer.save(f.name)
|
||||
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length)
|
||||
# real_tokenizer._tokenizer.save("tokenizer.json")
|
||||
# # + add missing config.json with albert as model_type
|
||||
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test")
|
||||
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING"):
|
||||
with self.assertRaises(IndexError):
|
||||
_ = nlp("This is a test")
|
||||
|
||||
output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST)
|
||||
# 2 is default BOS from Bart.
|
||||
self.assertEqual(output, [{"summary_text": "\x02 L L L"}])
|
||||
|
||||
|
||||
class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "summarization"
|
||||
pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}
|
||||
|
Loading…
Reference in New Issue
Block a user