Adding handle_long_generation paramters for text-generation pipeline. (#14118)

* Adding `handle_long_generation` paramters for `text-generation` pipeline.

* More error handling

* Fixing tests by dropping tf support on this functionality, it needs

`max_new_tokens` to make it possible to understand user's intent.
Otherwise, `max_length` == `tokenizer.model_max_length` <
input_ids.shape[0].

* Fixing doc ?

* Doc ?

* Remove link from doc.

* Catched an issue on roberta.

* Damn doc.

* Non BC proposal ?

* Cleaning the fix ?

* Finally using only a test override.

* Don't need to modify this.

* Bad print.
This commit is contained in:
Nicolas Patry 2021-10-29 15:29:28 +02:00 committed by GitHub
parent d37f1fb8ba
commit dc540dd316
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 4 deletions

View File

@ -254,7 +254,7 @@ class ReformerEmbeddings(nn.Module):
if position_ids.shape[-1] > self.max_position_embeddings:
raise ValueError(
f"Sequence Length: {position_ids.shape[-1]} has to be larger equal than "
f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than "
f"config.max_position_embeddings {self.max_position_embeddings}."
)

View File

@ -75,6 +75,7 @@ class TextGenerationPipeline(Pipeline):
return_type=None,
clean_up_tokenization_spaces=None,
prefix=None,
handle_long_generation=None,
**generate_kwargs
):
preprocess_params = {}
@ -85,14 +86,24 @@ class TextGenerationPipeline(Pipeline):
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
)
prefix_length = prefix_inputs["input_ids"].shape[-1]
if "max_length" in generate_kwargs:
if "max_new_tokens" in generate_kwargs:
pass
elif "max_length" in generate_kwargs:
generate_kwargs["max_length"] += prefix_length
else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
if "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
if handle_long_generation is not None:
if handle_long_generation not in {"hole"}:
raise ValueError(
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']"
)
preprocess_params["handle_long_generation"] = handle_long_generation
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs
postprocess_params = {}
@ -136,6 +147,16 @@ class TextGenerationPipeline(Pipeline):
Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`):
Prefix added to prompt.
handle_long_generation (:obj:`str`, `optional`):
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
the model maximum length). There is no perfect way to adress this (more info
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
strategies to work around that problem depending on your use case.
- :obj:`None` : default strategy where nothing in particular happens
- :obj:`"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__).
@ -149,11 +170,31 @@ class TextGenerationPipeline(Pipeline):
"""
return super().__call__(text_inputs, **kwargs)
def preprocess(self, prompt_text, prefix=""):
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
inputs = self.tokenizer(
prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
)
inputs["prompt_text"] = prompt_text
if handle_long_generation == "hole":
cur_len = inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
return inputs
def _forward(self, model_inputs, **generate_kwargs):

View File

@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type):
try:
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
# XLNet actually defines it as -1.
if (
if model.config.__class__.__name__ == "RobertaConfig":
tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings > 0
):

View File

@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
else:
with self.assertRaises((ValueError, AssertionError)):
outputs = text_generator("")
if text_generator.framework == "tf":
# TF generation does not support max_new_tokens, and it's impossible
# to control long generation with only max_length without
# fancy calculation, dismissing tests for now.
return
# We don't care about infinite range models.
# They already work.
if tokenizer.model_max_length < 10000:
# Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)