Fixing missing arguments for TransfoXL tokenizer when using TextGenerationPipeline (#5465)

* overriding _parse_and_tokenize in `TextGenerationPipeine` to allow for TransfoXl tokenizer arguments
This commit is contained in:
Teven 2020-07-02 13:53:33 +02:00 committed by GitHub
parent 6726416e4a
commit c6a510c6fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -615,6 +615,28 @@ class TextGenerationPipeline(Pipeline):
"TFCTRLLMHeadModel",
]
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
**tokenizer_kwargs,
)
return inputs
def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):