mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
AutoTokenizer supports mbart-large-en-ro (#5121)
This commit is contained in:
parent
2db1e2f415
commit
84be482f66
@ -19,7 +19,7 @@ import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, MBartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
@ -80,6 +80,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("camembert", CamembertConfig,),
|
||||
("xlm-roberta", XLMRobertaConfig,),
|
||||
("marian", MarianConfig,),
|
||||
("mbart", MBartConfig,),
|
||||
("bart", BartConfig,),
|
||||
("reformer", ReformerConfig,),
|
||||
("longformer", LongformerConfig,),
|
||||
|
@ -133,3 +133,7 @@ class BartConfig(PretrainedConfig):
|
||||
if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
|
||||
logger.info("This configuration is a mixture of MBART and BART settings")
|
||||
return False
|
||||
|
||||
|
||||
class MBartConfig(BartConfig):
|
||||
model_type = "mbart"
|
||||
|
@ -30,6 +30,7 @@ from .configuration_auto import (
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
MBartConfig,
|
||||
OpenAIGPTConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
@ -43,7 +44,7 @@ from .configuration_auto import (
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bart import BartTokenizer
|
||||
from .tokenization_bart import BartTokenizer, MBartTokenizer
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
@ -75,6 +76,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||
(AlbertConfig, (AlbertTokenizer, None)),
|
||||
(CamembertConfig, (CamembertTokenizer, None)),
|
||||
(MBartConfig, (MBartTokenizer, None)),
|
||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||
(MarianConfig, (MarianTokenizer, None)),
|
||||
(BartConfig, (BartTokenizer, None)),
|
||||
|
@ -31,6 +31,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
@ -38,7 +39,6 @@ if is_torch_available():
|
||||
BartForQuestionAnswering,
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
MBartTokenizer,
|
||||
BatchEncoding,
|
||||
pipeline,
|
||||
)
|
||||
@ -218,15 +218,14 @@ class MBartIntegrationTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
cls.tokenizer = MBartTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||
if "cuda" in torch_device:
|
||||
model = model.half()
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user