diff --git a/examples/seq2seq/run_seq2seq.py b/examples/seq2seq/run_seq2seq.py index d73f24877c1..0f481d49207 100755 --- a/examples/seq2seq/run_seq2seq.py +++ b/examples/seq2seq/run_seq2seq.py @@ -386,7 +386,7 @@ def main(): # For translation we set the codes of our source and target languages (only useful for mBART, the others will # ignore those attributes). - if data_args.task.startswith("translation"): + if data_args.task.startswith("translation") or isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if data_args.source_lang is not None: tokenizer.src_lang = data_args.source_lang if data_args.target_lang is not None: