From 0fced067880e2f7785dd2b7dce600f8cbdb691e7 Mon Sep 17 00:00:00 2001 From: BakerBunker <17872844+BakerBunker@users.noreply.github.com> Date: Thu, 14 Sep 2023 02:12:47 +0800 Subject: [PATCH] Fix `beam_scores` shape when token scores shape changes after `logits_processor` (#25980) --- src/transformers/generation/utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 31bb0eca5c0..3b1bef6f040 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3038,7 +3038,9 @@ class GenerationMixin: ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -3363,7 +3365,9 @@ class GenerationMixin: ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867 @@ -4080,7 +4084,9 @@ class GenerationMixin: next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) scores_for_all_vocab = next_token_scores.clone()