[tests] make test_trainer_log_level_replica to run on accelerators with more than 2 devices (#29609)

add new arg
This commit is contained in:
Fanli Lin 2024-03-14 01:44:35 +08:00 committed by GitHub
parent 3b6e95ec7f
commit a7e5e15472
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,6 +62,7 @@ class TestTrainerExt(TestCasePlus):
do_train=True, do_train=True,
do_eval=True, do_eval=True,
do_predict=True, do_predict=True,
n_gpus_to_use=None,
): ):
output_dir = self.run_trainer( output_dir = self.run_trainer(
eval_steps=1, eval_steps=1,
@ -74,6 +75,7 @@ class TestTrainerExt(TestCasePlus):
do_train=do_train, do_train=do_train,
do_eval=do_eval, do_eval=do_eval,
do_predict=do_predict, do_predict=do_predict,
n_gpus_to_use=n_gpus_to_use,
) )
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
@ -138,7 +140,13 @@ class TestTrainerExt(TestCasePlus):
} }
data = experiments[experiment_id] data = experiments[experiment_id]
kwargs = {"distributed": True, "predict_with_generate": False, "do_eval": False, "do_predict": False} kwargs = {
"distributed": True,
"predict_with_generate": False,
"do_eval": False,
"do_predict": False,
"n_gpus_to_use": 2,
}
log_info_string = "Running training" log_info_string = "Running training"
with CaptureStderr() as cl: with CaptureStderr() as cl:
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"]) self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])