fix beam_search behavior when sampling (#3106)

* fix beam_search behavior when sampling

* delete print

* make correct style
This commit is contained in:
Patrick von Platen 2020-03-04 15:30:51 +01:00 committed by GitHub
parent e9e6efdc45
commit 6701fb7859
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -564,7 +564,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model.eval()
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
@ -941,7 +945,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
# Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
@ -967,19 +974,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
scores = scores / temperature
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
scores = top_k_top_p_filtering(
scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
_scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
) # (batch_size, num_beams * vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
next_words = torch.multinomial(
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
) # (batch_size, num_beams * 2)
# Compute next scores
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2)
else:
# do greedy beam search
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
@ -1026,7 +1042,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and word_id.item() in eos_token_ids:
generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
)
else:
# add next predicted word if it is not eos_token