[Whisper + beam search] fix usage of beam_indices (#38259)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* tmp

* fix test_tiny_token_timestamp_batch_generation

* better comments

* test

* comments

* Apply suggestions from code review

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
Joao Gante 2025-05-23 11:05:44 +01:00 committed by GitHub
parent 3e960e032d
commit a6b51e7341
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 18 deletions

View File

@ -231,42 +231,52 @@ class WhisperGenerationMixin(GenerationMixin):
tensor containing the timestamps in seconds for each predicted token
"""
# Create a list with `decoder_layers` elements, each a tensor of shape
# (batch size, attention_heads, output length, input length).
# (batch size * num beams, attention_heads, output length, input length).
cross_attentions = []
for i in range(self.config.decoder_layers):
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
# Select specific cross-attention layers and heads. This is a tensor
# of shape (batch size, num selected, output length, input length).
# of shape (batch size * num beams, num selected heads, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])
weight_length = None
if "beam_indices" in generate_outputs:
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
# If beam search was used, the sequence length of the outputs may not be the real sequence length:
# beam search may end up returning a sequence that finished a few steps earlier while decoding.
# In that case, the `cross_attentions` weights are too long and we have to make sure that they have
# the right `output_length`
# get the real sequence length of the longest sequence, crop the beam_indices to the real length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
beam_indices = generate_outputs.beam_indices[:, :weight_length]
# beam search takes `decoder_input_ids` into account in the `beam_indices` length
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
# we actually shift the beam indices here
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
weights = weights[:, :, :weight_length]
# The first forward pass (prefill) may have processed more than one token and, therefore, contain
# cross-attention weights for several tokens.
# Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights.
if num_input_ids is not None and num_input_ids > 1:
# `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1
weight_length += num_input_ids - 1
beam_indices_first_step_unrolled = (
torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long)
* (beam_indices[:, 0:1])
)
unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1)
else:
unrolled_beam_indices = beam_indices
# If beam index is still -1, it means that the associated token id is EOS
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0)
# Select the cross attention from the right beam for each output sequences
# Select the cross attention from the right beam for each output sequence, up to the real sequence
# length (`weight_length`)
weights = torch.stack(
[
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
for i in range(beam_indices.shape[1])
torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i])
for i in range(unrolled_beam_indices.shape[1])
],
dim=2,
)

View File

@ -2155,7 +2155,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# task id and lang id prompts should not have timestamp tokens
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
@slow