mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
[MLU] Fix FA2 check error, remove deepspeed-mlu deps. (#36159)
* add Cambricon MLUs support * fix mlu device rng state * up for quality check * up mlu to support fp16 * fix mlu device dependency error * fix mlu device dependency error * enable mlu device for bf16 * fix mlu device memory tracker * Cambricon support SDPA and flash_attn * MLU devices : Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu * Fix mlu FA2 check. Remove deepspeed-mlu check. add mlu tests support. * fix testing errors. * Merge branch 'hf/main' into main * fix get_device_count error. * fix mlu testing utils. * fix code quality and style. * switch to @require_torch_multi_accelerator
This commit is contained in:
parent
ad63d20dff
commit
d0b65bb479
@ -22,7 +22,7 @@ import weakref
|
||||
from functools import partialmethod
|
||||
|
||||
from ..dependency_versions_check import dep_version_check
|
||||
from ..utils import is_accelerate_available, is_torch_available, is_torch_mlu_available, logging
|
||||
from ..utils import is_accelerate_available, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -40,9 +40,6 @@ def is_deepspeed_available():
|
||||
# AND checking it has an author field in the metadata that is HuggingFace.
|
||||
if package_exists:
|
||||
try:
|
||||
if is_torch_mlu_available():
|
||||
_ = importlib_metadata.metadata("deepspeed-mlu")
|
||||
return True
|
||||
_ = importlib_metadata.metadata("deepspeed")
|
||||
return True
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
|
@ -103,6 +103,7 @@ from .utils import (
|
||||
is_safetensors_available,
|
||||
is_torch_flex_attn_available,
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_mlu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
is_torch_xla_available,
|
||||
@ -2323,12 +2324,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
elif is_torch_mlu_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on MLU. Make sure to move the model to MLU"
|
||||
" after initializing it on CPU with `model.to('mlu')`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
|
||||
|
@ -144,6 +144,7 @@ from .utils import (
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_hpu_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
@ -940,6 +941,10 @@ if is_torch_available():
|
||||
raise ValueError(
|
||||
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
|
||||
)
|
||||
if torch_device == "mlu" and not is_torch_mlu_available():
|
||||
raise ValueError(
|
||||
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but MLU is unavailable. Please double-check your testing environment."
|
||||
)
|
||||
if torch_device == "hpu" and not is_torch_hpu_available():
|
||||
raise ValueError(
|
||||
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment."
|
||||
@ -956,6 +961,8 @@ if is_torch_available():
|
||||
torch_device = "cuda"
|
||||
elif _run_third_party_device_tests and is_torch_npu_available():
|
||||
torch_device = "npu"
|
||||
elif _run_third_party_device_tests and is_torch_mlu_available():
|
||||
torch_device = "mlu"
|
||||
elif _run_third_party_device_tests and is_torch_hpu_available():
|
||||
torch_device = "hpu"
|
||||
elif _run_third_party_device_tests and is_torch_xpu_available():
|
||||
@ -2927,9 +2934,21 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
|
||||
if is_torch_available():
|
||||
# Mappings from device names to callable functions to support device agnostic
|
||||
# testing.
|
||||
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
|
||||
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
|
||||
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}
|
||||
BACKEND_MANUAL_SEED = {
|
||||
"cuda": torch.cuda.manual_seed,
|
||||
"cpu": torch.manual_seed,
|
||||
"default": torch.manual_seed,
|
||||
}
|
||||
BACKEND_EMPTY_CACHE = {
|
||||
"cuda": torch.cuda.empty_cache,
|
||||
"cpu": None,
|
||||
"default": None,
|
||||
}
|
||||
BACKEND_DEVICE_COUNT = {
|
||||
"cuda": torch.cuda.device_count,
|
||||
"cpu": lambda: 0,
|
||||
"default": lambda: 1,
|
||||
}
|
||||
else:
|
||||
BACKEND_MANUAL_SEED = {"default": None}
|
||||
BACKEND_EMPTY_CACHE = {"default": None}
|
||||
@ -2939,6 +2958,11 @@ if is_torch_hpu_available():
|
||||
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
|
||||
BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
|
||||
|
||||
if is_torch_mlu_available():
|
||||
BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
|
||||
BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
|
||||
BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
|
||||
|
||||
if is_torch_npu_available():
|
||||
BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
|
||||
BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
|
||||
|
@ -15,12 +15,12 @@
|
||||
import argparse
|
||||
from typing import Any, Callable
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers import is_torch_available, is_torch_mlu_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
)
|
||||
|
||||
|
||||
@ -46,7 +46,11 @@ if is_torch_available():
|
||||
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
|
||||
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
torch.distributed.init_process_group(world_size=torch.cuda.device_count())
|
||||
if is_torch_mlu_available():
|
||||
device_count = torch.mlu.device_count()
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
torch.distributed.init_process_group(world_size=device_count)
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
@ -56,7 +60,10 @@ if is_torch_available():
|
||||
|
||||
@manage_process_group
|
||||
def fsdp_generate():
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
if is_torch_mlu_available():
|
||||
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
else:
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
@ -79,11 +86,14 @@ if is_torch_available():
|
||||
|
||||
@manage_process_group
|
||||
def fsdp2_generate():
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
if is_torch_mlu_available():
|
||||
torch.mlu.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
else:
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),))
|
||||
mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),))
|
||||
for submodule in model.modules():
|
||||
if isinstance(submodule, GPT2Block):
|
||||
fully_shard(submodule, mesh=mesh)
|
||||
@ -102,9 +112,13 @@ if is_torch_available():
|
||||
|
||||
|
||||
class TestFSDPGeneration(TestCasePlus):
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_fsdp_generate(self):
|
||||
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
||||
if is_torch_mlu_available():
|
||||
device_count = torch.mlu.device_count()
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
distributed_args = f"""--nproc_per_node={device_count}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
@ -113,9 +127,13 @@ class TestFSDPGeneration(TestCasePlus):
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_fsdp2_generate(self):
|
||||
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
||||
if is_torch_mlu_available():
|
||||
device_count = torch.mlu.device_count()
|
||||
else:
|
||||
device_count = torch.cuda.device_count()
|
||||
distributed_args = f"""--nproc_per_node={device_count}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
|
Loading…
Reference in New Issue
Block a user