[examples/translation] support mBART-50 and M2M100 fine-tuning (#11170)

* keep a list of multilingual tokenizers

* add forced_bos_token argument
This commit is contained in:
Suraj Patil 2021-04-09 23:58:42 +05:30 committed by GitHub
parent fb41f9f50c
commit c161dd56df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,6 +34,9 @@ from transformers import (
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
@ -50,6 +53,9 @@ check_min_version("4.6.0.dev0")
logger = logging.getLogger(__name__)
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
@dataclass
class ModelArguments:
@ -191,6 +197,14 @@ class DataTrainingArguments:
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@ -325,9 +339,6 @@ def main():
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
assert (
data_args.target_lang is not None and data_args.source_lang is not None
), "mBart requires --target_lang and --source_lang"
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
@ -352,11 +363,21 @@ def main():
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
# ignore those attributes).
if isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if data_args.source_lang is not None:
tokenizer.src_lang = data_args.source_lang
if data_args.target_lang is not None:
tokenizer.tgt_lang = data_args.target_lang
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
assert data_args.target_lang is not None and data_args.source_lang is not None, (
f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
"--target_lang arguments."
)
tokenizer.src_lang = data_args.source_lang
tokenizer.tgt_lang = data_args.target_lang
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = (
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
)
model.config.foced_bos_token_id = forced_bos_token_id
# Get the language codes for input/target.
source_lang = data_args.source_lang.split("_")[0]