TF: add beam search tests (#16202)

This commit is contained in:
Joao Gante 2022-03-16 15:44:33 +00:00 committed by GitHub
parent 190994573a
commit 204c54d411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 0 deletions

View File

@ -521,6 +521,34 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
]
self.assertListEqual(output_strings, expected_output_string)
@slow
def test_lm_generate_greedy_distilgpt2_beam_search_special(self):
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
sentences = ["Today is a beautiful day and", "Yesterday was"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
"no_repeat_ngram_size": 2,
"do_sample": False,
"repetition_penalty": 1.3,
"num_beams": 2,
}
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = [
"Today is a beautiful day and I hope you enjoy it.\nI am very happy to announce that",
"Yesterday was the first time I've ever seen a game where you can play with",
]
self.assertListEqual(output_strings, expected_output_string)
@slow
def test_lm_generate_gpt2(self):
model = TFGPT2LMHeadModel.from_pretrained("gpt2")

View File

@ -548,6 +548,29 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings)
@slow
def test_beam_search_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
generation_kwargs = {
"bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
"no_repeat_ngram_size": 3,
"do_sample": False,
"repetition_penalty": 2.2,
"num_beams": 4,
}
output_ids = model.generate(input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
self.assertListEqual(expected_output_string, output_strings)
@require_tf
@require_sentencepiece