AutoTokenizer supports mbart-large-en-ro (#5121)

This commit is contained in:
Sam Shleifer 2020-06-18 20:47:37 -04:00 committed by GitHub
parent 2db1e2f415
commit 84be482f66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 6 deletions

View File

@ -19,7 +19,7 @@ import logging
from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
@ -80,6 +80,7 @@ CONFIG_MAPPING = OrderedDict(
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("marian", MarianConfig,),
("mbart", MBartConfig,),
("bart", BartConfig,),
("reformer", ReformerConfig,),
("longformer", LongformerConfig,),

View File

@ -133,3 +133,7 @@ class BartConfig(PretrainedConfig):
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
logger.info("This configuration is a mixture of MBART and BART settings")
return False
class MBartConfig(BartConfig):
model_type = "mbart"

View File

@ -30,6 +30,7 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
LongformerConfig,
MBartConfig,
OpenAIGPTConfig,
ReformerConfig,
RetriBertConfig,
@ -43,7 +44,7 @@ from .configuration_auto import (
from .configuration_marian import MarianConfig
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer
from .tokenization_bart import BartTokenizer
from .tokenization_bart import BartTokenizer, MBartTokenizer
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .tokenization_bert_japanese import BertJapaneseTokenizer
from .tokenization_camembert import CamembertTokenizer
@ -75,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict(
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, None)),
(CamembertConfig, (CamembertTokenizer, None)),
(MBartConfig, (MBartTokenizer, None)),
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(MarianConfig, (MarianTokenizer, None)),
(BartConfig, (BartTokenizer, None)),

View File

@ -31,6 +31,7 @@ if is_torch_available():
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartModel,
BartForConditionalGeneration,
@ -38,7 +39,6 @@ if is_torch_available():
BartForQuestionAnswering,
BartConfig,
BartTokenizer,
MBartTokenizer,
BatchEncoding,
pipeline,
)
@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
checkpoint_name = "facebook/mbart-large-en-ro"
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
def model(self):
"""Only load the model if needed."""
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
if "cuda" in torch_device:
model = model.half()
return model