From 680f1337c3c894b681ebbd24c18122a61e33715b Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 14 Aug 2020 12:51:16 +0530 Subject: [PATCH] MBartForConditionalGeneration (#6441) * add MBartForConditionalGeneration * style * rebase and fixes * add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS * fix docs * don't ignore mbart * doc * fix mbart fairseq link * put mbart before bart * apply doc suggestions --- docs/source/index.rst | 5 +- docs/source/model_doc/bart.rst | 7 - docs/source/model_doc/mbart.rst | 37 ++++ src/transformers/__init__.py | 7 +- src/transformers/configuration_auto.py | 4 +- src/transformers/configuration_bart.py | 6 +- src/transformers/configuration_mbart.py | 32 +++ src/transformers/modeling_auto.py | 3 + src/transformers/modeling_mbart.py | 38 ++++ src/transformers/tokenization_auto.py | 3 +- src/transformers/tokenization_bart.py | 260 ---------------------- src/transformers/tokenization_mbart.py | 279 ++++++++++++++++++++++++ tests/test_modeling_mbart.py | 10 +- utils/check_repo.py | 2 +- 14 files changed, 410 insertions(+), 283 deletions(-) create mode 100644 docs/source/model_doc/mbart.rst create mode 100644 src/transformers/configuration_mbart.py create mode 100644 src/transformers/modeling_mbart.py create mode 100644 src/transformers/tokenization_mbart.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 6e2cc51c229..2d1385af0f0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -126,7 +126,9 @@ conversion utilities for the following models: Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 23. `Pegasus `_ (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -24. `Other community models `_, contributed by the `community +24. `MBart `_ (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov + Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. +25. `Other community models `_, contributed by the `community `_. .. toctree:: @@ -208,6 +210,7 @@ conversion utilities for the following models: model_doc/mobilebert model_doc/dpr model_doc/pegasus + model_doc/mbart internal/modeling_utils internal/tokenization_utils internal/pipelines_utils diff --git a/docs/source/model_doc/bart.rst b/docs/source/model_doc/bart.rst index 81d138232dd..42318003cd2 100644 --- a/docs/source/model_doc/bart.rst +++ b/docs/source/model_doc/bart.rst @@ -49,13 +49,6 @@ BartTokenizer :members: -MBartTokenizer -~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.MBartTokenizer - :members: build_inputs_with_special_tokens, prepare_seq2seq_batch - - BartModel ~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst new file mode 100644 index 00000000000..7305fce9412 --- /dev/null +++ b/docs/source/model_doc/mbart.rst @@ -0,0 +1,37 @@ +MBart +---------------------------------------------------- +**DISCLAIMER:** If you see something strange, +file a `Github Issue `__ and assign +@sshleifer + +Overview +~~~~~~~~~~~~~~~~~~~~~ +The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation `_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov +Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. According to the abstract, + +MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text. + +The Authors' code can be found `here `__ + + +MBartConfig +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartConfig + :members: + + +MBartTokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartTokenizer + :members: build_inputs_with_special_tokens, prepare_seq2seq_batch + + +MBartForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartForConditionalGeneration + :members: generate, forward + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b9ddf890ebe..781737eaf05 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -22,7 +22,7 @@ import logging # Configurations from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig -from .configuration_bart import BartConfig, MBartConfig +from .configuration_bart import BartConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig @@ -34,6 +34,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_marian import MarianConfig +from .configuration_mbart import MBartConfig from .configuration_mmbt import MMBTConfig from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig @@ -131,7 +132,7 @@ from .pipelines import ( # Tokenizers from .tokenization_albert import AlbertTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer -from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer +from .tokenization_bart import BartTokenizer, BartTokenizerFast from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_camembert import CamembertTokenizer @@ -149,6 +150,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast +from .tokenization_mbart import MBartTokenizer from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_pegasus import PegasusTokenizer @@ -298,6 +300,7 @@ if is_torch_available(): BartForQuestionAnswering, BART_PRETRAINED_MODEL_ARCHIVE_LIST, ) + from .modeling_mbart import MBartForConditionalGeneration from .modeling_marian import MarianMTModel from .tokenization_marian import MarianTokenizer from .modeling_roberta import ( diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 1b1d32a019d..62090e931ad 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -19,7 +19,7 @@ import logging from collections import OrderedDict from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig -from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig +from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig @@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig from .configuration_marian import MarianConfig +from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig from .configuration_mobilebert import MobileBertConfig from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from .configuration_pegasus import PegasusConfig @@ -52,6 +53,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( for pretrained_map in [ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BART_PRETRAINED_CONFIG_ARCHIVE_MAP, + MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 42d69ba4464..8edcb4e69ce 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -32,6 +32,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", "yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json", } + BART_CONFIG_ARGS_DOC = r""" Args: vocab_size (:obj:`int`, optional, defaults to 50265): @@ -209,8 +210,3 @@ class BartConfig(PretrainedConfig): if self.normalize_before or self.add_final_layer_norm or self.scale_embedding: logger.info("This configuration is a mixture of MBART and BART settings") return False - - -class MBartConfig(BartConfig): - model_type = "mbart" - """See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json.""" diff --git a/src/transformers/configuration_mbart.py b/src/transformers/configuration_mbart.py new file mode 100644 index 00000000000..a01fef69143 --- /dev/null +++ b/src/transformers/configuration_mbart.py @@ -0,0 +1,32 @@ +# coding=utf-8 +# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MBART configuration """ + +import logging + +from .configuration_bart import BartConfig + + +logger = logging.getLogger(__name__) + +MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", + "facebook/mbart-large-cc25": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-cc25/config.json", +} + + +class MBartConfig(BartConfig): + model_type = "mbart" + """See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json.""" diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 7a45267703a..02088f565c8 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -32,6 +32,7 @@ from .configuration_auto import ( FlaubertConfig, GPT2Config, LongformerConfig, + MBartConfig, MobileBertConfig, OpenAIGPTConfig, PegasusConfig, @@ -116,6 +117,7 @@ from .modeling_longformer import ( LongformerModel, ) from .modeling_marian import MarianMTModel +from .modeling_mbart import MBartForConditionalGeneration from .modeling_mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, @@ -289,6 +291,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( (T5Config, T5ForConditionalGeneration), (PegasusConfig, PegasusForConditionalGeneration), (MarianConfig, MarianMTModel), + (MBartConfig, MBartForConditionalGeneration), (BartConfig, BartForConditionalGeneration), (EncoderDecoderConfig, EncoderDecoderModel), ] diff --git a/src/transformers/modeling_mbart.py b/src/transformers/modeling_mbart.py new file mode 100644 index 00000000000..60e47fe3152 --- /dev/null +++ b/src/transformers/modeling_mbart.py @@ -0,0 +1,38 @@ +from .configuration_mbart import MBartConfig +from .file_utils import add_start_docstrings +from .modeling_bart import BartForConditionalGeneration + + +_CONFIG_FOR_DOC = "MBartConfig" +_TOKENIZER_FOR_DOC = "MBartTokenizer" + +MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/mbart-large-cc25", + "facebook/mbart-large-en-ro", + # See all multilingual BART models at https://huggingface.co/models?filter=mbart +] + +MBART_START_DOCSTRING = r""" + + This model is a PyTorch `torch.nn.Module `__ sub-class. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.MBartConfig`): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + + +@add_start_docstrings( + "The BART Model with a language modeling head. Can be used for machine translation.", MBART_START_DOCSTRING +) +class MBartForConditionalGeneration(BartForConditionalGeneration): + """ + This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the + superclass for the appropriate documentation alongside usage examples. + """ + + config_class = MBartConfig diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index cdede96cf5b..8fee16f1bc9 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -46,7 +46,7 @@ from .configuration_auto import ( ) from .configuration_utils import PretrainedConfig from .tokenization_albert import AlbertTokenizer -from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer +from .tokenization_bart import BartTokenizer, BartTokenizerFast from .tokenization_bert import BertTokenizer, BertTokenizerFast from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_camembert import CamembertTokenizer @@ -57,6 +57,7 @@ from .tokenization_flaubert import FlaubertTokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast from .tokenization_marian import MarianTokenizer +from .tokenization_mbart import MBartTokenizer from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast from .tokenization_pegasus import PegasusTokenizer diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 2348ce86d66..e72d4963b27 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -14,13 +14,8 @@ # limitations under the License. import logging -from typing import List, Optional -from .file_utils import add_start_docstrings_to_callable from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast -from .tokenization_utils import BatchEncoding -from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING -from .tokenization_xlm_roberta import XLMRobertaTokenizer logger = logging.getLogger(__name__) @@ -55,258 +50,3 @@ class BartTokenizerFast(RobertaTokenizerFast): "vocab_file": {m: vocab_url for m in _all_bart_models}, "merges_file": {m: merges_url for m in _all_bart_models}, } - - -_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"] -SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" - -FAIRSEQ_LANGUAGE_CODES = [ - "ar_AR", - "cs_CZ", - "de_DE", - "en_XX", - "es_XX", - "et_EE", - "fi_FI", - "fr_XX", - "gu_IN", - "hi_IN", - "it_IT", - "ja_XX", - "kk_KZ", - "ko_KR", - "lt_LT", - "lv_LV", - "my_MM", - "ne_NP", - "nl_XX", - "ro_RO", - "ru_RU", - "si_LK", - "tr_TR", - "vi_VN", - "zh_CN", -] - - -class MBartTokenizer(XLMRobertaTokenizer): - """ - This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs. - Other tokenizer methods like ``encode`` do not work properly. - The tokenization method is `` `` for source language documents, and - `` ``` for target language documents. - - Examples:: - - >>> from transformers import MBartTokenizer - >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') - >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" - >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" - >>> batch: dict = tokenizer.prepare_seq2seq_batch( - ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian - ... ) - - """ - - vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} - max_model_input_sizes = {m: 1024 for m in _all_mbart_models} - pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}} - - prefix_tokens: List[int] = [] - suffix_tokens: List[int] = [] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.sp_model_size = len(self.sp_model) - self.lang_code_to_id = { - code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) - } - self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} - self.cur_lang_code = self.lang_code_to_id["en_XX"] - self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset - - self.fairseq_tokens_to_ids.update(self.lang_code_to_id) - self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} - self._additional_special_tokens = list(self.lang_code_to_id.keys()) - self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks - by concatenating and adding special tokens. The special tokens depend on calling set_lang. - An MBART sequence has the following format, where ``X`` represents the sequence: - - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` - - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]`` - BOS is never used. - Pairs of sequences are not the expected use case, but they will be handled without a separator. - - Args: - token_ids_0 (:obj:`List[int]`): - List of IDs to which the special tokens will be added - token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): - Optional second list of IDs for sequence pairs. - - Returns: - :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. - """ - if token_ids_1 is None: - return self.prefix_tokens + token_ids_0 + self.suffix_tokens - # We don't expect to process pairs, but leave the pair logic for API consistency - return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: - """ - Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer ``prepare_for_model`` methods. - - Args: - token_ids_0 (:obj:`List[int]`): - List of ids. - token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): - Set to True if the token list is already formatted with special tokens for the model - - Returns: - :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - - if already_has_special_tokens: - if token_ids_1 is not None: - raise ValueError( - "You should not supply a second sequence if the provided sequence of " - "ids is already formated with special tokens for the model." - ) - return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) - prefix_ones = [1] * len(self.prefix_tokens) - suffix_ones = [1] * len(self.suffix_tokens) - if token_ids_1 is None: - return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones - return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones - - @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) - def prepare_seq2seq_batch( - self, - src_texts: List[str], - src_lang: str = "en_XX", - tgt_texts: Optional[List[str]] = None, - tgt_lang: str = "ro_RO", - max_length: Optional[int] = None, - max_target_length: Optional[int] = None, - truncation: bool = True, - padding: str = "longest", - return_tensors: str = "pt", - **kwargs, - ) -> BatchEncoding: - """Prepare a batch that can be passed directly to an instance of MBartModel. - - Arguments: - src_texts: (:obj:`list`): - list of documents to summarize or source language texts - src_lang: (:obj:`str`, `optional`, default='en_XX'): - default en_XX (english), the language we are translating from - tgt_texts: (:obj:`list`, `optional`): - list of tgt language texts or summaries. - tgt_lang: (:obj:`str`, `optional`, default='ro_RO'): - default ro_RO (romanian), the language we are translating to - max_length (:obj:`int`, `optional`): - Controls the maximum length for encoder inputs (documents to summarize or source language texts) - If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum - length is required by one of the truncation/padding parameters. If the model has no specific maximum - input length (like XLNet) truncation/padding to a maximum length will be deactivated. - max_target_length (:obj:`int`, `optional`): - Controls the maximum length of decoder inputs (target language texts or summaries) - If left unset or set to :obj:`None`, this will use the max_length value. - padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): - Activates and controls padding. Accepts the following values: - - * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a - single sequence if provided). - * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the - maximum acceptable input length for the model if that argument is not provided. - * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of - different lengths). - return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): - If set, will return tensors instead of list of python integers. Acceptable values are: - - * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. - * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. - * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. - truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): - Activates and controls truncation. Accepts the following values: - - * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument - :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not - provided. This will truncate token by token, removing a token from the longest sequence in the pair - if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to - the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or - to the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with - sequence lengths greater than the model maximum admissible input size). - - Return: - :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - - - **input_ids** -- List of token ids to be fed to the encoder. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **decoder_input_ids** -- List of token ids to be fed to the decoder. - - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. - This does not include causal mask, which is built by the model. - - The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, - will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. - - """ - if max_length is None: - max_length = self.max_len - self.set_src_lang_special_tokens(src_lang) - model_inputs: BatchEncoding = self( - src_texts, - add_special_tokens=True, - return_tensors=return_tensors, - max_length=max_length, - padding=padding, - truncation=truncation, - **kwargs, - ) - if tgt_texts is None: - return model_inputs - # Process tgt_texts - if max_target_length is None: - max_target_length = max_length - self.set_tgt_lang_special_tokens(tgt_lang) - decoder_inputs: BatchEncoding = self( - tgt_texts, - add_special_tokens=True, - return_tensors=return_tensors, - padding=padding, - max_length=max_target_length, - truncation=True, - **kwargs, - ) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v - - self.set_src_lang_special_tokens(src_lang) # sets to src_lang - return model_inputs - - def set_src_lang_special_tokens(self, src_lang) -> None: - """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code].""" - self.cur_lang_code = self.lang_code_to_id[src_lang] - self.prefix_tokens = [] - self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] - - def set_tgt_lang_special_tokens(self, lang: str) -> None: - """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos].""" - self.cur_lang_code = self.lang_code_to_id[lang] - self.prefix_tokens = [self.cur_lang_code] - self.suffix_tokens = [self.eos_token_id] diff --git a/src/transformers/tokenization_mbart.py b/src/transformers/tokenization_mbart.py new file mode 100644 index 00000000000..e74a765340f --- /dev/null +++ b/src/transformers/tokenization_mbart.py @@ -0,0 +1,279 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Optional + +from .file_utils import add_start_docstrings_to_callable +from .tokenization_utils import BatchEncoding +from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING +from .tokenization_xlm_roberta import XLMRobertaTokenizer + + +logger = logging.getLogger(__name__) + +_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"] +SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" + +FAIRSEQ_LANGUAGE_CODES = [ + "ar_AR", + "cs_CZ", + "de_DE", + "en_XX", + "es_XX", + "et_EE", + "fi_FI", + "fr_XX", + "gu_IN", + "hi_IN", + "it_IT", + "ja_XX", + "kk_KZ", + "ko_KR", + "lt_LT", + "lv_LV", + "my_MM", + "ne_NP", + "nl_XX", + "ro_RO", + "ru_RU", + "si_LK", + "tr_TR", + "vi_VN", + "zh_CN", +] + + +class MBartTokenizer(XLMRobertaTokenizer): + """ + This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs. + Other tokenizer methods like ``encode`` do not work properly. + The tokenization method is `` `` for source language documents, and + `` ``` for target language documents. + + Examples:: + + >>> from transformers import MBartTokenizer + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') + >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" + >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> batch: dict = tokenizer.prepare_seq2seq_batch( + ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian + ... ) + + """ + + vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} + max_model_input_sizes = {m: 1024 for m in _all_mbart_models} + pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}} + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.sp_model_size = len(self.sp_model) + self.lang_code_to_id = { + code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) + } + self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} + self.cur_lang_code = self.lang_code_to_id["en_XX"] + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + + self.fairseq_tokens_to_ids.update(self.lang_code_to_id) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + self._additional_special_tokens = list(self.lang_code_to_id.keys()) + self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. The special tokens depend on calling set_lang. + An MBART sequence has the following format, where ``X`` represents the sequence: + - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` + - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]`` + BOS is never used. + Pairs of sequences are not the expected use case, but they will be handled without a separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` methods. + + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) + def prepare_seq2seq_batch( + self, + src_texts: List[str], + src_lang: str = "en_XX", + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "ro_RO", + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + truncation: bool = True, + padding: str = "longest", + return_tensors: str = "pt", + **kwargs, + ) -> BatchEncoding: + """Prepare a batch that can be passed directly to an instance of MBartModel. + + Arguments: + src_texts: (:obj:`list`): + list of documents to summarize or source language texts + src_lang: (:obj:`str`, `optional`, default='en_XX'): + default en_XX (english), the language we are translating from + tgt_texts: (:obj:`list`, `optional`): + list of tgt language texts or summaries. + tgt_lang: (:obj:`str`, `optional`, default='ro_RO'): + default ro_RO (romanian), the language we are translating to + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts) + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries) + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **decoder_input_ids** -- List of token ids to be fed to the decoder. + - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. + This does not include causal mask, which is built by the model. + + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + + """ + if max_length is None: + max_length = self.max_len + self.set_src_lang_special_tokens(src_lang) + model_inputs: BatchEncoding = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + self.set_tgt_lang_special_tokens(tgt_lang) + decoder_inputs: BatchEncoding = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=True, + **kwargs, + ) + for k, v in decoder_inputs.items(): + model_inputs[f"decoder_{k}"] = v + + self.set_src_lang_special_tokens(src_lang) # sets to src_lang + return model_inputs + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code].""" + self.cur_lang_code = self.lang_code_to_id[src_lang] + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos].""" + self.cur_lang_code = self.lang_code_to_id[lang] + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 65ab6c1fb9b..810ceae74e6 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -11,8 +11,8 @@ if is_torch_available(): import torch from transformers import ( AutoModelForSeq2SeqLM, - BartConfig, - BartForConditionalGeneration, + MBartConfig, + MBartForConditionalGeneration, BatchEncoding, AutoTokenizer, ) @@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): mbart_models = ["facebook/mbart-large-en-ro"] expected = {"scale_embedding": True, "output_past": True} for name in mbart_models: - config = BartConfig.from_pretrained(name) + config = MBartConfig.from_pretrained(name) self.assertTrue(config.is_valid_mbart()) for k, v in expected.items(): try: @@ -102,7 +102,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): raise def test_mbart_fast_forward(self): - config = BartConfig( + config = MBartConfig( vocab_size=99, d_model=24, encoder_layers=2, @@ -115,7 +115,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): add_final_layer_norm=True, return_dict=True, ) - lm_model = BartForConditionalGeneration(config).to(torch_device) + lm_model = MBartForConditionalGeneration(config).to(torch_device) context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) diff --git a/utils/check_repo.py b/utils/check_repo.py index afc5abd9d69..9a3154ec313 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -30,6 +30,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ "test_modeling_tf_xlm_roberta.py", "test_modeling_xlm_roberta.py", "test_modeling_pegasus.py", + "test_modeling_mbart.py", ] # Update this list for models that are not documented with a comment explaining the reason it should not be. @@ -106,7 +107,6 @@ def get_model_test_files(): "test_modeling_common", "test_modeling_encoder_decoder", "test_modeling_marian", - "test_modeling_mbart", "test_modeling_tf_common", ] test_files = []