mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 12:38:23 +06:00
Enable TruncationStrategy override for pipelines (#9432)
* Enable TruncationStrategy override for pipelines * Update isort. * Fixing test * Fixing text_generation pipeline. * Using same DummyTok as other PR for easier merge later. * Some more import guards. * Remove bogus file. * Do not pass `generate_kwargs` to `_parse_and_tokenize`. @patrickvonplaten * Removed DummyTok. * Doc quality.
This commit is contained in:
parent
8d25df2c7a
commit
d20e9c7299
@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||||
from ..modelcard import ModelCard
|
from ..modelcard import ModelCard
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
@ -577,7 +577,9 @@ class Pipeline(_ScikitCompat):
|
|||||||
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
|
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
def _parse_and_tokenize(
|
||||||
|
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
"""
|
"""
|
||||||
@ -587,6 +589,7 @@ class Pipeline(_ScikitCompat):
|
|||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
return_tensors=self.framework,
|
return_tensors=self.framework,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -2,6 +2,7 @@ import uuid
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||||
|
from ..tokenization_utils import TruncationStrategy
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
@ -317,12 +318,14 @@ class ConversationalPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _parse_and_tokenize(self, inputs, **kwargs):
|
def _parse_and_tokenize(
|
||||||
|
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
||||||
"""
|
"""
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
|
inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
|
||||||
for input in inputs:
|
for input in inputs:
|
||||||
input.append(self.tokenizer.eos_token_id)
|
input.append(self.tokenizer.eos_token_id)
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||||
|
from ..tokenization_utils import TruncationStrategy
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
@ -50,7 +51,13 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
self,
|
||||||
|
*args,
|
||||||
|
return_tensors=False,
|
||||||
|
return_text=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
|
||||||
|
**generate_kwargs
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Generate the output text(s) using text(s) given as inputs.
|
Generate the output text(s) using text(s) given as inputs.
|
||||||
@ -64,6 +71,10 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
Whether or not to include the decoded texts in the outputs.
|
Whether or not to include the decoded texts in the outputs.
|
||||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to clean up the potential extra spaces in the text output.
|
Whether or not to clean up the potential extra spaces in the text output.
|
||||||
|
truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`):
|
||||||
|
The truncation strategy for the tokenization within the pipeline.
|
||||||
|
:obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable
|
||||||
|
to truncate the input to fit the model's max_length instead of throwing an error down the line.
|
||||||
generate_kwargs:
|
generate_kwargs:
|
||||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||||
@ -96,7 +107,7 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs)
|
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)
|
||||||
|
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
inputs = self.ensure_tensor_on_device(**inputs)
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
@ -108,9 +119,6 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||||
self.check_inputs(input_length, min_length, max_length)
|
self.check_inputs(input_length, min_length, max_length)
|
||||||
|
|
||||||
# truncation should be used by _parse_and_tokenize
|
|
||||||
generate_kwargs.pop("truncation", None)
|
|
||||||
|
|
||||||
generations = self.model.generate(
|
generations = self.model.generate(
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"],
|
||||||
|
@ -50,25 +50,15 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
self.check_model_type(self.ALLOWED_MODELS)
|
self.check_model_type(self.ALLOWED_MODELS)
|
||||||
|
|
||||||
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
||||||
|
def _parse_and_tokenize(self, *args, **kwargs):
|
||||||
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
"""
|
"""
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
|
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
|
||||||
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
|
kwargs.update({"add_space_before_punct_symbol": True})
|
||||||
else:
|
|
||||||
tokenizer_kwargs = {}
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
inputs,
|
|
||||||
add_special_tokens=add_special_tokens,
|
|
||||||
return_tensors=self.framework,
|
|
||||||
padding=padding,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs
|
return super()._parse_and_tokenize(*args, **kwargs)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -3,6 +3,7 @@ from typing import List, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings
|
from ..file_utils import add_end_docstrings
|
||||||
|
from ..tokenization_utils import TruncationStrategy
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
|
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
|
||||||
|
|
||||||
@ -78,7 +79,14 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
return -1
|
return -1
|
||||||
|
|
||||||
def _parse_and_tokenize(
|
def _parse_and_tokenize(
|
||||||
self, sequences, candidate_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
|
self,
|
||||||
|
sequences,
|
||||||
|
candidate_labels,
|
||||||
|
hypothesis_template,
|
||||||
|
padding=True,
|
||||||
|
add_special_tokens=True,
|
||||||
|
truncation=TruncationStrategy.ONLY_FIRST,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
||||||
@ -89,7 +97,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
return_tensors=self.framework,
|
return_tensors=self.framework,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
truncation="only_first",
|
truncation=truncation,
|
||||||
)
|
)
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -14,15 +14,72 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import AutoTokenizer, is_torch_available, pipeline
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
from transformers.tokenization_utils import TruncationStrategy
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.models.bart import BartConfig, BartForConditionalGeneration
|
||||||
|
|
||||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSummarizationPipelineTests(unittest.TestCase):
|
||||||
|
@require_torch
|
||||||
|
def test_input_too_long(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = BartConfig(
|
||||||
|
vocab_size=257,
|
||||||
|
d_model=32,
|
||||||
|
encoder_layers=1,
|
||||||
|
decoder_layers=1,
|
||||||
|
encoder_ffn_dim=32,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
# So any text > 4 should raise an exception
|
||||||
|
max_position_embeddings=4,
|
||||||
|
encoder_attention_heads=1,
|
||||||
|
decoder_attention_heads=1,
|
||||||
|
max_length=4,
|
||||||
|
min_length=1,
|
||||||
|
)
|
||||||
|
model = BartForConditionalGeneration(config)
|
||||||
|
# Bias output towards L
|
||||||
|
V, C = model.lm_head.weight.shape
|
||||||
|
|
||||||
|
bias = torch.zeros(V, requires_grad=True)
|
||||||
|
bias[76] = 10
|
||||||
|
|
||||||
|
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||||
|
|
||||||
|
# # Generated with:
|
||||||
|
# import tempfile
|
||||||
|
# from tokenizers import Tokenizer, models
|
||||||
|
# from transformers import PreTrainedTokenizerFast
|
||||||
|
# model_max_length = 4
|
||||||
|
# vocab = [(chr(i), i) for i in range(256)]
|
||||||
|
# tokenizer = Tokenizer(models.Unigram(vocab))
|
||||||
|
# with tempfile.NamedTemporaryFile() as f:
|
||||||
|
# tokenizer.save(f.name)
|
||||||
|
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length)
|
||||||
|
# real_tokenizer._tokenizer.save("tokenizer.json")
|
||||||
|
# # + add missing config.json with albert as model_type
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test")
|
||||||
|
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
with self.assertLogs("transformers", level="WARNING"):
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
_ = nlp("This is a test")
|
||||||
|
|
||||||
|
output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST)
|
||||||
|
# 2 is default BOS from Bart.
|
||||||
|
self.assertEqual(output, [{"summary_text": "\x02 L L L"}])
|
||||||
|
|
||||||
|
|
||||||
class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||||
pipeline_task = "summarization"
|
pipeline_task = "summarization"
|
||||||
pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}
|
pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}
|
||||||
|
Loading…
Reference in New Issue
Block a user