Fixing the docs corresponding to the breaking change in torch 2.6. (#36420)

This commit is contained in:
Nicolas Patry 2025-02-26 14:11:52 +01:00 committed by GitHub
parent 9a217fc327
commit b4965cecc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 0 deletions

View File

@ -29,6 +29,7 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# Initialize distributed # Initialize distributed
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.distributed.init_process_group("nccl", device_id=device) torch.distributed.init_process_group("nccl", device_id=device)
# Retrieve tensor parallel model # Retrieve tensor parallel model

View File

@ -29,6 +29,7 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# 初始化分布式环境 # 初始化分布式环境
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.distributed.init_process_group("nccl", device_id=device) torch.distributed.init_process_group("nccl", device_id=device)
# 获取支持张量并行的模型 # 获取支持张量并行的模型