device agnostic fsdp testing (#27120)

* make fsdp test cases device agnostic

* make style
This commit is contained in:
Hz, Ji 2023-11-01 14:17:06 +08:00 committed by GitHub
parent 7d8ff3629b
commit 82c7e87987
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -24,18 +24,19 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
execute_subprocess_async,
get_gpu_count,
mockenv_context,
require_accelerate,
require_fsdp,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_accelerator,
require_torch_multi_accelerator,
slow,
torch_device,
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import FSDPOption, set_seed
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device
if is_torch_available():
@ -46,7 +47,7 @@ else:
# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"
dtypes = ["fp16"]
if is_torch_bf16_gpu_available():
if is_torch_bf16_available_on_device(torch_device):
dtypes += ["bf16"]
sharding_strategies = ["full_shard", "shard_grad_op"]
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
# - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data)
num_gpus = min(2, get_gpu_count()) if distributed else 1
num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1
master_port = get_master_port(real_launcher=True)
if use_accelerate:
return f"""accelerate launch
@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
@require_accelerate
@require_torch_gpu
@require_torch_accelerator
@require_fsdp_version
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
def setUp(self):
@ -170,7 +171,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@slow
def test_basic_run(self, sharding_strategy, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
@ -182,7 +183,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(dtypes)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@slow
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
def test_basic_run_with_cpu_offload(self, dtype):
@ -195,7 +196,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
execute_subprocess_async(cmd, env=self.get_env())
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@slow
def test_training_and_can_resume_normally(self, state_dict_type):
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)