mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
Add option to log only once in multinode training (#11819)
* Add option to long only once in multinode training * Use an alternate property
This commit is contained in:
parent
b8344a274f
commit
f086652b16
@ -44,7 +44,7 @@ from transformers import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import CaptureLogger
|
from transformers.testing_utils import CaptureLogger
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -210,7 +210,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -43,7 +43,7 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -219,7 +219,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -39,7 +39,7 @@ from transformers import (
|
|||||||
XLNetLMHeadModel,
|
XLNetLMHeadModel,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -216,7 +216,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -41,7 +41,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.file_utils import PaddingStrategy
|
from transformers.file_utils import PaddingStrategy
|
||||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -243,7 +243,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -40,7 +40,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from utils_qa import postprocess_qa_predictions
|
from utils_qa import postprocess_qa_predictions
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -236,7 +236,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -39,7 +39,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from utils_qa import postprocess_qa_predictions_with_beam_search
|
from utils_qa import postprocess_qa_predictions_with_beam_search
|
||||||
|
|
||||||
@ -227,7 +227,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -235,7 +235,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -41,7 +41,7 @@ from transformers import (
|
|||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import is_offline_mode
|
from transformers.file_utils import is_offline_mode
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -292,7 +292,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -216,7 +216,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -224,7 +224,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -40,7 +40,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -186,7 +186,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -195,7 +195,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -40,7 +40,7 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -209,7 +209,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
@ -44,7 +44,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -268,7 +268,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -276,7 +276,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
|
@ -1781,21 +1781,16 @@ class Trainer:
|
|||||||
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
|
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
|
||||||
machines) main process.
|
machines) main process.
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
return self.args.local_process_index == 0
|
||||||
return xm.is_master_ordinal(local=True)
|
|
||||||
elif is_sagemaker_mp_enabled():
|
|
||||||
return smp.local_rank() == 0
|
|
||||||
else:
|
|
||||||
return self.args.local_rank in [-1, 0]
|
|
||||||
|
|
||||||
def is_world_process_zero(self) -> bool:
|
def is_world_process_zero(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Whether or not this process is the global main process (when training in a distributed fashion on several
|
Whether or not this process is the global main process (when training in a distributed fashion on several
|
||||||
machines, this is only going to be :obj:`True` for one process).
|
machines, this is only going to be :obj:`True` for one process).
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
|
||||||
return xm.is_master_ordinal(local=False)
|
# process index.
|
||||||
elif is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
return smp.rank() == 0
|
return smp.rank() == 0
|
||||||
else:
|
else:
|
||||||
return self.args.process_index == 0
|
return self.args.process_index == 0
|
||||||
|
@ -316,6 +316,8 @@ class TrainingArguments:
|
|||||||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||||
details.
|
details.
|
||||||
|
log_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
In multinode distributed training, whether to log once per node, or only on the main node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@ -559,6 +561,12 @@ class TrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
||||||
)
|
)
|
||||||
|
log_on_each_node: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
|
||||||
|
},
|
||||||
|
)
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
mp_parameters: str = field(
|
mp_parameters: str = field(
|
||||||
default="",
|
default="",
|
||||||
@ -834,7 +842,7 @@ class TrainingArguments:
|
|||||||
@torch_required
|
@torch_required
|
||||||
def process_index(self):
|
def process_index(self):
|
||||||
"""
|
"""
|
||||||
The number of processes used in parallel.
|
The index of the current process used.
|
||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.get_ordinal()
|
return xm.get_ordinal()
|
||||||
@ -846,6 +854,35 @@ class TrainingArguments:
|
|||||||
return torch.distributed.get_rank()
|
return torch.distributed.get_rank()
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
@torch_required
|
||||||
|
def local_process_index(self):
|
||||||
|
"""
|
||||||
|
The index of the local process used.
|
||||||
|
"""
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
return xm.get_ordinal(local=True)
|
||||||
|
elif is_sagemaker_mp_enabled():
|
||||||
|
return smp.local_rank()
|
||||||
|
elif is_sagemaker_dp_enabled():
|
||||||
|
return sm_dist.get_rank()
|
||||||
|
elif self.local_rank != -1:
|
||||||
|
return self.local_rank
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_log(self):
|
||||||
|
"""
|
||||||
|
Whether or not the current process should produce log.
|
||||||
|
"""
|
||||||
|
if self.log_on_each_node:
|
||||||
|
return self.local_process_index == 0
|
||||||
|
else:
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
return smp.rank() == 0
|
||||||
|
else:
|
||||||
|
return self.process_index == 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def place_model_on_device(self):
|
def place_model_on_device(self):
|
||||||
"""
|
"""
|
||||||
|
@ -43,7 +43,7 @@ from transformers import (
|
|||||||
default_data_collator,
|
default_data_collator,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -226,7 +226,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -234,7 +234,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ from transformers import ( # Trainer,; TrainingArguments,
|
|||||||
# Will import SageMaker Model parallelism specific Trainer
|
# Will import SageMaker Model parallelism specific Trainer
|
||||||
from transformers.sagemaker import SageMakerTrainer as Trainer
|
from transformers.sagemaker import SageMakerTrainer as Trainer
|
||||||
from transformers.sagemaker import SageMakerTrainingArguments as TrainingArguments
|
from transformers.sagemaker import SageMakerTrainingArguments as TrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
|
|
||||||
|
|
||||||
@ -210,7 +210,7 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -218,7 +218,7 @@ def main():
|
|||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
if is_main_process(training_args.local_rank):
|
if training_args.should_log:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
Loading…
Reference in New Issue
Block a user