diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index ac982202a45..e272281d898 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -14,8 +14,10 @@ # limitations under the License. import logging +from typing import List, Optional from .tokenization_roberta import RobertaTokenizer +from .tokenization_utils import BatchEncoding from .tokenization_xlm_roberta import XLMRobertaTokenizer @@ -47,6 +49,104 @@ SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-la class MBartTokenizer(XLMRobertaTokenizer): + """ + This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs. + Other tokenizer methods like encode do not work properly. + The tokenization method is . There is no BOS token. + + Examples:: + from transformers import MBartTokenizer + tokenizer = MBartTokenizer.from_pretrained('mbart-large-en-ro') + example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" + batch: dict = tokenizer.prepare_translation_batch( + example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian + ) + """ + vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} max_model_input_sizes = {m: 1024 for m in _all_mbart_models} pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}} + lang_code_to_id = { # NOTE(SS): resize embeddings will break this + "ar_AR": 250001, + "cs_CZ": 250002, + "de_DE": 250003, + "en_XX": 250004, + "es_XX": 250005, + "et_EE": 250006, + "fi_FI": 250007, + "fr_XX": 250008, + "gu_IN": 250009, + "hi_IN": 250010, + "it_IT": 250011, + "ja_XX": 250012, + "kk_KZ": 250013, + "ko_KR": 250014, + "lt_LT": 250015, + "lv_LV": 250016, + "my_MM": 250017, + "ne_NP": 250018, + "nl_XX": 250019, + "ro_RO": 250020, + "ru_RU": 250021, + "si_LK": 250022, + "tr_TR": 250023, + "vi_VN": 250024, + "zh_CN": 250025, + } + cur_lang_code = lang_code_to_id["en_XX"] + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """Build model inputs from a sequence by appending eos_token_id.""" + special_tokens = [self.eos_token_id, self.cur_lang_code] + if token_ids_1 is None: + return token_ids_0 + special_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return token_ids_0 + token_ids_1 + special_tokens + + def prepare_translation_batch( + self, + src_texts: List[str], + src_lang: str = "en_XX", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "ro_RO", + max_length: Optional[int] = None, + pad_to_max_length: bool = True, + return_tensors: str = "pt", + ) -> BatchEncoding: + """ + Arguments: + src_texts: list of src language texts + src_lang: default en_XX (english) + tgt_texts: list of tgt language texts + tgt_lang: default ro_RO (romanian) + max_length: (None) defer to config (1024 for mbart-large-en-ro) + pad_to_max_length: (bool) + + Returns: + dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor. + """ + if max_length is None: + max_length = self.max_len + self.cur_lang_code = self.lang_code_to_id[src_lang] + model_inputs: BatchEncoding = self.batch_encode_plus( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + pad_to_max_length=pad_to_max_length, + ) + if tgt_texts is None: + return model_inputs + self.cur_lang_code = self.lang_code_to_id[tgt_lang] + decoder_inputs: BatchEncoding = self.batch_encode_plus( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + pad_to_max_length=pad_to_max_length, + ) + for k, v in decoder_inputs.items(): + model_inputs[f"decoder_{k}"] = v + self.cur_lang_code = self.lang_code_to_id[src_lang] + return model_inputs diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 366b7de2fb5..8e5daa96fd7 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -19,6 +19,7 @@ import unittest import timeout_decorator # noqa from transformers import is_torch_available +from transformers.file_utils import cached_property from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -37,6 +38,7 @@ if is_torch_available(): BartConfig, BartTokenizer, MBartTokenizer, + BatchEncoding, ) from transformers.modeling_bart import ( BART_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -197,15 +199,37 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): tiny(**inputs_dict) +EN_CODE = 250004 + + @require_torch -class BartTranslationTests(unittest.TestCase): - _model = None +class MBartIntegrationTests(unittest.TestCase): + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + " I ate lunch twice yesterday", + ] + tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"] + expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE] @classmethod def setUpClass(cls): checkpoint_name = "facebook/mbart-large-en-ro" cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name) cls.pad_token_id = 1 + return cls + + @cached_property + def model(self): + """Only load the model if needed.""" + + model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device) + if "cuda" in torch_device: + model = model.half() + return model + + @slow + def test_enro_forward(self): + model = self.model net_input = { "input_ids": _long_tensor( [ @@ -221,24 +245,9 @@ class BartTranslationTests(unittest.TestCase): ), "generation_mode": False, } - net_input["attention_mask"] = net_input["input_ids"].ne(cls.pad_token_id) - cls.net_input = net_input - - return cls - - @property - def model(self): - """Only load the model if needed.""" - if self._model is None: - model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro") - self._model = model.to(torch_device) - return self._model - - @slow - def test_enro_forward(self): - model = self.model + net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id) with torch.no_grad(): - logits, *other_stuff = model(**self.net_input) + logits, *other_stuff = model(**net_input) expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=torch_device) result_slice = logits[0][0][:3] @@ -246,19 +255,10 @@ class BartTranslationTests(unittest.TestCase): @slow def test_enro_generate(self): - model = self.model - # example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" - # inputs: dict = tokenizer.batch_encode_plus([example_english_phrase], return_tensors="pt",) - expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" - - inputs = { - "input_ids": torch.LongTensor( - [[8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2]] # 250004 - ) - } - translated_tokens = model.generate(input_ids=inputs["input_ids"].to(torch_device), num_beams=5,) + inputs: dict = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device) + translated_tokens = self.model.generate(input_ids=inputs["input_ids"].to(torch_device)) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) - self.assertEqual(expected_translation_romanian, decoded[0]) + self.assertEqual(self.tgt_text[0], decoded[0]) def test_mbart_enro_config(self): mbart_models = ["facebook/mbart-large-en-ro"] @@ -273,13 +273,6 @@ class BartTranslationTests(unittest.TestCase): e.args += (name, k) raise - def test_enro_tokenizer(self): - raw = "UN Chief Says There Is No Military Solution in Syria" - ids = self.tokenizer.batch_encode_plus([raw])["input_ids"][0] - expected_result = [0, 8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2] - # TODO(SS): should be [8274, ..., 2, 250020] - self.assertListEqual(expected_result, ids) - def test_mbart_fast_forward(self): config = BartConfig( vocab_size=99, @@ -301,6 +294,36 @@ class BartTranslationTests(unittest.TestCase): self.assertEqual(logits.shape, expected_shape) +@require_torch +class MBartTokenizerTests(MBartIntegrationTests): + def test_enro_tokenizer_prepare_translation_batch(self): + batch = self.tokenizer.prepare_translation_batch( + self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), + ) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 14), batch.input_ids.shape) + self.assertEqual((2, 14), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(self.expected_src_tokens, result) + self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS + + def test_enro_tokenizer_batch_encode_plus(self): + ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] + self.assertListEqual(self.expected_src_tokens, ids) + + def test_enro_tokenizer_truncation(self): + src_text = ["this is gunna be a long sentence " * 20] + assert isinstance(src_text[0], str) + desired_max_length = 10 + ids = self.tokenizer.prepare_translation_batch( + src_text, return_tensors=None, max_length=desired_max_length + ).input_ids[0] + self.assertEqual(ids[-2], 2) + self.assertEqual(ids[-1], EN_CODE) + self.assertEqual(len(ids), desired_max_length) + + @require_torch class BartHeadTests(unittest.TestCase): vocab_size = 99