diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md index 7f5d52363e4..37c193359ac 100644 --- a/docs/source/en/perf_infer_gpu_multi.md +++ b/docs/source/en/perf_infer_gpu_multi.md @@ -29,6 +29,7 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # Initialize distributed rank = int(os.environ["RANK"]) device = torch.device(f"cuda:{rank}") +torch.cuda.set_device(device) torch.distributed.init_process_group("nccl", device_id=device) # Retrieve tensor parallel model diff --git a/docs/source/zh/perf_infer_gpu_multi.md b/docs/source/zh/perf_infer_gpu_multi.md index 35e5bac465a..91a54e1d3f5 100644 --- a/docs/source/zh/perf_infer_gpu_multi.md +++ b/docs/source/zh/perf_infer_gpu_multi.md @@ -29,6 +29,7 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # 初始化分布式环境 rank = int(os.environ["RANK"]) device = torch.device(f"cuda:{rank}") +torch.cuda.set_device(device) torch.distributed.init_process_group("nccl", device_id=device) # 获取支持张量并行的模型