mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Support MUSA (Moore Threads GPU) backend in transformers (#31913)
Add accelerate version check, needs accelerate>=0.33.0
This commit is contained in:
parent
c1357834e8
commit
a22ff36e0e
2
.github/workflows/build-ci-docker-images.yml
vendored
2
.github/workflows/build-ci-docker-images.yml
vendored
@ -74,4 +74,4 @@ jobs:
|
||||
slack_channel: "#transformers-ci-circleci-images"
|
||||
title: 🤗 New docker images for CircleCI are pushed.
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
@ -930,6 +930,7 @@ _import_structure = {
|
||||
"is_tokenizers_available",
|
||||
"is_torch_available",
|
||||
"is_torch_mlu_available",
|
||||
"is_torch_musa_available",
|
||||
"is_torch_neuroncore_available",
|
||||
"is_torch_npu_available",
|
||||
"is_torch_tpu_available",
|
||||
@ -5706,6 +5707,7 @@ if TYPE_CHECKING:
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_musa_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_tpu_available,
|
||||
|
@ -45,6 +45,7 @@ from ..utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_musa_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
logging,
|
||||
@ -873,6 +874,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
self.device = torch.device("cpu")
|
||||
elif is_torch_mlu_available():
|
||||
self.device = torch.device(f"mlu:{device}")
|
||||
elif is_torch_musa_available():
|
||||
self.device = torch.device(f"musa:{device}")
|
||||
elif is_torch_cuda_available():
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
elif is_torch_npu_available():
|
||||
@ -1042,6 +1045,9 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
elif self.device.type == "mlu":
|
||||
with torch.mlu.device(self.device):
|
||||
yield
|
||||
elif self.device.type == "musa":
|
||||
with torch.musa.device(self.device):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
@ -164,6 +164,7 @@ from .utils import (
|
||||
is_torch_compile_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_musa_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xla_available,
|
||||
@ -2894,6 +2895,17 @@ class Trainer:
|
||||
f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
|
||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||
)
|
||||
if is_torch_musa_available():
|
||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
torch.musa.set_rng_state_all(checkpoint_rng_state["musa"])
|
||||
else:
|
||||
try:
|
||||
torch.musa.set_rng_state(checkpoint_rng_state["musa"])
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}"
|
||||
"\nThis won't yield the same results as if the training had not been interrupted."
|
||||
)
|
||||
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
|
||||
@ -2982,6 +2994,12 @@ class Trainer:
|
||||
else:
|
||||
rng_states["mlu"] = torch.mlu.random.get_rng_state()
|
||||
|
||||
if is_torch_musa_available():
|
||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
rng_states["musa"] = torch.musa.get_rng_state_all()
|
||||
else:
|
||||
rng_states["musa"] = torch.musa.get_rng_state()
|
||||
|
||||
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
|
||||
# not yet exist.
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@ -3351,6 +3369,8 @@ class Trainer:
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_mlu_available():
|
||||
torch.mlu.empty_cache()
|
||||
elif is_torch_musa_available():
|
||||
torch.musa.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
elif is_torch_mps_available(min_version="2.0"):
|
||||
|
@ -37,6 +37,7 @@ from .utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_musa_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xla_available,
|
||||
is_torch_xpu_available,
|
||||
@ -108,6 +109,8 @@ def set_seed(seed: int, deterministic: bool = False):
|
||||
torch.use_deterministic_algorithms(True)
|
||||
if is_torch_mlu_available():
|
||||
torch.mlu.manual_seed_all(seed)
|
||||
if is_torch_musa_available():
|
||||
torch.musa.manual_seed_all(seed)
|
||||
if is_torch_npu_available():
|
||||
torch.npu.manual_seed_all(seed)
|
||||
if is_torch_xpu_available():
|
||||
@ -464,7 +467,7 @@ class TrainerMemoryTracker:
|
||||
|
||||
import psutil # noqa
|
||||
|
||||
if is_torch_cuda_available() or is_torch_mlu_available():
|
||||
if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
|
||||
import torch
|
||||
|
||||
self.torch = torch
|
||||
@ -540,6 +543,9 @@ class TrainerMemoryTracker:
|
||||
elif is_torch_mlu_available():
|
||||
self.torch.mlu.reset_peak_memory_stats()
|
||||
self.torch.mlu.empty_cache()
|
||||
elif is_torch_musa_available():
|
||||
self.torch.musa.reset_peak_memory_stats()
|
||||
self.torch.musa.empty_cache()
|
||||
elif is_torch_xpu_available():
|
||||
self.torch.xpu.reset_peak_memory_stats()
|
||||
self.torch.xpu.empty_cache()
|
||||
@ -555,6 +561,8 @@ class TrainerMemoryTracker:
|
||||
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
|
||||
elif is_torch_mlu_available():
|
||||
self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
|
||||
elif is_torch_musa_available():
|
||||
self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
|
||||
elif is_torch_xpu_available():
|
||||
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
|
||||
elif is_torch_npu_available():
|
||||
@ -588,6 +596,8 @@ class TrainerMemoryTracker:
|
||||
self.torch.cuda.empty_cache()
|
||||
elif is_torch_mlu_available():
|
||||
self.torch.mlu.empty_cache()
|
||||
elif is_torch_musa_available():
|
||||
self.torch.musa.empty_cache()
|
||||
elif is_torch_xpu_available():
|
||||
self.torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
@ -608,6 +618,9 @@ class TrainerMemoryTracker:
|
||||
elif is_torch_mlu_available():
|
||||
self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
|
||||
self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
|
||||
elif is_torch_musa_available():
|
||||
self.gpu_mem_used_now = self.torch.musa.memory_allocated()
|
||||
self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
|
||||
elif is_torch_xpu_available():
|
||||
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
|
||||
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
|
||||
|
@ -49,6 +49,7 @@ from .utils import (
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_musa_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_tf32_available,
|
||||
@ -1090,7 +1091,7 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The backend to be used for distributed training",
|
||||
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl"],
|
||||
"choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl", "mccl"],
|
||||
},
|
||||
)
|
||||
tpu_num_cores: Optional[int] = field(
|
||||
@ -2201,6 +2202,9 @@ class TrainingArguments:
|
||||
elif is_torch_mlu_available():
|
||||
device = torch.device("mlu:0")
|
||||
torch.mlu.set_device(device)
|
||||
elif is_torch_musa_available():
|
||||
device = torch.device("musa:0")
|
||||
torch.musa.set_device(device)
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu:0")
|
||||
torch.npu.set_device(device)
|
||||
|
@ -201,6 +201,7 @@ from .import_utils import (
|
||||
is_torch_fx_proxy,
|
||||
is_torch_mlu_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_musa_available,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
|
@ -677,6 +677,29 @@ def is_torch_mlu_available(check_device=False):
|
||||
return hasattr(torch, "mlu") and torch.mlu.is_available()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def is_torch_musa_available(check_device=False):
|
||||
"Checks if `torch_musa` is installed and potentially if a MUSA is in the environment"
|
||||
if not _torch_available or importlib.util.find_spec("torch_musa") is None:
|
||||
return False
|
||||
|
||||
import torch
|
||||
import torch_musa # noqa: F401
|
||||
|
||||
torch_musa_min_version = "0.33.0"
|
||||
if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_musa_min_version):
|
||||
return False
|
||||
|
||||
if check_device:
|
||||
try:
|
||||
# Will raise a RuntimeError if no MUSA is found
|
||||
_ = torch.musa.device_count()
|
||||
return torch.musa.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return hasattr(torch, "musa") and torch.musa.is_available()
|
||||
|
||||
|
||||
def is_torchdynamo_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
Loading…
Reference in New Issue
Block a user