mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[Whisper + beam search] fix usage of beam_indices
(#38259)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* tmp * fix test_tiny_token_timestamp_batch_generation * better comments * test * comments * Apply suggestions from code review Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
parent
3e960e032d
commit
a6b51e7341
@ -231,42 +231,52 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
tensor containing the timestamps in seconds for each predicted token
|
||||
"""
|
||||
# Create a list with `decoder_layers` elements, each a tensor of shape
|
||||
# (batch size, attention_heads, output length, input length).
|
||||
# (batch size * num beams, attention_heads, output length, input length).
|
||||
cross_attentions = []
|
||||
for i in range(self.config.decoder_layers):
|
||||
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
|
||||
|
||||
# Select specific cross-attention layers and heads. This is a tensor
|
||||
# of shape (batch size, num selected, output length, input length).
|
||||
# of shape (batch size * num beams, num selected heads, output length, input length).
|
||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||
weights = weights.permute([1, 0, 2, 3])
|
||||
|
||||
weight_length = None
|
||||
|
||||
if "beam_indices" in generate_outputs:
|
||||
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
|
||||
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
||||
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
|
||||
# If beam search was used, the sequence length of the outputs may not be the real sequence length:
|
||||
# beam search may end up returning a sequence that finished a few steps earlier while decoding.
|
||||
# In that case, the `cross_attentions` weights are too long and we have to make sure that they have
|
||||
# the right `output_length`
|
||||
|
||||
# get the real sequence length of the longest sequence, crop the beam_indices to the real length
|
||||
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
|
||||
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
|
||||
beam_indices = generate_outputs.beam_indices[:, :weight_length]
|
||||
|
||||
# beam search takes `decoder_input_ids` into account in the `beam_indices` length
|
||||
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
|
||||
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
|
||||
# we actually shift the beam indices here
|
||||
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
|
||||
|
||||
weights = weights[:, :, :weight_length]
|
||||
# The first forward pass (prefill) may have processed more than one token and, therefore, contain
|
||||
# cross-attention weights for several tokens.
|
||||
# Let's unroll the first `beam_indices` accordingly, so we can use it to gather the weights.
|
||||
if num_input_ids is not None and num_input_ids > 1:
|
||||
# `-1`: `beam_indices` can be used as-is to gather the weights when `num_input_ids` is 1
|
||||
weight_length += num_input_ids - 1
|
||||
beam_indices_first_step_unrolled = (
|
||||
torch.ones(beam_indices.shape[0], num_input_ids - 1, device=beam_indices.device, dtype=torch.long)
|
||||
* (beam_indices[:, 0:1])
|
||||
)
|
||||
unrolled_beam_indices = torch.cat([beam_indices_first_step_unrolled, beam_indices], dim=-1)
|
||||
else:
|
||||
unrolled_beam_indices = beam_indices
|
||||
|
||||
# If beam index is still -1, it means that the associated token id is EOS
|
||||
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
|
||||
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
|
||||
unrolled_beam_indices = unrolled_beam_indices.masked_fill(unrolled_beam_indices == -1, 0)
|
||||
|
||||
# Select the cross attention from the right beam for each output sequences
|
||||
# Select the cross attention from the right beam for each output sequence, up to the real sequence
|
||||
# length (`weight_length`)
|
||||
weights = torch.stack(
|
||||
[
|
||||
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
|
||||
for i in range(beam_indices.shape[1])
|
||||
torch.index_select(weights[:, :, i, :], dim=0, index=unrolled_beam_indices[:, i])
|
||||
for i in range(unrolled_beam_indices.shape[1])
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
@ -2155,7 +2155,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# task id and lang id prompts should not have timestamp tokens
|
||||
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
|
||||
|
||||
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
|
||||
|
||||
@slow
|
||||
|
Loading…
Reference in New Issue
Block a user