[Whisper Tokenizer] Make more user-friendly (#19921)

* [Whisper Tokenizer] Make more user-friendly

* use property

* make indexing rigorous

* small clean-up

* tests

* skip seq2seq tests

* remove multilingual arg

* reorder args

* collapse to one function

Co-authored-by: ArthurZucker <arthur@huggingface.co>

* option to override attributes

Co-authored-by: ArthurZucker <arthur@huggingface.co>

* add to docs

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make comment more clear

Co-authored-by: sgugger <sylvain@huggingface.co>

* don't add special tokens in get_decoder_prompt_ids

* add test for set_prefix_tokens

Co-authored-by: ArthurZucker <arthur@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: sgugger <sylvain@huggingface.co>
This commit is contained in:
Sanchit Gandhi 2022-11-03 14:22:40 +00:00 committed by GitHub
parent 790ff2544a
commit 06d488061f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 277 additions and 59 deletions

View File

@ -39,6 +39,7 @@ The original code can be found [here](https://github.com/openai/whisper).
## WhisperTokenizer
[[autodoc]] WhisperTokenizer
- set_prefix_tokens
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences

View File

@ -70,7 +70,7 @@ class WhisperProcessor(ProcessorMixin):
forced_decoder_tokens += f"<|{task}|>"
forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
ids = self.tokenizer.encode(forced_decoder_tokens)
ids = self.tokenizer.encode(forced_decoder_tokens, add_special_tokens=False)
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
return forced_decoder_ids

View File

@ -89,9 +89,130 @@ def get_pairs(word):
return pairs
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
TASK_IDS = ["translate", "transcribe"]
class WhisperTokenizer(PreTrainedTokenizer):
"""
Construct an Whisper tokenizer.
Construct a Whisper tokenizer.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
the superclass for more information regarding such methods.
@ -109,16 +230,22 @@ class WhisperTokenizer(PreTrainedTokenizer):
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
bos_token (`str`, *optional*, defaults to `"<|startoftranscript|>"`):
The beginning of sequence token.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word.
add_bos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as
any other word.
language (`str`, *optional*):
The language of the transcription text. The corresponding language id token is appended to the start of the
sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token
`"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only.
task (`str`, *optional*):
Task identifier to append at the start of sequence (if any). This should be used for mulitlingual
fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation.
predict_timestamps (`bool`, *optional*, defaults to `False`):
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
"""
vocab_files_names = VOCAB_FILES_NAMES
@ -133,11 +260,13 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer_file=None,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
bos_token="<|startoftranscript|>",
eos_token="<|endoftext|>",
pad_token=None,
add_prefix_space=False,
add_bos_token=False,
language=None,
task=None,
predict_timestamps=False,
**kwargs
):
@ -152,10 +281,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_bos_token=add_bos_token,
**kwargs,
)
self.add_bos_token = add_bos_token
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
@ -179,6 +306,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.language = language
self.task = task
self.predict_timestamps = predict_timestamps
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
@ -231,27 +362,76 @@ class WhisperTokenizer(PreTrainedTokenizer):
self.cache[token] = word
return word
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.build_inputs_with_special_tokens with GPT2 -> Whisper
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None):
"""
Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to
update the prefix tokens as required when fine-tuning. Example:
output = bos_token_ids + token_ids_0
```python
>>> # instantiate the tokenizer and set the prefix token to Spanish
>>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
>>> # now switch the prefix token from Spanish to French
>>> tokenizer.set_prefix_tokens(language="french")
```
Args:
language (`str`, *optional*, defaults to `None`):
The language of the transcription text.
task (`str`, *optional*, defaults to `None`):
Task identifier to append at the start of sequence (if any).
predict_timestamps (`bool`, *optional*, defaults to `None`):
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
"""
self.language = language if language is not None else self.language
self.task = task if task is not None else self.task
self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps
@property
def prefix_tokens(self) -> List[int]:
all_special_ids = self.all_special_ids
bos_token_id = all_special_ids[-106]
translate_token_id = all_special_ids[-6]
transcribe_token_id = all_special_ids[-5]
notimestamps_token_id = all_special_ids[-1]
langs = tuple(LANGUAGES.keys())
if self.language is not None:
self.language = self.language.lower()
if self.language in TO_LANGUAGE_CODE:
language_id = TO_LANGUAGE_CODE[self.language]
else:
raise ValueError(
f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}"
)
if self.task is not None:
if self.task not in TASK_IDS:
raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
bos_sequence = [bos_token_id]
if self.language is not None:
bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
if self.task is not None:
bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
if not self.predict_timestamps:
bos_sequence.append(notimestamps_token_id)
return bos_sequence
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return output
return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
return output + bos_token_ids + token_ids_1
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask with GPT2 -> Whisper
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
@ -264,19 +444,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if not self.add_bos_token:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
)
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1]
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0))
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
def _tokenize(self, text):

