transformers/docs/source/en/perf_infer_gpu_multi.md
Ke Wen 20142ab542
Simplify Tensor Parallel implementation with PyTorch TP (#34184)
* Simplify Tensor Parallel implementation with PyTorch TP

* Move tp_plan to config

* Lint

* Format and warning

* Disable copy-from check

* Conditionally get attr from config

* make fix-copies

* Move base_model_tp_plan to PretrainedConfig

* Move TP into from_pretrained

* Add device context for load

* Do not serialize

* Move _tp_plan setting to post_init

* Add has_tp_plan

* Add test_tp

* Add 'Multi-gpu inference' doc

* Add backward support for device type identification

* Auto-detect accelerator

* supports_tp_plan

* copyright year

* Fix copy
2024-11-18 19:51:49 +01:00

2.6 KiB

Multi-GPU inference

Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication.

To enable tensor parallel, pass the argument tp_plan="auto" to [~AutoModelForCausalLM.from_pretrained]:

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# Initialize distributed
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)

# Retrieve tensor parallel model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    tp_plan="auto",
)

# Prepare input tokens
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# Distributed run
outputs = model(inputs)

You can use torchrun to launch the above script with multiple processes, each mapping to a GPU:

torchrun --nproc-per-node 4 demo.py

PyTorch tensor parallel is currently supported for the following models:

You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request.

Expected speedups

You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences.

For a single forward pass on Llama with a sequence length of 512 and various batch sizes, the expected speedup is as follows: