diff --git a/docs/source/main_classes/optimizer_schedules.rst b/docs/source/main_classes/optimizer_schedules.rst index b53d682375f..71cf1925742 100644 --- a/docs/source/main_classes/optimizer_schedules.rst +++ b/docs/source/main_classes/optimizer_schedules.rst @@ -43,6 +43,10 @@ Schedules Learning Rate Schedules (Pytorch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: transformers.SchedulerType + +.. autofunction:: transformers.get_scheduler + .. autofunction:: transformers.get_constant_schedule diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index a1ac8abd058..2bd666a1907 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -63,6 +63,13 @@ Trainer :members: +Seq2SeqTrainer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Seq2SeqTrainer + :members: evaluate, predict + + TFTrainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -77,6 +84,13 @@ TrainingArguments :members: +Seq2SeqTrainingArguments +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Seq2SeqTrainingArguments + :members: + + TFTrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 30cc30353a5..02c57f45026 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -20,9 +20,16 @@ from dataclasses import dataclass, field from typing import Optional import transformers -from seq2seq_trainer import Seq2SeqTrainer -from seq2seq_training_args import Seq2SeqTrainingArguments -from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed +from transformers import ( + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + HfArgumentParser, + MBartTokenizer, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + set_seed, +) from transformers.trainer_utils import EvaluationStrategy, is_main_process from transformers.training_args import ParallelMode from utils import ( @@ -86,14 +93,14 @@ class DataTrainingArguments: "than this will be truncated, sequences shorter will be padded." }, ) - max_target_length: Optional[int] = field( + max_length: Optional[int] = field( default=128, metadata={ "help": "The maximum total sequence length for target text after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) - val_max_target_length: Optional[int] = field( + eval_max_length: Optional[int] = field( default=142, metadata={ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " @@ -101,13 +108,6 @@ class DataTrainingArguments: " This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``" }, ) - test_max_target_length: Optional[int] = field( - default=142, - metadata={ - "help": "The maximum total sequence length for test target text after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."}) n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."}) n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."}) @@ -233,7 +233,7 @@ def main(): type_path="train", data_dir=data_args.data_dir, n_obs=data_args.n_train, - max_target_length=data_args.max_target_length, + max_target_length=data_args.max_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) @@ -246,7 +246,7 @@ def main(): type_path="val", data_dir=data_args.data_dir, n_obs=data_args.n_val, - max_target_length=data_args.val_max_target_length, + max_target_length=data_args.eval_max_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) @@ -259,7 +259,7 @@ def main(): type_path="test", data_dir=data_args.data_dir, n_obs=data_args.n_test, - max_target_length=data_args.test_max_target_length, + max_target_length=data_args.eval_max_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) @@ -273,13 +273,12 @@ def main(): ) trainer = Seq2SeqTrainer( model=model, - config=config, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), compute_metrics=compute_metrics_fn, - data_args=data_args, + tokenizer=tokenizer, ) all_metrics = {} @@ -310,7 +309,9 @@ def main(): if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate(metric_key_prefix="val") + metrics = trainer.evaluate( + metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams + ) metrics["val_n_objs"] = data_args.n_val metrics["val_loss"] = round(metrics["val_loss"], 4) @@ -322,7 +323,12 @@ def main(): if training_args.do_predict: logger.info("*** Predict ***") - test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test") + test_output = trainer.predict( + test_dataset=test_dataset, + metric_key_prefix="test", + max_length=data_args.eval_max_length, + num_beams=data_args.eval_beams, + ) metrics = test_output.metrics metrics["test_n_objs"] = data_args.n_test diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 24e56f752e8..9ce347ed894 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -17,8 +17,7 @@ import sys import unittest from unittest.mock import patch -from transformers import BertTokenizer, EncoderDecoderModel -from transformers.file_utils import is_apex_available, is_datasets_available +from transformers.file_utils import is_apex_available from transformers.integrations import is_fairscale_available from transformers.testing_utils import ( TestCasePlus, @@ -31,8 +30,7 @@ from transformers.testing_utils import ( from transformers.trainer_callback import TrainerState from transformers.trainer_utils import set_seed -from .finetune_trainer import Seq2SeqTrainingArguments, main -from .seq2seq_trainer import Seq2SeqTrainer +from .finetune_trainer import main set_seed(42) @@ -120,119 +118,6 @@ class TestFinetuneTrainer(TestCasePlus): assert "test_generations.txt" in contents assert "test_results.json" in contents - @slow - def test_finetune_bert2bert(self): - if not is_datasets_available(): - return - - import datasets - - bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") - tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - - bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size - bert2bert.config.eos_token_id = tokenizer.sep_token_id - bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id - bert2bert.config.max_length = 128 - - train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") - val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]") - - train_dataset = train_dataset.select(range(32)) - val_dataset = val_dataset.select(range(16)) - - rouge = datasets.load_metric("rouge") - - batch_size = 4 - - def _map_to_encoder_decoder_inputs(batch): - # Tokenizer will automatically set [BOS] [EOS] - inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512) - outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128) - batch["input_ids"] = inputs.input_ids - batch["attention_mask"] = inputs.attention_mask - - batch["decoder_input_ids"] = outputs.input_ids - batch["labels"] = outputs.input_ids.copy() - batch["labels"] = [ - [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"] - ] - batch["decoder_attention_mask"] = outputs.attention_mask - - assert all([len(x) == 512 for x in inputs.input_ids]) - assert all([len(x) == 128 for x in outputs.input_ids]) - - return batch - - def _compute_metrics(pred): - labels_ids = pred.label_ids - pred_ids = pred.predictions - - # all unnecessary tokens are removed - pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) - label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) - - rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[ - "rouge2" - ].mid - - return { - "rouge2_precision": round(rouge_output.precision, 4), - "rouge2_recall": round(rouge_output.recall, 4), - "rouge2_fmeasure": round(rouge_output.fmeasure, 4), - } - - # map train dataset - train_dataset = train_dataset.map( - _map_to_encoder_decoder_inputs, - batched=True, - batch_size=batch_size, - remove_columns=["article", "highlights"], - ) - train_dataset.set_format( - type="torch", - columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], - ) - - # same for validation dataset - val_dataset = val_dataset.map( - _map_to_encoder_decoder_inputs, - batched=True, - batch_size=batch_size, - remove_columns=["article", "highlights"], - ) - val_dataset.set_format( - type="torch", - columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], - ) - - output_dir = self.get_auto_remove_tmp_dir() - - training_args = Seq2SeqTrainingArguments( - output_dir=output_dir, - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - predict_with_generate=True, - evaluation_strategy="steps", - do_train=True, - do_eval=True, - warmup_steps=0, - eval_steps=2, - logging_steps=2, - ) - - # instantiate trainer - trainer = Seq2SeqTrainer( - model=bert2bert, - args=training_args, - compute_metrics=_compute_metrics, - train_dataset=train_dataset, - eval_dataset=val_dataset, - ) - - # start training - trainer.train() - def run_trainer( self, eval_steps: int, @@ -252,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus): --n_train 8 --n_val 8 --max_source_length {max_len} - --max_target_length {max_len} - --val_max_target_length {max_len} + --max_length {max_len} + --eval_max_length {max_len} --do_train --do_eval --do_predict diff --git a/examples/seq2seq/train_distil_marian_enro.sh b/examples/seq2seq/train_distil_marian_enro.sh index f09fd875efd..78bc6776cb9 100644 --- a/examples/seq2seq/train_distil_marian_enro.sh +++ b/examples/seq2seq/train_distil_marian_enro.sh @@ -29,7 +29,7 @@ python finetune_trainer.py \ --freeze_encoder --freeze_embeds \ --num_train_epochs=6 \ --save_steps 3000 --eval_steps 3000 \ - --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ + --max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \ --do_train --do_eval --do_predict \ --evaluation_strategy steps \ --predict_with_generate --logging_first_step \ diff --git a/examples/seq2seq/train_distil_marian_enro_tpu.sh b/examples/seq2seq/train_distil_marian_enro_tpu.sh index 271a8cf3ebc..7239d83a770 100644 --- a/examples/seq2seq/train_distil_marian_enro_tpu.sh +++ b/examples/seq2seq/train_distil_marian_enro_tpu.sh @@ -30,7 +30,7 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \ --num_train_epochs=6 \ --save_steps 500 --eval_steps 500 \ --logging_first_step --logging_steps 200 \ - --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ + --max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \ --do_train --do_eval \ --evaluation_strategy steps \ --prediction_loss_only \ diff --git a/examples/seq2seq/train_distilbart_cnn.sh b/examples/seq2seq/train_distilbart_cnn.sh index e89394adc64..70b4ff9bf09 100644 --- a/examples/seq2seq/train_distilbart_cnn.sh +++ b/examples/seq2seq/train_distilbart_cnn.sh @@ -32,7 +32,7 @@ python finetune_trainer.py \ --num_train_epochs=2 \ --save_steps 3000 --eval_steps 3000 \ --logging_first_step \ - --max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \ + --max_length 56 --eval_max_length $MAX_TGT_LEN \ --do_train --do_eval --do_predict \ --evaluation_strategy steps \ --predict_with_generate --sortish_sampler \ diff --git a/examples/seq2seq/train_mbart_cc25_enro.sh b/examples/seq2seq/train_mbart_cc25_enro.sh index cccae914c32..cfce15c1ec0 100644 --- a/examples/seq2seq/train_mbart_cc25_enro.sh +++ b/examples/seq2seq/train_mbart_cc25_enro.sh @@ -24,8 +24,7 @@ python finetune_trainer.py \ --src_lang en_XX --tgt_lang ro_RO \ --freeze_embeds \ --per_device_train_batch_size=4 --per_device_eval_batch_size=4 \ - --max_source_length 128 --max_target_length 128 \ - --val_max_target_length 128 --test_max_target_length 128 \ + --max_source_length 128 --max_length 128 --eval_max_length 128 \ --sortish_sampler \ --num_train_epochs 6 \ --save_steps 25000 --eval_steps 25000 --logging_steps 1000 \ diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 437cdf2e632..0b74bfd57fc 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -330,7 +330,7 @@ class Seq2SeqDataCollator: [x["src_texts"] for x in batch], tgt_texts=[x["tgt_texts"] for x in batch], max_length=self.data_args.max_source_length, - max_target_length=self.data_args.max_target_length, + max_target_length=self.data_args.max_length, padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack return_tensors="pt", **self.dataset_kwargs, diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 580318abaa2..789090a12a4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -287,8 +287,9 @@ from .trainer_callback import ( TrainerControl, TrainerState, ) -from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed +from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed from .training_args import TrainingArguments +from .training_args_seq2seq import Seq2SeqTrainingArguments from .training_args_tf import TFTrainingArguments from .utils import logging @@ -682,11 +683,13 @@ if is_torch_available(): get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, + get_scheduler, ) # Trainer from .trainer import Trainer from .trainer_pt_utils import torch_distributed_zero_first + from .trainer_seq2seq import Seq2SeqTrainer else: from .utils.dummy_pt_objects import * diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 935d2924b4d..9e65710d2ff 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -15,12 +15,13 @@ """PyTorch optimization for BERT model.""" import math -from typing import Callable, Iterable, Tuple +from typing import Callable, Iterable, Optional, Tuple, Union import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR +from .trainer_utils import SchedulerType from .utils import logging @@ -215,6 +216,56 @@ def get_polynomial_decay_schedule_with_warmup( return LambdaLR(optimizer, lr_lambda, last_epoch) +TYPE_TO_SCHEDULER_FUNCTION = { + SchedulerType.LINEAR: get_linear_schedule_with_warmup, + SchedulerType.COSINE: get_cosine_schedule_with_warmup, + SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, + SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, + SchedulerType.CONSTANT: get_constant_schedule, + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, +} + + +def get_scheduler( + name: Union[str, SchedulerType], + optimizer: Optimizer, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, +): + """ + Unified API to get any scheduler from its name. + + Args: + name (:obj:`str` or `:obj:`SchedulerType`): + The name of the scheduler to use. + optimizer (:obj:`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (:obj:`int`, `optional`): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (:obj:`int`, `optional`): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + + class AdamW(Optimizer): """ Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 10a911bd274..6756e591656 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -57,7 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available from .modeling_utils import PreTrainedModel from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING -from .optimization import AdamW, get_linear_schedule_with_warmup +from .optimization import Adafactor, AdamW, get_scheduler from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( CallbackHandler, @@ -70,6 +70,7 @@ from .trainer_callback import ( ) from .trainer_pt_utils import ( DistributedTensorGatherer, + LabelSmoother, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, @@ -320,6 +321,12 @@ class Trainer: ) self.use_apex = True + # Label smoothing + if self.args.label_smoothing_factor != 0: + self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) + else: + self.label_smoother = None + self.state = TrainerState() self.control = TrainerControl() # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the @@ -507,24 +514,32 @@ class Trainer: "weight_decay": 0.0, }, ] + optimizer_cls = Adafactor if self.args.adafactor else AdamW + if self.args.adafactor: + optimizer_cls = Adafactor + optimizer_kwargs = {"scale_parameter": False, "relative_step": False} + else: + optimizer_cls = AdamW + optimizer_kwargs = { + "betas": (self.args.adam_beta1, self.args.adam_beta2), + "eps": self.args.adam_epsilon, + } + optimizer_kwargs["lr"] = self.args.learning_rate if self.sharded_dpp: self.optimizer = OSS( params=optimizer_grouped_parameters, - optim=AdamW, - lr=self.args.learning_rate, - betas=(self.args.adam_beta1, self.args.adam_beta2), - eps=self.args.adam_epsilon, + optim=optimizer_cls, + **optimizer_kwargs, ) else: - self.optimizer = AdamW( - optimizer_grouped_parameters, - lr=self.args.learning_rate, - betas=(self.args.adam_beta1, self.args.adam_beta2), - eps=self.args.adam_epsilon, - ) + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if self.lr_scheduler is None: - self.lr_scheduler = get_linear_schedule_with_warmup( - self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + self.optimizer, + num_warmup_steps=self.args.warmup_steps, + num_training_steps=num_training_steps, ) def num_examples(self, dataloader: DataLoader) -> int: @@ -1168,8 +1183,12 @@ class Trainer: # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] - # We don't use .loss here since the model may return tuples instead of ModelOutput. - return outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + if self.label_smoother is not None and "labels" in inputs: + return self.label_smoother(outputs, inputs["labels"]) + else: + # We don't use .loss here since the model may return tuples instead of ModelOutput. + return outputs["loss"] if isinstance(outputs, dict) else outputs[0] def is_local_process_zero(self) -> bool: """ @@ -1556,11 +1575,13 @@ class Trainer: else: outputs = model(**inputs) if has_labels: + if self.label_smoother is not None and "labels" in inputs: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() if isinstance(outputs, dict): - loss = outputs["loss"].mean().detach() logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: - loss = outputs[0].mean().detach() logits = outputs[1:] else: loss = None diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 5cb45eb7bd3..89d51f5c4c3 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -19,6 +19,7 @@ Torch utilities for the Trainer class. import math import warnings from contextlib import contextmanager +from dataclasses import dataclass from typing import List, Optional, Union import numpy as np @@ -360,3 +361,32 @@ class DistributedTensorGatherer: if self._offsets[0] != self.process_length: logger.warn("Not all data has been set. Are you sure you passed all values?") return nested_truncate(self._storage, self.num_samples) + + +@dataclass +class LabelSmoother: + """ + Adds label-smoothing on a pre-computed output from a Transformers model. + + Args: + epsilon (:obj:`float`, `optional`, defaults to 0.1): + The label smoothing factor. + ignore_index (:obj:`int`, `optional`, defaults to -100): + The index in the labels to ignore when computing the loss. + """ + + epsilon: float = 0.1 + ignore_index: int = -100 + + def __call__(self, model_output, labels): + model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0] + logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1] + log_probs = -torch.nn.functional.log_softmax(logits, dim=-1) + + # Look at the ignored index and mask the corresponding log_probs. + padding_mask = labels.unsqueeze(-1).eq(self.ignore_index) + log_probs.masked_fill_(padding_mask, 0.0) + + # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): + smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum()) + return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py new file mode 100644 index 00000000000..4fc7eb59b64 --- /dev/null +++ b/src/transformers/trainer_seq2seq.py @@ -0,0 +1,231 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from packaging import version +from torch import nn +from torch.utils.data import DistributedSampler, RandomSampler +from torch.utils.data.dataset import Dataset + +from .file_utils import is_torch_tpu_available +from .trainer import Trainer +from .trainer_pt_utils import get_tpu_sampler +from .trainer_utils import PredictionOutput +from .training_args import ParallelMode +from .utils import logging + + +if version.parse(torch.__version__) >= version.parse("1.6"): + from torch.cuda.amp import autocast + + +logger = logging.get_logger(__name__) + + +class Seq2SeqTrainer(Trainer): + def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: + if isinstance(self.train_dataset, torch.utils.data.IterableDataset): + return None + elif is_torch_tpu_available(): + return get_tpu_sampler(self.train_dataset) + else: + if self.args.sortish_sampler: + self.train_dataset.make_sortish_sampler( + self.args.per_device_train_batch_size, + distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED), + ) + + return ( + RandomSampler(self.train_dataset) + if self.args.local_rank == -1 + else DistributedSampler(self.train_dataset) + ) + + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + max_length: Optional[int] = None, + num_beams: Optional[int] = None, + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init :obj:`compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (:obj:`Dataset`, `optional`): + Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, + columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the + :obj:`__len__` method. + ignore_keys (:obj:`List[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is ``"eval"`` (default) + max_length (:obj:`int`, `optional`): + The maximum target length to use when predicting with the generate method. + num_beams (:obj:`int`, `optional`): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + self._max_length = max_length + self._num_beams = num_beams + return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def predict( + self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + max_length: Optional[int] = None, + num_beams: Optional[int] = None, + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in :obj:`evaluate()`. + + Args: + test_dataset (:obj:`Dataset`): + Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the + ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` + ignore_keys (:obj:`List[str]`, `optional`): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is ``"eval"`` (default) + max_length (:obj:`int`, `optional`): + The maximum target length to use when predicting with the generate method. + num_beams (:obj:`int`, `optional`): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + + .. note:: + + If your predictions or labels have different sequence lengths (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + + Returns: `NamedTuple` A namedtuple with the following keys: + + - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`. + - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some). + - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset + contained labels). + """ + self._max_length = max_length + self._num_beams = num_beams + return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on :obj:`model` using obj:`inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to evaluate. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (:obj:`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + + if not self.args.predict_with_generate or prediction_loss_only: + return super()(self, model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys) + + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + gen_kwargs = { + "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, + "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, + } + + generated_tokens = self.model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + **gen_kwargs, + ) + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + with torch.no_grad(): + if self.use_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) + if has_labels: + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return (loss, None, None) + + labels = inputs["labels"] + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + + return (loss, generated_tokens, labels) + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.tokenizer is None: + raise ValueError( + f"Tensor need to be padded to `max_length={max_length}` but no tokenzier was passed when creating " + "this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer." + ) + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + ) + + padded_tensor = pad_token_id * torch.ones( + (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 4dc7874a0d5..da59d534b10 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -201,3 +201,12 @@ def speed_metrics(split, start_time, num_samples=None): samples_per_second = 1 / (runtime / num_samples) result[f"{split}_samples_per_second"] = round(samples_per_second, 3) return result + + +class SchedulerType(ExplicitEnum): + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 50a214d58e7..9d78ce41fe3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -19,7 +19,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required -from .trainer_utils import EvaluationStrategy +from .trainer_utils import EvaluationStrategy, SchedulerType from .utils import logging @@ -121,6 +121,9 @@ class TrainingArguments: max_steps (:obj:`int`, `optional`, defaults to -1): If set to a positive number, the total number of training steps to perform. Overrides :obj:`num_train_epochs`. + lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`): + The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible + values. warmup_steps (:obj:`int`, `optional`, defaults to 0): Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. logging_dir (:obj:`str`, `optional`): @@ -217,6 +220,13 @@ class TrainingArguments: sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): Use Sharded DDP training from `FairScale `__ (in distributed training only). This is an experimental feature. + label_smoothing_factor (:obj:`float`, `optional`, defaults to 0.0): + The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded + labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 - + label_smoothing_factor + label_smoothing_factor/num_labels` respectively. + adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of + :class:`~transformers.AdamW`. """ output_dir: str = field( @@ -246,7 +256,7 @@ class TrainingArguments: ) evaluation_strategy: EvaluationStrategy = field( default="no", - metadata={"help": "Run evaluation during training at each logging step."}, + metadata={"help": "The evaluation strategy to use."}, ) prediction_loss_only: bool = field( default=False, @@ -296,6 +306,10 @@ class TrainingArguments: default=-1, metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, ) + lr_scheduler_type: SchedulerType = field( + default="linear", + metadata={"help": "The scheduler type to use."}, + ) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."}) @@ -392,11 +406,16 @@ class TrainingArguments: default=False, metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."}, ) + label_smoothing_factor: float = field( + default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} + ) + adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."}) def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy) + self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type) if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO: self.do_eval = True if self.eval_steps is None: diff --git a/src/transformers/training_args_seq2seq.py b/src/transformers/training_args_seq2seq.py new file mode 100644 index 00000000000..8527fda1fdd --- /dev/null +++ b/src/transformers/training_args_seq2seq.py @@ -0,0 +1,42 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field + +from .file_utils import add_start_docstrings +from .training_args import TrainingArguments + + +logger = logging.getLogger(__name__) + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class Seq2SeqTrainingArguments(TrainingArguments): + """ + sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use a `sortish sampler` or not. Only possible if the underlying datasets are `Seq2SeqDataset` for + now but will become generally available in the near future. + + It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness for + the training set. + predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use generate to calculate generative metrics (ROUGE, BLEU). + """ + + sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) + predict_with_generate: bool = field( + default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 97669eff742..c2ed563ef05 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2279,6 +2279,10 @@ def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): requires_pytorch(get_polynomial_decay_schedule_with_warmup) +def get_scheduler(*args, **kwargs): + requires_pytorch(get_scheduler) + + class Trainer: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -2286,3 +2290,8 @@ class Trainer: def torch_distributed_zero_first(*args, **kwargs): requires_pytorch(torch_distributed_zero_first) + + +class Seq2SeqTrainer: + def __init__(self, *args, **kwargs): + requires_pytorch(self) diff --git a/tests/test_trainer_seq2seq.py b/tests/test_trainer_seq2seq.py new file mode 100644 index 00000000000..286ac5c2ad2 --- /dev/null +++ b/tests/test_trainer_seq2seq.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2020 the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments +from transformers.file_utils import is_datasets_available +from transformers.testing_utils import TestCasePlus, require_datasets, slow + + +if is_datasets_available(): + import datasets + + +class Seq2seqTrainerTester(TestCasePlus): + @slow + @require_datasets + def test_finetune_bert2bert(self): + + bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size + bert2bert.config.eos_token_id = tokenizer.sep_token_id + bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id + bert2bert.config.max_length = 128 + + train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]") + val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]") + + train_dataset = train_dataset.select(range(32)) + val_dataset = val_dataset.select(range(16)) + + rouge = datasets.load_metric("rouge") + + batch_size = 4 + + def _map_to_encoder_decoder_inputs(batch): + # Tokenizer will automatically set [BOS] [EOS] + inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512) + outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128) + batch["input_ids"] = inputs.input_ids + batch["attention_mask"] = inputs.attention_mask + + batch["decoder_input_ids"] = outputs.input_ids + batch["labels"] = outputs.input_ids.copy() + batch["labels"] = [ + [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"] + ] + batch["decoder_attention_mask"] = outputs.attention_mask + + assert all([len(x) == 512 for x in inputs.input_ids]) + assert all([len(x) == 128 for x in outputs.input_ids]) + + return batch + + def _compute_metrics(pred): + labels_ids = pred.label_ids + pred_ids = pred.predictions + + # all unnecessary tokens are removed + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) + + rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[ + "rouge2" + ].mid + + return { + "rouge2_precision": round(rouge_output.precision, 4), + "rouge2_recall": round(rouge_output.recall, 4), + "rouge2_fmeasure": round(rouge_output.fmeasure, 4), + } + + # map train dataset + train_dataset = train_dataset.map( + _map_to_encoder_decoder_inputs, + batched=True, + batch_size=batch_size, + remove_columns=["article", "highlights"], + ) + train_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], + ) + + # same for validation dataset + val_dataset = val_dataset.map( + _map_to_encoder_decoder_inputs, + batched=True, + batch_size=batch_size, + remove_columns=["article", "highlights"], + ) + val_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"], + ) + + output_dir = self.get_auto_remove_tmp_dir() + + training_args = Seq2SeqTrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + predict_with_generate=True, + evaluation_strategy="steps", + do_train=True, + do_eval=True, + warmup_steps=0, + eval_steps=2, + logging_steps=2, + ) + + # instantiate trainer + trainer = Seq2SeqTrainer( + model=bert2bert, + args=training_args, + compute_metrics=_compute_metrics, + train_dataset=train_dataset, + eval_dataset=val_dataset, + tokenizer=tokenizer, + ) + + # start training + trainer.train() diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index 91fe33fa478..a62e95ac6ba 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -22,7 +22,10 @@ from transformers.testing_utils import require_torch if is_torch_available(): - from transformers.trainer_pt_utils import DistributedTensorGatherer + import torch + + from transformers.modeling_outputs import SequenceClassifierOutput + from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother @require_torch @@ -56,3 +59,31 @@ class TrainerUtilsTest(unittest.TestCase): self.assertTrue(np.array_equal(result[0], predictions)) self.assertTrue(np.array_equal(result[1][0], predictions)) self.assertTrue(np.array_equal(result[1][1], predictions)) + + def test_label_smoothing(self): + epsilon = 0.1 + num_labels = 12 + random_logits = torch.randn(4, 5, num_labels) + random_labels = torch.randint(0, num_labels, (4, 5)) + loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) + model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) + label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) + log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) + expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean() + self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss)) + + # With a few -100 labels + random_labels[0, 1] = -100 + random_labels[2, 1] = -100 + random_labels[2, 3] = -100 + + loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) + model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) + label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) + log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) + # Mask the log probs with the -100 labels + log_probs[0, 1] = 0.0 + log_probs[2, 1] = 0.0 + log_probs[2, 3] = 0.0 + expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17) + self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))