mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[Whisper Tokenizer] Make more user-friendly (#19921)
* [Whisper Tokenizer] Make more user-friendly * use property * make indexing rigorous * small clean-up * tests * skip seq2seq tests * remove multilingual arg * reorder args * collapse to one function Co-authored-by: ArthurZucker <arthur@huggingface.co> * option to override attributes Co-authored-by: ArthurZucker <arthur@huggingface.co> * add to docs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make comment more clear Co-authored-by: sgugger <sylvain@huggingface.co> * don't add special tokens in get_decoder_prompt_ids * add test for set_prefix_tokens Co-authored-by: ArthurZucker <arthur@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: sgugger <sylvain@huggingface.co>
This commit is contained in:
parent
790ff2544a
commit
06d488061f
@ -39,6 +39,7 @@ The original code can be found [here](https://github.com/openai/whisper).
|
||||
## WhisperTokenizer
|
||||
|
||||
[[autodoc]] WhisperTokenizer
|
||||
- set_prefix_tokens
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
|
@ -70,7 +70,7 @@ class WhisperProcessor(ProcessorMixin):
|
||||
forced_decoder_tokens += f"<|{task}|>"
|
||||
|
||||
forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
|
||||
ids = self.tokenizer.encode(forced_decoder_tokens)
|
||||
ids = self.tokenizer.encode(forced_decoder_tokens, add_special_tokens=False)
|
||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
|
||||
return forced_decoder_ids
|
||||
|
||||
|
@ -89,9 +89,130 @@ def get_pairs(word):
|
||||
return pairs
|
||||
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"iw": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
TASK_IDS = ["translate", "transcribe"]
|
||||
|
||||
|
||||
class WhisperTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct an Whisper tokenizer.
|
||||
Construct a Whisper tokenizer.
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
|
||||
the superclass for more information regarding such methods.
|
||||
@ -109,16 +230,22 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
bos_token (`str`, *optional*, defaults to `"<|startoftranscript|>"`):
|
||||
The beginning of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||
other word.
|
||||
add_bos_token (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as
|
||||
any other word.
|
||||
language (`str`, *optional*):
|
||||
The language of the transcription text. The corresponding language id token is appended to the start of the
|
||||
sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token
|
||||
`"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only.
|
||||
task (`str`, *optional*):
|
||||
Task identifier to append at the start of sequence (if any). This should be used for mulitlingual
|
||||
fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation.
|
||||
predict_timestamps (`bool`, *optional*, defaults to `False`):
|
||||
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
@ -133,11 +260,13 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
normalizer_file=None,
|
||||
errors="replace",
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token="<|endoftext|>",
|
||||
bos_token="<|startoftranscript|>",
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token=None,
|
||||
add_prefix_space=False,
|
||||
add_bos_token=False,
|
||||
language=None,
|
||||
task=None,
|
||||
predict_timestamps=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
@ -152,10 +281,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
add_bos_token=add_bos_token,
|
||||
**kwargs,
|
||||
)
|
||||
self.add_bos_token = add_bos_token
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
@ -179,6 +306,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
self.language = language
|
||||
self.task = task
|
||||
self.predict_timestamps = predict_timestamps
|
||||
|
||||
def get_vocab(self):
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
@ -231,27 +362,76 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.build_inputs_with_special_tokens with GPT2 -> Whisper
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
if self.add_bos_token:
|
||||
bos_token_ids = [self.bos_token_id]
|
||||
else:
|
||||
bos_token_ids = []
|
||||
def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None):
|
||||
"""
|
||||
Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to
|
||||
update the prefix tokens as required when fine-tuning. Example:
|
||||
|
||||
output = bos_token_ids + token_ids_0
|
||||
```python
|
||||
>>> # instantiate the tokenizer and set the prefix token to Spanish
|
||||
>>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
|
||||
>>> # now switch the prefix token from Spanish to French
|
||||
>>> tokenizer.set_prefix_tokens(language="french")
|
||||
```
|
||||
|
||||
Args:
|
||||
language (`str`, *optional*, defaults to `None`):
|
||||
The language of the transcription text.
|
||||
task (`str`, *optional*, defaults to `None`):
|
||||
Task identifier to append at the start of sequence (if any).
|
||||
predict_timestamps (`bool`, *optional*, defaults to `None`):
|
||||
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
|
||||
"""
|
||||
self.language = language if language is not None else self.language
|
||||
self.task = task if task is not None else self.task
|
||||
self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps
|
||||
|
||||
@property
|
||||
def prefix_tokens(self) -> List[int]:
|
||||
all_special_ids = self.all_special_ids
|
||||
bos_token_id = all_special_ids[-106]
|
||||
translate_token_id = all_special_ids[-6]
|
||||
transcribe_token_id = all_special_ids[-5]
|
||||
notimestamps_token_id = all_special_ids[-1]
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
|
||||
if self.language is not None:
|
||||
self.language = self.language.lower()
|
||||
if self.language in TO_LANGUAGE_CODE:
|
||||
language_id = TO_LANGUAGE_CODE[self.language]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}"
|
||||
)
|
||||
|
||||
if self.task is not None:
|
||||
if self.task not in TASK_IDS:
|
||||
raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
|
||||
|
||||
bos_sequence = [bos_token_id]
|
||||
if self.language is not None:
|
||||
bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
|
||||
if self.task is not None:
|
||||
bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
|
||||
if not self.predict_timestamps:
|
||||
bos_sequence.append(notimestamps_token_id)
|
||||
return bos_sequence
|
||||
|
||||
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
if token_ids_1 is None:
|
||||
return output
|
||||
return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
return output + bos_token_ids + token_ids_1
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask with GPT2 -> Whisper
|
||||
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
@ -264,19 +444,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
|
||||
if not self.add_bos_token:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
|
||||
)
|
||||
|
||||
prefix_ones = [1] * len(self.prefix_tokens)
|
||||
suffix_ones = [1]
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0))
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
|
||||
def _tokenize(self, text):
|
||||
|
@ -20,14 +20,20 @@ from transformers.testing_utils import slow
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
EN_CODE = 50258
|
||||
ES_CODE = 50256
|
||||
ES_CODE = 50262
|
||||
EN_CODE = 50259
|
||||
END_OF_TRANSCRIPT = 50257
|
||||
START_OF_TRANSCRIPT = 50258
|
||||
TRANSLATE = 50358
|
||||
TRANSCRIBE = 50359
|
||||
NOTIMESTAMPS = 50363
|
||||
|
||||
|
||||
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = WhisperTokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_sentencepiece = False
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
transcript = (
|
||||
"'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting"
|
||||
" than his matter.<|endoftext|>'"
|
||||
)
|
||||
clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter."
|
||||
french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
|
||||
def test_tokenizer_equivalence(self):
|
||||
text = "다람쥐 헌 쳇바퀴에 타고파"
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko")
|
||||
gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean")
|
||||
monolingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
gpt2_tokens = gpt2_tokenizer.encode(text)
|
||||
multilingual_tokens = multilingual_tokenizer.encode(text)
|
||||
monolingual_tokens = monolingual_tokenizer.encode(text, add_special_tokens=False)
|
||||
multilingual_tokens = multilingual_tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
||||
assert monolingual_tokenizer.decode(monolingual_tokens) == text
|
||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
||||
assert len(monolingual_tokens) > len(multilingual_tokens)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_ENG = [
|
||||
@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(gpt2_tokens, EXPECTED_ENG)
|
||||
self.assertListEqual(monolingual_tokens, EXPECTED_ENG)
|
||||
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
||||
|
||||
def test_tokenizer_special(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>"
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||
"openai/whisper-tiny", language="english", task="transcribe"
|
||||
)
|
||||
text = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
||||
|
||||
multilingual_tokens = multilingual_tokenizer.encode(text)
|
||||
|
||||
# fmt: off
|
||||
# format: <|startoftranscript|> <|lang-id|> <|task|> <|notimestamps|> ... transcription ids ... <|endoftext|>
|
||||
EXPECTED_MULTI = [
|
||||
50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6,
|
||||
1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357,
|
||||
22174, 1556, 778, 25792, 83, 50256
|
||||
START_OF_TRANSCRIPT, EN_CODE, TRANSCRIBE, NOTIMESTAMPS, 7057, 0, 1012, 366, 291,
|
||||
2633, 30, 508, 6, 1301, 287, 6, 36107, 631, 220, 11178,
|
||||
115, 15567, 871, 44393, END_OF_TRANSCRIPT
|
||||
]
|
||||
EXPECTED_SPECIAL_TEXT = (
|
||||
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hey! How are you feeling? "
|
||||
"J'ai l'impression que 郷さん est prêt<|endoftext|>"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
||||
|
||||
self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens))
|
||||
special_transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=False)
|
||||
self.assertEqual(special_transcript, EXPECTED_SPECIAL_TEXT)
|
||||
|
||||
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
||||
self.assertEqual(transcript, EXPECTED_JAP)
|
||||
self.assertEqual(transcript, text)
|
||||
|
||||
def test_vocab_size(self):
|
||||
self.assertEqual(self.tokenizer.vocab_size, 50257)
|
||||
|
||||
# Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
|
||||
def test_tokenizer_decode_ignores_language_codes(self):
|
||||
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
|
||||
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
|
||||
@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||
|
||||
def test_batch_encoding(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||
batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"]
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||
"openai/whisper-tiny", language="spanish", task="translate"
|
||||
)
|
||||
batch = ["El gato ", "El gato se sentó"]
|
||||
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_MULTI = [
|
||||
[50258, 50362, 50256, 50256, 50256, 50256],
|
||||
[50258, 50362, 40, 716, 1654, 326]
|
||||
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 220,
|
||||
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
|
||||
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 369,
|
||||
2279, 812, END_OF_TRANSCRIPT]
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(batch_output, EXPECTED_MULTI)
|
||||
|
||||
def test_set_prefix_tokens(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||
"openai/whisper-tiny", language="spanish", task="translate"
|
||||
)
|
||||
|
||||
# change the language prefix token from Spanish to English
|
||||
multilingual_tokenizer.set_prefix_tokens(language="english")
|
||||
|
||||
batch = ["the cat", "the cat sat"]
|
||||
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_MULTI = [
|
||||
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
|
||||
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
|
||||
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
|
||||
3227, END_OF_TRANSCRIPT]
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(batch_output, EXPECTED_MULTI)
|
||||
|
||||
def test_batch_encoding_decoding(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
|
||||
batch = ["hola güey", "que onda"]
|
||||
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
||||
self.assertListEqual(batch, transcription)
|
||||
|
Loading…
Reference in New Issue
Block a user