correcting group beam search function output score bug (#13211)

This commit is contained in:
sourabh112 2021-08-23 16:57:24 +05:30 committed by GitHub
parent f689743e74
commit b13c6c18d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2403,6 +2403,9 @@ class GenerationMixin:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
@ -2411,9 +2414,6 @@ class GenerationMixin:
# indices of beams of current group among all sentences in batch
batch_group_indices = []
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]