mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[trainer] 2 bug fixes and a rename (#12309)
* bug fixes and a rename * add extended DDP test
This commit is contained in:
parent
64029abe4c
commit
ebe5413589
@ -153,7 +153,7 @@ Here is an example of how this can be used in an application:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# set the main code and the modules it uses to the same log-level according to the node
|
# set the main code and the modules it uses to the same log-level according to the node
|
||||||
log_level = training_args.get_node_log_level()
|
log_level = training_args.get_process_log_level()
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
datasets.utils.logging.set_verbosity(log_level)
|
datasets.utils.logging.set_verbosity(log_level)
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
@ -245,7 +245,7 @@ def main():
|
|||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
|
|
||||||
log_level = training_args.get_node_log_level()
|
log_level = training_args.get_process_log_level()
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
datasets.utils.logging.set_verbosity(log_level)
|
datasets.utils.logging.set_verbosity(log_level)
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
@ -291,7 +291,7 @@ class Trainer:
|
|||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
|
|
||||||
# set the correct log level depending on the node
|
# set the correct log level depending on the node
|
||||||
log_level = args.get_node_log_level()
|
log_level = args.get_process_log_level()
|
||||||
logging.set_verbosity(log_level)
|
logging.set_verbosity(log_level)
|
||||||
|
|
||||||
# force device and distributed setup init explicitly
|
# force device and distributed setup init explicitly
|
||||||
|
@ -603,7 +603,9 @@ class TrainingArguments:
|
|||||||
if env_local_rank != -1 and env_local_rank != self.local_rank:
|
if env_local_rank != -1 and env_local_rank != self.local_rank:
|
||||||
self.local_rank = env_local_rank
|
self.local_rank = env_local_rank
|
||||||
|
|
||||||
|
# convert to int
|
||||||
self.log_level = trainer_log_levels[self.log_level]
|
self.log_level = trainer_log_levels[self.log_level]
|
||||||
|
self.log_level_replica = trainer_log_levels[self.log_level_replica]
|
||||||
|
|
||||||
# expand paths, if not os.makedirs("~/bar") will make directory
|
# expand paths, if not os.makedirs("~/bar") will make directory
|
||||||
# in the current directory instead of the actual home
|
# in the current directory instead of the actual home
|
||||||
@ -914,7 +916,20 @@ class TrainingArguments:
|
|||||||
else:
|
else:
|
||||||
return self.process_index == 0
|
return self.process_index == 0
|
||||||
|
|
||||||
def get_node_log_level(self):
|
def get_process_log_level(self):
|
||||||
|
"""
|
||||||
|
Returns the log level to be used depending on whether this process is the main process of node 0, main process
|
||||||
|
of node non-0, or a non-main process.
|
||||||
|
|
||||||
|
For the main process the log level defaults to ``logging.INFO`` unless overridden by ``log_level`` argument.
|
||||||
|
|
||||||
|
For the replica processes the log level defaults to ``logging.WARNING`` unless overridden by
|
||||||
|
``log_level_replica`` argument.
|
||||||
|
|
||||||
|
The choice between the main and replica process settings is made according to the return value of
|
||||||
|
``should_log``.
|
||||||
|
"""
|
||||||
|
|
||||||
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level
|
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level
|
||||||
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica
|
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica
|
||||||
return log_level_main_node if self.should_log else log_level_replica_node
|
return log_level_main_node if self.should_log else log_level_replica_node
|
||||||
|
@ -765,7 +765,6 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
--eval_steps {eval_steps}
|
--eval_steps {eval_steps}
|
||||||
--group_by_length
|
--group_by_length
|
||||||
--label_smoothing_factor 0.1
|
--label_smoothing_factor 0.1
|
||||||
--adafactor
|
|
||||||
--source_lang en
|
--source_lang en
|
||||||
--target_lang ro
|
--target_lang ro
|
||||||
--report_to none
|
--report_to none
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -21,6 +22,7 @@ from unittest.mock import patch
|
|||||||
from transformers.file_utils import is_apex_available
|
from transformers.file_utils import is_apex_available
|
||||||
from transformers.integrations import is_fairscale_available
|
from transformers.integrations import is_fairscale_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureStderr,
|
||||||
ExtendSysPath,
|
ExtendSysPath,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
@ -68,7 +70,15 @@ def require_apex(test_case):
|
|||||||
|
|
||||||
|
|
||||||
class TestTrainerExt(TestCasePlus):
|
class TestTrainerExt(TestCasePlus):
|
||||||
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
|
def run_seq2seq_quick(
|
||||||
|
self,
|
||||||
|
distributed=False,
|
||||||
|
extra_args_str=None,
|
||||||
|
predict_with_generate=True,
|
||||||
|
do_train=True,
|
||||||
|
do_eval=True,
|
||||||
|
do_predict=True,
|
||||||
|
):
|
||||||
output_dir = self.run_trainer(
|
output_dir = self.run_trainer(
|
||||||
eval_steps=1,
|
eval_steps=1,
|
||||||
max_len=12,
|
max_len=12,
|
||||||
@ -77,8 +87,15 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
distributed=distributed,
|
distributed=distributed,
|
||||||
extra_args_str=extra_args_str,
|
extra_args_str=extra_args_str,
|
||||||
predict_with_generate=predict_with_generate,
|
predict_with_generate=predict_with_generate,
|
||||||
|
do_train=do_train,
|
||||||
|
do_eval=do_eval,
|
||||||
|
do_predict=do_predict,
|
||||||
)
|
)
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
|
|
||||||
|
if not do_eval:
|
||||||
|
return
|
||||||
|
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
|
|
||||||
first_step_stats = eval_metrics[0]
|
first_step_stats = eval_metrics[0]
|
||||||
@ -145,6 +162,49 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
# to reproduce the problem set distributed=False
|
# to reproduce the problem set distributed=False
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_trainer_log_level_replica(self):
|
||||||
|
log_info_string = "Running training"
|
||||||
|
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
|
||||||
|
|
||||||
|
# test with the default log_level - should be info and thus log info once
|
||||||
|
with CaptureStderr() as cl:
|
||||||
|
self.run_seq2seq_quick(
|
||||||
|
**kwargs,
|
||||||
|
extra_args_str="",
|
||||||
|
)
|
||||||
|
n_matches = len(re.findall(log_info_string, cl.err))
|
||||||
|
self.assertEqual(n_matches, 1)
|
||||||
|
|
||||||
|
# test with low log_level and log_level_replica - should be noisy on all processes
|
||||||
|
# now the info string should appear twice on 2 processes
|
||||||
|
with CaptureStderr() as cl:
|
||||||
|
self.run_seq2seq_quick(
|
||||||
|
**kwargs,
|
||||||
|
extra_args_str="--log_level debug --log_level_replica debug",
|
||||||
|
)
|
||||||
|
n_matches = len(re.findall(log_info_string, cl.err))
|
||||||
|
self.assertEqual(n_matches, 2)
|
||||||
|
|
||||||
|
# test with high log_level and low log_level_replica
|
||||||
|
# now the info string should appear once only on the replica
|
||||||
|
with CaptureStderr() as cl:
|
||||||
|
self.run_seq2seq_quick(
|
||||||
|
**kwargs,
|
||||||
|
extra_args_str="--log_level error --log_level_replica debug",
|
||||||
|
)
|
||||||
|
n_matches = len(re.findall(log_info_string, cl.err))
|
||||||
|
self.assertEqual(n_matches, 1)
|
||||||
|
|
||||||
|
# test with high log_level and log_level_replica - should be quiet on all processes
|
||||||
|
with CaptureStderr() as cl:
|
||||||
|
self.run_seq2seq_quick(
|
||||||
|
**kwargs,
|
||||||
|
extra_args_str="--log_level error --log_level_replica error",
|
||||||
|
)
|
||||||
|
n_matches = len(re.findall(log_info_string, cl.err))
|
||||||
|
self.assertEqual(n_matches, 0)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_run_seq2seq_slow(self):
|
def test_run_seq2seq_slow(self):
|
||||||
output_dir = self.run_trainer(
|
output_dir = self.run_trainer(
|
||||||
@ -181,10 +241,13 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
distributed: bool = False,
|
distributed: bool = False,
|
||||||
extra_args_str: str = None,
|
extra_args_str: str = None,
|
||||||
predict_with_generate: bool = True,
|
predict_with_generate: bool = True,
|
||||||
|
do_train: bool = True,
|
||||||
|
do_eval: bool = True,
|
||||||
|
do_predict: bool = True,
|
||||||
):
|
):
|
||||||
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
args = f"""
|
args_train = f"""
|
||||||
--model_name_or_path {model_name}
|
--model_name_or_path {model_name}
|
||||||
--train_file {data_dir}/train.json
|
--train_file {data_dir}/train.json
|
||||||
--validation_file {data_dir}/val.json
|
--validation_file {data_dir}/val.json
|
||||||
@ -192,21 +255,14 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
--output_dir {output_dir}
|
--output_dir {output_dir}
|
||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--max_train_samples 8
|
--max_train_samples 8
|
||||||
--max_eval_samples 8
|
|
||||||
--max_source_length {max_len}
|
--max_source_length {max_len}
|
||||||
--max_target_length {max_len}
|
--max_target_length {max_len}
|
||||||
--val_max_target_length {max_len}
|
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
|
||||||
--do_predict
|
|
||||||
--num_train_epochs {str(num_train_epochs)}
|
--num_train_epochs {str(num_train_epochs)}
|
||||||
--per_device_train_batch_size 4
|
--per_device_train_batch_size 4
|
||||||
--per_device_eval_batch_size 4
|
|
||||||
--learning_rate {learning_rate}
|
--learning_rate {learning_rate}
|
||||||
--warmup_steps 8
|
--warmup_steps 8
|
||||||
--evaluation_strategy steps
|
|
||||||
--logging_steps 0
|
--logging_steps 0
|
||||||
--eval_steps {str(eval_steps)}
|
|
||||||
--save_steps {str(eval_steps)}
|
--save_steps {str(eval_steps)}
|
||||||
--group_by_length
|
--group_by_length
|
||||||
--label_smoothing_factor 0.1
|
--label_smoothing_factor 0.1
|
||||||
@ -214,6 +270,30 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
--target_lang ro_RO
|
--target_lang ro_RO
|
||||||
--source_lang en_XX
|
--source_lang en_XX
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
args_eval = f"""
|
||||||
|
--do_eval
|
||||||
|
--per_device_eval_batch_size 4
|
||||||
|
--max_eval_samples 8
|
||||||
|
--val_max_target_length {max_len}
|
||||||
|
--evaluation_strategy steps
|
||||||
|
--eval_steps {str(eval_steps)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
args_predict = """
|
||||||
|
--do_predict
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = ""
|
||||||
|
if do_train:
|
||||||
|
args += args_train
|
||||||
|
|
||||||
|
if do_eval:
|
||||||
|
args += args_eval
|
||||||
|
|
||||||
|
if do_predict:
|
||||||
|
args += args_predict
|
||||||
|
|
||||||
if predict_with_generate:
|
if predict_with_generate:
|
||||||
args += "--predict_with_generate"
|
args += "--predict_with_generate"
|
||||||
|
|
||||||
|
@ -665,23 +665,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
||||||
|
|
||||||
def test_log_level(self):
|
def test_log_level(self):
|
||||||
# testing only --log_level (--log_level_replica requires multiple nodes)
|
# testing only --log_level (--log_level_replica requires multiple gpus and DDP and is tested elsewhere)
|
||||||
logger = logging.get_logger()
|
logger = logging.get_logger()
|
||||||
log_info_string = "Running training"
|
log_info_string = "Running training"
|
||||||
|
|
||||||
# test with the default log level - should be info and thus log
|
# test with the default log_level - should be info and thus log on the main process
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
trainer = get_regression_trainer()
|
trainer = get_regression_trainer()
|
||||||
trainer.train()
|
trainer.train()
|
||||||
self.assertIn(log_info_string, cl.out)
|
self.assertIn(log_info_string, cl.out)
|
||||||
|
|
||||||
# test with low log level - lower than info
|
# test with low log_level - lower than info
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
trainer = get_regression_trainer(log_level="debug")
|
trainer = get_regression_trainer(log_level="debug")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
self.assertIn(log_info_string, cl.out)
|
self.assertIn(log_info_string, cl.out)
|
||||||
|
|
||||||
# test with high log level - should be quiet
|
# test with high log_level - should be quiet
|
||||||
with CaptureLogger(logger) as cl:
|
with CaptureLogger(logger) as cl:
|
||||||
trainer = get_regression_trainer(log_level="error")
|
trainer = get_regression_trainer(log_level="error")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
Loading…
Reference in New Issue
Block a user