From 286393fbb11e3c95439ed94b01781cba632e2dfb Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 31 Mar 2025 16:55:47 +0800 Subject: [PATCH] enable tp on CPU (#36299) * enable tp on CPU Signed-off-by: jiqing-feng * get rank from cpu Signed-off-by: jiqing-feng * update Signed-off-by: jiqing-feng * enable TP tests Signed-off-by: jiqing-feng * fix comment Signed-off-by: jiqing-feng * em print Signed-off-by: jiqing-feng * fix model id Signed-off-by: jiqing-feng * fix conflict Signed-off-by: jiqing-feng * fix index and add doc Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng --- docs/source/en/perf_infer_gpu_multi.md | 14 +- src/transformers/modeling_utils.py | 29 ++-- tests/tensor_parallel/test_tensor_parallel.py | 126 ++++-------------- 3 files changed, 56 insertions(+), 113 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md index 067633891ce..3aa1f09be55 100644 --- a/docs/source/en/perf_infer_gpu_multi.md +++ b/docs/source/en/perf_infer_gpu_multi.md @@ -44,11 +44,6 @@ import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer -# initialize distributed environment -rank = int(os.environ["RANK"]) -device = torch.device(f"cuda:{rank}") -torch.cuda.set_device(device) -torch.distributed.init_process_group("nccl", device_id=device) # enable tensor parallelism model = AutoModelForCausalLM.from_pretrained( @@ -59,7 +54,7 @@ model = AutoModelForCausalLM.from_pretrained( # prepare input tokens tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") prompt = "Can I help" -inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device) +inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) # distributed run outputs = model(inputs) @@ -71,6 +66,13 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/ torchrun --nproc-per-node 4 demo.py ``` +For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon: +```bash +export OMP_NUM_THREADS=56 +numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait +``` +The CPU benchmark data will be released soon. + You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences. For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8e63c7a6af8..5b27f0c3dbe 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -774,7 +774,8 @@ def _load_state_dict_into_meta_model( """ tensor_device = "cpu" if device_map is not None and device_map.get("", None) is not None: - tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] + if device_map[""] not in ("cpu", torch.device("cpu")): + tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] if device_map is not None: device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) @@ -4110,24 +4111,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if tp_plan is not None: if not is_torch_greater_or_equal("2.5"): raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") + + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type + if not torch.distributed.is_initialized(): try: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - torch.distributed.init_process_group( - "nccl", rank=rank, world_size=world_size, init_method="env://" - ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + if device_type == "cuda": + torch.distributed.init_process_group( + "nccl", rank=rank, world_size=world_size, init_method="env://" + ) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + elif device_type == "cpu": + cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo" + torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size) + except Exception as e: raise EnvironmentError( "We tried to initialize torch.distributed for you, but it failed, make" "sure you init torch distributed in your script to use `tp_plan='auto'`" ) from e - # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. - device_type = torch._C._get_accelerator().type - tp_device = torch.device(device_type, torch.cuda.current_device()) - if tp_device.index > 0: + # Get device with index assuming equal number of devices per host + index = None if device_type == "cpu" else torch.cuda.current_device() + tp_device = torch.device(device_type, index) + + if index is not None and index > 0: import sys sys.stdout = open(os.devnull, "w") diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index b4e58fd7a0b..7276869d764 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -12,18 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import subprocess import tempfile import textwrap -# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py from transformers import is_torch_available -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaModel from transformers.testing_utils import ( TestCasePlus, - execute_subprocess_async, get_torch_dist_unique_port, require_torch_multi_gpu, ) @@ -33,7 +28,10 @@ if is_torch_available(): import torch +# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py class TestTensorParallel(TestCasePlus): + nproc_per_node = 2 + def torchrun(self, script: str): """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: @@ -41,7 +39,7 @@ class TestTensorParallel(TestCasePlus): tmp.flush() tmp.seek(0) cmd = ( - f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}" + f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}" ).split() # Note that the subprocess will be waited for here, and raise an error if not successful @@ -50,44 +48,39 @@ class TestTensorParallel(TestCasePlus): except subprocess.CalledProcessError as e: raise Exception(f"The following error was captured: {e.stderr}") - @require_torch_multi_gpu - def test_tp(self): - distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} - --master_port={get_torch_dist_unique_port()} - {self.test_file_dir}/test_tp.py - """.split() - output_dir = self.get_auto_remove_tmp_dir() - args = f"--output_dir {output_dir} --report_to none".split() - cmd = ["torchrun"] + distributed_args + args - print(cmd) - execute_subprocess_async(cmd, env=self.get_env()) - # successful return here == success - any errors would have caused an error in the sub-call - - @require_torch_multi_gpu - def test_loading_memory_consumption(self): + def test_model_forward(self): script_to_run = textwrap.dedent( """ import torch import os - from transformers import AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoTokenizer - model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + model_id = "JackFram/llama-68m" rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - device = torch.device(f"cuda:{rank}") - torch.distributed.init_process_group("nccl", device_id=device) - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto") + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto") torch.distributed.barrier() - # The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings) - expected_model_memory_per_device = (16 / world_size) + 1 - overhead_factor = 1.2 + has_dtensor = 0 + for name, parameter in model.named_parameters(): + if isinstance(parameter.data, torch.distributed.tensor.DTensor): + has_dtensor = 1 + break - # Check that we do not use more than the expected sharded size during initialization - if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor: - raise ValueError("Loading the model used more than the expected fraction of model size per device") + assert has_dtensor == 1, "TP model must has DTensor" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + prompt = "Can I help" + + inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) + outputs = model(inputs) + + next_token_logits = outputs[0][:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1) + response = tokenizer.decode(next_token) + assert response == "with" torch.distributed.barrier() torch.distributed.destroy_process_group() @@ -96,69 +89,6 @@ class TestTensorParallel(TestCasePlus): self.torchrun(script_to_run) -if __name__ == "__main__": - # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: - # CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py - # or - # PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py - - if not is_torch_available(): - exit(0) - - # Test settings - model_id = "meta-llama/Meta-Llama-3-8B-Instruct" - bs = 1 - seqlen = 4096 - # Get distributed settings - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - # Initialize distributed - device = torch.device(f"cuda:{rank}") - torch.distributed.init_process_group("nccl", device_id=device) - device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,)) - - # Get model config - config = LlamaConfig.from_pretrained(model_id) - config.hidden_size = 2048 - config.attention_bias = False - # Instantiate model - with device: - model = LlamaModel(config).to(dtype=torch.float16) - - model.eval() - # Tensor Parallel - if world_size > 1: - model.tensor_parallel(device_mesh) - # Run model - - inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device) - - # Test cuda graphing explicitly - with torch.cuda.device(device): - print("Cuda graphing") - with torch.no_grad(): - inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device) - # CUDA Graph setup - s = torch.cuda.Stream(device=device) - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for i in range(3): - out = model(inputs) - torch.cuda.current_stream().wait_stream(s) - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - out = model(inputs) - - for _ in range(2): - g.replay() - s.synchronize() - - assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size]) - - # Test compile - with torch.no_grad(): - out = model(inputs) - model.forward = torch.compile(model.forward, mode="reduce-overhead") - out = model(inputs) - out = model(inputs) +@require_torch_multi_gpu +class TestTensorParallelCuda(TestTensorParallel): + nproc_per_node = torch.cuda.device_count()