diff --git a/examples/multiple-choice/run_tf_multiple_choice.py b/examples/multiple-choice/run_tf_multiple_choice.py index 26d0fcbff5d..1eb19e32fe0 100644 --- a/examples/multiple-choice/run_tf_multiple_choice.py +++ b/examples/multiple-choice/run_tf_multiple_choice.py @@ -108,7 +108,10 @@ def main(): level=logging.INFO, ) logger.warning( - "device: %s, n_gpu: %s, 16-bits training: %s", training_args.device, training_args.n_gpu, training_args.fp16, + "device: %s, n_replicas: %s, 16-bits training: %s", + training_args.device, + training_args.n_replicas, + training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) diff --git a/examples/question-answering/run_tf_squad.py b/examples/question-answering/run_tf_squad.py index 2ba8626ea2f..1c654c32bfa 100644 --- a/examples/question-answering/run_tf_squad.py +++ b/examples/question-answering/run_tf_squad.py @@ -137,9 +137,9 @@ def main(): level=logging.INFO, ) logger.info( - "n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.n_gpu, - bool(training_args.n_gpu > 1), + "n_replicas: %s, distributed training: %s, 16-bits training: %s", + training_args.n_replicas, + bool(training_args.n_replicas > 1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) diff --git a/examples/text-classification/run_tf_glue.py b/examples/text-classification/run_tf_glue.py index 6699deba906..a1e4f7a90ae 100644 --- a/examples/text-classification/run_tf_glue.py +++ b/examples/text-classification/run_tf_glue.py @@ -131,9 +131,9 @@ def main(): level=logging.INFO, ) logger.info( - "n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.n_gpu, - bool(training_args.n_gpu > 1), + "n_replicas: %s, distributed training: %s, 16-bits training: %s", + training_args.n_replicas, + bool(training_args.n_replicas > 1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) diff --git a/examples/token-classification/run_tf_ner.py b/examples/token-classification/run_tf_ner.py index d294eaebab4..056a24c74fd 100644 --- a/examples/token-classification/run_tf_ner.py +++ b/examples/token-classification/run_tf_ner.py @@ -109,9 +109,9 @@ def main(): level=logging.INFO, ) logger.info( - "n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.n_gpu, - bool(training_args.n_gpu > 1), + "n_replicas: %s, distributed training: %s, 16-bits training: %s", + training_args.n_replicas, + bool(training_args.n_replicas > 1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4728f6ff02b..4447d5443dc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -155,7 +155,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer # Trainer -from .trainer_utils import EvalPrediction +from .trainer_utils import EvalPrediction, set_seed from .training_args import TrainingArguments from .training_args_tf import TFTrainingArguments @@ -397,7 +397,7 @@ if is_torch_available(): ) # Trainer - from .trainer import Trainer, set_seed, torch_distributed_zero_first + from .trainer import Trainer, torch_distributed_zero_first from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7b974814ad4..067f793d12a 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1,7 +1,6 @@ import logging import math import os -import random import re import shutil import warnings @@ -23,7 +22,14 @@ from .data.data_collator import DataCollator, default_data_collator from .file_utils import is_apex_available, is_torch_tpu_available from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup -from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available +from .trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + EvalPrediction, + PredictionOutput, + TrainOutput, + is_wandb_available, + set_seed, +) from .training_args import TrainingArguments @@ -60,20 +66,6 @@ if is_wandb_available(): logger = logging.getLogger(__name__) -def set_seed(seed: int): - """ - Helper function for reproducible behavior to set the seed in ``random``, ``numpy`` and ``torch``. - - Args: - seed (:obj:`int`): The seed to set. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # ^^ safe to call this function even if cuda is not available - - @contextmanager def torch_distributed_zero_first(local_rank: int): """ @@ -541,8 +533,8 @@ class Trainer: self._log(logs) - if self.args.evaluate_during_training: - self.evaluate() + if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: + self.evaluate() if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference @@ -573,7 +565,7 @@ class Trainer: if self.args.max_steps > 0 and self.global_step > self.args.max_steps: train_iterator.close() break - if self.args.tpu_metrics_debug: + if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) @@ -754,7 +746,7 @@ class Trainer: self._log(output.metrics) - if self.args.tpu_metrics_debug: + if self.args.tpu_metrics_debug or self.args.debug: # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index c61d3661f36..accbf0c7cf2 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -3,7 +3,6 @@ import logging import math import os -import random from typing import Callable, Dict, Optional, Tuple import numpy as np @@ -11,7 +10,7 @@ import tensorflow as tf from .modeling_tf_utils import TFPreTrainedModel from .optimization_tf import GradientAccumulator, create_optimizer -from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available +from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available, set_seed from .training_args_tf import TFTrainingArguments @@ -22,12 +21,6 @@ if is_wandb_available(): logger = logging.getLogger(__name__) -def set_seed(seed: int): - random.seed(seed) - np.random.seed(seed) - tf.random.set_seed(seed) - - class TFTrainer: """ TFTrainer is a simple but feature-complete training and eval loop for TensorFlow, @@ -256,7 +249,7 @@ class TFTrainer: if isinstance(labels, tuple): labels = labels[0] - if self.args.n_gpu > 1: + if self.args.n_replicas > 1: for val in logits.values: if preds is None: preds = val.numpy() @@ -542,7 +535,7 @@ class TFTrainer: loss, logits = outputs[:2] if self.args.past_index >= 0: self._past = outputs[self.args.past_index] - loss += sum(self.model.losses) * (1.0 / self.args.n_gpu) + loss += sum(self.model.losses) * (1.0 / self.args.n_replicas) return loss, logits diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 619adb1a1fe..1a4e9950723 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -1,8 +1,11 @@ import os +import random from typing import Dict, NamedTuple, Optional import numpy as np +from .file_utils import is_tf_available, is_torch_available + try: import wandb @@ -21,6 +24,28 @@ def is_wandb_available(): return _has_wandb +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` + (if installed). + + Args: + seed (:obj:`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + if is_torch_available(): + import torch + + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + if is_tf_available(): + import tensorflow as tf + + tf.random.set_seed(seed) + + class EvalPrediction(NamedTuple): """ Evaluation output (always contains labels), to be used to compute metrics. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d78ec19dbe5..9609cc9147e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -97,11 +97,13 @@ class TrainingArguments: During distributed training, the rank of the process. tpu_num_cores (:obj:`int`, `optional`): When training on TPU, the mumber of TPU cores (automatically passed by launcher script). - tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`): + debug (:obj:`bool`, `optional`, defaults to :obj:`False`): When training on TPU, whether to print debug metrics or not. dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not. + eval_steps (:obj:`int`, `optional`, defaults to 1000): + Number of update steps between two evaluations. past_index (:obj:`int`, `optional`, defaults to -1): Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can make use of the past hidden states for their predictions. If this argument is set to a positive int, the @@ -202,11 +204,16 @@ class TrainingArguments: tpu_num_cores: Optional[int] = field( default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"} ) - tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"}) + tpu_metrics_debug: bool = field( + default=False, + metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"}, + ) + debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"}) dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) + eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."}) past_index: int = field( default=-1, diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 942dc138926..666068163a7 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -1,4 +1,5 @@ import logging +import warnings from dataclasses import dataclass, field from typing import Tuple @@ -80,11 +81,13 @@ class TFTrainingArguments(TrainingArguments): During distributed training, the rank of the process. tpu_num_cores (:obj:`int`, `optional`): When training on TPU, the mumber of TPU cores (automatically passed by launcher script). - tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`): - When training on TPU, whether to print debug metrics or not. + debug (:obj:`bool`, `optional`, defaults to :obj:`False`): + Wheter to activate the trace to record computation graphs and profiling information or not. dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not. + eval_steps (:obj:`int`, `optional`, defaults to 1000): + Number of update steps before two evaluations. past_index (:obj:`int`, `optional`, defaults to -1): Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can make use of the past hidden states for their predictions. If this argument is set to a positive int, the @@ -92,19 +95,11 @@ class TFTrainingArguments(TrainingArguments): at the next training step under the keyword argument ``mems``. tpu_name (:obj:`str`, `optional`): The name of the TPU the process is running on. - eval_steps (:obj:`int`, `optional`, defaults to 1000): - Number of update steps before two evaluations. - debug (:obj:`bool`, `optional`, defaults to :obj:`False`): - Wheter to activate the trace to record computation graphs and profiling information or not. """ tpu_name: str = field( default=None, metadata={"help": "Name of TPU"}, ) - eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."}) - debug: bool = field( - default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"} - ) @cached_property @tf_required @@ -150,8 +145,46 @@ class TFTrainingArguments(TrainingArguments): @property @tf_required - def n_gpu(self) -> int: + def n_replicas(self) -> int: """ - The number of replicas (GPUs or TPU cores) used in this training. + The number of replicas (CPUs, GPUs or TPU cores) used in this training. """ return self._setup_strategy.num_replicas_in_sync + + @property + def train_batch_size(self) -> int: + """ + The actual batch size for training (may differ from :obj:`per_gpu_train_batch_size` in distributed training). + """ + if self.per_gpu_train_batch_size: + logger.warning( + "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " + "version. Using `--per_device_train_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size + return per_device_batch_size * max(1, self.n_replicas) + + @property + def eval_batch_size(self) -> int: + """ + The actual batch size for evaluation (may differ from :obj:`per_gpu_eval_batch_size` in distributed training). + """ + if self.per_gpu_eval_batch_size: + logger.warning( + "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " + "version. Using `--per_device_eval_batch_size` is preferred." + ) + per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size + return per_device_batch_size * max(1, self.n_replicas) + + @property + @tf_required + def n_gpu(self) -> int: + """ + The number of replicas (CPUs, GPUs or TPU cores) used in this training. + """ + warnings.warn( + "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.", + FutureWarning, + ) + return self._setup_strategy.num_replicas_in_sync