From a6b51e7341d702127a4a45f37439640840b5abf0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 May 2025 11:05:44 +0100 Subject: [PATCH] [Whisper + beam search] fix usage of `beam_indices` (#38259) * tmp * fix test_tiny_token_timestamp_batch_generation * better comments * test * comments * Apply suggestions from code review Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --- .../models/whisper/generation_whisper.py | 44 ++++++++++++------- tests/models/whisper/test_modeling_whisper.py | 1 - 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 4c29d456bf9..db362355b87 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -231,42 +231,52 @@ class WhisperGenerationMixin(GenerationMixin): tensor containing the timestamps in seconds for each predicted token """ # Create a list with `decoder_layers` elements, each a tensor of shape - # (batch size, attention_heads, output length, input length). + # (batch size * num beams, attention_heads, output length, input length). cross_attentions = [] for i in range(self.config.decoder_layers): cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2)) # Select specific cross-attention layers and heads. This is a tensor - # of shape (batch size, num selected, output length, input length). + # of shape (batch size * num beams, num selected heads, output length, input length). weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) weights = weights.permute([1, 0, 2, 3]) weight_length = None if "beam_indices" in generate_outputs: - # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths - # since the beam search strategy chooses the most probable sequences at the end of the search. - # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length + # If beam search was used, the sequence length of the outputs may not be the real sequence length: + # beam search may end up returning a sequence that finished a few steps earlier while decoding. + # In that case, the `cross_attentions` weights are too long and we have to make sure that they have + # the right `output_length` + + # get the real sequence length of the longest sequence, crop the beam_indices to the real length weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() - weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids + beam_indices = generate_outputs.beam_indices[:, :weight_length] - # beam search takes `decoder_input_ids` into account in the `beam_indices` length - # but forgot to shift the beam_indices by the number of `decoder_input_ids` - beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length]) - # we actually shift the beam indices here - beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids] - - weights = weights[:, :, :weight_length] + # The first forward pass (prefill) may have processed more than one token and, therefore, contain + # cross-attention weights for several tokens. + # Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights. + if num_input_ids is not None and num_input_ids > 1: + # `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1 + weight_length += num_input_ids - 1 + beam_indices_first_step_unrolled = ( + torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long) + * (beam_indices[:, 0:1]) + ) + unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1) + else: + unrolled_beam_indices = beam_indices # If beam index is still -1, it means that the associated token id is EOS # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. - beam_indices = beam_indices.masked_fill(beam_indices == -1, 0) + unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0) - # Select the cross attention from the right beam for each output sequences + # Select the cross attention from the right beam for each output sequence, up to the real sequence + # length (`weight_length`) weights = torch.stack( [ - torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i]) - for i in range(beam_indices.shape[1]) + torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i]) + for i in range(unrolled_beam_indices.shape[1]) ], dim=2, ) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2085d9f2844..37e459db983 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -2155,7 +2155,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): # task id and lang id prompts should not have timestamp tokens self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1]) - self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples) @slow