mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[whisper] move processor test into processor test file 🧹 (#38266)
move processor tests
This commit is contained in:
parent
b26157d64c
commit
aa02a5d902
@ -79,26 +79,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
||||
break
|
||||
|
||||
|
||||
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
|
||||
seq_len_left = len(sequence_left)
|
||||
seq_len_right = len(sequence_right)
|
||||
counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)]
|
||||
longest = 0
|
||||
for i in range(seq_len_left):
|
||||
for j in range(seq_len_right):
|
||||
if sequence_left[i] == sequence_right[j]:
|
||||
previous_counter = counter[i][j] + 1
|
||||
counter[i + 1][j + 1] = previous_counter
|
||||
if previous_counter > longest:
|
||||
longest = previous_counter
|
||||
|
||||
counter = np.array(counter)
|
||||
# we return the idx of the first element of the longest common sequence in the left sequence
|
||||
index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1
|
||||
index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1
|
||||
return index_left, index_right, longest
|
||||
|
||||
|
||||
def _find_longest_common_sequence(sequences, tokenizer):
|
||||
# TODO Use a faster algorithm this can probably be done in O(n)
|
||||
# using suffix array.
|
||||
@ -664,109 +644,3 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
for k, v in output.items():
|
||||
extra[k].append(v)
|
||||
return {"text": text, **optional, **extra}
|
||||
|
||||
|
||||
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
||||
"""
|
||||
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
||||
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
||||
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
||||
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
||||
properly compute the final `offset`.
|
||||
"""
|
||||
# index of the first timestamp token
|
||||
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||
items = []
|
||||
# approximation of the token to time ratio : ~0.2seconds
|
||||
time_precision = feature_extractor.chunk_length / max_source_positions
|
||||
time = 0
|
||||
for seq_idx, item in enumerate(sequences):
|
||||
sequence, stride = item
|
||||
if isinstance(sequence, list):
|
||||
sequence = np.array(sequence)
|
||||
chunk_len, stride_left, stride_right = stride
|
||||
sequence = sequence.squeeze(0)
|
||||
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
||||
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
|
||||
sequence = sequence[begin_idx:]
|
||||
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||
time -= stride_left + stride_right
|
||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||
# relevant timestamps are in the overlapping part
|
||||
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
||||
if relevant_timestamp.shape[0] > 0:
|
||||
relevant_timestamp = (
|
||||
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
||||
)
|
||||
# if a big stride is used, we need to check some of the previous items for the best overlap
|
||||
best_match = 0
|
||||
sliced_sequence = []
|
||||
for idx, previous_sequence in enumerate(reversed(items)):
|
||||
previous_tokens = previous_sequence[1:-1]
|
||||
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
||||
break # the previous sequence is too far in the past
|
||||
if len(previous_tokens) > 0:
|
||||
# find the longest common sequence between the overlapping parts
|
||||
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
||||
sequence[1:relevant_timestamp], previous_tokens
|
||||
)
|
||||
# don't do anything if only 1 token was matched
|
||||
if match_length > 1 and match_length > best_match:
|
||||
best_match = match_length
|
||||
best_idx = idx
|
||||
end_of_curr_sequence_idx = (
|
||||
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
||||
)
|
||||
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
||||
# if all the tokens are matched, suffix
|
||||
if index_left == 0 and match_length == len(previous_tokens):
|
||||
sliced_sequence = np.insert(
|
||||
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
||||
)
|
||||
sliced_sequence[-1] = previous_sequence[-1]
|
||||
# if part of the previous sequence is not taken
|
||||
elif index_left >= 0:
|
||||
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
||||
# let's insert the missing part of the previous sequence
|
||||
previous_slice = (
|
||||
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
||||
)
|
||||
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
||||
sliced_sequence[-1] += offset
|
||||
|
||||
if len(sliced_sequence) > 0:
|
||||
items[len(items) - best_idx - 1] = sliced_sequence
|
||||
items = items[: len(items) - best_idx]
|
||||
sequence = sequence[end_of_curr_sequence_idx:]
|
||||
|
||||
# sequence might have changed
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
if sum(timestamp_tokens) > 0:
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = (
|
||||
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
||||
)
|
||||
|
||||
if len(consecutive) > 0:
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
||||
sliced_tokens = sequence[last_slice:current_slice]
|
||||
duration = sliced_tokens[-1] - sliced_tokens[0]
|
||||
sliced_tokens[0] = actual_offset
|
||||
sliced_tokens[-1] = actual_offset + duration
|
||||
items.append(sliced_tokens)
|
||||
last_slice = current_slice
|
||||
|
||||
time += chunk_len
|
||||
result = []
|
||||
for i in range(len(items)):
|
||||
result += items[i].tolist()
|
||||
return result
|
||||
|
@ -16,6 +16,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import WhisperTokenizer, is_speech_available
|
||||
@ -177,3 +178,296 @@ class WhisperProcessorTest(unittest.TestCase):
|
||||
_test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>")
|
||||
_test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>")
|
||||
_test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>")
|
||||
|
||||
def test_find_longest_common_subsequence_old(self):
|
||||
"""Test using the old processing functions used in the ASR pipeline, but that serves as a BC reference."""
|
||||
max_source_positions = 1500
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
|
||||
previous_sequence = [[51492, 406, 3163, 1953, 466, 13, 51612, 51612]]
|
||||
self.assertEqual(
|
||||
processor.decode(previous_sequence[0], output_offsets=True),
|
||||
{
|
||||
"text": " not worth thinking about.",
|
||||
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
|
||||
},
|
||||
)
|
||||
|
||||
# Merge when the previous sequence is a suffix of the next sequence
|
||||
# fmt: off
|
||||
next_sequences_1 = [
|
||||
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||
]
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
processor.decode(next_sequences_1[0], output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
|
||||
" small, sharp blow high on his chest.<|endoftext|>"
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (5.0, 9.4),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_1, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 51739, 51739, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
|
||||
)
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " not worth thinking about.", "timestamp": (22.56, 27.5)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (27.5, 31.900000000000002),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Merge when the sequence is in the middle of the 1st next sequence
|
||||
# fmt: off
|
||||
next_sequences_2 = [
|
||||
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||
]
|
||||
# fmt: on
|
||||
# {'text': ' of spectators, retrievality is not worth thinking about. His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_2, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
|
||||
)
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on"
|
||||
" his chest."
|
||||
),
|
||||
"timestamp": (22.56, 31.900000000000002),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Merge when the previous sequence is not included in the current sequence
|
||||
next_sequences_3 = [[50364, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50584, 50257]] # fmt: skip
|
||||
# {'text': ' His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51832],
|
||||
) # fmt: skip
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (24.96, 29.36),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
# last case is when the sequence is not in the first next predicted start and end of timestamp
|
||||
next_sequences_3 = [
|
||||
[50364, 2812, 9836, 14783, 390, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934]
|
||||
] # fmt: skip
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 167_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51912]
|
||||
) # fmt: skip
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (24.96, 30.96),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
|
||||
"""Old processing function used in the ASR pipeline."""
|
||||
seq_len_left = len(sequence_left)
|
||||
seq_len_right = len(sequence_right)
|
||||
counter = [[0] * (seq_len_right + 1) for _ in range(seq_len_left + 1)]
|
||||
longest = 0
|
||||
for i in range(seq_len_left):
|
||||
for j in range(seq_len_right):
|
||||
if sequence_left[i] == sequence_right[j]:
|
||||
previous_counter = counter[i][j] + 1
|
||||
counter[i + 1][j + 1] = previous_counter
|
||||
if previous_counter > longest:
|
||||
longest = previous_counter
|
||||
|
||||
counter = np.array(counter)
|
||||
# we return the idx of the first element of the longest common sequence in the left sequence
|
||||
index_left = np.argwhere(counter == longest)[-1][0] - longest if longest != 0 else -1
|
||||
index_right = np.argwhere(counter == longest)[-1][1] - longest if longest != 0 else -1
|
||||
return index_left, index_right, longest
|
||||
|
||||
|
||||
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
||||
"""
|
||||
Old processing function used in the ASR pipeline.
|
||||
|
||||
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
||||
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
||||
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
||||
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
||||
properly compute the final `offset`.
|
||||
"""
|
||||
# index of the first timestamp token
|
||||
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||
items = []
|
||||
# approximation of the token to time ratio : ~0.2seconds
|
||||
time_precision = feature_extractor.chunk_length / max_source_positions
|
||||
time = 0
|
||||
for seq_idx, item in enumerate(sequences):
|
||||
sequence, stride = item
|
||||
if isinstance(sequence, list):
|
||||
sequence = np.array(sequence)
|
||||
chunk_len, stride_left, stride_right = stride
|
||||
sequence = sequence.squeeze(0)
|
||||
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
||||
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
|
||||
sequence = sequence[begin_idx:]
|
||||
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||
time -= stride_left + stride_right
|
||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||
# relevant timestamps are in the overlapping part
|
||||
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
||||
if relevant_timestamp.shape[0] > 0:
|
||||
relevant_timestamp = (
|
||||
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
||||
)
|
||||
# if a big stride is used, we need to check some of the previous items for the best overlap
|
||||
best_match = 0
|
||||
sliced_sequence = []
|
||||
for idx, previous_sequence in enumerate(reversed(items)):
|
||||
previous_tokens = previous_sequence[1:-1]
|
||||
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
||||
break # the previous sequence is too far in the past
|
||||
if len(previous_tokens) > 0:
|
||||
# find the longest common sequence between the overlapping parts
|
||||
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
||||
sequence[1:relevant_timestamp], previous_tokens
|
||||
)
|
||||
# don't do anything if only 1 token was matched
|
||||
if match_length > 1 and match_length > best_match:
|
||||
best_match = match_length
|
||||
best_idx = idx
|
||||
end_of_curr_sequence_idx = (
|
||||
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
||||
)
|
||||
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
||||
# if all the tokens are matched, suffix
|
||||
if index_left == 0 and match_length == len(previous_tokens):
|
||||
sliced_sequence = np.insert(
|
||||
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
||||
)
|
||||
sliced_sequence[-1] = previous_sequence[-1]
|
||||
# if part of the previous sequence is not taken
|
||||
elif index_left >= 0:
|
||||
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
||||
# let's insert the missing part of the previous sequence
|
||||
previous_slice = (
|
||||
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
||||
)
|
||||
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
||||
sliced_sequence[-1] += offset
|
||||
|
||||
if len(sliced_sequence) > 0:
|
||||
items[len(items) - best_idx - 1] = sliced_sequence
|
||||
items = items[: len(items) - best_idx]
|
||||
sequence = sequence[end_of_curr_sequence_idx:]
|
||||
|
||||
# sequence might have changed
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
if sum(timestamp_tokens) > 0:
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = (
|
||||
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
||||
)
|
||||
|
||||
if len(consecutive) > 0:
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
||||
sliced_tokens = sequence[last_slice:current_slice]
|
||||
duration = sliced_tokens[-1] - sliced_tokens[0]
|
||||
sliced_tokens[0] = actual_offset
|
||||
sliced_tokens[-1] = actual_offset + duration
|
||||
items.append(sliced_tokens)
|
||||
last_slice = current_slice
|
||||
|
||||
time += chunk_len
|
||||
result = []
|
||||
for i in range(len(items)):
|
||||
result += items[i].tolist()
|
||||
return result
|
||||
|
@ -33,7 +33,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.pipelines.audio_utils import chunk_bytes_iter, ffmpeg_microphone_live
|
||||
from transformers.pipelines.automatic_speech_recognition import _find_timestamp_sequence, chunk_iter
|
||||
from transformers.pipelines.automatic_speech_recognition import chunk_iter
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
is_pipeline_test,
|
||||
@ -636,169 +636,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
output = speech_recognizer(ds["audio"], batch_size=2)
|
||||
self.assertEqual(output, EXPECTED_OUTPUT)
|
||||
|
||||
def test_find_longest_common_subsequence(self):
|
||||
max_source_positions = 1500
|
||||
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
|
||||
|
||||
previous_sequence = [[51492, 406, 3163, 1953, 466, 13, 51612, 51612]]
|
||||
self.assertEqual(
|
||||
processor.decode(previous_sequence[0], output_offsets=True),
|
||||
{
|
||||
"text": " not worth thinking about.",
|
||||
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
|
||||
},
|
||||
)
|
||||
|
||||
# Merge when the previous sequence is a suffix of the next sequence
|
||||
# fmt: off
|
||||
next_sequences_1 = [
|
||||
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||
]
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
processor.decode(next_sequences_1[0], output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
|
||||
" small, sharp blow high on his chest.<|endoftext|>"
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (5.0, 9.4),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_1, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 51739, 51739, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
|
||||
)
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " not worth thinking about.", "timestamp": (22.56, 27.5)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (27.5, 31.900000000000002),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Merge when the sequence is in the middle of the 1st next sequence
|
||||
# fmt: off
|
||||
next_sequences_2 = [
|
||||
[50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||
]
|
||||
# fmt: on
|
||||
# {'text': ' of spectators, retrievality is not worth thinking about. His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_2, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51959],
|
||||
)
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on"
|
||||
" his chest."
|
||||
),
|
||||
"timestamp": (22.56, 31.900000000000002),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Merge when the previous sequence is not included in the current sequence
|
||||
next_sequences_3 = [[50364, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50584, 50257]] # fmt: skip
|
||||
# {'text': ' His instant panic was followed by a small, sharp blow high on his chest.','timestamp': (0.0, 9.4)}
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 120_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51832],
|
||||
) # fmt: skip
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (24.96, 29.36),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
# last case is when the sequence is not in the first next predicted start and end of timestamp
|
||||
next_sequences_3 = [
|
||||
[50364, 2812, 9836, 14783, 390, 406, 3163, 1953, 466, 13, 50634, 50634, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50934]
|
||||
] # fmt: skip
|
||||
merge = _find_timestamp_sequence(
|
||||
[[previous_sequence, (480_000, 0, 0)], [next_sequences_3, (480_000, 167_000, 0)]],
|
||||
processor.tokenizer,
|
||||
processor.feature_extractor,
|
||||
max_source_positions,
|
||||
)
|
||||
self.assertEqual(
|
||||
merge,
|
||||
[51492, 406, 3163, 1953, 466, 13, 51612, 51612, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 51912]
|
||||
) # fmt: skip
|
||||
self.assertEqual(
|
||||
processor.decode(merge, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" not worth thinking about. His instant panic was followed by a small, sharp blow high on his"
|
||||
" chest."
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (24.96, 30.96),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@unittest.skip("TODO (joao, eustache): this test is failing, find the breaking PR and fix the cause or the test")
|
||||
|
Loading…
Reference in New Issue
Block a user