From f086652b16e59bece9571fb9a266557ad3181b2a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 25 May 2021 08:03:43 -0400 Subject: [PATCH] Add option to log only once in multinode training (#11819) * Add option to long only once in multinode training * Use an alternate property --- examples/pytorch/language-modeling/run_clm.py | 6 +-- examples/pytorch/language-modeling/run_mlm.py | 6 +-- examples/pytorch/language-modeling/run_plm.py | 6 +-- examples/pytorch/multiple-choice/run_swag.py | 6 +-- examples/pytorch/question-answering/run_qa.py | 6 +-- .../question-answering/run_qa_beam_search.py | 6 +-- .../summarization/run_summarization.py | 6 +-- .../pytorch/text-classification/run_glue.py | 6 +-- .../pytorch/text-classification/run_xnli.py | 6 +-- .../pytorch/token-classification/run_ner.py | 6 +-- .../pytorch/translation/run_translation.py | 6 +-- src/transformers/trainer.py | 13 ++----- src/transformers/training_args.py | 39 ++++++++++++++++++- .../run_{{cookiecutter.example_shortcut}}.py | 6 +-- .../pytorch/run_glue_model_parallelism.py | 6 +-- 15 files changed, 81 insertions(+), 49 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 9d6e40c58a0..7aed40ed837 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -44,7 +44,7 @@ from transformers import ( set_seed, ) from transformers.testing_utils import CaptureLogger -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -202,7 +202,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -210,7 +210,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 9085e6fe0c8..32a4bb537fb 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -43,7 +43,7 @@ from transformers import ( TrainingArguments, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -211,7 +211,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -219,7 +219,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index 38f57768edf..f5cace2b6b0 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -39,7 +39,7 @@ from transformers import ( XLNetLMHeadModel, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -208,7 +208,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -216,7 +216,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py index 3c9bfce866d..4caa0bb5af3 100755 --- a/examples/pytorch/multiple-choice/run_swag.py +++ b/examples/pytorch/multiple-choice/run_swag.py @@ -41,7 +41,7 @@ from transformers import ( ) from transformers.file_utils import PaddingStrategy from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -235,7 +235,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -243,7 +243,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/question-answering/run_qa.py b/examples/pytorch/question-answering/run_qa.py index 0a48770a694..27155208be5 100755 --- a/examples/pytorch/question-answering/run_qa.py +++ b/examples/pytorch/question-answering/run_qa.py @@ -40,7 +40,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version from utils_qa import postprocess_qa_predictions @@ -228,7 +228,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -236,7 +236,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/question-answering/run_qa_beam_search.py b/examples/pytorch/question-answering/run_qa_beam_search.py index e097b5bea74..9cd1f39258d 100755 --- a/examples/pytorch/question-answering/run_qa_beam_search.py +++ b/examples/pytorch/question-answering/run_qa_beam_search.py @@ -39,7 +39,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version from utils_qa import postprocess_qa_predictions_with_beam_search @@ -227,7 +227,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -235,7 +235,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 690dede77c8..eebf5264ee5 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -41,7 +41,7 @@ from transformers import ( set_seed, ) from transformers.file_utils import is_offline_mode -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -284,7 +284,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -292,7 +292,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() logger.info(f"Training/evaluation parameters {training_args}") diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 5953aa6cdcf..1b08def9c62 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -40,7 +40,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -216,7 +216,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -224,7 +224,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/text-classification/run_xnli.py b/examples/pytorch/text-classification/run_xnli.py index 6327c8f8d81..a409d283b45 100755 --- a/examples/pytorch/text-classification/run_xnli.py +++ b/examples/pytorch/text-classification/run_xnli.py @@ -40,7 +40,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -186,7 +186,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -195,7 +195,7 @@ def main(): ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/token-classification/run_ner.py b/examples/pytorch/token-classification/run_ner.py index 4ff79088cef..f0f69f9e39b 100755 --- a/examples/pytorch/token-classification/run_ner.py +++ b/examples/pytorch/token-classification/run_ner.py @@ -40,7 +40,7 @@ from transformers import ( TrainingArguments, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -201,7 +201,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -209,7 +209,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index ed880b2e399..ea7a35719aa 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -44,7 +44,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -268,7 +268,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -276,7 +276,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() logger.info(f"Training/evaluation parameters {training_args}") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 70836cac716..aa85ed8ab95 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1781,21 +1781,16 @@ class Trainer: Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several machines) main process. """ - if is_torch_tpu_available(): - return xm.is_master_ordinal(local=True) - elif is_sagemaker_mp_enabled(): - return smp.local_rank() == 0 - else: - return self.args.local_rank in [-1, 0] + return self.args.local_process_index == 0 def is_world_process_zero(self) -> bool: """ Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be :obj:`True` for one process). """ - if is_torch_tpu_available(): - return xm.is_master_ordinal(local=False) - elif is_sagemaker_mp_enabled(): + # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global + # process index. + if is_sagemaker_mp_enabled(): return smp.rank() == 0 else: return self.args.process_index == 0 diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e1cb62cbab8..677afe4974c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -316,6 +316,8 @@ class TrainingArguments: :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See the `example scripts `__ for more details. + log_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`True`): + In multinode distributed training, whether to log once per node, or only on the main node. """ output_dir: str = field( @@ -559,6 +561,12 @@ class TrainingArguments: default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}, ) + log_on_each_node: bool = field( + default=True, + metadata={ + "help": "When doing a multinode distributed training, whether to log once per node or just once on the main node." + }, + ) _n_gpu: int = field(init=False, repr=False, default=-1) mp_parameters: str = field( default="", @@ -834,7 +842,7 @@ class TrainingArguments: @torch_required def process_index(self): """ - The number of processes used in parallel. + The index of the current process used. """ if is_torch_tpu_available(): return xm.get_ordinal() @@ -846,6 +854,35 @@ class TrainingArguments: return torch.distributed.get_rank() return 0 + @property + @torch_required + def local_process_index(self): + """ + The index of the local process used. + """ + if is_torch_tpu_available(): + return xm.get_ordinal(local=True) + elif is_sagemaker_mp_enabled(): + return smp.local_rank() + elif is_sagemaker_dp_enabled(): + return sm_dist.get_rank() + elif self.local_rank != -1: + return self.local_rank + return 0 + + @property + def should_log(self): + """ + Whether or not the current process should produce log. + """ + if self.log_on_each_node: + return self.local_process_index == 0 + else: + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.process_index == 0 + @property def place_model_on_device(self): """ diff --git a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py index 48590fe1671..a7af2159832 100755 --- a/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py +++ b/templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py @@ -43,7 +43,7 @@ from transformers import ( default_data_collator, set_seed, ) -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint logger = logging.getLogger(__name__) @@ -226,7 +226,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -234,7 +234,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() logger.info(f"Training/evaluation parameters {training_args}") diff --git a/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py index 1476a687a90..2021392930d 100644 --- a/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py +++ b/tests/sagemaker/scripts/pytorch/run_glue_model_parallelism.py @@ -42,7 +42,7 @@ from transformers import ( # Trainer,; TrainingArguments, # Will import SageMaker Model parallelism specific Trainer from transformers.sagemaker import SageMakerTrainer as Trainer from transformers.sagemaker import SageMakerTrainingArguments as TrainingArguments -from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version @@ -210,7 +210,7 @@ def main(): datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) # Log on each process the small summary: logger.warning( @@ -218,7 +218,7 @@ def main(): + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): - if is_main_process(training_args.local_rank): + if training_args.should_log: transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format()