Fix tp error when torch distributed is already initialized (#38294)

fix tp error
This commit is contained in:
Marc Sun 2025-05-22 12:34:05 +02:00 committed by GitHub
parent dcaf47dde5
commit 03a4c024dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,6 +52,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU. # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type device_type = torch._C._get_accelerator().type
current_device = getattr(torch, device_type)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
try: try:
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
@ -73,6 +74,9 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
"We tried to initialize torch.distributed for you, but it failed. Make " "We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan='auto'`." "sure you init torch distributed in your script to use `tp_plan='auto'`."
) from e ) from e
if device_type != "cpu":
current_device.set_device(int(os.environ["LOCAL_RANK"]))
index = current_device.current_device() if device_type != "cpu" else None index = current_device.current_device() if device_type != "cpu" else None
tp_device = torch.device(device_type, index) tp_device = torch.device(device_type, index)