[whisper] alternative fix for long-form timestamps (#32131)

* [whisper] alternative fix for long-form timestamps

* update test
This commit is contained in:
Sanchit Gandhi 2024-09-06 11:57:08 +01:00 committed by GitHub
parent 2b789f27f3
commit 51d15eb1c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 81 additions and 4 deletions

View File

@ -587,11 +587,20 @@ class WhisperTokenizer(PreTrainedTokenizer):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
last_slice = np.where(timestamp_tokens)[0][0]
cur_max_timestamp = 0
prev_segments_len = 0
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
cur_max_timestamp = end_timestamp_position
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
@ -600,8 +609,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
),
}
)

View File

@ -229,11 +229,20 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
last_slice = np.where(timestamp_tokens)[0][0]
cur_max_timestamp = 0
prev_segments_len = 0
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
cur_max_timestamp = end_timestamp_position
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
@ -242,8 +251,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
),
}
)

View File

@ -2099,6 +2099,65 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_longform_timestamps_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
input_features = processor(
sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"]
)
input_features = input_features.to(torch_device)
generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True)
EXPECTED_TRANSCRIPT = [
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"timestamp": (0.0, 6.5600000000000005),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.5600000000000005, 11.24),
},
{
"text": " He tells us that at this festive season of the year, with Christmas and roast beef looming",
"timestamp": (11.24, 16.88),
},
{
"text": " before us, similarly drawn from eating and its results occur most readily to the mind.",
"timestamp": (16.88, 23.76),
},
{
"text": " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and",
"timestamp": (23.76, 29.44),
},
{"text": " can discover in it but little of rocky ithaka.", "timestamp": (29.44, 33.72)},
{
"text": " Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite itals",
"timestamp": (33.72, 40.32),
},
{"text": " are as national as a jingo poem.", "timestamp": (40.32, 44.72)},
{
"text": " Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used",
"timestamp": (44.72, 50.4),
},
{"text": " to flash his teeth.", "timestamp": (50.4, 52.96)},
{
"text": " And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like",
"timestamp": (52.96, 58.68),
},
{"text": " a shampoo and a Turkish bath next man.", "timestamp": (58.68, 61.96)},
]
transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)
@slow
def test_large_timestamp_generation(self):
set_seed(0)