mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
device agnostic fsdp testing (#27120)
* make fsdp test cases device agnostic * make style
This commit is contained in:
parent
7d8ff3629b
commit
82c7e87987
@ -24,18 +24,19 @@ from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
mockenv_context,
|
||||
require_accelerate,
|
||||
require_fsdp,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import FSDPOption, set_seed
|
||||
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
|
||||
from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -46,7 +47,7 @@ else:
|
||||
# default torch.distributed port
|
||||
DEFAULT_MASTER_PORT = "10999"
|
||||
dtypes = ["fp16"]
|
||||
if is_torch_bf16_gpu_available():
|
||||
if is_torch_bf16_available_on_device(torch_device):
|
||||
dtypes += ["bf16"]
|
||||
sharding_strategies = ["full_shard", "shard_grad_op"]
|
||||
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
|
||||
@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
|
||||
# - it won't be able to handle that
|
||||
# 2. for now testing with just 2 gpus max (since some quality tests may give different
|
||||
# results with mode gpus because we use very little data)
|
||||
num_gpus = min(2, get_gpu_count()) if distributed else 1
|
||||
num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1
|
||||
master_port = get_master_port(real_launcher=True)
|
||||
if use_accelerate:
|
||||
return f"""accelerate launch
|
||||
@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
|
||||
|
||||
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_fsdp_version
|
||||
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
def setUp(self):
|
||||
@ -170,7 +171,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
|
||||
|
||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
@slow
|
||||
def test_basic_run(self, sharding_strategy, dtype):
|
||||
launcher = get_launcher(distributed=True, use_accelerate=False)
|
||||
@ -182,7 +183,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
@parameterized.expand(dtypes)
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
@slow
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
|
||||
def test_basic_run_with_cpu_offload(self, dtype):
|
||||
@ -195,7 +196,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
@slow
|
||||
def test_training_and_can_resume_normally(self, state_dict_type):
|
||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||
|
Loading…
Reference in New Issue
Block a user