[MLU] Fix FA2 check error, remove deepspeed-mlu deps. (#36159)

* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker

* Cambricon support SDPA and flash_attn

* MLU devices : Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu

* Fix mlu FA2 check. Remove deepspeed-mlu check. add mlu tests support.

* fix testing errors.

* Merge branch 'hf/main' into main

* fix get_device_count error.

* fix mlu testing utils.

* fix code quality and style.

* switch to @require_torch_multi_accelerator
This commit is contained in:
huismiling 2025-03-31 17:02:49 +08:00 committed by GitHub
parent ad63d20dff
commit d0b65bb479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 63 additions and 18 deletions

View File

@ -22,7 +22,7 @@ import weakref
from functools import partialmethod
from ..dependency_versions_check import dep_version_check
from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_available, logging
from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
@ -40,9 +40,6 @@ def is_deepspeed_available():
# AND checking it has an author field in the metadata that is HuggingFace.
if package_exists:
try:
if is_torch_mlu_available():
_ = importlib_metadata.metadata("deepspeed-mlu")
return True
_ = importlib_metadata.metadata("deepspeed")
return True
except importlib_metadata.PackageNotFoundError:

View File

@ -103,6 +103,7 @@ from .utils import (
is_safetensors_available,
is_torch_flex_attn_available,
is_torch_greater_or_equal,
is_torch_mlu_available,
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_xla_available,
@ -2323,12 +2324,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
if torch.cuda.is_available():
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
" after initializing it on CPU with `model.to('cuda')`."
)
elif is_torch_mlu_available():
logger.warning_once(
"You are attempting to use Flash Attention 2.0 with a model not initialized on MLU. Make sure to move the model to MLU"
" after initializing it on CPU with `model.to('mlu')`."
)
else:
raise ValueError(
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "

View File

@ -144,6 +144,7 @@ from .utils import (
is_torch_fp16_available_on_device,
is_torch_greater_or_equal,
is_torch_hpu_available,
is_torch_mlu_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
@ -940,6 +941,10 @@ if is_torch_available():
raise ValueError(
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
)
if torch_device == "mlu" and not is_torch_mlu_available():
raise ValueError(
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but MLU is unavailable. Please double-check your testing environment."
)
if torch_device == "hpu" and not is_torch_hpu_available():
raise ValueError(
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment."
@ -956,6 +961,8 @@ if is_torch_available():
torch_device = "cuda"
elif _run_third_party_device_tests and is_torch_npu_available():
torch_device = "npu"
elif _run_third_party_device_tests and is_torch_mlu_available():
torch_device = "mlu"
elif _run_third_party_device_tests and is_torch_hpu_available():
torch_device = "hpu"
elif _run_third_party_device_tests and is_torch_xpu_available():
@ -2927,9 +2934,21 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
if is_torch_available():
# Mappings from device names to callable functions to support device agnostic
# testing.
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"cpu": torch.manual_seed,
"default": torch.manual_seed,
}
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"cpu": None,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"cpu": lambda: 0,
"default": lambda: 1,
}
else:
BACKEND_MANUAL_SEED = {"default": None}
BACKEND_EMPTY_CACHE = {"default": None}
@ -2939,6 +2958,11 @@ if is_torch_hpu_available():
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
if is_torch_mlu_available():
BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
if is_torch_npu_available():
BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed

View File

@ -15,12 +15,12 @@
import argparse
from typing import Any, Callable
from transformers import is_torch_available
from transformers import is_torch_available, is_torch_mlu_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
require_torch_multi_accelerator,
)
@ -46,7 +46,11 @@ if is_torch_available():
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
def wrapped(*args: Any, **kwargs: Any) -> Any:
torch.distributed.init_process_group(world_size=torch.cuda.device_count())
if is_torch_mlu_available():
device_count = torch.mlu.device_count()
else:
device_count = torch.cuda.device_count()
torch.distributed.init_process_group(world_size=device_count)
try:
return func(*args, **kwargs)
finally:
@ -56,7 +60,10 @@ if is_torch_available():
@manage_process_group
def fsdp_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
if is_torch_mlu_available():
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
else:
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
@ -79,11 +86,14 @@ if is_torch_available():
@manage_process_group
def fsdp2_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
if is_torch_mlu_available():
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
else:
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),))
mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),))
for submodule in model.modules():
if isinstance(submodule, GPT2Block):
fully_shard(submodule, mesh=mesh)
@ -102,9 +112,13 @@ if is_torch_available():
class TestFSDPGeneration(TestCasePlus):
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_fsdp_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
if is_torch_mlu_available():
device_count = torch.mlu.device_count()
else:
device_count = torch.cuda.device_count()
distributed_args = f"""--nproc_per_node={device_count}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()
@ -113,9 +127,13 @@ class TestFSDPGeneration(TestCasePlus):
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_fsdp2_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
if is_torch_mlu_available():
device_count = torch.mlu.device_count()
else:
device_count = torch.cuda.device_count()
distributed_args = f"""--nproc_per_node={device_count}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()