mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix beam_scores
shape when token scores shape changes after logits_processor
(#25980)
This commit is contained in:
parent
a796f7eea6
commit
0fced06788
@ -3038,7 +3038,9 @@ class GenerationMixin:
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
@ -3363,7 +3365,9 @@ class GenerationMixin:
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
|
||||
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
|
||||
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
|
||||
@ -4080,7 +4084,9 @@ class GenerationMixin:
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
|
||||
scores_for_all_vocab = next_token_scores.clone()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user