mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[Whisper] Pipeline: handle long form generation (#35750)
* handle long form generation * add warning * correct incorrect in place token change * update test to catch edge case * make style * update warning * add doc
This commit is contained in:
parent
02ecdcfc0f
commit
cfff7ca9a2
@ -136,7 +136,18 @@ def _pad_to_max_length(
|
|||||||
cut_off_length=None,
|
cut_off_length=None,
|
||||||
return_token_timestamps=False,
|
return_token_timestamps=False,
|
||||||
force_unique_generate_call=False,
|
force_unique_generate_call=False,
|
||||||
|
skip_ending_double_timestamps=False,
|
||||||
|
timestamp_begin=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
skip_ending_double_timestamps: when the segement ended with two timestamp tokens, whether to ignore the last timestamp token
|
||||||
|
see https://github.com/huggingface/transformers/pull/35750
|
||||||
|
|
||||||
|
_pad_to_max_length is used in different contexts:
|
||||||
|
1. At the end of generation: we need to keep both ending timestamp tokens in the segment (see https://github.com/huggingface/transformers/pull/34537).
|
||||||
|
2. In the middle of generation, e.g. when condition_on_prev_tokens=True and we want to use the last generated tokens as decoder_input_ids:
|
||||||
|
we must skip one of the double ending timestamp tokens (see https://github.com/huggingface/transformers/pull/35750).
|
||||||
|
"""
|
||||||
max_total_length = 0
|
max_total_length = 0
|
||||||
sequences = []
|
sequences = []
|
||||||
token_timestamps_list = []
|
token_timestamps_list = []
|
||||||
@ -166,7 +177,17 @@ def _pad_to_max_length(
|
|||||||
|
|
||||||
for current_segment_list in current_segments:
|
for current_segment_list in current_segments:
|
||||||
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
|
||||||
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
|
sequences_list = []
|
||||||
|
for d in current_segment_list:
|
||||||
|
if skip_ending_double_timestamps and len(d["tokens"]) > 2 and d["tokens"][-2] >= timestamp_begin:
|
||||||
|
# the segment finishes with two timestamp tokens
|
||||||
|
# we need to ignore the last timestamp token
|
||||||
|
# see https://github.com/huggingface/transformers/pull/34537
|
||||||
|
sequences_list.append(d["tokens"][:-1])
|
||||||
|
else:
|
||||||
|
sequences_list.append(d["tokens"])
|
||||||
|
sequence = torch.cat(sequences_list, dim=-1)
|
||||||
|
|
||||||
if return_token_timestamps:
|
if return_token_timestamps:
|
||||||
token_timestamps = torch.cat(
|
token_timestamps = torch.cat(
|
||||||
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
|
[d["result"]["token_timestamps"][d["idxs"][0] : d["idxs"][1]] for d in current_segment_list],
|
||||||
@ -1809,14 +1830,6 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
|
||||||
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
|
||||||
|
|
||||||
for segments in active_segments:
|
|
||||||
for seg in segments:
|
|
||||||
if len(seg["tokens"]) > 2 and seg["tokens"][-2] >= timestamp_begin:
|
|
||||||
# the segment finishes with two timestamp tokens
|
|
||||||
# we need to ignore the last timestamp token
|
|
||||||
# see https://github.com/huggingface/transformers/pull/34537
|
|
||||||
seg["tokens"] = seg["tokens"][:-1]
|
|
||||||
|
|
||||||
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
|
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
|
||||||
prev_ids = prompt_ids
|
prev_ids = prompt_ids
|
||||||
else:
|
else:
|
||||||
@ -1833,6 +1846,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
bos_token_tensor=prev_ids,
|
bos_token_tensor=prev_ids,
|
||||||
cut_off_length=cut_off_length,
|
cut_off_length=cut_off_length,
|
||||||
|
skip_ending_double_timestamps=True,
|
||||||
|
timestamp_begin=timestamp_begin,
|
||||||
)
|
)
|
||||||
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
|
||||||
|
|
||||||
|
@ -910,7 +910,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
|
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500):
|
||||||
"""
|
"""
|
||||||
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
|
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
|
||||||
the various options not allowed in other seq2seq models
|
the various options not allowed in other seq2seq models
|
||||||
@ -962,6 +962,12 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
last_timestamp = None
|
last_timestamp = None
|
||||||
first_timestamp = timestamp_begin
|
first_timestamp = timestamp_begin
|
||||||
|
|
||||||
|
# long form generation: we need to handle the case where the call to generate returns concatenated segments,
|
||||||
|
# with underlying multiple calls to generate
|
||||||
|
cur_max_timestamp = 0.0
|
||||||
|
prev_segments_len = 0.0
|
||||||
|
penultimate_timestamp = 0.0
|
||||||
|
|
||||||
if "stride" in output:
|
if "stride" in output:
|
||||||
chunk_len, stride_left, stride_right = output["stride"]
|
chunk_len, stride_left, stride_right = output["stride"]
|
||||||
# Offset the timings to account for the other `model_outputs`.
|
# Offset the timings to account for the other `model_outputs`.
|
||||||
@ -1024,7 +1030,24 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
pass
|
pass
|
||||||
elif token >= timestamp_begin:
|
elif token >= timestamp_begin:
|
||||||
# 3/ Timestamp token
|
# 3/ Timestamp token
|
||||||
time = (token - timestamp_begin) * time_precision + time_offset
|
|
||||||
|
timestamp = float((token - timestamp_begin) * time_precision)
|
||||||
|
if timestamp < cur_max_timestamp:
|
||||||
|
# next segment has started
|
||||||
|
last_was_single_ending = i >= 2 and not (
|
||||||
|
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
|
||||||
|
)
|
||||||
|
if last_was_single_ending:
|
||||||
|
prev_segments_len += time_precision * segment_size
|
||||||
|
else:
|
||||||
|
cur_max_timestamp = penultimate_timestamp
|
||||||
|
prev_segments_len += penultimate_timestamp
|
||||||
|
|
||||||
|
penultimate_timestamp = cur_max_timestamp
|
||||||
|
cur_max_timestamp = timestamp
|
||||||
|
|
||||||
|
time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len
|
||||||
|
|
||||||
time = round(time, 2)
|
time = round(time, 2)
|
||||||
if last_timestamp and token >= last_timestamp:
|
if last_timestamp and token >= last_timestamp:
|
||||||
# Whisper outputted a timestamp token, but it falls within
|
# Whisper outputted a timestamp token, but it falls within
|
||||||
|
@ -283,13 +283,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# No parameters on this pipeline right now
|
# No parameters on this pipeline right now
|
||||||
preprocess_params = {}
|
preprocess_params = {}
|
||||||
if chunk_length_s is not None:
|
if chunk_length_s is not None:
|
||||||
if self.type == "seq2seq" and not ignore_warning:
|
if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning:
|
||||||
logger.warning(
|
type_warning = (
|
||||||
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
|
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
|
||||||
" be entirely accurate and will have caveats. More information:"
|
" be entirely accurate and will have caveats. More information:"
|
||||||
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
|
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
|
||||||
" ignore_warning=True)"
|
" ignore_warning=True)."
|
||||||
)
|
)
|
||||||
|
if self.type == "seq2seq_whisper":
|
||||||
|
type_warning += (
|
||||||
|
" To use Whisper for long-form transcription, use rather the model's `generate` method directly "
|
||||||
|
"as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. "
|
||||||
|
"Long-form Transcription)."
|
||||||
|
)
|
||||||
|
logger.warning(type_warning)
|
||||||
preprocess_params["chunk_length_s"] = chunk_length_s
|
preprocess_params["chunk_length_s"] = chunk_length_s
|
||||||
if stride_length_s is not None:
|
if stride_length_s is not None:
|
||||||
preprocess_params["stride_length_s"] = stride_length_s
|
preprocess_params["stride_length_s"] = stride_length_s
|
||||||
|
@ -2031,11 +2031,13 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
).input_features
|
).input_features
|
||||||
input_features = input_features.to(torch_device)
|
input_features = input_features.to(torch_device)
|
||||||
|
|
||||||
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
|
generated_ids = model.generate(
|
||||||
|
input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True
|
||||||
|
).to("cpu")
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_OUTPUT = torch.tensor([
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430
|
[50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50431]
|
||||||
])
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
|
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
|
||||||
@ -2078,7 +2080,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"text": (" and can discover"),
|
"text": (" and can discover"),
|
||||||
"timestamp": (28.68, 29.98),
|
"timestamp": (28.68, 30.0),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user