comment out stuff

This commit is contained in:
Patrick von Platen 2020-03-06 14:51:13 +01:00
parent 7a11e925cf
commit 421216997b

View File

@ -521,18 +521,18 @@ class BartModelIntegrationTest(unittest.TestCase):
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
hypotheses_batch_bart = hf.generate_bart(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=max_length,
min_len=min_length,
no_repeat_ngram_size=3,
)
decoded_bart = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch_bart
]
# hypotheses_batch_bart = hf.generate_bart(
# input_ids=dct["input_ids"].to(torch_device),
# attention_mask=dct["attention_mask"].to(torch_device),
# num_beams=4,
# length_penalty=2.0,
# max_length=max_length,
# min_len=min_length,
# no_repeat_ngram_size=3,
# )
# decoded_bart = [
# tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch_bart
# ]
ipdb.set_trace()
@ -540,9 +540,9 @@ class BartModelIntegrationTest(unittest.TestCase):
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded,
)
self.assertListEqual(
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded_bart,
)
# self.assertListEqual(
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
# decoded_bart,
# )
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length