From 8f08318769c15fdb6b64418cbb070e8a8b405ffb Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Sat, 10 May 2025 04:52:41 +0800 Subject: [PATCH] enable generation fsdp/utils cases on XPU (#38009) * enable generation fsdp/utils test cases on XPU Signed-off-by: Yao Matrix * fix style Signed-off-by: Yao Matrix * xx Signed-off-by: Yao Matrix * use backend_xx APIs Signed-off-by: Yao Matrix * fix style Signed-off-by: Yao Matrix --------- Signed-off-by: Yao Matrix --- src/transformers/testing_utils.py | 15 +++++++++++ src/transformers/utils/__init__.py | 1 + tests/generation/test_fsdp.py | 40 ++++++++++++++---------------- tests/generation/test_utils.py | 27 ++++++++++---------- 4 files changed, 48 insertions(+), 35 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 93f9be67822..b3734d97887 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -2986,6 +2986,11 @@ if is_torch_available(): "cpu": 0, "default": 0, } + BACKEND_TORCH_ACCELERATOR_MODULE = { + "cuda": torch.cuda, + "cpu": None, + "default": None, + } else: BACKEND_MANUAL_SEED = {"default": None} BACKEND_EMPTY_CACHE = {"default": None} @@ -2993,21 +2998,25 @@ else: BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None} BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0} BACKEND_MEMORY_ALLOCATED = {"default": 0} + BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None} if is_torch_hpu_available(): BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count + BACKEND_TORCH_ACCELERATOR_MODULE["hpu"] = torch.hpu if is_torch_mlu_available(): BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count + BACKEND_TORCH_ACCELERATOR_MODULE["mlu"] = torch.mlu if is_torch_npu_available(): BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count + BACKEND_TORCH_ACCELERATOR_MODULE["npu"] = torch.npu if is_torch_xpu_available(): BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache @@ -3016,6 +3025,8 @@ if is_torch_xpu_available(): BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated + BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu + if is_torch_xla_available(): BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache @@ -3047,6 +3058,10 @@ def backend_memory_allocated(device: str): return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED) +def backend_torch_accelerator_module(device: str): + return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE) + + if is_torch_available(): # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries # into device to function mappings. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index bb2f7e890fb..90ccef3137d 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -140,6 +140,7 @@ from .import_utils import ( is_bitsandbytes_available, is_bitsandbytes_multi_backend_available, is_bs4_available, + is_ccl_available, is_coloredlogs_available, is_compressed_tensors_available, is_cv2_available, diff --git a/tests/generation/test_fsdp.py b/tests/generation/test_fsdp.py index 2f4c77078f8..9ecb4315731 100644 --- a/tests/generation/test_fsdp.py +++ b/tests/generation/test_fsdp.py @@ -15,19 +15,29 @@ import argparse from typing import Any, Callable -from transformers import is_torch_available, is_torch_mlu_available +from transformers import is_torch_available, is_torch_xpu_available from transformers.testing_utils import ( TestCasePlus, + backend_device_count, + backend_torch_accelerator_module, execute_subprocess_async, get_torch_dist_unique_port, require_torch_multi_accelerator, + torch_device, ) +from transformers.utils import is_ccl_available, is_ipex_available if is_torch_available(): import functools import torch + + if is_torch_xpu_available(): + if is_ipex_available(): + import intel_extension_for_pytorch # noqa: F401 + if is_ccl_available(): + import oneccl_bindings_for_pytorch # noqa: F401 import torch.distributed from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method from torch.distributed.device_mesh import init_device_mesh @@ -46,10 +56,7 @@ if is_torch_available(): """Manage the creation and destruction of the distributed process group for the wrapped function.""" def wrapped(*args: Any, **kwargs: Any) -> Any: - if is_torch_mlu_available(): - device_count = torch.mlu.device_count() - else: - device_count = torch.cuda.device_count() + device_count = backend_device_count(torch_device) torch.distributed.init_process_group(world_size=device_count) try: return func(*args, **kwargs) @@ -60,10 +67,8 @@ if is_torch_available(): @manage_process_group def fsdp_generate(): - if is_torch_mlu_available(): - torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank())) - else: - torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank())) + torch_accelerator_module = backend_torch_accelerator_module(torch_device) + torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank())) model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) @@ -86,10 +91,8 @@ if is_torch_available(): @manage_process_group def fsdp2_generate(): - if is_torch_mlu_available(): - torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank())) - else: - torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank())) + torch_accelerator_module = backend_torch_accelerator_module(torch_device) + torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank())) model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) @@ -114,10 +117,7 @@ if is_torch_available(): class TestFSDPGeneration(TestCasePlus): @require_torch_multi_accelerator def test_fsdp_generate(self): - if is_torch_mlu_available(): - device_count = torch.mlu.device_count() - else: - device_count = torch.cuda.device_count() + device_count = backend_device_count(torch_device) distributed_args = f"""--nproc_per_node={device_count} --master_port={get_torch_dist_unique_port()} {self.test_file_dir}/test_fsdp.py @@ -129,10 +129,8 @@ class TestFSDPGeneration(TestCasePlus): @require_torch_multi_accelerator def test_fsdp2_generate(self): - if is_torch_mlu_available(): - device_count = torch.mlu.device_count() - else: - device_count = torch.cuda.device_count() + device_count = backend_device_count(torch_device) + distributed_args = f"""--nproc_per_node={device_count} --master_port={get_torch_dist_unique_port()} {self.test_file_dir}/test_fsdp.py diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 20edc1c8973..3807c84dade 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -41,7 +41,6 @@ from transformers.testing_utils import ( require_torch_gpu, require_torch_greater_or_equal, require_torch_multi_accelerator, - require_torch_multi_gpu, require_torch_sdpa, set_config_for_less_flaky_test, set_model_for_less_flaky_test, @@ -2954,7 +2953,7 @@ class GenerationIntegrationTests(unittest.TestCase): def test_stop_sequence_stopping_criteria(self): prompt = """Hello I believe in""" generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart") - output = generator(prompt) + output = generator(prompt, max_new_tokens=10) self.assertEqual( output, [{"generated_text": ("Hello I believe in we we we we we we we we we")}], @@ -3860,7 +3859,7 @@ class GenerationIntegrationTests(unittest.TestCase): @slow @require_torch_multi_accelerator - def test_assisted_decoding_in_different_gpu(self): + def test_assisted_decoding_in_different_accelerator(self): device_0 = f"{torch_device}:0" if torch_device != "cpu" else "cpu" device_1 = f"{torch_device}:1" if torch_device != "cpu" else "cpu" model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(device_0) @@ -3885,7 +3884,7 @@ class GenerationIntegrationTests(unittest.TestCase): @slow @require_torch_accelerator - def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self): + def test_assisted_decoding_model_in_accelerator_assistant_in_cpu(self): model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( torch_device ) @@ -3970,10 +3969,10 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertTrue((expected_out == predicted_out).all().item()) @pytest.mark.generate - @require_torch_multi_gpu - def test_generate_with_static_cache_multi_gpu(self): + @require_torch_multi_accelerator + def test_generate_with_static_cache_multi_accelerator(self): """ - Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus. + Tests if the static cache has been set correctly and if generate works correctly when we are using multi-acceleratorss. """ # need to split manually as auto doesn't work well with unbalanced model device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} @@ -4005,10 +4004,10 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1)) @pytest.mark.generate - @require_torch_multi_gpu - def test_generate_multi_gpu_causal_mask(self): + @require_torch_multi_accelerator + def test_generate_multi_accelerator_causal_mask(self): """ - Tests that cache position device doesn't clash with causal mask device when we are using multi-gpus. + Tests that cache position device doesn't clash with causal mask device when we are using multi-accelerators. In real life happens only when multimodal encoder size is big, so `embed_tokens` gets allocated to the next device. The error will be triggered whenever a bacthed input is used, so that `causal_mask` is actually prepared instead of being `None`. @@ -4033,10 +4032,10 @@ class GenerationIntegrationTests(unittest.TestCase): _ = model.generate(**inputs, max_new_tokens=20) @pytest.mark.generate - @require_torch_multi_gpu - def test_init_static_cache_multi_gpu(self): + @require_torch_multi_accelerator + def test_init_static_cache_multi_accelerator(self): """ - Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup. + Tests if the static cache has been set correctly when we initialize it manually in a multi-accelerator setup. """ # need to split manually as auto doesn't work well with unbalanced model device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0} @@ -4870,7 +4869,7 @@ class GenerationIntegrationTests(unittest.TestCase): @require_read_token @slow - @require_torch_gpu + @require_torch_accelerator def test_cache_device_map_with_vision_layer_device_map(self): """ Test that the cache device map is correctly set when the vision layer has a device map. Regression test for