feat: allow prefix for any generative model (#5885)

* feat: allow padding_text for any generative model

* docs(pipelines.py): correct typo

* Update src/transformers/pipelines.py

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* feat: rename padding_text to prefix

* fix: cannot tokenize empty text

* fix: pass prefix arg to pipeline

* test: add prefix to text-generetation pipeline

* style: fix style

* style: clean code and variable name more explicit

* set arg docstring to optional

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Boris Dayma 2020-09-07 02:03:45 -05:00 committed by GitHub
parent ce37be9d94
commit 995a958dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 24 deletions

View File

@ -61,7 +61,7 @@ MODEL_CLASSES = {
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
@ -122,12 +122,14 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
prompt_text = prefix + prompt_text
return prompt_text
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
prompt_text = prefix + prompt_text
return prompt_text
@ -182,7 +184,8 @@ def main():
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9)
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
@ -241,7 +244,8 @@ def main():
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
)
else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
prefix = args.prefix if args.prefix else args.padding_text
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device)
if encoded_prompt.size()[-1] == 0:

View File

@ -752,11 +752,11 @@ class TextGenerationPipeline(Pipeline):
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
"""
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
XL_PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
@ -765,7 +765,7 @@ class TextGenerationPipeline(Pipeline):
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. """
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS = [
"XLNetLMHeadModel",
@ -809,7 +809,13 @@ class TextGenerationPipeline(Pipeline):
return inputs
def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
prefix=None,
**generate_kwargs
):
"""
Complete the prompt(s) given as inputs.
@ -823,6 +829,8 @@ class TextGenerationPipeline(Pipeline):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`):
Prefix added to prompt.
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate
method corresponding to your framework `here <./model.html#generative-models>`__).
@ -841,27 +849,27 @@ class TextGenerationPipeline(Pipeline):
for prompt_text in text_inputs:
# Manage correct placement of the tensors
with self.device_placement():
if self.model.__class__.__name__ in [
prefix = prefix if prefix is not None else self.model.config.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
]:
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
padding_text = self.PADDING_TEXT + self.tokenizer.eos_token
padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
# This impacts max_length and min_length argument that need adjusting.
padding_length = padding["input_ids"].shape[-1]
if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None:
generate_kwargs["max_length"] += padding_length
if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None:
generate_kwargs["min_length"] += padding_length
# For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
prefix = self.XL_PREFIX
inputs = self._parse_and_tokenize(
padding_text + prompt_text, padding=False, add_special_tokens=False
)
else:
inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
if prefix:
prefix_inputs = self._parse_and_tokenize(prefix, padding=False, add_special_tokens=False)
# This impacts max_length and min_length argument that need adjusting.
prefix_length = prefix_inputs["input_ids"].shape[-1]
if generate_kwargs.get("max_length", None) is not None:
generate_kwargs["max_length"] += prefix_length
if generate_kwargs.get("min_length", None) is not None:
generate_kwargs["min_length"] += prefix_length
prefix = prefix or ""
inputs = self._parse_and_tokenize(prefix + prompt_text, padding=False, add_special_tokens=False)
# set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0:

View File

@ -424,12 +424,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
for model_name in TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="pt")
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ")
@require_tf
def test_tf_text_generation(self):
for model_name in TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf")
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ")
@slow
@require_torch