mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
add tests for the new sharded ddp fairscale integration (#9177)
This commit is contained in:
parent
bf713cdec7
commit
63841c559b
@ -14,10 +14,12 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
@ -38,9 +40,20 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_fairscale(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires fairscale
|
||||
"""
|
||||
if not is_fairscale_available():
|
||||
return unittest.skip("test requires fairscale")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
class TestFinetuneTrainer(TestCasePlus):
|
||||
def finetune_trainer_quick(self, distributed=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed)
|
||||
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
@ -59,6 +72,16 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
def test_finetune_trainer_ddp(self):
|
||||
self.finetune_trainer_quick(distributed=True)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_finetune_trainer_ddp_sharded_ddp(self):
|
||||
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
|
||||
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow(self):
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
@ -195,7 +218,13 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
trainer.train()
|
||||
|
||||
def run_trainer(
|
||||
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False
|
||||
self,
|
||||
eval_steps: int,
|
||||
max_len: str,
|
||||
model_name: str,
|
||||
num_train_epochs: int,
|
||||
distributed: bool = False,
|
||||
extra_args_str: str = None,
|
||||
):
|
||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
@ -231,6 +260,9 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
""".split()
|
||||
# --eval_beams 2
|
||||
|
||||
if extra_args_str is not None:
|
||||
args.extend(extra_args_str.split())
|
||||
|
||||
if distributed:
|
||||
n_gpu = get_gpu_count()
|
||||
distributed_args = f"""
|
||||
|
Loading…
Reference in New Issue
Block a user