Simplify Tensor Parallel implementation with PyTorch TP (#34184)

* Simplify Tensor Parallel implementation with PyTorch TP

* Move tp_plan to config

* Lint

* Format and warning

* Disable copy-from check

* Conditionally get attr from config

* make fix-copies

* Move base_model_tp_plan to PretrainedConfig

* Move TP into from_pretrained

* Add device context for load

* Do not serialize

* Move _tp_plan setting to post_init

* Add has_tp_plan

* Add test_tp

* Add 'Multi-gpu inference' doc

* Add backward support for device type identification

* Auto-detect accelerator

* supports_tp_plan

* copyright year

* Fix copy
This commit is contained in:
Ke Wen 2024-11-18 10:51:49 -08:00 committed by GitHub
parent 7df93d6ffb
commit 20142ab542
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 357 additions and 92 deletions

View File

@ -218,6 +218,8 @@
title: CPU inference
- local: perf_infer_gpu_one
title: GPU inference
- local: perf_infer_gpu_multi
title: Multi-GPU inference
title: Optimizing inference
- local: big_models
title: Instantiate a big model

View File

@ -0,0 +1,68 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Multi-GPU inference
Built-in Tensor Parallelism (TP) is now available with certain models using PyTorch. Tensor parallelism shards a model onto multiple GPUs, enabling larger model sizes, and parallelizes computations such as matrix multiplication.
To enable tensor parallel, pass the argument `tp_plan="auto"` to [`~AutoModelForCausalLM.from_pretrained`]:
```python
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# Initialize distributed
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
# Retrieve tensor parallel model
model = AutoModelForCausalLM.from_pretrained(
model_id,
tp_plan="auto",
)
# Prepare input tokens
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Distributed run
outputs = model(inputs)
```
You can use `torchrun` to launch the above script with multiple processes, each mapping to a GPU:
```
torchrun --nproc-per-node 4 demo.py
```
PyTorch tensor parallel is currently supported for the following models:
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
You can request to add tensor parallel support for another model by opening a GitHub Issue or Pull Request.
### Expected speedups
You can benefit from considerable speedups for inference, especially for inputs with large batch size or long sequences.
For a single forward pass on [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) with a sequence length of 512 and various batch sizes, the expected speedup is as follows:
<div style="text-align: center">
<img src="huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct, seqlen = 512, python, w_ compile.png">
</div>

View File

@ -53,7 +53,7 @@ sections we go through the steps to run inference on CPU and single/multi-GPU se
* [Inference on a single CPU](perf_infer_cpu)
* [Inference on a single GPU](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_one)
* [Multi-GPU inference](perf_infer_gpu_multi)
* [XLA Integration for TensorFlow Models](tf_xla)

View File

@ -71,6 +71,8 @@ class PretrainedConfig(PushToHubMixin):
outputs of the model during inference.
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
naming of attributes.
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
Common attributes (present in all subclasses):
@ -194,6 +196,7 @@ class PretrainedConfig(PushToHubMixin):
sub_configs: Dict[str, "PretrainedConfig"] = {}
is_composition: bool = False
attribute_map: Dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None
_auto_class: Optional[str] = None
def __setattr__(self, key, value):
@ -848,6 +851,9 @@ class PretrainedConfig(PushToHubMixin):
if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]
return serializable_config_dict
@ -867,6 +873,9 @@ class PretrainedConfig(PushToHubMixin):
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]
# Transformers version when serializing the model
output["transformers_version"] = __version__

View File

