mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[trainer] 2 bug fixes and a rename (#12309)
* bug fixes and a rename * add extended DDP test
This commit is contained in:
parent
64029abe4c
commit
ebe5413589
@ -153,7 +153,7 @@ Here is an example of how this can be used in an application:
|
||||
)
|
||||
|
||||
# set the main code and the modules it uses to the same log-level according to the node
|
||||
log_level = training_args.get_node_log_level()
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
|
@ -245,7 +245,7 @@ def main():
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_node_log_level()
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
|
@ -291,7 +291,7 @@ class Trainer:
|
||||
self._memory_tracker.start()
|
||||
|
||||
# set the correct log level depending on the node
|
||||
log_level = args.get_node_log_level()
|
||||
log_level = args.get_process_log_level()
|
||||
logging.set_verbosity(log_level)
|
||||
|
||||
# force device and distributed setup init explicitly
|
||||
|
@ -603,7 +603,9 @@ class TrainingArguments:
|
||||
if env_local_rank != -1 and env_local_rank != self.local_rank:
|
||||
self.local_rank = env_local_rank
|
||||
|
||||
# convert to int
|
||||
self.log_level = trainer_log_levels[self.log_level]
|
||||
self.log_level_replica = trainer_log_levels[self.log_level_replica]
|
||||
|
||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||
# in the current directory instead of the actual home
|
||||
@ -914,7 +916,20 @@ class TrainingArguments:
|
||||
else:
|
||||
return self.process_index == 0
|
||||
|
||||
def get_node_log_level(self):
|
||||
def get_process_log_level(self):
|
||||
"""
|
||||
Returns the log level to be used depending on whether this process is the main process of node 0, main process
|
||||
of node non-0, or a non-main process.
|
||||
|
||||
For the main process the log level defaults to ``logging.INFO`` unless overridden by ``log_level`` argument.
|
||||
|
||||
For the replica processes the log level defaults to ``logging.WARNING`` unless overridden by
|
||||
``log_level_replica`` argument.
|
||||
|
||||
The choice between the main and replica process settings is made according to the return value of
|
||||
``should_log``.
|
||||
"""
|
||||
|
||||
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level
|
||||
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica
|
||||
return log_level_main_node if self.should_log else log_level_replica_node
|
||||
|
@ -765,7 +765,6 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
--eval_steps {eval_steps}
|
||||
--group_by_length
|
||||
--label_smoothing_factor 0.1
|
||||
--adafactor
|
||||
--source_lang en
|
||||
--target_lang ro
|
||||
--report_to none
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
@ -21,6 +22,7 @@ from unittest.mock import patch
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
ExtendSysPath,
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
@ -68,7 +70,15 @@ def require_apex(test_case):
|
||||
|
||||
|
||||
class TestTrainerExt(TestCasePlus):
|
||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
|
||||
def run_seq2seq_quick(
|
||||
self,
|
||||
distributed=False,
|
||||
extra_args_str=None,
|
||||
predict_with_generate=True,
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
do_predict=True,
|
||||
):
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=1,
|
||||
max_len=12,
|
||||
@ -77,8 +87,15 @@ class TestTrainerExt(TestCasePlus):
|
||||
distributed=distributed,
|
||||
extra_args_str=extra_args_str,
|
||||
predict_with_generate=predict_with_generate,
|
||||
do_train=do_train,
|
||||
do_eval=do_eval,
|
||||
do_predict=do_predict,
|
||||
)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
|
||||
if not do_eval:
|
||||
return
|
||||
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
|
||||
first_step_stats = eval_metrics[0]
|
||||
@ -145,6 +162,49 @@ class TestTrainerExt(TestCasePlus):
|
||||
# to reproduce the problem set distributed=False
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_trainer_log_level_replica(self):
|
||||
log_info_string = "Running training"
|
||||
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
|
||||
|
||||
# test with the default log_level - should be info and thus log info once
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 1)
|
||||
|
||||
# test with low log_level and log_level_replica - should be noisy on all processes
|
||||
# now the info string should appear twice on 2 processes
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="--log_level debug --log_level_replica debug",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 2)
|
||||
|
||||
# test with high log_level and low log_level_replica
|
||||
# now the info string should appear once only on the replica
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="--log_level error --log_level_replica debug",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 1)
|
||||
|
||||
# test with high log_level and log_level_replica - should be quiet on all processes
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(
|
||||
**kwargs,
|
||||
extra_args_str="--log_level error --log_level_replica error",
|
||||
)
|
||||
n_matches = len(re.findall(log_info_string, cl.err))
|
||||
self.assertEqual(n_matches, 0)
|
||||
|
||||
@slow
|
||||
def test_run_seq2seq_slow(self):
|
||||
output_dir = self.run_trainer(
|
||||
@ -181,10 +241,13 @@ class TestTrainerExt(TestCasePlus):
|
||||
distributed: bool = False,
|
||||
extra_args_str: str = None,
|
||||
predict_with_generate: bool = True,
|
||||
do_train: bool = True,
|
||||
do_eval: bool = True,
|
||||
do_predict: bool = True,
|
||||
):
|
||||
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
args_train = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--train_file {data_dir}/train.json
|
||||
--validation_file {data_dir}/val.json
|
||||
@ -192,21 +255,14 @@ class TestTrainerExt(TestCasePlus):
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--max_train_samples 8
|
||||
--max_eval_samples 8
|
||||
--max_source_length {max_len}
|
||||
--max_target_length {max_len}
|
||||
--val_max_target_length {max_len}
|
||||
--do_train
|
||||
--do_eval
|
||||
--do_predict
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
--learning_rate {learning_rate}
|
||||
--warmup_steps 8
|
||||
--evaluation_strategy steps
|
||||
--logging_steps 0
|
||||
--eval_steps {str(eval_steps)}
|
||||
--save_steps {str(eval_steps)}
|
||||
--group_by_length
|
||||
--label_smoothing_factor 0.1
|
||||
@ -214,6 +270,30 @@ class TestTrainerExt(TestCasePlus):
|
||||
--target_lang ro_RO
|
||||
--source_lang en_XX
|
||||
"""
|
||||
|
||||
args_eval = f"""
|
||||
--do_eval
|
||||
--per_device_eval_batch_size 4
|
||||
--max_eval_samples 8
|
||||
--val_max_target_length {max_len}
|
||||
--evaluation_strategy steps
|
||||
--eval_steps {str(eval_steps)}
|
||||
"""
|
||||
|
||||
args_predict = """
|
||||
--do_predict
|
||||
"""
|
||||
|
||||
args = ""
|
||||
if do_train:
|
||||
args += args_train
|
||||
|
||||
if do_eval:
|
||||
args += args_eval
|
||||
|
||||
if do_predict:
|
||||
args += args_predict
|
||||
|
||||
if predict_with_generate:
|
||||
args += "--predict_with_generate"
|
||||
|
||||
|
@ -665,23 +665,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
||||
|
||||
def test_log_level(self):
|
||||
# testing only --log_level (--log_level_replica requires multiple nodes)
|
||||
# testing only --log_level (--log_level_replica requires multiple gpus and DDP and is tested elsewhere)
|
||||
logger = logging.get_logger()
|
||||
log_info_string = "Running training"
|
||||
|
||||
# test with the default log level - should be info and thus log
|
||||
# test with the default log_level - should be info and thus log on the main process
|
||||
with CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer()
|
||||
trainer.train()
|
||||
self.assertIn(log_info_string, cl.out)
|
||||
|
||||
# test with low log level - lower than info
|
||||
# test with low log_level - lower than info
|
||||
with CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer(log_level="debug")
|
||||
trainer.train()
|
||||
self.assertIn(log_info_string, cl.out)
|
||||
|
||||
# test with high log level - should be quiet
|
||||
# test with high log_level - should be quiet
|
||||
with CaptureLogger(logger) as cl:
|
||||
trainer = get_regression_trainer(log_level="error")
|
||||
trainer.train()
|
||||
|
Loading…
Reference in New Issue
Block a user