mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[examples/seq2seq]: add --label_smoothing option (#5919)
This commit is contained in:
parent
95d1962b9c
commit
5b193b39b0
@ -27,8 +27,18 @@ this should make a directory called `cnn_dm/` with files like `test.source`.
|
||||
```
|
||||
|
||||
WMT16 English-Romanian Translation Data:
|
||||
|
||||
This dataset comes in two formats. The "packed" version merges short training examples into examples of <200 tokens to increase GPU utilization (and also improves validation performance).
|
||||
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro_packed_train_200.tgz
|
||||
tar -xzvf wmt_en_ro_packed_200.tgz
|
||||
export ENRO_DIR=wmt_en_ro_packed_train_200
|
||||
```
|
||||
|
||||
The original data can also be downloaded with this command:
|
||||
```bash
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
|
||||
tar -xzvf wmt_en_ro.tar.gz
|
||||
export ENRO_DIR=${PWD}/wmt_en_ro
|
||||
@ -84,16 +94,31 @@ The following command should work on a 16GB GPU:
|
||||
|
||||
First, follow the wmt_en_ro download instructions.
|
||||
Then you can finetune mbart_cc25 on english-romanian with the following command.
|
||||
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
||||
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
||||
|
||||
Best performing command:
|
||||
```bash
|
||||
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
|
||||
export MAX_LEN=128
|
||||
# optionally
|
||||
export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above
|
||||
# export WANDB_PROJECT="MT" # optional
|
||||
export MAX_LEN=200
|
||||
export BS=4
|
||||
export GAS=8
|
||||
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
|
||||
export GAS=8 # gradient accumulation steps
|
||||
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler
|
||||
```
|
||||
This should take < 2h/epoch on a 16GB v100 and achieve val_avg_ BLEU score above 25. (you can see in wandb or metrics.json).
|
||||
To get results in line with fairseq, you need to do some postprocessing.
|
||||
|
||||
|
||||
MultiGPU command
|
||||
(using 8 GPUS as an example)
|
||||
```bash
|
||||
export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above
|
||||
# export WANDB_PROJECT="MT" # optional
|
||||
export MAX_LEN=200
|
||||
export BS=4
|
||||
export GAS=1 # gradient accumulation steps
|
||||
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb
|
||||
```
|
||||
### Finetuning Outputs
|
||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||
@ -108,7 +133,7 @@ output_dir
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.json
|
||||
├── git_log.json # repo, branch, and commit hash
|
||||
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
|
||||
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT)
|
||||
├── metrics.json # new validation metrics will continually be appended to this
|
||||
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
|
||||
│ ├── config.json
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
@ -90,3 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
def get_early_stopping_callback(metric, patience):
|
||||
return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,)
|
||||
|
@ -33,9 +33,10 @@ try:
|
||||
calculate_bleu_score,
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
Seq2SeqDataset,
|
||||
@ -52,8 +53,9 @@ except ImportError:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -128,12 +130,22 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
y_ids = y[:, :-1].contiguous()
|
||||
lm_labels = y[:, 1:].clone()
|
||||
lm_labels[y[:, 1:] == pad_token_id] = -100
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
|
||||
loss = outputs[0]
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
lm_logits = outputs[0]
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
return (loss,)
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
@ -290,8 +302,16 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@ -335,17 +355,24 @@ def main(args, model=None) -> SummarizationModule:
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
if args.early_stopping_patience >= 0:
|
||||
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
else:
|
||||
es_callback = False
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
# TODO: early stopping callback seems messed up
|
||||
)
|
||||
|
@ -12,14 +12,14 @@ import torch
|
||||
from pytest import param
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer, MBartTokenizer
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer
|
||||
from transformers.testing_utils import require_multigpu
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .pack_dataset import pack_data_dir
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
|
||||
from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@ -27,7 +27,8 @@ logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"label_smoothing_eps": 0.2,
|
||||
"label_smoothing": 0.2,
|
||||
"early_stopping_patience": 2,
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
@ -142,6 +143,26 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
def test_loss_fn(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
|
||||
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
|
||||
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
model_computed_loss = model(
|
||||
input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False
|
||||
).loss
|
||||
|
||||
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
|
||||
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
smoothed_loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
|
||||
)
|
||||
with self.assertRaises(AssertionError):
|
||||
# TODO: understand why this breaks
|
||||
self.assertEqual(nll_loss, model_computed_loss)
|
||||
|
||||
@unittest.skip("T5 distillation is broken at the moment")
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
@ -156,6 +177,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
def _test_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing_eps=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
max_epochs=2,
|
||||
@ -215,6 +238,8 @@ def test_run_eval_bart(model):
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
|
@ -4,17 +4,17 @@ export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--val_check_interval 0.1 \
|
||||
--val_check_interval 0.25 \
|
||||
--adam_eps 1e-06 \
|
||||
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
|
||||
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
||||
--data_dir $ENRO_DIR \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--model_name_or_path facebook/mbart-large-cc25 \
|
||||
--task translation \
|
||||
--warmup_steps 500 \
|
||||
--logger_name wandb --sortish_sampler \
|
||||
--freeze_embeds \
|
||||
--early_stopping_patience 4 \
|
||||
--model_name_or_path facebook/mbart-large-cc25 \
|
||||
$@
|
||||
|
@ -1,18 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
# Need to export N_GPUS=
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus $N_GPUS \
|
||||
--do_train \
|
||||
--val_check_interval 0.25 \
|
||||
--adam_eps 1e-06 \
|
||||
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
||||
--data_dir $ENRO_DIR \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--tokenizer facebook/mbart-large-cc25 \
|
||||
--task translation \
|
||||
--warmup_steps 500 --freeze_encoder --freeze_embeds \
|
||||
$@
|
@ -19,6 +19,29 @@ from torch.utils.data import Dataset, Sampler
|
||||
from transformers import BartTokenizer
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
"""From fairseq"""
|
||||
if target.dim() == lprobs.dim() - 1:
|
||||
target = target.unsqueeze(-1)
|
||||
nll_loss = -lprobs.gather(dim=-1, index=target)
|
||||
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
||||
if ignore_index is not None:
|
||||
pad_mask = target.eq(ignore_index)
|
||||
nll_loss.masked_fill_(pad_mask, 0.0)
|
||||
smooth_loss.masked_fill_(pad_mask, 0.0)
|
||||
bs = pad_mask.long().sum()
|
||||
else:
|
||||
nll_loss = nll_loss.squeeze(-1)
|
||||
smooth_loss = smooth_loss.squeeze(-1)
|
||||
bs = lprobs.shape[0]
|
||||
|
||||
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
||||
smooth_loss = smooth_loss.sum()
|
||||
eps_i = epsilon / lprobs.size(-1)
|
||||
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
||||
return loss / bs, nll_loss / bs
|
||||
|
||||
|
||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
return tokenizer(
|
||||
@ -144,8 +167,8 @@ class MBartDataset(Seq2SeqDataset):
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
return {
|
||||
"tgt_texts": source_line,
|
||||
"src_texts": tgt_line,
|
||||
"tgt_texts": tgt_line,
|
||||
"src_texts": source_line,
|
||||
}
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
|
Loading…
Reference in New Issue
Block a user