Seq2seq trainer (#9241)

* Add label smoothing in Trainer

* Add options for scheduler and Adafactor in Trainer

* Put Seq2SeqTrainer in the main lib

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments and adapt scripts

* Documentation

* Move test not using script to tests folder

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger 2020-12-22 11:33:44 -05:00 committed by GitHub
parent 1fc7119181
commit 490b39e614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 655 additions and 166 deletions

View File

@ -43,6 +43,10 @@ Schedules
Learning Rate Schedules (Pytorch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.SchedulerType
.. autofunction:: transformers.get_scheduler
.. autofunction:: transformers.get_constant_schedule

View File

@ -63,6 +63,13 @@ Trainer
:members:
Seq2SeqTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Seq2SeqTrainer
:members: evaluate, predict
TFTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -77,6 +84,13 @@ TrainingArguments
:members:
Seq2SeqTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Seq2SeqTrainingArguments
:members:
TFTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -20,9 +20,16 @@ from dataclasses import dataclass, field
from typing import Optional
import transformers
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
HfArgumentParser,
MBartTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvaluationStrategy, is_main_process
from transformers.training_args import ParallelMode
from utils import (
@ -86,14 +93,14 @@ class DataTrainingArguments:
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
max_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
eval_max_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
@ -101,13 +108,6 @@ class DataTrainingArguments:
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
@ -233,7 +233,7 @@ def main():
type_path="train",
data_dir=data_args.data_dir,
n_obs=data_args.n_train,
max_target_length=data_args.max_target_length,
max_target_length=data_args.max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
@ -246,7 +246,7 @@ def main():
type_path="val",
data_dir=data_args.data_dir,
n_obs=data_args.n_val,
max_target_length=data_args.val_max_target_length,
max_target_length=data_args.eval_max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
@ -259,7 +259,7 @@ def main():
type_path="test",
data_dir=data_args.data_dir,
n_obs=data_args.n_test,
max_target_length=data_args.test_max_target_length,
max_target_length=data_args.eval_max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
@ -273,13 +273,12 @@ def main():
)
trainer = Seq2SeqTrainer(
model=model,
config=config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=compute_metrics_fn,
data_args=data_args,
tokenizer=tokenizer,
)
all_metrics = {}
@ -310,7 +309,9 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="val")
metrics = trainer.evaluate(
metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams
)
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)
@ -322,7 +323,12 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
test_output = trainer.predict(
test_dataset=test_dataset,
metric_key_prefix="test",
max_length=data_args.eval_max_length,
num_beams=data_args.eval_beams,
)
metrics = test_output.metrics
metrics["test_n_objs"] = data_args.n_test

View File

@ -17,8 +17,7 @@ import sys
import unittest
from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_apex_available, is_datasets_available
from transformers.file_utils import is_apex_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
TestCasePlus,
@ -31,8 +30,7 @@ from transformers.testing_utils import (
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
from .finetune_trainer import main
set_seed(42)
@ -120,119 +118,6 @@ class TestFinetuneTrainer(TestCasePlus):
assert "test_generations.txt" in contents
assert "test_results.json" in contents
@slow
def test_finetune_bert2bert(self):
if not is_datasets_available():
return
import datasets
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.max_length = 128
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16))
rouge = datasets.load_metric("rouge")
batch_size = 4
def _map_to_encoder_decoder_inputs(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["decoder_input_ids"] = outputs.input_ids
batch["labels"] = outputs.input_ids.copy()
batch["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
]
batch["decoder_attention_mask"] = outputs.attention_mask
assert all([len(x) == 512 for x in inputs.input_ids])
assert all([len(x) == 128 for x in outputs.input_ids])
return batch
def _compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
"rouge2"
].mid
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
# map train dataset
train_dataset = train_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
train_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
# same for validation dataset
val_dataset = val_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
val_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
output_dir = self.get_auto_remove_tmp_dir()
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
evaluation_strategy="steps",
do_train=True,
do_eval=True,
warmup_steps=0,
eval_steps=2,
logging_steps=2,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
args=training_args,
compute_metrics=_compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
# start training
trainer.train()
def run_trainer(
self,
eval_steps: int,
@ -252,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus):
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--max_length {max_len}
--eval_max_length {max_len}
--do_train
--do_eval
--do_predict

View File

@ -29,7 +29,7 @@ python finetune_trainer.py \
--freeze_encoder --freeze_embeds \
--num_train_epochs=6 \
--save_steps 3000 --eval_steps 3000 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --logging_first_step \

View File

@ -30,7 +30,7 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
--num_train_epochs=6 \
--save_steps 500 --eval_steps 500 \
--logging_first_step --logging_steps 200 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--do_train --do_eval \
--evaluation_strategy steps \
--prediction_loss_only \

View File

