mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 11:38:21 +06:00
enable generation fsdp/utils cases on XPU (#38009)
* enable generation fsdp/utils test cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> * xx Signed-off-by: Yao Matrix <matrix.yao@intel.com> * use backend_xx APIs Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
parent
87e971e14d
commit
8f08318769
@ -2986,6 +2986,11 @@ if is_torch_available():
|
|||||||
"cpu": 0,
|
"cpu": 0,
|
||||||
"default": 0,
|
"default": 0,
|
||||||
}
|
}
|
||||||
|
BACKEND_TORCH_ACCELERATOR_MODULE = {
|
||||||
|
"cuda": torch.cuda,
|
||||||
|
"cpu": None,
|
||||||
|
"default": None,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
BACKEND_MANUAL_SEED = {"default": None}
|
BACKEND_MANUAL_SEED = {"default": None}
|
||||||
BACKEND_EMPTY_CACHE = {"default": None}
|
BACKEND_EMPTY_CACHE = {"default": None}
|
||||||
@ -2993,21 +2998,25 @@ else:
|
|||||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
|
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
|
||||||
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
|
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
|
||||||
BACKEND_MEMORY_ALLOCATED = {"default": 0}
|
BACKEND_MEMORY_ALLOCATED = {"default": 0}
|
||||||
|
BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
|
||||||
|
|
||||||
|
|
||||||
if is_torch_hpu_available():
|
if is_torch_hpu_available():
|
||||||
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
|
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
|
||||||
BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
|
BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
|
||||||
|
BACKEND_TORCH_ACCELERATOR_MODULE["hpu"] = torch.hpu
|
||||||
|
|
||||||
if is_torch_mlu_available():
|
if is_torch_mlu_available():
|
||||||
BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
|
BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
|
||||||
BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
|
BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
|
||||||
BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
|
BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
|
||||||
|
BACKEND_TORCH_ACCELERATOR_MODULE["mlu"] = torch.mlu
|
||||||
|
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
|
BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
|
||||||
BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
|
BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
|
||||||
BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count
|
BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count
|
||||||
|
BACKEND_TORCH_ACCELERATOR_MODULE["npu"] = torch.npu
|
||||||
|
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
|
BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
|
||||||
@ -3016,6 +3025,8 @@ if is_torch_xpu_available():
|
|||||||
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
|
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
|
||||||
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
|
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
|
||||||
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
|
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
|
||||||
|
BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
|
||||||
|
|
||||||
|
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
|
BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
|
||||||
@ -3047,6 +3058,10 @@ def backend_memory_allocated(device: str):
|
|||||||
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
|
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_torch_accelerator_module(device: str):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
|
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
|
||||||
# into device to function mappings.
|
# into device to function mappings.
|
||||||
|
@ -140,6 +140,7 @@ from .import_utils import (
|
|||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_bitsandbytes_multi_backend_available,
|
is_bitsandbytes_multi_backend_available,
|
||||||
is_bs4_available,
|
is_bs4_available,
|
||||||
|
is_ccl_available,
|
||||||
is_coloredlogs_available,
|
is_coloredlogs_available,
|
||||||
is_compressed_tensors_available,
|
is_compressed_tensors_available,
|
||||||
is_cv2_available,
|
is_cv2_available,
|
||||||
|
@ -15,19 +15,29 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from transformers import is_torch_available, is_torch_mlu_available
|
from transformers import is_torch_available, is_torch_xpu_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
backend_device_count,
|
||||||
|
backend_torch_accelerator_module,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import is_ccl_available, is_ipex_available
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
if is_ipex_available():
|
||||||
|
import intel_extension_for_pytorch # noqa: F401
|
||||||
|
if is_ccl_available():
|
||||||
|
import oneccl_bindings_for_pytorch # noqa: F401
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
|
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
@ -46,10 +56,7 @@ if is_torch_available():
|
|||||||
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
|
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
|
||||||
|
|
||||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||||
if is_torch_mlu_available():
|
device_count = backend_device_count(torch_device)
|
||||||
device_count = torch.mlu.device_count()
|
|
||||||
else:
|
|
||||||
device_count = torch.cuda.device_count()
|
|
||||||
torch.distributed.init_process_group(world_size=device_count)
|
torch.distributed.init_process_group(world_size=device_count)
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@ -60,10 +67,8 @@ if is_torch_available():
|
|||||||
|
|
||||||
@manage_process_group
|
@manage_process_group
|
||||||
def fsdp_generate():
|
def fsdp_generate():
|
||||||
if is_torch_mlu_available():
|
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
||||||
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||||
else:
|
|
||||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||||
|
|
||||||
@ -86,10 +91,8 @@ if is_torch_available():
|
|||||||
|
|
||||||
@manage_process_group
|
@manage_process_group
|
||||||
def fsdp2_generate():
|
def fsdp2_generate():
|
||||||
if is_torch_mlu_available():
|
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
|
||||||
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||||
else:
|
|
||||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||||
|
|
||||||
@ -114,10 +117,7 @@ if is_torch_available():
|
|||||||
class TestFSDPGeneration(TestCasePlus):
|
class TestFSDPGeneration(TestCasePlus):
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_fsdp_generate(self):
|
def test_fsdp_generate(self):
|
||||||
if is_torch_mlu_available():
|
device_count = backend_device_count(torch_device)
|
||||||
device_count = torch.mlu.device_count()
|
|
||||||
else:
|
|
||||||
device_count = torch.cuda.device_count()
|
|
||||||
distributed_args = f"""--nproc_per_node={device_count}
|
distributed_args = f"""--nproc_per_node={device_count}
|
||||||
--master_port={get_torch_dist_unique_port()}
|
--master_port={get_torch_dist_unique_port()}
|
||||||
{self.test_file_dir}/test_fsdp.py
|
{self.test_file_dir}/test_fsdp.py
|
||||||
@ -129,10 +129,8 @@ class TestFSDPGeneration(TestCasePlus):
|
|||||||
|
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_fsdp2_generate(self):
|
def test_fsdp2_generate(self):
|
||||||
if is_torch_mlu_available():
|
device_count = backend_device_count(torch_device)
|
||||||
device_count = torch.mlu.device_count()
|
|
||||||
else:
|
|
||||||
device_count = torch.cuda.device_count()
|
|
||||||
distributed_args = f"""--nproc_per_node={device_count}
|
distributed_args = f"""--nproc_per_node={device_count}
|
||||||
--master_port={get_torch_dist_unique_port()}
|
--master_port={get_torch_dist_unique_port()}
|
||||||
{self.test_file_dir}/test_fsdp.py
|
{self.test_file_dir}/test_fsdp.py
|
||||||
|
@ -41,7 +41,6 @@ from transformers.testing_utils import (
|
|||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_greater_or_equal,
|
require_torch_greater_or_equal,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_multi_gpu,
|
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
set_config_for_less_flaky_test,
|
set_config_for_less_flaky_test,
|
||||||
set_model_for_less_flaky_test,
|
set_model_for_less_flaky_test,
|
||||||
@ -2954,7 +2953,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
def test_stop_sequence_stopping_criteria(self):
|
def test_stop_sequence_stopping_criteria(self):
|
||||||
prompt = """Hello I believe in"""
|
prompt = """Hello I believe in"""
|
||||||
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
|
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
|
||||||
output = generator(prompt)
|
output = generator(prompt, max_new_tokens=10)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
output,
|
output,
|
||||||
[{"generated_text": ("Hello I believe in we we we we we we we we we")}],
|
[{"generated_text": ("Hello I believe in we we we we we we we we we")}],
|
||||||
@ -3860,7 +3859,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_assisted_decoding_in_different_gpu(self):
|
def test_assisted_decoding_in_different_accelerator(self):
|
||||||
device_0 = f"{torch_device}:0" if torch_device != "cpu" else "cpu"
|
device_0 = f"{torch_device}:0" if torch_device != "cpu" else "cpu"
|
||||||
device_1 = f"{torch_device}:1" if torch_device != "cpu" else "cpu"
|
device_1 = f"{torch_device}:1" if torch_device != "cpu" else "cpu"
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(device_0)
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(device_0)
|
||||||
@ -3885,7 +3884,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
|
def test_assisted_decoding_model_in_accelerator_assistant_in_cpu(self):
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
@ -3970,10 +3969,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue((expected_out == predicted_out).all().item())
|
self.assertTrue((expected_out == predicted_out).all().item())
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_generate_with_static_cache_multi_gpu(self):
|
def test_generate_with_static_cache_multi_accelerator(self):
|
||||||
"""
|
"""
|
||||||
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
|
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-acceleratorss.
|
||||||
"""
|
"""
|
||||||
# need to split manually as auto doesn't work well with unbalanced model
|
# need to split manually as auto doesn't work well with unbalanced model
|
||||||
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||||
@ -4005,10 +4004,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_generate_multi_gpu_causal_mask(self):
|
def test_generate_multi_accelerator_causal_mask(self):
|
||||||
"""
|
"""
|
||||||
Tests that cache position device doesn't clash with causal mask device when we are using multi-gpus.
|
Tests that cache position device doesn't clash with causal mask device when we are using multi-accelerators.
|
||||||
In real life happens only when multimodal encoder size is big, so `embed_tokens` gets allocated to the next device.
|
In real life happens only when multimodal encoder size is big, so `embed_tokens` gets allocated to the next device.
|
||||||
The error will be triggered whenever a bacthed input is used, so that `causal_mask` is actually prepared instead of
|
The error will be triggered whenever a bacthed input is used, so that `causal_mask` is actually prepared instead of
|
||||||
being `None`.
|
being `None`.
|
||||||
@ -4033,10 +4032,10 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
_ = model.generate(**inputs, max_new_tokens=20)
|
_ = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_init_static_cache_multi_gpu(self):
|
def test_init_static_cache_multi_accelerator(self):
|
||||||
"""
|
"""
|
||||||
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
|
Tests if the static cache has been set correctly when we initialize it manually in a multi-accelerator setup.
|
||||||
"""
|
"""
|
||||||
# need to split manually as auto doesn't work well with unbalanced model
|
# need to split manually as auto doesn't work well with unbalanced model
|
||||||
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||||
@ -4870,7 +4869,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_cache_device_map_with_vision_layer_device_map(self):
|
def test_cache_device_map_with_vision_layer_device_map(self):
|
||||||
"""
|
"""
|
||||||
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
|
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
|
||||||
|
Loading…
Reference in New Issue
Block a user