mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Whisper Tokenizer] Skip special tokens when decoding with timestamps (#23945)
This commit is contained in:
parent
8940d315aa
commit
c9cf337772
@ -491,7 +491,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||
return normalizer(text)
|
||||
|
||||
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
|
||||
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
@ -505,7 +505,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
|
||||
outputs = [
|
||||
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
|
||||
]
|
||||
return "".join(outputs)
|
||||
|
||||
def _compute_offsets(self, token_ids, time_precision=0.02):
|
||||
@ -593,7 +595,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
**kwargs,
|
||||
)
|
||||
if decode_with_timestamps:
|
||||
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
|
||||
text = self._decode_with_timestamps(
|
||||
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
# retrieve offsets
|
||||
if output_offsets:
|
||||
offsets = None
|
||||
|
@ -199,7 +199,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
return super()._encode_plus(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
|
||||
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str:
|
||||
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
@ -213,7 +213,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs]
|
||||
outputs = [
|
||||
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
|
||||
]
|
||||
return "".join(outputs)
|
||||
|
||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
|
||||
@ -303,7 +305,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
**kwargs,
|
||||
)
|
||||
if decode_with_timestamps:
|
||||
text = self._decode_with_timestamps(token_ids, time_precision=time_precision)
|
||||
text = self._decode_with_timestamps(
|
||||
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
# retrieve offsets
|
||||
if output_offsets:
|
||||
offsets = None
|
||||
|
@ -213,6 +213,38 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
||||
)
|
||||
|
||||
def test_skip_special_tokens_with_timestamps(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
# fmt: off
|
||||
encoded_input = [
|
||||
50258, 50363, 50364, 634, 575, 12525, 22618, 1968, 6144,
|
||||
35617, 20084, 1756, 311, 589, 307, 534, 10281, 934,
|
||||
439, 293, 50676, 50676, 393, 4411, 294, 309, 457,
|
||||
707, 295, 33301, 286, 392, 6628, 13, 50836, 50257,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
expected_with_special_tokens = "<|startoftranscript|><|notimestamps|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>"
|
||||
expected_without_special_tokens = "<|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|>"
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
|
||||
expected_with_special_tokens,
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
|
||||
expected_without_special_tokens,
|
||||
)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
|
||||
expected_with_special_tokens,
|
||||
)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
|
||||
expected_without_special_tokens,
|
||||
)
|
||||
|
||||
def test_fast_tokenizer_get_prompt_ids(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
Loading…
Reference in New Issue
Block a user