mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
MBartTokenizer:add language codes (#3776)
This commit is contained in:
parent
20451195f0
commit
08b59d10e5
@ -14,8 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from .tokenization_roberta import RobertaTokenizer
|
from .tokenization_roberta import RobertaTokenizer
|
||||||
|
from .tokenization_utils import BatchEncoding
|
||||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
|
|
||||||
|
|
||||||
@ -47,6 +49,104 @@ SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-la
|
|||||||
|
|
||||||
|
|
||||||
class MBartTokenizer(XLMRobertaTokenizer):
|
class MBartTokenizer(XLMRobertaTokenizer):
|
||||||
|
"""
|
||||||
|
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
|
||||||
|
Other tokenizer methods like encode do not work properly.
|
||||||
|
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
from transformers import MBartTokenizer
|
||||||
|
tokenizer = MBartTokenizer.from_pretrained('mbart-large-en-ro')
|
||||||
|
example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||||
|
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||||
|
batch: dict = tokenizer.prepare_translation_batch(
|
||||||
|
example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||||
|
lang_code_to_id = { # NOTE(SS): resize embeddings will break this
|
||||||
|
"ar_AR": 250001,
|
||||||
|
"cs_CZ": 250002,
|
||||||
|
"de_DE": 250003,
|
||||||
|
"en_XX": 250004,
|
||||||
|
"es_XX": 250005,
|
||||||
|
"et_EE": 250006,
|
||||||
|
"fi_FI": 250007,
|
||||||
|
"fr_XX": 250008,
|
||||||
|
"gu_IN": 250009,
|
||||||
|
"hi_IN": 250010,
|
||||||
|
"it_IT": 250011,
|
||||||
|
"ja_XX": 250012,
|
||||||
|
"kk_KZ": 250013,
|
||||||
|
"ko_KR": 250014,
|
||||||
|
"lt_LT": 250015,
|
||||||
|
"lv_LV": 250016,
|
||||||
|
"my_MM": 250017,
|
||||||
|
"ne_NP": 250018,
|
||||||
|
"nl_XX": 250019,
|
||||||
|
"ro_RO": 250020,
|
||||||
|
"ru_RU": 250021,
|
||||||
|
"si_LK": 250022,
|
||||||
|
"tr_TR": 250023,
|
||||||
|
"vi_VN": 250024,
|
||||||
|
"zh_CN": 250025,
|
||||||
|
}
|
||||||
|
cur_lang_code = lang_code_to_id["en_XX"]
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||||
|
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||||
|
special_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return token_ids_0 + special_tokens
|
||||||
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
|
return token_ids_0 + token_ids_1 + special_tokens
|
||||||
|
|
||||||
|
def prepare_translation_batch(
|
||||||
|
self,
|
||||||
|
src_texts: List[str],
|
||||||
|
src_lang: str = "en_XX",
|
||||||
|
tgt_texts: Optional[List[str]] = None,
|
||||||
|
tgt_lang: str = "ro_RO",
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_to_max_length: bool = True,
|
||||||
|
return_tensors: str = "pt",
|
||||||
|
) -> BatchEncoding:
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
src_texts: list of src language texts
|
||||||
|
src_lang: default en_XX (english)
|
||||||
|
tgt_texts: list of tgt language texts
|
||||||
|
tgt_lang: default ro_RO (romanian)
|
||||||
|
max_length: (None) defer to config (1024 for mbart-large-en-ro)
|
||||||
|
pad_to_max_length: (bool)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
|
||||||
|
"""
|
||||||
|
if max_length is None:
|
||||||
|
max_length = self.max_len
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||||
|
model_inputs: BatchEncoding = self.batch_encode_plus(
|
||||||
|
src_texts,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
max_length=max_length,
|
||||||
|
pad_to_max_length=pad_to_max_length,
|
||||||
|
)
|
||||||
|
if tgt_texts is None:
|
||||||
|
return model_inputs
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[tgt_lang]
|
||||||
|
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
||||||
|
tgt_texts,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
max_length=max_length,
|
||||||
|
pad_to_max_length=pad_to_max_length,
|
||||||
|
)
|
||||||
|
for k, v in decoder_inputs.items():
|
||||||
|
model_inputs[f"decoder_{k}"] = v
|
||||||
|
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||||
|
return model_inputs
|
||||||
|
@ -19,6 +19,7 @@ import unittest
|
|||||||
import timeout_decorator # noqa
|
import timeout_decorator # noqa
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
@ -37,6 +38,7 @@ if is_torch_available():
|
|||||||
BartConfig,
|
BartConfig,
|
||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
MBartTokenizer,
|
MBartTokenizer,
|
||||||
|
BatchEncoding,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bart import (
|
from transformers.modeling_bart import (
|
||||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
tiny(**inputs_dict)
|
tiny(**inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
|
EN_CODE = 250004
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BartTranslationTests(unittest.TestCase):
|
class MBartIntegrationTests(unittest.TestCase):
|
||||||
_model = None
|
src_text = [
|
||||||
|
" UN Chief Says There Is No Military Solution in Syria",
|
||||||
|
" I ate lunch twice yesterday",
|
||||||
|
]
|
||||||
|
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
|
||||||
|
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
||||||
cls.pad_token_id = 1
|
cls.pad_token_id = 1
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
"""Only load the model if needed."""
|
||||||
|
|
||||||
|
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||||
|
if "cuda" in torch_device:
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_enro_forward(self):
|
||||||
|
model = self.model
|
||||||
net_input = {
|
net_input = {
|
||||||
"input_ids": _long_tensor(
|
"input_ids": _long_tensor(
|
||||||
[
|
[
|
||||||
@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase):
|
|||||||
),
|
),
|
||||||
"generation_mode": False,
|
"generation_mode": False,
|
||||||
}
|
}
|
||||||
net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id)
|
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||||
cls.net_input = net_input
|
|
||||||
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model(self):
|
|
||||||
"""Only load the model if needed."""
|
|
||||||
if self._model is None:
|
|
||||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
|
||||||
self._model = model.to(torch_device)
|
|
||||||
return self._model
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_enro_forward(self):
|
|
||||||
model = self.model
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits, *other_stuff = model(**self.net_input)
|
logits, *other_stuff = model(**net_input)
|
||||||
|
|
||||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device)
|
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device)
|
||||||
result_slice = logits[0][0][:3]
|
result_slice = logits[0][0][:3]
|
||||||
@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_enro_generate(self):
|
def test_enro_generate(self):
|
||||||
model = self.model
|
inputs: dict = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
|
||||||
# example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
translated_tokens = self.model.generate(input_ids=inputs["input_ids"].to(torch_device))
|
||||||
# inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",)
|
|
||||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"input_ids": torch.LongTensor(
|
|
||||||
[[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004
|
|
||||||
)
|
|
||||||
}
|
|
||||||
translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,)
|
|
||||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
self.assertEqual(expected_translation_romanian, decoded[0])
|
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||||
|
|
||||||
def test_mbart_enro_config(self):
|
def test_mbart_enro_config(self):
|
||||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||||
@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase):
|
|||||||
e.args += (name, k)
|
e.args += (name, k)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def test_enro_tokenizer(self):
|
|
||||||
raw = "UN Chief Says There Is No Military Solution in Syria"
|
|
||||||
ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0]
|
|
||||||
expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]
|
|
||||||
# TODO(SS): should be [8274, ..., 2, 250020]
|
|
||||||
self.assertListEqual(expected_result, ids)
|
|
||||||
|
|
||||||
def test_mbart_fast_forward(self):
|
def test_mbart_fast_forward(self):
|
||||||
config = BartConfig(
|
config = BartConfig(
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase):
|
|||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MBartTokenizerTests(MBartIntegrationTests):
|
||||||
|
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||||
|
batch = self.tokenizer.prepare_translation_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||||
|
)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
|
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||||
|
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||||
|
result = batch.input_ids.tolist()[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, result)
|
||||||
|
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
|
||||||
|
|
||||||
|
def test_enro_tokenizer_batch_encode_plus(self):
|
||||||
|
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, ids)
|
||||||
|
|
||||||
|
def test_enro_tokenizer_truncation(self):
|
||||||
|
src_text = ["this is gunna be a long sentence " * 20]
|
||||||
|
assert isinstance(src_text[0], str)
|
||||||
|
desired_max_length = 10
|
||||||
|
ids = self.tokenizer.prepare_translation_batch(
|
||||||
|
src_text, return_tensors=None, max_length=desired_max_length
|
||||||
|
).input_ids[0]
|
||||||
|
self.assertEqual(ids[-2], 2)
|
||||||
|
self.assertEqual(ids[-1], EN_CODE)
|
||||||
|
self.assertEqual(len(ids), desired_max_length)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BartHeadTests(unittest.TestCase):
|
class BartHeadTests(unittest.TestCase):
|
||||||
vocab_size = 99
|
vocab_size = 99
|
||||||
|
Loading…
Reference in New Issue
Block a user