comment out stuff

This commit is contained in:
Patrick von Platen 2020-03-06 14:51:13 +01:00
parent 7a11e925cf
commit 421216997b

View File

@ -521,18 +521,18 @@ class BartModelIntegrationTest(unittest.TestCase):
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
] ]
hypotheses_batch_bart = hf.generate_bart( # hypotheses_batch_bart = hf.generate_bart(
input_ids=dct["input_ids"].to(torch_device), # input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device), # attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4, # num_beams=4,
length_penalty=2.0, # length_penalty=2.0,
max_length=max_length, # max_length=max_length,
min_len=min_length, # min_len=min_length,
no_repeat_ngram_size=3, # no_repeat_ngram_size=3,
) # )
decoded_bart = [ # decoded_bart = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch_bart # tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch_bart
] # ]
ipdb.set_trace() ipdb.set_trace()
@ -540,9 +540,9 @@ class BartModelIntegrationTest(unittest.TestCase):
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded, decoded,
) )
self.assertListEqual( # self.assertListEqual(
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], # [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded_bart, # decoded_bart,
) # )
# TODO(SS): run fairseq again with num_beams=2, min_len=20. # TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length # TODO(SS): add test case that hits max_length