@ -32,7 +32,7 @@ python finetune_trainer.py \
--num_train_epochs=2 \
--save_steps 3000 --eval_steps 3000 \
--logging_first_step \
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--max_length 56 --eval_max_length $MAX_TGT_LEN \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --sortish_sampler \

View File

@ -24,8 +24,7 @@ python finetune_trainer.py \
--src_lang en_XX --tgt_lang ro_RO \
--freeze_embeds \
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
--max_source_length 128 --max_target_length 128 \
--val_max_target_length 128 --test_max_target_length 128 \
--max_source_length 128 --max_length 128 --eval_max_length 128 \
--sortish_sampler \
--num_train_epochs 6 \
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \

View File

@ -330,7 +330,7 @@ class Seq2SeqDataCollator:
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
max_target_length=self.data_args.max_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
**self.dataset_kwargs,

View File

@ -287,8 +287,9 @@ from .trainer_callback import (
TrainerControl,
TrainerState,
)
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging
@ -682,11 +683,13 @@ if is_torch_available():
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
# Trainer
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer
else:
from .utils.dummy_pt_objects import *

View File

@ -15,12 +15,13 @@
"""PyTorch optimization for BERT model."""
import math
from typing import Callable, Iterable, Tuple
from typing import Callable, Iterable, Optional, Tuple, Union
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from .trainer_utils import SchedulerType
from .utils import logging
@ -215,6 +216,56 @@ def get_polynomial_decay_schedule_with_warmup(
return LambdaLR(optimizer, lr_lambda, last_epoch)
TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
SchedulerType.CONSTANT: get_constant_schedule,
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
}
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
):
"""
Unified API to get any scheduler from its name.
Args:
name (:obj:`str` or `:obj:`SchedulerType`):
The name of the scheduler to use.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (:obj:`int`, `optional`):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (:obj:`int`, `optional`):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
class AdamW(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization

View File

@ -57,7 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .modeling_utils import PreTrainedModel
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .optimization import AdamW, get_linear_schedule_with_warmup
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
@ -70,6 +70,7 @@ from .trainer_callback import (
)
from .trainer_pt_utils import (
DistributedTensorGatherer,
LabelSmoother,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
@ -320,6 +321,12 @@ class Trainer:
)
self.use_apex = True
# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None
self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
@ -507,24 +514,32 @@ class Trainer:
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_cls = AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=AdamW,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
self.optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=num_training_steps,
)
def num_examples(self, dataloader: DataLoader) -> int:
@ -1168,8 +1183,12 @@ class Trainer:
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.label_smoother is not None and "labels" in inputs:
return self.label_smoother(outputs, inputs["labels"])
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
def is_local_process_zero(self) -> bool:
"""
@ -1556,11 +1575,13 @@ class Trainer:
else:
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None and "labels" in inputs:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
if isinstance(outputs, dict):
loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None

View File

@ -19,6 +19,7 @@ Torch utilities for the Trainer class.
import math
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
@ -360,3 +361,32 @@ class DistributedTensorGatherer:
if self._offsets[0] != self.process_length:
logger.warn("Not all data has been set. Are you sure you passed all values?")
return nested_truncate(self._storage, self.num_samples)
@dataclass
class LabelSmoother:
"""
Adds label-smoothing on a pre-computed output from a Transformers model.
Args:
epsilon (:obj:`float`, `optional`, defaults to 0.1):
The label smoothing factor.
ignore_index (:obj:`int`, `optional`, defaults to -100):
The index in the labels to ignore when computing the loss.
"""
epsilon: float = 0.1
ignore_index: int = -100
def __call__(self, model_output, labels):
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0]
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
# Look at the ignored index and mask the corresponding log_probs.
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index)
log_probs.masked_fill_(padding_mask, 0.0)
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss

View File

