mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
from_pretrained should handle xpu case (#37382)
* from_pretrained should handle xpu case Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fmt Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
4f139f5a50
commit
ae5ce22664
@ -3989,6 +3989,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
elif device_type == "cpu":
|
||||
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
|
||||
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
|
||||
elif device_type == "xpu":
|
||||
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
|
||||
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
@ -3997,6 +4000,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
) from e
|
||||
|
||||
# Get device with index assuming equal number of devices per host
|
||||
if device_type == "xpu":
|
||||
index = torch.xpu.current_device()
|
||||
else:
|
||||
index = None if device_type == "cpu" else torch.cuda.current_device()
|
||||
tp_device = torch.device(device_type, index)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user