[BART] Delete redundant unit test (#3302)

This commit is contained in:
Sam Shleifer 2020-03-16 23:09:10 -04:00 committed by GitHub
parent b2028cc26b
commit b2c1a447fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -381,7 +381,7 @@ TOLERANCE = 1e-4
@require_torch
class BartModelIntegrationTest(unittest.TestCase):
class BartModelIntegrationTests(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = BartModel.from_pretrained("bart-large").to(torch_device)
@ -431,25 +431,7 @@ class BartModelIntegrationTest(unittest.TestCase):
self.assertIsNotNone(model)
@slow
def test_cnn_summarization_same_as_fairseq_easy(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens = hf.generate(
tokens,
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_ids[0],
) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
self.assertEqual(expected_result, generated[0])
@slow
def test_cnn_summarization_same_as_fairseq_hard(self):
def test_cnn_summarization_same_as_fairseq(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")