[WIP] Enable reproducibility for distributed trainings (#16907)

* add seed worker and set_deterministic_seed_for_cuda function to enforce reproducability

* change function name to enable determinism, add docstrings, reproducability support for tf

* change function name to enable_determinism_for_distributed_training

* revert changes in set_seed and call set_seed within enable_full_determinism

* add one position argument for seed_worker function

* add full_determinism flag in training args and call enable_full_determinism when it is true

* add enable_full_determinism to documentation

* apply make fixup after the last commit

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
hasan salim kanmaz 2022-05-11 15:37:13 +02:00 committed by GitHub
parent 5229744b26
commit c33f6046c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 60 additions and 6 deletions

View File

@ -22,6 +22,8 @@ Most of those are only useful if you are studying the code of the Trainer in the
[[autodoc]] IntervalStrategy
[[autodoc]] enable_full_determinism
[[autodoc]] set_seed
[[autodoc]] torch_distributed_zero_first

View File

@ -372,7 +372,7 @@ _import_structure = {
"TrainerControl",
"TrainerState",
],
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "enable_full_determinism", "set_seed"],
"training_args": ["TrainingArguments"],
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
"training_args_tf": ["TFTrainingArguments"],
@ -2810,7 +2810,7 @@ if TYPE_CHECKING:
TrainerControl,
TrainerState,
)
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, enable_full_determinism, set_seed
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments

View File

@ -115,10 +115,12 @@ from .trainer_utils import (
default_compute_objective,
default_hp_space,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
get_last_checkpoint,
has_length,
number_of_arguments,
seed_worker,
set_seed,
speed_metrics,
)
@ -300,7 +302,7 @@ class Trainer:
args = TrainingArguments(output_dir=output_dir)
self.args = args
# Seed must be set before instantiating the model when using model
set_seed(self.args.seed)
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
self.is_in_train = False
@ -746,6 +748,7 @@ class Trainer:
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
worker_init_fn=seed_worker,
)
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
@ -1254,7 +1257,7 @@ class Trainer:
model_reloaded = False
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
set_seed(args.seed)
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
self.model = self.call_model_init(trial)
model_reloaded = True
# Reinitializes optimizer and scheduler

View File

@ -34,7 +34,14 @@ from tensorflow.python.distribute.values import PerReplica
from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
EvalPrediction,
IntervalStrategy,
PredictionOutput,
enable_full_determinism,
set_seed,
)
from .training_args_tf import TFTrainingArguments
from .utils import logging
@ -134,7 +141,7 @@ class TFTrainer:
"see https://www.comet.ml/docs/python-sdk/huggingface/"
)
set_seed(self.args.seed)
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
def get_train_tfdataset(self) -> tf.data.Dataset:
"""

View File

@ -47,6 +47,39 @@ if is_tf_available():
import tensorflow as tf
def seed_worker(_):
"""
Helper function to set worker seed during Dataloader initialization.
"""
worker_seed = torch.initial_seed() % 2**32
set_seed(worker_seed)
def enable_full_determinism(seed: int):
"""
Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
- https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
"""
# set seed first
set_seed(seed)
if is_torch_available():
#  Enable PyTorch deterministic mode. This potentially requires either the environment
#  variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if is_tf_available():
tf.config.experimental.enable_op_determinism()
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).

View File

@ -448,6 +448,9 @@ class TrainingArguments:
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
full_determinism (`bool`, *optional*, defaults to `False`)
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training
"""
output_dir: str = field(
@ -816,6 +819,12 @@ class TrainingArguments:
"help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached"
},
)
full_determinism: bool = field(
default=False,
metadata={
"help": "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed training"
},
)
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).