[trainer] 2 bug fixes and a rename (#12309)

* bug fixes and a rename

* add extended DDP test
This commit is contained in:
Stas Bekman 2021-06-22 11:13:23 -07:00 committed by GitHub
parent 64029abe4c
commit ebe5413589
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 112 additions and 18 deletions

View File

@ -153,7 +153,7 @@ Here is an example of how this can be used in an application:
)
# set the main code and the modules it uses to the same log-level according to the node
log_level = training_args.get_node_log_level()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)

View File

@ -245,7 +245,7 @@ def main():
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_node_log_level()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)

View File

@ -291,7 +291,7 @@ class Trainer:
self._memory_tracker.start()
# set the correct log level depending on the node
log_level = args.get_node_log_level()
log_level = args.get_process_log_level()
logging.set_verbosity(log_level)
# force device and distributed setup init explicitly

View File

@ -603,7 +603,9 @@ class TrainingArguments:
if env_local_rank != -1 and env_local_rank != self.local_rank:
self.local_rank = env_local_rank
# convert to int
self.log_level = trainer_log_levels[self.log_level]
self.log_level_replica = trainer_log_levels[self.log_level_replica]
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
@ -914,7 +916,20 @@ class TrainingArguments:
else:
return self.process_index == 0
def get_node_log_level(self):
def get_process_log_level(self):
"""
Returns the log level to be used depending on whether this process is the main process of node 0, main process
of node non-0, or a non-main process.
For the main process the log level defaults to ``logging.INFO`` unless overridden by ``log_level`` argument.
For the replica processes the log level defaults to ``logging.WARNING`` unless overridden by
``log_level_replica`` argument.
The choice between the main and replica process settings is made according to the return value of
``should_log``.
"""
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica
return log_level_main_node if self.should_log else log_level_replica_node

View File

@ -765,7 +765,6 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
--eval_steps {eval_steps}
--group_by_length
--label_smoothing_factor 0.1
--adafactor
--source_lang en
--target_lang ro
--report_to none

View File

@ -14,6 +14,7 @@
import math
import os
import re
import sys
import unittest
from unittest.mock import patch
@ -21,6 +22,7 @@ from unittest.mock import patch
from transformers.file_utils import is_apex_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
CaptureStderr,
ExtendSysPath,
TestCasePlus,
execute_subprocess_async,
@ -68,7 +70,15 @@ def require_apex(test_case):
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
def run_seq2seq_quick(
self,
distributed=False,
extra_args_str=None,
predict_with_generate=True,
do_train=True,
do_eval=True,
do_predict=True,
):
output_dir = self.run_trainer(
eval_steps=1,
max_len=12,
@ -77,8 +87,15 @@ class TestTrainerExt(TestCasePlus):
distributed=distributed,
extra_args_str=extra_args_str,
predict_with_generate=predict_with_generate,
do_train=do_train,
do_eval=do_eval,
do_predict=do_predict,
)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
if not do_eval:
return
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
@ -145,6 +162,49 @@ class TestTrainerExt(TestCasePlus):
# to reproduce the problem set distributed=False
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
@require_torch_multi_gpu
def test_trainer_log_level_replica(self):
log_info_string = "Running training"
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
# test with the default log_level - should be info and thus log info once
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 1)
# test with low log_level and log_level_replica - should be noisy on all processes
# now the info string should appear twice on 2 processes
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="--log_level debug --log_level_replica debug",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 2)
# test with high log_level and low log_level_replica
# now the info string should appear once only on the replica
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="--log_level error --log_level_replica debug",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 1)
# test with high log_level and log_level_replica - should be quiet on all processes
with CaptureStderr() as cl:
self.run_seq2seq_quick(
**kwargs,
extra_args_str="--log_level error --log_level_replica error",
)
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, 0)
@slow
def test_run_seq2seq_slow(self):
output_dir = self.run_trainer(
@ -181,10 +241,13 @@ class TestTrainerExt(TestCasePlus):
distributed: bool = False,
extra_args_str: str = None,
predict_with_generate: bool = True,
do_train: bool = True,
do_eval: bool = True,
do_predict: bool = True,
):
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
args_train = f"""
--model_name_or_path {model_name}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
@ -192,21 +255,14 @@ class TestTrainerExt(TestCasePlus):
--output_dir {output_dir}
--overwrite_output_dir
--max_train_samples 8
--max_eval_samples 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--do_train
--do_eval
--do_predict
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--learning_rate {learning_rate}
--warmup_steps 8
--evaluation_strategy steps
--logging_steps 0
--eval_steps {str(eval_steps)}
--save_steps {str(eval_steps)}
--group_by_length
--label_smoothing_factor 0.1
@ -214,6 +270,30 @@ class TestTrainerExt(TestCasePlus):
--target_lang ro_RO
--source_lang en_XX
"""
args_eval = f"""
--do_eval
--per_device_eval_batch_size 4
--max_eval_samples 8
--val_max_target_length {max_len}
--evaluation_strategy steps
--eval_steps {str(eval_steps)}
"""
args_predict = """
--do_predict
"""
args = ""
if do_train:
args += args_train
if do_eval:
args += args_eval
if do_predict:
args += args_predict
if predict_with_generate:
args += "--predict_with_generate"

View File

@ -665,23 +665,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
def test_log_level(self):
# testing only --log_level (--log_level_replica requires multiple nodes)
# testing only --log_level (--log_level_replica requires multiple gpus and DDP and is tested elsewhere)
logger = logging.get_logger()
log_info_string = "Running training"
# test with the default log level - should be info and thus log
# test with the default log_level - should be info and thus log on the main process
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer()
trainer.train()
self.assertIn(log_info_string, cl.out)
# test with low log level - lower than info
# test with low log_level - lower than info
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer(log_level="debug")
trainer.train()
self.assertIn(log_info_string, cl.out)
# test with high log level - should be quiet
# test with high log_level - should be quiet
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer(log_level="error")
trainer.train()