mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Patch with accelerate xpu (#25714)
* patch with accelerate xpu * patch with accelerate xpu * formatting * fix tests * revert ruff unrelated fixes * revert ruff unrelated fixes * revert ruff unrelated fixes * fix test * review fixes * review fixes * black fixed * review commits * review commits * style fix * use pytorch_utils * revert markuplm test
This commit is contained in:
parent
aa5c94d38d
commit
70a98024b1
@ -748,6 +748,7 @@ _import_structure = {
|
||||
"is_torch_npu_available",
|
||||
"is_torch_tpu_available",
|
||||
"is_torchvision_available",
|
||||
"is_torch_xpu_available",
|
||||
"is_vision_available",
|
||||
"logging",
|
||||
],
|
||||
@ -4814,6 +4815,7 @@ if TYPE_CHECKING:
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
|
@ -100,6 +100,7 @@ from .utils import (
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdynamo_available,
|
||||
is_torchvision_available,
|
||||
@ -624,6 +625,29 @@ def require_torch_multi_npu(test_case):
|
||||
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires XPU and IPEX.
|
||||
|
||||
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
|
||||
version.
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
|
||||
|
||||
|
||||
def require_torch_multi_xpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-XPU setup with IPEX and atleast one XPU device. These tests are
|
||||
skipped on a machine without IPEX or multiple XPUs.
|
||||
|
||||
To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
|
||||
"""
|
||||
if not is_torch_xpu_available():
|
||||
return unittest.skip("test requires IPEX and atleast one XPU device")(test_case)
|
||||
|
||||
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||
import torch
|
||||
@ -641,6 +665,8 @@ if is_torch_available():
|
||||
torch_device = "cuda"
|
||||
elif _run_third_party_device_tests and is_torch_npu_available():
|
||||
torch_device = "npu"
|
||||
elif _run_third_party_device_tests and is_torch_xpu_available():
|
||||
torch_device = "xpu"
|
||||
else:
|
||||
torch_device = "cpu"
|
||||
|
||||
|
@ -38,6 +38,7 @@ from .utils import (
|
||||
is_torch_mps_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xpu_available,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
@ -97,6 +98,8 @@ def set_seed(seed: int):
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
if is_torch_npu_available():
|
||||
torch.npu.manual_seed_all(seed)
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
if is_tf_available():
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
@ -420,6 +423,11 @@ class TrainerMemoryTracker:
|
||||
elif is_torch_mps_available():
|
||||
import torch
|
||||
|
||||
self.torch = torch
|
||||
self.gpu = {}
|
||||
elif is_torch_xpu_available():
|
||||
import torch
|
||||
|
||||
self.torch = torch
|
||||
self.gpu = {}
|
||||
else:
|
||||
@ -472,12 +480,19 @@ class TrainerMemoryTracker:
|
||||
gc.collect()
|
||||
|
||||
if self.torch is not None:
|
||||
self.torch.cuda.reset_peak_memory_stats()
|
||||
self.torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
self.torch.cuda.reset_peak_memory_stats()
|
||||
self.torch.cuda.empty_cache()
|
||||
elif is_torch_xpu_available():
|
||||
self.torch.xpu.reset_peak_memory_stats()
|
||||
self.torch.xpu.empty_cache()
|
||||
|
||||
# gpu
|
||||
if self.torch is not None:
|
||||
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
|
||||
if torch.cuda.is_available():
|
||||
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
|
||||
elif is_torch_xpu_available():
|
||||
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
|
||||
|
||||
# cpu
|
||||
self.cpu_mem_used_at_start = self.cpu_mem_used()
|
||||
@ -501,7 +516,10 @@ class TrainerMemoryTracker:
|
||||
gc.collect()
|
||||
|
||||
if self.torch is not None:
|
||||
self.torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
self.torch.cuda.empty_cache()
|
||||
elif is_torch_xpu_available():
|
||||
self.torch.xpu.empty_cache()
|
||||
|
||||
# concepts:
|
||||
# - alloc_delta: the difference of allocated memory between the end and the start
|
||||
@ -510,8 +528,15 @@ class TrainerMemoryTracker:
|
||||
|
||||
# gpu
|
||||
if self.torch is not None:
|
||||
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
|
||||
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
|
||||
if torch.cuda.is_available():
|
||||
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
|
||||
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
|
||||
elif is_torch_xpu_available():
|
||||
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
|
||||
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
|
||||
else:
|
||||
raise ValueError("No available GPU device found!")
|
||||
|
||||
self.gpu[self.cur_stage] = {
|
||||
"begin": self.gpu_mem_used_at_start,
|
||||
"end": self.gpu_mem_used_now,
|
||||
|
@ -50,6 +50,7 @@ from .utils import (
|
||||
is_torch_npu_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xpu_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
@ -194,9 +195,9 @@ class TrainingArguments:
|
||||
prediction_loss_only (`bool`, *optional*, defaults to `False`):
|
||||
When performing evaluation and generating predictions, only returns the loss.
|
||||
per_device_train_batch_size (`int`, *optional*, defaults to 8):
|
||||
The batch size per GPU/TPU/MPS/NPU core/CPU for training.
|
||||
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.
|
||||
per_device_eval_batch_size (`int`, *optional*, defaults to 8):
|
||||
The batch size per GPU/TPU/MPS/NPU core/CPU for evaluation.
|
||||
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
|
||||
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
|
||||
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
|
||||
|
||||
@ -1357,11 +1358,20 @@ class TrainingArguments:
|
||||
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
|
||||
# cpu
|
||||
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
|
||||
elif not self.use_cpu and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
|
||||
# gpu
|
||||
raise ValueError(
|
||||
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
|
||||
)
|
||||
elif not self.use_cpu:
|
||||
if torch.cuda.is_available() and not is_torch_bf16_gpu_available():
|
||||
# gpu
|
||||
raise ValueError(
|
||||
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
|
||||
)
|
||||
elif not is_torch_xpu_available():
|
||||
# xpu
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_1_12
|
||||
|
||||
if not is_torch_greater_or_equal_than_1_12:
|
||||
raise ValueError(
|
||||
"Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed"
|
||||
)
|
||||
|
||||
if self.fp16 and self.bf16:
|
||||
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
||||
@ -1416,6 +1426,7 @@ class TrainingArguments:
|
||||
self.framework == "pt"
|
||||
and is_torch_available()
|
||||
and (self.device.type != "cuda")
|
||||
and (self.device.type != "xpu")
|
||||
and (get_xla_device_type(self.device) != "GPU")
|
||||
and (get_xla_device_type(self.device) != "TPU")
|
||||
and (self.device.type != "cpu")
|
||||
@ -1423,7 +1434,7 @@ class TrainingArguments:
|
||||
):
|
||||
raise ValueError(
|
||||
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
|
||||
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
|
||||
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX) or CPU/TPU/NeuronCore devices."
|
||||
)
|
||||
|
||||
if self.torchdynamo is not None:
|
||||
@ -1779,6 +1790,10 @@ class TrainingArguments:
|
||||
device = torch.device("cuda", local_rank)
|
||||
self._n_gpu = 1
|
||||
torch.cuda.set_device(device)
|
||||
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
|
||||
os.environ["ACCELERATE_USE_XPU"] = "true"
|
||||
device = torch.device("xpu:0")
|
||||
self._n_gpu = 1
|
||||
elif is_sagemaker_dp_enabled():
|
||||
self.distributed_state = PartialState(_use_sagemaker_dp=True)
|
||||
self._n_gpu = 1
|
||||
@ -1807,6 +1822,12 @@ class TrainingArguments:
|
||||
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
|
||||
# Already set _n_gpu
|
||||
pass
|
||||
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
|
||||
if "ACCELERATE_USE_XPU" not in os.environ:
|
||||
os.environ["ACCELERATE_USE_XPU"] = "true"
|
||||
self._n_gpu = torch.xpu.device_count()
|
||||
device = torch.device("xpu:0")
|
||||
torch.xpu.set_device(device)
|
||||
elif self.distributed_state.distributed_type == DistributedType.NO:
|
||||
if self.use_mps_device:
|
||||
warnings.warn(
|
||||
@ -1824,6 +1845,10 @@ class TrainingArguments:
|
||||
elif self.use_cpu:
|
||||
device = torch.device("cpu")
|
||||
self._n_gpu = 0
|
||||
elif is_torch_xpu_available():
|
||||
device = torch.device("xpu:0")
|
||||
torch.xpu.set_device(device)
|
||||
self._n_gpu = 1
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu:0")
|
||||
torch.npu.set_device(device)
|
||||
|
@ -172,6 +172,7 @@ from .import_utils import (
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torch_xpu_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdistx_available,
|
||||
is_torchdynamo_available,
|
||||
|
@ -528,6 +528,25 @@ def is_ipex_available():
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available(check_device=False):
|
||||
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
|
||||
if not is_ipex_available():
|
||||
return False
|
||||
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import torch
|
||||
|
||||
if check_device:
|
||||
try:
|
||||
# Will raise a RuntimeError if no XPU is found
|
||||
_ = torch.xpu.device_count()
|
||||
return torch.xpu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
|
||||
def is_bitsandbytes_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
@ -22,6 +22,7 @@ from transformers.testing_utils import (
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_xpu,
|
||||
require_torch_neuroncore,
|
||||
require_torch_npu,
|
||||
)
|
||||
@ -158,6 +159,20 @@ class TestTrainerDistributed(TestCasePlus):
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
@require_torch_multi_xpu
|
||||
class TestTrainerDistributedXPU(TestCasePlus):
|
||||
def test_trainer(self):
|
||||
distributed_args = f"""--nproc_per_node={torch.xpu.device_count()}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_trainer_distributed.py
|
||||
""".split()
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"--output_dir {output_dir}".split()
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||
#
|
||||
|
Loading…
Reference in New Issue
Block a user