From 03a4c024dcc25b0668f11228ba6dd83b8d63655c Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Thu, 22 May 2025 12:34:05 +0200 Subject: [PATCH] Fix tp error when torch distributed is already initialized (#38294) fix tp error --- src/transformers/integrations/tensor_parallel.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 8312891941b..a9f8940e72e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -52,6 +52,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None): # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. device_type = torch._C._get_accelerator().type + current_device = getattr(torch, device_type) if not torch.distributed.is_initialized(): try: rank = int(os.environ["RANK"]) @@ -73,6 +74,9 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None): "We tried to initialize torch.distributed for you, but it failed. Make " "sure you init torch distributed in your script to use `tp_plan='auto'`." ) from e + + if device_type != "cpu": + current_device.set_device(int(os.environ["LOCAL_RANK"])) index = current_device.current_device() if device_type != "cpu" else None tp_device = torch.device(device_type, index)