mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adding handle_long_generation
paramters for text-generation
pipeline. (#14118)
* Adding `handle_long_generation` paramters for `text-generation` pipeline. * More error handling * Fixing tests by dropping tf support on this functionality, it needs `max_new_tokens` to make it possible to understand user's intent. Otherwise, `max_length` == `tokenizer.model_max_length` < input_ids.shape[0]. * Fixing doc ? * Doc ? * Remove link from doc. * Catched an issue on roberta. * Damn doc. * Non BC proposal ? * Cleaning the fix ? * Finally using only a test override. * Don't need to modify this. * Bad print.
This commit is contained in:
parent
d37f1fb8ba
commit
dc540dd316
@ -254,7 +254,7 @@ class ReformerEmbeddings(nn.Module):
|
||||
|
||||
if position_ids.shape[-1] > self.max_position_embeddings:
|
||||
raise ValueError(
|
||||
f"Sequence Length: {position_ids.shape[-1]} has to be larger equal than "
|
||||
f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than "
|
||||
f"config.max_position_embeddings {self.max_position_embeddings}."
|
||||
)
|
||||
|
||||
|
@ -75,6 +75,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
return_type=None,
|
||||
clean_up_tokenization_spaces=None,
|
||||
prefix=None,
|
||||
handle_long_generation=None,
|
||||
**generate_kwargs
|
||||
):
|
||||
preprocess_params = {}
|
||||
@ -85,14 +86,24 @@ class TextGenerationPipeline(Pipeline):
|
||||
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
|
||||
)
|
||||
prefix_length = prefix_inputs["input_ids"].shape[-1]
|
||||
if "max_length" in generate_kwargs:
|
||||
|
||||
if "max_new_tokens" in generate_kwargs:
|
||||
pass
|
||||
elif "max_length" in generate_kwargs:
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
else:
|
||||
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
|
||||
|
||||
if "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
if handle_long_generation is not None:
|
||||
if handle_long_generation not in {"hole"}:
|
||||
raise ValueError(
|
||||
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']"
|
||||
)
|
||||
preprocess_params["handle_long_generation"] = handle_long_generation
|
||||
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
|
||||
postprocess_params = {}
|
||||
@ -136,6 +147,16 @@ class TextGenerationPipeline(Pipeline):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
prefix (:obj:`str`, `optional`):
|
||||
Prefix added to prompt.
|
||||
handle_long_generation (:obj:`str`, `optional`):
|
||||
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
|
||||
the model maximum length). There is no perfect way to adress this (more info
|
||||
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
|
||||
strategies to work around that problem depending on your use case.
|
||||
|
||||
- :obj:`None` : default strategy where nothing in particular happens
|
||||
- :obj:`"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
|
||||
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
|
||||
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
@ -149,11 +170,31 @@ class TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
return super().__call__(text_inputs, **kwargs)
|
||||
|
||||
def preprocess(self, prompt_text, prefix=""):
|
||||
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
||||
inputs = self.tokenizer(
|
||||
prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
|
||||
)
|
||||
inputs["prompt_text"] = prompt_text
|
||||
|
||||
if handle_long_generation == "hole":
|
||||
cur_len = inputs["input_ids"].shape[-1]
|
||||
if "max_new_tokens" in generate_kwargs:
|
||||
new_tokens = generate_kwargs["max_new_tokens"]
|
||||
else:
|
||||
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
|
||||
if new_tokens < 0:
|
||||
raise ValueError("We cannot infer how many new tokens are expected")
|
||||
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
||||
keep_length = self.tokenizer.model_max_length - new_tokens
|
||||
if keep_length <= 0:
|
||||
raise ValueError(
|
||||
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length"
|
||||
)
|
||||
|
||||
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
|
@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type):
|
||||
try:
|
||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||
# XLNet actually defines it as -1.
|
||||
if (
|
||||
if model.config.__class__.__name__ == "RobertaConfig":
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings - 2
|
||||
elif (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings > 0
|
||||
):
|
||||
|
@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
else:
|
||||
with self.assertRaises((ValueError, AssertionError)):
|
||||
outputs = text_generator("")
|
||||
|
||||
if text_generator.framework == "tf":
|
||||
# TF generation does not support max_new_tokens, and it's impossible
|
||||
# to control long generation with only max_length without
|
||||
# fancy calculation, dismissing tests for now.
|
||||
return
|
||||
# We don't care about infinite range models.
|
||||
# They already work.
|
||||
if tokenizer.model_max_length < 10000:
|
||||
# Handling of large generations
|
||||
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
|
||||
text_generator("This is a test" * 500, max_new_tokens=20)
|
||||
|
||||
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
|
||||
# Hole strategy cannot work
|
||||
with self.assertRaises(ValueError):
|
||||
text_generator(
|
||||
"This is a test" * 500,
|
||||
handle_long_generation="hole",
|
||||
max_new_tokens=tokenizer.model_max_length + 10,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user