From ebe54135890bf06d88609cfbbd26de02f12e387b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 22 Jun 2021 11:13:23 -0700 Subject: [PATCH] [trainer] 2 bug fixes and a rename (#12309) * bug fixes and a rename * add extended DDP test --- docs/source/main_classes/trainer.rst | 2 +- .../pytorch/translation/run_translation.py | 2 +- src/transformers/trainer.py | 2 +- src/transformers/training_args.py | 17 +++- tests/deepspeed/test_deepspeed.py | 1 - tests/extended/test_trainer_ext.py | 98 +++++++++++++++++-- tests/test_trainer.py | 8 +- 7 files changed, 112 insertions(+), 18 deletions(-) diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index 866665eacf5..21586e8772f 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -153,7 +153,7 @@ Here is an example of how this can be used in an application: ) # set the main code and the modules it uses to the same log-level according to the node - log_level = training_args.get_node_log_level() + log_level = training_args.get_process_log_level() logger.setLevel(log_level) datasets.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level) diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 44111800442..0274dbce17b 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -245,7 +245,7 @@ def main(): handlers=[logging.StreamHandler(sys.stdout)], ) - log_level = training_args.get_node_log_level() + log_level = training_args.get_process_log_level() logger.setLevel(log_level) datasets.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 60c344be7a5..55fcb4af01e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -291,7 +291,7 @@ class Trainer: self._memory_tracker.start() # set the correct log level depending on the node - log_level = args.get_node_log_level() + log_level = args.get_process_log_level() logging.set_verbosity(log_level) # force device and distributed setup init explicitly diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 69f1693abb7..74791cd7e18 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -603,7 +603,9 @@ class TrainingArguments: if env_local_rank != -1 and env_local_rank != self.local_rank: self.local_rank = env_local_rank + # convert to int self.log_level = trainer_log_levels[self.log_level] + self.log_level_replica = trainer_log_levels[self.log_level_replica] # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home @@ -914,7 +916,20 @@ class TrainingArguments: else: return self.process_index == 0 - def get_node_log_level(self): + def get_process_log_level(self): + """ + Returns the log level to be used depending on whether this process is the main process of node 0, main process + of node non-0, or a non-main process. + + For the main process the log level defaults to ``logging.INFO`` unless overridden by ``log_level`` argument. + + For the replica processes the log level defaults to ``logging.WARNING`` unless overridden by + ``log_level_replica`` argument. + + The choice between the main and replica process settings is made according to the return value of + ``should_log``. + """ + log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica return log_level_main_node if self.should_log else log_level_replica_node diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 74a2928c3ec..e699b110f0a 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -765,7 +765,6 @@ class TestDeepSpeedWithLauncher(TestCasePlus): --eval_steps {eval_steps} --group_by_length --label_smoothing_factor 0.1 - --adafactor --source_lang en --target_lang ro --report_to none diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py index 93ef0ddb555..a0a328cf091 100644 --- a/tests/extended/test_trainer_ext.py +++ b/tests/extended/test_trainer_ext.py @@ -14,6 +14,7 @@ import math import os +import re import sys import unittest from unittest.mock import patch @@ -21,6 +22,7 @@ from unittest.mock import patch from transformers.file_utils import is_apex_available from transformers.integrations import is_fairscale_available from transformers.testing_utils import ( + CaptureStderr, ExtendSysPath, TestCasePlus, execute_subprocess_async, @@ -68,7 +70,15 @@ def require_apex(test_case): class TestTrainerExt(TestCasePlus): - def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True): + def run_seq2seq_quick( + self, + distributed=False, + extra_args_str=None, + predict_with_generate=True, + do_train=True, + do_eval=True, + do_predict=True, + ): output_dir = self.run_trainer( eval_steps=1, max_len=12, @@ -77,8 +87,15 @@ class TestTrainerExt(TestCasePlus): distributed=distributed, extra_args_str=extra_args_str, predict_with_generate=predict_with_generate, + do_train=do_train, + do_eval=do_eval, + do_predict=do_predict, ) logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history + + if not do_eval: + return + eval_metrics = [log for log in logs if "eval_loss" in log.keys()] first_step_stats = eval_metrics[0] @@ -145,6 +162,49 @@ class TestTrainerExt(TestCasePlus): # to reproduce the problem set distributed=False self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex") + @require_torch_multi_gpu + def test_trainer_log_level_replica(self): + log_info_string = "Running training" + kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False) + + # test with the default log_level - should be info and thus log info once + with CaptureStderr() as cl: + self.run_seq2seq_quick( + **kwargs, + extra_args_str="", + ) + n_matches = len(re.findall(log_info_string, cl.err)) + self.assertEqual(n_matches, 1) + + # test with low log_level and log_level_replica - should be noisy on all processes + # now the info string should appear twice on 2 processes + with CaptureStderr() as cl: + self.run_seq2seq_quick( + **kwargs, + extra_args_str="--log_level debug --log_level_replica debug", + ) + n_matches = len(re.findall(log_info_string, cl.err)) + self.assertEqual(n_matches, 2) + + # test with high log_level and low log_level_replica + # now the info string should appear once only on the replica + with CaptureStderr() as cl: + self.run_seq2seq_quick( + **kwargs, + extra_args_str="--log_level error --log_level_replica debug", + ) + n_matches = len(re.findall(log_info_string, cl.err)) + self.assertEqual(n_matches, 1) + + # test with high log_level and log_level_replica - should be quiet on all processes + with CaptureStderr() as cl: + self.run_seq2seq_quick( + **kwargs, + extra_args_str="--log_level error --log_level_replica error", + ) + n_matches = len(re.findall(log_info_string, cl.err)) + self.assertEqual(n_matches, 0) + @slow def test_run_seq2seq_slow(self): output_dir = self.run_trainer( @@ -181,10 +241,13 @@ class TestTrainerExt(TestCasePlus): distributed: bool = False, extra_args_str: str = None, predict_with_generate: bool = True, + do_train: bool = True, + do_eval: bool = True, + do_predict: bool = True, ): data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro" output_dir = self.get_auto_remove_tmp_dir() - args = f""" + args_train = f""" --model_name_or_path {model_name} --train_file {data_dir}/train.json --validation_file {data_dir}/val.json @@ -192,21 +255,14 @@ class TestTrainerExt(TestCasePlus): --output_dir {output_dir} --overwrite_output_dir --max_train_samples 8 - --max_eval_samples 8 --max_source_length {max_len} --max_target_length {max_len} - --val_max_target_length {max_len} --do_train - --do_eval - --do_predict --num_train_epochs {str(num_train_epochs)} --per_device_train_batch_size 4 - --per_device_eval_batch_size 4 --learning_rate {learning_rate} --warmup_steps 8 - --evaluation_strategy steps --logging_steps 0 - --eval_steps {str(eval_steps)} --save_steps {str(eval_steps)} --group_by_length --label_smoothing_factor 0.1 @@ -214,6 +270,30 @@ class TestTrainerExt(TestCasePlus): --target_lang ro_RO --source_lang en_XX """ + + args_eval = f""" + --do_eval + --per_device_eval_batch_size 4 + --max_eval_samples 8 + --val_max_target_length {max_len} + --evaluation_strategy steps + --eval_steps {str(eval_steps)} + """ + + args_predict = """ + --do_predict + """ + + args = "" + if do_train: + args += args_train + + if do_eval: + args += args_eval + + if do_predict: + args += args_predict + if predict_with_generate: args += "--predict_with_generate" diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 7107cea56df..2dc7108d4d5 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -665,23 +665,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(np.all(seen[expected.shape[0] :] == -100)) def test_log_level(self): - # testing only --log_level (--log_level_replica requires multiple nodes) + # testing only --log_level (--log_level_replica requires multiple gpus and DDP and is tested elsewhere) logger = logging.get_logger() log_info_string = "Running training" - # test with the default log level - should be info and thus log + # test with the default log_level - should be info and thus log on the main process with CaptureLogger(logger) as cl: trainer = get_regression_trainer() trainer.train() self.assertIn(log_info_string, cl.out) - # test with low log level - lower than info + # test with low log_level - lower than info with CaptureLogger(logger) as cl: trainer = get_regression_trainer(log_level="debug") trainer.train() self.assertIn(log_info_string, cl.out) - # test with high log level - should be quiet + # test with high log_level - should be quiet with CaptureLogger(logger) as cl: trainer = get_regression_trainer(log_level="error") trainer.train()