finalized PR

This commit is contained in:
patrickvonplaten 2020-03-07 10:55:23 +01:00 committed by Patrick von Platen
parent 2acfe63964
commit d880a5fbde
2 changed files with 11 additions and 13 deletions

View File

@ -798,7 +798,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
@ -989,10 +989,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if unfinished_sents.max() == 0:
break
# extend attention_mask for new generated input
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
cur_len = cur_len + 1
@ -1078,7 +1078,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
if (
self.config.is_encoder_decoder
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search?
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached
@ -1205,10 +1207,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if past:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# update current length
@ -1270,7 +1272,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder:
# do not return first <BOS> token
# do not return first <EOS> token
return decoded[:, 1:]
return decoded

View File

@ -453,9 +453,7 @@ class BartModelIntegrationTest(unittest.TestCase):
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
dct = tok.batch_encode_plus(
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
[IRAN_ARTICLE, ARTICLE_SUBWAY],
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
[FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
max_length=1024,
pad_to_max_length=True,
return_tensors="pt",
@ -482,9 +480,7 @@ class BartModelIntegrationTest(unittest.TestCase):
]
self.assertListEqual(
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
[EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded,
)
# TODO(SS): run fairseq again with num_beams=2, min_len=20.