enable tp on CPU (#36299)

* enable tp on CPU

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* get rank from cpu

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable TP tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* em print

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix model id

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix conflict

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix index and add doc

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-03-31 16:55:47 +08:00 committed by GitHub
parent 4705b04c74
commit 286393fbb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 113 deletions

View File

@ -44,11 +44,6 @@ import os
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
# initialize distributed environment
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.distributed.init_process_group("nccl", device_id=device)
# enable tensor parallelism # enable tensor parallelism
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@ -59,7 +54,7 @@ model = AutoModelForCausalLM.from_pretrained(
# prepare input tokens # prepare input tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
prompt = "Can I help" prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device) inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
# distributed run # distributed run
outputs = model(inputs) outputs = model(inputs)
@ -71,6 +66,13 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
torchrun --nproc-per-node 4 demo.py torchrun --nproc-per-node 4 demo.py
``` ```
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
```bash
export OMP_NUM_THREADS=56
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
```
The CPU benchmark data will be released soon.
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences. You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups. For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.

View File

@ -774,7 +774,8 @@ def _load_state_dict_into_meta_model(
""" """
tensor_device = "cpu" tensor_device = "cpu"
if device_map is not None and device_map.get("", None) is not None: if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] if device_map[""] not in ("cpu", torch.device("cpu")):
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None: if device_map is not None:
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
@ -4110,24 +4111,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if tp_plan is not None: if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"): if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
try: try:
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.environ["WORLD_SIZE"])
torch.distributed.init_process_group( if device_type == "cuda":
"nccl", rank=rank, world_size=world_size, init_method="env://" torch.distributed.init_process_group(
) "nccl", rank=rank, world_size=world_size, init_method="env://"
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) )
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
except Exception as e: except Exception as e:
raise EnvironmentError( raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make" "We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`" "sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e ) from e
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU. # Get device with index assuming equal number of devices per host
device_type = torch._C._get_accelerator().type index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, torch.cuda.current_device()) tp_device = torch.device(device_type, index)
if tp_device.index > 0:
if index is not None and index > 0:
import sys import sys
sys.stdout = open(os.devnull, "w") sys.stdout = open(os.devnull, "w")

View File

@ -12,18 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess import subprocess
import tempfile import tempfile
import textwrap import textwrap
# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port, get_torch_dist_unique_port,
require_torch_multi_gpu, require_torch_multi_gpu,
) )
@ -33,7 +28,10 @@ if is_torch_available():
import torch import torch
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
class TestTensorParallel(TestCasePlus): class TestTensorParallel(TestCasePlus):
nproc_per_node = 2
def torchrun(self, script: str): def torchrun(self, script: str):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
@ -41,7 +39,7 @@ class TestTensorParallel(TestCasePlus):
tmp.flush() tmp.flush()
tmp.seek(0) tmp.seek(0)
cmd = ( cmd = (
f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}" f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split() ).split()
# Note that the subprocess will be waited for here, and raise an error if not successful # Note that the subprocess will be waited for here, and raise an error if not successful
@ -50,44 +48,39 @@ class TestTensorParallel(TestCasePlus):
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}") raise Exception(f"The following error was captured: {e.stderr}")
@require_torch_multi_gpu def test_model_forward(self):
def test_tp(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_tp.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir} --report_to none".split()
cmd = ["torchrun"] + distributed_args + args
print(cmd)
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
@require_torch_multi_gpu
def test_loading_memory_consumption(self):
script_to_run = textwrap.dedent( script_to_run = textwrap.dedent(
""" """
import torch import torch
import os import os
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" model_id = "JackFram/llama-68m"
rank = int(os.environ["RANK"]) rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"]) world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto") model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
torch.distributed.barrier() torch.distributed.barrier()
# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings) has_dtensor = 0
expected_model_memory_per_device = (16 / world_size) + 1 for name, parameter in model.named_parameters():
overhead_factor = 1.2 if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break
# Check that we do not use more than the expected sharded size during initialization assert has_dtensor == 1, "TP model must has DTensor"
if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
raise ValueError("Loading the model used more than the expected fraction of model size per device") tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model(inputs)
next_token_logits = outputs[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
response = tokenizer.decode(next_token)
assert response == "with"
torch.distributed.barrier() torch.distributed.barrier()
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
@ -96,69 +89,6 @@ class TestTensorParallel(TestCasePlus):
self.torchrun(script_to_run) self.torchrun(script_to_run)
if __name__ == "__main__": @require_torch_multi_gpu
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: class TestTensorParallelCuda(TestTensorParallel):
# CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py nproc_per_node = torch.cuda.device_count()
# or
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
if not is_torch_available():
exit(0)
# Test settings
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bs = 1
seqlen = 4096
# Get distributed settings
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Initialize distributed
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,))
# Get model config
config = LlamaConfig.from_pretrained(model_id)
config.hidden_size = 2048
config.attention_bias = False
# Instantiate model
with device:
model = LlamaModel(config).to(dtype=torch.float16)
model.eval()
# Tensor Parallel
if world_size > 1:
model.tensor_parallel(device_mesh)
# Run model
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
# Test cuda graphing explicitly
with torch.cuda.device(device):
print("Cuda graphing")
with torch.no_grad():
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
# CUDA Graph setup
s = torch.cuda.Stream(device=device)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
out = model(inputs)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(inputs)
for _ in range(2):
g.replay()
s.synchronize()
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
# Test compile
with torch.no_grad():
out = model(inputs)
model.forward = torch.compile(model.forward, mode="reduce-overhead")
out = model(inputs)
out = model(inputs)