mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Clean up diffs in Trainer/TFTrainer (#5417)
* Cleanup and unify Trainer/TFTrainer * Forgot to adapt TFTrainingArgs * In tf scripts n_gpu -> n_replicas * Update src/transformers/training_args.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Formatting * Fix typo Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
43cb03a93d
commit
734a28a767
@ -108,7 +108,10 @@ def main():
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.warning(
|
||||
"device: %s, n_gpu: %s, 16-bits training: %s", training_args.device, training_args.n_gpu, training_args.fp16,
|
||||
"device: %s, n_replicas: %s, 16-bits training: %s",
|
||||
training_args.device,
|
||||
training_args.n_replicas,
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
|
@ -137,9 +137,9 @@ def main():
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(
|
||||
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_gpu,
|
||||
bool(training_args.n_gpu > 1),
|
||||
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_replicas,
|
||||
bool(training_args.n_replicas > 1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
@ -131,9 +131,9 @@ def main():
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(
|
||||
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_gpu,
|
||||
bool(training_args.n_gpu > 1),
|
||||
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_replicas,
|
||||
bool(training_args.n_replicas > 1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
@ -109,9 +109,9 @@ def main():
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(
|
||||
"n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_gpu,
|
||||
bool(training_args.n_gpu > 1),
|
||||
"n_replicas: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.n_replicas,
|
||||
bool(training_args.n_replicas > 1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
@ -155,7 +155,7 @@ from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||
|
||||
# Trainer
|
||||
from .trainer_utils import EvalPrediction
|
||||
from .trainer_utils import EvalPrediction, set_seed
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
@ -397,7 +397,7 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
# Trainer
|
||||
from .trainer import Trainer, set_seed, torch_distributed_zero_first
|
||||
from .trainer import Trainer, torch_distributed_zero_first
|
||||
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
|
||||
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
@ -23,7 +22,14 @@ from .data.data_collator import DataCollator, default_data_collator
|
||||
from .file_utils import is_apex_available, is_torch_tpu_available
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available
|
||||
from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
EvalPrediction,
|
||||
PredictionOutput,
|
||||
TrainOutput,
|
||||
is_wandb_available,
|
||||
set_seed,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
@ -60,20 +66,6 @@ if is_wandb_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy`` and ``torch``.
|
||||
|
||||
Args:
|
||||
seed (:obj:`int`): The seed to set.
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""
|
||||
@ -541,8 +533,8 @@ class Trainer:
|
||||
|
||||
self._log(logs)
|
||||
|
||||
if self.args.evaluate_during_training:
|
||||
self.evaluate()
|
||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||
self.evaluate()
|
||||
|
||||
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
||||
# In all cases (even distributed/parallel), self.model is always a reference
|
||||
@ -573,7 +565,7 @@ class Trainer:
|
||||
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
if self.args.tpu_metrics_debug:
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
|
||||
@ -754,7 +746,7 @@ class Trainer:
|
||||
|
||||
self._log(output.metrics)
|
||||
|
||||
if self.args.tpu_metrics_debug:
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -11,7 +10,7 @@ import tensorflow as tf
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available, set_seed
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
|
||||
@ -22,12 +21,6 @@ if is_wandb_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
|
||||
class TFTrainer:
|
||||
"""
|
||||
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
|
||||
@ -256,7 +249,7 @@ class TFTrainer:
|
||||
if isinstance(labels, tuple):
|
||||
labels = labels[0]
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
if self.args.n_replicas > 1:
|
||||
for val in logits.values:
|
||||
if preds is None:
|
||||
preds = val.numpy()
|
||||
@ -542,7 +535,7 @@ class TFTrainer:
|
||||
loss, logits = outputs[:2]
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
|
||||
loss += sum(self.model.losses) * (1.0 / self.args.n_replicas)
|
||||
|
||||
return loss, logits
|
||||
|
||||
|
@ -1,8 +1,11 @@
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
|
||||
|
||||
try:
|
||||
import wandb
|
||||
@ -21,6 +24,28 @@ def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
|
||||
(if installed).
|
||||
|
||||
Args:
|
||||
seed (:obj:`int`): The seed to set.
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
"""
|
||||
Evaluation output (always contains labels), to be used to compute metrics.
|
||||
|
@ -97,11 +97,13 @@ class TrainingArguments:
|
||||
During distributed training, the rank of the process.
|
||||
tpu_num_cores (:obj:`int`, `optional`):
|
||||
When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
|
||||
tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
When training on TPU, whether to print debug metrics or not.
|
||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||
or not.
|
||||
eval_steps (:obj:`int`, `optional`, defaults to 1000):
|
||||
Number of update steps between two evaluations.
|
||||
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||
@ -202,11 +204,16 @@ class TrainingArguments:
|
||||
tpu_num_cores: Optional[int] = field(
|
||||
default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
|
||||
)
|
||||
tpu_metrics_debug: bool = field(default=False, metadata={"help": "TPU: Whether to print debug metrics"})
|
||||
tpu_metrics_debug: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"},
|
||||
)
|
||||
debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"})
|
||||
|
||||
dataloader_drop_last: bool = field(
|
||||
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
||||
)
|
||||
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
||||
|
||||
past_index: int = field(
|
||||
default=-1,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
@ -80,11 +81,13 @@ class TFTrainingArguments(TrainingArguments):
|
||||
During distributed training, the rank of the process.
|
||||
tpu_num_cores (:obj:`int`, `optional`):
|
||||
When training on TPU, the mumber of TPU cores (automatically passed by launcher script).
|
||||
tpu_metrics_debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
When training on TPU, whether to print debug metrics or not.
|
||||
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Wheter to activate the trace to record computation graphs and profiling information or not.
|
||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||
or not.
|
||||
eval_steps (:obj:`int`, `optional`, defaults to 1000):
|
||||
Number of update steps before two evaluations.
|
||||
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||
@ -92,19 +95,11 @@ class TFTrainingArguments(TrainingArguments):
|
||||
at the next training step under the keyword argument ``mems``.
|
||||
tpu_name (:obj:`str`, `optional`):
|
||||
The name of the TPU the process is running on.
|
||||
eval_steps (:obj:`int`, `optional`, defaults to 1000):
|
||||
Number of update steps before two evaluations.
|
||||
debug (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Wheter to activate the trace to record computation graphs and profiling information or not.
|
||||
"""
|
||||
|
||||
tpu_name: str = field(
|
||||
default=None, metadata={"help": "Name of TPU"},
|
||||
)
|
||||
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
||||
debug: bool = field(
|
||||
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@tf_required
|
||||
@ -150,8 +145,46 @@ class TFTrainingArguments(TrainingArguments):
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def n_gpu(self) -> int:
|
||||
def n_replicas(self) -> int:
|
||||
"""
|
||||
The number of replicas (GPUs or TPU cores) used in this training.
|
||||
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
||||
"""
|
||||
return self._setup_strategy.num_replicas_in_sync
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
"""
|
||||
The actual batch size for training (may differ from :obj:`per_gpu_train_batch_size` in distributed training).
|
||||
"""
|
||||
if self.per_gpu_train_batch_size:
|
||||
logger.warning(
|
||||
"Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
|
||||
"version. Using `--per_device_train_batch_size` is preferred."
|
||||
)
|
||||
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
|
||||
return per_device_batch_size * max(1, self.n_replicas)
|
||||
|
||||
@property
|
||||
def eval_batch_size(self) -> int:
|
||||
"""
|
||||
The actual batch size for evaluation (may differ from :obj:`per_gpu_eval_batch_size` in distributed training).
|
||||
"""
|
||||
if self.per_gpu_eval_batch_size:
|
||||
logger.warning(
|
||||
"Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
|
||||
"version. Using `--per_device_eval_batch_size` is preferred."
|
||||
)
|
||||
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
|
||||
return per_device_batch_size * max(1, self.n_replicas)
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def n_gpu(self) -> int:
|
||||
"""
|
||||
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
||||
"""
|
||||
warnings.warn(
|
||||
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._setup_strategy.num_replicas_in_sync
|
||||
|
Loading…
Reference in New Issue
Block a user