mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
comment out stuff
This commit is contained in:
parent
7a11e925cf
commit
421216997b
@ -521,18 +521,18 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
|
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
|
||||||
]
|
]
|
||||||
|
|
||||||
hypotheses_batch_bart = hf.generate_bart(
|
# hypotheses_batch_bart = hf.generate_bart(
|
||||||
input_ids=dct["input_ids"].to(torch_device),
|
# input_ids=dct["input_ids"].to(torch_device),
|
||||||
attention_mask=dct["attention_mask"].to(torch_device),
|
# attention_mask=dct["attention_mask"].to(torch_device),
|
||||||
num_beams=4,
|
# num_beams=4,
|
||||||
length_penalty=2.0,
|
# length_penalty=2.0,
|
||||||
max_length=max_length,
|
# max_length=max_length,
|
||||||
min_len=min_length,
|
# min_len=min_length,
|
||||||
no_repeat_ngram_size=3,
|
# no_repeat_ngram_size=3,
|
||||||
)
|
# )
|
||||||
decoded_bart = [
|
# decoded_bart = [
|
||||||
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch_bart
|
# tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch_bart
|
||||||
]
|
# ]
|
||||||
|
|
||||||
ipdb.set_trace()
|
ipdb.set_trace()
|
||||||
|
|
||||||
@ -540,9 +540,9 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||||
decoded,
|
decoded,
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
# self.assertListEqual(
|
||||||
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
|
||||||
decoded_bart,
|
# decoded_bart,
|
||||||
)
|
# )
|
||||||
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
||||||
# TODO(SS): add test case that hits max_length
|
# TODO(SS): add test case that hits max_length
|
||||||
|
Loading…
Reference in New Issue
Block a user