from_pretrained should handle xpu case (#37382)

* from_pretrained should handle xpu case

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fmt

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-04-10 19:23:17 +08:00 committed by GitHub
parent 4f139f5a50
commit ae5ce22664
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3989,6 +3989,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
elif device_type == "xpu":
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
except Exception as e:
raise EnvironmentError(
@ -3997,7 +4000,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
) from e
# Get device with index assuming equal number of devices per host
index = None if device_type == "cpu" else torch.cuda.current_device()
if device_type == "xpu":
index = torch.xpu.current_device()
else:
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)
if index is not None and index > 0: