Fix error of HPU TP (#37782)

* Fix error of HPU TP

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add the init distrubuted for hpu

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Fix error of make style

Signed-off-by: yuanwu <yuan.wu@intel.com>

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
Yuan Wu 2025-04-28 21:47:16 +08:00 committed by GitHub
parent da4ff2a5f5
commit 2933894985
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4108,6 +4108,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
elif device_type == "xpu":
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "hpu":
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
except Exception as e:
raise EnvironmentError(
@ -4118,6 +4121,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Get device with index assuming equal number of devices per host
if device_type == "xpu":
index = torch.xpu.current_device()
elif device_type == "hpu":
index = torch.hpu.current_device()
else:
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)