From 2804fff8393dbda5098b8c9f5e36235e89c50023 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 6 Aug 2020 14:58:38 -0400 Subject: [PATCH] [s2s]Use prepare_translation_batch for Marian finetuning (#6293) Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- examples/seq2seq/README.md | 2 +- examples/seq2seq/finetune.py | 10 +++++----- examples/seq2seq/test_seq2seq_examples.py | 13 ++++++++----- examples/seq2seq/utils.py | 4 +++- src/transformers/tokenization_marian.py | 5 +++++ 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index dd026784169..63b5b078204 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -63,7 +63,7 @@ Summarization Tips: (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). **Update 2018-07-18** -Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.** +Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_translation_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.** A new dataset is needed to support multilingual tasks. diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index c7138295460..702d71ba579 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -14,7 +14,7 @@ import torch from torch.utils.data import DataLoader from lightning_base import BaseTransformer, add_generic_args, generic_train -from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup +from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup try: @@ -32,7 +32,7 @@ try: ROUGE_KEYS, calculate_bleu_score, Seq2SeqDataset, - MBartDataset, + TranslationDataset, label_smoothed_nll_loss, ) @@ -40,7 +40,7 @@ try: except ImportError: from utils import ( Seq2SeqDataset, - MBartDataset, + TranslationDataset, assert_all_frozen, use_task_specific_params, lmap, @@ -108,8 +108,8 @@ class SummarizationModule(BaseTransformer): if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] self.model.config.decoder_start_token_id = self.decoder_start_token_id - if isinstance(self.tokenizer, MBartTokenizer): - self.dataset_class = MBartDataset + if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer): + self.dataset_class = TranslationDataset else: self.dataset_class = Seq2SeqDataset diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 7473e0a64bd..06719446d6c 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -14,14 +14,14 @@ from pytest import param from torch.utils.data import DataLoader import lightning_base -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers.testing_utils import require_multigpu from .distillation import distill_main, evaluate_checkpoint from .finetune import SummarizationModule, main from .pack_dataset import pack_data_dir from .run_eval import generate_summaries_or_translations, run_generate -from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json +from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json logging.basicConfig(level=logging.DEBUG) @@ -406,8 +406,9 @@ def test_pack_dataset(): assert orig_paths == new_paths -def test_mbart_dataset_truncation(): - tokenizer = MBartTokenizer.from_pretrained(MBART_TINY) +@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]) +def test_mbart_dataset_truncation(tok_name): + tokenizer = AutoTokenizer.from_pretrained(tok_name) tmp_dir = make_test_data_dir() max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) @@ -416,7 +417,7 @@ def test_mbart_dataset_truncation(): assert max_len_target > max_src_len # Truncated assert max_len_source > max_src_len src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON - train_dataset = MBartDataset( + train_dataset = TranslationDataset( tokenizer, data_dir=tmp_dir, type_path="train", @@ -433,6 +434,8 @@ def test_mbart_dataset_truncation(): assert batch["input_ids"].shape[1] == max_src_len # show that targets are the same len assert batch["decoder_input_ids"].shape[1] == max_tgt_len + if tok_name == MARIAN_TINY: + continue # check language codes in correct place assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 7d9288333c9..1c13c0aa28e 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -146,7 +146,9 @@ class Seq2SeqDataset(Dataset): return SortishSampler(self.src_lens, batch_size) -class MBartDataset(Seq2SeqDataset): +class TranslationDataset(Seq2SeqDataset): + """A dataset that calls prepare_translation_batch.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.max_source_length != self.max_target_length: diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index 46ff3ff457c..211dfda8a2b 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -127,10 +127,12 @@ class MarianTokenizer(PreTrainedTokenizer): src_texts: List[str], tgt_texts: Optional[List[str]] = None, max_length: Optional[int] = None, + max_target_length: Optional[int] = None, pad_to_max_length: bool = True, return_tensors: str = "pt", truncation_strategy="only_first", padding="longest", + **unused, ) -> BatchEncoding: """Prepare model inputs for translation. For best performance, translate one sentence at a time. Arguments: @@ -162,6 +164,9 @@ class MarianTokenizer(PreTrainedTokenizer): if tgt_texts is None: return model_inputs + if max_target_length is not None: + tokenizer_kwargs["max_length"] = max_target_length + self.current_spm = self.spm_target decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs) for k, v in decoder_inputs.items():