From 2933894985b8b69fb65c5f0e7676f2be88f965b9 Mon Sep 17 00:00:00 2001 From: Yuan Wu Date: Mon, 28 Apr 2025 21:47:16 +0800 Subject: [PATCH] Fix error of HPU TP (#37782) * Fix error of HPU TP Signed-off-by: yuanwu * Add the init distrubuted for hpu Signed-off-by: yuanwu * Fix error of make style Signed-off-by: yuanwu --------- Signed-off-by: yuanwu --- src/transformers/modeling_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0ee4182f98c..199adf825b7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4108,6 +4108,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi elif device_type == "xpu": torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size) torch.xpu.set_device(int(os.environ["LOCAL_RANK"])) + elif device_type == "hpu": + torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size) + torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) except Exception as e: raise EnvironmentError( @@ -4118,6 +4121,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # Get device with index assuming equal number of devices per host if device_type == "xpu": index = torch.xpu.current_device() + elif device_type == "hpu": + index = torch.hpu.current_device() else: index = None if device_type == "cpu" else torch.cuda.current_device() tp_device = torch.device(device_type, index)