diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index b4a73d68a34..9efa6d40750 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -615,6 +615,28 @@ class TextGenerationPipeline(Pipeline): "TFCTRLLMHeadModel", ] + # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments + + def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): + """ + Parse arguments and tokenize + """ + # Parse arguments + if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: + tokenizer_kwargs = {"add_space_before_punct_symbol": True} + else: + tokenizer_kwargs = {} + inputs = self._args_parser(*args, **kwargs) + inputs = self.tokenizer( + inputs, + add_special_tokens=add_special_tokens, + return_tensors=self.framework, + padding=padding, + **tokenizer_kwargs, + ) + + return inputs + def __call__( self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs ):