MBartForConditionalGeneration (#6441)

* add MBartForConditionalGeneration

* style

* rebase and fixes

* add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS

* fix docs

* don't ignore mbart

* doc

* fix mbart fairseq link

* put mbart before bart

* apply doc suggestions
This commit is contained in:
Suraj Patil 2020-08-14 12:51:16 +05:30 committed by GitHub
parent 05810cd80a
commit 680f1337c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 410 additions and 283 deletions

View File

@ -126,7 +126,9 @@ conversion utilities for the following models:
Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
23. `Pegasus <https://github.com/google-research/pegasus>`_ (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization
<https://arxiv.org/abs/1912.08777>`_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
24. `Other community models <https://huggingface.co/models>`_, contributed by the `community
24. `MBart <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`_ (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation <https://arxiv.org/abs/2001.08210>`_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov
Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
25. `Other community models <https://huggingface.co/models>`_, contributed by the `community
<https://huggingface.co/users>`_.
.. toctree::
@ -208,6 +210,7 @@ conversion utilities for the following models:
model_doc/mobilebert
model_doc/dpr
model_doc/pegasus
model_doc/mbart
internal/modeling_utils
internal/tokenization_utils
internal/pipelines_utils

View File

@ -49,13 +49,6 @@ BartTokenizer
:members:
MBartTokenizer
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartTokenizer
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch
BartModel
~~~~~~~~~~~~~

View File

@ -0,0 +1,37 @@
MBart
----------------------------------------------------
**DISCLAIMER:** If you see something strange,
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
@sshleifer
Overview
~~~~~~~~~~~~~~~~~~~~~
The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation <https://arxiv.org/abs/2001.08210>`_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov
Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. According to the abstract,
MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text.
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`__
MBartConfig
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartConfig
:members:
MBartTokenizer
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartTokenizer
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch
MBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartForConditionalGeneration
:members: generate, forward

View File

@ -22,7 +22,7 @@ import logging
# Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
from .configuration_bart import BartConfig, MBartConfig
from .configuration_bart import BartConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
@ -34,6 +34,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from .configuration_marian import MarianConfig
from .configuration_mbart import MBartConfig
from .configuration_mmbt import MMBTConfig
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
@ -131,7 +132,7 @@ from .pipelines import (
# Tokenizers
from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer
from .tokenization_bart import BartTokenizer, BartTokenizerFast
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_camembert import CamembertTokenizer
@ -149,6 +150,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_mbart import MBartTokenizer
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
@ -298,6 +300,7 @@ if is_torch_available():
BartForQuestionAnswering,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_mbart import MBartForConditionalGeneration
from .modeling_marian import MarianMTModel
from .tokenization_marian import MarianTokenizer
from .modeling_roberta import (

View File

@ -19,7 +19,7 @@ import logging
from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from .configuration_marian import MarianConfig
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from .configuration_mobilebert import MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
@ -52,6 +53,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
for pretrained_map in [
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,

View File

@ -32,6 +32,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
}
BART_CONFIG_ARGS_DOC = r"""
Args:
vocab_size (:obj:`int`, optional, defaults to 50265):
@ -209,8 +210,3 @@ class BartConfig(PretrainedConfig):
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
logger.info("This configuration is a mixture of MBART and BART settings")
return False
class MBartConfig(BartConfig):
model_type = "mbart"
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""

View File

@ -0,0 +1,32 @@
# coding=utf-8
# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MBART configuration """
import logging
from .configuration_bart import BartConfig
logger = logging.getLogger(__name__)
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
"facebook/mbart-large-cc25": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-cc25/config.json",
}
class MBartConfig(BartConfig):
model_type = "mbart"
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""

View File

@ -32,6 +32,7 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
LongformerConfig,
MBartConfig,
MobileBertConfig,
OpenAIGPTConfig,
PegasusConfig,
@ -116,6 +117,7 @@ from .modeling_longformer import (
LongformerModel,
)
from .modeling_marian import MarianMTModel
from .modeling_mbart import MBartForConditionalGeneration
from .modeling_mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
@ -289,6 +291,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
(T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration),
(MarianConfig, MarianMTModel),
(MBartConfig, MBartForConditionalGeneration),
(BartConfig, BartForConditionalGeneration),
(EncoderDecoderConfig, EncoderDecoderModel),
]

View File

@ -0,0 +1,38 @@
from .configuration_mbart import MBartConfig
from .file_utils import add_start_docstrings
from .modeling_bart import BartForConditionalGeneration
_CONFIG_FOR_DOC = "MBartConfig"
_TOKENIZER_FOR_DOC = "MBartTokenizer"
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/mbart-large-cc25",
"facebook/mbart-large-en-ro",
# See all multilingual BART models at https://huggingface.co/models?filter=mbart
]
MBART_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
Parameters:
config (:class:`~transformers.MBartConfig`): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for machine translation.", MBART_START_DOCSTRING
)
class MBartForConditionalGeneration(BartForConditionalGeneration):
"""
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = MBartConfig

View File

@ -46,7 +46,7 @@ from .configuration_auto import (
)
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer
from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer
from .tokenization_bart import BartTokenizer, BartTokenizerFast
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .tokenization_bert_japanese import BertJapaneseTokenizer
from .tokenization_camembert import CamembertTokenizer
@ -57,6 +57,7 @@ from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_marian import MarianTokenizer
from .tokenization_mbart import MBartTokenizer
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer

View File

@ -14,13 +14,8 @@
# limitations under the License.
import logging
from typing import List, Optional
from .file_utils import add_start_docstrings_to_callable
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_utils import BatchEncoding
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.getLogger(__name__)
@ -55,258 +50,3 @@ class BartTokenizerFast(RobertaTokenizerFast):
"vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models},
}
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
FAIRSEQ_LANGUAGE_CODES = [
"ar_AR",
"cs_CZ",
"de_DE",
"en_XX",
"es_XX",
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
]
class MBartTokenizer(XLMRobertaTokenizer):
"""
This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs.
Other tokenizer methods like ``encode`` do not work properly.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
``<language code> <tokens> <eos>``` for target language documents.
Examples::
>>> from transformers import MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
... )
"""
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
}
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
self.cur_lang_code = self.lang_code_to_id["en_XX"]
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used.
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids.
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True if the token list is already formatted with special tokens for the model
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1] * len(self.suffix_tokens)
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO",
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
truncation: bool = True,
padding: str = "longest",
return_tensors: str = "pt",
**kwargs,
) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments:
src_texts: (:obj:`list`):
list of documents to summarize or source language texts
src_lang: (:obj:`str`, `optional`, default='en_XX'):
default en_XX (english), the language we are translating from
tgt_texts: (:obj:`list`, `optional`):
list of tgt language texts or summaries.
tgt_lang: (:obj:`str`, `optional`, default='ro_RO'):
default ro_RO (romanian), the language we are translating to
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.max_len
self.set_src_lang_special_tokens(src_lang)
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
self.set_tgt_lang_special_tokens(tgt_lang)
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=True,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

View File

@ -0,0 +1,279 @@
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional
from .file_utils import add_start_docstrings_to_callable
from .tokenization_utils import BatchEncoding
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.getLogger(__name__)
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
FAIRSEQ_LANGUAGE_CODES = [
"ar_AR",
"cs_CZ",
"de_DE",
"en_XX",
"es_XX",
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
]
class MBartTokenizer(XLMRobertaTokenizer):
"""
This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs.
Other tokenizer methods like ``encode`` do not work properly.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
``<language code> <tokens> <eos>``` for target language documents.
Examples::
>>> from transformers import MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
... )
"""
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
}
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
self.cur_lang_code = self.lang_code_to_id["en_XX"]
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used.
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids.
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True if the token list is already formatted with special tokens for the model
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1] * len(self.suffix_tokens)
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO",
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
truncation: bool = True,
padding: str = "longest",
return_tensors: str = "pt",
**kwargs,
) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments:
src_texts: (:obj:`list`):
list of documents to summarize or source language texts
src_lang: (:obj:`str`, `optional`, default='en_XX'):
default en_XX (english), the language we are translating from
tgt_texts: (:obj:`list`, `optional`):
list of tgt language texts or summaries.
tgt_lang: (:obj:`str`, `optional`, default='ro_RO'):
default ro_RO (romanian), the language we are translating to
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.max_len
self.set_src_lang_special_tokens(src_lang)
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
self.set_tgt_lang_special_tokens(tgt_lang)
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=True,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs
def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
def set_tgt_lang_special_tokens(self, lang: str) -> None:
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]

View File

@ -11,8 +11,8 @@ if is_torch_available():
import torch
from transformers import (
AutoModelForSeq2SeqLM,
BartConfig,
BartForConditionalGeneration,
MBartConfig,
MBartForConditionalGeneration,
BatchEncoding,
AutoTokenizer,
)
@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
mbart_models = ["facebook/mbart-large-en-ro"]
expected = {"scale_embedding": True, "output_past": True}
for name in mbart_models:
config = BartConfig.from_pretrained(name)
config = MBartConfig.from_pretrained(name)
self.assertTrue(config.is_valid_mbart())
for k, v in expected.items():
try:
@ -102,7 +102,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
raise
def test_mbart_fast_forward(self):
config = BartConfig(
config = MBartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
@ -115,7 +115,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
add_final_layer_norm=True,
return_dict=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
lm_model = MBartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)

View File

@ -30,6 +30,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
"test_modeling_tf_xlm_roberta.py",
"test_modeling_xlm_roberta.py",
"test_modeling_pegasus.py",
"test_modeling_mbart.py",
]
# Update this list for models that are not documented with a comment explaining the reason it should not be.
@ -106,7 +107,6 @@ def get_model_test_files():
"test_modeling_common",
"test_modeling_encoder_decoder",
"test_modeling_marian",
"test_modeling_mbart",
"test_modeling_tf_common",
]
test_files = []