mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
feat: allow prefix for any generative model (#5885)
* feat: allow padding_text for any generative model * docs(pipelines.py): correct typo * Update src/transformers/pipelines.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * feat: rename padding_text to prefix * fix: cannot tokenize empty text * fix: pass prefix arg to pipeline * test: add prefix to text-generetation pipeline * style: fix style * style: clean code and variable name more explicit * set arg docstring to optional Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
ce37be9d94
commit
995a958dd1
@ -61,7 +61,7 @@ MODEL_CLASSES = {
|
||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
(except for Alexei and Maria) are discovered.
|
||||
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||
remainder of the story. 1883 Western Siberia,
|
||||
@ -122,12 +122,14 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
|
||||
|
||||
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
|
||||
prompt_text = prefix + prompt_text
|
||||
return prompt_text
|
||||
|
||||
|
||||
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
|
||||
prompt_text = prefix + prompt_text
|
||||
return prompt_text
|
||||
|
||||
|
||||
@ -182,7 +184,8 @@ def main():
|
||||
parser.add_argument("--k", type=int, default=0)
|
||||
parser.add_argument("--p", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
|
||||
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
|
||||
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
|
||||
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
@ -241,7 +244,8 @@ def main():
|
||||
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
|
||||
)
|
||||
else:
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
prefix = args.prefix if args.prefix else args.padding_text
|
||||
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
encoded_prompt = encoded_prompt.to(args.device)
|
||||
|
||||
if encoded_prompt.size()[-1] == 0:
|
||||
|
@ -752,11 +752,11 @@ class TextGenerationPipeline(Pipeline):
|
||||
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
|
||||
"""
|
||||
|
||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||
|
||||
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
XL_PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
(except for Alexei and Maria) are discovered.
|
||||
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||
remainder of the story. 1883 Western Siberia,
|
||||
@ -765,7 +765,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
father initially slaps him for making such an accusation, Rasputin watches as the
|
||||
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
||||
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||||
with people, even a bishop, begging for his blessing. """
|
||||
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||||
|
||||
ALLOWED_MODELS = [
|
||||
"XLNetLMHeadModel",
|
||||
@ -809,7 +809,13 @@ class TextGenerationPipeline(Pipeline):
|
||||
return inputs
|
||||
|
||||
def __call__(
|
||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
self,
|
||||
*args,
|
||||
return_tensors=False,
|
||||
return_text=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
prefix=None,
|
||||
**generate_kwargs
|
||||
):
|
||||
"""
|
||||
Complete the prompt(s) given as inputs.
|
||||
@ -823,6 +829,8 @@ class TextGenerationPipeline(Pipeline):
|
||||
Whether or not to include the decoded texts in the outputs.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
prefix (:obj:`str`, `optional`):
|
||||
Prefix added to prompt.
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate
|
||||
method corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
@ -841,27 +849,27 @@ class TextGenerationPipeline(Pipeline):
|
||||
for prompt_text in text_inputs:
|
||||
# Manage correct placement of the tensors
|
||||
with self.device_placement():
|
||||
if self.model.__class__.__name__ in [
|
||||
prefix = prefix if prefix is not None else self.model.config.prefix
|
||||
if prefix is None and self.model.__class__.__name__ in [
|
||||
"XLNetLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
"TFXLNetLMHeadModel",
|
||||
"TFTransfoXLLMHeadModel",
|
||||
]:
|
||||
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
|
||||
padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
|
||||
padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
|
||||
# This impacts max_length and min_length argument that need adjusting.
|
||||
padding_length = padding["input_ids"].shape[-1]
|
||||
if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None:
|
||||
generate_kwargs["max_length"] += padding_length
|
||||
if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None:
|
||||
generate_kwargs["min_length"] += padding_length
|
||||
# For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
|
||||
prefix = self.XL_PREFIX
|
||||
|
||||
inputs = self._parse_and_tokenize(
|
||||
padding_text + prompt_text, padding=False, add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
|
||||
if prefix:
|
||||
prefix_inputs = self._parse_and_tokenize(prefix, padding=False, add_special_tokens=False)
|
||||
# This impacts max_length and min_length argument that need adjusting.
|
||||
prefix_length = prefix_inputs["input_ids"].shape[-1]
|
||||
if generate_kwargs.get("max_length", None) is not None:
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
if generate_kwargs.get("min_length", None) is not None:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
prefix = prefix or ""
|
||||
inputs = self._parse_and_tokenize(prefix + prompt_text, padding=False, add_special_tokens=False)
|
||||
|
||||
# set input_ids to None to allow empty prompt
|
||||
if inputs["input_ids"].shape[-1] == 0:
|
||||
|
@ -424,12 +424,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
for model_name in TEXT_GENERATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="pt")
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ")
|
||||
|
||||
@require_tf
|
||||
def test_tf_text_generation(self):
|
||||
for model_name in TEXT_GENERATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
|
||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user