mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Whisper tokenizer word level timestamps (#32197)
* fix _fix_key in PreTrainedModel * fix _find_longest_common_sequence * add test * remove result.json * nit * update test
This commit is contained in:
parent
7ffe25f2b9
commit
3fbaaaa64d
@ -1174,7 +1174,22 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
|
||||
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
|
||||
)
|
||||
|
||||
matches = np.sum(left == right)
|
||||
if token_timestamp_sequences:
|
||||
# Get length of longest subsequence of tokens that match
|
||||
# and have timestamps that are in order
|
||||
matches = sum(
|
||||
1
|
||||
for idx, elem in enumerate(left)
|
||||
if (
|
||||
elem == right[idx]
|
||||
and left_token_timestamp_sequence[left_start + idx]
|
||||
<= token_timestamp_sequences[seq_idx + 1][right_start + idx]
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
matches = np.sum(left == right)
|
||||
|
||||
matching = matches / i + eps
|
||||
if matches > 1 and matching > max_:
|
||||
max_ = matching
|
||||
|
@ -338,6 +338,42 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||
|
||||
def test_decode_asr_with_word_level_timestamps(self):
|
||||
# fmt: off
|
||||
model_outputs = [
|
||||
{
|
||||
'stride': [10, 0, 5],
|
||||
'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]),
|
||||
'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]])
|
||||
},
|
||||
{
|
||||
'stride': [10, 5, 0],
|
||||
'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]),
|
||||
'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]])
|
||||
}
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
tokenizer = WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped")
|
||||
result = tokenizer._decode_asr(
|
||||
model_outputs, return_timestamps="word", return_language=False, time_precision=0.02
|
||||
)
|
||||
|
||||
EXPECTED_OUTPUT = (
|
||||
" Yes, you can! Just do it",
|
||||
{
|
||||
"chunks": [
|
||||
{"text": " Yes,", "timestamp": (5.18, 5.56)},
|
||||
{"text": " you", "timestamp": (5.56, 5.84)},
|
||||
{"text": " can!", "timestamp": (5.84, 7.12)},
|
||||
{"text": " Just", "timestamp": (7.12, 7.56)},
|
||||
{"text": " do", "timestamp": (7.56, 7.8)},
|
||||
{"text": " it", "timestamp": (7.8, 8.72)},
|
||||
]
|
||||
},
|
||||
)
|
||||
self.assertEqual(result, EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
Loading…
Reference in New Issue
Block a user