Add mbart-large-cc25, support translation finetuning (#5129)

improve unittests for finetuning, especially w.r.t testing frozen parameters
fix freeze_embeds for T5
add streamlit setup.cfg
This commit is contained in:
Sam Shleifer 2020-07-07 13:23:01 -04:00 committed by GitHub
parent 141492448b
commit 353b8f1e7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 521 additions and 204 deletions

View File

@ -39,6 +39,18 @@ BartTokenizer
:members:
MBartTokenizer
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartTokenizer
:members: build_inputs_with_special_tokens, prepare_translation_batch
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForConditionalGeneration
:members: generate, forward
BartModel
~~~~~~~~~~~~~
@ -62,10 +74,3 @@ BartForQuestionAnswering
:members: forward
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForConditionalGeneration
:members: generate, forward

View File

@ -1,10 +1,10 @@
import faiss
import nlp
import numpy as np
import streamlit as st
import torch
from elasticsearch import Elasticsearch
import streamlit as st
import transformers
from eli5_utils import (
embed_questions_for_retrieval,

View File

@ -41,6 +41,28 @@ If you are using your own data, it must be formatted as one directory with 6 fil
The `.source` files are the input, the `.target` files are the desired output.
### Tips and Tricks
General Tips:
- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
- Read scripts before you run them!
Summarization Tips:
- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
### Summarization Finetuning
Run/modify `finetune.sh`
@ -58,25 +80,20 @@ The following command should work on a 16GB GPU:
*Note*: The following tips mostly apply to summarization finetuning.
Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
### Translation Finetuning
#### Finetuning Outputs
First, follow the wmt_en_ro download instructions.
Then you can finetune mbart_cc25 on english-romanian with the following command.
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
```bash
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
export BS=4
export GAS=8
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
```
### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:

View File

@ -14,11 +14,12 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
try:
from .utils import (
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
@ -47,6 +48,7 @@ except ImportError:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
assert_all_frozen,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
@ -92,9 +94,12 @@ class SummarizationModule(BaseTransformer):
if self.hparams.freeze_embeds:
self.freeze_embeds()
if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5
freeze_params(self.model.get_encoder())
assert_all_frozen(self.model.get_encoder())
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@ -160,7 +165,12 @@ class SummarizationModule(BaseTransformer):
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
generated_ids = self.model.generate(
input_ids=source_ids,
attention_mask=source_mask,
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
gen_time = (time.time() - t0) / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
@ -276,6 +286,9 @@ class SummarizationModule(BaseTransformer):
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
return parser
@ -285,6 +298,13 @@ class TranslationModule(SummarizationModule):
metric_names = ["bleu"]
val_metric = "bleu"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)

View File

@ -1,18 +1,13 @@
export OUTPUT_DIR_NAME=t5
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_name_or_path=t5-large \
--data_dir=$CNN_DIR \
--learning_rate=3e-5 \
--train_batch_size=4 \
--eval_batch_size=4 \
--train_batch_size=$BS \
--eval_batch_size=$BS \
--output_dir=$OUTPUT_DIR \
--do_train $@
--max_source_length=512 \
--val_check_interval=0.1 --n_val=200 \
--do_train --do_predict \
$@

View File

@ -223,10 +223,30 @@ def test_finetune(model):
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
@pytest.mark.parametrize(
@ -239,7 +259,12 @@ def test_dataset(tok):
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
tgt_lang="ro_RO",
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:

View File

@ -0,0 +1,21 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--val_check_interval 0.1 \
--n_val 500 \
--adam_eps 1e-06 \
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
--max_source_length=300 --max_target_length 300 --val_max_target_length=300 --test_max_target_length 300 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--model_name_or_path facebook/mbart-large-cc25 \
--task translation \
--warmup_steps 500 \
--logger wandb --sortish_sampler \
$@

View File

@ -14,6 +14,8 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(
tokenizer,
@ -25,6 +27,7 @@ def encode_file(
prefix="",
tok_name="",
):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
if not overwrite_cache and cache_path.exists():
try:
@ -46,8 +49,8 @@ def encode_file(
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
add_prefix_space=True,
return_tensors=return_tensors,
**extra_kw,
)
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
@ -87,9 +90,14 @@ class SummarizationDataset(Dataset):
n_obs=None,
overwrite_cache=False,
prefix="",
src_lang=None,
tgt_lang=None,
):
super().__init__()
# FIXME: the rstrip logic strips all the chars, it seems.
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
if hasattr(tokenizer, "set_lang") and src_lang is not None:
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
@ -100,7 +108,8 @@ class SummarizationDataset(Dataset):
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
@ -224,8 +233,8 @@ def get_git_info():
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):

View File

@ -26,6 +26,7 @@ known_third_party =
sacrebleu
seqeval
sklearn
streamlit
tensorboardX
tensorflow
tensorflow_datasets

View File

@ -55,15 +55,16 @@ class BartTokenizerFast(RobertaTokenizerFast):
}
_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"]
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
class MBartTokenizer(XLMRobertaTokenizer):
"""
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
Other tokenizer methods like encode do not work properly.
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
Other tokenizer methods like ``encode`` do not work properly.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
``<language code> <tokens> <eos>``` for target language documents.
Examples::
@ -109,24 +110,84 @@ class MBartTokenizer(XLMRobertaTokenizer):
}
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
cur_lang_code = lang_code_to_id["en_XX"]
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.reset_special_tokens()
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
special_tokens = [self.eos_token_id, self.cur_lang_code]
def reset_special_tokens(self) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
self.prefix_tokens = []
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
An MBART sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
BOS is never used.
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return token_ids_0 + special_tokens
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + special_tokens
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` methods.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids.
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True if the token list is already formatted with special tokens for the model
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1] * len(self.suffix_tokens)
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def set_lang(self, lang: str) -> None:
"""Set the current language code in order to call tokenizer properly."""
self.cur_lang_code = self.lang_code_to_id[lang]
self.prefix_tokens = [self.cur_lang_code]
self.suffix_tokens = [self.eos_token_id]
def prepare_translation_batch(
self,
@ -135,44 +196,45 @@ class MBartTokenizer(XLMRobertaTokenizer):
tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO",
max_length: Optional[int] = None,
pad_to_max_length: bool = True,
padding: str = "longest",
return_tensors: str = "pt",
) -> BatchEncoding:
"""
"""Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments:
src_texts: list of src language texts
src_lang: default en_XX (english)
src_lang: default en_XX (english), the language we are translating from
tgt_texts: list of tgt language texts
tgt_lang: default ro_RO (romanian)
max_length: (None) defer to config (1024 for mbart-large-en-ro)
pad_to_max_length: (bool)
tgt_lang: default ro_RO (romanian), the language we are translating to
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
Returns:
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
"""
if max_length is None:
max_length = self.max_len
self.cur_lang_code = self.lang_code_to_id[src_lang]
model_inputs: BatchEncoding = self.batch_encode_plus(
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
padding=padding,
truncation=True,
)
if tgt_texts is None:
return model_inputs
self.cur_lang_code = self.lang_code_to_id[tgt_lang]
decoder_inputs: BatchEncoding = self.batch_encode_plus(
self.set_lang(tgt_lang)
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
truncation=True,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
self.cur_lang_code = self.lang_code_to_id[src_lang]
self.reset_special_tokens() # sets to src_lang
return model_inputs

View File

@ -19,7 +19,6 @@ import unittest
import timeout_decorator # noqa
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -31,7 +30,6 @@ if is_torch_available():
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartModel,
BartForConditionalGeneration,
@ -39,7 +37,6 @@ if is_torch_available():
BartForQuestionAnswering,
BartConfig,
BartTokenizer,
BatchEncoding,
pipeline,
)
from transformers.modeling_bart import (
@ -202,140 +199,6 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
tiny(**inputs_dict)
EN_CODE = 250004
@require_torch
class MBartIntegrationTests(unittest.TestCase):
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
]
tgt_text = [
"Şeful ONU declară că nu există o soluţie militară în Siria",
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
@classmethod
def setUpClass(cls):
checkpoint_name = "facebook/mbart-large-en-ro"
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
def model(self):
"""Only load the model if needed."""
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
if "cuda" in torch_device:
model = model.half()
return model
@slow
@unittest.skip("This has been failing since June 20th at least.")
def test_enro_forward(self):
model = self.model
net_input = {
"input_ids": _long_tensor(
[
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
]
),
"decoder_input_ids": _long_tensor(
[
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
]
),
}
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
with torch.no_grad():
logits, *other_stuff = model(**net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
result_slice = logits[0, 0, :3]
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
@slow
def test_enro_generate(self):
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text[0], decoded[0])
self.assertEqual(self.tgt_text[1], decoded[1])
def test_mbart_enro_config(self):
mbart_models = ["facebook/mbart-large-en-ro"]
expected = {"scale_embedding": True, "output_past": True}
for name in mbart_models:
config = BartConfig.from_pretrained(name)
self.assertTrue(config.is_valid_mbart())
for k, v in expected.items():
try:
self.assertEqual(v, getattr(config, k))
except AssertionError as e:
e.args += (name, k)
raise
def test_mbart_fast_forward(self):
config = BartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
def test_enro_tokenizer_prepare_translation_batch(self):
batch = self.tokenizer.prepare_translation_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
def test_enro_tokenizer_decode_ignores_language_codes(self):
self.assertIn(250020, self.tokenizer.all_special_ids)
generated_ids = [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
self.assertEqual(result, expected_romanian)
self.assertNotIn(self.tokenizer.eos_token, result)
def test_enro_tokenizer_truncation(self):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_translation_batch(
src_text, return_tensors=None, max_length=desired_max_length
).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length)
@require_torch
class BartHeadTests(unittest.TestCase):
vocab_size = 99

View File

@ -0,0 +1,142 @@
import unittest
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bart import TOLERANCE, _assert_tensors_equal, _long_tensor
if is_torch_available():
import torch
from transformers import (
AutoModelForSeq2SeqLM,
BartConfig,
BartForConditionalGeneration,
BatchEncoding,
AutoTokenizer,
)
EN_CODE = 250004
RO_CODE = 250020
@require_torch
class AbstractMBartIntegrationTest(unittest.TestCase):
checkpoint_name = None
@classmethod
def setUpClass(cls):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
def model(self):
"""Only load the model if needed."""
model = AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
if "cuda" in torch_device:
model = model.half()
return model
@require_torch
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
checkpoint_name = "facebook/mbart-large-en-ro"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
]
tgt_text = [
"Şeful ONU declară că nu există o soluţie militară în Siria",
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
@slow
@unittest.skip("This has been failing since June 20th at least.")
def test_enro_forward(self):
model = self.model
net_input = {
"input_ids": _long_tensor(
[
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
]
),
"decoder_input_ids": _long_tensor(
[
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
]
),
}
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
with torch.no_grad():
logits, *other_stuff = model(**net_input)
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
result_slice = logits[0, 0, :3]
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
@slow
def test_enro_generate(self):
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text[0], decoded[0])
self.assertEqual(self.tgt_text[1], decoded[1])
def test_mbart_enro_config(self):
mbart_models = ["facebook/mbart-large-en-ro"]
expected = {"scale_embedding": True, "output_past": True}
for name in mbart_models:
config = BartConfig.from_pretrained(name)
self.assertTrue(config.is_valid_mbart())
for k, v in expected.items():
try:
self.assertEqual(v, getattr(config, k))
except AssertionError as e:
e.args += (name, k)
raise
def test_mbart_fast_forward(self):
config = BartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
)
lm_model = BartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
checkpoint_name = "facebook/mbart-large-cc25"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
" I ate lunch twice yesterday",
]
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
@unittest.skip("This test is broken, still generates english")
def test_cc25_generate(self):
inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
translated_tokens = self.model.generate(
input_ids=inputs["input_ids"].to(torch_device),
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text[0], decoded[0])