View File

@ -20,14 +20,20 @@ from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
EN_CODE = 50258
ES_CODE = 50256
ES_CODE = 50262
EN_CODE = 50259
END_OF_TRANSCRIPT = 50257
START_OF_TRANSCRIPT = 50258
TRANSLATE = 50358
TRANSCRIBE = 50359
NOTIMESTAMPS = 50363
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = WhisperTokenizer
test_rust_tokenizer = False
test_sentencepiece = False
test_seq2seq = False
def setUp(self):
super().setUp()
@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
transcript = (
"'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting"
" than his matter.<|endoftext|>'"
)
clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter."
french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente"
@classmethod
def setUpClass(cls):
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
def test_tokenizer_equivalence(self):
text = "다람쥐 헌 쳇바퀴에 타고파"
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko")
gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean")
monolingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
gpt2_tokens = gpt2_tokenizer.encode(text)
multilingual_tokens = multilingual_tokenizer.encode(text)
monolingual_tokens = monolingual_tokenizer.encode(text, add_special_tokens=False)
multilingual_tokens = multilingual_tokenizer.encode(text, add_special_tokens=False)
assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert monolingual_tokenizer.decode(monolingual_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens)
assert len(monolingual_tokens) > len(multilingual_tokens)
# fmt: off
EXPECTED_ENG = [
@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
]
# fmt: on
self.assertListEqual(gpt2_tokens, EXPECTED_ENG)
self.assertListEqual(monolingual_tokens, EXPECTED_ENG)
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
def test_tokenizer_special(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>"
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-tiny", language="english", task="transcribe"
)
text = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
multilingual_tokens = multilingual_tokenizer.encode(text)
# fmt: off
# format: <|startoftranscript|> <|lang-id|> <|task|> <|notimestamps|> ... transcription ids ... <|endoftext|>
EXPECTED_MULTI = [
50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6,
1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357,
22174, 1556, 778, 25792, 83, 50256
START_OF_TRANSCRIPT, EN_CODE, TRANSCRIBE, NOTIMESTAMPS, 7057, 0, 1012, 366, 291,
2633, 30, 508, 6, 1301, 287, 6, 36107, 631, 220, 11178,
115, 15567, 871, 44393, END_OF_TRANSCRIPT
]
EXPECTED_SPECIAL_TEXT = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hey! How are you feeling? "
"J'ai l'impression que 郷さん est prêt<|endoftext|>"
)
# fmt: on
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens))
special_transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=False)
self.assertEqual(special_transcript, EXPECTED_SPECIAL_TEXT)
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
self.assertEqual(transcript, EXPECTED_JAP)
self.assertEqual(transcript, text)
def test_vocab_size(self):
self.assertEqual(self.tokenizer.vocab_size, 50257)
# Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
def test_tokenizer_decode_ignores_language_codes(self):
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
self.assertNotIn(self.tokenizer.eos_token, result)
def test_batch_encoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"]
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-tiny", language="spanish", task="translate"
)
batch = ["El gato ", "El gato se sentó"]
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
# fmt: off
EXPECTED_MULTI = [
[50258, 50362, 50256, 50256, 50256, 50256],
[50258, 50362, 40, 716, 1654, 326]
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 220,
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 369,
2279, 812, END_OF_TRANSCRIPT]
]
# fmt: on
self.assertListEqual(batch_output, EXPECTED_MULTI)
def test_set_prefix_tokens(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
"openai/whisper-tiny", language="spanish", task="translate"
)
# change the language prefix token from Spanish to English
multilingual_tokenizer.set_prefix_tokens(language="english")
batch = ["the cat", "the cat sat"]
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
# fmt: off
EXPECTED_MULTI = [
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
3227, END_OF_TRANSCRIPT]
]
# fmt: on
self.assertListEqual(batch_output, EXPECTED_MULTI)
def test_batch_encoding_decoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
batch = ["hola güey", "que onda"]
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
self.assertListEqual(batch, transcription)