mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[Pipeline, Generation] tf generation pipeline bug (#4217)
* fix PR * move tests to correct place
This commit is contained in:
parent
8bf7312654
commit
cf08830c28
@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||||
|
|
||||||
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||||
(except for Alexei and Maria) are discovered.
|
(except for Alexei and Maria) are discovered.
|
||||||
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||||
@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||||||
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||||||
|
|
||||||
|
ALLOWED_MODELS = [
|
||||||
|
"XLNetLMHeadModel",
|
||||||
|
"TransfoXLLMHeadModel",
|
||||||
|
"ReformerModelWithLMHead",
|
||||||
|
"GPT2LMHeadModel",
|
||||||
|
"OpenAIGPTLMHeadModel",
|
||||||
|
"CTRLLMHeadModel",
|
||||||
|
"TFXLNetLMHeadModel",
|
||||||
|
"TFTransfoXLLMHeadModel",
|
||||||
|
"TFGPT2LMHeadModel",
|
||||||
|
"TFOpenAIGPTLMHeadModel",
|
||||||
|
"TFCTRLLMHeadModel",
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||||
):
|
):
|
||||||
|
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
|
||||||
|
self.model.__class__.__name__, self.ALLOWED_MODELS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
text_inputs = self._args_parser(*args)
|
text_inputs = self._args_parser(*args)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
result = []
|
result = []
|
||||||
for generated_sequence in output_sequences:
|
for generated_sequence in output_sequences:
|
||||||
generated_sequence = generated_sequence.tolist()
|
generated_sequence = generated_sequence.numpy().tolist()
|
||||||
record = {}
|
record = {}
|
||||||
if return_tensors:
|
if return_tensors:
|
||||||
record["generated_token_ids"] = generated_sequence
|
record["generated_token_ids"] = generated_sequence
|
||||||
|
@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
|
|||||||
("xlnet-base-cased", "xlnet-base-cased"),
|
("xlnet-base-cased", "xlnet-base-cased"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_TEXT_GENERATION_FINETUNED_MODELS = {
|
||||||
|
("gpt2", "gpt2"),
|
||||||
|
("xlnet-base-cased", "xlnet-base-cased"),
|
||||||
|
}
|
||||||
|
|
||||||
FILL_MASK_FINETUNED_MODELS = [
|
FILL_MASK_FINETUNED_MODELS = [
|
||||||
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
|
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
|
||||||
]
|
]
|
||||||
@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
nlp, valid_inputs, invalid_inputs, {},
|
nlp, valid_inputs, invalid_inputs, {},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_text_generation(self):
|
||||||
|
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||||
|
invalid_inputs = [None]
|
||||||
|
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
|
||||||
|
self._test_mono_column_pipeline(
|
||||||
|
nlp, valid_inputs, invalid_inputs, {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiColumnInputTestCase(unittest.TestCase):
|
class MultiColumnInputTestCase(unittest.TestCase):
|
||||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||||
|
Loading…
Reference in New Issue
Block a user