mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
correcting group beam search function output score bug (#13211)
This commit is contained in:
parent
f689743e74
commit
b13c6c18d0
@ -2403,6 +2403,9 @@ class GenerationMixin:
|
|||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
continue # don't waste resources running the code we don't need
|
continue # don't waste resources running the code we don't need
|
||||||
|
|
||||||
|
if output_scores:
|
||||||
|
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
||||||
|
|
||||||
for beam_group_idx in range(num_beam_groups):
|
for beam_group_idx in range(num_beam_groups):
|
||||||
group_start_idx = beam_group_idx * num_sub_beams
|
group_start_idx = beam_group_idx * num_sub_beams
|
||||||
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||||
@ -2411,9 +2414,6 @@ class GenerationMixin:
|
|||||||
# indices of beams of current group among all sentences in batch
|
# indices of beams of current group among all sentences in batch
|
||||||
batch_group_indices = []
|
batch_group_indices = []
|
||||||
|
|
||||||
if output_scores:
|
|
||||||
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
|
||||||
|
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
batch_group_indices.extend(
|
batch_group_indices.extend(
|
||||||
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
||||||
|
Loading…
Reference in New Issue
Block a user