View File

@ -903,6 +903,7 @@ class TokenizerTesterMixin:
tokenizer.padding_side = "right"
encoded_sequence = tokenizer.encode(sequence)
sequence_length = len(encoded_sequence)
# FIXME: the next line should be padding(max_length) to avoid warning
padded_sequence = tokenizer.encode(
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
)

View File

@ -0,0 +1,156 @@
import unittest
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
from transformers.testing_utils import require_torch
from .test_tokenization_common import TokenizerTesterMixin
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
EN_CODE = 250004
RO_CODE = 250020
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MBartTokenizer
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def test_full_tokenizer(self):
tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids,
[
value + tokenizer.fairseq_offset
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
# ^ unk: 2 + 1 = 3 unk: 2 + 1 = 3 ^
],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
@require_torch
class MBartEnroIntegrationTest(unittest.TestCase):
checkpoint_name = "facebook/mbart-large-en-ro"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
]
tgt_text = [
"Şeful ONU declară că nu există o soluţie militară în Siria",
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
@classmethod
def setUpClass(cls):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
cls.pad_token_id = 1
return cls
def test_enro_tokenizer_prepare_translation_batch(self):
batch = self.tokenizer.prepare_translation_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 14), batch.input_ids.shape)
self.assertEqual((2, 14), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
def test_enro_tokenizer_batch_encode_plus(self):
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
self.assertListEqual(self.expected_src_tokens, ids)
def test_enro_tokenizer_decode_ignores_language_codes(self):
self.assertIn(RO_CODE, self.tokenizer.all_special_ids)
generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
self.assertEqual(result, expected_romanian)
self.assertNotIn(self.tokenizer.eos_token, result)
def test_enro_tokenizer_truncation(self):
src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer.prepare_translation_batch(
src_text, return_tensors=None, max_length=desired_max_length
).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length)