[Whisper] 🚨 Fix whisper decoding 🚨 (#34135)

* do not remove decoder_input_ids for the first segment

* do not remove eos token in generate_with_fallback

* when removing padding tokens, do not remove eos token

* remove eos token in generate (and not in generate_with_fallback!)

* reconciliate short-from/ long-form behavior

* correct avg_logprobs calculation

* handle eos token in segments

* handle decoder_input_ids and eos token in _prepare_decoder_input_ids

* fix incorrect time precision

* always remove eos token

* always remove decoder_input_ids

* no need to handle decoder_inputs_ids and eos token

* no need to remove decoder_input_ids

* no need to handle eos token

* fix num_beams in _retrieve_logit_processors

* remove todo unconsistency

* no need to add eos token

* last_timestamp_pos should indeed be timestamp token pos

* patch generate to enable compatibility with GenerationTesterMixin tests

* adapt test_generate_continue_from_past_key_values

* adapt test_prompt_lookup_decoding_matches_greedy_search

* adapt generic GenerationMixin tests to whisper's generate

* fix speculative decoding

* fix

* [run-slow] whisper

* change HF_HUB_TOKEN for require_read_token

* [run-slow] whisper

* prioritize kwargs over generation_config

* remove unnecessary args

* [run-slow] whisper

* update tests

* [run-slow] whisper

* add comment

* update test

* [run-slow] whisper

* update test + revert require_read_token

* docstring updates

* revert tokenizer decode args change

* do not use a patch + docstring updates

* [run-slow] whisper

* make

* [run-slow] whisper

* add a flag to force unique call to generate

* test update

* [run-slow] whisper

* add force_unique_generate_call arg

* do not use a patch

* correct the timestamps for the pad tokens

* docstring update

* docstring update

* docstring update

* upodate TF tests

* add require_read_token

* [run-slow] whisper

* test reset dynamo

* [run-slow] whisper

* fix

* [run-slow] whisper

* avoid iterating twice on current_segments

* [run-slow] whisper

* [run-slow] whisper

---------

Co-authored-by: Eustache Le Bihan <eustlb@users.noreply.huggingface.co>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
eustlb 2024-12-18 14:13:21 +01:00 committed by GitHub
parent f1b7634fc8
commit da334bcfa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 271 additions and 124 deletions

View File

@ -133,9 +133,12 @@ def _pad_to_max_length(
padding="longest",
bos_token_tensor=None,
cut_off_length=None,
return_token_timestamps=False,
force_unique_generate_call=False,
):
max_total_length = 0
sequences = []
token_timestamps_list = []
if padding_side not in ["right", "left"]:
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
@ -145,31 +148,74 @@ def _pad_to_max_length(
elif padding == "max_length" and cut_off_length is None:
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
if force_unique_generate_call:
sequences_list = []
timestamps_list = []
for segments in current_segments:
result = segments[0]["result"]
sequences_list.append(result if isinstance(result, torch.Tensor) else result["sequences"])
if return_token_timestamps:
timestamps_list.append(result["token_timestamps"])
sequences = torch.stack(sequences_list, dim=0)
if return_token_timestamps:
token_timestamps = torch.stack(timestamps_list, dim=0)
return sequences, token_timestamps
return sequences
for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
if return_token_timestamps:
token_timestamps = torch.cat(
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
dim=-1,
)
if cut_off_length is not None:
sequence = sequence[-cut_off_length:]
if return_token_timestamps:
token_timestamps = token_timestamps[-cut_off_length:]
if bos_token_tensor is not None:
sequence = torch.cat([bos_token_tensor, sequence])
if return_token_timestamps:
token_timestamps = torch.cat(
[torch.ones_like(bos_token_tensor, device=device) * 0.0, token_timestamps]
)
sequences.append(sequence)
if return_token_timestamps:
token_timestamps_list.append(token_timestamps)
max_total_length = max(max_total_length, len(sequences[-1]))
elif bos_token_tensor is not None:
sequences.append(bos_token_tensor)
if return_token_timestamps:
token_timestamps_list.append(torch.ones_like(bos_token_tensor, device=device) * 0.0)
else:
sequences.append(torch.tensor([], device=device))
if return_token_timestamps:
token_timestamps_list.append(torch.tensor([], device=device))
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
if return_token_timestamps:
token_timestamps_list[i] = F.pad(
token_timestamps_list[i],
pad=pad,
value=token_timestamps_list[i][-1] if len(token_timestamps_list[i]) > 0 else 0.0,
)
sequences = torch.stack(sequences, dim=0)
return sequences
if return_token_timestamps:
token_timestamps = torch.stack(token_timestamps_list, dim=0)
return sequences, token_timestamps
else:
return sequences
class WhisperGenerationMixin(GenerationMixin):
@ -312,6 +358,7 @@ class WhisperGenerationMixin(GenerationMixin):
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
force_unique_generate_call: Optional[bool] = None,
**kwargs,
):
"""
@ -432,27 +479,39 @@ class WhisperGenerationMixin(GenerationMixin):
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
`return_segments` is set True. In this case the generation outputs of each segment is added to each
segment.
force_unique_generate_call (`bool`, *optional*):
Whether to force a unique call to the underlying GenerationMixin's generate method. This is useful for assisted decoding and testing purposes to ensure
that only one call to generate is made and therefore decoder input token ids and eos token ids are returned.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
[`~utils.ModelOutput`] or `Dict[str, Any]` or `torch.LongTensor`:
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
A:
- [`~utils.ModelOutput`] when `return_dict_in_generate=True` and (`return_timestamps=False` or `force_unique_generate_call=True`), including the decoder input ids and end of sequence id.
- `Dict[str, Any]` when (`return_dict_in_generate=True` and `return_timestamps=True`) or `return_segments=True` or `return_token_timestamps=True`.
- `torch.LongTensor` in all other cases, excluding the decoder input ids and end of sequence id.
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
The possible [`~utils.ModelOutput`] types are:
- [`~utils.GenerateEncoderDecoderOutput`]
- [`~utils.GenerateBeamEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
`segments` is a list of lists (one list per batch element) of `segment`.
A `segment` is a dictionary with keys `start`, `end`, `tokens`, `idxs`, and `result`.
- `start`: the start timestamp of the segment.
- `end`: the end timestamp of the segment.
- `tokens`: the tokens of the segment, excluding the decoder input ids and end of sequence id.
- `idxs`: the start (included) and end (excluded) indices of the `tokens` of the segment in the underlying call to GenerationMixin's `generate` (present in `result`).
- `result`: the result of the underlying call to GenerationMixin's `generate`.
else only the generated output sequence ids are returned.
When `return_timestamps=True`, `return_dict_in_generate=True` applies to each call of the underlying GenerationMixin's `generate`, with outputs stored in `result` of each `segment`.
Example:
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. It is necessary to set `return_timestamps=True`.
Indeed, long-form transcription uses a sequential algorithm based on timestamps predictions, with heuristics like compression ratio threshold, log probability threshold and temperature fallback. This algorithm is described in the [the Whisper original paper](https://cdn.openai.com/papers/whisper.pdf), section *3.8. Long-form Transcription*.
```python
>>> import torch
@ -483,7 +542,9 @@ class WhisperGenerationMixin(GenerationMixin):
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
```
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
- *Shortform transcription*: If passed mel input features are <= 30 seconds, there are two possibilities:
- `return_timestamps=False`: the whole audio will be transcribed with a single call to GenerationMixin's generate.
- `return_timestamps=True`: the audio will be transcribed using the same logic as long-form transcription.
```python
>>> import torch
@ -570,11 +631,21 @@ class WhisperGenerationMixin(GenerationMixin):
# 3. Retrieve logits processors
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
begin_index = init_tokens.shape[1]
num_beams = kwargs.get(
"num_beams",
generation_config.num_beams
if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
else 1,
)
if "assistant_model" in kwargs:
# speculative decoding: the model should be able to return eos token
generation_config.begin_suppress_tokens = None
logits_processor = self._retrieve_logit_processors(
generation_config=generation_config,
logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token
num_beams=kwargs.get("num_beams", 1),
num_beams=num_beams,
device=device,
)
@ -618,6 +689,19 @@ class WhisperGenerationMixin(GenerationMixin):
batch_size=cur_bsz,
generation_config=generation_config,
)
# 5bis speculative decoding: ensure the assistant model does only one call to generate and therefore returns decoder input token ids and eos token id
# we set a flag in the generation config to force the model to make only one call to generate and return the decoder input token ids and eos token id
if "assistant_model" in kwargs:
assistant_model = kwargs["assistant_model"]
assistant_model.generation_config.force_unique_generate_call = True
if force_unique_generate_call is None:
if hasattr(generation_config, "force_unique_generate_call"):
force_unique_generate_call = generation_config.force_unique_generate_call
elif hasattr(self.generation_config, "force_unique_generate_call"):
force_unique_generate_call = self.generation_config.force_unique_generate_call
else:
force_unique_generate_call = False
# 6 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
@ -729,14 +813,15 @@ class WhisperGenerationMixin(GenerationMixin):
prev_idx=prev_i,
idx=i,
return_token_timestamps=return_token_timestamps,
decoder_input_ids=decoder_input_ids,
)
seek[prev_i] += segment_offset
current_segments[prev_i] += segments
if is_shortform:
seek[prev_i] += max_frames[i]
else:
seek[prev_i] += segment_offset
if force_unique_generate_call:
break
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
@ -746,51 +831,62 @@ class WhisperGenerationMixin(GenerationMixin):
else current_segments
)
sequences = _pad_to_max_length(
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
)
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": final_segments}
if is_shortform:
# add eos token:
if generation_config.max_new_tokens is None and generation_config.max_length is None:
eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
sequences = torch.cat([sequences, eos_tokens], dim=-1)
if return_token_timestamps:
outputs = {}
outputs["sequences"] = sequences
outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
else:
outputs = sequences
if return_dict_in_generate and generation_config.return_dict_in_generate:
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
if num_return_sequences > 1:
if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
dict_outputs.encoder_attentions = tuple(
dict_outputs.encoder_attentions[i][::num_return_sequences]
for i in range(len(dict_outputs.encoder_attentions))
)
if (
hasattr(dict_outputs, "encoder_hidden_states")
and dict_outputs.encoder_hidden_states is not None
):
dict_outputs.encoder_hidden_states = tuple(
dict_outputs.encoder_hidden_states[i][::num_return_sequences]
for i in range(len(dict_outputs.encoder_hidden_states))
)
if return_token_timestamps:
dict_outputs["token_timestamps"] = outputs["token_timestamps"]
return dict_outputs
# if return_dict_in_generate=True and we forced a unique call to generate or return_timestamps=False, meaning we are sure only one call to generate has been made,
# -> we can return a ModelOutput
# otherwise, return_dict_in_generate is applied in the 'result' of each segment in final_segments
if (
return_dict_in_generate
and generation_config.return_dict_in_generate
and (force_unique_generate_call or not return_timestamps)
):
# only one call to generate_with_fallback, we can return a ModelOutput
outputs = self._stack_split_outputs(seek_outputs, model_output_type, self.device, kwargs)
if num_return_sequences > 1:
if hasattr(outputs, "encoder_attentions") and outputs.encoder_attentions is not None:
outputs.encoder_attentions = tuple(
outputs.encoder_attentions[i][::num_return_sequences]
for i in range(len(outputs.encoder_attentions))
)
if hasattr(outputs, "encoder_hidden_states") and outputs.encoder_hidden_states is not None:
outputs.encoder_hidden_states = tuple(
outputs.encoder_hidden_states[i][::num_return_sequences]
for i in range(len(outputs.encoder_hidden_states))
)
return outputs
return sequences
padded_outputs = _pad_to_max_length(
current_segments=final_segments,
pad_token_id=generation_config.pad_token_id,
device=self.device,
padding_side="right",
return_token_timestamps=return_token_timestamps,
force_unique_generate_call=force_unique_generate_call,
)
if return_dict_in_generate and generation_config.return_dict_in_generate:
logger.warning_once(
"You have passed `return_dict_in_generate=True` and `return_timestamps=True`, this automatically sets `return_segments=True` to access the resuls of the underlying calls to GenerationMixin's generate in the returned `segments`."
)
return_segments = True
elif not return_segments and not return_token_timestamps:
return padded_outputs
if return_token_timestamps:
sequences, token_timestamps = padded_outputs
outputs = {
"sequences": sequences,
"token_timestamps": token_timestamps,
}
else:
sequences = padded_outputs
outputs = {
"sequences": sequences,
}
if return_segments:
outputs["segments"] = final_segments
return outputs
def generate_with_fallback(
self,
@ -886,22 +982,14 @@ class WhisperGenerationMixin(GenerationMixin):
new_decoder_attention_mask = []
for i, seek_sequence in enumerate(seek_sequences):
# make sure we cut a predicted EOS token if we are not finished with the generation yet
prev_i = batch_idx_map[fallback_index_map[i]]
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
# remove eos token id
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
if return_token_timestamps and not is_shortform:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
# remove all padding tokens
# remove all padding tokens, except for the eos token
if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings]
if return_token_timestamps and not is_shortform:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
if generation_config.pad_token_id == generation_config.eos_token_id:
# we do not remove the eos token id since it is needed for avg logprob calculation in _need_fallback
num_paddings -= 1
if num_paddings != 0:
seek_sequence = seek_sequence[:-num_paddings]
# check which sequences in batch need fallback & which should be skipped
needs_fallback[i], should_skip[i] = self._need_fallback(
@ -914,6 +1002,10 @@ class WhisperGenerationMixin(GenerationMixin):
temperature,
)
# remove eos token
if seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
seek_sequence_list[fallback_index_map[i]] = seek_sequence
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
is_low_temperature = temperature is None or temperature < 0.5
@ -956,14 +1048,19 @@ class WhisperGenerationMixin(GenerationMixin):
return current_segments
def _postprocess_outputs(
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
self,
seek_outputs,
decoder_input_ids,
return_token_timestamps,
generation_config,
is_shortform,
):
# remove all previously passed decoder input ids
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
# should happen only if it is the first generated segment
start_idx = decoder_input_ids.shape[-1]
if isinstance(seek_outputs, torch.Tensor):
seek_outputs = seek_outputs[:, start_idx:]
return seek_outputs, seek_outputs
return seek_outputs[:, start_idx:], seek_outputs
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
@ -973,9 +1070,6 @@ class WhisperGenerationMixin(GenerationMixin):
num_frames=num_frames,
num_input_ids=decoder_input_ids.shape[-1],
)
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
if beam_indices is not None and key == "scores":
@ -1011,7 +1105,7 @@ class WhisperGenerationMixin(GenerationMixin):
return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"]
sequence_tokens = seek_outputs["sequences"][:, start_idx:]
seek_outputs = [
{
k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices"))
@ -1026,7 +1120,7 @@ class WhisperGenerationMixin(GenerationMixin):
# Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
outputs = {}
for key in seek_outputs[0].keys():
if key in ["sequences", "beam_indices"]:
if key in ["sequences", "beam_indices", "token_timestamps"]:
outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
elif key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
outputs[key] = tuple(
@ -1057,6 +1151,10 @@ class WhisperGenerationMixin(GenerationMixin):
else:
outputs[key] = None
token_timestamps = outputs.get("token_timestamps", None)
if token_timestamps is not None:
model_output_type = dict
return model_output_type(**outputs)
def _need_fallback(
@ -1083,7 +1181,9 @@ class WhisperGenerationMixin(GenerationMixin):
else:
scores = seek_outputs[index]["scores"]
logprobs = self._retrieve_avg_logprobs(
scores, seek_sequence, generation_config.eos_token_id, temperature
scores,
seek_sequence,
temperature,
)
if logprobs < generation_config.logprob_threshold:
@ -1179,13 +1279,6 @@ class WhisperGenerationMixin(GenerationMixin):
if no_speech_threshold is not None:
logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
# when passing temperature as a list it cannot just be ignored => throw error in this case
if isinstance(temperature, (list, tuple)):
raise ValueError(
f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
)
@staticmethod
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
if return_dict_in_generate is None:
@ -1768,7 +1861,7 @@ class WhisperGenerationMixin(GenerationMixin):
return compression_ratio
@staticmethod
def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
def _retrieve_avg_logprobs(scores, tokens, temperature):
rescale_temperature = temperature if temperature > 0.0 else 1
scores = torch.stack(scores).to(tokens.device)
@ -1780,10 +1873,10 @@ class WhisperGenerationMixin(GenerationMixin):
logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
# retrieve logprob of selected tokens and sum
sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
# don't remove the eos token logprob! it counts in avg_logprob calculation in the original implementation
sum_logprobs = sum(logprobs[i][tokens[i]] for i in range(logprobs.shape[0]))
avg_logprobs = sum_logprobs / (length + 1)
avg_logprobs = sum_logprobs / len(tokens)
return avg_logprobs
@staticmethod
@ -1799,6 +1892,7 @@ class WhisperGenerationMixin(GenerationMixin):
prev_idx,
idx,
return_token_timestamps,
decoder_input_ids,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
@ -1807,6 +1901,7 @@ class WhisperGenerationMixin(GenerationMixin):
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
idx_offset = decoder_input_ids.shape[-1]
device = seek_sequence.device
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
@ -1838,12 +1933,13 @@ class WhisperGenerationMixin(GenerationMixin):
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
* time_precision,
"tokens": sliced_tokens,
"idxs": (idx_offset + last_slice, idx_offset + current_slice),
"result": seek_outputs[idx],
}
)
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
token_timestamps[idx_offset + last_slice : idx_offset + current_slice] + time_offset[prev_idx]
)
last_slice = current_slice
@ -1871,11 +1967,14 @@ class WhisperGenerationMixin(GenerationMixin):
"start": time_offset[prev_idx],
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
"tokens": seek_sequence,
"idxs": (idx_offset, idx_offset + len(seek_sequence)),
"result": seek_outputs[idx],
}
]
if return_token_timestamps:
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segments[-1]["token_timestamps"] = (
token_timestamps[idx_offset : idx_offset + len(seek_sequence)] + time_offset[prev_idx]
)
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset

View File

@ -17,14 +17,22 @@
from __future__ import annotations
import inspect
import os
import tempfile
import traceback
import unittest
import numpy as np
from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow
from transformers import GenerationConfig, WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
from transformers.testing_utils import (
is_tf_available,
require_read_token,
require_tf,
require_tokenizers,
run_test_in_subprocess,
slow,
)
from transformers.utils import cached_property
from transformers.utils.import_utils import is_datasets_available
@ -749,7 +757,9 @@ def _test_large_generation(in_queue, out_queue, timeout):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
input_features,
do_sample=False,
max_length=20,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@ -772,13 +782,29 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
ds = load_dataset("legacy-datasets/common_voice", "ja", split="test", streaming=True, trust_remote_code=True)
# update generation config
generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
token = os.getenv("HF_HUB_READ_TOKEN", True)
ds = load_dataset(
"mozilla-foundation/common_voice_6_1",
"ja",
split="test",
streaming=True,
trust_remote_code=True,
token=token,
)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
input_features,
do_sample=False,
max_length=20,
language="<|ja|>",
task="transcribe",
generation_config=generation_config,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@ -786,7 +812,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
input_features,
do_sample=False,
max_length=20,
language="<|en|>",
task="transcribe",
generation_config=generation_config,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@ -794,7 +825,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
input_features,
do_sample=False,
max_length=20,
language="<|ja|>",
task="translate",
generation_config=generation_config,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
@ -825,10 +861,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
# fmt: off
EXPECTED_IDS = [
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
[50258, 50259, 50359, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
[50258, 50259, 50359, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
[50258, 50259, 50359, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
[50258, 50259, 50359, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
]
# fmt: on
@ -836,10 +872,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
# fmt: off
EXPECTED_TRANSCRIPT = [
" Mr. Quilter is the apostle of the middle classes and we are glad to",
" Mr. Quilter is the apostle of the middle classes and we are glad",
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
" He tells us that at this festive season of the year, with Christmas and roast",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all"
]
# fmt: on
@ -1009,6 +1045,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)
@slow
@require_read_token
def test_large_generation_multilingual(self):
run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)

View File

@ -445,6 +445,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
self.maxDiff = 3000
def prepare_config_and_inputs_for_generate(self, batch_size=2):
config, inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
inputs_dict["force_unique_generate_call"] = True
return config, inputs_dict
def test_config(self):
self.config_tester.run_common_tests()
@ -1891,8 +1896,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
"ja",
split="test",
streaming=True,
token=token,
trust_remote_code=True,
token=token,
)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
@ -2144,11 +2149,16 @@ class WhisperModelIntegrationTests(unittest.TestCase):
},
{
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
"timestamp": (39.80, 45.36),
# "timestamp": (39.80, 45.36),
# above is the expected output on A100.
# on CI T4s, due to sligth difference in floating points operations, expected is below
"timestamp": (39.80, 45.38),
},
{
"text": " can discover in it but little of rocky Ithaca.",
"timestamp": (45.36, 49.0),
# "timestamp": (45.36, 49.0),
# see above
"timestamp": (45.38, 49.0),
},
{
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
@ -2275,20 +2285,20 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
[0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400],
[0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400],
[0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200],
[0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000],
[0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800],
[0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600]
[0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200]
])
# fmt: on
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
@slow
def test_large_token_timestamp_generation(self):
def test_small_token_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.to(torch_device)
input_speech = self._load_datasamples(4)
@ -2305,10 +2315,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
[0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
[0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
[0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000],
[0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800]
[0.0000, 0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
[0.0000, 0.0000, 0.7600, 1.0000, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.1800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
[0.0000, 0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600],
[0.0000, 0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1400, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800]
])
# fmt: on
@ -3331,6 +3341,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
torch._dynamo.config.cache_size_limit = 4
torch._dynamo.reset()
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")