@ -0,0 +1,231 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from torch.utils.data.dataset import Dataset
from .file_utils import is_torch_tpu_available
from .trainer import Trainer
from .trainer_pt_utils import get_tpu_sampler
from .trainer_utils import PredictionOutput
from .training_args import ParallelMode
from .utils import logging
if version.parse(torch.__version__) >= version.parse("1.6"):
from torch.cuda.amp import autocast
logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size,
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
)
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`List[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is ``"eval"`` (default)
max_length (:obj:`int`, `optional`):
The maximum target length to use when predicting with the generate method.
num_beams (:obj:`int`, `optional`):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
self._max_length = max_length
self._num_beams = num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict(
self,
test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in :obj:`evaluate()`.
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`List[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is ``"eval"`` (default)
max_length (:obj:`int`, `optional`):
The maximum target length to use when predicting with the generate method.
num_beams (:obj:`int`, `optional`):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
.. note::
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
self._max_length = max_length
self._num_beams = num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
if not self.args.predict_with_generate or prediction_loss_only:
return super()(self, model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)
gen_kwargs = {
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
}
generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
else:
loss = None
if self.args.prediction_loss_only:
return (loss, None, None)
labels = inputs["labels"]
if labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
return (loss, generated_tokens, labels)
def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is None:
raise ValueError(
f"Tensor need to be padded to `max_length={max_length}` but no tokenzier was passed when creating "
"this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
)
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor
return padded_tensor

View File

@ -201,3 +201,12 @@ def speed_metrics(split, start_time, num_samples=None):
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
return result
class SchedulerType(ExplicitEnum):
LINEAR = "linear"
COSINE = "cosine"
COSINE_WITH_RESTARTS = "cosine_with_restarts"
POLYNOMIAL = "polynomial"
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"

View File

@ -19,7 +19,7 @@ from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .trainer_utils import EvaluationStrategy
from .trainer_utils import EvaluationStrategy, SchedulerType
from .utils import logging
@ -121,6 +121,9 @@ class TrainingArguments:
max_steps (:obj:`int`, `optional`, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides
:obj:`num_train_epochs`.
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
values.
warmup_steps (:obj:`int`, `optional`, defaults to 0):
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
logging_dir (:obj:`str`, `optional`):
@ -217,6 +220,13 @@ class TrainingArguments:
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.
label_smoothing_factor (:obj:`float`, `optional`, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 -
label_smoothing_factor + label_smoothing_factor/num_labels` respectively.
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
:class:`~transformers.AdamW`.
"""
output_dir: str = field(
@ -246,7 +256,7 @@ class TrainingArguments:
)
evaluation_strategy: EvaluationStrategy = field(
default="no",
metadata={"help": "Run evaluation during training at each logging step."},
metadata={"help": "The evaluation strategy to use."},
)
prediction_loss_only: bool = field(
default=False,
@ -296,6 +306,10 @@ class TrainingArguments:
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
lr_scheduler_type: SchedulerType = field(
default="linear",
metadata={"help": "The scheduler type to use."},
)
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
@ -392,11 +406,16 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
self.do_eval = True
if self.eval_steps is None:

View File

@ -0,0 +1,42 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from .file_utils import add_start_docstrings
from .training_args import TrainingArguments
logger = logging.getLogger(__name__)
@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class Seq2SeqTrainingArguments(TrainingArguments):
"""
sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use a `sortish sampler` or not. Only possible if the underlying datasets are `Seq2SeqDataset` for
now but will become generally available in the near future.
It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness for
the training set.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
"""
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)

View File

@ -2279,6 +2279,10 @@ def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
requires_pytorch(get_polynomial_decay_schedule_with_warmup)
def get_scheduler(*args, **kwargs):
requires_pytorch(get_scheduler)
class Trainer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@ -2286,3 +2290,8 @@ class Trainer:
def torch_distributed_zero_first(*args, **kwargs):
requires_pytorch(torch_distributed_zero_first)
class Seq2SeqTrainer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)

View File

@ -0,0 +1,135 @@
# coding=utf-8
# Copyright 2020 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, require_datasets, slow
if is_datasets_available():
import datasets
class Seq2seqTrainerTester(TestCasePlus):
@slow
@require_datasets
def test_finetune_bert2bert(self):
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.max_length = 128
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16))
rouge = datasets.load_metric("rouge")
batch_size = 4
def _map_to_encoder_decoder_inputs(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["decoder_input_ids"] = outputs.input_ids
batch["labels"] = outputs.input_ids.copy()
batch["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
]
batch["decoder_attention_mask"] = outputs.attention_mask
assert all([len(x) == 512 for x in inputs.input_ids])
assert all([len(x) == 128 for x in outputs.input_ids])
return batch
def _compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
"rouge2"
].mid
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
# map train dataset
train_dataset = train_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
train_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
# same for validation dataset
val_dataset = val_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
val_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
output_dir = self.get_auto_remove_tmp_dir()
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
evaluation_strategy="steps",
do_train=True,
do_eval=True,
warmup_steps=0,
eval_steps=2,
logging_steps=2,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
args=training_args,
compute_metrics=_compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
)
# start training
trainer.train()

View File

@ -22,7 +22,10 @@ from transformers.testing_utils import require_torch
if is_torch_available():
from transformers.trainer_pt_utils import DistributedTensorGatherer
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother
@require_torch
@ -56,3 +59,31 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertTrue(np.array_equal(result[0], predictions))
self.assertTrue(np.array_equal(result[1][0], predictions))
self.assertTrue(np.array_equal(result[1][1], predictions))
def test_label_smoothing(self):
epsilon = 0.1
num_labels = 12
random_logits = torch.randn(4, 5, num_labels)
random_labels = torch.randint(0, num_labels, (4, 5))
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean()
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
# With a few -100 labels
random_labels[0, 1] = -100
random_labels[2, 1] = -100
random_labels[2, 3] = -100
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
# Mask the log probs with the -100 labels
log_probs[0, 1] = 0.0
log_probs[2, 1] = 0.0
log_probs[2, 3] = 0.0
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17)
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))