mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[s2s]Use prepare_translation_batch for Marian finetuning (#6293)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
2f2aa0c89c
commit
2804fff839
@ -63,7 +63,7 @@ Summarization Tips:
|
|||||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||||
|
|
||||||
**Update 2018-07-18**
|
**Update 2018-07-18**
|
||||||
Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.**
|
Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_translation_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.**
|
||||||
A new dataset is needed to support multilingual tasks.
|
A new dataset is needed to support multilingual tasks.
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -32,7 +32,7 @@ try:
|
|||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
calculate_bleu_score,
|
calculate_bleu_score,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
MBartDataset,
|
TranslationDataset,
|
||||||
label_smoothed_nll_loss,
|
label_smoothed_nll_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import (
|
from utils import (
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
MBartDataset,
|
TranslationDataset,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
lmap,
|
lmap,
|
||||||
@ -108,8 +108,8 @@ class SummarizationModule(BaseTransformer):
|
|||||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||||
if isinstance(self.tokenizer, MBartTokenizer):
|
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
|
||||||
self.dataset_class = MBartDataset
|
self.dataset_class = TranslationDataset
|
||||||
else:
|
else:
|
||||||
self.dataset_class = Seq2SeqDataset
|
self.dataset_class = Seq2SeqDataset
|
||||||
|
|
||||||
|
@ -14,14 +14,14 @@ from pytest import param
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import lightning_base
|
import lightning_base
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
from transformers.testing_utils import require_multigpu
|
from transformers.testing_utils import require_multigpu
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import SummarizationModule, main
|
from .finetune import SummarizationModule, main
|
||||||
from .pack_dataset import pack_data_dir
|
from .pack_dataset import pack_data_dir
|
||||||
from .run_eval import generate_summaries_or_translations, run_generate
|
from .run_eval import generate_summaries_or_translations, run_generate
|
||||||
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -406,8 +406,9 @@ def test_pack_dataset():
|
|||||||
assert orig_paths == new_paths
|
assert orig_paths == new_paths
|
||||||
|
|
||||||
|
|
||||||
def test_mbart_dataset_truncation():
|
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)])
|
||||||
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
|
def test_mbart_dataset_truncation(tok_name):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
@ -416,7 +417,7 @@ def test_mbart_dataset_truncation():
|
|||||||
assert max_len_target > max_src_len # Truncated
|
assert max_len_target > max_src_len # Truncated
|
||||||
assert max_len_source > max_src_len
|
assert max_len_source > max_src_len
|
||||||
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
||||||
train_dataset = MBartDataset(
|
train_dataset = TranslationDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
@ -433,6 +434,8 @@ def test_mbart_dataset_truncation():
|
|||||||
assert batch["input_ids"].shape[1] == max_src_len
|
assert batch["input_ids"].shape[1] == max_src_len
|
||||||
# show that targets are the same len
|
# show that targets are the same len
|
||||||
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
|
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
|
||||||
|
if tok_name == MARIAN_TINY:
|
||||||
|
continue
|
||||||
# check language codes in correct place
|
# check language codes in correct place
|
||||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||||
|
@ -146,7 +146,9 @@ class Seq2SeqDataset(Dataset):
|
|||||||
return SortishSampler(self.src_lens, batch_size)
|
return SortishSampler(self.src_lens, batch_size)
|
||||||
|
|
||||||
|
|
||||||
class MBartDataset(Seq2SeqDataset):
|
class TranslationDataset(Seq2SeqDataset):
|
||||||
|
"""A dataset that calls prepare_translation_batch."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.max_source_length != self.max_target_length:
|
if self.max_source_length != self.max_target_length:
|
||||||
|
@ -127,10 +127,12 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
tgt_texts: Optional[List[str]] = None,
|
tgt_texts: Optional[List[str]] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
|
max_target_length: Optional[int] = None,
|
||||||
pad_to_max_length: bool = True,
|
pad_to_max_length: bool = True,
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
truncation_strategy="only_first",
|
truncation_strategy="only_first",
|
||||||
padding="longest",
|
padding="longest",
|
||||||
|
**unused,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -162,6 +164,9 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
if tgt_texts is None:
|
if tgt_texts is None:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
if max_target_length is not None:
|
||||||
|
tokenizer_kwargs["max_length"] = max_target_length
|
||||||
|
|
||||||
self.current_spm = self.spm_target
|
self.current_spm = self.spm_target
|
||||||
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
||||||
for k, v in decoder_inputs.items():
|
for k, v in decoder_inputs.items():
|
||||||
|
Loading…
Reference in New Issue
Block a user