mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
correcting group beam search function output score bug (#13211)
This commit is contained in:
parent
f689743e74
commit
b13c6c18d0
@ -2403,6 +2403,9 @@ class GenerationMixin:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
if output_scores:
|
||||
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
||||
|
||||
for beam_group_idx in range(num_beam_groups):
|
||||
group_start_idx = beam_group_idx * num_sub_beams
|
||||
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||
@ -2411,9 +2414,6 @@ class GenerationMixin:
|
||||
# indices of beams of current group among all sentences in batch
|
||||
batch_group_indices = []
|
||||
|
||||
if output_scores:
|
||||
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_group_indices.extend(
|
||||
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
||||
|
Loading…
Reference in New Issue
Block a user