mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[s2s trainer] fix DP mode (#8823)
* fix DP case on multi-gpu * make executable * test all 3 modes * use the correct check for distributed * dp doesn't need a special case * restore original name * cleanup
This commit is contained in:
parent
d8fc26e919
commit
7f34d75780
2
examples/seq2seq/finetune_trainer.py
Normal file → Executable file
2
examples/seq2seq/finetune_trainer.py
Normal file → Executable file
@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
@ -122,7 +122,8 @@ class Seq2SeqTrainer(Trainer):
|
||||
else:
|
||||
if self.args.sortish_sampler:
|
||||
self.train_dataset.make_sortish_sampler(
|
||||
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
|
||||
self.args.per_device_train_batch_size,
|
||||
distributed=(self.args.local_rank != -1),
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -4,7 +4,14 @@ from unittest.mock import patch
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
@ -18,17 +25,32 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
|
||||
|
||||
class TestFinetuneTrainer(TestCasePlus):
|
||||
def test_finetune_trainer(self):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
|
||||
def finetune_trainer_quick(self, distributed=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
assert "eval_bleu" in first_step_stats
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
def test_finetune_trainer_no_dist(self):
|
||||
self.finetune_trainer_quick()
|
||||
|
||||
# the following 2 tests verify that the trainer can handle distributed and non-distributed with n_gpu > 1
|
||||
@require_torch_multi_gpu
|
||||
def test_finetune_trainer_dp(self):
|
||||
self.finetune_trainer_quick(distributed=False)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_finetune_trainer_ddp(self):
|
||||
self.finetune_trainer_quick(distributed=True)
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow(self):
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10)
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
|
||||
)
|
||||
|
||||
# Check metrics
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
@ -158,7 +180,9 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
# start training
|
||||
trainer.train()
|
||||
|
||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
def run_trainer(
|
||||
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False
|
||||
):
|
||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
@ -193,8 +217,8 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
""".split()
|
||||
# --eval_beams 2
|
||||
|
||||
n_gpu = get_gpu_count()
|
||||
if n_gpu > 1:
|
||||
if distributed:
|
||||
n_gpu = get_gpu_count()
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={n_gpu}
|
||||
@ -203,7 +227,6 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
else:
|
||||
# 0 or 1 gpu
|
||||
testargs = ["finetune_trainer.py"] + args
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user