mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
MbartTokenizer: do not hardcode vocab size (#5998)
This commit is contained in:
parent
6e16195510
commit
9827d666eb
@ -58,6 +58,34 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
|||||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||||
|
|
||||||
|
FAIRSEQ_LANGUAGE_CODES = [
|
||||||
|
"ar_AR",
|
||||||
|
"cs_CZ",
|
||||||
|
"de_DE",
|
||||||
|
"en_XX",
|
||||||
|
"es_XX",
|
||||||
|
"et_EE",
|
||||||
|
"fi_FI",
|
||||||
|
"fr_XX",
|
||||||
|
"gu_IN",
|
||||||
|
"hi_IN",
|
||||||
|
"it_IT",
|
||||||
|
"ja_XX",
|
||||||
|
"kk_KZ",
|
||||||
|
"ko_KR",
|
||||||
|
"lt_LT",
|
||||||
|
"lv_LV",
|
||||||
|
"my_MM",
|
||||||
|
"ne_NP",
|
||||||
|
"nl_XX",
|
||||||
|
"ro_RO",
|
||||||
|
"ru_RU",
|
||||||
|
"si_LK",
|
||||||
|
"tr_TR",
|
||||||
|
"vi_VN",
|
||||||
|
"zh_CN",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class MBartTokenizer(XLMRobertaTokenizer):
|
class MBartTokenizer(XLMRobertaTokenizer):
|
||||||
"""
|
"""
|
||||||
@ -81,40 +109,20 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||||
lang_code_to_id = { # NOTE(SS): resize embeddings will break this
|
|
||||||
"ar_AR": 250001,
|
|
||||||
"cs_CZ": 250002,
|
|
||||||
"de_DE": 250003,
|
|
||||||
"en_XX": 250004,
|
|
||||||
"es_XX": 250005,
|
|
||||||
"et_EE": 250006,
|
|
||||||
"fi_FI": 250007,
|
|
||||||
"fr_XX": 250008,
|
|
||||||
"gu_IN": 250009,
|
|
||||||
"hi_IN": 250010,
|
|
||||||
"it_IT": 250011,
|
|
||||||
"ja_XX": 250012,
|
|
||||||
"kk_KZ": 250013,
|
|
||||||
"ko_KR": 250014,
|
|
||||||
"lt_LT": 250015,
|
|
||||||
"lv_LV": 250016,
|
|
||||||
"my_MM": 250017,
|
|
||||||
"ne_NP": 250018,
|
|
||||||
"nl_XX": 250019,
|
|
||||||
"ro_RO": 250020,
|
|
||||||
"ru_RU": 250021,
|
|
||||||
"si_LK": 250022,
|
|
||||||
"tr_TR": 250023,
|
|
||||||
"vi_VN": 250024,
|
|
||||||
"zh_CN": 250025,
|
|
||||||
}
|
|
||||||
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
|
|
||||||
cur_lang_code = lang_code_to_id["en_XX"]
|
|
||||||
prefix_tokens: List[int] = []
|
prefix_tokens: List[int] = []
|
||||||
suffix_tokens: List[int] = []
|
suffix_tokens: List[int] = []
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.sp_model_size = len(self.sp_model)
|
||||||
|
self.lang_code_to_id = {
|
||||||
|
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||||
|
}
|
||||||
|
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||||
|
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||||
|
|
||||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||||
|
@ -113,10 +113,15 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
cls.tokenizer: MBartTokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||||
cls.pad_token_id = 1
|
cls.pad_token_id = 1
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
def check_language_codes(self):
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001)
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
||||||
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
||||||
|
|
||||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||||
batch = self.tokenizer.prepare_translation_batch(
|
batch = self.tokenizer.prepare_translation_batch(
|
||||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||||
|
Loading…
Reference in New Issue
Block a user