Adding handle_long_generation paramters for text-generation pipeline. (#14118)

* Adding `handle_long_generation` paramters for `text-generation` pipeline.

* More error handling

* Fixing tests by dropping tf support on this functionality, it needs

`max_new_tokens` to make it possible to understand user's intent.
Otherwise, `max_length` == `tokenizer.model_max_length` <
input_ids.shape[0].

* Fixing doc ?

* Doc ?

* Remove link from doc.

* Catched an issue on roberta.

* Damn doc.

* Non BC proposal ?

* Cleaning the fix ?

* Finally using only a test override.

* Don't need to modify this.

* Bad print.
This commit is contained in:
Nicolas Patry 2021-10-29 15:29:28 +02:00 committed by GitHub
parent d37f1fb8ba
commit dc540dd316
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 4 deletions

View File

@ -254,7 +254,7 @@ class ReformerEmbeddings(nn.Module):
if position_ids.shape[-1] > self.max_position_embeddings: if position_ids.shape[-1] > self.max_position_embeddings:
raise ValueError( raise ValueError(
f"Sequence Length: {position_ids.shape[-1]} has to be larger equal than " f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than "
f"config.max_position_embeddings {self.max_position_embeddings}." f"config.max_position_embeddings {self.max_position_embeddings}."
) )

View File

@ -75,6 +75,7 @@ class TextGenerationPipeline(Pipeline):
return_type=None, return_type=None,
clean_up_tokenization_spaces=None, clean_up_tokenization_spaces=None,
prefix=None, prefix=None,
handle_long_generation=None,
**generate_kwargs **generate_kwargs
): ):
preprocess_params = {} preprocess_params = {}
@ -85,14 +86,24 @@ class TextGenerationPipeline(Pipeline):
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
) )
prefix_length = prefix_inputs["input_ids"].shape[-1] prefix_length = prefix_inputs["input_ids"].shape[-1]
if "max_length" in generate_kwargs:
if "max_new_tokens" in generate_kwargs:
pass
elif "max_length" in generate_kwargs:
generate_kwargs["max_length"] += prefix_length generate_kwargs["max_length"] += prefix_length
else: else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
if "min_length" in generate_kwargs: if "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length generate_kwargs["min_length"] += prefix_length
if handle_long_generation is not None:
if handle_long_generation not in {"hole"}:
raise ValueError(
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']"
)
preprocess_params["handle_long_generation"] = handle_long_generation
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs forward_params = generate_kwargs
postprocess_params = {} postprocess_params = {}
@ -136,6 +147,16 @@ class TextGenerationPipeline(Pipeline):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`): prefix (:obj:`str`, `optional`):
Prefix added to prompt. Prefix added to prompt.
handle_long_generation (:obj:`str`, `optional`):
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
the model maximum length). There is no perfect way to adress this (more info
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
strategies to work around that problem depending on your use case.
- :obj:`None` : default strategy where nothing in particular happens
- :obj:`"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
generate_kwargs: generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__). corresponding to your framework `here <./model.html#generative-models>`__).
@ -149,11 +170,31 @@ class TextGenerationPipeline(Pipeline):
""" """
return super().__call__(text_inputs, **kwargs) return super().__call__(text_inputs, **kwargs)
def preprocess(self, prompt_text, prefix=""): def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
inputs = self.tokenizer( inputs = self.tokenizer(
prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
) )
inputs["prompt_text"] = prompt_text inputs["prompt_text"] = prompt_text
if handle_long_generation == "hole":
cur_len = inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
return inputs return inputs
def _forward(self, model_inputs, **generate_kwargs): def _forward(self, model_inputs, **generate_kwargs):

View File

@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type):
try: try:
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint) tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
# XLNet actually defines it as -1. # XLNet actually defines it as -1.
if ( if model.config.__class__.__name__ == "RobertaConfig":
tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings > 0 and model.config.max_position_embeddings > 0
): ):

View File

@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
else: else:
with self.assertRaises((ValueError, AssertionError)): with self.assertRaises((ValueError, AssertionError)):
outputs = text_generator("") outputs = text_generator("")
if text_generator.framework == "tf":
# TF generation does not support max_new_tokens, and it's impossible
# to control long generation with only max_length without
# fancy calculation, dismissing tests for now.
return
# We don't care about infinite range models.
# They already work.
if tokenizer.model_max_length < 10000:
# Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)