mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
MBartForConditionalGeneration (#6441)
* add MBartForConditionalGeneration * style * rebase and fixes * add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS * fix docs * don't ignore mbart * doc * fix mbart fairseq link * put mbart before bart * apply doc suggestions
This commit is contained in:
parent
05810cd80a
commit
680f1337c3
@ -126,7 +126,9 @@ conversion utilities for the following models:
|
||||
Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
23. `Pegasus <https://github.com/google-research/pegasus>`_ (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization
|
||||
<https://arxiv.org/abs/1912.08777>`_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
24. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
24. `MBart <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`_ (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation <https://arxiv.org/abs/2001.08210>`_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov
|
||||
Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
25. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
<https://huggingface.co/users>`_.
|
||||
|
||||
.. toctree::
|
||||
@ -208,6 +210,7 @@ conversion utilities for the following models:
|
||||
model_doc/mobilebert
|
||||
model_doc/dpr
|
||||
model_doc/pegasus
|
||||
model_doc/mbart
|
||||
internal/modeling_utils
|
||||
internal/tokenization_utils
|
||||
internal/pipelines_utils
|
||||
|
@ -49,13 +49,6 @@ BartTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
MBartTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartTokenizer
|
||||
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch
|
||||
|
||||
|
||||
|
||||
BartModel
|
||||
~~~~~~~~~~~~~
|
||||
|
37
docs/source/model_doc/mbart.rst
Normal file
37
docs/source/model_doc/mbart.rst
Normal file
@ -0,0 +1,37 @@
|
||||
MBart
|
||||
----------------------------------------------------
|
||||
**DISCLAIMER:** If you see something strange,
|
||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||
@sshleifer
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
The MBart model was presented in `Multilingual Denoising Pre-training for Neural Machine Translation <https://arxiv.org/abs/2001.08210>`_ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov
|
||||
Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. According to the abstract,
|
||||
|
||||
MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text.
|
||||
|
||||
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/mbart>`__
|
||||
|
||||
|
||||
MBartConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartConfig
|
||||
:members:
|
||||
|
||||
|
||||
MBartTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartTokenizer
|
||||
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch
|
||||
|
||||
|
||||
MBartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartForConditionalGeneration
|
||||
:members: generate, forward
|
||||
|
||||
|
@ -22,7 +22,7 @@ import logging
|
||||
# Configurations
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
||||
from .configuration_bart import BartConfig, MBartConfig
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
@ -34,6 +34,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mbart import MBartConfig
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
@ -131,7 +132,7 @@ from .pipelines import (
|
||||
# Tokenizers
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
@ -149,6 +150,7 @@ from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
|
||||
from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_mbart import MBartTokenizer
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
@ -298,6 +300,7 @@ if is_torch_available():
|
||||
BartForQuestionAnswering,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_mbart import MBartForConditionalGeneration
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
from .modeling_roberta import (
|
||||
|
@ -19,7 +19,7 @@ import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
@ -30,6 +30,7 @@ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Flau
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
@ -52,6 +53,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
|
@ -32,6 +32,7 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
|
||||
}
|
||||
|
||||
BART_CONFIG_ARGS_DOC = r"""
|
||||
Args:
|
||||
vocab_size (:obj:`int`, optional, defaults to 50265):
|
||||
@ -209,8 +210,3 @@ class BartConfig(PretrainedConfig):
|
||||
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
||||
logger.info("This configuration is a mixture of MBART and BART settings")
|
||||
return False
|
||||
|
||||
|
||||
class MBartConfig(BartConfig):
|
||||
model_type = "mbart"
|
||||
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""
|
||||
|
32
src/transformers/configuration_mbart.py
Normal file
32
src/transformers/configuration_mbart.py
Normal file
@ -0,0 +1,32 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MBART configuration """
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_bart import BartConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
"facebook/mbart-large-cc25": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-cc25/config.json",
|
||||
}
|
||||
|
||||
|
||||
class MBartConfig(BartConfig):
|
||||
model_type = "mbart"
|
||||
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""
|
@ -32,6 +32,7 @@ from .configuration_auto import (
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
@ -116,6 +117,7 @@ from .modeling_longformer import (
|
||||
LongformerModel,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .modeling_mbart import MBartForConditionalGeneration
|
||||
from .modeling_mobilebert import (
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
@ -289,6 +291,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(PegasusConfig, PegasusForConditionalGeneration),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(MBartConfig, MBartForConditionalGeneration),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||
]
|
||||
|
38
src/transformers/modeling_mbart.py
Normal file
38
src/transformers/modeling_mbart.py
Normal file
@ -0,0 +1,38 @@
|
||||
from .configuration_mbart import MBartConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "MBartConfig"
|
||||
_TOKENIZER_FOR_DOC = "MBartTokenizer"
|
||||
|
||||
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/mbart-large-cc25",
|
||||
"facebook/mbart-large-en-ro",
|
||||
# See all multilingual BART models at https://huggingface.co/models?filter=mbart
|
||||
]
|
||||
|
||||
MBART_START_DOCSTRING = r"""
|
||||
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ sub-class.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.MBartConfig`): Model configuration class with all the parameters of the
|
||||
model. Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The BART Model with a language modeling head. Can be used for machine translation.", MBART_START_DOCSTRING
|
||||
)
|
||||
class MBartForConditionalGeneration(BartForConditionalGeneration):
|
||||
"""
|
||||
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the
|
||||
superclass for the appropriate documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = MBartConfig
|
@ -46,7 +46,7 @@ from .configuration_auto import (
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast, MBartTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
@ -57,6 +57,7 @@ from .tokenization_flaubert import FlaubertTokenizer
|
||||
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
from .tokenization_mbart import MBartTokenizer
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
|
@ -14,13 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -55,258 +50,3 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
||||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
||||
|
||||
|
||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
FAIRSEQ_LANGUAGE_CODES = [
|
||||
"ar_AR",
|
||||
"cs_CZ",
|
||||
"de_DE",
|
||||
"en_XX",
|
||||
"es_XX",
|
||||
"et_EE",
|
||||
"fi_FI",
|
||||
"fr_XX",
|
||||
"gu_IN",
|
||||
"hi_IN",
|
||||
"it_IT",
|
||||
"ja_XX",
|
||||
"kk_KZ",
|
||||
"ko_KR",
|
||||
"lt_LT",
|
||||
"lv_LV",
|
||||
"my_MM",
|
||||
"ne_NP",
|
||||
"nl_XX",
|
||||
"ro_RO",
|
||||
"ru_RU",
|
||||
"si_LK",
|
||||
"tr_TR",
|
||||
"vi_VN",
|
||||
"zh_CN",
|
||||
]
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs.
|
||||
Other tokenizer methods like ``encode`` do not work properly.
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
|
||||
``<language code> <tokens> <eos>``` for target language documents.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizer
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||
... )
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
||||
An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||
BOS is never used.
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Set to True if the token list is already formatted with special tokens for the model
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
prefix_ones = [1] * len(self.prefix_tokens)
|
||||
suffix_ones = [1] * len(self.suffix_tokens)
|
||||
if token_ids_1 is None:
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||
|
||||
Arguments:
|
||||
src_texts: (:obj:`list`):
|
||||
list of documents to summarize or source language texts
|
||||
src_lang: (:obj:`str`, `optional`, default='en_XX'):
|
||||
default en_XX (english), the language we are translating from
|
||||
tgt_texts: (:obj:`list`, `optional`):
|
||||
list of tgt language texts or summaries.
|
||||
tgt_lang: (:obj:`str`, `optional`, default='ro_RO'):
|
||||
default ro_RO (romanian), the language we are translating to
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
|
||||
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries)
|
||||
If left unset or set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
|
||||
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
|
||||
This does not include causal mask, which is built by the model.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.max_len
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
decoder_inputs: BatchEncoding = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
self.prefix_tokens = []
|
||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||
|
||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||
self.prefix_tokens = [self.cur_lang_code]
|
||||
self.suffix_tokens = [self.eos_token_id]
|
||||
|
279
src/transformers/tokenization_mbart.py
Normal file
279
src/transformers/tokenization_mbart.py
Normal file
@ -0,0 +1,279 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
from .tokenization_utils import BatchEncoding
|
||||
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
FAIRSEQ_LANGUAGE_CODES = [
|
||||
"ar_AR",
|
||||
"cs_CZ",
|
||||
"de_DE",
|
||||
"en_XX",
|
||||
"es_XX",
|
||||
"et_EE",
|
||||
"fi_FI",
|
||||
"fr_XX",
|
||||
"gu_IN",
|
||||
"hi_IN",
|
||||
"it_IT",
|
||||
"ja_XX",
|
||||
"kk_KZ",
|
||||
"ko_KR",
|
||||
"lt_LT",
|
||||
"lv_LV",
|
||||
"my_MM",
|
||||
"ne_NP",
|
||||
"nl_XX",
|
||||
"ro_RO",
|
||||
"ru_RU",
|
||||
"si_LK",
|
||||
"tr_TR",
|
||||
"vi_VN",
|
||||
"zh_CN",
|
||||
]
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs.
|
||||
Other tokenizer methods like ``encode`` do not work properly.
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
|
||||
``<language code> <tokens> <eos>``` for target language documents.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizer
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||
... )
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
max_model_input_sizes = {m: 1024 for m in _all_mbart_models}
|
||||
pretrained_vocab_files_map = {"vocab_file": {m: SPM_URL for m in _all_mbart_models}}
|
||||
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
||||
An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||
BOS is never used.
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Set to True if the token list is already formatted with special tokens for the model
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
prefix_ones = [1] * len(self.prefix_tokens)
|
||||
suffix_ones = [1] * len(self.suffix_tokens)
|
||||
if token_ids_1 is None:
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||
|
||||
Arguments:
|
||||
src_texts: (:obj:`list`):
|
||||
list of documents to summarize or source language texts
|
||||
src_lang: (:obj:`str`, `optional`, default='en_XX'):
|
||||
default en_XX (english), the language we are translating from
|
||||
tgt_texts: (:obj:`list`, `optional`):
|
||||
list of tgt language texts or summaries.
|
||||
tgt_lang: (:obj:`str`, `optional`, default='ro_RO'):
|
||||
default ro_RO (romanian), the language we are translating to
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
|
||||
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries)
|
||||
If left unset or set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
|
||||
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
|
||||
This does not include causal mask, which is built by the model.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.max_len
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
decoder_inputs: BatchEncoding = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
self.prefix_tokens = []
|
||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||
|
||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
|
||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||
self.prefix_tokens = [self.cur_lang_code]
|
||||
self.suffix_tokens = [self.eos_token_id]
|
@ -11,8 +11,8 @@ if is_torch_available():
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
MBartConfig,
|
||||
MBartForConditionalGeneration,
|
||||
BatchEncoding,
|
||||
AutoTokenizer,
|
||||
)
|
||||
@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
expected = {"scale_embedding": True, "output_past": True}
|
||||
for name in mbart_models:
|
||||
config = BartConfig.from_pretrained(name)
|
||||
config = MBartConfig.from_pretrained(name)
|
||||
self.assertTrue(config.is_valid_mbart())
|
||||
for k, v in expected.items():
|
||||
try:
|
||||
@ -102,7 +102,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
raise
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
config = MBartConfig(
|
||||
vocab_size=99,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
@ -115,7 +115,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
add_final_layer_norm=True,
|
||||
return_dict=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
lm_model = MBartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
|
@ -30,6 +30,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"test_modeling_tf_xlm_roberta.py",
|
||||
"test_modeling_xlm_roberta.py",
|
||||
"test_modeling_pegasus.py",
|
||||
"test_modeling_mbart.py",
|
||||
]
|
||||
|
||||
# Update this list for models that are not documented with a comment explaining the reason it should not be.
|
||||
@ -106,7 +107,6 @@ def get_model_test_files():
|
||||
"test_modeling_common",
|
||||
"test_modeling_encoder_decoder",
|
||||
"test_modeling_marian",
|
||||
"test_modeling_mbart",
|
||||
"test_modeling_tf_common",
|
||||
]
|
||||
test_files = []
|
||||
|
Loading…
Reference in New Issue
Block a user