[s2s]Use prepare_translation_batch for Marian finetuning (#6293)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sam Shleifer 2020-08-06 14:58:38 -04:00 committed by GitHub
parent 2f2aa0c89c
commit 2804fff839
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 12 deletions

View File

@ -63,7 +63,7 @@ Summarization Tips:
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18**
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_translation_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.**
A new dataset is needed to support multilingual tasks.

View File

@ -14,7 +14,7 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
try:
@ -32,7 +32,7 @@ try:
ROUGE_KEYS,
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
TranslationDataset,
label_smoothed_nll_loss,
)
@ -40,7 +40,7 @@ try:
except ImportError:
from utils import (
Seq2SeqDataset,
MBartDataset,
TranslationDataset,
assert_all_frozen,
use_task_specific_params,
lmap,
@ -108,8 +108,8 @@ class SummarizationModule(BaseTransformer):
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
self.dataset_class = TranslationDataset
else:
self.dataset_class = Seq2SeqDataset

View File

@ -14,14 +14,14 @@ from pytest import param
from torch.utils.data import DataLoader
import lightning_base
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
@ -406,8 +406,9 @@ def test_pack_dataset():
assert orig_paths == new_paths
def test_mbart_dataset_truncation():
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)])
def test_mbart_dataset_truncation(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
@ -416,7 +417,7 @@ def test_mbart_dataset_truncation():
assert max_len_target > max_src_len # Truncated
assert max_len_source > max_src_len
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
train_dataset = MBartDataset(
train_dataset = TranslationDataset(
tokenizer,
data_dir=tmp_dir,
type_path="train",
@ -433,6 +434,8 @@ def test_mbart_dataset_truncation():
assert batch["input_ids"].shape[1] == max_src_len
# show that targets are the same len
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
if tok_name == MARIAN_TINY:
continue
# check language codes in correct place
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id

View File

@ -146,7 +146,9 @@ class Seq2SeqDataset(Dataset):
return SortishSampler(self.src_lens, batch_size)
class MBartDataset(Seq2SeqDataset):
class TranslationDataset(Seq2SeqDataset):
"""A dataset that calls prepare_translation_batch."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.max_source_length != self.max_target_length:

View File

@ -127,10 +127,12 @@ class MarianTokenizer(PreTrainedTokenizer):
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
pad_to_max_length: bool = True,
return_tensors: str = "pt",
truncation_strategy="only_first",
padding="longest",
**unused,
) -> BatchEncoding:
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
@ -162,6 +164,9 @@ class MarianTokenizer(PreTrainedTokenizer):
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
self.current_spm = self.spm_target
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
for k, v in decoder_inputs.items():