mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Add mbart-large-cc25, support translation finetuning (#5129)
improve unittests for finetuning, especially w.r.t testing frozen parameters fix freeze_embeds for T5 add streamlit setup.cfg
This commit is contained in:
parent
141492448b
commit
353b8f1e7a
@ -39,6 +39,18 @@ BartTokenizer
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
MBartTokenizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.MBartTokenizer
|
||||||
|
:members: build_inputs_with_special_tokens, prepare_translation_batch
|
||||||
|
|
||||||
|
BartForConditionalGeneration
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.BartForConditionalGeneration
|
||||||
|
:members: generate, forward
|
||||||
|
|
||||||
BartModel
|
BartModel
|
||||||
~~~~~~~~~~~~~
|
~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -62,10 +74,3 @@ BartForQuestionAnswering
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
BartForConditionalGeneration
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
|
|
||||||
.. autoclass:: transformers.BartForConditionalGeneration
|
|
||||||
:members: generate, forward
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import faiss
|
import faiss
|
||||||
import nlp
|
import nlp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import streamlit as st
|
||||||
import torch
|
import torch
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
import streamlit as st
|
|
||||||
import transformers
|
import transformers
|
||||||
from eli5_utils import (
|
from eli5_utils import (
|
||||||
embed_questions_for_retrieval,
|
embed_questions_for_retrieval,
|
||||||
|
@ -41,6 +41,28 @@ If you are using your own data, it must be formatted as one directory with 6 fil
|
|||||||
The `.source` files are the input, the `.target` files are the desired output.
|
The `.source` files are the input, the `.target` files are the desired output.
|
||||||
|
|
||||||
|
|
||||||
|
### Tips and Tricks
|
||||||
|
|
||||||
|
General Tips:
|
||||||
|
- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
|
||||||
|
- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
|
||||||
|
- `fp16_opt_level=O1` (the default works best).
|
||||||
|
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||||
|
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||||
|
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||||
|
- This warning can be safely ignored:
|
||||||
|
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||||
|
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||||
|
- Read scripts before you run them!
|
||||||
|
|
||||||
|
Summarization Tips:
|
||||||
|
- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||||
|
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||||
|
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||||
|
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||||
|
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||||
|
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||||
|
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||||
|
|
||||||
### Summarization Finetuning
|
### Summarization Finetuning
|
||||||
Run/modify `finetune.sh`
|
Run/modify `finetune.sh`
|
||||||
@ -58,25 +80,20 @@ The following command should work on a 16GB GPU:
|
|||||||
|
|
||||||
*Note*: The following tips mostly apply to summarization finetuning.
|
*Note*: The following tips mostly apply to summarization finetuning.
|
||||||
|
|
||||||
Tips:
|
### Translation Finetuning
|
||||||
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
|
||||||
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
|
|
||||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
|
||||||
- `fp16_opt_level=O1` (the default works best).
|
|
||||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
|
||||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
|
||||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
|
||||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
|
||||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
|
||||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
|
||||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
|
||||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
|
||||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
|
||||||
- This warning can be safely ignored:
|
|
||||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
|
||||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
|
||||||
|
|
||||||
#### Finetuning Outputs
|
First, follow the wmt_en_ro download instructions.
|
||||||
|
Then you can finetune mbart_cc25 on english-romanian with the following command.
|
||||||
|
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
||||||
|
```bash
|
||||||
|
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
|
||||||
|
export BS=4
|
||||||
|
export GAS=8
|
||||||
|
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Finetuning Outputs
|
||||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||||
|
|
||||||
|
@ -14,11 +14,12 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from transformers import get_linear_schedule_with_warmup
|
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
assert_all_frozen,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
SummarizationDataset,
|
SummarizationDataset,
|
||||||
lmap,
|
lmap,
|
||||||
@ -47,6 +48,7 @@ except ImportError:
|
|||||||
get_git_info,
|
get_git_info,
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
calculate_bleu_score,
|
calculate_bleu_score,
|
||||||
|
assert_all_frozen,
|
||||||
)
|
)
|
||||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||||
|
|
||||||
@ -92,9 +94,12 @@ class SummarizationModule(BaseTransformer):
|
|||||||
if self.hparams.freeze_embeds:
|
if self.hparams.freeze_embeds:
|
||||||
self.freeze_embeds()
|
self.freeze_embeds()
|
||||||
if self.hparams.freeze_encoder:
|
if self.hparams.freeze_encoder:
|
||||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
freeze_params(self.model.get_encoder())
|
||||||
|
assert_all_frozen(self.model.get_encoder())
|
||||||
|
|
||||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||||
self.num_workers = hparams.num_workers
|
self.num_workers = hparams.num_workers
|
||||||
|
self.decoder_start_token_id = None
|
||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
@ -160,7 +165,12 @@ class SummarizationModule(BaseTransformer):
|
|||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
generated_ids = self.model.generate(
|
||||||
|
input_ids=source_ids,
|
||||||
|
attention_mask=source_mask,
|
||||||
|
use_cache=True,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
)
|
||||||
gen_time = (time.time() - t0) / source_ids.shape[0]
|
gen_time = (time.time() - t0) / source_ids.shape[0]
|
||||||
preds = self.ids_to_clean_text(generated_ids)
|
preds = self.ids_to_clean_text(generated_ids)
|
||||||
target = self.ids_to_clean_text(y)
|
target = self.ids_to_clean_text(y)
|
||||||
@ -276,6 +286,9 @@ class SummarizationModule(BaseTransformer):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||||
|
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -285,6 +298,13 @@ class TranslationModule(SummarizationModule):
|
|||||||
metric_names = ["bleu"]
|
metric_names = ["bleu"]
|
||||||
val_metric = "bleu"
|
val_metric = "bleu"
|
||||||
|
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||||
|
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||||
|
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||||
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||||
|
|
||||||
def calc_generative_metrics(self, preds, target) -> dict:
|
def calc_generative_metrics(self, preds, target) -> dict:
|
||||||
return calculate_bleu_score(preds, target)
|
return calculate_bleu_score(preds, target)
|
||||||
|
|
||||||
|
@ -1,18 +1,13 @@
|
|||||||
export OUTPUT_DIR_NAME=t5
|
|
||||||
export CURRENT_DIR=${PWD}
|
|
||||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
|
||||||
|
|
||||||
# Make output directory if it doesn't exist
|
|
||||||
mkdir -p $OUTPUT_DIR
|
|
||||||
|
|
||||||
# Add parent directory to python path to access lightning_base.py
|
# Add parent directory to python path to access lightning_base.py
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
python finetune.py \
|
python finetune.py \
|
||||||
--data_dir=./cnn-dailymail/cnn_dm \
|
--data_dir=$CNN_DIR \
|
||||||
--model_name_or_path=t5-large \
|
|
||||||
--learning_rate=3e-5 \
|
--learning_rate=3e-5 \
|
||||||
--train_batch_size=4 \
|
--train_batch_size=$BS \
|
||||||
--eval_batch_size=4 \
|
--eval_batch_size=$BS \
|
||||||
--output_dir=$OUTPUT_DIR \
|
--output_dir=$OUTPUT_DIR \
|
||||||
--do_train $@
|
--max_source_length=512 \
|
||||||
|
--val_check_interval=0.1 --n_val=200 \
|
||||||
|
--do_train --do_predict \
|
||||||
|
$@
|
||||||
|
@ -223,10 +223,30 @@ def test_finetune(model):
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
do_predict=True,
|
do_predict=True,
|
||||||
task=task,
|
task=task,
|
||||||
|
src_lang="en_XX",
|
||||||
|
tgt_lang="ro_RO",
|
||||||
|
freeze_encoder=True,
|
||||||
|
freeze_embeds=True,
|
||||||
)
|
)
|
||||||
assert "n_train" in args_d
|
assert "n_train" in args_d
|
||||||
args = argparse.Namespace(**args_d)
|
args = argparse.Namespace(**args_d)
|
||||||
main(args)
|
module = main(args)
|
||||||
|
|
||||||
|
input_embeds = module.model.get_input_embeddings()
|
||||||
|
assert not input_embeds.weight.requires_grad
|
||||||
|
if model == T5_TINY:
|
||||||
|
lm_head = module.model.lm_head
|
||||||
|
assert not lm_head.weight.requires_grad
|
||||||
|
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||||
|
|
||||||
|
else:
|
||||||
|
bart = module.model.model
|
||||||
|
embed_pos = bart.decoder.embed_positions
|
||||||
|
assert not embed_pos.weight.requires_grad
|
||||||
|
assert not bart.shared.weight.requires_grad
|
||||||
|
# check that embeds are the same
|
||||||
|
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||||
|
assert bart.decoder.embed_tokens == bart.shared
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -239,7 +259,12 @@ def test_dataset(tok):
|
|||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
trunc_target = 4
|
trunc_target = 4
|
||||||
train_dataset = SummarizationDataset(
|
train_dataset = SummarizationDataset(
|
||||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
tokenizer,
|
||||||
|
data_dir=tmp_dir,
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=20,
|
||||||
|
max_target_length=trunc_target,
|
||||||
|
tgt_lang="ro_RO",
|
||||||
)
|
)
|
||||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
|
21
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file
21
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file
@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
|
python finetune.py \
|
||||||
|
--learning_rate=3e-5 \
|
||||||
|
--fp16 \
|
||||||
|
--gpus 1 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict \
|
||||||
|
--val_check_interval 0.1 \
|
||||||
|
--n_val 500 \
|
||||||
|
--adam_eps 1e-06 \
|
||||||
|
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
|
||||||
|
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
|
||||||
|
--max_source_length=300 --max_target_length 300 --val_max_target_length=300 --test_max_target_length 300 \
|
||||||
|
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||||
|
--model_name_or_path facebook/mbart-large-cc25 \
|
||||||
|
--task translation \
|
||||||
|
--warmup_steps 500 \
|
||||||
|
--logger wandb --sortish_sampler \
|
||||||
|
$@
|
@ -14,6 +14,8 @@ from torch import nn
|
|||||||
from torch.utils.data import Dataset, Sampler
|
from torch.utils.data import Dataset, Sampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
|
|
||||||
def encode_file(
|
def encode_file(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -25,6 +27,7 @@ def encode_file(
|
|||||||
prefix="",
|
prefix="",
|
||||||
tok_name="",
|
tok_name="",
|
||||||
):
|
):
|
||||||
|
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||||
if not overwrite_cache and cache_path.exists():
|
if not overwrite_cache and cache_path.exists():
|
||||||
try:
|
try:
|
||||||
@ -46,8 +49,8 @@ def encode_file(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
padding="max_length" if pad_to_max_length else None,
|
padding="max_length" if pad_to_max_length else None,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
add_prefix_space=True,
|
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
|
**extra_kw,
|
||||||
)
|
)
|
||||||
assert tokenized.input_ids.shape[1] == max_length
|
assert tokenized.input_ids.shape[1] == max_length
|
||||||
examples.append(tokenized)
|
examples.append(tokenized)
|
||||||
@ -87,9 +90,14 @@ class SummarizationDataset(Dataset):
|
|||||||
n_obs=None,
|
n_obs=None,
|
||||||
overwrite_cache=False,
|
overwrite_cache=False,
|
||||||
prefix="",
|
prefix="",
|
||||||
|
src_lang=None,
|
||||||
|
tgt_lang=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# FIXME: the rstrip logic strips all the chars, it seems.
|
||||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||||
|
if hasattr(tokenizer, "set_lang") and src_lang is not None:
|
||||||
|
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
|
||||||
self.source = encode_file(
|
self.source = encode_file(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
os.path.join(data_dir, type_path + ".source"),
|
os.path.join(data_dir, type_path + ".source"),
|
||||||
@ -100,7 +108,8 @@ class SummarizationDataset(Dataset):
|
|||||||
)
|
)
|
||||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||||
if hasattr(tokenizer, "set_lang"):
|
if hasattr(tokenizer, "set_lang"):
|
||||||
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
|
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
|
||||||
|
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
|
||||||
self.target = encode_file(
|
self.target = encode_file(
|
||||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||||
)
|
)
|
||||||
@ -224,8 +233,8 @@ def get_git_info():
|
|||||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||||
|
|
||||||
|
|
||||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
|
||||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
|
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
|
||||||
aggregator = scoring.BootstrapAggregator()
|
aggregator = scoring.BootstrapAggregator()
|
||||||
|
|
||||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||||
|
@ -26,6 +26,7 @@ known_third_party =
|
|||||||
sacrebleu
|
sacrebleu
|
||||||
seqeval
|
seqeval
|
||||||
sklearn
|
sklearn
|
||||||
|
streamlit
|
||||||
tensorboardX
|
tensorboardX
|
||||||
tensorflow
|
tensorflow
|
||||||
tensorflow_datasets
|
tensorflow_datasets
|
||||||
|
@ -55,15 +55,16 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"]
|
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||||
|
|
||||||
|
|
||||||
class MBartTokenizer(XLMRobertaTokenizer):
|
class MBartTokenizer(XLMRobertaTokenizer):
|
||||||
"""
|
"""
|
||||||
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
|
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
|
||||||
Other tokenizer methods like encode do not work properly.
|
Other tokenizer methods like ``encode`` do not work properly.
|
||||||
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
|
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
|
||||||
|
``<language code> <tokens> <eos>``` for target language documents.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@ -109,24 +110,84 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
}
|
}
|
||||||
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
|
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
|
||||||
cur_lang_code = lang_code_to_id["en_XX"]
|
cur_lang_code = lang_code_to_id["en_XX"]
|
||||||
|
prefix_tokens: List[int] = []
|
||||||
|
suffix_tokens: List[int] = []
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||||
|
self.reset_special_tokens()
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
def reset_special_tokens(self) -> None:
|
||||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
||||||
special_tokens = [self.eos_token_id, self.cur_lang_code]
|
self.prefix_tokens = []
|
||||||
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||||
|
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
||||||
|
An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||||
|
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||||
|
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||||
|
BOS is never used.
|
||||||
|
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs to which the special tokens will be added
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
"""
|
||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return token_ids_0 + special_tokens
|
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return token_ids_0 + token_ids_1 + special_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer ``prepare_for_model`` methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of ids.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Set to True if the token list is already formatted with special tokens for the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if already_has_special_tokens:
|
||||||
|
if token_ids_1 is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
|
"ids is already formated with special tokens for the model."
|
||||||
|
)
|
||||||
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||||
|
prefix_ones = [1] * len(self.prefix_tokens)
|
||||||
|
suffix_ones = [1] * len(self.suffix_tokens)
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||||
|
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||||
|
|
||||||
def set_lang(self, lang: str) -> None:
|
def set_lang(self, lang: str) -> None:
|
||||||
"""Set the current language code in order to call tokenizer properly."""
|
"""Set the current language code in order to call tokenizer properly."""
|
||||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||||
|
self.prefix_tokens = [self.cur_lang_code]
|
||||||
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
def prepare_translation_batch(
|
def prepare_translation_batch(
|
||||||
self,
|
self,
|
||||||
@ -135,44 +196,45 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
tgt_texts: Optional[List[str]] = None,
|
tgt_texts: Optional[List[str]] = None,
|
||||||
tgt_lang: str = "ro_RO",
|
tgt_lang: str = "ro_RO",
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_to_max_length: bool = True,
|
padding: str = "longest",
|
||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
"""
|
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||||
Arguments:
|
Arguments:
|
||||||
src_texts: list of src language texts
|
src_texts: list of src language texts
|
||||||
src_lang: default en_XX (english)
|
src_lang: default en_XX (english), the language we are translating from
|
||||||
tgt_texts: list of tgt language texts
|
tgt_texts: list of tgt language texts
|
||||||
tgt_lang: default ro_RO (romanian)
|
tgt_lang: default ro_RO (romanian), the language we are translating to
|
||||||
max_length: (None) defer to config (1024 for mbart-large-en-ro)
|
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
|
||||||
pad_to_max_length: (bool)
|
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
|
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
|
||||||
"""
|
"""
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = self.max_len
|
max_length = self.max_len
|
||||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
model_inputs: BatchEncoding = self(
|
||||||
src_texts,
|
src_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_max_length=pad_to_max_length,
|
padding=padding,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
if tgt_texts is None:
|
if tgt_texts is None:
|
||||||
return model_inputs
|
return model_inputs
|
||||||
self.cur_lang_code = self.lang_code_to_id[tgt_lang]
|
self.set_lang(tgt_lang)
|
||||||
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
decoder_inputs: BatchEncoding = self(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
|
padding=padding,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_max_length=pad_to_max_length,
|
|
||||||
truncation=True,
|
truncation=True,
|
||||||
)
|
)
|
||||||
for k, v in decoder_inputs.items():
|
for k, v in decoder_inputs.items():
|
||||||
model_inputs[f"decoder_{k}"] = v
|
model_inputs[f"decoder_{k}"] = v
|
||||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||||
|
self.reset_special_tokens() # sets to src_lang
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
@ -19,7 +19,6 @@ import unittest
|
|||||||
import timeout_decorator # noqa
|
import timeout_decorator # noqa
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@ -31,7 +30,6 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForSeq2SeqLM,
|
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BartModel,
|
BartModel,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
@ -39,7 +37,6 @@ if is_torch_available():
|
|||||||
BartForQuestionAnswering,
|
BartForQuestionAnswering,
|
||||||
BartConfig,
|
BartConfig,
|
||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
BatchEncoding,
|
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.modeling_bart import (
|
from transformers.modeling_bart import (
|
||||||
@ -202,140 +199,6 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
tiny(**inputs_dict)
|
tiny(**inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
EN_CODE = 250004
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
class MBartIntegrationTests(unittest.TestCase):
|
|
||||||
src_text = [
|
|
||||||
" UN Chief Says There Is No Military Solution in Syria",
|
|
||||||
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
|
|
||||||
]
|
|
||||||
tgt_text = [
|
|
||||||
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
|
||||||
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
|
||||||
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
|
|
||||||
cls.pad_token_id = 1
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def model(self):
|
|
||||||
"""Only load the model if needed."""
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
|
||||||
if "cuda" in torch_device:
|
|
||||||
model = model.half()
|
|
||||||
return model
|
|
||||||
|
|
||||||
@slow
|
|
||||||
@unittest.skip("This has been failing since June 20th at least.")
|
|
||||||
def test_enro_forward(self):
|
|
||||||
model = self.model
|
|
||||||
net_input = {
|
|
||||||
"input_ids": _long_tensor(
|
|
||||||
[
|
|
||||||
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
|
|
||||||
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
"decoder_input_ids": _long_tensor(
|
|
||||||
[
|
|
||||||
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
|
|
||||||
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
}
|
|
||||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
|
||||||
with torch.no_grad():
|
|
||||||
logits, *other_stuff = model(**net_input)
|
|
||||||
|
|
||||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
|
|
||||||
result_slice = logits[0, 0, :3]
|
|
||||||
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_enro_generate(self):
|
|
||||||
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
|
|
||||||
translated_tokens = self.model.generate(**batch)
|
|
||||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
|
||||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
|
||||||
self.assertEqual(self.tgt_text[1], decoded[1])
|
|
||||||
|
|
||||||
def test_mbart_enro_config(self):
|
|
||||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
|
||||||
expected = {"scale_embedding": True, "output_past": True}
|
|
||||||
for name in mbart_models:
|
|
||||||
config = BartConfig.from_pretrained(name)
|
|
||||||
self.assertTrue(config.is_valid_mbart())
|
|
||||||
for k, v in expected.items():
|
|
||||||
try:
|
|
||||||
self.assertEqual(v, getattr(config, k))
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (name, k)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def test_mbart_fast_forward(self):
|
|
||||||
config = BartConfig(
|
|
||||||
vocab_size=99,
|
|
||||||
d_model=24,
|
|
||||||
encoder_layers=2,
|
|
||||||
decoder_layers=2,
|
|
||||||
encoder_attention_heads=2,
|
|
||||||
decoder_attention_heads=2,
|
|
||||||
encoder_ffn_dim=32,
|
|
||||||
decoder_ffn_dim=32,
|
|
||||||
max_position_embeddings=48,
|
|
||||||
add_final_layer_norm=True,
|
|
||||||
)
|
|
||||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
|
||||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
|
||||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
|
||||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
|
||||||
expected_shape = (*summary.shape, config.vocab_size)
|
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
|
||||||
|
|
||||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
|
||||||
batch = self.tokenizer.prepare_translation_batch(
|
|
||||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
|
||||||
)
|
|
||||||
self.assertIsInstance(batch, BatchEncoding)
|
|
||||||
|
|
||||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
|
||||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
|
||||||
result = batch.input_ids.tolist()[0]
|
|
||||||
self.assertListEqual(self.expected_src_tokens, result)
|
|
||||||
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
|
|
||||||
|
|
||||||
def test_enro_tokenizer_batch_encode_plus(self):
|
|
||||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
|
||||||
self.assertListEqual(self.expected_src_tokens, ids)
|
|
||||||
|
|
||||||
def test_enro_tokenizer_decode_ignores_language_codes(self):
|
|
||||||
self.assertIn(250020, self.tokenizer.all_special_ids)
|
|
||||||
generated_ids = [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
|
|
||||||
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
||||||
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
|
|
||||||
self.assertEqual(result, expected_romanian)
|
|
||||||
self.assertNotIn(self.tokenizer.eos_token, result)
|
|
||||||
|
|
||||||
def test_enro_tokenizer_truncation(self):
|
|
||||||
src_text = ["this is gunna be a long sentence " * 20]
|
|
||||||
assert isinstance(src_text[0], str)
|
|
||||||
desired_max_length = 10
|
|
||||||
ids = self.tokenizer.prepare_translation_batch(
|
|
||||||
src_text, return_tensors=None, max_length=desired_max_length
|
|
||||||
).input_ids[0]
|
|
||||||
self.assertEqual(ids[-2], 2)
|
|
||||||
self.assertEqual(ids[-1], EN_CODE)
|
|
||||||
self.assertEqual(len(ids), desired_max_length)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BartHeadTests(unittest.TestCase):
|
class BartHeadTests(unittest.TestCase):
|
||||||
vocab_size = 99
|
vocab_size = 99
|
||||||
|
142
tests/test_modeling_mbart.py
Normal file
142
tests/test_modeling_mbart.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
|
from .test_modeling_bart import TOLERANCE, _assert_tensors_equal, _long_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
BartConfig,
|
||||||
|
BartForConditionalGeneration,
|
||||||
|
BatchEncoding,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EN_CODE = 250004
|
||||||
|
RO_CODE = 250020
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
checkpoint_name = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||||
|
cls.pad_token_id = 1
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
"""Only load the model if needed."""
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
|
||||||
|
if "cuda" in torch_device:
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||||
|
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||||
|
src_text = [
|
||||||
|
" UN Chief Says There Is No Military Solution in Syria",
|
||||||
|
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
|
||||||
|
]
|
||||||
|
tgt_text = [
|
||||||
|
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||||
|
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
|
||||||
|
]
|
||||||
|
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@unittest.skip("This has been failing since June 20th at least.")
|
||||||
|
def test_enro_forward(self):
|
||||||
|
model = self.model
|
||||||
|
net_input = {
|
||||||
|
"input_ids": _long_tensor(
|
||||||
|
[
|
||||||
|
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
|
||||||
|
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"decoder_input_ids": _long_tensor(
|
||||||
|
[
|
||||||
|
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
|
||||||
|
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||||
|
with torch.no_grad():
|
||||||
|
logits, *other_stuff = model(**net_input)
|
||||||
|
|
||||||
|
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
|
||||||
|
result_slice = logits[0, 0, :3]
|
||||||
|
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_enro_generate(self):
|
||||||
|
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
|
||||||
|
translated_tokens = self.model.generate(**batch)
|
||||||
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
|
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||||
|
self.assertEqual(self.tgt_text[1], decoded[1])
|
||||||
|
|
||||||
|
def test_mbart_enro_config(self):
|
||||||
|
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||||
|
expected = {"scale_embedding": True, "output_past": True}
|
||||||
|
for name in mbart_models:
|
||||||
|
config = BartConfig.from_pretrained(name)
|
||||||
|
self.assertTrue(config.is_valid_mbart())
|
||||||
|
for k, v in expected.items():
|
||||||
|
try:
|
||||||
|
self.assertEqual(v, getattr(config, k))
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (name, k)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def test_mbart_fast_forward(self):
|
||||||
|
config = BartConfig(
|
||||||
|
vocab_size=99,
|
||||||
|
d_model=24,
|
||||||
|
encoder_layers=2,
|
||||||
|
decoder_layers=2,
|
||||||
|
encoder_attention_heads=2,
|
||||||
|
decoder_attention_heads=2,
|
||||||
|
encoder_ffn_dim=32,
|
||||||
|
decoder_ffn_dim=32,
|
||||||
|
max_position_embeddings=48,
|
||||||
|
add_final_layer_norm=True,
|
||||||
|
)
|
||||||
|
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||||
|
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||||
|
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||||
|
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||||
|
expected_shape = (*summary.shape, config.vocab_size)
|
||||||
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
|
||||||
|
checkpoint_name = "facebook/mbart-large-cc25"
|
||||||
|
src_text = [
|
||||||
|
" UN Chief Says There Is No Military Solution in Syria",
|
||||||
|
" I ate lunch twice yesterday",
|
||||||
|
]
|
||||||
|
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
|
||||||
|
|
||||||
|
@unittest.skip("This test is broken, still generates english")
|
||||||
|
def test_cc25_generate(self):
|
||||||
|
inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
|
||||||
|
translated_tokens = self.model.generate(
|
||||||
|
input_ids=inputs["input_ids"].to(torch_device),
|
||||||
|
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
|
||||||
|
)
|
||||||
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
|
self.assertEqual(self.tgt_text[0], decoded[0])
|
@ -903,6 +903,7 @@ class TokenizerTesterMixin:
|
|||||||
tokenizer.padding_side = "right"
|
tokenizer.padding_side = "right"
|
||||||
encoded_sequence = tokenizer.encode(sequence)
|
encoded_sequence = tokenizer.encode(sequence)
|
||||||
sequence_length = len(encoded_sequence)
|
sequence_length = len(encoded_sequence)
|
||||||
|
# FIXME: the next line should be padding(max_length) to avoid warning
|
||||||
padded_sequence = tokenizer.encode(
|
padded_sequence = tokenizer.encode(
|
||||||
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
|
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
|
||||||
)
|
)
|
||||||
|
156
tests/test_tokenization_mbart.py
Normal file
156
tests/test_tokenization_mbart.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
|
||||||
|
EN_CODE = 250004
|
||||||
|
RO_CODE = 250020
|
||||||
|
|
||||||
|
|
||||||
|
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
tokenizer_class = MBartTokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# We have a SentencePiece fixture for testing
|
||||||
|
tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_full_tokenizer(self):
|
||||||
|
tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("This is a test")
|
||||||
|
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens),
|
||||||
|
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||||
|
self.assertListEqual(
|
||||||
|
tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"9",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"é",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.assertListEqual(
|
||||||
|
ids,
|
||||||
|
[
|
||||||
|
value + tokenizer.fairseq_offset
|
||||||
|
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
|
||||||
|
# ^ unk: 2 + 1 = 3 unk: 2 + 1 = 3 ^
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||||
|
self.assertListEqual(
|
||||||
|
back_tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"<unk>",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"<unk>",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class MBartEnroIntegrationTest(unittest.TestCase):
|
||||||
|
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||||
|
src_text = [
|
||||||
|
" UN Chief Says There Is No Military Solution in Syria",
|
||||||
|
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
|
||||||
|
]
|
||||||
|
tgt_text = [
|
||||||
|
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||||
|
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
|
||||||
|
]
|
||||||
|
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||||
|
cls.pad_token_id = 1
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||||
|
batch = self.tokenizer.prepare_translation_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||||
|
)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
|
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||||
|
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||||
|
result = batch.input_ids.tolist()[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, result)
|
||||||
|
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
||||||
|
# Test that special tokens are reset
|
||||||
|
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||||
|
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||||
|
|
||||||
|
def test_enro_tokenizer_batch_encode_plus(self):
|
||||||
|
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, ids)
|
||||||
|
|
||||||
|
def test_enro_tokenizer_decode_ignores_language_codes(self):
|
||||||
|
self.assertIn(RO_CODE, self.tokenizer.all_special_ids)
|
||||||
|
generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
|
||||||
|
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
|
||||||
|
self.assertEqual(result, expected_romanian)
|
||||||
|
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||||
|
|
||||||
|
def test_enro_tokenizer_truncation(self):
|
||||||
|
src_text = ["this is gunna be a long sentence " * 20]
|
||||||
|
assert isinstance(src_text[0], str)
|
||||||
|
desired_max_length = 10
|
||||||
|
ids = self.tokenizer.prepare_translation_batch(
|
||||||
|
src_text, return_tensors=None, max_length=desired_max_length
|
||||||
|
).input_ids[0]
|
||||||
|
self.assertEqual(ids[-2], 2)
|
||||||
|
self.assertEqual(ids[-1], EN_CODE)
|
||||||
|
self.assertEqual(len(ids), desired_max_length)
|
Loading…
Reference in New Issue
Block a user