[Pipeline, Generation] tf generation pipeline bug (#4217)

* fix PR

* move tests to correct place
This commit is contained in:
Patrick von Platen 2020-05-08 14:30:05 +02:00 committed by GitHub
parent 8bf7312654
commit cf08830c28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 1 deletions

View File

@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS = [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"ReformerModelWithLMHead",
"GPT2LMHeadModel",
"OpenAIGPTLMHeadModel",
"CTRLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
"TFGPT2LMHeadModel",
"TFOpenAIGPTLMHeadModel",
"TFCTRLLMHeadModel",
]
def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
raise NotImplementedError(
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
self.model.__class__.__name__, self.ALLOWED_MODELS
)
)
text_inputs = self._args_parser(*args)
results = []
@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
result = []
for generated_sequence in output_sequences:
generated_sequence = generated_sequence.tolist()
generated_sequence = generated_sequence.numpy().tolist()
record = {}
if return_tensors:
record["generated_token_ids"] = generated_sequence

View File

@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
("xlnet-base-cased", "xlnet-base-cased"),
}
TF_TEXT_GENERATION_FINETUNED_MODELS = {
("gpt2", "gpt2"),
("xlnet-base-cased", "xlnet-base-cased"),
}
FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]
@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp, valid_inputs, invalid_inputs, {},
)
@require_tf
def test_tf_text_generation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [None]
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, {},
)
class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):