[tests] make test_trainer_log_level_replica to run on accelerators with more than 2 devices (#29609)

add new arg
This commit is contained in:
Fanli Lin 2024-03-14 01:44:35 +08:00 committed by GitHub
parent 3b6e95ec7f
commit a7e5e15472
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,6 +62,7 @@ class TestTrainerExt(TestCasePlus):
do_train=True,
do_eval=True,
do_predict=True,
n_gpus_to_use=None,
):
output_dir = self.run_trainer(
eval_steps=1,
@ -74,6 +75,7 @@ class TestTrainerExt(TestCasePlus):
do_train=do_train,
do_eval=do_eval,
do_predict=do_predict,
n_gpus_to_use=n_gpus_to_use,
)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
@ -138,7 +140,13 @@ class TestTrainerExt(TestCasePlus):
}
data = experiments[experiment_id]
kwargs = {"distributed": True, "predict_with_generate": False, "do_eval": False, "do_predict": False}
kwargs = {
"distributed": True,
"predict_with_generate": False,
"do_eval": False,
"do_predict": False,
"n_gpus_to_use": 2,
}
log_info_string = "Running training"
with CaptureStderr() as cl:
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])