From ae5ce226644c8576c9047987e6b1d2e9bdeaed24 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 10 Apr 2025 19:23:17 +0800 Subject: [PATCH] from_pretrained should handle xpu case (#37382) * from_pretrained should handle xpu case Signed-off-by: Wang, Yi A * fmt Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- src/transformers/modeling_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b8a7831176f..76526c360c1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3989,6 +3989,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi elif device_type == "cpu": cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo" torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size) + elif device_type == "xpu": + torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size) + torch.xpu.set_device(int(os.environ["LOCAL_RANK"])) except Exception as e: raise EnvironmentError( @@ -3997,7 +4000,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi ) from e # Get device with index assuming equal number of devices per host - index = None if device_type == "cpu" else torch.cuda.current_device() + if device_type == "xpu": + index = torch.xpu.current_device() + else: + index = None if device_type == "cpu" else torch.cuda.current_device() tp_device = torch.device(device_type, index) if index is not None and index > 0: