mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 19:48:23 +06:00
[tests] make test_trainer_log_level_replica
to run on accelerators with more than 2 devices (#29609)
add new arg
This commit is contained in:
parent
3b6e95ec7f
commit
a7e5e15472
@ -62,6 +62,7 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
do_train=True,
|
do_train=True,
|
||||||
do_eval=True,
|
do_eval=True,
|
||||||
do_predict=True,
|
do_predict=True,
|
||||||
|
n_gpus_to_use=None,
|
||||||
):
|
):
|
||||||
output_dir = self.run_trainer(
|
output_dir = self.run_trainer(
|
||||||
eval_steps=1,
|
eval_steps=1,
|
||||||
@ -74,6 +75,7 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
do_train=do_train,
|
do_train=do_train,
|
||||||
do_eval=do_eval,
|
do_eval=do_eval,
|
||||||
do_predict=do_predict,
|
do_predict=do_predict,
|
||||||
|
n_gpus_to_use=n_gpus_to_use,
|
||||||
)
|
)
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
|
|
||||||
@ -138,7 +140,13 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
}
|
}
|
||||||
|
|
||||||
data = experiments[experiment_id]
|
data = experiments[experiment_id]
|
||||||
kwargs = {"distributed": True, "predict_with_generate": False, "do_eval": False, "do_predict": False}
|
kwargs = {
|
||||||
|
"distributed": True,
|
||||||
|
"predict_with_generate": False,
|
||||||
|
"do_eval": False,
|
||||||
|
"do_predict": False,
|
||||||
|
"n_gpus_to_use": 2,
|
||||||
|
}
|
||||||
log_info_string = "Running training"
|
log_info_string = "Running training"
|
||||||
with CaptureStderr() as cl:
|
with CaptureStderr() as cl:
|
||||||
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
|
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
|
||||||
|
Loading…
Reference in New Issue
Block a user