Fix beam_scores shape when token scores shape changes after logits_processor (#25980)

This commit is contained in:
BakerBunker 2023-09-14 02:12:47 +08:00 committed by GitHub
parent a796f7eea6
commit 0fced06788
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3038,7 +3038,9 @@ class GenerationMixin:
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
@ -3363,7 +3365,9 @@ class GenerationMixin:
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
@ -4080,7 +4084,9 @@ class GenerationMixin:
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
scores_for_all_vocab = next_token_scores.clone()