[Whisper] Pipeline: handle long form generation (#35750)

* handle long form generation

* add warning

* correct incorrect in place token change

* update test to catch edge case

* make style

* update warning

* add doc
This commit is contained in:
eustlb 2025-06-26 16:33:31 +02:00 committed by GitHub
parent 02ecdcfc0f
commit cfff7ca9a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 17 deletions

View File

@ -136,7 +136,18 @@ def _pad_to_max_length(
cut_off_length=None, cut_off_length=None,
return_token_timestamps=False, return_token_timestamps=False,
force_unique_generate_call=False, force_unique_generate_call=False,
skip_ending_double_timestamps=False,
timestamp_begin=None,
): ):
"""
skip_ending_double_timestamps: when the segement ended with two timestamp tokens, whether to ignore the last timestamp token
see https://github.com/huggingface/transformers/pull/35750
_pad_to_max_length is used in different contexts:
1. At the end of generation: we need to keep both ending timestamp tokens in the segment (see https://github.com/huggingface/transformers/pull/34537).
2. In the middle of generation, e.g. when condition_on_prev_tokens=True and we want to use the last generated tokens as decoder_input_ids:
we must skip one of the double ending timestamp tokens (see https://github.com/huggingface/transformers/pull/35750).
"""
max_total_length = 0 max_total_length = 0
sequences = [] sequences = []
token_timestamps_list = [] token_timestamps_list = []
@ -166,7 +177,17 @@ def _pad_to_max_length(
for current_segment_list in current_segments: for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) sequences_list = []
for d in current_segment_list:
if skip_ending_double_timestamps and len(d["tokens"]) > 2 and d["tokens"][-2] >= timestamp_begin:
# the segment finishes with two timestamp tokens
# we need to ignore the last timestamp token
# see https://github.com/huggingface/transformers/pull/34537
sequences_list.append(d["tokens"][:-1])
else:
sequences_list.append(d["tokens"])
sequence = torch.cat(sequences_list, dim=-1)
if return_token_timestamps: if return_token_timestamps:
token_timestamps = torch.cat( token_timestamps = torch.cat(
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list], [d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
@ -1809,14 +1830,6 @@ class WhisperGenerationMixin(GenerationMixin):
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
for segments in active_segments:
for seg in segments:
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
# the segment finishes with two timestamp tokens
# we need to ignore the last timestamp token
# see https://github.com/huggingface/transformers/pull/34537
seg["tokens"] = seg["tokens"][:-1]
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments": if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
prev_ids = prompt_ids prev_ids = prompt_ids
else: else:
@ -1833,6 +1846,8 @@ class WhisperGenerationMixin(GenerationMixin):
padding=padding, padding=padding,
bos_token_tensor=prev_ids, bos_token_tensor=prev_ids,
cut_off_length=cut_off_length, cut_off_length=cut_off_length,
skip_ending_double_timestamps=True,
timestamp_begin=timestamp_begin,
) )
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)

View File

@ -910,7 +910,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
return token_ids return token_ids
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision): def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500):
""" """
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
the various options not allowed in other seq2seq models the various options not allowed in other seq2seq models
@ -962,6 +962,12 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
last_timestamp = None last_timestamp = None
first_timestamp = timestamp_begin first_timestamp = timestamp_begin
# long form generation: we need to handle the case where the call to generate returns concatenated segments,
# with underlying multiple calls to generate
cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0
if "stride" in output: if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"] chunk_len, stride_left, stride_right = output["stride"]
# Offset the timings to account for the other `model_outputs`. # Offset the timings to account for the other `model_outputs`.
@ -1024,7 +1030,24 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
pass pass
elif token >= timestamp_begin: elif token >= timestamp_begin:
# 3/ Timestamp token # 3/ Timestamp token
time = (token - timestamp_begin) * time_precision + time_offset
timestamp = float((token - timestamp_begin) * time_precision)
if timestamp < cur_max_timestamp:
# next segment has started
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp
penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp
time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len
time = round(time, 2) time = round(time, 2)
if last_timestamp and token >= last_timestamp: if last_timestamp and token >= last_timestamp:
# Whisper outputted a timestamp token, but it falls within # Whisper outputted a timestamp token, but it falls within

View File

@ -283,13 +283,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# No parameters on this pipeline right now # No parameters on this pipeline right now
preprocess_params = {} preprocess_params = {}
if chunk_length_s is not None: if chunk_length_s is not None:
if self.type == "seq2seq" and not ignore_warning: if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning:
logger.warning( type_warning = (
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily" "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
" be entirely accurate and will have caveats. More information:" " be entirely accurate and will have caveats. More information:"
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...," " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)" " ignore_warning=True)."
) )
if self.type == "seq2seq_whisper":
type_warning += (
" To use Whisper for long-form transcription, use rather the model's `generate` method directly "
"as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. "
"Long-form Transcription)."
)
logger.warning(type_warning)
preprocess_params["chunk_length_s"] = chunk_length_s preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None: if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s preprocess_params["stride_length_s"] = stride_length_s

View File

@ -2031,11 +2031,13 @@ class WhisperModelIntegrationTests(unittest.TestCase):
).input_features ).input_features
input_features = input_features.to(torch_device) input_features = input_features.to(torch_device)
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu") generated_ids = model.generate(
input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True
).to("cpu")
# fmt: off # fmt: off
EXPECTED_OUTPUT = torch.tensor([ EXPECTED_OUTPUT = torch.tensor([
50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430 [50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50431]
]) ])
# fmt: on # fmt: on
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT) torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
@ -2078,7 +2080,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
}, },
{ {
"text": (" and can discover"), "text": (" and can discover"),
"timestamp": (28.68, 29.98), "timestamp": (28.68, 30.0),
}, },
], ],
} }