mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[tests] make test_trainer_log_level_replica
to run on accelerators with more than 2 devices (#29609)
add new arg
This commit is contained in:
parent
3b6e95ec7f
commit
a7e5e15472
@ -62,6 +62,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
do_predict=True,
|
||||
n_gpus_to_use=None,
|
||||
):
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=1,
|
||||
@ -74,6 +75,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
do_train=do_train,
|
||||
do_eval=do_eval,
|
||||
do_predict=do_predict,
|
||||
n_gpus_to_use=n_gpus_to_use,
|
||||
)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
|
||||
@ -138,7 +140,13 @@ class TestTrainerExt(TestCasePlus):
|
||||
}
|
||||
|
||||
data = experiments[experiment_id]
|
||||
kwargs = {"distributed": True, "predict_with_generate": False, "do_eval": False, "do_predict": False}
|
||||
kwargs = {
|
||||
"distributed": True,
|
||||
"predict_with_generate": False,
|
||||
"do_eval": False,
|
||||
"do_predict": False,
|
||||
"n_gpus_to_use": 2,
|
||||
}
|
||||
log_info_string = "Running training"
|
||||
with CaptureStderr() as cl:
|
||||
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
|
||||
|
Loading…
Reference in New Issue
Block a user