mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Adding handle_long_generation
paramters for text-generation
pipeline. (#14118)
* Adding `handle_long_generation` paramters for `text-generation` pipeline. * More error handling * Fixing tests by dropping tf support on this functionality, it needs `max_new_tokens` to make it possible to understand user's intent. Otherwise, `max_length` == `tokenizer.model_max_length` < input_ids.shape[0]. * Fixing doc ? * Doc ? * Remove link from doc. * Catched an issue on roberta. * Damn doc. * Non BC proposal ? * Cleaning the fix ? * Finally using only a test override. * Don't need to modify this. * Bad print.
This commit is contained in:
parent
d37f1fb8ba
commit
dc540dd316
@ -254,7 +254,7 @@ class ReformerEmbeddings(nn.Module):
|
|||||||
|
|
||||||
if position_ids.shape[-1] > self.max_position_embeddings:
|
if position_ids.shape[-1] > self.max_position_embeddings:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Sequence Length: {position_ids.shape[-1]} has to be larger equal than "
|
f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than "
|
||||||
f"config.max_position_embeddings {self.max_position_embeddings}."
|
f"config.max_position_embeddings {self.max_position_embeddings}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -75,6 +75,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
return_type=None,
|
return_type=None,
|
||||||
clean_up_tokenization_spaces=None,
|
clean_up_tokenization_spaces=None,
|
||||||
prefix=None,
|
prefix=None,
|
||||||
|
handle_long_generation=None,
|
||||||
**generate_kwargs
|
**generate_kwargs
|
||||||
):
|
):
|
||||||
preprocess_params = {}
|
preprocess_params = {}
|
||||||
@ -85,14 +86,24 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
|
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
|
||||||
)
|
)
|
||||||
prefix_length = prefix_inputs["input_ids"].shape[-1]
|
prefix_length = prefix_inputs["input_ids"].shape[-1]
|
||||||
if "max_length" in generate_kwargs:
|
|
||||||
|
if "max_new_tokens" in generate_kwargs:
|
||||||
|
pass
|
||||||
|
elif "max_length" in generate_kwargs:
|
||||||
generate_kwargs["max_length"] += prefix_length
|
generate_kwargs["max_length"] += prefix_length
|
||||||
else:
|
else:
|
||||||
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
|
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
|
||||||
|
|
||||||
if "min_length" in generate_kwargs:
|
if "min_length" in generate_kwargs:
|
||||||
generate_kwargs["min_length"] += prefix_length
|
generate_kwargs["min_length"] += prefix_length
|
||||||
|
if handle_long_generation is not None:
|
||||||
|
if handle_long_generation not in {"hole"}:
|
||||||
|
raise ValueError(
|
||||||
|
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']"
|
||||||
|
)
|
||||||
|
preprocess_params["handle_long_generation"] = handle_long_generation
|
||||||
|
|
||||||
|
preprocess_params.update(generate_kwargs)
|
||||||
forward_params = generate_kwargs
|
forward_params = generate_kwargs
|
||||||
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
@ -136,6 +147,16 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
Whether or not to clean up the potential extra spaces in the text output.
|
Whether or not to clean up the potential extra spaces in the text output.
|
||||||
prefix (:obj:`str`, `optional`):
|
prefix (:obj:`str`, `optional`):
|
||||||
Prefix added to prompt.
|
Prefix added to prompt.
|
||||||
|
handle_long_generation (:obj:`str`, `optional`):
|
||||||
|
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
|
||||||
|
the model maximum length). There is no perfect way to adress this (more info
|
||||||
|
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
|
||||||
|
strategies to work around that problem depending on your use case.
|
||||||
|
|
||||||
|
- :obj:`None` : default strategy where nothing in particular happens
|
||||||
|
- :obj:`"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
|
||||||
|
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
|
||||||
|
|
||||||
generate_kwargs:
|
generate_kwargs:
|
||||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||||
@ -149,11 +170,31 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
return super().__call__(text_inputs, **kwargs)
|
return super().__call__(text_inputs, **kwargs)
|
||||||
|
|
||||||
def preprocess(self, prompt_text, prefix=""):
|
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
|
prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
|
||||||
)
|
)
|
||||||
inputs["prompt_text"] = prompt_text
|
inputs["prompt_text"] = prompt_text
|
||||||
|
|
||||||
|
if handle_long_generation == "hole":
|
||||||
|
cur_len = inputs["input_ids"].shape[-1]
|
||||||
|
if "max_new_tokens" in generate_kwargs:
|
||||||
|
new_tokens = generate_kwargs["max_new_tokens"]
|
||||||
|
else:
|
||||||
|
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
|
||||||
|
if new_tokens < 0:
|
||||||
|
raise ValueError("We cannot infer how many new tokens are expected")
|
||||||
|
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
||||||
|
keep_length = self.tokenizer.model_max_length - new_tokens
|
||||||
|
if keep_length <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
|
||||||
|
if "attention_mask" in inputs:
|
||||||
|
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def _forward(self, model_inputs, **generate_kwargs):
|
def _forward(self, model_inputs, **generate_kwargs):
|
||||||
|
@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type):
|
|||||||
try:
|
try:
|
||||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||||
# XLNet actually defines it as -1.
|
# XLNet actually defines it as -1.
|
||||||
if (
|
if model.config.__class__.__name__ == "RobertaConfig":
|
||||||
|
tokenizer.model_max_length = model.config.max_position_embeddings - 2
|
||||||
|
elif (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings > 0
|
and model.config.max_position_embeddings > 0
|
||||||
):
|
):
|
||||||
|
@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
|||||||
else:
|
else:
|
||||||
with self.assertRaises((ValueError, AssertionError)):
|
with self.assertRaises((ValueError, AssertionError)):
|
||||||
outputs = text_generator("")
|
outputs = text_generator("")
|
||||||
|
|
||||||
|
if text_generator.framework == "tf":
|
||||||
|
# TF generation does not support max_new_tokens, and it's impossible
|
||||||
|
# to control long generation with only max_length without
|
||||||
|
# fancy calculation, dismissing tests for now.
|
||||||
|
return
|
||||||
|
# We don't care about infinite range models.
|
||||||
|
# They already work.
|
||||||
|
if tokenizer.model_max_length < 10000:
|
||||||
|
# Handling of large generations
|
||||||
|
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
|
||||||
|
text_generator("This is a test" * 500, max_new_tokens=20)
|
||||||
|
|
||||||
|
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
|
||||||
|
# Hole strategy cannot work
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
text_generator(
|
||||||
|
"This is a test" * 500,
|
||||||
|
handle_long_generation="hole",
|
||||||
|
max_new_tokens=tokenizer.model_max_length + 10,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user