diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index f542a47a088..b8f039ad936 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -2403,6 +2403,9 @@ class GenerationMixin: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need + if output_scores: + processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) @@ -2411,9 +2414,6 @@ class GenerationMixin: # indices of beams of current group among all sentences in batch batch_group_indices = [] - if output_scores: - processed_score = torch.zeros_like(outputs.logits[:, -1, :]) - for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]