mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Pipeline, Generation] tf generation pipeline bug (#4217)
* fix PR * move tests to correct place
This commit is contained in:
parent
8bf7312654
commit
cf08830c28
@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||
|
||||
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
(except for Alexei and Maria) are discovered.
|
||||
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||
@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
|
||||
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||||
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
|
||||
|
||||
ALLOWED_MODELS = [
|
||||
"XLNetLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
"ReformerModelWithLMHead",
|
||||
"GPT2LMHeadModel",
|
||||
"OpenAIGPTLMHeadModel",
|
||||
"CTRLLMHeadModel",
|
||||
"TFXLNetLMHeadModel",
|
||||
"TFTransfoXLLMHeadModel",
|
||||
"TFGPT2LMHeadModel",
|
||||
"TFOpenAIGPTLMHeadModel",
|
||||
"TFCTRLLMHeadModel",
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
|
||||
raise NotImplementedError(
|
||||
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
|
||||
self.model.__class__.__name__, self.ALLOWED_MODELS
|
||||
)
|
||||
)
|
||||
|
||||
text_inputs = self._args_parser(*args)
|
||||
|
||||
results = []
|
||||
@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
|
||||
|
||||
result = []
|
||||
for generated_sequence in output_sequences:
|
||||
generated_sequence = generated_sequence.tolist()
|
||||
generated_sequence = generated_sequence.numpy().tolist()
|
||||
record = {}
|
||||
if return_tensors:
|
||||
record["generated_token_ids"] = generated_sequence
|
||||
|
@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
|
||||
("xlnet-base-cased", "xlnet-base-cased"),
|
||||
}
|
||||
|
||||
TF_TEXT_GENERATION_FINETUNED_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("xlnet-base-cased", "xlnet-base-cased"),
|
||||
}
|
||||
|
||||
FILL_MASK_FINETUNED_MODELS = [
|
||||
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
|
||||
]
|
||||
@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
nlp, valid_inputs, invalid_inputs, {},
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_text_generation(self):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [None]
|
||||
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, {},
|
||||
)
|
||||
|
||||
|
||||
class MultiColumnInputTestCase(unittest.TestCase):
|
||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||
|
Loading…
Reference in New Issue
Block a user