add draft version of propsoed changes for ROGUE score

This commit is contained in:
patrickvonplaten 2020-03-09 00:33:12 +01:00 committed by Patrick von Platen
parent a5751f7578
commit 41b437ea3a
2 changed files with 14 additions and 9 deletions

View File

@ -846,7 +846,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
@ -1079,10 +1080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if (
self.config.is_encoder_decoder and do_sample is False
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# if (
# self.config.is_encoder_decoder and do_sample is False
# ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
# scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
@ -1271,9 +1272,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder:
# if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:]
# return decoded[:, 1:]
return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.

View File

@ -214,6 +214,9 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=output_past,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
return config, input_ids, batch_size
@ -468,7 +471,8 @@ class BartModelIntegrationTest(unittest.TestCase):
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=max_length + 2,
# max_length=max_length + 2,
max_length=max_length + 1,
min_length=min_length + 1,
no_repeat_ngram_size=3,
do_sample=False,