mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix beam_search behavior when sampling (#3106)
* fix beam_search behavior when sampling * delete print * make correct style
This commit is contained in:
parent
e9e6efdc45
commit
6701fb7859
@ -564,7 +564,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
@ -941,7 +945,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
|
||||
# Greedy decoding it is made sure that only words of the first beam are considered to avoid sampling the exact same words three times
|
||||
if do_sample is False:
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
# cache compute states
|
||||
@ -967,19 +974,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
scores = scores / temperature
|
||||
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# Top-p/top-k filtering
|
||||
scores = top_k_top_p_filtering(
|
||||
scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
_scores = top_k_top_p_filtering(
|
||||
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together to sample from all beam_idxs
|
||||
_scores = _scores.contiguous().view(
|
||||
batch_size, num_beams * vocab_size
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||
next_words = torch.multinomial(F.softmax(scores, dim=-1), num_samples=2) # (batch_size * num_beams, 2)
|
||||
next_words = torch.multinomial(
|
||||
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
|
||||
) # (batch_size, num_beams * 2)
|
||||
|
||||
# Compute next scores
|
||||
_scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = torch.gather(_scores, -1, next_words) # (batch_size * num_beams, 2)
|
||||
next_scores = _scores + beam_scores[:, None].expand_as(_scores) # (batch_size * num_beams, 2)
|
||||
# Match shape of greedy beam search
|
||||
next_words = next_words.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
|
||||
next_scores = next_scores.view(batch_size, 2 * num_beams) # (batch_size, 2 * num_beams)
|
||||
next_scores = torch.gather(_scores, -1, next_words) # (batch_size, num_beams * 2)
|
||||
|
||||
else:
|
||||
# do greedy beam search
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
@ -1026,7 +1042,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if eos_token_ids is not None and word_id.item() in eos_token_ids:
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted word if it is not eos_token
|
||||
|
Loading…
Reference in New Issue
Block a user