mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[examples/translation] support mBART-50 and M2M100 fine-tuning (#11170)
* keep a list of multilingual tokenizers * add forced_bos_token argument
This commit is contained in:
parent
fb41f9f50c
commit
c161dd56df
@ -34,6 +34,9 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
M2M100Tokenizer,
|
||||
MBart50Tokenizer,
|
||||
MBart50TokenizerFast,
|
||||
MBartTokenizer,
|
||||
MBartTokenizerFast,
|
||||
Seq2SeqTrainer,
|
||||
@ -50,6 +53,9 @@ check_min_version("4.6.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
|
||||
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@ -191,6 +197,14 @@ class DataTrainingArguments:
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
)
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
@ -325,9 +339,6 @@ def main():
|
||||
|
||||
# Set decoder_start_token_id
|
||||
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
assert (
|
||||
data_args.target_lang is not None and data_args.source_lang is not None
|
||||
), "mBart requires --target_lang and --source_lang"
|
||||
if isinstance(tokenizer, MBartTokenizer):
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
|
||||
else:
|
||||
@ -352,11 +363,21 @@ def main():
|
||||
|
||||
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
|
||||
# ignore those attributes).
|
||||
if isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
|
||||
if data_args.source_lang is not None:
|
||||
tokenizer.src_lang = data_args.source_lang
|
||||
if data_args.target_lang is not None:
|
||||
tokenizer.tgt_lang = data_args.target_lang
|
||||
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
|
||||
assert data_args.target_lang is not None and data_args.source_lang is not None, (
|
||||
f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
|
||||
"--target_lang arguments."
|
||||
)
|
||||
|
||||
tokenizer.src_lang = data_args.source_lang
|
||||
tokenizer.tgt_lang = data_args.target_lang
|
||||
|
||||
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
|
||||
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
|
||||
forced_bos_token_id = (
|
||||
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
|
||||
)
|
||||
model.config.foced_bos_token_id = forced_bos_token_id
|
||||
|
||||
# Get the language codes for input/target.
|
||||
source_lang = data_args.source_lang.split("_")[0]
|
||||
|
Loading…
Reference in New Issue
Block a user