mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[whisper] alternative fix for long-form timestamps (#32131)
* [whisper] alternative fix for long-form timestamps * update test
This commit is contained in:
parent
2b789f27f3
commit
51d15eb1c1
@ -587,11 +587,20 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
|
||||
|
||||
last_slice = np.where(timestamp_tokens)[0][0]
|
||||
cur_max_timestamp = 0
|
||||
prev_segments_len = 0
|
||||
for current_slice in consecutive:
|
||||
sliced_tokens = token_ids[last_slice:current_slice]
|
||||
if len(sliced_tokens) > 1:
|
||||
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
||||
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
||||
|
||||
if start_timestamp_position < cur_max_timestamp:
|
||||
# next segment has started
|
||||
prev_segments_len += cur_max_timestamp
|
||||
|
||||
cur_max_timestamp = end_timestamp_position
|
||||
|
||||
# strip timestamp tokens from the text output
|
||||
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
|
||||
text = self._decode(sliced_tokens)
|
||||
@ -600,8 +609,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
{
|
||||
"text": text,
|
||||
"timestamp": (
|
||||
start_timestamp_position * time_precision,
|
||||
end_timestamp_position * time_precision,
|
||||
(start_timestamp_position + prev_segments_len) * time_precision,
|
||||
(end_timestamp_position + prev_segments_len) * time_precision,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -229,11 +229,20 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
|
||||
|
||||
last_slice = np.where(timestamp_tokens)[0][0]
|
||||
cur_max_timestamp = 0
|
||||
prev_segments_len = 0
|
||||
for current_slice in consecutive:
|
||||
sliced_tokens = token_ids[last_slice:current_slice]
|
||||
if len(sliced_tokens) > 1:
|
||||
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
|
||||
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
|
||||
|
||||
if start_timestamp_position < cur_max_timestamp:
|
||||
# next segment has started
|
||||
prev_segments_len += cur_max_timestamp
|
||||
|
||||
cur_max_timestamp = end_timestamp_position
|
||||
|
||||
# strip timestamp tokens from the text output
|
||||
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
|
||||
text = self._decode(sliced_tokens)
|
||||
@ -242,8 +251,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
{
|
||||
"text": text,
|
||||
"timestamp": (
|
||||
start_timestamp_position * time_precision,
|
||||
end_timestamp_position * time_precision,
|
||||
(start_timestamp_position + prev_segments_len) * time_precision,
|
||||
(end_timestamp_position + prev_segments_len) * time_precision,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
@ -2099,6 +2099,65 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_longform_timestamps_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
input_features = processor(
|
||||
sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"]
|
||||
)
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||
"timestamp": (0.0, 6.5600000000000005),
|
||||
},
|
||||
{
|
||||
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
"timestamp": (6.5600000000000005, 11.24),
|
||||
},
|
||||
{
|
||||
"text": " He tells us that at this festive season of the year, with Christmas and roast beef looming",
|
||||
"timestamp": (11.24, 16.88),
|
||||
},
|
||||
{
|
||||
"text": " before us, similarly drawn from eating and its results occur most readily to the mind.",
|
||||
"timestamp": (16.88, 23.76),
|
||||
},
|
||||
{
|
||||
"text": " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and",
|
||||
"timestamp": (23.76, 29.44),
|
||||
},
|
||||
{"text": " can discover in it but little of rocky ithaka.", "timestamp": (29.44, 33.72)},
|
||||
{
|
||||
"text": " Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite itals",
|
||||
"timestamp": (33.72, 40.32),
|
||||
},
|
||||
{"text": " are as national as a jingo poem.", "timestamp": (40.32, 44.72)},
|
||||
{
|
||||
"text": " Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used",
|
||||
"timestamp": (44.72, 50.4),
|
||||
},
|
||||
{"text": " to flash his teeth.", "timestamp": (50.4, 52.96)},
|
||||
{
|
||||
"text": " And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like",
|
||||
"timestamp": (52.96, 58.68),
|
||||
},
|
||||
{"text": " a shampoo and a Turkish bath next man.", "timestamp": (58.68, 61.96)},
|
||||
]
|
||||
|
||||
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_large_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
|
Loading…
Reference in New Issue
Block a user