mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 11:41:51 +06:00
[WIP] Enable reproducibility for distributed trainings (#16907)
* add seed worker and set_deterministic_seed_for_cuda function to enforce reproducability * change function name to enable determinism, add docstrings, reproducability support for tf * change function name to enable_determinism_for_distributed_training * revert changes in set_seed and call set_seed within enable_full_determinism * add one position argument for seed_worker function * add full_determinism flag in training args and call enable_full_determinism when it is true * add enable_full_determinism to documentation * apply make fixup after the last commit * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
5229744b26
commit
c33f6046c3
@ -22,6 +22,8 @@ Most of those are only useful if you are studying the code of the Trainer in the
|
|||||||
|
|
||||||
[[autodoc]] IntervalStrategy
|
[[autodoc]] IntervalStrategy
|
||||||
|
|
||||||
|
[[autodoc]] enable_full_determinism
|
||||||
|
|
||||||
[[autodoc]] set_seed
|
[[autodoc]] set_seed
|
||||||
|
|
||||||
[[autodoc]] torch_distributed_zero_first
|
[[autodoc]] torch_distributed_zero_first
|
||||||
|
@ -372,7 +372,7 @@ _import_structure = {
|
|||||||
"TrainerControl",
|
"TrainerControl",
|
||||||
"TrainerState",
|
"TrainerState",
|
||||||
],
|
],
|
||||||
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"],
|
"trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "enable_full_determinism", "set_seed"],
|
||||||
"training_args": ["TrainingArguments"],
|
"training_args": ["TrainingArguments"],
|
||||||
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
|
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
|
||||||
"training_args_tf": ["TFTrainingArguments"],
|
"training_args_tf": ["TFTrainingArguments"],
|
||||||
@ -2810,7 +2810,7 @@ if TYPE_CHECKING:
|
|||||||
TrainerControl,
|
TrainerControl,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed
|
from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, enable_full_determinism, set_seed
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
from .training_args_seq2seq import Seq2SeqTrainingArguments
|
from .training_args_seq2seq import Seq2SeqTrainingArguments
|
||||||
from .training_args_tf import TFTrainingArguments
|
from .training_args_tf import TFTrainingArguments
|
||||||
|
@ -115,10 +115,12 @@ from .trainer_utils import (
|
|||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
default_hp_space,
|
default_hp_space,
|
||||||
denumpify_detensorize,
|
denumpify_detensorize,
|
||||||
|
enable_full_determinism,
|
||||||
find_executable_batch_size,
|
find_executable_batch_size,
|
||||||
get_last_checkpoint,
|
get_last_checkpoint,
|
||||||
has_length,
|
has_length,
|
||||||
number_of_arguments,
|
number_of_arguments,
|
||||||
|
seed_worker,
|
||||||
set_seed,
|
set_seed,
|
||||||
speed_metrics,
|
speed_metrics,
|
||||||
)
|
)
|
||||||
@ -300,7 +302,7 @@ class Trainer:
|
|||||||
args = TrainingArguments(output_dir=output_dir)
|
args = TrainingArguments(output_dir=output_dir)
|
||||||
self.args = args
|
self.args = args
|
||||||
# Seed must be set before instantiating the model when using model
|
# Seed must be set before instantiating the model when using model
|
||||||
set_seed(self.args.seed)
|
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||||
self.hp_name = None
|
self.hp_name = None
|
||||||
self.deepspeed = None
|
self.deepspeed = None
|
||||||
self.is_in_train = False
|
self.is_in_train = False
|
||||||
@ -746,6 +748,7 @@ class Trainer:
|
|||||||
drop_last=self.args.dataloader_drop_last,
|
drop_last=self.args.dataloader_drop_last,
|
||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
pin_memory=self.args.dataloader_pin_memory,
|
pin_memory=self.args.dataloader_pin_memory,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
|
||||||
@ -1254,7 +1257,7 @@ class Trainer:
|
|||||||
model_reloaded = False
|
model_reloaded = False
|
||||||
if self.model_init is not None:
|
if self.model_init is not None:
|
||||||
# Seed must be set before instantiating the model when using model_init.
|
# Seed must be set before instantiating the model when using model_init.
|
||||||
set_seed(args.seed)
|
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||||
self.model = self.call_model_init(trial)
|
self.model = self.call_model_init(trial)
|
||||||
model_reloaded = True
|
model_reloaded = True
|
||||||
# Reinitializes optimizer and scheduler
|
# Reinitializes optimizer and scheduler
|
||||||
|
@ -34,7 +34,14 @@ from tensorflow.python.distribute.values import PerReplica
|
|||||||
|
|
||||||
from .modeling_tf_utils import TFPreTrainedModel
|
from .modeling_tf_utils import TFPreTrainedModel
|
||||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed
|
from .trainer_utils import (
|
||||||
|
PREFIX_CHECKPOINT_DIR,
|
||||||
|
EvalPrediction,
|
||||||
|
IntervalStrategy,
|
||||||
|
PredictionOutput,
|
||||||
|
enable_full_determinism,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
from .training_args_tf import TFTrainingArguments
|
from .training_args_tf import TFTrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@ -134,7 +141,7 @@ class TFTrainer:
|
|||||||
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
||||||
)
|
)
|
||||||
|
|
||||||
set_seed(self.args.seed)
|
enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
|
||||||
|
|
||||||
def get_train_tfdataset(self) -> tf.data.Dataset:
|
def get_train_tfdataset(self) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
|
@ -47,6 +47,39 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def seed_worker(_):
|
||||||
|
"""
|
||||||
|
Helper function to set worker seed during Dataloader initialization.
|
||||||
|
"""
|
||||||
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
|
set_seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_full_determinism(seed: int):
|
||||||
|
"""
|
||||||
|
Helper function for reproducible behavior during distributed training. See
|
||||||
|
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
|
||||||
|
- https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
|
||||||
|
"""
|
||||||
|
# set seed first
|
||||||
|
set_seed(seed)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
# Enable PyTorch deterministic mode. This potentially requires either the environment
|
||||||
|
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
|
||||||
|
# depending on the CUDA version, so we set them both here
|
||||||
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
|
# Enable CUDNN deterministic mode
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
tf.config.experimental.enable_op_determinism()
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
"""
|
"""
|
||||||
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
|
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
|
||||||
|
@ -448,6 +448,9 @@ class TrainingArguments:
|
|||||||
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
|
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
|
||||||
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
|
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
|
||||||
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
|
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
|
||||||
|
full_determinism (`bool`, *optional*, defaults to `False`)
|
||||||
|
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
|
||||||
|
distributed training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@ -816,6 +819,12 @@ class TrainingArguments:
|
|||||||
"help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached"
|
"help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
full_determinism: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed training"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
|
||||||
|
Loading…
Reference in New Issue
Block a user