[examples/seq2seq]: add --label_smoothing option (#5919)

This commit is contained in:
Sam Shleifer 2020-07-21 16:51:39 -04:00 committed by GitHub
parent 95d1962b9c
commit 5b193b39b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 132 additions and 46 deletions

View File

@ -27,8 +27,18 @@ this should make a directory called `cnn_dm/` with files like `test.source`.
```
WMT16 English-Romanian Translation Data:
This dataset comes in two formats. The "packed" version merges short training examples into examples of <200 tokens to increase GPU utilization (and also improves validation performance).
```bash
cd examples/seq2seq
https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro_packed_train_200.tgz
tar -xzvf wmt_en_ro_packed_200.tgz
export ENRO_DIR=wmt_en_ro_packed_train_200
```
The original data can also be downloaded with this command:
```bash
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export ENRO_DIR=${PWD}/wmt_en_ro
@ -84,16 +94,31 @@ The following command should work on a 16GB GPU:
First, follow the wmt_en_ro download instructions.
Then you can finetune mbart_cc25 on english-romanian with the following command.
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
Best performing command:
```bash
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
export MAX_LEN=128
# optionally
export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above
# export WANDB_PROJECT="MT" # optional
export MAX_LEN=200
export BS=4
export GAS=8
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
export GAS=8 # gradient accumulation steps
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler
```
This should take < 2h/epoch on a 16GB v100 and achieve val_avg_ BLEU score above 25. (you can see in wandb or metrics.json).
To get results in line with fairseq, you need to do some postprocessing.
MultiGPU command
(using 8 GPUS as an example)
```bash
export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above
# export WANDB_PROJECT="MT" # optional
export MAX_LEN=200
export BS=4
export GAS=1 # gradient accumulation steps
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb
```
### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
@ -108,7 +133,7 @@ output_dir
│   ├── tokenizer_config.json
│   └── vocab.json
├── git_log.json # repo, branch, and commit hash
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT)
├── metrics.json # new validation metrics will continually be appended to this
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
│   ├── config.json

View File

@ -5,7 +5,7 @@ from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
@ -90,3 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback
def get_early_stopping_callback(metric, patience):
return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,)

View File

@ -33,9 +33,10 @@ try:
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
label_smoothed_nll_loss,
)
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
except ImportError:
from utils import (
Seq2SeqDataset,
@ -52,8 +53,9 @@ except ImportError:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
label_smoothed_nll_loss,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
logger = logging.getLogger(__name__)
@ -128,12 +130,22 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone()
lm_labels[y[:, 1:] == pad_token_id] = -100
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
loss = outputs[0]
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
lm_logits = outputs[0]
assert lm_logits.shape[-1] == self.model.config.vocab_size
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
)
return (loss,)
def training_step(self, batch, batch_idx) -> Dict:
@ -290,8 +302,16 @@ class SummarizationModule(BaseTransformer):
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument(
"--early_stopping_patience",
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
)
return parser
@ -335,17 +355,24 @@ def main(args, model=None) -> SummarizationModule:
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=dataset)
project = os.environ.get("WANDB_PROJECT", dataset)
logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
if args.early_stopping_patience >= 0:
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
else:
es_callback = False
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=logger,
# TODO: early stopping callback seems messed up
)

View File

@ -12,14 +12,14 @@ import torch
from pytest import param
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, MBartTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
@ -27,7 +27,8 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"label_smoothing_eps": 0.2,
"label_smoothing": 0.2,
"early_stopping_patience": 2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
@ -142,6 +143,26 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
def test_loss_fn(self):
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
model_computed_loss = model(
input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False
).loss
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
smoothed_loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
)
with self.assertRaises(AssertionError):
# TODO: understand why this breaks
self.assertEqual(nll_loss, model_computed_loss)
@unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self):
updates = dict(
@ -156,6 +177,8 @@ class TestSummarizationDistiller(unittest.TestCase):
def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
label_smoothing_eps=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
@ -215,6 +238,8 @@ def test_run_eval_bart(model):
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(

View File

@ -4,17 +4,17 @@ export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--val_check_interval 0.1 \
--val_check_interval 0.25 \
--adam_eps 1e-06 \
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
--data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--model_name_or_path facebook/mbart-large-cc25 \
--task translation \
--warmup_steps 500 \
--logger_name wandb --sortish_sampler \
--freeze_embeds \
--early_stopping_patience 4 \
--model_name_or_path facebook/mbart-large-cc25 \
$@

View File

@ -1,18 +0,0 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
# Need to export N_GPUS=
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus $N_GPUS \
--do_train \
--val_check_interval 0.25 \
--adam_eps 1e-06 \
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
--data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--tokenizer facebook/mbart-large-cc25 \
--task translation \
--warmup_steps 500 --freeze_encoder --freeze_embeds \
$@

View File

@ -19,6 +19,29 @@ from torch.utils.data import Dataset, Sampler
from transformers import BartTokenizer
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
bs = pad_mask.long().sum()
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
bs = lprobs.shape[0]
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss / bs, nll_loss / bs
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
return tokenizer(
@ -144,8 +167,8 @@ class MBartDataset(Seq2SeqDataset):
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {
"tgt_texts": source_line,
"src_texts": tgt_line,
"tgt_texts": tgt_line,
"src_texts": source_line,
}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]: