mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix error of HPU TP (#37782)
* Fix error of HPU TP Signed-off-by: yuanwu <yuan.wu@intel.com> * Add the init distrubuted for hpu Signed-off-by: yuanwu <yuan.wu@intel.com> * Fix error of make style Signed-off-by: yuanwu <yuan.wu@intel.com> --------- Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
da4ff2a5f5
commit
2933894985
@ -4108,6 +4108,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
elif device_type == "xpu":
|
elif device_type == "xpu":
|
||||||
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
|
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
|
||||||
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
|
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||||
|
elif device_type == "hpu":
|
||||||
|
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
|
||||||
|
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
@ -4118,6 +4121,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# Get device with index assuming equal number of devices per host
|
# Get device with index assuming equal number of devices per host
|
||||||
if device_type == "xpu":
|
if device_type == "xpu":
|
||||||
index = torch.xpu.current_device()
|
index = torch.xpu.current_device()
|
||||||
|
elif device_type == "hpu":
|
||||||
|
index = torch.hpu.current_device()
|
||||||
else:
|
else:
|
||||||
index = None if device_type == "cpu" else torch.cuda.current_device()
|
index = None if device_type == "cpu" else torch.cuda.current_device()
|
||||||
tp_device = torch.device(device_type, index)
|
tp_device = torch.device(device_type, index)
|
||||||
|
Loading…
Reference in New Issue
Block a user