mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add mbart-large-cc25, support translation finetuning (#5129)
improve unittests for finetuning, especially w.r.t testing frozen parameters fix freeze_embeds for T5 add streamlit setup.cfg
This commit is contained in:
parent
141492448b
commit
353b8f1e7a
@ -39,6 +39,18 @@ BartTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
MBartTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartTokenizer
|
||||
:members: build_inputs_with_special_tokens, prepare_translation_batch
|
||||
|
||||
BartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForConditionalGeneration
|
||||
:members: generate, forward
|
||||
|
||||
BartModel
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
@ -62,10 +74,3 @@ BartForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
BartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForConditionalGeneration
|
||||
:members: generate, forward
|
||||
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import faiss
|
||||
import nlp
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import streamlit as st
|
||||
import transformers
|
||||
from eli5_utils import (
|
||||
embed_questions_for_retrieval,
|
||||
|
@ -41,6 +41,28 @@ If you are using your own data, it must be formatted as one directory with 6 fil
|
||||
The `.source` files are the input, the `.target` files are the desired output.
|
||||
|
||||
|
||||
### Tips and Tricks
|
||||
|
||||
General Tips:
|
||||
- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
|
||||
- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- This warning can be safely ignored:
|
||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||
- Read scripts before you run them!
|
||||
|
||||
Summarization Tips:
|
||||
- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
|
||||
### Summarization Finetuning
|
||||
Run/modify `finetune.sh`
|
||||
@ -58,25 +80,20 @@ The following command should work on a 16GB GPU:
|
||||
|
||||
*Note*: The following tips mostly apply to summarization finetuning.
|
||||
|
||||
Tips:
|
||||
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
|
||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||
- This warning can be safely ignored:
|
||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||
### Translation Finetuning
|
||||
|
||||
#### Finetuning Outputs
|
||||
First, follow the wmt_en_ro download instructions.
|
||||
Then you can finetune mbart_cc25 on english-romanian with the following command.
|
||||
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
||||
```bash
|
||||
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
|
||||
export BS=4
|
||||
export GAS=8
|
||||
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
|
||||
```
|
||||
|
||||
|
||||
### Finetuning Outputs
|
||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||
|
||||
|
@ -14,11 +14,12 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
@ -47,6 +48,7 @@ except ImportError:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
assert_all_frozen,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
|
||||
@ -92,9 +94,12 @@ class SummarizationModule(BaseTransformer):
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||
freeze_params(self.model.get_encoder())
|
||||
assert_all_frozen(self.model.get_encoder())
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.decoder_start_token_id = None
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
@ -160,7 +165,12 @@ class SummarizationModule(BaseTransformer):
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=source_ids,
|
||||
attention_mask=source_mask,
|
||||
use_cache=True,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
)
|
||||
gen_time = (time.time() - t0) / source_ids.shape[0]
|
||||
preds = self.ids_to_clean_text(generated_ids)
|
||||
target = self.ids_to_clean_text(y)
|
||||
@ -276,6 +286,9 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -285,6 +298,13 @@ class TranslationModule(SummarizationModule):
|
||||
metric_names = ["bleu"]
|
||||
val_metric = "bleu"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
|
@ -1,18 +1,13 @@
|
||||
export OUTPUT_DIR_NAME=t5
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--data_dir=./cnn-dailymail/cnn_dm \
|
||||
--model_name_or_path=t5-large \
|
||||
--data_dir=$CNN_DIR \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--train_batch_size=$BS \
|
||||
--eval_batch_size=$BS \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_train $@
|
||||
--max_source_length=512 \
|
||||
--val_check_interval=0.1 --n_val=200 \
|
||||
--do_train --do_predict \
|
||||
$@
|
||||
|
@ -223,10 +223,30 @@ def test_finetune(model):
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
main(args)
|
||||
module = main(args)
|
||||
|
||||
input_embeds = module.model.get_input_embeddings()
|
||||
assert not input_embeds.weight.requires_grad
|
||||
if model == T5_TINY:
|
||||
lm_head = module.model.lm_head
|
||||
assert not lm_head.weight.requires_grad
|
||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||
|
||||
else:
|
||||
bart = module.model.model
|
||||
embed_pos = bart.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not bart.shared.weight.requires_grad
|
||||
# check that embeds are the same
|
||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -239,7 +259,12 @@ def test_dataset(tok):
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = SummarizationDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
tgt_lang="ro_RO",
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
|
21
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file
21
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--val_check_interval 0.1 \
|
||||
--n_val 500 \
|
||||
--adam_eps 1e-06 \
|
||||
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
|
||||
--max_source_length=300 --max_target_length 300 --val_max_target_length=300 --test_max_target_length 300 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--model_name_or_path facebook/mbart-large-cc25 \
|
||||
--task translation \
|
||||
--warmup_steps 500 \
|
||||
--logger wandb --sortish_sampler \
|
||||
$@
|
@ -14,6 +14,8 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
|
||||
def encode_file(
|
||||
tokenizer,
|
||||
@ -25,6 +27,7 @@ def encode_file(
|
||||
prefix="",
|
||||
tok_name="",
|
||||
):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||
if not overwrite_cache and cache_path.exists():
|
||||
try:
|
||||
@ -46,8 +49,8 @@ def encode_file(
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
add_prefix_space=True,
|
||||
return_tensors=return_tensors,
|
||||
**extra_kw,
|
||||
)
|
||||
assert tokenized.input_ids.shape[1] == max_length
|
||||
examples.append(tokenized)
|
||||
@ -87,9 +90,14 @@ class SummarizationDataset(Dataset):
|
||||
n_obs=None,
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
):
|
||||
super().__init__()
|
||||
# FIXME: the rstrip logic strips all the chars, it seems.
|
||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||
if hasattr(tokenizer, "set_lang") and src_lang is not None:
|
||||
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
|
||||
self.source = encode_file(
|
||||
tokenizer,
|
||||
os.path.join(data_dir, type_path + ".source"),
|
||||
@ -100,7 +108,8 @@ class SummarizationDataset(Dataset):
|
||||
)
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
if hasattr(tokenizer, "set_lang"):
|
||||
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
|
||||
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
|
||||
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
@ -224,8 +233,8 @@ def get_git_info():
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||
|
||||
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
|
@ -26,6 +26,7 @@ known_third_party =
|
||||
sacrebleu
|
||||
seqeval
|
||||
sklearn
|
||||
streamlit
|
||||
tensorboardX
|
||||
tensorflow
|
||||
tensorflow_datasets
|
||||
|
@ -55,15 +55,16 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
||||
}
|
||||
|
||||
|
||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"]
|
||||
_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"]
|
||||
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
|
||||
|
||||
|
||||
class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs.
|
||||
Other tokenizer methods like encode do not work properly.
|
||||
The tokenization method is <tokens> <eos> <language code>. There is no BOS token.
|
||||
Other tokenizer methods like ``encode`` do not work properly.
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
|
||||
``<language code> <tokens> <eos>``` for target language documents.
|
||||
|
||||
Examples::
|
||||
|
||||
@ -109,24 +110,84 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
}
|
||||
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
|
||||
cur_lang_code = lang_code_to_id["en_XX"]
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
self.reset_special_tokens()
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||
special_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||
def reset_special_tokens(self) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code]."""
|
||||
self.prefix_tokens = []
|
||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
||||
An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||
BOS is never used.
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0 + special_tokens
|
||||
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + special_tokens
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Set to True if the token list is already formatted with special tokens for the model
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
prefix_ones = [1] * len(self.prefix_tokens)
|
||||
suffix_ones = [1] * len(self.suffix_tokens)
|
||||
if token_ids_1 is None:
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
def set_lang(self, lang: str) -> None:
|
||||
"""Set the current language code in order to call tokenizer properly."""
|
||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||
self.prefix_tokens = [self.cur_lang_code]
|
||||
self.suffix_tokens = [self.eos_token_id]
|
||||
|
||||
def prepare_translation_batch(
|
||||
self,
|
||||
@ -135,44 +196,45 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
pad_to_max_length: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
"""Prepare a batch that can be passed directly to an instance of MBartModel.
|
||||
Arguments:
|
||||
src_texts: list of src language texts
|
||||
src_lang: default en_XX (english)
|
||||
src_lang: default en_XX (english), the language we are translating from
|
||||
tgt_texts: list of tgt language texts
|
||||
tgt_lang: default ro_RO (romanian)
|
||||
max_length: (None) defer to config (1024 for mbart-large-en-ro)
|
||||
pad_to_max_length: (bool)
|
||||
tgt_lang: default ro_RO (romanian), the language we are translating to
|
||||
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large*
|
||||
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest.
|
||||
|
||||
Returns:
|
||||
dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor.
|
||||
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.max_len
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
model_inputs: BatchEncoding = self.batch_encode_plus(
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
self.cur_lang_code = self.lang_code_to_id[tgt_lang]
|
||||
decoder_inputs: BatchEncoding = self.batch_encode_plus(
|
||||
self.set_lang(tgt_lang)
|
||||
decoder_inputs: BatchEncoding = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
truncation=True,
|
||||
)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
self.cur_lang_code = self.lang_code_to_id[src_lang]
|
||||
self.reset_special_tokens() # sets to src_lang
|
||||
return model_inputs
|
||||
|
@ -19,7 +19,6 @@ import unittest
|
||||
import timeout_decorator # noqa
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@ -31,7 +30,6 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
@ -39,7 +37,6 @@ if is_torch_available():
|
||||
BartForQuestionAnswering,
|
||||
BartConfig,
|
||||
BartTokenizer,
|
||||
BatchEncoding,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
@ -202,140 +199,6 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
tiny(**inputs_dict)
|
||||
|
||||
|
||||
EN_CODE = 250004
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartIntegrationTests(unittest.TestCase):
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
|
||||
]
|
||||
tgt_text = [
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
|
||||
]
|
||||
|
||||
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device)
|
||||
if "cuda" in torch_device:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
@slow
|
||||
@unittest.skip("This has been failing since June 20th at least.")
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
net_input = {
|
||||
"input_ids": _long_tensor(
|
||||
[
|
||||
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
|
||||
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
|
||||
]
|
||||
),
|
||||
"decoder_input_ids": _long_tensor(
|
||||
[
|
||||
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
|
||||
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
|
||||
]
|
||||
),
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**net_input)
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
|
||||
result_slice = logits[0, 0, :3]
|
||||
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
self.assertEqual(self.tgt_text[1], decoded[1])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
expected = {"scale_embedding": True, "output_past": True}
|
||||
for name in mbart_models:
|
||||
config = BartConfig.from_pretrained(name)
|
||||
self.assertTrue(config.is_valid_mbart())
|
||||
for k, v in expected.items():
|
||||
try:
|
||||
self.assertEqual(v, getattr(config, k))
|
||||
except AssertionError as e:
|
||||
e.args += (name, k)
|
||||
raise
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=99,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
add_final_layer_norm=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||
result = batch.input_ids.tolist()[0]
|
||||
self.assertListEqual(self.expected_src_tokens, result)
|
||||
self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS
|
||||
|
||||
def test_enro_tokenizer_batch_encode_plus(self):
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
self.assertListEqual(self.expected_src_tokens, ids)
|
||||
|
||||
def test_enro_tokenizer_decode_ignores_language_codes(self):
|
||||
self.assertIn(250020, self.tokenizer.all_special_ids)
|
||||
generated_ids = [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
|
||||
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
|
||||
self.assertEqual(result, expected_romanian)
|
||||
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||
|
||||
def test_enro_tokenizer_truncation(self):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_translation_batch(
|
||||
src_text, return_tensors=None, max_length=desired_max_length
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
vocab_size = 99
|
||||
|
142
tests/test_modeling_mbart.py
Normal file
142
tests/test_modeling_mbart.py
Normal file
@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bart import TOLERANCE, _assert_tensors_equal, _long_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BatchEncoding,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
EN_CODE = 250004
|
||||
RO_CODE = 250020
|
||||
|
||||
|
||||
@require_torch
|
||||
class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||
|
||||
checkpoint_name = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
"""Only load the model if needed."""
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
|
||||
if "cuda" in torch_device:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
|
||||
]
|
||||
tgt_text = [
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
|
||||
]
|
||||
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||
|
||||
@slow
|
||||
@unittest.skip("This has been failing since June 20th at least.")
|
||||
def test_enro_forward(self):
|
||||
model = self.model
|
||||
net_input = {
|
||||
"input_ids": _long_tensor(
|
||||
[
|
||||
[3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004],
|
||||
[64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004],
|
||||
]
|
||||
),
|
||||
"decoder_input_ids": _long_tensor(
|
||||
[
|
||||
[250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1],
|
||||
[250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2],
|
||||
]
|
||||
),
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**net_input)
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
|
||||
result_slice = logits[0, 0, :3]
|
||||
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
self.assertEqual(self.tgt_text[1], decoded[1])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
expected = {"scale_embedding": True, "output_past": True}
|
||||
for name in mbart_models:
|
||||
config = BartConfig.from_pretrained(name)
|
||||
self.assertTrue(config.is_valid_mbart())
|
||||
for k, v in expected.items():
|
||||
try:
|
||||
self.assertEqual(v, getattr(config, k))
|
||||
except AssertionError as e:
|
||||
e.args += (name, k)
|
||||
raise
|
||||
|
||||
def test_mbart_fast_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=99,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
add_final_layer_norm=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
|
||||
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-cc25"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
" I ate lunch twice yesterday",
|
||||
]
|
||||
tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"]
|
||||
|
||||
@unittest.skip("This test is broken, still generates english")
|
||||
def test_cc25_generate(self):
|
||||
inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device)
|
||||
translated_tokens = self.model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
|
||||
)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
@ -903,6 +903,7 @@ class TokenizerTesterMixin:
|
||||
tokenizer.padding_side = "right"
|
||||
encoded_sequence = tokenizer.encode(sequence)
|
||||
sequence_length = len(encoded_sequence)
|
||||
# FIXME: the next line should be padding(max_length) to avoid warning
|
||||
padded_sequence = tokenizer.encode(
|
||||
sequence, max_length=sequence_length + padding_size, pad_to_max_length=True
|
||||
)
|
||||
|
156
tests/test_tokenization_mbart.py
Normal file
156
tests/test_tokenization_mbart.py
Normal file
@ -0,0 +1,156 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
|
||||
|
||||
|
||||
EN_CODE = 250004
|
||||
RO_CODE = 250020
|
||||
|
||||
|
||||
class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = MBartTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens),
|
||||
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"9",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"é",
|
||||
".",
|
||||
],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids,
|
||||
[
|
||||
value + tokenizer.fairseq_offset
|
||||
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
|
||||
# ^ unk: 2 + 1 = 3 unk: 2 + 1 = 3 ^
|
||||
],
|
||||
)
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"<unk>",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"<unk>",
|
||||
".",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
""" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
|
||||
]
|
||||
tgt_text = [
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.',
|
||||
]
|
||||
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||
)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||
result = batch.input_ids.tolist()[0]
|
||||
self.assertListEqual(self.expected_src_tokens, result)
|
||||
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
||||
# Test that special tokens are reset
|
||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||
|
||||
def test_enro_tokenizer_batch_encode_plus(self):
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
self.assertListEqual(self.expected_src_tokens, ids)
|
||||
|
||||
def test_enro_tokenizer_decode_ignores_language_codes(self):
|
||||
self.assertIn(RO_CODE, self.tokenizer.all_special_ids)
|
||||
generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2]
|
||||
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
|
||||
self.assertEqual(result, expected_romanian)
|
||||
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||
|
||||
def test_enro_tokenizer_truncation(self):
|
||||
src_text = ["this is gunna be a long sentence " * 20]
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_translation_batch(
|
||||
src_text, return_tensors=None, max_length=desired_max_length
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
Loading…
Reference in New Issue
Block a user