Patch with accelerate xpu (#25714)

* patch with accelerate xpu

* patch with accelerate xpu

* formatting

* fix tests

* revert ruff unrelated fixes

* revert ruff unrelated fixes

* revert ruff unrelated fixes

* fix test

* review fixes

* review fixes

* black fixed

* review commits

* review commits

* style fix

* use pytorch_utils

* revert markuplm test
This commit is contained in:
Abhilash Majumder 2023-09-05 20:11:42 +05:30 committed by GitHub
parent aa5c94d38d
commit 70a98024b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 127 additions and 14 deletions

View File

@ -748,6 +748,7 @@ _import_structure = {
"is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available",
"is_torch_xpu_available",
"is_vision_available",
"logging",
],
@ -4814,6 +4815,7 @@ if TYPE_CHECKING:
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xpu_available,
is_torchvision_available,
is_vision_available,
logging,

View File

@ -100,6 +100,7 @@ from .utils import (
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_torchvision_available,
@ -624,6 +625,29 @@ def require_torch_multi_npu(test_case):
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
def require_torch_xpu(test_case):
"""
Decorator marking a test that requires XPU and IPEX.
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
"""
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
def require_torch_multi_xpu(test_case):
"""
Decorator marking a test that requires a multi-XPU setup with IPEX and atleast one XPU device. These tests are
skipped on a machine without IPEX or multiple XPUs.
To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
"""
if not is_torch_xpu_available():
return unittest.skip("test requires IPEX and atleast one XPU device")(test_case)
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
@ -641,6 +665,8 @@ if is_torch_available():
torch_device = "cuda"
elif _run_third_party_device_tests and is_torch_npu_available():
torch_device = "npu"
elif _run_third_party_device_tests and is_torch_xpu_available():
torch_device = "xpu"
else:
torch_device = "cpu"

View File

@ -38,6 +38,7 @@ from .utils import (
is_torch_mps_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xpu_available,
requires_backends,
)
@ -97,6 +98,8 @@ def set_seed(seed: int):
# ^^ safe to call this function even if cuda is not available
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
if is_torch_xpu_available():
torch.xpu.manual_seed_all(seed)
if is_tf_available():
tf.random.set_seed(seed)
@ -420,6 +423,11 @@ class TrainerMemoryTracker:
elif is_torch_mps_available():
import torch
self.torch = torch
self.gpu = {}
elif is_torch_xpu_available():
import torch
self.torch = torch
self.gpu = {}
else:
@ -472,12 +480,19 @@ class TrainerMemoryTracker:
gc.collect()
if self.torch is not None:
self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache()
if torch.cuda.is_available():
self.torch.cuda.reset_peak_memory_stats()
self.torch.cuda.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.reset_peak_memory_stats()
self.torch.xpu.empty_cache()
# gpu
if self.torch is not None:
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
if torch.cuda.is_available():
self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
# cpu
self.cpu_mem_used_at_start = self.cpu_mem_used()
@ -501,7 +516,10 @@ class TrainerMemoryTracker:
gc.collect()
if self.torch is not None:
self.torch.cuda.empty_cache()
if torch.cuda.is_available():
self.torch.cuda.empty_cache()
elif is_torch_xpu_available():
self.torch.xpu.empty_cache()
# concepts:
# - alloc_delta: the difference of allocated memory between the end and the start
@ -510,8 +528,15 @@ class TrainerMemoryTracker:
# gpu
if self.torch is not None:
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
if torch.cuda.is_available():
self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
elif is_torch_xpu_available():
self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
else:
raise ValueError("No available GPU device found!")
self.gpu[self.cur_stage] = {
"begin": self.gpu_mem_used_at_start,
"end": self.gpu_mem_used_now,

View File

@ -50,6 +50,7 @@ from .utils import (
is_torch_npu_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xpu_available,
logging,
requires_backends,
)
@ -194,9 +195,9 @@ class TrainingArguments:
prediction_loss_only (`bool`, *optional*, defaults to `False`):
When performing evaluation and generating predictions, only returns the loss.
per_device_train_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/TPU/MPS/NPU core/CPU for training.
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.
per_device_eval_batch_size (`int`, *optional*, defaults to 8):
The batch size per GPU/TPU/MPS/NPU core/CPU for evaluation.
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
@ -1357,11 +1358,20 @@ class TrainingArguments:
if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
# cpu
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.use_cpu and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
elif not self.use_cpu:
if torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
)
elif not is_torch_xpu_available():
# xpu
from .pytorch_utils import is_torch_greater_or_equal_than_1_12
if not is_torch_greater_or_equal_than_1_12:
raise ValueError(
"Your setup doesn't support bf16/xpu. You need torch>=1.12, using Intel XPU/GPU with IPEX installed"
)
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
@ -1416,6 +1426,7 @@ class TrainingArguments:
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
@ -1423,7 +1434,7 @@ class TrainingArguments:
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX) or CPU/TPU/NeuronCore devices."
)
if self.torchdynamo is not None:
@ -1779,6 +1790,10 @@ class TrainingArguments:
device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
device = torch.device("xpu:0")
self._n_gpu = 1
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
@ -1807,6 +1822,12 @@ class TrainingArguments:
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
# Already set _n_gpu
pass
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
if "ACCELERATE_USE_XPU" not in os.environ:
os.environ["ACCELERATE_USE_XPU"] = "true"
self._n_gpu = torch.xpu.device_count()
device = torch.device("xpu:0")
torch.xpu.set_device(device)
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
warnings.warn(
@ -1824,6 +1845,10 @@ class TrainingArguments:
elif self.use_cpu:
device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)

View File

@ -172,6 +172,7 @@ from .import_utils import (
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,

View File

@ -528,6 +528,25 @@ def is_ipex_available():
return True
@lru_cache
def is_torch_xpu_available(check_device=False):
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
if not is_ipex_available():
return False
import intel_extension_for_pytorch # noqa: F401
import torch
if check_device:
try:
# Will raise a RuntimeError if no XPU is found
_ = torch.xpu.device_count()
return torch.xpu.is_available()
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
def is_bitsandbytes_available():
if not is_torch_available():
return False

View File

@ -22,6 +22,7 @@ from transformers.testing_utils import (
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
require_torch_multi_xpu,
require_torch_neuroncore,
require_torch_npu,
)
@ -158,6 +159,20 @@ class TestTrainerDistributed(TestCasePlus):
# successful return here == success - any errors would have caused an error in the sub-call
@require_torch_multi_xpu
class TestTrainerDistributedXPU(TestCasePlus):
def test_trainer(self):
distributed_args = f"""--nproc_per_node={torch.xpu.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir}".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
#