7.3 KiB
mBART
mBART is a multilingual machine translation model that pretrains the entire translation model (encoder-decoder) unlike previous methods that only focused on parts of the model. The model is trained on a denoising objective which reconstructs the corrupted text. This allows mBART to handle the source language and the target text to translate to.
mBART-50 is pretrained on an additional 25 languages.
You can find all the original mBART checkpoints under the AI at Meta organization.
Tip
Click on the mBART models in the right sidebar for more examples of applying mBART to different language tasks.
Note
The
head_mask
argument is ignored when using all attention implementation other than "eager". If you have ahead_mask
and want it to have effect, load the model withXXXModel.from_pretrained(model_id, attn_implementation="eager")
The example below demonstrates how to translate text with [Pipeline
] or the [AutoModel
] class.
import torch
from transformers import pipeline
pipeline = pipeline(
task="translation",
model="facebook/mbart-large-50-many-to-many-mmt",
device=0,
torch_dtype=torch.float16,
src_lang="en_XX",
tgt_lang="fr_XX",
)
print(pipeline("UN Chief Says There Is No Military Solution in Syria"))
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
article_en = "UN Chief Says There Is No Military Solution in Syria"
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer.src_lang = "en_XX"
encoded_hi = tokenizer(article_en, return_tensors="pt").to("cuda")
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"], cache_implementation="static")
print(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
Notes
-
You can check the full list of language codes via
tokenizer.lang_code_to_id.keys()
. -
mBART requires a special language id token in the source and target text during training. The source text format is
X [eos, src_lang_code]
whereX
is the source text. The target text format is[tgt_lang_code] X [eos]
. Thebos
token is never used. The [~PreTrainedTokenizerBase._call_
] encodes the source text format passed as the first argument or with thetext
keyword. The target text format is passed with thetext_label
keyword. -
Set the
decoder_start_token_id
to the target language id for mBART.import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro", torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto") tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX") article = "UN Chief Says There Is No Military Solution in Syria" inputs = tokenizer(article, return_tensors="pt") translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"]) tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
-
mBART-50 has a different text format. The language id token is used as the prefix for the source and target text. The text format is
[lang_code] X [eos]
wherelang_code
is the source language id for the source text and target language id for the target text.X
is the source or target text respectively. -
Set the
eos_token_id
as thedecoder_start_token_id
for mBART-50. The target language id is used as the first generated token by passingforced_bos_token_id
to [~GenerationMixin.generate
].import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt", torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto") tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا." tokenizer.src_lang = "ar_AR" encoded_ar = tokenizer(article_ar, return_tensors="pt") generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]) tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
MBartConfig
autodoc MBartConfig
MBartTokenizer
autodoc MBartTokenizer - build_inputs_with_special_tokens
MBartTokenizerFast
autodoc MBartTokenizerFast
MBart50Tokenizer
autodoc MBart50Tokenizer
MBart50TokenizerFast
autodoc MBart50TokenizerFast
MBartModel
autodoc MBartModel
MBartForConditionalGeneration
autodoc MBartForConditionalGeneration
MBartForQuestionAnswering
autodoc MBartForQuestionAnswering
MBartForSequenceClassification
autodoc MBartForSequenceClassification
MBartForCausalLM
autodoc MBartForCausalLM - forward
TFMBartModel
autodoc TFMBartModel - call
TFMBartForConditionalGeneration
autodoc TFMBartForConditionalGeneration - call
FlaxMBartModel
autodoc FlaxMBartModel - call - encode - decode
FlaxMBartForConditionalGeneration
autodoc FlaxMBartForConditionalGeneration - call - encode - decode
FlaxMBartForSequenceClassification
autodoc FlaxMBartForSequenceClassification - call - encode - decode
FlaxMBartForQuestionAnswering
autodoc FlaxMBartForQuestionAnswering - call - encode - decode