[Whisper Tokenizer] Skip special tokens when decoding with timestamps (#23945)

This commit is contained in:
Sanchit Gandhi 2023-06-02 15:26:59 +01:00 committed by GitHub
parent 8940d315aa
commit c9cf337772
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 6 deletions

View File

@ -491,7 +491,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
@ -505,7 +505,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
outputs = [
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
]
return "".join(outputs)
def _compute_offsets(self, token_ids, time_precision=0.02):
@ -593,7 +595,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
**kwargs,
)
if decode_with_timestamps:
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
# retrieve offsets
if output_offsets:
offsets = None

View File

@ -199,7 +199,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return super()._encode_plus(*args, **kwargs)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
@ -213,7 +213,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
outputs = [
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
]
return "".join(outputs)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
@ -303,7 +305,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
**kwargs,
)
if decode_with_timestamps:
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
# retrieve offsets
if output_offsets:
offsets = None

View File

@ -213,6 +213,38 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
)
def test_skip_special_tokens_with_timestamps(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# fmt: off
encoded_input = [
50258, 50363, 50364, 634, 575, 12525, 22618, 1968, 6144,
35617, 20084, 1756, 311, 589, 307, 534, 10281, 934,
439, 293, 50676, 50676, 393, 4411, 294, 309, 457,
707, 295, 33301, 286, 392, 6628, 13, 50836, 50257,
]
# fmt: on
expected_with_special_tokens = "<|startoftranscript|><|notimestamps|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>"
expected_without_special_tokens = "<|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|>"
self.assertEqual(
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
expected_with_special_tokens,
)
self.assertEqual(
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
expected_without_special_tokens,
)
self.assertEqual(
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
expected_with_special_tokens,
)
self.assertEqual(
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
expected_without_special_tokens,
)
def test_fast_tokenizer_get_prompt_ids(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()