@ -55,6 +55,7 @@ from .pytorch_utils import ( # noqa: F401
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
@ -1326,6 +1327,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False
# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
@ -1370,6 +1377,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan
def dequantize(self):
"""
@ -3399,6 +3409,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Cache path to the GGUF file
gguf_path = None
tp_plan = kwargs.pop("tp_plan", None)
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
if is_fsdp_enabled():
low_cpu_mem_usage = True
@ -4000,6 +4015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
@ -4012,6 +4028,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
elif tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
init_contexts.append(tp_device)
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
@ -4145,6 +4171,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
load_contexts = []
# Make sure we load onto targeted device
if tp_device is not None:
load_contexts.append(tp_device)
with ContextManagers(load_contexts):
(
model,
missing_keys,
@ -4254,6 +4286,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
}
return model, loading_info
if tp_plan is not None:
assert tp_device is not None, "tp_device not set!"
if not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
# Apply Tensor Parallelism
model.tensor_parallel(device_mesh)
return model
@classmethod
@ -4943,6 +4985,54 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return self.hf_quantizer.is_trainable
@property
def supports_tp_plan(self):
"""
Returns whether the model has a tensor parallelism plan.
"""
if self._tp_plan is not None:
return True
# Check if base model has a TP plan
if getattr(self.base_model, "_tp_plan", None) is not None:
return True
return False
def tensor_parallel(self, device_mesh):
"""
Tensor parallelize the model across the given device mesh.
Args:
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}")
# In model configs, we use a neutral type (string) to specify
# parallel styles, here we translate them into torch TP types.
# Using tree_map because `tp_plan` is a dict.
tp_plan = torch.utils._pytree.tree_map(
translate_to_torch_parallel_style,
tp_plan,
)
# Apply TP to current module.
torch.distributed.tensor.parallel.parallelize_module(
mod,
device_mesh=device_mesh,
parallelize_plan=tp_plan,
)
# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
self.apply(tplize)
@property
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:

View File

@ -1068,7 +1068,7 @@ class CohereModel(CoherePreTrainedModel):
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

View File

@ -720,7 +720,10 @@ class GemmaModel(GemmaPreTrainedModel):
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@ -982,6 +985,7 @@ class GemmaModel(GemmaPreTrainedModel):
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)

View File

@ -740,7 +740,10 @@ class Gemma2Model(Gemma2PreTrainedModel):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@ -961,6 +964,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)

View File

@ -708,6 +708,8 @@ class GlmModel(GlmPreTrainedModel):
dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@ -967,6 +969,7 @@ class GlmModel(GlmPreTrainedModel):
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config: GlmConfig):
super().__init__(config)

View File

@ -141,6 +141,16 @@ class LlamaConfig(PretrainedConfig):
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `LlamaModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,

View File

@ -21,7 +21,6 @@ import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
@ -240,25 +239,7 @@ class LlamaMLP(nn.Module):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
@ -320,31 +301,14 @@ class LlamaAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
@ -386,11 +350,6 @@ class LlamaAttention(nn.Module):
attn_output = attn_output.reshape(bsz, q_len, -1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -564,9 +523,10 @@ class LlamaSdpaAttention(LlamaAttention):
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
@ -850,7 +810,10 @@ class LlamaModel(LlamaPreTrainedModel):
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
# Initialize weights and apply final processing
self.post_init()
@ -1113,6 +1076,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
@ -1211,11 +1175,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

View File

@ -980,7 +980,7 @@ class NemotronModel(NemotronPreTrainedModel):
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

View File

@ -1020,7 +1020,7 @@ class OlmoModel(OlmoPreTrainedModel):
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

View File

@ -971,6 +971,7 @@ class Olmo1124Model(Olmo1124PreTrainedModel):
return causal_mask
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO_1124,Llama->Olmo1124
class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

View File

@ -888,7 +888,7 @@ OLMOE_INPUTS_DOCSTRING = r"""
"The bare Olmoe Model outputting raw hidden-states without any specific head on top.",
OLMOE_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Olmoe
class OlmoeModel(OlmoePreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoeDecoderLayer`]

View File

@ -775,7 +775,7 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma
class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

View File

@ -20,6 +20,11 @@ import torch
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)
from .utils import is_torch_xla_available, logging
@ -329,3 +334,22 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
else:
# Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
return torch.isin(elements, test_elements)
def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")

91
tests/tp/test_tp.py Normal file
View File

@ -0,0 +1,91 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
)
if is_torch_available():
import torch
class TestTensorParallel(TestCasePlus):
@require_torch_multi_gpu
def test_tp(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_tp.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir} --report_to none".split()
cmd = ["torchrun"] + distributed_args + args
print(cmd)
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
# CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py
# or
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
if not is_torch_available():
exit(0)
# Test settings
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bs = 4
seqlen = 64
# Get distributed settings
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Initialize distributed
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,))
# Get model config
config = LlamaConfig.from_pretrained(model_id)
# Shrink model size
config.num_hidden_layers //= 8
config.vocab_size //= 8
# Instantiate model
with device:
model = LlamaModel(config)
model.eval()
# Tensor Parallel
if world_size > 1:
model.tensor_parallel(device_mesh)
# Run model
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
with torch.no_grad():
out = model(inputs)
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])