mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[BART] Delete redundant unit test (#3302)
This commit is contained in:
parent
b2028cc26b
commit
b2c1a447fe
@ -381,7 +381,7 @@ TOLERANCE = 1e-4
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartModelIntegrationTest(unittest.TestCase):
|
||||
class BartModelIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = BartModel.from_pretrained("bart-large").to(torch_device)
|
||||
@ -431,25 +431,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_cnn_summarization_same_as_fairseq_easy(self):
|
||||
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
|
||||
tok = BartTokenizer.from_pretrained("bart-large")
|
||||
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
|
||||
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
||||
extra_len = 20
|
||||
gen_tokens = hf.generate(
|
||||
tokens,
|
||||
num_beams=4,
|
||||
max_length=extra_len + 2,
|
||||
do_sample=False,
|
||||
decoder_start_token_id=hf.config.eos_token_ids[0],
|
||||
) # repetition_penalty=10.,
|
||||
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
|
||||
generated = [tok.decode(g,) for g in gen_tokens]
|
||||
self.assertEqual(expected_result, generated[0])
|
||||
|
||||
@slow
|
||||
def test_cnn_summarization_same_as_fairseq_hard(self):
|
||||
def test_cnn_summarization_same_as_fairseq(self):
|
||||
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
|
||||
tok = BartTokenizer.from_pretrained("bart-large")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user