[Pipeline, Generation] tf generation pipeline bug (#4217)

* fix PR

* move tests to correct place
This commit is contained in:
Patrick von Platen 2020-05-08 14:30:05 +02:00 committed by GitHub
parent 8bf7312654
commit cf08830c28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 1 deletions

View File

@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology # in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered. (except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS = [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"ReformerModelWithLMHead",
"GPT2LMHeadModel",
"OpenAIGPTLMHeadModel",
"CTRLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
"TFGPT2LMHeadModel",
"TFOpenAIGPTLMHeadModel",
"TFCTRLLMHeadModel",
]
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
raise NotImplementedError(
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
self.model.__class__.__name__, self.ALLOWED_MODELS
)
)
text_inputs = self._args_parser(*args) text_inputs = self._args_parser(*args)
results = [] results = []
@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
result = [] result = []
for generated_sequence in output_sequences: for generated_sequence in output_sequences:
generated_sequence = generated_sequence.tolist() generated_sequence = generated_sequence.numpy().tolist()
record = {} record = {}
if return_tensors: if return_tensors:
record["generated_token_ids"] = generated_sequence record["generated_token_ids"] = generated_sequence

View File

@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
("xlnet-base-cased", "xlnet-base-cased"), ("xlnet-base-cased", "xlnet-base-cased"),
} }
TF_TEXT_GENERATION_FINETUNED_MODELS = {
("gpt2", "gpt2"),
("xlnet-base-cased", "xlnet-base-cased"),
}
FILL_MASK_FINETUNED_MODELS = [ FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None), (("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
] ]
@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp, valid_inputs, invalid_inputs, {}, nlp, valid_inputs, invalid_inputs, {},
) )
@require_tf
def test_tf_text_generation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [None]
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, {},
)
class MultiColumnInputTestCase(unittest.TestCase): class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]): def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):