mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix tp error when torch distributed is already initialized (#38294)
fix tp error
This commit is contained in:
parent
dcaf47dde5
commit
03a4c024dc
@ -52,6 +52,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
|||||||
|
|
||||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||||
device_type = torch._C._get_accelerator().type
|
device_type = torch._C._get_accelerator().type
|
||||||
|
current_device = getattr(torch, device_type)
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
try:
|
try:
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
@ -73,6 +74,9 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
|||||||
"We tried to initialize torch.distributed for you, but it failed. Make "
|
"We tried to initialize torch.distributed for you, but it failed. Make "
|
||||||
"sure you init torch distributed in your script to use `tp_plan='auto'`."
|
"sure you init torch distributed in your script to use `tp_plan='auto'`."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
if device_type != "cpu":
|
||||||
|
current_device.set_device(int(os.environ["LOCAL_RANK"]))
|
||||||
index = current_device.current_device() if device_type != "cpu" else None
|
index = current_device.current_device() if device_type != "cpu" else None
|
||||||
tp_device = torch.device(device_type, index)
|
tp_device = torch.device(device_type, index)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user