mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
add draft version of propsoed changes for ROGUE score
This commit is contained in:
parent
a5751f7578
commit
41b437ea3a
@ -846,7 +846,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_inputs = input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
|
||||
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
|
||||
bos_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
@ -1079,10 +1080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
if (
|
||||
self.config.is_encoder_decoder and do_sample is False
|
||||
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
# if (
|
||||
# self.config.is_encoder_decoder and do_sample is False
|
||||
# ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
|
||||
# scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
@ -1271,9 +1272,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# if self.config.is_encoder_decoder:
|
||||
# do not return first <EOS> token
|
||||
return decoded[:, 1:]
|
||||
# return decoded[:, 1:]
|
||||
return decoded
|
||||
|
||||
# force one of token_ids to be generated by setting prob of all other tokens to 0.
|
||||
|
@ -214,6 +214,9 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=output_past,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
)
|
||||
return config, input_ids, batch_size
|
||||
|
||||
@ -468,7 +471,8 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
attention_mask=dct["attention_mask"].to(torch_device),
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
max_length=max_length + 2,
|
||||
# max_length=max_length + 2,
|
||||
max_length=max_length + 1,
|
||||
min_length=min_length + 1,
|
||||
no_repeat_ngram_size=3,
|
||||
do_sample=False,
|
||||
|
Loading…
Reference in New Issue
Block a user