Clean up diffs in Trainer/TFTrainer (#5417)

* Cleanup and unify Trainer/TFTrainer

* Forgot to adapt TFTrainingArgs

* In tf scripts n_gpu -> n_replicas

* Update src/transformers/training_args.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* Formatting

* Fix typo

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger 2020-07-01 11:00:20 -04:00 committed by GitHub
parent 43cb03a93d
commit 734a28a767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 109 additions and 56 deletions

View File

@ -108,7 +108,10 @@ def main():
level=logging.INFO,
)
logger.warning(
"device: %s, n_gpu: %s, 16-bits training: %s", training_args.device, training_args.n_gpu, training_args.fp16,
"device: %s, n_replicas: %s, 16-bits training: %s",
training_args.device,
training_args.n_replicas,
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

View File

@ -137,9 +137,9 @@ def main():
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_replicas,
bool(training_args.n_replicas > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

View File

@ -131,9 +131,9 @@ def main():
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_replicas,
bool(training_args.n_replicas > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

View File

@ -109,9 +109,9 @@ def main():
level=logging.INFO,
)
logger.info(
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.n_gpu,
bool(training_args.n_gpu > 1),
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
training_args.n_replicas,
bool(training_args.n_replicas > 1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

View File

@ -155,7 +155,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
# Trainer
from .trainer_utils import EvalPrediction
from .trainer_utils import EvalPrediction, set_seed
from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments
@ -397,7 +397,7 @@ if is_torch_available():
)
# Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first
from .trainer import Trainer, torch_distributed_zero_first
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments

View File

@ -1,7 +1,6 @@
import logging
import math
import os
import random
import re
import shutil
import warnings
@ -23,7 +22,14 @@ from .data.data_collator import DataCollator, default_data_collator
from .file_utils import is_apex_available, is_torch_tpu_available
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
EvalPrediction,
PredictionOutput,
TrainOutput,
is_wandb_available,
set_seed,
)
from .training_args import TrainingArguments
@ -60,20 +66,6 @@ if is_wandb_available():
logger = logging.getLogger(__name__)
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy`` and ``torch``.
Args:
seed (:obj:`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
@ -541,8 +533,8 @@ class Trainer:
self._log(logs)
if self.args.evaluate_during_training:
self.evaluate()
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
self.evaluate()
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
# In all cases (even distributed/parallel), self.model is always a reference
@ -573,7 +565,7 @@ class Trainer:
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
train_iterator.close()
break
if self.args.tpu_metrics_debug:
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
@ -754,7 +746,7 @@ class Trainer:
self._log(output.metrics)
if self.args.tpu_metrics_debug:
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())

View File

@ -3,7 +3,6 @@
import logging
import math
import os
import random
from typing import Callable, Dict, Optional, Tuple
import numpy as np
@ -11,7 +10,7 @@ import tensorflow as tf
from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available, set_seed
from .training_args_tf import TFTrainingArguments
@ -22,12 +21,6 @@ if is_wandb_available():
logger = logging.getLogger(__name__)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
class TFTrainer:
"""
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
@ -256,7 +249,7 @@ class TFTrainer:
if isinstance(labels, tuple):
labels = labels[0]
if self.args.n_gpu > 1:
if self.args.n_replicas > 1:
for val in logits.values:
if preds is None:
preds = val.numpy()
@ -542,7 +535,7 @@ class TFTrainer:
loss, logits = outputs[:2]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
loss += sum(self.model.losses) * (1.0 / self.args.n_replicas)
return loss, logits

View File

@ -1,8 +1,11 @@
import os
import random
from typing import Dict, NamedTuple, Optional
import numpy as np
from .file_utils import is_tf_available, is_torch_available
try:
import wandb
@ -21,6 +24,28 @@ def is_wandb_available():
return _has_wandb
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
(if installed).
Args:
seed (:obj:`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
if is_torch_available():
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_tf_available():
import tensorflow as tf
tf.random.set_seed(seed)
class EvalPrediction(NamedTuple):
"""
Evaluation output (always contains labels), to be used to compute metrics.

View File

@ -97,11 +97,13 @@ class TrainingArguments:
During distributed training, the rank of the process.
tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
When training on TPU, whether to print debug metrics or not.
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not.
eval_steps (:obj:`int`, `optional`, defaults to 1000):
Number of update steps between two evaluations.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
@ -202,11 +204,16 @@ class TrainingArguments:
tpu_num_cores: Optional[int] = field(
default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
)
tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"})
tpu_metrics_debug: bool = field(
default=False,
metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"},
)
debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"})
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
past_index: int = field(
default=-1,

View File

@ -1,4 +1,5 @@
import logging
import warnings
from dataclasses import dataclass, field
from typing import Tuple
@ -80,11 +81,13 @@ class TFTrainingArguments(TrainingArguments):
During distributed training, the rank of the process.
tpu_num_cores (:obj:`int`, `optional`):
When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
When training on TPU, whether to print debug metrics or not.
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
Wheter to activate the trace to record computation graphs and profiling information or not.
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not.
eval_steps (:obj:`int`, `optional`, defaults to 1000):
Number of update steps before two evaluations.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
@ -92,19 +95,11 @@ class TFTrainingArguments(TrainingArguments):
at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on.
eval_steps (:obj:`int`, `optional`, defaults to 1000):
Number of update steps before two evaluations.
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
Wheter to activate the trace to record computation graphs and profiling information or not.
"""
tpu_name: str = field(
default=None, metadata={"help": "Name of TPU"},
)
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
debug: bool = field(
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
)
@cached_property
@tf_required
@ -150,8 +145,46 @@ class TFTrainingArguments(TrainingArguments):
@property
@tf_required
def n_gpu(self) -> int:
def n_replicas(self) -> int:
"""
The number of replicas (GPUs or TPU cores) used in this training.
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
return self._setup_strategy.num_replicas_in_sync
@property
def train_batch_size(self) -> int:
"""
The actual batch size for training (may differ from :obj:`per_gpu_train_batch_size` in distributed training).
"""
if self.per_gpu_train_batch_size:
logger.warning(
"Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
"version. Using `--per_device_train_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
return per_device_batch_size * max(1, self.n_replicas)
@property
def eval_batch_size(self) -> int:
"""
The actual batch size for evaluation (may differ from :obj:`per_gpu_eval_batch_size` in distributed training).
"""
if self.per_gpu_eval_batch_size:
logger.warning(
"Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
"version. Using `--per_device_eval_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
return per_device_batch_size * max(1, self.n_replicas)
@property
@tf_required
def n_gpu(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
warnings.warn(
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
FutureWarning,
)
return self._setup_strategy.num_replicas_in_sync